2025-03-08 01:34:36 +08:00

431 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
回调函数模块:提供用于模型训练的自定义回调函数
"""
import os
import time
import numpy as np
import tensorflow as tf
from typing import List, Dict, Tuple, Optional, Any, Union
import matplotlib.pyplot as plt
from io import BytesIO
from utils.logger import get_logger
logger = get_logger("Callbacks")
class MetricsHistory(tf.keras.callbacks.Callback):
"""跟踪训练过程中的指标历史"""
def __init__(self, validation_data: Optional[Tuple] = None,
metrics: Optional[List[str]] = None,
save_path: Optional[str] = None):
"""
初始化MetricsHistory回调
Args:
validation_data: 验证数据,格式为(x_val, y_val)
metrics: 要跟踪的指标列表
save_path: 指标历史的保存路径
"""
super().__init__()
self.validation_data = validation_data
self.metrics = metrics or ['loss', 'accuracy']
self.save_path = save_path
# 历史指标
self.history = {metric: [] for metric in self.metrics}
if validation_data is not None:
for metric in self.metrics:
self.history[f'val_{metric}'] = []
def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
"""
每个epoch结束时调用
Args:
epoch: 当前epoch索引
logs: 训练日志
"""
logs = logs or {}
# 记录训练指标
for metric in self.metrics:
if metric in logs:
self.history[metric].append(logs[metric])
# 记录验证指标
if self.validation_data is not None:
for metric in self.metrics:
val_metric = f'val_{metric}'
if val_metric in logs:
self.history[val_metric].append(logs[val_metric])
def plot_metrics(self, save_path: Optional[str] = None) -> None:
"""
绘制指标历史
Args:
save_path: 图像保存路径如果为None则使用初始化时设置的路径
"""
plt.figure(figsize=(12, 5))
for i, metric in enumerate(self.metrics):
plt.subplot(1, len(self.metrics), i + 1)
if metric in self.history:
plt.plot(self.history[metric], label=f'train_{metric}')
val_metric = f'val_{metric}'
if val_metric in self.history:
plt.plot(self.history[val_metric], label=f'val_{metric}')
plt.title(f'Model {metric}')
plt.xlabel('Epoch')
plt.ylabel(metric)
plt.legend()
plt.tight_layout()
save_path = save_path or self.save_path
if save_path:
plt.savefig(save_path)
logger.info(f"指标历史图已保存到: {save_path}")
else:
plt.show()
class ConfusionMatrixCallback(tf.keras.callbacks.Callback):
"""计算并显示验证集上的混淆矩阵"""
def __init__(self, validation_data: Tuple[np.ndarray, np.ndarray],
class_names: Optional[List[str]] = None,
log_dir: Optional[str] = None,
freq: int = 1,
fig_size: Tuple[int, int] = (10, 8)):
"""
初始化ConfusionMatrixCallback
Args:
validation_data: 验证数据,格式为(x_val, y_val)
class_names: 类别名称列表
log_dir: TensorBoard日志目录
freq: 计算混淆矩阵的频率每多少个epoch计算一次
fig_size: 图像大小
"""
super().__init__()
self.x_val, self.y_val = validation_data
self.class_names = class_names
self.log_dir = log_dir
self.freq = freq
self.fig_size = fig_size
# 如果提供了TensorBoard日志目录创建一个文件写入器
if log_dir:
self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'confusion_matrix'))
else:
self.file_writer = None
def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
"""
每个epoch结束时调用
Args:
epoch: 当前epoch索引
logs: 训练日志
"""
# 每freq个epoch计算一次混淆矩阵
if (epoch + 1) % self.freq == 0 or epoch == 0:
# 获取预测结果
y_pred = np.argmax(self.model.predict(self.x_val), axis=1)
# 确保y_val是一维数组
y_true = self.y_val
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 计算混淆矩阵
cm = tf.math.confusion_matrix(y_true, y_pred).numpy()
# 归一化混淆矩阵
cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
# 绘制混淆矩阵
fig = self._plot_confusion_matrix(cm_norm, epoch + 1)
# 如果有TensorBoard日志将图像添加到TensorBoard
if self.file_writer:
with self.file_writer.as_default():
# 将matplotlib图像转换为TensorBoard图像
buf = BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
# 将PNG编码为字符串并创建图像
image = tf.image.decode_png(buf.getvalue(), channels=4)
image = tf.expand_dims(image, 0)
# 添加到TensorBoard
tf.summary.image(f'Confusion Matrix (Epoch {epoch + 1})', image, step=epoch)
plt.close(fig)
def _plot_confusion_matrix(self, cm: np.ndarray, epoch: int) -> plt.Figure:
"""
绘制混淆矩阵
Args:
cm: 混淆矩阵
epoch: 当前epoch
Returns:
matplotlib图像
"""
fig, ax = plt.subplots(figsize=self.fig_size)
# 使用热图显示混淆矩阵
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
# 设置坐标轴标签
if self.class_names:
ax.set(
xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=self.class_names,
yticklabels=self.class_names
)
# 旋转x轴标签
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# 在每个单元格中显示数值
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], '.2f'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
ax.set_title(f"Normalized Confusion Matrix (Epoch {epoch})")
ax.set_ylabel('True label')
ax.set_xlabel('Predicted label')
fig.tight_layout()
return fig
class TimingCallback(tf.keras.callbacks.Callback):
"""测量训练时间的回调函数"""
def __init__(self):
"""初始化TimingCallback"""
super().__init__()
self.epoch_times = []
self.batch_times = []
self.epoch_start_time = None
self.batch_start_time = None
self.training_start_time = None
def on_train_begin(self, logs: Dict[str, float] = None) -> None:
"""
训练开始时调用
Args:
logs: 训练日志
"""
self.training_start_time = time.time()
def on_train_end(self, logs: Dict[str, float] = None) -> None:
"""
训练结束时调用
Args:
logs: 训练日志
"""
training_time = time.time() - self.training_start_time
logger.info(f"总训练时间: {training_time:.2f}")
if self.epoch_times:
avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
logger.info(f"平均每个epoch时间: {avg_epoch_time:.2f}")
if self.batch_times:
avg_batch_time = sum(self.batch_times) / len(self.batch_times)
logger.info(f"平均每个batch时间: {avg_batch_time:.4f}")
def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None:
"""
每个epoch开始时调用
Args:
epoch: 当前epoch索引
logs: 训练日志
"""
self.epoch_start_time = time.time()
def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
"""
每个epoch结束时调用
Args:
epoch: 当前epoch索引
logs: 训练日志
"""
epoch_time = time.time() - self.epoch_start_time
self.epoch_times.append(epoch_time)
# 将epoch时间添加到日志中
if logs is not None:
logs['epoch_time'] = epoch_time
def on_batch_begin(self, batch: int, logs: Dict[str, float] = None) -> None:
"""
每个batch开始时调用
Args:
batch: 当前batch索引
logs: 训练日志
"""
self.batch_start_time = time.time()
def on_batch_end(self, batch: int, logs: Dict[str, float] = None) -> None:
"""
每个batch结束时调用
Args:
batch: 当前batch索引
logs: 训练日志
"""
batch_time = time.time() - self.batch_start_time
self.batch_times.append(batch_time)
class LearningRateSchedulerCallback(tf.keras.callbacks.Callback):
"""学习率调度器回调函数"""
def __init__(self, scheduler_func: Callable[[int, float], float],
verbose: int = 0,
log_dir: Optional[str] = None):
"""
初始化LearningRateSchedulerCallback
Args:
scheduler_func: 学习率调度函数,接收(epoch, lr)参数,返回新的学习率
verbose: 详细程度
log_dir: TensorBoard日志目录
"""
super().__init__()
self.scheduler_func = scheduler_func
self.verbose = verbose
# 如果提供了TensorBoard日志目录创建一个文件写入器
if log_dir:
self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'learning_rate'))
else:
self.file_writer = None
# 学习率历史
self.lr_history = []
def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None:
"""
每个epoch开始时调用
Args:
epoch: 当前epoch索引
logs: 训练日志
"""
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
# 获取当前学习率
current_lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
# 计算新的学习率
new_lr = self.scheduler_func(epoch, current_lr)
# 设置新的学习率
tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
# 记录学习率
self.lr_history.append(new_lr)
# 记录到TensorBoard
if self.file_writer:
with self.file_writer.as_default():
tf.summary.scalar('learning_rate', new_lr, step=epoch)
if self.verbose > 0:
logger.info(f"Epoch {epoch + 1}: 学习率设置为 {new_lr:.6f}")
def get_lr_history(self) -> List[float]:
"""
获取学习率历史
Returns:
学习率历史列表
"""
return self.lr_history
class EarlyStoppingCallback(tf.keras.callbacks.EarlyStopping):
"""增强版早停回调函数,支持最小变化率"""
def __init__(self, monitor: str = 'val_loss',
min_delta: float = 0,
min_delta_ratio: float = 0,
patience: int = 0,
verbose: int = 0,
mode: str = 'auto',
baseline: Optional[float] = None,
restore_best_weights: bool = False):
"""
初始化EarlyStoppingCallback
Args:
monitor: 监控的指标
min_delta: 视为改进的最小绝对变化
min_delta_ratio: 视为改进的最小相对变化率
patience: 没有改进的轮数
verbose: 详细程度
mode: 'auto', 'min''max'
baseline: 基准值
restore_best_weights: 是否恢复最佳权重
"""
super().__init__(
monitor=monitor,
min_delta=min_delta,
patience=patience,
verbose=verbose,
mode=mode,
baseline=baseline,
restore_best_weights=restore_best_weights
)
self.min_delta_ratio = min_delta_ratio
def _is_improvement(self, current: float, reference: float) -> bool:
"""
判断是否有所改进
Args:
current: 当前值
reference: 参考值
Returns:
是否有所改进
"""
# 先检查绝对变化
if super()._is_improvement(current, reference):
return True
# 再检查相对变化率
if self.monitor_op == np.less:
# 对于 'min' 模式,值越小越好
relative_delta = (reference - current) / reference if reference != 0 else 0
return relative_delta > self.min_delta_ratio
else:
# 对于 'max' 模式,值越大越好
relative_delta = (current - reference) / reference if reference != 0 else 0
return relative_delta > self.min_delta_ratio