431 lines
14 KiB
Python
431 lines
14 KiB
Python
"""
|
||
回调函数模块:提供用于模型训练的自定义回调函数
|
||
"""
|
||
import os
|
||
import time
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import matplotlib.pyplot as plt
|
||
from io import BytesIO
|
||
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("Callbacks")
|
||
|
||
|
||
class MetricsHistory(tf.keras.callbacks.Callback):
|
||
"""跟踪训练过程中的指标历史"""
|
||
|
||
def __init__(self, validation_data: Optional[Tuple] = None,
|
||
metrics: Optional[List[str]] = None,
|
||
save_path: Optional[str] = None):
|
||
"""
|
||
初始化MetricsHistory回调
|
||
|
||
Args:
|
||
validation_data: 验证数据,格式为(x_val, y_val)
|
||
metrics: 要跟踪的指标列表
|
||
save_path: 指标历史的保存路径
|
||
"""
|
||
super().__init__()
|
||
self.validation_data = validation_data
|
||
self.metrics = metrics or ['loss', 'accuracy']
|
||
self.save_path = save_path
|
||
|
||
# 历史指标
|
||
self.history = {metric: [] for metric in self.metrics}
|
||
if validation_data is not None:
|
||
for metric in self.metrics:
|
||
self.history[f'val_{metric}'] = []
|
||
|
||
def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个epoch结束时调用
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
logs: 训练日志
|
||
"""
|
||
logs = logs or {}
|
||
|
||
# 记录训练指标
|
||
for metric in self.metrics:
|
||
if metric in logs:
|
||
self.history[metric].append(logs[metric])
|
||
|
||
# 记录验证指标
|
||
if self.validation_data is not None:
|
||
for metric in self.metrics:
|
||
val_metric = f'val_{metric}'
|
||
if val_metric in logs:
|
||
self.history[val_metric].append(logs[val_metric])
|
||
|
||
def plot_metrics(self, save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制指标历史
|
||
|
||
Args:
|
||
save_path: 图像保存路径,如果为None则使用初始化时设置的路径
|
||
"""
|
||
plt.figure(figsize=(12, 5))
|
||
|
||
for i, metric in enumerate(self.metrics):
|
||
plt.subplot(1, len(self.metrics), i + 1)
|
||
|
||
if metric in self.history:
|
||
plt.plot(self.history[metric], label=f'train_{metric}')
|
||
|
||
val_metric = f'val_{metric}'
|
||
if val_metric in self.history:
|
||
plt.plot(self.history[val_metric], label=f'val_{metric}')
|
||
|
||
plt.title(f'Model {metric}')
|
||
plt.xlabel('Epoch')
|
||
plt.ylabel(metric)
|
||
plt.legend()
|
||
|
||
plt.tight_layout()
|
||
|
||
save_path = save_path or self.save_path
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"指标历史图已保存到: {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
|
||
class ConfusionMatrixCallback(tf.keras.callbacks.Callback):
|
||
"""计算并显示验证集上的混淆矩阵"""
|
||
|
||
def __init__(self, validation_data: Tuple[np.ndarray, np.ndarray],
|
||
class_names: Optional[List[str]] = None,
|
||
log_dir: Optional[str] = None,
|
||
freq: int = 1,
|
||
fig_size: Tuple[int, int] = (10, 8)):
|
||
"""
|
||
初始化ConfusionMatrixCallback
|
||
|
||
Args:
|
||
validation_data: 验证数据,格式为(x_val, y_val)
|
||
class_names: 类别名称列表
|
||
log_dir: TensorBoard日志目录
|
||
freq: 计算混淆矩阵的频率(每多少个epoch计算一次)
|
||
fig_size: 图像大小
|
||
"""
|
||
super().__init__()
|
||
self.x_val, self.y_val = validation_data
|
||
self.class_names = class_names
|
||
self.log_dir = log_dir
|
||
self.freq = freq
|
||
self.fig_size = fig_size
|
||
|
||
# 如果提供了TensorBoard日志目录,创建一个文件写入器
|
||
if log_dir:
|
||
self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'confusion_matrix'))
|
||
else:
|
||
self.file_writer = None
|
||
|
||
def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个epoch结束时调用
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
logs: 训练日志
|
||
"""
|
||
# 每freq个epoch计算一次混淆矩阵
|
||
if (epoch + 1) % self.freq == 0 or epoch == 0:
|
||
# 获取预测结果
|
||
y_pred = np.argmax(self.model.predict(self.x_val), axis=1)
|
||
|
||
# 确保y_val是一维数组
|
||
y_true = self.y_val
|
||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||
y_true = np.argmax(y_true, axis=1)
|
||
|
||
# 计算混淆矩阵
|
||
cm = tf.math.confusion_matrix(y_true, y_pred).numpy()
|
||
|
||
# 归一化混淆矩阵
|
||
cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
|
||
|
||
# 绘制混淆矩阵
|
||
fig = self._plot_confusion_matrix(cm_norm, epoch + 1)
|
||
|
||
# 如果有TensorBoard日志,将图像添加到TensorBoard
|
||
if self.file_writer:
|
||
with self.file_writer.as_default():
|
||
# 将matplotlib图像转换为TensorBoard图像
|
||
buf = BytesIO()
|
||
fig.savefig(buf, format='png')
|
||
buf.seek(0)
|
||
|
||
# 将PNG编码为字符串,并创建图像
|
||
image = tf.image.decode_png(buf.getvalue(), channels=4)
|
||
image = tf.expand_dims(image, 0)
|
||
|
||
# 添加到TensorBoard
|
||
tf.summary.image(f'Confusion Matrix (Epoch {epoch + 1})', image, step=epoch)
|
||
|
||
plt.close(fig)
|
||
|
||
def _plot_confusion_matrix(self, cm: np.ndarray, epoch: int) -> plt.Figure:
|
||
"""
|
||
绘制混淆矩阵
|
||
|
||
Args:
|
||
cm: 混淆矩阵
|
||
epoch: 当前epoch
|
||
|
||
Returns:
|
||
matplotlib图像
|
||
"""
|
||
fig, ax = plt.subplots(figsize=self.fig_size)
|
||
|
||
# 使用热图显示混淆矩阵
|
||
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
||
ax.figure.colorbar(im, ax=ax)
|
||
|
||
# 设置坐标轴标签
|
||
if self.class_names:
|
||
ax.set(
|
||
xticks=np.arange(cm.shape[1]),
|
||
yticks=np.arange(cm.shape[0]),
|
||
xticklabels=self.class_names,
|
||
yticklabels=self.class_names
|
||
)
|
||
|
||
# 旋转x轴标签
|
||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
||
|
||
# 在每个单元格中显示数值
|
||
thresh = cm.max() / 2.0
|
||
for i in range(cm.shape[0]):
|
||
for j in range(cm.shape[1]):
|
||
ax.text(j, i, format(cm[i, j], '.2f'),
|
||
ha="center", va="center",
|
||
color="white" if cm[i, j] > thresh else "black")
|
||
|
||
ax.set_title(f"Normalized Confusion Matrix (Epoch {epoch})")
|
||
ax.set_ylabel('True label')
|
||
ax.set_xlabel('Predicted label')
|
||
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
class TimingCallback(tf.keras.callbacks.Callback):
|
||
"""测量训练时间的回调函数"""
|
||
|
||
def __init__(self):
|
||
"""初始化TimingCallback"""
|
||
super().__init__()
|
||
self.epoch_times = []
|
||
self.batch_times = []
|
||
self.epoch_start_time = None
|
||
self.batch_start_time = None
|
||
self.training_start_time = None
|
||
|
||
def on_train_begin(self, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
训练开始时调用
|
||
|
||
Args:
|
||
logs: 训练日志
|
||
"""
|
||
self.training_start_time = time.time()
|
||
|
||
def on_train_end(self, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
训练结束时调用
|
||
|
||
Args:
|
||
logs: 训练日志
|
||
"""
|
||
training_time = time.time() - self.training_start_time
|
||
logger.info(f"总训练时间: {training_time:.2f} 秒")
|
||
|
||
if self.epoch_times:
|
||
avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
|
||
logger.info(f"平均每个epoch时间: {avg_epoch_time:.2f} 秒")
|
||
|
||
if self.batch_times:
|
||
avg_batch_time = sum(self.batch_times) / len(self.batch_times)
|
||
logger.info(f"平均每个batch时间: {avg_batch_time:.4f} 秒")
|
||
|
||
def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个epoch开始时调用
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
logs: 训练日志
|
||
"""
|
||
self.epoch_start_time = time.time()
|
||
|
||
def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个epoch结束时调用
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
logs: 训练日志
|
||
"""
|
||
epoch_time = time.time() - self.epoch_start_time
|
||
self.epoch_times.append(epoch_time)
|
||
|
||
# 将epoch时间添加到日志中
|
||
if logs is not None:
|
||
logs['epoch_time'] = epoch_time
|
||
|
||
def on_batch_begin(self, batch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个batch开始时调用
|
||
|
||
Args:
|
||
batch: 当前batch索引
|
||
logs: 训练日志
|
||
"""
|
||
self.batch_start_time = time.time()
|
||
|
||
def on_batch_end(self, batch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个batch结束时调用
|
||
|
||
Args:
|
||
batch: 当前batch索引
|
||
logs: 训练日志
|
||
"""
|
||
batch_time = time.time() - self.batch_start_time
|
||
self.batch_times.append(batch_time)
|
||
|
||
|
||
class LearningRateSchedulerCallback(tf.keras.callbacks.Callback):
|
||
"""学习率调度器回调函数"""
|
||
|
||
def __init__(self, scheduler_func: Callable[[int, float], float],
|
||
verbose: int = 0,
|
||
log_dir: Optional[str] = None):
|
||
"""
|
||
初始化LearningRateSchedulerCallback
|
||
|
||
Args:
|
||
scheduler_func: 学习率调度函数,接收(epoch, lr)参数,返回新的学习率
|
||
verbose: 详细程度
|
||
log_dir: TensorBoard日志目录
|
||
"""
|
||
super().__init__()
|
||
self.scheduler_func = scheduler_func
|
||
self.verbose = verbose
|
||
|
||
# 如果提供了TensorBoard日志目录,创建一个文件写入器
|
||
if log_dir:
|
||
self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'learning_rate'))
|
||
else:
|
||
self.file_writer = None
|
||
|
||
# 学习率历史
|
||
self.lr_history = []
|
||
|
||
def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个epoch开始时调用
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
logs: 训练日志
|
||
"""
|
||
if not hasattr(self.model.optimizer, 'lr'):
|
||
raise ValueError('Optimizer must have a "lr" attribute.')
|
||
|
||
# 获取当前学习率
|
||
current_lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
|
||
|
||
# 计算新的学习率
|
||
new_lr = self.scheduler_func(epoch, current_lr)
|
||
|
||
# 设置新的学习率
|
||
tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
|
||
|
||
# 记录学习率
|
||
self.lr_history.append(new_lr)
|
||
|
||
# 记录到TensorBoard
|
||
if self.file_writer:
|
||
with self.file_writer.as_default():
|
||
tf.summary.scalar('learning_rate', new_lr, step=epoch)
|
||
|
||
if self.verbose > 0:
|
||
logger.info(f"Epoch {epoch + 1}: 学习率设置为 {new_lr:.6f}")
|
||
|
||
def get_lr_history(self) -> List[float]:
|
||
"""
|
||
获取学习率历史
|
||
|
||
Returns:
|
||
学习率历史列表
|
||
"""
|
||
return self.lr_history
|
||
|
||
|
||
class EarlyStoppingCallback(tf.keras.callbacks.EarlyStopping):
|
||
"""增强版早停回调函数,支持最小变化率"""
|
||
|
||
def __init__(self, monitor: str = 'val_loss',
|
||
min_delta: float = 0,
|
||
min_delta_ratio: float = 0,
|
||
patience: int = 0,
|
||
verbose: int = 0,
|
||
mode: str = 'auto',
|
||
baseline: Optional[float] = None,
|
||
restore_best_weights: bool = False):
|
||
"""
|
||
初始化EarlyStoppingCallback
|
||
|
||
Args:
|
||
monitor: 监控的指标
|
||
min_delta: 视为改进的最小绝对变化
|
||
min_delta_ratio: 视为改进的最小相对变化率
|
||
patience: 没有改进的轮数
|
||
verbose: 详细程度
|
||
mode: 'auto', 'min' 或 'max'
|
||
baseline: 基准值
|
||
restore_best_weights: 是否恢复最佳权重
|
||
"""
|
||
super().__init__(
|
||
monitor=monitor,
|
||
min_delta=min_delta,
|
||
patience=patience,
|
||
verbose=verbose,
|
||
mode=mode,
|
||
baseline=baseline,
|
||
restore_best_weights=restore_best_weights
|
||
)
|
||
self.min_delta_ratio = min_delta_ratio
|
||
|
||
def _is_improvement(self, current: float, reference: float) -> bool:
|
||
"""
|
||
判断是否有所改进
|
||
|
||
Args:
|
||
current: 当前值
|
||
reference: 参考值
|
||
|
||
Returns:
|
||
是否有所改进
|
||
"""
|
||
# 先检查绝对变化
|
||
if super()._is_improvement(current, reference):
|
||
return True
|
||
|
||
# 再检查相对变化率
|
||
if self.monitor_op == np.less:
|
||
# 对于 'min' 模式,值越小越好
|
||
relative_delta = (reference - current) / reference if reference != 0 else 0
|
||
return relative_delta > self.min_delta_ratio
|
||
else:
|
||
# 对于 'max' 模式,值越大越好
|
||
relative_delta = (current - reference) / reference if reference != 0 else 0
|
||
return relative_delta > self.min_delta_ratio
|