superlishunqin 07c7151272 2th-version
2025-03-13 09:35:44 +08:00

320 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
训练器模块:实现模型训练流程,包括训练循环、验证等
"""
import os
import time
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime
from config.system_config import SAVED_MODELS_DIR
from config.model_config import (
NUM_EPOCHS, BATCH_SIZE, EARLY_STOPPING_PATIENCE,
VALIDATION_SPLIT, RANDOM_SEED
)
from models.base_model import TextClassificationModel
from utils.logger import get_logger, TrainingLogger
from utils.file_utils import ensure_dir
logger = get_logger("Trainer")
class Trainer:
"""模型训练器,负责训练和验证模型"""
def __init__(self, model: TextClassificationModel,
epochs: int = NUM_EPOCHS,
batch_size: Optional[int] = None,
validation_split: float = VALIDATION_SPLIT,
early_stopping: bool = True,
early_stopping_patience: int = EARLY_STOPPING_PATIENCE,
save_best_only: bool = True,
tensorboard: bool = True,
checkpoint: bool = True,
custom_callbacks: Optional[List[tf.keras.callbacks.Callback]] = None):
"""
初始化训练器
Args:
model: 要训练的模型
epochs: 训练轮数
batch_size: 批大小如果为None则使用模型默认值
validation_split: 验证集比例
early_stopping: 是否使用早停
early_stopping_patience: 早停耐心值
save_best_only: 是否只保存最佳模型
tensorboard: 是否使用TensorBoard
checkpoint: 是否保存检查点
custom_callbacks: 自定义回调函数列表
"""
self.model = model
self.epochs = epochs
self.batch_size = batch_size or model.batch_size
self.validation_split = validation_split
self.early_stopping = early_stopping
self.early_stopping_patience = early_stopping_patience
self.save_best_only = save_best_only
self.tensorboard = tensorboard
self.checkpoint = checkpoint
self.custom_callbacks = custom_callbacks or []
# 训练历史
self.history = None
# 训练日志记录器
self.training_logger = TrainingLogger(model.model_name)
logger.info(f"初始化训练器,模型: {model.model_name}, 轮数: {epochs}, 批大小: {self.batch_size}")
def _create_callbacks(self) -> List[tf.keras.callbacks.Callback]:
"""
创建回调函数列表
Returns:
回调函数列表
"""
callbacks = []
# 早停
if self.early_stopping:
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=self.early_stopping_patience,
restore_best_weights=True,
verbose=1
)
callbacks.append(early_stopping)
# 学习率衰减
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=self.early_stopping_patience // 2,
min_lr=1e-6,
verbose=1
)
callbacks.append(reduce_lr)
# 模型检查点
if self.checkpoint:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join(SAVED_MODELS_DIR, 'checkpoints')
ensure_dir(checkpoint_dir)
checkpoint_path = os.path.join(
checkpoint_dir,
f"{self.model.model_name}_{timestamp}.h5"
)
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_best_only=self.save_best_only,
monitor='val_loss',
verbose=1
)
callbacks.append(model_checkpoint)
# TensorBoard
if self.tensorboard:
log_dir = os.path.join(
SAVED_MODELS_DIR,
'logs',
f"{self.model.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
)
ensure_dir(log_dir)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1,
update_freq='epoch'
)
callbacks.append(tensorboard_callback)
# 添加自定义回调函数
callbacks.extend(self.custom_callbacks)
return callbacks
def _log_training_progress(self, epoch: int, logs: Dict[str, float]) -> None:
"""
记录训练进度
Args:
epoch: 当前轮数
logs: 日志信息
"""
self.training_logger.log_epoch(epoch, logs)
def train(self, x_train: Union[np.ndarray, tf.data.Dataset],
y_train: Optional[np.ndarray] = None,
x_val: Optional[Union[np.ndarray, tf.data.Dataset]] = None,
y_val: Optional[np.ndarray] = None,
class_weights: Optional[Dict[int, float]] = None) -> Dict[str, List[float]]:
"""
训练模型
Args:
x_train: 训练数据特征
y_train: 训练数据标签
x_val: 验证数据特征
y_val: 验证数据标签
class_weights: 类别权重
Returns:
训练历史
"""
logger.info(f"开始训练模型: {self.model.model_name}")
# 创建回调函数
callbacks = self._create_callbacks()
# 添加训练进度记录回调
progress_callback = tf.keras.callbacks.LambdaCallback(
on_epoch_end=lambda epoch, logs: self._log_training_progress(epoch, logs)
)
callbacks.append(progress_callback)
# 记录开始时间
start_time = time.time()
# 记录训练开始信息
model_config = self.model.get_config()
train_config = {
"epochs": self.epochs,
"batch_size": self.batch_size,
"validation_split": self.validation_split,
"early_stopping": self.early_stopping,
"early_stopping_patience": self.early_stopping_patience
}
self.training_logger.log_training_start({**model_config, **train_config})
# 检查和配置 GPU 使用
physical_devices = tf.config.list_physical_devices('GPU')
logger.info(f"可用的物理 GPU 设备: {physical_devices}")
# 记录当前使用的设备情况
logger.info(f"TensorFlow 版本: {tf.__version__}")
if physical_devices:
logger.info(f"模型将使用 GPU 进行训练")
try:
# 设置 GPU 内存增长模式
for gpu in physical_devices:
tf.config.experimental.set_memory_growth(gpu, True)
logger.info(f"已设置 GPU 内存增长模式")
except RuntimeError as e:
logger.warning(f"设置 GPU 内存增长时出错: {e}")
else:
logger.warning(f"未检测到 GPU将使用 CPU 进行训练")
# 尝试强制使用 GPU
if physical_devices:
try:
# 将运算放到 GPU 上
with tf.device('/GPU:0'):
logger.info("已强制指定使用 GPU:0 进行训练")
except RuntimeError as e:
logger.warning(f"指定 GPU 设备时出错: {e}")
# 准备验证数据
validation_data = None
if x_val is not None and y_val is not None:
validation_data = (x_val, y_val)
# 训练模型
if physical_devices:
with tf.device('/GPU:0'):
history = self.model.fit(
x_train, y_train,
validation_data=validation_data,
epochs=self.epochs,
callbacks=callbacks,
class_weights=class_weights,
verbose=1
)
else:
history = self.model.fit(
x_train, y_train,
validation_data=validation_data,
epochs=self.epochs,
callbacks=callbacks,
class_weights=class_weights,
verbose=1
)
# 计算训练时间
train_time = time.time() - start_time
# 保存训练历史
self.history = history.history
# 找出最佳性能
best_val_loss = min(history.history['val_loss']) if 'val_loss' in history.history else None
best_val_acc = max(history.history['val_accuracy']) if 'val_accuracy' in history.history else None
best_metrics = {}
if best_val_loss is not None:
best_metrics['val_loss'] = best_val_loss
if best_val_acc is not None:
best_metrics['val_accuracy'] = best_val_acc
# 记录训练结束信息
self.training_logger.log_training_end(train_time, best_metrics)
logger.info(f"模型训练完成,用时: {train_time:.2f}")
return history.history
def plot_training_history(self, metrics: Optional[List[str]] = None,
save_path: Optional[str] = None) -> None:
"""
绘制训练历史
Args:
metrics: 要绘制的指标列表,默认为['loss', 'accuracy']
save_path: 保存路径如果为None则显示图像
"""
if self.history is None:
raise ValueError("模型尚未训练,没有训练历史")
if metrics is None:
metrics = ['loss', 'accuracy']
plt.figure(figsize=(12, 5))
for i, metric in enumerate(metrics):
plt.subplot(1, len(metrics), i + 1)
if metric in self.history:
plt.plot(self.history[metric], label=f'train_{metric}')
val_metric = f'val_{metric}'
if val_metric in self.history:
plt.plot(self.history[val_metric], label=f'val_{metric}')
plt.title(f'Model {metric}')
plt.xlabel('Epoch')
plt.ylabel(metric)
plt.legend()
plt.tight_layout()
if save_path:
plt.savefig(save_path)
logger.info(f"训练历史图已保存到: {save_path}")
else:
plt.show()
def save_trained_model(self, filepath: Optional[str] = None) -> str:
"""
保存训练好的模型
Args:
filepath: 保存路径如果为None则使用默认路径
Returns:
保存路径
"""
return self.model.save(filepath)