115 lines
3.3 KiB
Python
115 lines
3.3 KiB
Python
"""
|
||
日志工具模块
|
||
"""
|
||
import logging
|
||
import sys
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
import os
|
||
|
||
from config.system_config import LOG_DIR, LOG_LEVEL, LOG_FORMAT
|
||
|
||
|
||
def get_logger(name, level=None, log_file=None):
|
||
"""
|
||
获取logger实例
|
||
|
||
Args:
|
||
name: logger名称
|
||
level: 日志级别,默认为系统配置
|
||
log_file: 日志文件路径,默认为None(仅控制台输出)
|
||
|
||
Returns:
|
||
logger实例
|
||
"""
|
||
level = level or LOG_LEVEL
|
||
|
||
# 创建logger
|
||
logger = logging.getLogger(name)
|
||
logger.setLevel(getattr(logging, level))
|
||
|
||
# 避免重复添加handler
|
||
if logger.handlers:
|
||
return logger
|
||
|
||
# 创建格式化器
|
||
formatter = logging.Formatter(LOG_FORMAT)
|
||
|
||
# 创建控制台处理器
|
||
console_handler = logging.StreamHandler(sys.stdout)
|
||
console_handler.setFormatter(formatter)
|
||
logger.addHandler(console_handler)
|
||
|
||
# 如果指定了日志文件,创建文件处理器
|
||
if log_file:
|
||
log_path = Path(LOG_DIR) / log_file
|
||
file_handler = logging.FileHandler(log_path, encoding='utf-8')
|
||
file_handler.setFormatter(formatter)
|
||
logger.addHandler(file_handler)
|
||
|
||
return logger
|
||
|
||
|
||
def get_time_logger(name):
|
||
"""
|
||
获取带时间戳的logger实例
|
||
|
||
Args:
|
||
name: logger名称
|
||
|
||
Returns:
|
||
logger实例
|
||
"""
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
log_file = f"{name}_{timestamp}.log"
|
||
return get_logger(name, log_file=log_file)
|
||
|
||
|
||
class TrainingLogger:
|
||
"""训练过程日志记录器"""
|
||
|
||
def __init__(self, model_name):
|
||
"""
|
||
初始化训练日志记录器
|
||
|
||
Args:
|
||
model_name: 模型名称
|
||
"""
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
self.log_file = f"training_{model_name}_{timestamp}.log"
|
||
self.logger = get_logger(f"training_{model_name}", log_file=self.log_file)
|
||
|
||
# 创建CSV日志
|
||
self.csv_path = Path(LOG_DIR) / f"metrics_{model_name}_{timestamp}.csv"
|
||
with open(self.csv_path, 'w', encoding='utf-8') as f:
|
||
f.write("epoch,loss,accuracy,val_loss,val_accuracy\n")
|
||
|
||
def log_epoch(self, epoch, metrics):
|
||
"""记录每个epoch的指标"""
|
||
# 日志记录
|
||
metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
||
self.logger.info(f"Epoch {epoch}: {metrics_str}")
|
||
|
||
# CSV记录
|
||
csv_line = f"{epoch},{metrics.get('loss', '')},{metrics.get('accuracy', '')}," \
|
||
f"{metrics.get('val_loss', '')},{metrics.get('val_accuracy', '')}\n"
|
||
with open(self.csv_path, 'a', encoding='utf-8') as f:
|
||
f.write(csv_line)
|
||
|
||
def log_training_start(self, config):
|
||
"""记录训练开始信息"""
|
||
self.logger.info("=" * 50)
|
||
self.logger.info("训练开始")
|
||
self.logger.info("模型配置:")
|
||
for key, value in config.items():
|
||
self.logger.info(f" {key}: {value}")
|
||
self.logger.info("=" * 50)
|
||
|
||
def log_training_end(self, duration, best_metrics):
|
||
"""记录训练结束信息"""
|
||
self.logger.info("=" * 50)
|
||
self.logger.info(f"训练结束,总用时: {duration:.2f}秒")
|
||
self.logger.info("最佳性能:")
|
||
for key, value in best_metrics.items():
|
||
self.logger.info(f" {key}: {value:.4f}")
|
||
self.logger.info("=" * 50) |