2025-03-08 01:34:36 +08:00

217 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
集成模型:实现多个模型的集成
"""
import numpy as np
import tensorflow as tf
from typing import List, Dict, Tuple, Optional, Any, Union
import os
from config.system_config import CLASSIFIERS_DIR
from models.base_model import TextClassificationModel
from utils.logger import get_logger
logger = get_logger("EnsembleModel")
class EnsembleModel:
"""模型集成类,集成多个模型的预测结果"""
def __init__(self, models: List[TextClassificationModel],
weights: Optional[List[float]] = None,
voting: str = 'soft',
name: str = "ensemble_model"):
"""
初始化集成模型
Args:
models: 模型列表
weights: 各模型的权重,默认为均等权重
voting: 投票方式,'hard'表示多数投票,'soft'表示概率平均
name: 集成模型名称
"""
self.models = models
self.num_models = len(models)
# 验证模型数量
if self.num_models == 0:
raise ValueError("模型列表不能为空")
# 设置权重
if weights is None:
self.weights = np.ones(self.num_models) / self.num_models
else:
if len(weights) != self.num_models:
raise ValueError("权重数量必须与模型数量相同")
# 归一化权重
self.weights = np.array(weights) / np.sum(weights)
# 验证投票方式
self.voting = voting.lower()
if self.voting not in ['hard', 'soft']:
raise ValueError("无效的投票方式,支持的方式: 'hard', 'soft'")
# 从第一个模型获取类别数
self.num_classes = models[0].num_classes
# 验证所有模型的类别数是否相同
for i, model in enumerate(models[1:], 1):
if model.num_classes != self.num_classes:
raise ValueError(
f"模型 {i} 的类别数 ({model.num_classes}) 与第一个模型的类别数 ({self.num_classes}) 不同")
self.name = name
logger.info(f"初始化集成模型,包含 {self.num_models} 个模型,投票方式: {self.voting}")
def predict(self, x: Union[np.ndarray, tf.data.Dataset, List],
batch_size: Optional[int] = None,
verbose: int = 0) -> np.ndarray:
"""
使用集成模型进行预测
Args:
x: 预测数据
batch_size: 批大小
verbose: 详细程度
Returns:
预测概率
"""
# 获取每个模型的预测结果
all_predictions = []
for i, model in enumerate(self.models):
logger.info(f"获取模型 {i + 1}/{self.num_models} 的预测结果")
predictions = model.predict(x, batch_size, verbose)
# 如果是二分类且输出形状是(n,1),转换为(n,2)
if self.num_classes == 2 and predictions.shape[1:] == (1,):
predictions = np.hstack([1 - predictions, predictions])
all_predictions.append(predictions)
# 根据投票方式进行集成
if self.voting == 'hard':
# 硬投票:每个模型预测的类别,取众数
individual_classes = [np.argmax(pred, axis=1) for pred in all_predictions]
# 获取带权重的预测类别频率
ensemble_result = np.zeros((len(x), self.num_classes))
for i, classes in enumerate(individual_classes):
for j, cls in enumerate(classes):
ensemble_result[j, cls] += self.weights[i]
return ensemble_result
else: # soft voting
# 软投票:对每个模型的预测概率进行加权平均
weighted_predictions = [pred * weight for pred, weight in zip(all_predictions, self.weights)]
ensemble_result = np.sum(weighted_predictions, axis=0)
return ensemble_result
def predict_classes(self, x: Union[np.ndarray, tf.data.Dataset, List],
batch_size: Optional[int] = None,
verbose: int = 0) -> np.ndarray:
"""
使用集成模型预测类别
Args:
x: 预测数据
batch_size: 批大小
verbose: 详细程度
Returns:
预测的类别索引
"""
# 获取预测概率
predictions = self.predict(x, batch_size, verbose)
# 获取最大概率的类别索引
return np.argmax(predictions, axis=1)
def save(self, directory: Optional[str] = None) -> str:
"""
保存集成模型
Args:
directory: 保存目录默认为CLASSIFIERS_DIR
Returns:
保存路径
"""
if directory is None:
import time
timestamp = time.strftime("%Y%m%d_%H%M%S")
directory = os.path.join(CLASSIFIERS_DIR, f"{self.name}_{timestamp}")
os.makedirs(directory, exist_ok=True)
# 保存模型列表
model_paths = []
for i, model in enumerate(self.models):
model_path = os.path.join(directory, f"model_{i}")
model.save(model_path)
model_paths.append(model_path)
# 保存集成配置
config = {
"name": self.name,
"num_models": self.num_models,
"model_paths": model_paths,
"weights": self.weights.tolist(),
"voting": self.voting,
"num_classes": self.num_classes
}
import json
config_path = os.path.join(directory, "ensemble_config.json")
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(config, f, ensure_ascii=False, indent=4)
logger.info(f"集成模型已保存到目录: {directory}")
return directory
@classmethod
def load(cls, directory: str, custom_objects: Optional[Dict[str, Any]] = None) -> 'EnsembleModel':
"""
加载集成模型
Args:
directory: 模型目录
custom_objects: 自定义对象字典
Returns:
加载的集成模型实例
"""
# 加载配置
config_path = os.path.join(directory, "ensemble_config.json")
import json
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# 加载子模型
from models.model_factory import ModelFactory
models = []
model_paths = config["model_paths"]
for model_path in model_paths:
model = ModelFactory.load_model(model_path, custom_objects)
models.append(model)
# 创建集成模型
ensemble = cls(
models=models,
weights=config["weights"],
voting=config["voting"],
name=config["name"]
)
logger.info(f"从目录 {directory} 加载集成模型成功")
return ensemble