""" 模型工厂:统一创建和管理不同类型的模型 """ from typing import List, Dict, Tuple, Optional, Any, Union import os import glob import time import numpy as np from config.system_config import CLASSIFIERS_DIR from config.model_config import ( BATCH_SIZE, LEARNING_RATE ) from models.base_model import TextClassificationModel from models.cnn_model import CNNTextClassifier from models.rnn_model import RNNTextClassifier from models.transformer_model import TransformerTextClassifier from utils.logger import get_logger logger = get_logger("ModelFactory") class ModelFactory: """模型工厂,用于创建和管理不同类型的模型""" @staticmethod def create_model(model_type: str, num_classes: int, vocab_size: int, embedding_matrix: Optional[np.ndarray] = None, model_config: Optional[Dict[str, Any]] = None, **kwargs) -> TextClassificationModel: """ 创建指定类型的模型 Args: model_type: 模型类型,可选值: 'cnn', 'rnn', 'transformer' num_classes: 类别数量 vocab_size: 词汇表大小 embedding_matrix: 预训练词嵌入矩阵 model_config: 模型配置字典 **kwargs: 其他参数 Returns: 创建的模型实例 """ model_type = model_type.lower() # 合并配置 config = model_config or {} config.update(kwargs) # 创建模型 if model_type == 'cnn': model = CNNTextClassifier( num_classes=num_classes, vocab_size=vocab_size, embedding_matrix=embedding_matrix, **config ) elif model_type == 'rnn': model = RNNTextClassifier( num_classes=num_classes, vocab_size=vocab_size, embedding_matrix=embedding_matrix, **config ) elif model_type == 'transformer': model = TransformerTextClassifier( num_classes=num_classes, vocab_size=vocab_size, embedding_matrix=embedding_matrix, **config ) else: raise ValueError(f"不支持的模型类型: {model_type}") logger.info(f"已创建 {model_type.upper()} 模型") return model @staticmethod def load_model(model_path: str, custom_objects: Optional[Dict[str, Any]] = None) -> TextClassificationModel: """ 加载保存的模型 Args: model_path: 模型路径 custom_objects: 自定义对象字典 Returns: 加载的模型实例 """ # 添加Transformer相关的自定义对象 if custom_objects is None: custom_objects = {} if 'TransformerBlock' not in custom_objects: from models.transformer_model import TransformerBlock custom_objects['TransformerBlock'] = TransformerBlock # 根据配置确定模型类型 model_config_path = f"{model_path}_config.json" import json with open(model_config_path, 'r', encoding='utf-8') as f: config = json.load(f) model_type = config.get('model_type', '').lower() # 根据模型类型选择加载方法 if model_type == 'cnn': model = CNNTextClassifier.load(model_path, custom_objects) elif model_type == 'rnn': model = RNNTextClassifier.load(model_path, custom_objects) elif model_type == 'transformer': model = TransformerTextClassifier.load(model_path, custom_objects) else: # 如果无法确定模型类型,使用基类加载 logger.warning(f"无法确定模型类型,使用基类加载: {model_path}") model = TextClassificationModel.load(model_path, custom_objects) logger.info(f"已加载模型: {model_path}") return model @staticmethod def get_available_models() -> List[Dict[str, Any]]: """ 获取可用的已保存模型列表 Returns: 模型信息列表,每个元素是包含模型信息的字典 """ model_files = glob.glob(os.path.join(CLASSIFIERS_DIR, "*")) model_files = [f for f in model_files if not f.endswith("_config.json")] models_info = [] for model_file in model_files: config_file = f"{model_file}_config.json" if os.path.exists(config_file): try: import json with open(config_file, 'r', encoding='utf-8') as f: config = json.load(f) # 获取模型文件的创建时间 created_time = os.path.getctime(model_file) created_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_time)) # 获取模型文件大小 file_size = os.path.getsize(model_file) / (1024 * 1024) # MB models_info.append({ "path": model_file, "name": config.get("model_name", os.path.basename(model_file)), "type": config.get("model_type", "unknown"), "num_classes": config.get("num_classes", 0), "created_time": created_time_str, "file_size": f"{file_size:.2f} MB", "config": config }) except Exception as e: logger.error(f"读取模型配置失败: {config_file}, 错误: {e}") # 按创建时间降序排序 models_info.sort(key=lambda x: x.get("created_time", ""), reverse=True) return models_info