2025-03-08 01:34:36 +08:00

285 lines
8.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
学习率调度器模块:提供各种学习率调度策略
"""
import numpy as np
import math
from typing import Callable, Optional, Union, Dict
import tensorflow as tf
from utils.logger import get_logger
logger = get_logger("Scheduler")
def step_decay(epoch: int, initial_lr: float,
drop_rate: float = 0.5,
epochs_drop: int = 10) -> float:
"""
阶梯式学习率衰减
Args:
epoch: 当前epoch索引
initial_lr: 初始学习率
drop_rate: 衰减率
epochs_drop: 每多少个epoch衰减一次
Returns:
新的学习率
"""
return initial_lr * math.pow(drop_rate, math.floor((1 + epoch) / epochs_drop))
def exponential_decay(epoch: int, initial_lr: float,
decay_rate: float = 0.9,
staircase: bool = False) -> float:
"""
指数衰减学习率
Args:
epoch: 当前epoch索引
initial_lr: 初始学习率
decay_rate: 衰减率
staircase: 是否阶梯式衰减
Returns:
新的学习率
"""
if staircase:
return initial_lr * math.pow(decay_rate, math.floor(epoch))
else:
return initial_lr * math.pow(decay_rate, epoch)
def cosine_decay(epoch: int, initial_lr: float,
total_epochs: int = 100,
min_lr: float = 0) -> float:
"""
余弦退火学习率
Args:
epoch: 当前epoch索引
initial_lr: 初始学习率
total_epochs: 总epoch数
min_lr: 最小学习率
Returns:
新的学习率
"""
return min_lr + 0.5 * (initial_lr - min_lr) * (1 + math.cos(math.pi * epoch / total_epochs))
def cosine_decay_with_warmup(epoch: int, initial_lr: float,
total_epochs: int = 100,
warmup_epochs: int = 5,
min_lr: float = 0,
warmup_init_lr: float = 0) -> float:
"""
带预热的余弦退火学习率
Args:
epoch: 当前epoch索引
initial_lr: 初始学习率
total_epochs: 总epoch数
warmup_epochs: 预热epoch数
min_lr: 最小学习率
warmup_init_lr: 预热初始学习率
Returns:
新的学习率
"""
if epoch < warmup_epochs:
# 线性预热
return warmup_init_lr + (initial_lr - warmup_init_lr) * epoch / warmup_epochs
else:
# 余弦退火
return min_lr + 0.5 * (initial_lr - min_lr) * (
1 + math.cos(math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs))
)
def cyclical_learning_rate(epoch: int, initial_lr: float,
max_lr: float = 0.1,
step_size: int = 8,
gamma: float = 1.0) -> float:
"""
循环学习率
Args:
epoch: 当前epoch索引
initial_lr: 初始学习率
max_lr: 最大学习率
step_size: 半周期大小epoch数
gamma: 循环衰减率
Returns:
新的学习率
"""
# 计算循环数
cycle = math.floor(1 + epoch / (2 * step_size))
# 计算x值范围在[0, 2]
x = abs(epoch / step_size - 2 * cycle + 1)
# 应用循环衰减
lr_range = (max_lr - initial_lr) * math.pow(gamma, cycle - 1)
# 计算学习率
return initial_lr + lr_range * max(0, 1 - x)
def create_custom_scheduler(scheduler_type: str, **kwargs) -> Callable[[int, float], float]:
"""
创建自定义学习率调度器
Args:
scheduler_type: 调度器类型,可选值: 'step', 'exp', 'cosine', 'cosine_warmup', 'cyclical'
**kwargs: 调度器参数
Returns:
学习率调度函数
"""
scheduler_type = scheduler_type.lower()
if scheduler_type == 'step':
drop_rate = kwargs.get('drop_rate', 0.5)
epochs_drop = kwargs.get('epochs_drop', 10)
def scheduler(epoch, lr):
return step_decay(epoch, lr, drop_rate, epochs_drop)
return scheduler
elif scheduler_type == 'exp':
decay_rate = kwargs.get('decay_rate', 0.9)
staircase = kwargs.get('staircase', False)
def scheduler(epoch, lr):
if epoch == 0:
# 第一个epoch使用初始学习率
return lr
return exponential_decay(epoch, lr, decay_rate, staircase)
return scheduler
elif scheduler_type == 'cosine':
total_epochs = kwargs.get('total_epochs', 100)
min_lr = kwargs.get('min_lr', 0)
def scheduler(epoch, lr):
if epoch == 0:
return lr
return cosine_decay(epoch, lr, total_epochs, min_lr)
return scheduler
elif scheduler_type == 'cosine_warmup':
total_epochs = kwargs.get('total_epochs', 100)
warmup_epochs = kwargs.get('warmup_epochs', 5)
min_lr = kwargs.get('min_lr', 0)
warmup_init_lr = kwargs.get('warmup_init_lr', 0)
def scheduler(epoch, lr):
if epoch == 0:
return warmup_init_lr
return cosine_decay_with_warmup(epoch, lr, total_epochs, warmup_epochs, min_lr, warmup_init_lr)
return scheduler
elif scheduler_type == 'cyclical':
max_lr = kwargs.get('max_lr', 0.1)
step_size = kwargs.get('step_size', 8)
gamma = kwargs.get('gamma', 1.0)
def scheduler(epoch, lr):
if epoch == 0:
return lr
return cyclical_learning_rate(epoch, lr, max_lr, step_size, gamma)
return scheduler
else:
raise ValueError(f"不支持的调度器类型: {scheduler_type}")
class WarmupCosineDecayScheduler(tf.keras.callbacks.Callback):
"""预热余弦退火学习率调度器"""
def __init__(self, learning_rate_base: float,
total_steps: int,
warmup_learning_rate: float = 0.0,
warmup_steps: int = 0,
hold_base_rate_steps: int = 0,
verbose: int = 0):
"""
初始化预热余弦退火学习率调度器
Args:
learning_rate_base: 基础学习率
total_steps: 总步数
warmup_learning_rate: 预热学习率
warmup_steps: 预热步数
hold_base_rate_steps: 保持基础学习率的步数
verbose: 详细程度
"""
super().__init__()
self.learning_rate_base = learning_rate_base
self.total_steps = total_steps
self.warmup_learning_rate = warmup_learning_rate
self.warmup_steps = warmup_steps
self.hold_base_rate_steps = hold_base_rate_steps
self.verbose = verbose
# 学习率历史
self.learning_rates = []
def on_train_begin(self, logs: Optional[Dict] = None) -> None:
"""
训练开始时调用
Args:
logs: 训练日志
"""
self.current_step = 0
logger.info(f"预热余弦退火学习率调度器初始化: 基础学习率={self.learning_rate_base}, "
f"预热步数={self.warmup_steps}, 总步数={self.total_steps}")
def on_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""
每个batch结束时调用
Args:
batch: 当前batch索引
logs: 训练日志
"""
self.current_step += 1
lr = self._get_learning_rate()
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
self.learning_rates.append(lr)
def _get_learning_rate(self) -> float:
"""
计算当前学习率
Returns:
当前学习率
"""
if self.current_step < self.warmup_steps:
# 预热阶段:线性增加学习率
lr = self.warmup_learning_rate + self.current_step * (
(self.learning_rate_base - self.warmup_learning_rate) / self.warmup_steps
)
elif self.current_step < self.warmup_steps + self.hold_base_rate_steps:
# 保持基础学习率阶段
lr = self.learning_rate_base
else:
# 余弦退火阶段
cosine_steps = self.total_steps - self.warmup_steps - self.hold_base_rate_steps
cosine_current_step = self.current_step - self.warmup_steps - self.hold_base_rate_steps
lr = 0.5 * self.learning_rate_base * (
1 + math.cos(math.pi * cosine_current_step / cosine_steps)
)
return lr