2025-03-08 01:34:36 +08:00

115 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
日志工具模块
"""
import logging
import sys
from pathlib import Path
from datetime import datetime
import os
from config.system_config import LOG_DIR, LOG_LEVEL, LOG_FORMAT
def get_logger(name, level=None, log_file=None):
"""
获取logger实例
Args:
name: logger名称
level: 日志级别,默认为系统配置
log_file: 日志文件路径默认为None仅控制台输出
Returns:
logger实例
"""
level = level or LOG_LEVEL
# 创建logger
logger = logging.getLogger(name)
logger.setLevel(getattr(logging, level))
# 避免重复添加handler
if logger.handlers:
return logger
# 创建格式化器
formatter = logging.Formatter(LOG_FORMAT)
# 创建控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 如果指定了日志文件,创建文件处理器
if log_file:
log_path = Path(LOG_DIR) / log_file
file_handler = logging.FileHandler(log_path, encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def get_time_logger(name):
"""
获取带时间戳的logger实例
Args:
name: logger名称
Returns:
logger实例
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = f"{name}_{timestamp}.log"
return get_logger(name, log_file=log_file)
class TrainingLogger:
"""训练过程日志记录器"""
def __init__(self, model_name):
"""
初始化训练日志记录器
Args:
model_name: 模型名称
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.log_file = f"training_{model_name}_{timestamp}.log"
self.logger = get_logger(f"training_{model_name}", log_file=self.log_file)
# 创建CSV日志
self.csv_path = Path(LOG_DIR) / f"metrics_{model_name}_{timestamp}.csv"
with open(self.csv_path, 'w', encoding='utf-8') as f:
f.write("epoch,loss,accuracy,val_loss,val_accuracy\n")
def log_epoch(self, epoch, metrics):
"""记录每个epoch的指标"""
# 日志记录
metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
self.logger.info(f"Epoch {epoch}: {metrics_str}")
# CSV记录
csv_line = f"{epoch},{metrics.get('loss', '')},{metrics.get('accuracy', '')}," \
f"{metrics.get('val_loss', '')},{metrics.get('val_accuracy', '')}\n"
with open(self.csv_path, 'a', encoding='utf-8') as f:
f.write(csv_line)
def log_training_start(self, config):
"""记录训练开始信息"""
self.logger.info("=" * 50)
self.logger.info("训练开始")
self.logger.info("模型配置:")
for key, value in config.items():
self.logger.info(f" {key}: {value}")
self.logger.info("=" * 50)
def log_training_end(self, duration, best_metrics):
"""记录训练结束信息"""
self.logger.info("=" * 50)
self.logger.info(f"训练结束,总用时: {duration:.2f}")
self.logger.info("最佳性能:")
for key, value in best_metrics.items():
self.logger.info(f" {key}: {value:.4f}")
self.logger.info("=" * 50)