2025-03-08 01:34:36 +08:00

170 lines
5.8 KiB
Python

"""
模型工厂:统一创建和管理不同类型的模型
"""
from typing import List, Dict, Tuple, Optional, Any, Union
import os
import glob
import time
import numpy as np
from config.system_config import CLASSIFIERS_DIR
from config.model_config import (
BATCH_SIZE, LEARNING_RATE
)
from models.base_model import TextClassificationModel
from models.cnn_model import CNNTextClassifier
from models.rnn_model import RNNTextClassifier
from models.transformer_model import TransformerTextClassifier
from utils.logger import get_logger
logger = get_logger("ModelFactory")
class ModelFactory:
"""模型工厂,用于创建和管理不同类型的模型"""
@staticmethod
def create_model(model_type: str, num_classes: int, vocab_size: int,
embedding_matrix: Optional[np.ndarray] = None,
model_config: Optional[Dict[str, Any]] = None,
**kwargs) -> TextClassificationModel:
"""
创建指定类型的模型
Args:
model_type: 模型类型,可选值: 'cnn', 'rnn', 'transformer'
num_classes: 类别数量
vocab_size: 词汇表大小
embedding_matrix: 预训练词嵌入矩阵
model_config: 模型配置字典
**kwargs: 其他参数
Returns:
创建的模型实例
"""
model_type = model_type.lower()
# 合并配置
config = model_config or {}
config.update(kwargs)
# 创建模型
if model_type == 'cnn':
model = CNNTextClassifier(
num_classes=num_classes,
vocab_size=vocab_size,
embedding_matrix=embedding_matrix,
**config
)
elif model_type == 'rnn':
model = RNNTextClassifier(
num_classes=num_classes,
vocab_size=vocab_size,
embedding_matrix=embedding_matrix,
**config
)
elif model_type == 'transformer':
model = TransformerTextClassifier(
num_classes=num_classes,
vocab_size=vocab_size,
embedding_matrix=embedding_matrix,
**config
)
else:
raise ValueError(f"不支持的模型类型: {model_type}")
logger.info(f"已创建 {model_type.upper()} 模型")
return model
@staticmethod
def load_model(model_path: str, custom_objects: Optional[Dict[str, Any]] = None) -> TextClassificationModel:
"""
加载保存的模型
Args:
model_path: 模型路径
custom_objects: 自定义对象字典
Returns:
加载的模型实例
"""
# 添加Transformer相关的自定义对象
if custom_objects is None:
custom_objects = {}
if 'TransformerBlock' not in custom_objects:
from models.transformer_model import TransformerBlock
custom_objects['TransformerBlock'] = TransformerBlock
# 根据配置确定模型类型
model_config_path = f"{model_path}_config.json"
import json
with open(model_config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
model_type = config.get('model_type', '').lower()
# 根据模型类型选择加载方法
if model_type == 'cnn':
model = CNNTextClassifier.load(model_path, custom_objects)
elif model_type == 'rnn':
model = RNNTextClassifier.load(model_path, custom_objects)
elif model_type == 'transformer':
model = TransformerTextClassifier.load(model_path, custom_objects)
else:
# 如果无法确定模型类型,使用基类加载
logger.warning(f"无法确定模型类型,使用基类加载: {model_path}")
model = TextClassificationModel.load(model_path, custom_objects)
logger.info(f"已加载模型: {model_path}")
return model
@staticmethod
def get_available_models() -> List[Dict[str, Any]]:
"""
获取可用的已保存模型列表
Returns:
模型信息列表,每个元素是包含模型信息的字典
"""
model_files = glob.glob(os.path.join(CLASSIFIERS_DIR, "*"))
model_files = [f for f in model_files if not f.endswith("_config.json")]
models_info = []
for model_file in model_files:
config_file = f"{model_file}_config.json"
if os.path.exists(config_file):
try:
import json
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
# 获取模型文件的创建时间
created_time = os.path.getctime(model_file)
created_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_time))
# 获取模型文件大小
file_size = os.path.getsize(model_file) / (1024 * 1024) # MB
models_info.append({
"path": model_file,
"name": config.get("model_name", os.path.basename(model_file)),
"type": config.get("model_type", "unknown"),
"num_classes": config.get("num_classes", 0),
"created_time": created_time_str,
"file_size": f"{file_size:.2f} MB",
"config": config
})
except Exception as e:
logger.error(f"读取模型配置失败: {config_file}, 错误: {e}")
# 按创建时间降序排序
models_info.sort(key=lambda x: x.get("created_time", ""), reverse=True)
return models_info