""" 模型基类:定义所有文本分类模型的通用接口 """ import os import time import json import numpy as np import tensorflow as tf from tensorflow.keras.models import Model, load_model from typing import List, Dict, Tuple, Optional, Any, Union, Callable from abc import ABC, abstractmethod from config.system_config import SAVED_MODELS_DIR, CLASSIFIERS_DIR from config.model_config import ( BATCH_SIZE, LEARNING_RATE, EARLY_STOPPING_PATIENCE, REDUCE_LR_PATIENCE, REDUCE_LR_FACTOR ) from utils.logger import get_logger from utils.file_utils import ensure_dir, save_json logger = get_logger("BaseModel") class TextClassificationModel(ABC): """文本分类模型基类,定义所有模型的通用接口""" def __init__(self, num_classes: int, model_name: str = "text_classifier", batch_size: int = BATCH_SIZE, learning_rate: float = LEARNING_RATE): """ 初始化文本分类模型 Args: num_classes: 类别数量 model_name: 模型名称 batch_size: 批大小 learning_rate: 学习率 """ self.num_classes = num_classes self.model_name = model_name self.batch_size = batch_size self.learning_rate = learning_rate # 模型实例 self.model = None # 训练历史 self.history = None # 训练配置 self.config = { "model_name": model_name, "num_classes": num_classes, "batch_size": batch_size, "learning_rate": learning_rate } # 验证集合最佳性能 self.best_val_loss = float('inf') self.best_val_accuracy = 0.0 logger.info(f"初始化 {model_name} 模型,类别数: {num_classes}") @abstractmethod def build(self) -> None: """构建模型架构,这是一个抽象方法,子类必须实现""" pass def compile(self, optimizer: Optional[tf.keras.optimizers.Optimizer] = None, loss: Optional[Union[str, tf.keras.losses.Loss]] = None, metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None) -> None: """ 编译模型 Args: optimizer: 优化器,默认为Adam loss: 损失函数,默认为sparse_categorical_crossentropy metrics: 评估指标,默认为accuracy """ if self.model is None: raise ValueError("模型尚未构建,请先调用build方法") # 默认优化器 if optimizer is None: optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate) # 默认损失函数 if loss is None: loss = 'sparse_categorical_crossentropy' # 默认评估指标 if metrics is None: metrics = ['accuracy'] # 编译模型 self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics) logger.info(f"模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}, 评估指标: {metrics}") def summary(self) -> None: """打印模型概要""" if self.model is None: raise ValueError("模型尚未构建,请先调用build方法") self.model.summary() def fit(self, x_train: Union[np.ndarray, tf.data.Dataset], y_train: Optional[np.ndarray] = None, validation_data: Optional[Union[Tuple[np.ndarray, np.ndarray], tf.data.Dataset]] = None, epochs: int = 10, callbacks: Optional[List[tf.keras.callbacks.Callback]] = None, class_weights: Optional[Dict[int, float]] = None, verbose: int = 1) -> tf.keras.callbacks.History: """ 训练模型 Args: x_train: 训练数据特征 y_train: 训练数据标签 validation_data: 验证数据 epochs: 训练轮数 callbacks: 回调函数列表 class_weights: 类别权重 verbose: 详细程度 Returns: 训练历史 """ if self.model is None: raise ValueError("模型尚未构建,请先调用build方法") # 记录开始时间 start_time = time.time() # 添加默认回调函数 if callbacks is None: callbacks = self._get_default_callbacks() # 训练模型 if isinstance(x_train, tf.data.Dataset): # 如果输入是TensorFlow Dataset history = self.model.fit( x_train, epochs=epochs, validation_data=validation_data, callbacks=callbacks, class_weight=class_weights, verbose=verbose ) else: # 如果输入是NumPy数组 history = self.model.fit( x_train, y_train, batch_size=self.batch_size, epochs=epochs, validation_data=validation_data, callbacks=callbacks, class_weight=class_weights, verbose=verbose ) # 计算训练时间 train_time = time.time() - start_time # 保存训练历史 self.history = history.history self.history['train_time'] = train_time logger.info(f"模型训练完成,耗时: {train_time:.2f} 秒") return history def evaluate(self, x_test: Union[np.ndarray, tf.data.Dataset], y_test: Optional[np.ndarray] = None, verbose: int = 1) -> Dict[str, float]: """ 评估模型 Args: x_test: 测试数据特征 y_test: 测试数据标签 verbose: 详细程度 Returns: 评估结果字典 """ if self.model is None: raise ValueError("模型尚未构建,请先调用build方法") # 评估模型 if isinstance(x_test, tf.data.Dataset): # 如果输入是TensorFlow Dataset results = self.model.evaluate(x_test, verbose=verbose) else: # 如果输入是NumPy数组 results = self.model.evaluate(x_test, y_test, batch_size=self.batch_size, verbose=verbose) # 构建评估结果字典 metrics_names = self.model.metrics_names evaluation_results = {name: float(value) for name, value in zip(metrics_names, results)} logger.info(f"模型评估结果: {evaluation_results}") return evaluation_results def predict(self, x: Union[np.ndarray, tf.data.Dataset, List], batch_size: Optional[int] = None, verbose: int = 0) -> np.ndarray: """ 使用模型进行预测 Args: x: 预测数据 batch_size: 批大小 verbose: 详细程度 Returns: 预测结果 """ if self.model is None: raise ValueError("模型尚未构建,请先调用build方法") # 使用模型进行预测 if batch_size is None: batch_size = self.batch_size return self.model.predict(x, batch_size=batch_size, verbose=verbose) def predict_classes(self, x: Union[np.ndarray, tf.data.Dataset, List], batch_size: Optional[int] = None, verbose: int = 0) -> np.ndarray: """ 使用模型预测类别 Args: x: 预测数据 batch_size: 批大小 verbose: 详细程度 Returns: 预测的类别索引 """ # 获取模型预测概率 predictions = self.predict(x, batch_size, verbose) # 获取最大概率的类别索引 return np.argmax(predictions, axis=1) def save(self, filepath: Optional[str] = None, save_format: str = 'tf', include_optimizer: bool = True) -> str: """ 保存模型 Args: filepath: 保存路径,如果为None则使用默认路径 save_format: 保存格式,'tf'或'h5' include_optimizer: 是否包含优化器状态 Returns: 保存路径 """ if self.model is None: raise ValueError("模型尚未构建,请先调用build方法") # 如果未指定保存路径,使用默认路径 if filepath is None: ensure_dir(CLASSIFIERS_DIR) timestamp = time.strftime("%Y%m%d_%H%M%S") filepath = os.path.join(CLASSIFIERS_DIR, f"{self.model_name}_{timestamp}") # 保存模型 self.model.save(filepath, save_format=save_format, include_optimizer=include_optimizer) # 保存模型配置 config_path = f"{filepath}_config.json" with open(config_path, 'w', encoding='utf-8') as f: json.dump(self.config, f, ensure_ascii=False, indent=4) logger.info(f"模型已保存到: {filepath}") return filepath @classmethod def load(cls, filepath: str, custom_objects: Optional[Dict[str, Any]] = None) -> 'TextClassificationModel': """ 加载模型 Args: filepath: 模型文件路径 custom_objects: 自定义对象字典 Returns: 加载的模型实例 """ # 加载模型配置 config_path = f"{filepath}_config.json" try: with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) except FileNotFoundError: logger.warning(f"未找到模型配置文件: {config_path},将使用默认配置") config = {} # 创建模型实例 model_name = config.get('model_name', 'loaded_model') num_classes = config.get('num_classes', 1) batch_size = config.get('batch_size', BATCH_SIZE) learning_rate = config.get('learning_rate', LEARNING_RATE) instance = cls(num_classes, model_name, batch_size, learning_rate) # 加载Keras模型 instance.model = load_model(filepath, custom_objects=custom_objects) # 加载配置 instance.config = config logger.info(f"从 {filepath} 加载模型成功") return instance def _get_default_callbacks(self) -> List[tf.keras.callbacks.Callback]: """获取默认的回调函数列表""" # 早停 early_stopping = tf.keras.callbacks.EarlyStopping( monitor='val_loss', patience=EARLY_STOPPING_PATIENCE, restore_best_weights=True, verbose=1 ) # 学习率衰减 reduce_lr = tf.keras.callbacks.ReduceLROnPlateau( monitor='val_loss', factor=REDUCE_LR_FACTOR, patience=REDUCE_LR_PATIENCE, min_lr=1e-6, verbose=1 ) # 模型检查点 checkpoint_path = os.path.join(SAVED_MODELS_DIR, 'checkpoints', self.model_name) ensure_dir(os.path.dirname(checkpoint_path)) model_checkpoint = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_path, save_best_only=True, monitor='val_loss', verbose=1 ) # TensorBoard日志 log_dir = os.path.join(SAVED_MODELS_DIR, 'logs', f"{self.model_name}_{time.strftime('%Y%m%d_%H%M%S')}") ensure_dir(log_dir) tensorboard = tf.keras.callbacks.TensorBoard( log_dir=log_dir, histogram_freq=1 ) return [early_stopping, reduce_lr, model_checkpoint, tensorboard] def get_config(self) -> Dict[str, Any]: """获取模型配置""" return self.config.copy() def get_model(self) -> Model: """获取Keras模型实例""" return self.model def get_training_history(self) -> Optional[Dict[str, List[float]]]: """获取训练历史""" return self.history def plot_training_history(self, save_path: Optional[str] = None, metrics: Optional[List[str]] = None) -> None: """ 绘制训练历史 Args: save_path: 保存路径,如果为None则显示图像 metrics: 要绘制的指标列表,默认为['loss', 'accuracy'] """ if self.history is None: raise ValueError("模型尚未训练,没有训练历史") import matplotlib.pyplot as plt if metrics is None: metrics = ['loss', 'accuracy'] # 创建图形 plt.figure(figsize=(12, 5)) # 绘制指标 for i, metric in enumerate(metrics): plt.subplot(1, len(metrics), i + 1) if metric in self.history: plt.plot(self.history[metric], label=f'train_{metric}') val_metric = f'val_{metric}' if val_metric in self.history: plt.plot(self.history[val_metric], label=f'val_{metric}') plt.title(f'Model {metric}') plt.xlabel('Epoch') plt.ylabel(metric) plt.legend() plt.tight_layout() # 保存或显示图像 if save_path: plt.savefig(save_path) logger.info(f"训练历史图已保存到: {save_path}") else: plt.show()