419 lines
13 KiB
Python
419 lines
13 KiB
Python
"""
|
||
模型基类:定义所有文本分类模型的通用接口
|
||
"""
|
||
import os
|
||
import time
|
||
import json
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from tensorflow.keras.models import Model, load_model
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
|
||
from abc import ABC, abstractmethod
|
||
|
||
from config.system_config import SAVED_MODELS_DIR, CLASSIFIERS_DIR
|
||
from config.model_config import (
|
||
BATCH_SIZE, LEARNING_RATE, EARLY_STOPPING_PATIENCE,
|
||
REDUCE_LR_PATIENCE, REDUCE_LR_FACTOR
|
||
)
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import ensure_dir, save_json
|
||
|
||
logger = get_logger("BaseModel")
|
||
|
||
|
||
class TextClassificationModel(ABC):
|
||
"""文本分类模型基类,定义所有模型的通用接口"""
|
||
|
||
def __init__(self, num_classes: int, model_name: str = "text_classifier",
|
||
batch_size: int = BATCH_SIZE,
|
||
learning_rate: float = LEARNING_RATE):
|
||
"""
|
||
初始化文本分类模型
|
||
|
||
Args:
|
||
num_classes: 类别数量
|
||
model_name: 模型名称
|
||
batch_size: 批大小
|
||
learning_rate: 学习率
|
||
"""
|
||
self.num_classes = num_classes
|
||
self.model_name = model_name
|
||
self.batch_size = batch_size
|
||
self.learning_rate = learning_rate
|
||
|
||
# 模型实例
|
||
self.model = None
|
||
|
||
# 训练历史
|
||
self.history = None
|
||
|
||
# 训练配置
|
||
self.config = {
|
||
"model_name": model_name,
|
||
"num_classes": num_classes,
|
||
"batch_size": batch_size,
|
||
"learning_rate": learning_rate
|
||
}
|
||
|
||
# 验证集合最佳性能
|
||
self.best_val_loss = float('inf')
|
||
self.best_val_accuracy = 0.0
|
||
|
||
logger.info(f"初始化 {model_name} 模型,类别数: {num_classes}")
|
||
|
||
@abstractmethod
|
||
def build(self) -> None:
|
||
"""构建模型架构,这是一个抽象方法,子类必须实现"""
|
||
pass
|
||
|
||
def compile(self, optimizer: Optional[tf.keras.optimizers.Optimizer] = None,
|
||
loss: Optional[Union[str, tf.keras.losses.Loss]] = None,
|
||
metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None) -> None:
|
||
"""
|
||
编译模型
|
||
|
||
Args:
|
||
optimizer: 优化器,默认为Adam
|
||
loss: 损失函数,默认为sparse_categorical_crossentropy
|
||
metrics: 评估指标,默认为accuracy
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 默认优化器
|
||
if optimizer is None:
|
||
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
|
||
|
||
# 默认损失函数
|
||
if loss is None:
|
||
loss = 'sparse_categorical_crossentropy'
|
||
|
||
# 默认评估指标
|
||
if metrics is None:
|
||
metrics = ['accuracy']
|
||
|
||
# 编译模型
|
||
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
|
||
logger.info(f"模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}, 评估指标: {metrics}")
|
||
|
||
def summary(self) -> None:
|
||
"""打印模型概要"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
self.model.summary()
|
||
|
||
def fit(self, x_train: Union[np.ndarray, tf.data.Dataset],
|
||
y_train: Optional[np.ndarray] = None,
|
||
validation_data: Optional[Union[Tuple[np.ndarray, np.ndarray], tf.data.Dataset]] = None,
|
||
epochs: int = 10,
|
||
callbacks: Optional[List[tf.keras.callbacks.Callback]] = None,
|
||
class_weights: Optional[Dict[int, float]] = None,
|
||
verbose: int = 1) -> tf.keras.callbacks.History:
|
||
"""
|
||
训练模型
|
||
|
||
Args:
|
||
x_train: 训练数据特征
|
||
y_train: 训练数据标签
|
||
validation_data: 验证数据
|
||
epochs: 训练轮数
|
||
callbacks: 回调函数列表
|
||
class_weights: 类别权重
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
训练历史
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 记录开始时间
|
||
start_time = time.time()
|
||
|
||
# 添加默认回调函数
|
||
if callbacks is None:
|
||
callbacks = self._get_default_callbacks()
|
||
|
||
# 训练模型
|
||
if isinstance(x_train, tf.data.Dataset):
|
||
# 如果输入是TensorFlow Dataset
|
||
history = self.model.fit(
|
||
x_train,
|
||
epochs=epochs,
|
||
validation_data=validation_data,
|
||
callbacks=callbacks,
|
||
class_weight=class_weights,
|
||
verbose=verbose
|
||
)
|
||
else:
|
||
# 如果输入是NumPy数组
|
||
history = self.model.fit(
|
||
x_train, y_train,
|
||
batch_size=self.batch_size,
|
||
epochs=epochs,
|
||
validation_data=validation_data,
|
||
callbacks=callbacks,
|
||
class_weight=class_weights,
|
||
verbose=verbose
|
||
)
|
||
|
||
# 计算训练时间
|
||
train_time = time.time() - start_time
|
||
|
||
# 保存训练历史
|
||
self.history = history.history
|
||
self.history['train_time'] = train_time
|
||
|
||
logger.info(f"模型训练完成,耗时: {train_time:.2f} 秒")
|
||
|
||
return history
|
||
|
||
def evaluate(self, x_test: Union[np.ndarray, tf.data.Dataset],
|
||
y_test: Optional[np.ndarray] = None,
|
||
verbose: int = 1) -> Dict[str, float]:
|
||
"""
|
||
评估模型
|
||
|
||
Args:
|
||
x_test: 测试数据特征
|
||
y_test: 测试数据标签
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
评估结果字典
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 评估模型
|
||
if isinstance(x_test, tf.data.Dataset):
|
||
# 如果输入是TensorFlow Dataset
|
||
results = self.model.evaluate(x_test, verbose=verbose)
|
||
else:
|
||
# 如果输入是NumPy数组
|
||
results = self.model.evaluate(x_test, y_test, batch_size=self.batch_size, verbose=verbose)
|
||
|
||
# 构建评估结果字典
|
||
metrics_names = self.model.metrics_names
|
||
evaluation_results = {name: float(value) for name, value in zip(metrics_names, results)}
|
||
|
||
logger.info(f"模型评估结果: {evaluation_results}")
|
||
|
||
return evaluation_results
|
||
|
||
def predict(self, x: Union[np.ndarray, tf.data.Dataset, List],
|
||
batch_size: Optional[int] = None,
|
||
verbose: int = 0) -> np.ndarray:
|
||
"""
|
||
使用模型进行预测
|
||
|
||
Args:
|
||
x: 预测数据
|
||
batch_size: 批大小
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
预测结果
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 使用模型进行预测
|
||
if batch_size is None:
|
||
batch_size = self.batch_size
|
||
|
||
return self.model.predict(x, batch_size=batch_size, verbose=verbose)
|
||
|
||
def predict_classes(self, x: Union[np.ndarray, tf.data.Dataset, List],
|
||
batch_size: Optional[int] = None,
|
||
verbose: int = 0) -> np.ndarray:
|
||
"""
|
||
使用模型预测类别
|
||
|
||
Args:
|
||
x: 预测数据
|
||
batch_size: 批大小
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
预测的类别索引
|
||
"""
|
||
# 获取模型预测概率
|
||
predictions = self.predict(x, batch_size, verbose)
|
||
|
||
# 获取最大概率的类别索引
|
||
return np.argmax(predictions, axis=1)
|
||
|
||
def save(self, filepath: Optional[str] = None,
|
||
save_format: str = 'tf',
|
||
include_optimizer: bool = True) -> str:
|
||
"""
|
||
保存模型
|
||
|
||
Args:
|
||
filepath: 保存路径,如果为None则使用默认路径
|
||
save_format: 保存格式,'tf'或'h5'
|
||
include_optimizer: 是否包含优化器状态
|
||
|
||
Returns:
|
||
保存路径
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 如果未指定保存路径,使用默认路径
|
||
if filepath is None:
|
||
ensure_dir(CLASSIFIERS_DIR)
|
||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||
filepath = os.path.join(CLASSIFIERS_DIR, f"{self.model_name}_{timestamp}")
|
||
|
||
# 保存模型
|
||
self.model.save(filepath, save_format=save_format, include_optimizer=include_optimizer)
|
||
|
||
# 保存模型配置
|
||
config_path = f"{filepath}_config.json"
|
||
with open(config_path, 'w', encoding='utf-8') as f:
|
||
json.dump(self.config, f, ensure_ascii=False, indent=4)
|
||
|
||
logger.info(f"模型已保存到: {filepath}")
|
||
|
||
return filepath
|
||
|
||
@classmethod
|
||
def load(cls, filepath: str, custom_objects: Optional[Dict[str, Any]] = None) -> 'TextClassificationModel':
|
||
"""
|
||
加载模型
|
||
|
||
Args:
|
||
filepath: 模型文件路径
|
||
custom_objects: 自定义对象字典
|
||
|
||
Returns:
|
||
加载的模型实例
|
||
"""
|
||
# 加载模型配置
|
||
config_path = f"{filepath}_config.json"
|
||
|
||
try:
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
except FileNotFoundError:
|
||
logger.warning(f"未找到模型配置文件: {config_path},将使用默认配置")
|
||
config = {}
|
||
|
||
# 创建模型实例
|
||
model_name = config.get('model_name', 'loaded_model')
|
||
num_classes = config.get('num_classes', 1)
|
||
batch_size = config.get('batch_size', BATCH_SIZE)
|
||
learning_rate = config.get('learning_rate', LEARNING_RATE)
|
||
|
||
instance = cls(num_classes, model_name, batch_size, learning_rate)
|
||
|
||
# 加载Keras模型
|
||
instance.model = load_model(filepath, custom_objects=custom_objects)
|
||
|
||
# 加载配置
|
||
instance.config = config
|
||
|
||
logger.info(f"从 {filepath} 加载模型成功")
|
||
|
||
return instance
|
||
|
||
def _get_default_callbacks(self) -> List[tf.keras.callbacks.Callback]:
|
||
"""获取默认的回调函数列表"""
|
||
# 早停
|
||
early_stopping = tf.keras.callbacks.EarlyStopping(
|
||
monitor='val_loss',
|
||
patience=EARLY_STOPPING_PATIENCE,
|
||
restore_best_weights=True,
|
||
verbose=1
|
||
)
|
||
|
||
# 学习率衰减
|
||
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
|
||
monitor='val_loss',
|
||
factor=REDUCE_LR_FACTOR,
|
||
patience=REDUCE_LR_PATIENCE,
|
||
min_lr=1e-6,
|
||
verbose=1
|
||
)
|
||
|
||
# 模型检查点
|
||
checkpoint_path = os.path.join(SAVED_MODELS_DIR, 'checkpoints', self.model_name)
|
||
ensure_dir(os.path.dirname(checkpoint_path))
|
||
|
||
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||
filepath=checkpoint_path,
|
||
save_best_only=True,
|
||
monitor='val_loss',
|
||
verbose=1
|
||
)
|
||
|
||
# TensorBoard日志
|
||
log_dir = os.path.join(SAVED_MODELS_DIR, 'logs', f"{self.model_name}_{time.strftime('%Y%m%d_%H%M%S')}")
|
||
ensure_dir(log_dir)
|
||
|
||
tensorboard = tf.keras.callbacks.TensorBoard(
|
||
log_dir=log_dir,
|
||
histogram_freq=1
|
||
)
|
||
|
||
return [early_stopping, reduce_lr, model_checkpoint, tensorboard]
|
||
|
||
def get_config(self) -> Dict[str, Any]:
|
||
"""获取模型配置"""
|
||
return self.config.copy()
|
||
|
||
def get_model(self) -> Model:
|
||
"""获取Keras模型实例"""
|
||
return self.model
|
||
|
||
def get_training_history(self) -> Optional[Dict[str, List[float]]]:
|
||
"""获取训练历史"""
|
||
return self.history
|
||
|
||
def plot_training_history(self, save_path: Optional[str] = None,
|
||
metrics: Optional[List[str]] = None) -> None:
|
||
"""
|
||
绘制训练历史
|
||
|
||
Args:
|
||
save_path: 保存路径,如果为None则显示图像
|
||
metrics: 要绘制的指标列表,默认为['loss', 'accuracy']
|
||
"""
|
||
if self.history is None:
|
||
raise ValueError("模型尚未训练,没有训练历史")
|
||
|
||
import matplotlib.pyplot as plt
|
||
|
||
if metrics is None:
|
||
metrics = ['loss', 'accuracy']
|
||
|
||
# 创建图形
|
||
plt.figure(figsize=(12, 5))
|
||
|
||
# 绘制指标
|
||
for i, metric in enumerate(metrics):
|
||
plt.subplot(1, len(metrics), i + 1)
|
||
|
||
if metric in self.history:
|
||
plt.plot(self.history[metric], label=f'train_{metric}')
|
||
|
||
val_metric = f'val_{metric}'
|
||
if val_metric in self.history:
|
||
plt.plot(self.history[val_metric], label=f'val_{metric}')
|
||
|
||
plt.title(f'Model {metric}')
|
||
plt.xlabel('Epoch')
|
||
plt.ylabel(metric)
|
||
plt.legend()
|
||
|
||
plt.tight_layout()
|
||
|
||
# 保存或显示图像
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"训练历史图已保存到: {save_path}")
|
||
else:
|
||
plt.show() |