2025-03-08 01:34:36 +08:00

419 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
模型基类:定义所有文本分类模型的通用接口
"""
import os
import time
import json
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
from abc import ABC, abstractmethod
from config.system_config import SAVED_MODELS_DIR, CLASSIFIERS_DIR
from config.model_config import (
BATCH_SIZE, LEARNING_RATE, EARLY_STOPPING_PATIENCE,
REDUCE_LR_PATIENCE, REDUCE_LR_FACTOR
)
from utils.logger import get_logger
from utils.file_utils import ensure_dir, save_json
logger = get_logger("BaseModel")
class TextClassificationModel(ABC):
"""文本分类模型基类,定义所有模型的通用接口"""
def __init__(self, num_classes: int, model_name: str = "text_classifier",
batch_size: int = BATCH_SIZE,
learning_rate: float = LEARNING_RATE):
"""
初始化文本分类模型
Args:
num_classes: 类别数量
model_name: 模型名称
batch_size: 批大小
learning_rate: 学习率
"""
self.num_classes = num_classes
self.model_name = model_name
self.batch_size = batch_size
self.learning_rate = learning_rate
# 模型实例
self.model = None
# 训练历史
self.history = None
# 训练配置
self.config = {
"model_name": model_name,
"num_classes": num_classes,
"batch_size": batch_size,
"learning_rate": learning_rate
}
# 验证集合最佳性能
self.best_val_loss = float('inf')
self.best_val_accuracy = 0.0
logger.info(f"初始化 {model_name} 模型,类别数: {num_classes}")
@abstractmethod
def build(self) -> None:
"""构建模型架构,这是一个抽象方法,子类必须实现"""
pass
def compile(self, optimizer: Optional[tf.keras.optimizers.Optimizer] = None,
loss: Optional[Union[str, tf.keras.losses.Loss]] = None,
metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None) -> None:
"""
编译模型
Args:
optimizer: 优化器默认为Adam
loss: 损失函数默认为sparse_categorical_crossentropy
metrics: 评估指标默认为accuracy
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 默认优化器
if optimizer is None:
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
# 默认损失函数
if loss is None:
loss = 'sparse_categorical_crossentropy'
# 默认评估指标
if metrics is None:
metrics = ['accuracy']
# 编译模型
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
logger.info(f"模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}, 评估指标: {metrics}")
def summary(self) -> None:
"""打印模型概要"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
self.model.summary()
def fit(self, x_train: Union[np.ndarray, tf.data.Dataset],
y_train: Optional[np.ndarray] = None,
validation_data: Optional[Union[Tuple[np.ndarray, np.ndarray], tf.data.Dataset]] = None,
epochs: int = 10,
callbacks: Optional[List[tf.keras.callbacks.Callback]] = None,
class_weights: Optional[Dict[int, float]] = None,
verbose: int = 1) -> tf.keras.callbacks.History:
"""
训练模型
Args:
x_train: 训练数据特征
y_train: 训练数据标签
validation_data: 验证数据
epochs: 训练轮数
callbacks: 回调函数列表
class_weights: 类别权重
verbose: 详细程度
Returns:
训练历史
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 记录开始时间
start_time = time.time()
# 添加默认回调函数
if callbacks is None:
callbacks = self._get_default_callbacks()
# 训练模型
if isinstance(x_train, tf.data.Dataset):
# 如果输入是TensorFlow Dataset
history = self.model.fit(
x_train,
epochs=epochs,
validation_data=validation_data,
callbacks=callbacks,
class_weight=class_weights,
verbose=verbose
)
else:
# 如果输入是NumPy数组
history = self.model.fit(
x_train, y_train,
batch_size=self.batch_size,
epochs=epochs,
validation_data=validation_data,
callbacks=callbacks,
class_weight=class_weights,
verbose=verbose
)
# 计算训练时间
train_time = time.time() - start_time
# 保存训练历史
self.history = history.history
self.history['train_time'] = train_time
logger.info(f"模型训练完成,耗时: {train_time:.2f}")
return history
def evaluate(self, x_test: Union[np.ndarray, tf.data.Dataset],
y_test: Optional[np.ndarray] = None,
verbose: int = 1) -> Dict[str, float]:
"""
评估模型
Args:
x_test: 测试数据特征
y_test: 测试数据标签
verbose: 详细程度
Returns:
评估结果字典
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 评估模型
if isinstance(x_test, tf.data.Dataset):
# 如果输入是TensorFlow Dataset
results = self.model.evaluate(x_test, verbose=verbose)
else:
# 如果输入是NumPy数组
results = self.model.evaluate(x_test, y_test, batch_size=self.batch_size, verbose=verbose)
# 构建评估结果字典
metrics_names = self.model.metrics_names
evaluation_results = {name: float(value) for name, value in zip(metrics_names, results)}
logger.info(f"模型评估结果: {evaluation_results}")
return evaluation_results
def predict(self, x: Union[np.ndarray, tf.data.Dataset, List],
batch_size: Optional[int] = None,
verbose: int = 0) -> np.ndarray:
"""
使用模型进行预测
Args:
x: 预测数据
batch_size: 批大小
verbose: 详细程度
Returns:
预测结果
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 使用模型进行预测
if batch_size is None:
batch_size = self.batch_size
return self.model.predict(x, batch_size=batch_size, verbose=verbose)
def predict_classes(self, x: Union[np.ndarray, tf.data.Dataset, List],
batch_size: Optional[int] = None,
verbose: int = 0) -> np.ndarray:
"""
使用模型预测类别
Args:
x: 预测数据
batch_size: 批大小
verbose: 详细程度
Returns:
预测的类别索引
"""
# 获取模型预测概率
predictions = self.predict(x, batch_size, verbose)
# 获取最大概率的类别索引
return np.argmax(predictions, axis=1)
def save(self, filepath: Optional[str] = None,
save_format: str = 'tf',
include_optimizer: bool = True) -> str:
"""
保存模型
Args:
filepath: 保存路径如果为None则使用默认路径
save_format: 保存格式,'tf''h5'
include_optimizer: 是否包含优化器状态
Returns:
保存路径
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 如果未指定保存路径,使用默认路径
if filepath is None:
ensure_dir(CLASSIFIERS_DIR)
timestamp = time.strftime("%Y%m%d_%H%M%S")
filepath = os.path.join(CLASSIFIERS_DIR, f"{self.model_name}_{timestamp}")
# 保存模型
self.model.save(filepath, save_format=save_format, include_optimizer=include_optimizer)
# 保存模型配置
config_path = f"{filepath}_config.json"
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(self.config, f, ensure_ascii=False, indent=4)
logger.info(f"模型已保存到: {filepath}")
return filepath
@classmethod
def load(cls, filepath: str, custom_objects: Optional[Dict[str, Any]] = None) -> 'TextClassificationModel':
"""
加载模型
Args:
filepath: 模型文件路径
custom_objects: 自定义对象字典
Returns:
加载的模型实例
"""
# 加载模型配置
config_path = f"{filepath}_config.json"
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
except FileNotFoundError:
logger.warning(f"未找到模型配置文件: {config_path},将使用默认配置")
config = {}
# 创建模型实例
model_name = config.get('model_name', 'loaded_model')
num_classes = config.get('num_classes', 1)
batch_size = config.get('batch_size', BATCH_SIZE)
learning_rate = config.get('learning_rate', LEARNING_RATE)
instance = cls(num_classes, model_name, batch_size, learning_rate)
# 加载Keras模型
instance.model = load_model(filepath, custom_objects=custom_objects)
# 加载配置
instance.config = config
logger.info(f"{filepath} 加载模型成功")
return instance
def _get_default_callbacks(self) -> List[tf.keras.callbacks.Callback]:
"""获取默认的回调函数列表"""
# 早停
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=EARLY_STOPPING_PATIENCE,
restore_best_weights=True,
verbose=1
)
# 学习率衰减
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=REDUCE_LR_FACTOR,
patience=REDUCE_LR_PATIENCE,
min_lr=1e-6,
verbose=1
)
# 模型检查点
checkpoint_path = os.path.join(SAVED_MODELS_DIR, 'checkpoints', self.model_name)
ensure_dir(os.path.dirname(checkpoint_path))
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_best_only=True,
monitor='val_loss',
verbose=1
)
# TensorBoard日志
log_dir = os.path.join(SAVED_MODELS_DIR, 'logs', f"{self.model_name}_{time.strftime('%Y%m%d_%H%M%S')}")
ensure_dir(log_dir)
tensorboard = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1
)
return [early_stopping, reduce_lr, model_checkpoint, tensorboard]
def get_config(self) -> Dict[str, Any]:
"""获取模型配置"""
return self.config.copy()
def get_model(self) -> Model:
"""获取Keras模型实例"""
return self.model
def get_training_history(self) -> Optional[Dict[str, List[float]]]:
"""获取训练历史"""
return self.history
def plot_training_history(self, save_path: Optional[str] = None,
metrics: Optional[List[str]] = None) -> None:
"""
绘制训练历史
Args:
save_path: 保存路径如果为None则显示图像
metrics: 要绘制的指标列表,默认为['loss', 'accuracy']
"""
if self.history is None:
raise ValueError("模型尚未训练,没有训练历史")
import matplotlib.pyplot as plt
if metrics is None:
metrics = ['loss', 'accuracy']
# 创建图形
plt.figure(figsize=(12, 5))
# 绘制指标
for i, metric in enumerate(metrics):
plt.subplot(1, len(metrics), i + 1)
if metric in self.history:
plt.plot(self.history[metric], label=f'train_{metric}')
val_metric = f'val_{metric}'
if val_metric in self.history:
plt.plot(self.history[val_metric], label=f'val_{metric}')
plt.title(f'Model {metric}')
plt.xlabel('Epoch')
plt.ylabel(metric)
plt.legend()
plt.tight_layout()
# 保存或显示图像
if save_path:
plt.savefig(save_path)
logger.info(f"训练历史图已保存到: {save_path}")
else:
plt.show()