""" 训练器模块:实现模型训练流程,包括训练循环、验证等 """ import os import time from typing import List, Dict, Tuple, Optional, Any, Union, Callable import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from datetime import datetime from config.system_config import SAVED_MODELS_DIR from config.model_config import ( NUM_EPOCHS, BATCH_SIZE, EARLY_STOPPING_PATIENCE, VALIDATION_SPLIT, RANDOM_SEED ) from models.base_model import TextClassificationModel from utils.logger import get_logger, TrainingLogger from utils.file_utils import ensure_dir logger = get_logger("Trainer") class Trainer: """模型训练器,负责训练和验证模型""" def __init__(self, model: TextClassificationModel, epochs: int = NUM_EPOCHS, batch_size: Optional[int] = None, validation_split: float = VALIDATION_SPLIT, early_stopping: bool = True, early_stopping_patience: int = EARLY_STOPPING_PATIENCE, save_best_only: bool = True, tensorboard: bool = True, checkpoint: bool = True, custom_callbacks: Optional[List[tf.keras.callbacks.Callback]] = None): """ 初始化训练器 Args: model: 要训练的模型 epochs: 训练轮数 batch_size: 批大小,如果为None则使用模型默认值 validation_split: 验证集比例 early_stopping: 是否使用早停 early_stopping_patience: 早停耐心值 save_best_only: 是否只保存最佳模型 tensorboard: 是否使用TensorBoard checkpoint: 是否保存检查点 custom_callbacks: 自定义回调函数列表 """ self.model = model self.epochs = epochs self.batch_size = batch_size or model.batch_size self.validation_split = validation_split self.early_stopping = early_stopping self.early_stopping_patience = early_stopping_patience self.save_best_only = save_best_only self.tensorboard = tensorboard self.checkpoint = checkpoint self.custom_callbacks = custom_callbacks or [] # 训练历史 self.history = None # 训练日志记录器 self.training_logger = TrainingLogger(model.model_name) logger.info(f"初始化训练器,模型: {model.model_name}, 轮数: {epochs}, 批大小: {self.batch_size}") def _create_callbacks(self) -> List[tf.keras.callbacks.Callback]: """ 创建回调函数列表 Returns: 回调函数列表 """ callbacks = [] # 早停 if self.early_stopping: early_stopping = tf.keras.callbacks.EarlyStopping( monitor='val_loss', patience=self.early_stopping_patience, restore_best_weights=True, verbose=1 ) callbacks.append(early_stopping) # 学习率衰减 reduce_lr = tf.keras.callbacks.ReduceLROnPlateau( monitor='val_loss', factor=0.5, patience=self.early_stopping_patience // 2, min_lr=1e-6, verbose=1 ) callbacks.append(reduce_lr) # 模型检查点 if self.checkpoint: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") checkpoint_dir = os.path.join(SAVED_MODELS_DIR, 'checkpoints') ensure_dir(checkpoint_dir) checkpoint_path = os.path.join( checkpoint_dir, f"{self.model.model_name}_{timestamp}.h5" ) model_checkpoint = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_path, save_best_only=self.save_best_only, monitor='val_loss', verbose=1 ) callbacks.append(model_checkpoint) # TensorBoard if self.tensorboard: log_dir = os.path.join( SAVED_MODELS_DIR, 'logs', f"{self.model.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" ) ensure_dir(log_dir) tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=log_dir, histogram_freq=1, update_freq='epoch' ) callbacks.append(tensorboard_callback) # 添加自定义回调函数 callbacks.extend(self.custom_callbacks) return callbacks def _log_training_progress(self, epoch: int, logs: Dict[str, float]) -> None: """ 记录训练进度 Args: epoch: 当前轮数 logs: 日志信息 """ self.training_logger.log_epoch(epoch, logs) def train(self, x_train: Union[np.ndarray, tf.data.Dataset], y_train: Optional[np.ndarray] = None, x_val: Optional[Union[np.ndarray, tf.data.Dataset]] = None, y_val: Optional[np.ndarray] = None, class_weights: Optional[Dict[int, float]] = None) -> Dict[str, List[float]]: """ 训练模型 Args: x_train: 训练数据特征 y_train: 训练数据标签 x_val: 验证数据特征 y_val: 验证数据标签 class_weights: 类别权重 Returns: 训练历史 """ logger.info(f"开始训练模型: {self.model.model_name}") # 创建回调函数 callbacks = self._create_callbacks() # 添加训练进度记录回调 progress_callback = tf.keras.callbacks.LambdaCallback( on_epoch_end=lambda epoch, logs: self._log_training_progress(epoch, logs) ) callbacks.append(progress_callback) # 记录开始时间 start_time = time.time() # 记录训练开始信息 model_config = self.model.get_config() train_config = { "epochs": self.epochs, "batch_size": self.batch_size, "validation_split": self.validation_split, "early_stopping": self.early_stopping, "early_stopping_patience": self.early_stopping_patience } self.training_logger.log_training_start({**model_config, **train_config}) # 准备验证数据 validation_data = None if x_val is not None and y_val is not None: validation_data = (x_val, y_val) # 训练模型 history = self.model.fit( x_train, y_train, validation_data=validation_data, epochs=self.epochs, callbacks=callbacks, class_weights=class_weights, verbose=1 ) # 计算训练时间 train_time = time.time() - start_time # 保存训练历史 self.history = history.history # 找出最佳性能 best_val_loss = min(history.history['val_loss']) if 'val_loss' in history.history else None best_val_acc = max(history.history['val_accuracy']) if 'val_accuracy' in history.history else None best_metrics = {} if best_val_loss is not None: best_metrics['val_loss'] = best_val_loss if best_val_acc is not None: best_metrics['val_accuracy'] = best_val_acc # 记录训练结束信息 self.training_logger.log_training_end(train_time, best_metrics) logger.info(f"模型训练完成,用时: {train_time:.2f} 秒") return history.history def plot_training_history(self, metrics: Optional[List[str]] = None, save_path: Optional[str] = None) -> None: """ 绘制训练历史 Args: metrics: 要绘制的指标列表,默认为['loss', 'accuracy'] save_path: 保存路径,如果为None则显示图像 """ if self.history is None: raise ValueError("模型尚未训练,没有训练历史") if metrics is None: metrics = ['loss', 'accuracy'] plt.figure(figsize=(12, 5)) for i, metric in enumerate(metrics): plt.subplot(1, len(metrics), i + 1) if metric in self.history: plt.plot(self.history[metric], label=f'train_{metric}') val_metric = f'val_{metric}' if val_metric in self.history: plt.plot(self.history[val_metric], label=f'val_{metric}') plt.title(f'Model {metric}') plt.xlabel('Epoch') plt.ylabel(metric) plt.legend() plt.tight_layout() if save_path: plt.savefig(save_path) logger.info(f"训练历史图已保存到: {save_path}") else: plt.show() def save_trained_model(self, filepath: Optional[str] = None) -> str: """ 保存训练好的模型 Args: filepath: 保存路径,如果为None则使用默认路径 Returns: 保存路径 """ return self.model.save(filepath)