492 lines
18 KiB
Python
492 lines
18 KiB
Python
"""
|
||
评估器模块:实现模型评估流程
|
||
"""
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
import time
|
||
import os
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import json
|
||
|
||
from config.system_config import SAVED_MODELS_DIR
|
||
from models.base_model import TextClassificationModel
|
||
from evaluation.metrics import ClassificationMetrics
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import ensure_dir, save_json
|
||
|
||
logger = get_logger("Evaluator")
|
||
|
||
|
||
class ModelEvaluator:
|
||
"""模型评估器,负责评估模型性能"""
|
||
|
||
def __init__(self, model: TextClassificationModel,
|
||
class_names: Optional[List[str]] = None,
|
||
output_dir: Optional[str] = None,
|
||
batch_size: Optional[int] = None):
|
||
"""
|
||
初始化模型评估器
|
||
|
||
Args:
|
||
model: 要评估的模型
|
||
class_names: 类别名称列表
|
||
output_dir: 输出目录,用于保存评估结果
|
||
batch_size: 批大小,如果为None则使用模型默认值
|
||
"""
|
||
self.model = model
|
||
self.class_names = class_names
|
||
self.batch_size = batch_size or model.batch_size
|
||
|
||
# 设置输出目录
|
||
if output_dir is None:
|
||
self.output_dir = os.path.join(SAVED_MODELS_DIR, 'evaluation', model.model_name)
|
||
else:
|
||
self.output_dir = output_dir
|
||
|
||
ensure_dir(self.output_dir)
|
||
|
||
# 创建评估指标计算器
|
||
self.metrics = ClassificationMetrics(class_names)
|
||
|
||
# 评估结果
|
||
self.evaluation_results = None
|
||
|
||
logger.info(f"初始化模型评估器,模型: {model.model_name}")
|
||
|
||
def evaluate(self, x_test: Union[np.ndarray, tf.data.Dataset],
|
||
y_test: Optional[np.ndarray] = None,
|
||
batch_size: Optional[int] = None,
|
||
verbose: int = 1) -> Dict[str, float]:
|
||
"""
|
||
评估模型
|
||
|
||
Args:
|
||
x_test: 测试数据特征
|
||
y_test: 测试数据标签
|
||
batch_size: 批大小
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
评估结果
|
||
"""
|
||
batch_size = batch_size or self.batch_size
|
||
|
||
logger.info(f"开始评估模型: {self.model.model_name}")
|
||
start_time = time.time()
|
||
|
||
# 使用模型评估
|
||
model_metrics = self.model.evaluate(x_test, y_test, verbose=verbose)
|
||
|
||
# 获取预测结果
|
||
y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=0)
|
||
y_pred = np.argmax(y_prob, axis=1)
|
||
|
||
# 处理y_test,确保y_test是一维数组
|
||
if isinstance(x_test, tf.data.Dataset):
|
||
# 如果是TensorFlow Dataset,需要从中提取y_test
|
||
y_test_extracted = np.concatenate([y for _, y in x_test], axis=0)
|
||
y_test = y_test_extracted
|
||
|
||
if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
|
||
y_test = np.argmax(y_test, axis=1)
|
||
|
||
# 计算所有指标
|
||
all_metrics = self.metrics.calculate_all_metrics(y_test, y_pred, y_prob)
|
||
|
||
# 合并模型内置指标和自定义指标
|
||
metrics_names = self.model.model.metrics_names
|
||
model_metrics_dict = {name: float(value) for name, value in zip(metrics_names, model_metrics)}
|
||
all_metrics.update(model_metrics_dict)
|
||
|
||
# 记录评估时间
|
||
evaluation_time = time.time() - start_time
|
||
all_metrics['evaluation_time'] = evaluation_time
|
||
|
||
# 保存评估结果
|
||
self.evaluation_results = {
|
||
'metrics': all_metrics,
|
||
'confusion_matrix': self.metrics.confusion_matrix(y_test, y_pred).tolist(),
|
||
'classification_report': self.metrics.classification_report(y_test, y_pred, output_dict=True)
|
||
}
|
||
|
||
logger.info(f"模型评估完成,用时: {evaluation_time:.2f} 秒")
|
||
logger.info(f"主要评估指标: accuracy={all_metrics.get('accuracy', 'N/A'):.4f}, "
|
||
f"f1_macro={all_metrics.get('f1_macro', 'N/A'):.4f}")
|
||
|
||
return all_metrics
|
||
|
||
def save_evaluation_results(self, save_plots: bool = True) -> str:
|
||
"""
|
||
保存评估结果
|
||
|
||
Args:
|
||
save_plots: 是否保存可视化图表
|
||
|
||
Returns:
|
||
结果保存路径
|
||
"""
|
||
if self.evaluation_results is None:
|
||
raise ValueError("请先调用evaluate方法进行评估")
|
||
|
||
# 保存评估结果为JSON
|
||
results_path = os.path.join(self.output_dir, 'evaluation_results.json')
|
||
with open(results_path, 'w', encoding='utf-8') as f:
|
||
json.dump(self.evaluation_results, f, ensure_ascii=False, indent=4)
|
||
|
||
# 保存评估指标为CSV
|
||
metrics_df = pd.DataFrame(
|
||
self.evaluation_results['metrics'].items(),
|
||
columns=['Metric', 'Value']
|
||
).set_index('Metric')
|
||
|
||
metrics_path = os.path.join(self.output_dir, 'metrics.csv')
|
||
metrics_df.to_csv(metrics_path)
|
||
|
||
# 保存可视化图表
|
||
if save_plots:
|
||
self._save_plots()
|
||
|
||
logger.info(f"评估结果已保存到: {self.output_dir}")
|
||
|
||
return self.output_dir
|
||
|
||
def _save_plots(self) -> None:
|
||
"""保存评估结果可视化图表"""
|
||
if self.evaluation_results is None:
|
||
raise ValueError("请先调用evaluate方法进行评估")
|
||
|
||
# 创建可视化目录
|
||
plots_dir = os.path.join(self.output_dir, 'plots')
|
||
ensure_dir(plots_dir)
|
||
|
||
# 混淆矩阵图
|
||
cm_path = os.path.join(plots_dir, 'confusion_matrix.png')
|
||
cm = np.array(self.evaluation_results['confusion_matrix'])
|
||
|
||
# 将混淆矩阵转换为NumPy数组
|
||
if isinstance(cm, list):
|
||
cm = np.array(cm)
|
||
|
||
# 绘制混淆矩阵
|
||
self.metrics.plot_confusion_matrix(
|
||
np.arange(cm.shape[0]), # 假设标签
|
||
np.arange(cm.shape[1]), # 假设预测
|
||
normalize='true',
|
||
save_path=cm_path
|
||
)
|
||
|
||
# 保存评估指标条形图
|
||
metrics_path = os.path.join(plots_dir, 'metrics_bar.png')
|
||
metrics = self.evaluation_results['metrics']
|
||
|
||
# 选择要展示的主要指标
|
||
main_metrics = {
|
||
'accuracy': metrics.get('accuracy', 0),
|
||
'precision_macro': metrics.get('precision_macro', 0),
|
||
'recall_macro': metrics.get('recall_macro', 0),
|
||
'f1_macro': metrics.get('f1_macro', 0)
|
||
}
|
||
|
||
# 绘制条形图
|
||
plt.figure(figsize=(10, 6))
|
||
plt.bar(main_metrics.keys(), main_metrics.values())
|
||
plt.title('Main Evaluation Metrics')
|
||
plt.ylabel('Score')
|
||
plt.ylim(0, 1)
|
||
plt.xticks(rotation=45, ha='right')
|
||
plt.tight_layout()
|
||
plt.savefig(metrics_path)
|
||
plt.close()
|
||
|
||
# 如果有类别级别的指标,绘制每个类别的指标
|
||
if 'classification_report' in self.evaluation_results:
|
||
report = self.evaluation_results['classification_report']
|
||
|
||
# 提取每个类别的精确率、召回率和F1值
|
||
class_metrics = {}
|
||
for key, value in report.items():
|
||
if key not in ['accuracy', 'macro avg', 'weighted avg', 'samples avg']:
|
||
if isinstance(value, dict):
|
||
class_metrics[key] = value
|
||
|
||
if class_metrics:
|
||
# 绘制每个类别的F1分数
|
||
class_f1_path = os.path.join(plots_dir, 'class_f1_scores.png')
|
||
|
||
classes = list(class_metrics.keys())
|
||
f1_scores = [metrics['f1-score'] for metrics in class_metrics.values()]
|
||
|
||
plt.figure(figsize=(12, 6))
|
||
bars = plt.bar(classes, f1_scores)
|
||
|
||
# 在柱状图上方显示数值
|
||
for bar in bars:
|
||
height = bar.get_height()
|
||
plt.text(bar.get_x() + bar.get_width() / 2., height + 0.01,
|
||
f'{height:.2f}',
|
||
ha='center', va='bottom', rotation=0)
|
||
|
||
plt.title('F1 Score by Class')
|
||
plt.ylabel('F1 Score')
|
||
plt.ylim(0, 1.1)
|
||
plt.xticks(rotation=45, ha='right')
|
||
plt.tight_layout()
|
||
plt.savefig(class_f1_path)
|
||
plt.close()
|
||
|
||
# 绘制每个类别的精确率和召回率
|
||
class_prec_rec_path = os.path.join(plots_dir, 'class_precision_recall.png')
|
||
|
||
precisions = [metrics['precision'] for metrics in class_metrics.values()]
|
||
recalls = [metrics['recall'] for metrics in class_metrics.values()]
|
||
|
||
plt.figure(figsize=(12, 6))
|
||
x = np.arange(len(classes))
|
||
width = 0.35
|
||
|
||
plt.bar(x - width / 2, precisions, width, label='Precision')
|
||
plt.bar(x + width / 2, recalls, width, label='Recall')
|
||
|
||
plt.ylabel('Score')
|
||
plt.title('Precision and Recall by Class')
|
||
plt.xticks(x, classes, rotation=45, ha='right')
|
||
plt.legend()
|
||
plt.ylim(0, 1.1)
|
||
plt.tight_layout()
|
||
plt.savefig(class_prec_rec_path)
|
||
plt.close()
|
||
|
||
logger.info(f"评估可视化图表已保存到: {plots_dir}")
|
||
|
||
def compare_models(self, other_evaluators: List['ModelEvaluator'],
|
||
metrics: Optional[List[str]] = None,
|
||
save_path: Optional[str] = None) -> pd.DataFrame:
|
||
"""
|
||
比较多个模型的评估结果
|
||
|
||
Args:
|
||
other_evaluators: 其他模型评估器列表
|
||
metrics: 要比较的指标列表,默认为['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
|
||
save_path: 比较结果的保存路径
|
||
|
||
Returns:
|
||
比较结果DataFrame
|
||
"""
|
||
if self.evaluation_results is None:
|
||
raise ValueError("请先调用evaluate方法进行评估")
|
||
|
||
# 默认比较指标
|
||
if metrics is None:
|
||
metrics = ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
|
||
|
||
# 收集所有模型的评估指标
|
||
models_metrics = {}
|
||
|
||
# 当前模型
|
||
models_metrics[self.model.model_name] = {
|
||
metric: self.evaluation_results['metrics'].get(metric, 'N/A')
|
||
for metric in metrics
|
||
}
|
||
|
||
# 其他模型
|
||
for evaluator in other_evaluators:
|
||
if evaluator.evaluation_results is None:
|
||
logger.warning(f"模型 {evaluator.model.model_name} 尚未评估,跳过")
|
||
continue
|
||
|
||
models_metrics[evaluator.model.model_name] = {
|
||
metric: evaluator.evaluation_results['metrics'].get(metric, 'N/A')
|
||
for metric in metrics
|
||
}
|
||
|
||
# 创建比较DataFrame
|
||
comparison_df = pd.DataFrame(models_metrics).T
|
||
|
||
# 保存比较结果
|
||
if save_path:
|
||
comparison_df.to_csv(save_path)
|
||
logger.info(f"模型比较结果已保存到: {save_path}")
|
||
|
||
# 绘制比较条形图
|
||
plt.figure(figsize=(12, 6))
|
||
comparison_df.plot(kind='bar', figsize=(12, 6))
|
||
plt.title('Model Comparison')
|
||
plt.ylabel('Score')
|
||
plt.ylim(0, 1)
|
||
plt.legend(loc='lower right')
|
||
plt.tight_layout()
|
||
|
||
# 如果save_path是CSV文件,将其替换为PNG文件
|
||
if save_path.endswith('.csv'):
|
||
plot_path = save_path.replace('.csv', '.png')
|
||
else:
|
||
plot_path = save_path + '.png'
|
||
|
||
plt.savefig(plot_path)
|
||
plt.close()
|
||
|
||
logger.info(f"模型比较图表已保存到: {plot_path}")
|
||
|
||
return comparison_df
|
||
|
||
def evaluate_class_performance(self, x_test: Union[np.ndarray, tf.data.Dataset],
|
||
y_test: Optional[np.ndarray] = None,
|
||
batch_size: Optional[int] = None,
|
||
verbose: int = 0) -> pd.DataFrame:
|
||
"""
|
||
评估模型在各个类别上的性能
|
||
|
||
Args:
|
||
x_test: 测试数据特征
|
||
y_test: 测试数据标签
|
||
batch_size: 批大小
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
各类别性能指标DataFrame
|
||
"""
|
||
batch_size = batch_size or self.batch_size
|
||
|
||
# 获取预测结果
|
||
y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=verbose)
|
||
y_pred = np.argmax(y_prob, axis=1)
|
||
|
||
# 处理y_test,确保y_test是一维数组
|
||
if isinstance(x_test, tf.data.Dataset):
|
||
# 如果是TensorFlow Dataset,需要从中提取y_test
|
||
y_test_extracted = np.concatenate([y for _, y in x_test], axis=0)
|
||
y_test = y_test_extracted
|
||
|
||
if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
|
||
y_test = np.argmax(y_test, axis=1)
|
||
|
||
# 获取分类报告
|
||
report = self.metrics.classification_report(y_test, y_pred, output_dict=True)
|
||
|
||
# 提取各类别指标
|
||
class_metrics = {}
|
||
for key, value in report.items():
|
||
if key not in ['accuracy', 'macro avg', 'weighted avg', 'samples avg']:
|
||
if isinstance(value, dict):
|
||
class_metrics[key] = value
|
||
|
||
# 转换为DataFrame
|
||
class_performance_df = pd.DataFrame(class_metrics).T
|
||
|
||
# 添加支持度(样本数量)
|
||
class_counts = np.bincount(y_test)
|
||
for idx, count in enumerate(class_counts):
|
||
if str(idx) in class_performance_df.index:
|
||
class_performance_df.loc[str(idx), 'support'] = count
|
||
|
||
# 添加类别名称
|
||
if self.class_names:
|
||
class_performance_df['class_name'] = [
|
||
self.class_names[int(idx)] if int(idx) < len(self.class_names) else idx
|
||
for idx in class_performance_df.index
|
||
]
|
||
|
||
# 保存类别性能指标
|
||
performance_path = os.path.join(self.output_dir, 'class_performance.csv')
|
||
class_performance_df.to_csv(performance_path)
|
||
logger.info(f"各类别性能指标已保存到: {performance_path}")
|
||
|
||
return class_performance_df
|
||
|
||
def plot_error_analysis(self, x_test: Union[np.ndarray, tf.data.Dataset],
|
||
y_test: Optional[np.ndarray] = None,
|
||
batch_size: Optional[int] = None,
|
||
num_samples: int = 10,
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制误分类样本分析
|
||
|
||
Args:
|
||
x_test: 测试数据特征
|
||
y_test: 测试数据标签
|
||
batch_size: 批大小
|
||
num_samples: 要展示的误分类样本数量
|
||
save_path: 保存路径
|
||
"""
|
||
# 仅适用于文本数据的分析,需要原始文本
|
||
logger.info("误分类样本分析需要原始文本数据,此方法可能需要根据实际数据类型进行修改")
|
||
|
||
# 在实际应用中,这里应该根据实际数据类型进行修改
|
||
# 例如,对于序列化的文本,可能需要反序列化,或者使用词汇表将索引转换回文本
|
||
|
||
# 此处仅展示一个基本框架
|
||
batch_size = batch_size or self.batch_size
|
||
|
||
# 获取预测结果
|
||
y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=0)
|
||
y_pred = np.argmax(y_prob, axis=1)
|
||
|
||
# 处理y_test,确保y_test是一维数组
|
||
if isinstance(x_test, tf.data.Dataset):
|
||
# 如果是TensorFlow Dataset,需要从中提取y_test和x_test
|
||
dataset_iterator = iter(x_test)
|
||
x_test_extracted = []
|
||
y_test_extracted = []
|
||
|
||
try:
|
||
while True:
|
||
x, y = next(dataset_iterator)
|
||
x_test_extracted.append(x)
|
||
y_test_extracted.append(y)
|
||
except StopIteration:
|
||
pass
|
||
|
||
x_test = np.concatenate(x_test_extracted, axis=0)
|
||
y_test = np.concatenate(y_test_extracted, axis=0)
|
||
|
||
if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
|
||
y_test = np.argmax(y_test, axis=1)
|
||
|
||
# 找出误分类样本
|
||
misclassified_indices = np.where(y_pred != y_test)[0]
|
||
|
||
# 如果没有误分类样本,返回
|
||
if len(misclassified_indices) == 0:
|
||
logger.info("没有误分类样本")
|
||
return
|
||
|
||
# 随机选择一些误分类样本
|
||
if len(misclassified_indices) > num_samples:
|
||
misclassified_indices = np.random.choice(misclassified_indices, num_samples, replace=False)
|
||
|
||
# 保存误分类样本分析结果
|
||
misclassified_data = []
|
||
|
||
for idx in misclassified_indices:
|
||
true_label = y_test[idx]
|
||
pred_label = y_pred[idx]
|
||
|
||
true_class = self.class_names[true_label] if self.class_names else str(true_label)
|
||
pred_class = self.class_names[pred_label] if self.class_names else str(pred_label)
|
||
|
||
# 对于序列化的文本,此处需要进行反序列化
|
||
# 这里仅作示例,实际应用中需要根据具体数据类型修改
|
||
sample_text = f"Sample {idx}"
|
||
|
||
misclassified_data.append({
|
||
'sample_id': idx,
|
||
'true_label': true_label,
|
||
'predicted_label': pred_label,
|
||
'true_class': true_class,
|
||
'predicted_class': pred_class,
|
||
'confidence': float(y_prob[idx, pred_label]),
|
||
'sample_text': sample_text
|
||
})
|
||
|
||
# 创建DataFrame
|
||
misclassified_df = pd.DataFrame(misclassified_data)
|
||
|
||
# 保存结果
|
||
if save_path:
|
||
misclassified_df.to_csv(save_path)
|
||
logger.info(f"误分类样本分析已保存到: {save_path}")
|
||
|
||
return misclassified_df
|