2025-03-08 01:34:36 +08:00

492 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
评估器模块:实现模型评估流程
"""
import numpy as np
import tensorflow as tf
import time
import os
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
import pandas as pd
import matplotlib.pyplot as plt
import json
from config.system_config import SAVED_MODELS_DIR
from models.base_model import TextClassificationModel
from evaluation.metrics import ClassificationMetrics
from utils.logger import get_logger
from utils.file_utils import ensure_dir, save_json
logger = get_logger("Evaluator")
class ModelEvaluator:
"""模型评估器,负责评估模型性能"""
def __init__(self, model: TextClassificationModel,
class_names: Optional[List[str]] = None,
output_dir: Optional[str] = None,
batch_size: Optional[int] = None):
"""
初始化模型评估器
Args:
model: 要评估的模型
class_names: 类别名称列表
output_dir: 输出目录,用于保存评估结果
batch_size: 批大小如果为None则使用模型默认值
"""
self.model = model
self.class_names = class_names
self.batch_size = batch_size or model.batch_size
# 设置输出目录
if output_dir is None:
self.output_dir = os.path.join(SAVED_MODELS_DIR, 'evaluation', model.model_name)
else:
self.output_dir = output_dir
ensure_dir(self.output_dir)
# 创建评估指标计算器
self.metrics = ClassificationMetrics(class_names)
# 评估结果
self.evaluation_results = None
logger.info(f"初始化模型评估器,模型: {model.model_name}")
def evaluate(self, x_test: Union[np.ndarray, tf.data.Dataset],
y_test: Optional[np.ndarray] = None,
batch_size: Optional[int] = None,
verbose: int = 1) -> Dict[str, float]:
"""
评估模型
Args:
x_test: 测试数据特征
y_test: 测试数据标签
batch_size: 批大小
verbose: 详细程度
Returns:
评估结果
"""
batch_size = batch_size or self.batch_size
logger.info(f"开始评估模型: {self.model.model_name}")
start_time = time.time()
# 使用模型评估
model_metrics = self.model.evaluate(x_test, y_test, verbose=verbose)
# 获取预测结果
y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=0)
y_pred = np.argmax(y_prob, axis=1)
# 处理y_test确保y_test是一维数组
if isinstance(x_test, tf.data.Dataset):
# 如果是TensorFlow Dataset需要从中提取y_test
y_test_extracted = np.concatenate([y for _, y in x_test], axis=0)
y_test = y_test_extracted
if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
y_test = np.argmax(y_test, axis=1)
# 计算所有指标
all_metrics = self.metrics.calculate_all_metrics(y_test, y_pred, y_prob)
# 合并模型内置指标和自定义指标
metrics_names = self.model.model.metrics_names
model_metrics_dict = {name: float(value) for name, value in zip(metrics_names, model_metrics)}
all_metrics.update(model_metrics_dict)
# 记录评估时间
evaluation_time = time.time() - start_time
all_metrics['evaluation_time'] = evaluation_time
# 保存评估结果
self.evaluation_results = {
'metrics': all_metrics,
'confusion_matrix': self.metrics.confusion_matrix(y_test, y_pred).tolist(),
'classification_report': self.metrics.classification_report(y_test, y_pred, output_dict=True)
}
logger.info(f"模型评估完成,用时: {evaluation_time:.2f}")
logger.info(f"主要评估指标: accuracy={all_metrics.get('accuracy', 'N/A'):.4f}, "
f"f1_macro={all_metrics.get('f1_macro', 'N/A'):.4f}")
return all_metrics
def save_evaluation_results(self, save_plots: bool = True) -> str:
"""
保存评估结果
Args:
save_plots: 是否保存可视化图表
Returns:
结果保存路径
"""
if self.evaluation_results is None:
raise ValueError("请先调用evaluate方法进行评估")
# 保存评估结果为JSON
results_path = os.path.join(self.output_dir, 'evaluation_results.json')
with open(results_path, 'w', encoding='utf-8') as f:
json.dump(self.evaluation_results, f, ensure_ascii=False, indent=4)
# 保存评估指标为CSV
metrics_df = pd.DataFrame(
self.evaluation_results['metrics'].items(),
columns=['Metric', 'Value']
).set_index('Metric')
metrics_path = os.path.join(self.output_dir, 'metrics.csv')
metrics_df.to_csv(metrics_path)
# 保存可视化图表
if save_plots:
self._save_plots()
logger.info(f"评估结果已保存到: {self.output_dir}")
return self.output_dir
def _save_plots(self) -> None:
"""保存评估结果可视化图表"""
if self.evaluation_results is None:
raise ValueError("请先调用evaluate方法进行评估")
# 创建可视化目录
plots_dir = os.path.join(self.output_dir, 'plots')
ensure_dir(plots_dir)
# 混淆矩阵图
cm_path = os.path.join(plots_dir, 'confusion_matrix.png')
cm = np.array(self.evaluation_results['confusion_matrix'])
# 将混淆矩阵转换为NumPy数组
if isinstance(cm, list):
cm = np.array(cm)
# 绘制混淆矩阵
self.metrics.plot_confusion_matrix(
np.arange(cm.shape[0]), # 假设标签
np.arange(cm.shape[1]), # 假设预测
normalize='true',
save_path=cm_path
)
# 保存评估指标条形图
metrics_path = os.path.join(plots_dir, 'metrics_bar.png')
metrics = self.evaluation_results['metrics']
# 选择要展示的主要指标
main_metrics = {
'accuracy': metrics.get('accuracy', 0),
'precision_macro': metrics.get('precision_macro', 0),
'recall_macro': metrics.get('recall_macro', 0),
'f1_macro': metrics.get('f1_macro', 0)
}
# 绘制条形图
plt.figure(figsize=(10, 6))
plt.bar(main_metrics.keys(), main_metrics.values())
plt.title('Main Evaluation Metrics')
plt.ylabel('Score')
plt.ylim(0, 1)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig(metrics_path)
plt.close()
# 如果有类别级别的指标,绘制每个类别的指标
if 'classification_report' in self.evaluation_results:
report = self.evaluation_results['classification_report']
# 提取每个类别的精确率、召回率和F1值
class_metrics = {}
for key, value in report.items():
if key not in ['accuracy', 'macro avg', 'weighted avg', 'samples avg']:
if isinstance(value, dict):
class_metrics[key] = value
if class_metrics:
# 绘制每个类别的F1分数
class_f1_path = os.path.join(plots_dir, 'class_f1_scores.png')
classes = list(class_metrics.keys())
f1_scores = [metrics['f1-score'] for metrics in class_metrics.values()]
plt.figure(figsize=(12, 6))
bars = plt.bar(classes, f1_scores)
# 在柱状图上方显示数值
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width() / 2., height + 0.01,
f'{height:.2f}',
ha='center', va='bottom', rotation=0)
plt.title('F1 Score by Class')
plt.ylabel('F1 Score')
plt.ylim(0, 1.1)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig(class_f1_path)
plt.close()
# 绘制每个类别的精确率和召回率
class_prec_rec_path = os.path.join(plots_dir, 'class_precision_recall.png')
precisions = [metrics['precision'] for metrics in class_metrics.values()]
recalls = [metrics['recall'] for metrics in class_metrics.values()]
plt.figure(figsize=(12, 6))
x = np.arange(len(classes))
width = 0.35
plt.bar(x - width / 2, precisions, width, label='Precision')
plt.bar(x + width / 2, recalls, width, label='Recall')
plt.ylabel('Score')
plt.title('Precision and Recall by Class')
plt.xticks(x, classes, rotation=45, ha='right')
plt.legend()
plt.ylim(0, 1.1)
plt.tight_layout()
plt.savefig(class_prec_rec_path)
plt.close()
logger.info(f"评估可视化图表已保存到: {plots_dir}")
def compare_models(self, other_evaluators: List['ModelEvaluator'],
metrics: Optional[List[str]] = None,
save_path: Optional[str] = None) -> pd.DataFrame:
"""
比较多个模型的评估结果
Args:
other_evaluators: 其他模型评估器列表
metrics: 要比较的指标列表,默认为['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
save_path: 比较结果的保存路径
Returns:
比较结果DataFrame
"""
if self.evaluation_results is None:
raise ValueError("请先调用evaluate方法进行评估")
# 默认比较指标
if metrics is None:
metrics = ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
# 收集所有模型的评估指标
models_metrics = {}
# 当前模型
models_metrics[self.model.model_name] = {
metric: self.evaluation_results['metrics'].get(metric, 'N/A')
for metric in metrics
}
# 其他模型
for evaluator in other_evaluators:
if evaluator.evaluation_results is None:
logger.warning(f"模型 {evaluator.model.model_name} 尚未评估,跳过")
continue
models_metrics[evaluator.model.model_name] = {
metric: evaluator.evaluation_results['metrics'].get(metric, 'N/A')
for metric in metrics
}
# 创建比较DataFrame
comparison_df = pd.DataFrame(models_metrics).T
# 保存比较结果
if save_path:
comparison_df.to_csv(save_path)
logger.info(f"模型比较结果已保存到: {save_path}")
# 绘制比较条形图
plt.figure(figsize=(12, 6))
comparison_df.plot(kind='bar', figsize=(12, 6))
plt.title('Model Comparison')
plt.ylabel('Score')
plt.ylim(0, 1)
plt.legend(loc='lower right')
plt.tight_layout()
# 如果save_path是CSV文件将其替换为PNG文件
if save_path.endswith('.csv'):
plot_path = save_path.replace('.csv', '.png')
else:
plot_path = save_path + '.png'
plt.savefig(plot_path)
plt.close()
logger.info(f"模型比较图表已保存到: {plot_path}")
return comparison_df
def evaluate_class_performance(self, x_test: Union[np.ndarray, tf.data.Dataset],
y_test: Optional[np.ndarray] = None,
batch_size: Optional[int] = None,
verbose: int = 0) -> pd.DataFrame:
"""
评估模型在各个类别上的性能
Args:
x_test: 测试数据特征
y_test: 测试数据标签
batch_size: 批大小
verbose: 详细程度
Returns:
各类别性能指标DataFrame
"""
batch_size = batch_size or self.batch_size
# 获取预测结果
y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=verbose)
y_pred = np.argmax(y_prob, axis=1)
# 处理y_test确保y_test是一维数组
if isinstance(x_test, tf.data.Dataset):
# 如果是TensorFlow Dataset需要从中提取y_test
y_test_extracted = np.concatenate([y for _, y in x_test], axis=0)
y_test = y_test_extracted
if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
y_test = np.argmax(y_test, axis=1)
# 获取分类报告
report = self.metrics.classification_report(y_test, y_pred, output_dict=True)
# 提取各类别指标
class_metrics = {}
for key, value in report.items():
if key not in ['accuracy', 'macro avg', 'weighted avg', 'samples avg']:
if isinstance(value, dict):
class_metrics[key] = value
# 转换为DataFrame
class_performance_df = pd.DataFrame(class_metrics).T
# 添加支持度(样本数量)
class_counts = np.bincount(y_test)
for idx, count in enumerate(class_counts):
if str(idx) in class_performance_df.index:
class_performance_df.loc[str(idx), 'support'] = count
# 添加类别名称
if self.class_names:
class_performance_df['class_name'] = [
self.class_names[int(idx)] if int(idx) < len(self.class_names) else idx
for idx in class_performance_df.index
]
# 保存类别性能指标
performance_path = os.path.join(self.output_dir, 'class_performance.csv')
class_performance_df.to_csv(performance_path)
logger.info(f"各类别性能指标已保存到: {performance_path}")
return class_performance_df
def plot_error_analysis(self, x_test: Union[np.ndarray, tf.data.Dataset],
y_test: Optional[np.ndarray] = None,
batch_size: Optional[int] = None,
num_samples: int = 10,
save_path: Optional[str] = None) -> None:
"""
绘制误分类样本分析
Args:
x_test: 测试数据特征
y_test: 测试数据标签
batch_size: 批大小
num_samples: 要展示的误分类样本数量
save_path: 保存路径
"""
# 仅适用于文本数据的分析,需要原始文本
logger.info("误分类样本分析需要原始文本数据,此方法可能需要根据实际数据类型进行修改")
# 在实际应用中,这里应该根据实际数据类型进行修改
# 例如,对于序列化的文本,可能需要反序列化,或者使用词汇表将索引转换回文本
# 此处仅展示一个基本框架
batch_size = batch_size or self.batch_size
# 获取预测结果
y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=0)
y_pred = np.argmax(y_prob, axis=1)
# 处理y_test确保y_test是一维数组
if isinstance(x_test, tf.data.Dataset):
# 如果是TensorFlow Dataset需要从中提取y_test和x_test
dataset_iterator = iter(x_test)
x_test_extracted = []
y_test_extracted = []
try:
while True:
x, y = next(dataset_iterator)
x_test_extracted.append(x)
y_test_extracted.append(y)
except StopIteration:
pass
x_test = np.concatenate(x_test_extracted, axis=0)
y_test = np.concatenate(y_test_extracted, axis=0)
if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
y_test = np.argmax(y_test, axis=1)
# 找出误分类样本
misclassified_indices = np.where(y_pred != y_test)[0]
# 如果没有误分类样本,返回
if len(misclassified_indices) == 0:
logger.info("没有误分类样本")
return
# 随机选择一些误分类样本
if len(misclassified_indices) > num_samples:
misclassified_indices = np.random.choice(misclassified_indices, num_samples, replace=False)
# 保存误分类样本分析结果
misclassified_data = []
for idx in misclassified_indices:
true_label = y_test[idx]
pred_label = y_pred[idx]
true_class = self.class_names[true_label] if self.class_names else str(true_label)
pred_class = self.class_names[pred_label] if self.class_names else str(pred_label)
# 对于序列化的文本,此处需要进行反序列化
# 这里仅作示例,实际应用中需要根据具体数据类型修改
sample_text = f"Sample {idx}"
misclassified_data.append({
'sample_id': idx,
'true_label': true_label,
'predicted_label': pred_label,
'true_class': true_class,
'predicted_class': pred_class,
'confidence': float(y_prob[idx, pred_label]),
'sample_text': sample_text
})
# 创建DataFrame
misclassified_df = pd.DataFrame(misclassified_data)
# 保存结果
if save_path:
misclassified_df.to_csv(save_path)
logger.info(f"误分类样本分析已保存到: {save_path}")
return misclassified_df