""" 回调函数模块:提供用于模型训练的自定义回调函数 """ import os import time import numpy as np import tensorflow as tf from typing import List, Dict, Tuple, Optional, Any, Union import matplotlib.pyplot as plt from io import BytesIO from utils.logger import get_logger logger = get_logger("Callbacks") class MetricsHistory(tf.keras.callbacks.Callback): """跟踪训练过程中的指标历史""" def __init__(self, validation_data: Optional[Tuple] = None, metrics: Optional[List[str]] = None, save_path: Optional[str] = None): """ 初始化MetricsHistory回调 Args: validation_data: 验证数据,格式为(x_val, y_val) metrics: 要跟踪的指标列表 save_path: 指标历史的保存路径 """ super().__init__() self.validation_data = validation_data self.metrics = metrics or ['loss', 'accuracy'] self.save_path = save_path # 历史指标 self.history = {metric: [] for metric in self.metrics} if validation_data is not None: for metric in self.metrics: self.history[f'val_{metric}'] = [] def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None: """ 每个epoch结束时调用 Args: epoch: 当前epoch索引 logs: 训练日志 """ logs = logs or {} # 记录训练指标 for metric in self.metrics: if metric in logs: self.history[metric].append(logs[metric]) # 记录验证指标 if self.validation_data is not None: for metric in self.metrics: val_metric = f'val_{metric}' if val_metric in logs: self.history[val_metric].append(logs[val_metric]) def plot_metrics(self, save_path: Optional[str] = None) -> None: """ 绘制指标历史 Args: save_path: 图像保存路径,如果为None则使用初始化时设置的路径 """ plt.figure(figsize=(12, 5)) for i, metric in enumerate(self.metrics): plt.subplot(1, len(self.metrics), i + 1) if metric in self.history: plt.plot(self.history[metric], label=f'train_{metric}') val_metric = f'val_{metric}' if val_metric in self.history: plt.plot(self.history[val_metric], label=f'val_{metric}') plt.title(f'Model {metric}') plt.xlabel('Epoch') plt.ylabel(metric) plt.legend() plt.tight_layout() save_path = save_path or self.save_path if save_path: plt.savefig(save_path) logger.info(f"指标历史图已保存到: {save_path}") else: plt.show() class ConfusionMatrixCallback(tf.keras.callbacks.Callback): """计算并显示验证集上的混淆矩阵""" def __init__(self, validation_data: Tuple[np.ndarray, np.ndarray], class_names: Optional[List[str]] = None, log_dir: Optional[str] = None, freq: int = 1, fig_size: Tuple[int, int] = (10, 8)): """ 初始化ConfusionMatrixCallback Args: validation_data: 验证数据,格式为(x_val, y_val) class_names: 类别名称列表 log_dir: TensorBoard日志目录 freq: 计算混淆矩阵的频率(每多少个epoch计算一次) fig_size: 图像大小 """ super().__init__() self.x_val, self.y_val = validation_data self.class_names = class_names self.log_dir = log_dir self.freq = freq self.fig_size = fig_size # 如果提供了TensorBoard日志目录,创建一个文件写入器 if log_dir: self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'confusion_matrix')) else: self.file_writer = None def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None: """ 每个epoch结束时调用 Args: epoch: 当前epoch索引 logs: 训练日志 """ # 每freq个epoch计算一次混淆矩阵 if (epoch + 1) % self.freq == 0 or epoch == 0: # 获取预测结果 y_pred = np.argmax(self.model.predict(self.x_val), axis=1) # 确保y_val是一维数组 y_true = self.y_val if len(y_true.shape) > 1 and y_true.shape[1] > 1: y_true = np.argmax(y_true, axis=1) # 计算混淆矩阵 cm = tf.math.confusion_matrix(y_true, y_pred).numpy() # 归一化混淆矩阵 cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True) # 绘制混淆矩阵 fig = self._plot_confusion_matrix(cm_norm, epoch + 1) # 如果有TensorBoard日志,将图像添加到TensorBoard if self.file_writer: with self.file_writer.as_default(): # 将matplotlib图像转换为TensorBoard图像 buf = BytesIO() fig.savefig(buf, format='png') buf.seek(0) # 将PNG编码为字符串,并创建图像 image = tf.image.decode_png(buf.getvalue(), channels=4) image = tf.expand_dims(image, 0) # 添加到TensorBoard tf.summary.image(f'Confusion Matrix (Epoch {epoch + 1})', image, step=epoch) plt.close(fig) def _plot_confusion_matrix(self, cm: np.ndarray, epoch: int) -> plt.Figure: """ 绘制混淆矩阵 Args: cm: 混淆矩阵 epoch: 当前epoch Returns: matplotlib图像 """ fig, ax = plt.subplots(figsize=self.fig_size) # 使用热图显示混淆矩阵 im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) ax.figure.colorbar(im, ax=ax) # 设置坐标轴标签 if self.class_names: ax.set( xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), xticklabels=self.class_names, yticklabels=self.class_names ) # 旋转x轴标签 plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # 在每个单元格中显示数值 thresh = cm.max() / 2.0 for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text(j, i, format(cm[i, j], '.2f'), ha="center", va="center", color="white" if cm[i, j] > thresh else "black") ax.set_title(f"Normalized Confusion Matrix (Epoch {epoch})") ax.set_ylabel('True label') ax.set_xlabel('Predicted label') fig.tight_layout() return fig class TimingCallback(tf.keras.callbacks.Callback): """测量训练时间的回调函数""" def __init__(self): """初始化TimingCallback""" super().__init__() self.epoch_times = [] self.batch_times = [] self.epoch_start_time = None self.batch_start_time = None self.training_start_time = None def on_train_begin(self, logs: Dict[str, float] = None) -> None: """ 训练开始时调用 Args: logs: 训练日志 """ self.training_start_time = time.time() def on_train_end(self, logs: Dict[str, float] = None) -> None: """ 训练结束时调用 Args: logs: 训练日志 """ training_time = time.time() - self.training_start_time logger.info(f"总训练时间: {training_time:.2f} 秒") if self.epoch_times: avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times) logger.info(f"平均每个epoch时间: {avg_epoch_time:.2f} 秒") if self.batch_times: avg_batch_time = sum(self.batch_times) / len(self.batch_times) logger.info(f"平均每个batch时间: {avg_batch_time:.4f} 秒") def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None: """ 每个epoch开始时调用 Args: epoch: 当前epoch索引 logs: 训练日志 """ self.epoch_start_time = time.time() def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None: """ 每个epoch结束时调用 Args: epoch: 当前epoch索引 logs: 训练日志 """ epoch_time = time.time() - self.epoch_start_time self.epoch_times.append(epoch_time) # 将epoch时间添加到日志中 if logs is not None: logs['epoch_time'] = epoch_time def on_batch_begin(self, batch: int, logs: Dict[str, float] = None) -> None: """ 每个batch开始时调用 Args: batch: 当前batch索引 logs: 训练日志 """ self.batch_start_time = time.time() def on_batch_end(self, batch: int, logs: Dict[str, float] = None) -> None: """ 每个batch结束时调用 Args: batch: 当前batch索引 logs: 训练日志 """ batch_time = time.time() - self.batch_start_time self.batch_times.append(batch_time) class LearningRateSchedulerCallback(tf.keras.callbacks.Callback): """学习率调度器回调函数""" def __init__(self, scheduler_func: Callable[[int, float], float], verbose: int = 0, log_dir: Optional[str] = None): """ 初始化LearningRateSchedulerCallback Args: scheduler_func: 学习率调度函数,接收(epoch, lr)参数,返回新的学习率 verbose: 详细程度 log_dir: TensorBoard日志目录 """ super().__init__() self.scheduler_func = scheduler_func self.verbose = verbose # 如果提供了TensorBoard日志目录,创建一个文件写入器 if log_dir: self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'learning_rate')) else: self.file_writer = None # 学习率历史 self.lr_history = [] def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None: """ 每个epoch开始时调用 Args: epoch: 当前epoch索引 logs: 训练日志 """ if not hasattr(self.model.optimizer, 'lr'): raise ValueError('Optimizer must have a "lr" attribute.') # 获取当前学习率 current_lr = float(tf.keras.backend.get_value(self.model.optimizer.lr)) # 计算新的学习率 new_lr = self.scheduler_func(epoch, current_lr) # 设置新的学习率 tf.keras.backend.set_value(self.model.optimizer.lr, new_lr) # 记录学习率 self.lr_history.append(new_lr) # 记录到TensorBoard if self.file_writer: with self.file_writer.as_default(): tf.summary.scalar('learning_rate', new_lr, step=epoch) if self.verbose > 0: logger.info(f"Epoch {epoch + 1}: 学习率设置为 {new_lr:.6f}") def get_lr_history(self) -> List[float]: """ 获取学习率历史 Returns: 学习率历史列表 """ return self.lr_history class EarlyStoppingCallback(tf.keras.callbacks.EarlyStopping): """增强版早停回调函数,支持最小变化率""" def __init__(self, monitor: str = 'val_loss', min_delta: float = 0, min_delta_ratio: float = 0, patience: int = 0, verbose: int = 0, mode: str = 'auto', baseline: Optional[float] = None, restore_best_weights: bool = False): """ 初始化EarlyStoppingCallback Args: monitor: 监控的指标 min_delta: 视为改进的最小绝对变化 min_delta_ratio: 视为改进的最小相对变化率 patience: 没有改进的轮数 verbose: 详细程度 mode: 'auto', 'min' 或 'max' baseline: 基准值 restore_best_weights: 是否恢复最佳权重 """ super().__init__( monitor=monitor, min_delta=min_delta, patience=patience, verbose=verbose, mode=mode, baseline=baseline, restore_best_weights=restore_best_weights ) self.min_delta_ratio = min_delta_ratio def _is_improvement(self, current: float, reference: float) -> bool: """ 判断是否有所改进 Args: current: 当前值 reference: 参考值 Returns: 是否有所改进 """ # 先检查绝对变化 if super()._is_improvement(current, reference): return True # 再检查相对变化率 if self.monitor_op == np.less: # 对于 'min' 模式,值越小越好 relative_delta = (reference - current) / reference if reference != 0 else 0 return relative_delta > self.min_delta_ratio else: # 对于 'max' 模式,值越大越好 relative_delta = (current - reference) / reference if reference != 0 else 0 return relative_delta > self.min_delta_ratio