285 lines
8.4 KiB
Python
285 lines
8.4 KiB
Python
"""
|
||
学习率调度器模块:提供各种学习率调度策略
|
||
"""
|
||
import numpy as np
|
||
import math
|
||
from typing import Callable, Optional, Union, Dict
|
||
import tensorflow as tf
|
||
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("Scheduler")
|
||
|
||
|
||
def step_decay(epoch: int, initial_lr: float,
|
||
drop_rate: float = 0.5,
|
||
epochs_drop: int = 10) -> float:
|
||
"""
|
||
阶梯式学习率衰减
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
initial_lr: 初始学习率
|
||
drop_rate: 衰减率
|
||
epochs_drop: 每多少个epoch衰减一次
|
||
|
||
Returns:
|
||
新的学习率
|
||
"""
|
||
return initial_lr * math.pow(drop_rate, math.floor((1 + epoch) / epochs_drop))
|
||
|
||
|
||
def exponential_decay(epoch: int, initial_lr: float,
|
||
decay_rate: float = 0.9,
|
||
staircase: bool = False) -> float:
|
||
"""
|
||
指数衰减学习率
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
initial_lr: 初始学习率
|
||
decay_rate: 衰减率
|
||
staircase: 是否阶梯式衰减
|
||
|
||
Returns:
|
||
新的学习率
|
||
"""
|
||
if staircase:
|
||
return initial_lr * math.pow(decay_rate, math.floor(epoch))
|
||
else:
|
||
return initial_lr * math.pow(decay_rate, epoch)
|
||
|
||
|
||
def cosine_decay(epoch: int, initial_lr: float,
|
||
total_epochs: int = 100,
|
||
min_lr: float = 0) -> float:
|
||
"""
|
||
余弦退火学习率
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
initial_lr: 初始学习率
|
||
total_epochs: 总epoch数
|
||
min_lr: 最小学习率
|
||
|
||
Returns:
|
||
新的学习率
|
||
"""
|
||
return min_lr + 0.5 * (initial_lr - min_lr) * (1 + math.cos(math.pi * epoch / total_epochs))
|
||
|
||
|
||
def cosine_decay_with_warmup(epoch: int, initial_lr: float,
|
||
total_epochs: int = 100,
|
||
warmup_epochs: int = 5,
|
||
min_lr: float = 0,
|
||
warmup_init_lr: float = 0) -> float:
|
||
"""
|
||
带预热的余弦退火学习率
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
initial_lr: 初始学习率
|
||
total_epochs: 总epoch数
|
||
warmup_epochs: 预热epoch数
|
||
min_lr: 最小学习率
|
||
warmup_init_lr: 预热初始学习率
|
||
|
||
Returns:
|
||
新的学习率
|
||
"""
|
||
if epoch < warmup_epochs:
|
||
# 线性预热
|
||
return warmup_init_lr + (initial_lr - warmup_init_lr) * epoch / warmup_epochs
|
||
else:
|
||
# 余弦退火
|
||
return min_lr + 0.5 * (initial_lr - min_lr) * (
|
||
1 + math.cos(math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs))
|
||
)
|
||
|
||
|
||
def cyclical_learning_rate(epoch: int, initial_lr: float,
|
||
max_lr: float = 0.1,
|
||
step_size: int = 8,
|
||
gamma: float = 1.0) -> float:
|
||
"""
|
||
循环学习率
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
initial_lr: 初始学习率
|
||
max_lr: 最大学习率
|
||
step_size: 半周期大小(epoch数)
|
||
gamma: 循环衰减率
|
||
|
||
Returns:
|
||
新的学习率
|
||
"""
|
||
# 计算循环数
|
||
cycle = math.floor(1 + epoch / (2 * step_size))
|
||
|
||
# 计算x值,范围在[0, 2]
|
||
x = abs(epoch / step_size - 2 * cycle + 1)
|
||
|
||
# 应用循环衰减
|
||
lr_range = (max_lr - initial_lr) * math.pow(gamma, cycle - 1)
|
||
|
||
# 计算学习率
|
||
return initial_lr + lr_range * max(0, 1 - x)
|
||
|
||
|
||
def create_custom_scheduler(scheduler_type: str, **kwargs) -> Callable[[int, float], float]:
|
||
"""
|
||
创建自定义学习率调度器
|
||
|
||
Args:
|
||
scheduler_type: 调度器类型,可选值: 'step', 'exp', 'cosine', 'cosine_warmup', 'cyclical'
|
||
**kwargs: 调度器参数
|
||
|
||
Returns:
|
||
学习率调度函数
|
||
"""
|
||
scheduler_type = scheduler_type.lower()
|
||
|
||
if scheduler_type == 'step':
|
||
drop_rate = kwargs.get('drop_rate', 0.5)
|
||
epochs_drop = kwargs.get('epochs_drop', 10)
|
||
|
||
def scheduler(epoch, lr):
|
||
return step_decay(epoch, lr, drop_rate, epochs_drop)
|
||
|
||
return scheduler
|
||
|
||
elif scheduler_type == 'exp':
|
||
decay_rate = kwargs.get('decay_rate', 0.9)
|
||
staircase = kwargs.get('staircase', False)
|
||
|
||
def scheduler(epoch, lr):
|
||
if epoch == 0:
|
||
# 第一个epoch使用初始学习率
|
||
return lr
|
||
return exponential_decay(epoch, lr, decay_rate, staircase)
|
||
|
||
return scheduler
|
||
|
||
elif scheduler_type == 'cosine':
|
||
total_epochs = kwargs.get('total_epochs', 100)
|
||
min_lr = kwargs.get('min_lr', 0)
|
||
|
||
def scheduler(epoch, lr):
|
||
if epoch == 0:
|
||
return lr
|
||
return cosine_decay(epoch, lr, total_epochs, min_lr)
|
||
|
||
return scheduler
|
||
|
||
elif scheduler_type == 'cosine_warmup':
|
||
total_epochs = kwargs.get('total_epochs', 100)
|
||
warmup_epochs = kwargs.get('warmup_epochs', 5)
|
||
min_lr = kwargs.get('min_lr', 0)
|
||
warmup_init_lr = kwargs.get('warmup_init_lr', 0)
|
||
|
||
def scheduler(epoch, lr):
|
||
if epoch == 0:
|
||
return warmup_init_lr
|
||
return cosine_decay_with_warmup(epoch, lr, total_epochs, warmup_epochs, min_lr, warmup_init_lr)
|
||
|
||
return scheduler
|
||
|
||
elif scheduler_type == 'cyclical':
|
||
max_lr = kwargs.get('max_lr', 0.1)
|
||
step_size = kwargs.get('step_size', 8)
|
||
gamma = kwargs.get('gamma', 1.0)
|
||
|
||
def scheduler(epoch, lr):
|
||
if epoch == 0:
|
||
return lr
|
||
return cyclical_learning_rate(epoch, lr, max_lr, step_size, gamma)
|
||
|
||
return scheduler
|
||
|
||
else:
|
||
raise ValueError(f"不支持的调度器类型: {scheduler_type}")
|
||
|
||
|
||
class WarmupCosineDecayScheduler(tf.keras.callbacks.Callback):
|
||
"""预热余弦退火学习率调度器"""
|
||
|
||
def __init__(self, learning_rate_base: float,
|
||
total_steps: int,
|
||
warmup_learning_rate: float = 0.0,
|
||
warmup_steps: int = 0,
|
||
hold_base_rate_steps: int = 0,
|
||
verbose: int = 0):
|
||
"""
|
||
初始化预热余弦退火学习率调度器
|
||
|
||
Args:
|
||
learning_rate_base: 基础学习率
|
||
total_steps: 总步数
|
||
warmup_learning_rate: 预热学习率
|
||
warmup_steps: 预热步数
|
||
hold_base_rate_steps: 保持基础学习率的步数
|
||
verbose: 详细程度
|
||
"""
|
||
super().__init__()
|
||
|
||
self.learning_rate_base = learning_rate_base
|
||
self.total_steps = total_steps
|
||
self.warmup_learning_rate = warmup_learning_rate
|
||
self.warmup_steps = warmup_steps
|
||
self.hold_base_rate_steps = hold_base_rate_steps
|
||
self.verbose = verbose
|
||
|
||
# 学习率历史
|
||
self.learning_rates = []
|
||
|
||
def on_train_begin(self, logs: Optional[Dict] = None) -> None:
|
||
"""
|
||
训练开始时调用
|
||
|
||
Args:
|
||
logs: 训练日志
|
||
"""
|
||
self.current_step = 0
|
||
logger.info(f"预热余弦退火学习率调度器初始化: 基础学习率={self.learning_rate_base}, "
|
||
f"预热步数={self.warmup_steps}, 总步数={self.total_steps}")
|
||
|
||
def on_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
|
||
"""
|
||
每个batch结束时调用
|
||
|
||
Args:
|
||
batch: 当前batch索引
|
||
logs: 训练日志
|
||
"""
|
||
self.current_step += 1
|
||
lr = self._get_learning_rate()
|
||
|
||
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
|
||
self.learning_rates.append(lr)
|
||
|
||
def _get_learning_rate(self) -> float:
|
||
"""
|
||
计算当前学习率
|
||
|
||
Returns:
|
||
当前学习率
|
||
"""
|
||
if self.current_step < self.warmup_steps:
|
||
# 预热阶段:线性增加学习率
|
||
lr = self.warmup_learning_rate + self.current_step * (
|
||
(self.learning_rate_base - self.warmup_learning_rate) / self.warmup_steps
|
||
)
|
||
elif self.current_step < self.warmup_steps + self.hold_base_rate_steps:
|
||
# 保持基础学习率阶段
|
||
lr = self.learning_rate_base
|
||
else:
|
||
# 余弦退火阶段
|
||
cosine_steps = self.total_steps - self.warmup_steps - self.hold_base_rate_steps
|
||
cosine_current_step = self.current_step - self.warmup_steps - self.hold_base_rate_steps
|
||
lr = 0.5 * self.learning_rate_base * (
|
||
1 + math.cos(math.pi * cosine_current_step / cosine_steps)
|
||
)
|
||
|
||
return lr
|