217 lines
6.9 KiB
Python
217 lines
6.9 KiB
Python
"""
|
||
集成模型:实现多个模型的集成
|
||
"""
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import os
|
||
|
||
from config.system_config import CLASSIFIERS_DIR
|
||
from models.base_model import TextClassificationModel
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("EnsembleModel")
|
||
|
||
|
||
class EnsembleModel:
|
||
"""模型集成类,集成多个模型的预测结果"""
|
||
|
||
def __init__(self, models: List[TextClassificationModel],
|
||
weights: Optional[List[float]] = None,
|
||
voting: str = 'soft',
|
||
name: str = "ensemble_model"):
|
||
"""
|
||
初始化集成模型
|
||
|
||
Args:
|
||
models: 模型列表
|
||
weights: 各模型的权重,默认为均等权重
|
||
voting: 投票方式,'hard'表示多数投票,'soft'表示概率平均
|
||
name: 集成模型名称
|
||
"""
|
||
self.models = models
|
||
self.num_models = len(models)
|
||
|
||
# 验证模型数量
|
||
if self.num_models == 0:
|
||
raise ValueError("模型列表不能为空")
|
||
|
||
# 设置权重
|
||
if weights is None:
|
||
self.weights = np.ones(self.num_models) / self.num_models
|
||
else:
|
||
if len(weights) != self.num_models:
|
||
raise ValueError("权重数量必须与模型数量相同")
|
||
|
||
# 归一化权重
|
||
self.weights = np.array(weights) / np.sum(weights)
|
||
|
||
# 验证投票方式
|
||
self.voting = voting.lower()
|
||
if self.voting not in ['hard', 'soft']:
|
||
raise ValueError("无效的投票方式,支持的方式: 'hard', 'soft'")
|
||
|
||
# 从第一个模型获取类别数
|
||
self.num_classes = models[0].num_classes
|
||
|
||
# 验证所有模型的类别数是否相同
|
||
for i, model in enumerate(models[1:], 1):
|
||
if model.num_classes != self.num_classes:
|
||
raise ValueError(
|
||
f"模型 {i} 的类别数 ({model.num_classes}) 与第一个模型的类别数 ({self.num_classes}) 不同")
|
||
|
||
self.name = name
|
||
|
||
logger.info(f"初始化集成模型,包含 {self.num_models} 个模型,投票方式: {self.voting}")
|
||
|
||
def predict(self, x: Union[np.ndarray, tf.data.Dataset, List],
|
||
batch_size: Optional[int] = None,
|
||
verbose: int = 0) -> np.ndarray:
|
||
"""
|
||
使用集成模型进行预测
|
||
|
||
Args:
|
||
x: 预测数据
|
||
batch_size: 批大小
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
预测概率
|
||
"""
|
||
# 获取每个模型的预测结果
|
||
all_predictions = []
|
||
|
||
for i, model in enumerate(self.models):
|
||
logger.info(f"获取模型 {i + 1}/{self.num_models} 的预测结果")
|
||
predictions = model.predict(x, batch_size, verbose)
|
||
|
||
# 如果是二分类且输出形状是(n,1),转换为(n,2)
|
||
if self.num_classes == 2 and predictions.shape[1:] == (1,):
|
||
predictions = np.hstack([1 - predictions, predictions])
|
||
|
||
all_predictions.append(predictions)
|
||
|
||
# 根据投票方式进行集成
|
||
if self.voting == 'hard':
|
||
# 硬投票:每个模型预测的类别,取众数
|
||
individual_classes = [np.argmax(pred, axis=1) for pred in all_predictions]
|
||
|
||
# 获取带权重的预测类别频率
|
||
ensemble_result = np.zeros((len(x), self.num_classes))
|
||
|
||
for i, classes in enumerate(individual_classes):
|
||
for j, cls in enumerate(classes):
|
||
ensemble_result[j, cls] += self.weights[i]
|
||
|
||
return ensemble_result
|
||
else: # soft voting
|
||
# 软投票:对每个模型的预测概率进行加权平均
|
||
weighted_predictions = [pred * weight for pred, weight in zip(all_predictions, self.weights)]
|
||
ensemble_result = np.sum(weighted_predictions, axis=0)
|
||
|
||
return ensemble_result
|
||
|
||
def predict_classes(self, x: Union[np.ndarray, tf.data.Dataset, List],
|
||
batch_size: Optional[int] = None,
|
||
verbose: int = 0) -> np.ndarray:
|
||
"""
|
||
使用集成模型预测类别
|
||
|
||
Args:
|
||
x: 预测数据
|
||
batch_size: 批大小
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
预测的类别索引
|
||
"""
|
||
# 获取预测概率
|
||
predictions = self.predict(x, batch_size, verbose)
|
||
|
||
# 获取最大概率的类别索引
|
||
return np.argmax(predictions, axis=1)
|
||
|
||
def save(self, directory: Optional[str] = None) -> str:
|
||
"""
|
||
保存集成模型
|
||
|
||
Args:
|
||
directory: 保存目录,默认为CLASSIFIERS_DIR
|
||
|
||
Returns:
|
||
保存路径
|
||
"""
|
||
if directory is None:
|
||
import time
|
||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||
directory = os.path.join(CLASSIFIERS_DIR, f"{self.name}_{timestamp}")
|
||
|
||
os.makedirs(directory, exist_ok=True)
|
||
|
||
# 保存模型列表
|
||
model_paths = []
|
||
for i, model in enumerate(self.models):
|
||
model_path = os.path.join(directory, f"model_{i}")
|
||
model.save(model_path)
|
||
model_paths.append(model_path)
|
||
|
||
# 保存集成配置
|
||
config = {
|
||
"name": self.name,
|
||
"num_models": self.num_models,
|
||
"model_paths": model_paths,
|
||
"weights": self.weights.tolist(),
|
||
"voting": self.voting,
|
||
"num_classes": self.num_classes
|
||
}
|
||
|
||
import json
|
||
config_path = os.path.join(directory, "ensemble_config.json")
|
||
with open(config_path, 'w', encoding='utf-8') as f:
|
||
json.dump(config, f, ensure_ascii=False, indent=4)
|
||
|
||
logger.info(f"集成模型已保存到目录: {directory}")
|
||
|
||
return directory
|
||
|
||
@classmethod
|
||
def load(cls, directory: str, custom_objects: Optional[Dict[str, Any]] = None) -> 'EnsembleModel':
|
||
"""
|
||
加载集成模型
|
||
|
||
Args:
|
||
directory: 模型目录
|
||
custom_objects: 自定义对象字典
|
||
|
||
Returns:
|
||
加载的集成模型实例
|
||
"""
|
||
# 加载配置
|
||
config_path = os.path.join(directory, "ensemble_config.json")
|
||
|
||
import json
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
|
||
# 加载子模型
|
||
from models.model_factory import ModelFactory
|
||
|
||
models = []
|
||
model_paths = config["model_paths"]
|
||
|
||
for model_path in model_paths:
|
||
model = ModelFactory.load_model(model_path, custom_objects)
|
||
models.append(model)
|
||
|
||
# 创建集成模型
|
||
ensemble = cls(
|
||
models=models,
|
||
weights=config["weights"],
|
||
voting=config["voting"],
|
||
name=config["name"]
|
||
)
|
||
|
||
logger.info(f"从目录 {directory} 加载集成模型成功")
|
||
|
||
return ensemble
|