170 lines
5.8 KiB
Python
170 lines
5.8 KiB
Python
"""
|
|
模型工厂:统一创建和管理不同类型的模型
|
|
"""
|
|
from typing import List, Dict, Tuple, Optional, Any, Union
|
|
import os
|
|
import glob
|
|
import time
|
|
import numpy as np
|
|
|
|
from config.system_config import CLASSIFIERS_DIR
|
|
from config.model_config import (
|
|
BATCH_SIZE, LEARNING_RATE
|
|
)
|
|
from models.base_model import TextClassificationModel
|
|
from models.cnn_model import CNNTextClassifier
|
|
from models.rnn_model import RNNTextClassifier
|
|
from models.transformer_model import TransformerTextClassifier
|
|
from utils.logger import get_logger
|
|
|
|
logger = get_logger("ModelFactory")
|
|
|
|
|
|
class ModelFactory:
|
|
"""模型工厂,用于创建和管理不同类型的模型"""
|
|
|
|
@staticmethod
|
|
def create_model(model_type: str, num_classes: int, vocab_size: int,
|
|
embedding_matrix: Optional[np.ndarray] = None,
|
|
model_config: Optional[Dict[str, Any]] = None,
|
|
**kwargs) -> TextClassificationModel:
|
|
"""
|
|
创建指定类型的模型
|
|
|
|
Args:
|
|
model_type: 模型类型,可选值: 'cnn', 'rnn', 'transformer'
|
|
num_classes: 类别数量
|
|
vocab_size: 词汇表大小
|
|
embedding_matrix: 预训练词嵌入矩阵
|
|
model_config: 模型配置字典
|
|
**kwargs: 其他参数
|
|
|
|
Returns:
|
|
创建的模型实例
|
|
"""
|
|
model_type = model_type.lower()
|
|
|
|
# 合并配置
|
|
config = model_config or {}
|
|
config.update(kwargs)
|
|
|
|
# 创建模型
|
|
if model_type == 'cnn':
|
|
model = CNNTextClassifier(
|
|
num_classes=num_classes,
|
|
vocab_size=vocab_size,
|
|
embedding_matrix=embedding_matrix,
|
|
**config
|
|
)
|
|
elif model_type == 'rnn':
|
|
model = RNNTextClassifier(
|
|
num_classes=num_classes,
|
|
vocab_size=vocab_size,
|
|
embedding_matrix=embedding_matrix,
|
|
**config
|
|
)
|
|
elif model_type == 'transformer':
|
|
model = TransformerTextClassifier(
|
|
num_classes=num_classes,
|
|
vocab_size=vocab_size,
|
|
embedding_matrix=embedding_matrix,
|
|
**config
|
|
)
|
|
else:
|
|
raise ValueError(f"不支持的模型类型: {model_type}")
|
|
|
|
logger.info(f"已创建 {model_type.upper()} 模型")
|
|
|
|
return model
|
|
|
|
@staticmethod
|
|
def load_model(model_path: str, custom_objects: Optional[Dict[str, Any]] = None) -> TextClassificationModel:
|
|
"""
|
|
加载保存的模型
|
|
|
|
Args:
|
|
model_path: 模型路径
|
|
custom_objects: 自定义对象字典
|
|
|
|
Returns:
|
|
加载的模型实例
|
|
"""
|
|
# 添加Transformer相关的自定义对象
|
|
if custom_objects is None:
|
|
custom_objects = {}
|
|
|
|
if 'TransformerBlock' not in custom_objects:
|
|
from models.transformer_model import TransformerBlock
|
|
custom_objects['TransformerBlock'] = TransformerBlock
|
|
|
|
# 根据配置确定模型类型
|
|
model_config_path = f"{model_path}_config.json"
|
|
|
|
import json
|
|
with open(model_config_path, 'r', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
|
|
model_type = config.get('model_type', '').lower()
|
|
|
|
# 根据模型类型选择加载方法
|
|
if model_type == 'cnn':
|
|
model = CNNTextClassifier.load(model_path, custom_objects)
|
|
elif model_type == 'rnn':
|
|
model = RNNTextClassifier.load(model_path, custom_objects)
|
|
elif model_type == 'transformer':
|
|
model = TransformerTextClassifier.load(model_path, custom_objects)
|
|
else:
|
|
# 如果无法确定模型类型,使用基类加载
|
|
logger.warning(f"无法确定模型类型,使用基类加载: {model_path}")
|
|
model = TextClassificationModel.load(model_path, custom_objects)
|
|
|
|
logger.info(f"已加载模型: {model_path}")
|
|
|
|
return model
|
|
|
|
@staticmethod
|
|
def get_available_models() -> List[Dict[str, Any]]:
|
|
"""
|
|
获取可用的已保存模型列表
|
|
|
|
Returns:
|
|
模型信息列表,每个元素是包含模型信息的字典
|
|
"""
|
|
model_files = glob.glob(os.path.join(CLASSIFIERS_DIR, "*"))
|
|
model_files = [f for f in model_files if not f.endswith("_config.json")]
|
|
|
|
models_info = []
|
|
|
|
for model_file in model_files:
|
|
config_file = f"{model_file}_config.json"
|
|
|
|
if os.path.exists(config_file):
|
|
try:
|
|
import json
|
|
with open(config_file, 'r', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
|
|
# 获取模型文件的创建时间
|
|
created_time = os.path.getctime(model_file)
|
|
created_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_time))
|
|
|
|
# 获取模型文件大小
|
|
file_size = os.path.getsize(model_file) / (1024 * 1024) # MB
|
|
|
|
models_info.append({
|
|
"path": model_file,
|
|
"name": config.get("model_name", os.path.basename(model_file)),
|
|
"type": config.get("model_type", "unknown"),
|
|
"num_classes": config.get("num_classes", 0),
|
|
"created_time": created_time_str,
|
|
"file_size": f"{file_size:.2f} MB",
|
|
"config": config
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"读取模型配置失败: {config_file}, 错误: {e}")
|
|
|
|
# 按创建时间降序排序
|
|
models_info.sort(key=lambda x: x.get("created_time", ""), reverse=True)
|
|
|
|
return models_info
|