2th-version
This commit is contained in:
parent
7e37329831
commit
07c7151272
@ -0,0 +1,84 @@
|
||||
import os
|
||||
|
||||
|
||||
def find_all_python_files(root_dir='.'):
|
||||
"""
|
||||
查找指定目录及其所有子目录下的所有Python文件
|
||||
|
||||
Args:
|
||||
root_dir: 根目录路径,默认为当前目录
|
||||
|
||||
Returns:
|
||||
包含所有Python文件路径的列表
|
||||
"""
|
||||
python_files = []
|
||||
|
||||
# 遍历根目录及所有子目录
|
||||
for dirpath, dirnames, filenames in os.walk(root_dir):
|
||||
# 查找所有.py文件
|
||||
for filename in filenames:
|
||||
if filename.endswith('.py'):
|
||||
# 构建完整文件路径
|
||||
full_path = os.path.join(dirpath, filename)
|
||||
python_files.append(full_path)
|
||||
|
||||
return python_files
|
||||
|
||||
|
||||
def export_file_contents_to_txt(python_files, output_file='python_contents.txt'):
|
||||
"""
|
||||
将所有Python文件的内容导出到一个TXT文件
|
||||
|
||||
Args:
|
||||
python_files: Python文件路径列表
|
||||
output_file: 输出TXT文件名
|
||||
"""
|
||||
with open(output_file, 'w', encoding='utf-8') as outfile:
|
||||
for file_path in python_files:
|
||||
# 获取相对路径以便更好地显示
|
||||
rel_path = os.path.relpath(file_path)
|
||||
|
||||
# 写入文件分隔符
|
||||
outfile.write(f"\n{'=' * 80}\n")
|
||||
outfile.write(f"文件: {rel_path}\n")
|
||||
outfile.write(f"{'=' * 80}\n\n")
|
||||
|
||||
# 读取Python文件内容并写入输出文件
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as infile:
|
||||
content = infile.read()
|
||||
outfile.write(content)
|
||||
outfile.write("\n") # 文件末尾添加换行
|
||||
except Exception as e:
|
||||
outfile.write(f"[无法读取文件内容: {str(e)}]\n")
|
||||
|
||||
print(f"已将{len(python_files)}个Python文件的内容导出到 {output_file}")
|
||||
|
||||
|
||||
def main():
|
||||
# 获取当前工作目录
|
||||
current_dir = os.getcwd()
|
||||
print(f"正在搜索目录: {current_dir}")
|
||||
|
||||
# 查找所有Python文件
|
||||
python_files = find_all_python_files(current_dir)
|
||||
|
||||
if python_files:
|
||||
print(f"找到 {len(python_files)} 个Python文件")
|
||||
|
||||
# 导出所有文件内容到TXT
|
||||
export_file_contents_to_txt(python_files)
|
||||
|
||||
# 打印处理的文件列表
|
||||
for i, file_path in enumerate(python_files[:10]):
|
||||
rel_path = os.path.relpath(file_path)
|
||||
print(f"{i + 1}. {rel_path}")
|
||||
|
||||
if len(python_files) > 10:
|
||||
print(f"... 还有 {len(python_files) - 10} 个文件")
|
||||
else:
|
||||
print("未找到任何Python文件")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
10286
python_contents.txt
Normal file
10286
python_contents.txt
Normal file
File diff suppressed because it is too large
Load Diff
54
python_files.txt
Normal file
54
python_files.txt
Normal file
@ -0,0 +1,54 @@
|
||||
export_all_pyfile.py
|
||||
setup.py
|
||||
main.py
|
||||
interface/__init__.py
|
||||
interface/api.py
|
||||
interface/cli.py
|
||||
interface/web/__init__.py
|
||||
interface/web/app.py
|
||||
interface/web/routes.py
|
||||
config/model_config.py
|
||||
config/__init__.py
|
||||
config/system_config.py
|
||||
training/__init__.py
|
||||
training/callbacks.py
|
||||
training/optimizer.py
|
||||
training/scheduler.py
|
||||
training/trainer.py
|
||||
tests/__init__.py
|
||||
tests/test_evaluation.py
|
||||
tests/test_preprocessing.py
|
||||
tests/test_models.py
|
||||
utils/text_utils.py
|
||||
utils/__init__.py
|
||||
utils/time_utils.py
|
||||
utils/logger.py
|
||||
utils/file_utils.py
|
||||
models/ensemble_model.py
|
||||
models/transformer_model.py
|
||||
models/__init__.py
|
||||
models/base_model.py
|
||||
models/rnn_model.py
|
||||
models/model_factory.py
|
||||
models/cnn_model.py
|
||||
models/layers/__init__.py
|
||||
scripts/predict.py
|
||||
scripts/train.py
|
||||
scripts/evaluate.py
|
||||
inference/predictor.py
|
||||
inference/batch_processor.py
|
||||
inference/__init__.py
|
||||
evaluation/metrics.py
|
||||
evaluation/__init__.py
|
||||
evaluation/visualization.py
|
||||
evaluation/evaluator.py
|
||||
preprocessing/__init__.py
|
||||
preprocessing/tokenization.py
|
||||
preprocessing/vectorizer.py
|
||||
preprocessing/text_cleaner.py
|
||||
preprocessing/feature_extraction.py
|
||||
preprocessing/data_augmentation.py
|
||||
data/__init__.py
|
||||
data/dataset.py
|
||||
data/dataloader.py
|
||||
data/data_manager.py
|
@ -11,6 +11,19 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 检测 GPU
|
||||
physical_devices = tf.config.list_physical_devices('GPU')
|
||||
print("可用的物理 GPU 设备:", physical_devices)
|
||||
|
||||
if physical_devices:
|
||||
try:
|
||||
# 设置 GPU 内存增长模式
|
||||
for gpu in physical_devices:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
print("已设置 GPU 内存增长模式")
|
||||
except RuntimeError as e:
|
||||
print(f"设置 GPU 内存增长时出错: {e}")
|
||||
|
||||
# 将项目根目录添加到系统路径
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(project_root)
|
||||
|
@ -190,12 +190,50 @@ class Trainer:
|
||||
}
|
||||
self.training_logger.log_training_start({**model_config, **train_config})
|
||||
|
||||
# 检查和配置 GPU 使用
|
||||
physical_devices = tf.config.list_physical_devices('GPU')
|
||||
logger.info(f"可用的物理 GPU 设备: {physical_devices}")
|
||||
|
||||
# 记录当前使用的设备情况
|
||||
logger.info(f"TensorFlow 版本: {tf.__version__}")
|
||||
if physical_devices:
|
||||
logger.info(f"模型将使用 GPU 进行训练")
|
||||
try:
|
||||
# 设置 GPU 内存增长模式
|
||||
for gpu in physical_devices:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
logger.info(f"已设置 GPU 内存增长模式")
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"设置 GPU 内存增长时出错: {e}")
|
||||
else:
|
||||
logger.warning(f"未检测到 GPU,将使用 CPU 进行训练")
|
||||
|
||||
# 尝试强制使用 GPU
|
||||
if physical_devices:
|
||||
try:
|
||||
# 将运算放到 GPU 上
|
||||
with tf.device('/GPU:0'):
|
||||
logger.info("已强制指定使用 GPU:0 进行训练")
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"指定 GPU 设备时出错: {e}")
|
||||
|
||||
# 准备验证数据
|
||||
validation_data = None
|
||||
if x_val is not None and y_val is not None:
|
||||
validation_data = (x_val, y_val)
|
||||
|
||||
# 训练模型
|
||||
if physical_devices:
|
||||
with tf.device('/GPU:0'):
|
||||
history = self.model.fit(
|
||||
x_train, y_train,
|
||||
validation_data=validation_data,
|
||||
epochs=self.epochs,
|
||||
callbacks=callbacks,
|
||||
class_weights=class_weights,
|
||||
verbose=1
|
||||
)
|
||||
else:
|
||||
history = self.model.fit(
|
||||
x_train, y_train,
|
||||
validation_data=validation_data,
|
||||
|
Loading…
x
Reference in New Issue
Block a user