320 lines
11 KiB
Python
320 lines
11 KiB
Python
"""
|
||
训练器模块:实现模型训练流程,包括训练循环、验证等
|
||
"""
|
||
import os
|
||
import time
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
import matplotlib.pyplot as plt
|
||
from datetime import datetime
|
||
|
||
from config.system_config import SAVED_MODELS_DIR
|
||
from config.model_config import (
|
||
NUM_EPOCHS, BATCH_SIZE, EARLY_STOPPING_PATIENCE,
|
||
VALIDATION_SPLIT, RANDOM_SEED
|
||
)
|
||
from models.base_model import TextClassificationModel
|
||
from utils.logger import get_logger, TrainingLogger
|
||
from utils.file_utils import ensure_dir
|
||
|
||
logger = get_logger("Trainer")
|
||
|
||
|
||
class Trainer:
|
||
"""模型训练器,负责训练和验证模型"""
|
||
|
||
def __init__(self, model: TextClassificationModel,
|
||
epochs: int = NUM_EPOCHS,
|
||
batch_size: Optional[int] = None,
|
||
validation_split: float = VALIDATION_SPLIT,
|
||
early_stopping: bool = True,
|
||
early_stopping_patience: int = EARLY_STOPPING_PATIENCE,
|
||
save_best_only: bool = True,
|
||
tensorboard: bool = True,
|
||
checkpoint: bool = True,
|
||
custom_callbacks: Optional[List[tf.keras.callbacks.Callback]] = None):
|
||
"""
|
||
初始化训练器
|
||
|
||
Args:
|
||
model: 要训练的模型
|
||
epochs: 训练轮数
|
||
batch_size: 批大小,如果为None则使用模型默认值
|
||
validation_split: 验证集比例
|
||
early_stopping: 是否使用早停
|
||
early_stopping_patience: 早停耐心值
|
||
save_best_only: 是否只保存最佳模型
|
||
tensorboard: 是否使用TensorBoard
|
||
checkpoint: 是否保存检查点
|
||
custom_callbacks: 自定义回调函数列表
|
||
"""
|
||
self.model = model
|
||
self.epochs = epochs
|
||
self.batch_size = batch_size or model.batch_size
|
||
self.validation_split = validation_split
|
||
self.early_stopping = early_stopping
|
||
self.early_stopping_patience = early_stopping_patience
|
||
self.save_best_only = save_best_only
|
||
self.tensorboard = tensorboard
|
||
self.checkpoint = checkpoint
|
||
self.custom_callbacks = custom_callbacks or []
|
||
|
||
# 训练历史
|
||
self.history = None
|
||
|
||
# 训练日志记录器
|
||
self.training_logger = TrainingLogger(model.model_name)
|
||
|
||
logger.info(f"初始化训练器,模型: {model.model_name}, 轮数: {epochs}, 批大小: {self.batch_size}")
|
||
|
||
def _create_callbacks(self) -> List[tf.keras.callbacks.Callback]:
|
||
"""
|
||
创建回调函数列表
|
||
|
||
Returns:
|
||
回调函数列表
|
||
"""
|
||
callbacks = []
|
||
|
||
# 早停
|
||
if self.early_stopping:
|
||
early_stopping = tf.keras.callbacks.EarlyStopping(
|
||
monitor='val_loss',
|
||
patience=self.early_stopping_patience,
|
||
restore_best_weights=True,
|
||
verbose=1
|
||
)
|
||
callbacks.append(early_stopping)
|
||
|
||
# 学习率衰减
|
||
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
|
||
monitor='val_loss',
|
||
factor=0.5,
|
||
patience=self.early_stopping_patience // 2,
|
||
min_lr=1e-6,
|
||
verbose=1
|
||
)
|
||
callbacks.append(reduce_lr)
|
||
|
||
# 模型检查点
|
||
if self.checkpoint:
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
checkpoint_dir = os.path.join(SAVED_MODELS_DIR, 'checkpoints')
|
||
ensure_dir(checkpoint_dir)
|
||
|
||
checkpoint_path = os.path.join(
|
||
checkpoint_dir,
|
||
f"{self.model.model_name}_{timestamp}.h5"
|
||
)
|
||
|
||
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||
filepath=checkpoint_path,
|
||
save_best_only=self.save_best_only,
|
||
monitor='val_loss',
|
||
verbose=1
|
||
)
|
||
callbacks.append(model_checkpoint)
|
||
|
||
# TensorBoard
|
||
if self.tensorboard:
|
||
log_dir = os.path.join(
|
||
SAVED_MODELS_DIR,
|
||
'logs',
|
||
f"{self.model.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
)
|
||
ensure_dir(log_dir)
|
||
|
||
tensorboard_callback = tf.keras.callbacks.TensorBoard(
|
||
log_dir=log_dir,
|
||
histogram_freq=1,
|
||
update_freq='epoch'
|
||
)
|
||
callbacks.append(tensorboard_callback)
|
||
|
||
# 添加自定义回调函数
|
||
callbacks.extend(self.custom_callbacks)
|
||
|
||
return callbacks
|
||
|
||
def _log_training_progress(self, epoch: int, logs: Dict[str, float]) -> None:
|
||
"""
|
||
记录训练进度
|
||
|
||
Args:
|
||
epoch: 当前轮数
|
||
logs: 日志信息
|
||
"""
|
||
self.training_logger.log_epoch(epoch, logs)
|
||
|
||
def train(self, x_train: Union[np.ndarray, tf.data.Dataset],
|
||
y_train: Optional[np.ndarray] = None,
|
||
x_val: Optional[Union[np.ndarray, tf.data.Dataset]] = None,
|
||
y_val: Optional[np.ndarray] = None,
|
||
class_weights: Optional[Dict[int, float]] = None) -> Dict[str, List[float]]:
|
||
"""
|
||
训练模型
|
||
|
||
Args:
|
||
x_train: 训练数据特征
|
||
y_train: 训练数据标签
|
||
x_val: 验证数据特征
|
||
y_val: 验证数据标签
|
||
class_weights: 类别权重
|
||
|
||
Returns:
|
||
训练历史
|
||
"""
|
||
logger.info(f"开始训练模型: {self.model.model_name}")
|
||
|
||
# 创建回调函数
|
||
callbacks = self._create_callbacks()
|
||
|
||
# 添加训练进度记录回调
|
||
progress_callback = tf.keras.callbacks.LambdaCallback(
|
||
on_epoch_end=lambda epoch, logs: self._log_training_progress(epoch, logs)
|
||
)
|
||
callbacks.append(progress_callback)
|
||
|
||
# 记录开始时间
|
||
start_time = time.time()
|
||
|
||
# 记录训练开始信息
|
||
model_config = self.model.get_config()
|
||
train_config = {
|
||
"epochs": self.epochs,
|
||
"batch_size": self.batch_size,
|
||
"validation_split": self.validation_split,
|
||
"early_stopping": self.early_stopping,
|
||
"early_stopping_patience": self.early_stopping_patience
|
||
}
|
||
self.training_logger.log_training_start({**model_config, **train_config})
|
||
|
||
# 检查和配置 GPU 使用
|
||
physical_devices = tf.config.list_physical_devices('GPU')
|
||
logger.info(f"可用的物理 GPU 设备: {physical_devices}")
|
||
|
||
# 记录当前使用的设备情况
|
||
logger.info(f"TensorFlow 版本: {tf.__version__}")
|
||
if physical_devices:
|
||
logger.info(f"模型将使用 GPU 进行训练")
|
||
try:
|
||
# 设置 GPU 内存增长模式
|
||
for gpu in physical_devices:
|
||
tf.config.experimental.set_memory_growth(gpu, True)
|
||
logger.info(f"已设置 GPU 内存增长模式")
|
||
except RuntimeError as e:
|
||
logger.warning(f"设置 GPU 内存增长时出错: {e}")
|
||
else:
|
||
logger.warning(f"未检测到 GPU,将使用 CPU 进行训练")
|
||
|
||
# 尝试强制使用 GPU
|
||
if physical_devices:
|
||
try:
|
||
# 将运算放到 GPU 上
|
||
with tf.device('/GPU:0'):
|
||
logger.info("已强制指定使用 GPU:0 进行训练")
|
||
except RuntimeError as e:
|
||
logger.warning(f"指定 GPU 设备时出错: {e}")
|
||
|
||
# 准备验证数据
|
||
validation_data = None
|
||
if x_val is not None and y_val is not None:
|
||
validation_data = (x_val, y_val)
|
||
|
||
# 训练模型
|
||
if physical_devices:
|
||
with tf.device('/GPU:0'):
|
||
history = self.model.fit(
|
||
x_train, y_train,
|
||
validation_data=validation_data,
|
||
epochs=self.epochs,
|
||
callbacks=callbacks,
|
||
class_weights=class_weights,
|
||
verbose=1
|
||
)
|
||
else:
|
||
history = self.model.fit(
|
||
x_train, y_train,
|
||
validation_data=validation_data,
|
||
epochs=self.epochs,
|
||
callbacks=callbacks,
|
||
class_weights=class_weights,
|
||
verbose=1
|
||
)
|
||
|
||
# 计算训练时间
|
||
train_time = time.time() - start_time
|
||
|
||
# 保存训练历史
|
||
self.history = history.history
|
||
|
||
# 找出最佳性能
|
||
best_val_loss = min(history.history['val_loss']) if 'val_loss' in history.history else None
|
||
best_val_acc = max(history.history['val_accuracy']) if 'val_accuracy' in history.history else None
|
||
|
||
best_metrics = {}
|
||
if best_val_loss is not None:
|
||
best_metrics['val_loss'] = best_val_loss
|
||
if best_val_acc is not None:
|
||
best_metrics['val_accuracy'] = best_val_acc
|
||
|
||
# 记录训练结束信息
|
||
self.training_logger.log_training_end(train_time, best_metrics)
|
||
|
||
logger.info(f"模型训练完成,用时: {train_time:.2f} 秒")
|
||
|
||
return history.history
|
||
|
||
def plot_training_history(self, metrics: Optional[List[str]] = None,
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制训练历史
|
||
|
||
Args:
|
||
metrics: 要绘制的指标列表,默认为['loss', 'accuracy']
|
||
save_path: 保存路径,如果为None则显示图像
|
||
"""
|
||
if self.history is None:
|
||
raise ValueError("模型尚未训练,没有训练历史")
|
||
|
||
if metrics is None:
|
||
metrics = ['loss', 'accuracy']
|
||
|
||
plt.figure(figsize=(12, 5))
|
||
|
||
for i, metric in enumerate(metrics):
|
||
plt.subplot(1, len(metrics), i + 1)
|
||
|
||
if metric in self.history:
|
||
plt.plot(self.history[metric], label=f'train_{metric}')
|
||
|
||
val_metric = f'val_{metric}'
|
||
if val_metric in self.history:
|
||
plt.plot(self.history[val_metric], label=f'val_{metric}')
|
||
|
||
plt.title(f'Model {metric}')
|
||
plt.xlabel('Epoch')
|
||
plt.ylabel(metric)
|
||
plt.legend()
|
||
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"训练历史图已保存到: {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
def save_trained_model(self, filepath: Optional[str] = None) -> str:
|
||
"""
|
||
保存训练好的模型
|
||
|
||
Args:
|
||
filepath: 保存路径,如果为None则使用默认路径
|
||
|
||
Returns:
|
||
保存路径
|
||
"""
|
||
return self.model.save(filepath)
|