357 lines
12 KiB
Python
357 lines
12 KiB
Python
"""
|
||
评估指标模块:实现各种评估指标
|
||
"""
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from sklearn.metrics import (
|
||
accuracy_score, precision_score, recall_score, f1_score,
|
||
confusion_matrix, classification_report, roc_auc_score,
|
||
precision_recall_curve, average_precision_score
|
||
)
|
||
import matplotlib.pyplot as plt
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
|
||
import pandas as pd
|
||
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("Metrics")
|
||
|
||
|
||
class ClassificationMetrics:
|
||
"""分类评估指标类,计算各种分类评估指标"""
|
||
|
||
def __init__(self, class_names: Optional[List[str]] = None):
|
||
"""
|
||
初始化分类评估指标类
|
||
|
||
Args:
|
||
class_names: 类别名称列表
|
||
"""
|
||
self.class_names = class_names
|
||
|
||
def accuracy(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||
"""
|
||
计算准确率
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
|
||
Returns:
|
||
准确率
|
||
"""
|
||
return accuracy_score(y_true, y_pred)
|
||
|
||
def precision(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
average: str = 'macro') -> Union[float, np.ndarray]:
|
||
"""
|
||
计算精确率
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
|
||
|
||
Returns:
|
||
精确率
|
||
"""
|
||
return precision_score(y_true, y_pred, average=average, zero_division=0)
|
||
|
||
def recall(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
average: str = 'macro') -> Union[float, np.ndarray]:
|
||
"""
|
||
计算召回率
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
|
||
|
||
Returns:
|
||
召回率
|
||
"""
|
||
return recall_score(y_true, y_pred, average=average, zero_division=0)
|
||
|
||
def f1(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
average: str = 'macro') -> Union[float, np.ndarray]:
|
||
"""
|
||
计算F1分数
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
|
||
|
||
Returns:
|
||
F1分数
|
||
"""
|
||
return f1_score(y_true, y_pred, average=average, zero_division=0)
|
||
|
||
def confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
normalize: Optional[str] = None) -> np.ndarray:
|
||
"""
|
||
计算混淆矩阵
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
normalize: 归一化方式,可选值: 'true', 'pred', 'all', None
|
||
|
||
Returns:
|
||
混淆矩阵
|
||
"""
|
||
cm = confusion_matrix(y_true, y_pred)
|
||
|
||
if normalize is not None:
|
||
if normalize == 'true':
|
||
cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
|
||
elif normalize == 'pred':
|
||
cm = cm.astype('float') / cm.sum(axis=0, keepdims=True)
|
||
elif normalize == 'all':
|
||
cm = cm.astype('float') / cm.sum()
|
||
|
||
return cm
|
||
|
||
def plot_confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
normalize: Optional[str] = None,
|
||
figsize: Tuple[int, int] = (10, 8),
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制混淆矩阵
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
normalize: 归一化方式,可选值: 'true', 'pred', 'all', None
|
||
figsize: 图像大小
|
||
save_path: 保存路径,如果为None则显示图像
|
||
"""
|
||
# 计算混淆矩阵
|
||
cm = self.confusion_matrix(y_true, y_pred, normalize)
|
||
|
||
# 确定类别名称
|
||
if self.class_names is None:
|
||
class_names = [str(i) for i in range(cm.shape[0])]
|
||
else:
|
||
class_names = self.class_names
|
||
|
||
# 创建图像
|
||
plt.figure(figsize=figsize)
|
||
|
||
# 使用热图显示混淆矩阵
|
||
im = plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
||
plt.colorbar(im)
|
||
|
||
# 设置坐标轴标签
|
||
plt.xticks(np.arange(cm.shape[1]), class_names, rotation=45, ha='right')
|
||
plt.yticks(np.arange(cm.shape[0]), class_names)
|
||
|
||
# 设置标题
|
||
if normalize is not None:
|
||
plt.title(f"Normalized ({normalize}) Confusion Matrix")
|
||
else:
|
||
plt.title("Confusion Matrix")
|
||
|
||
plt.ylabel('True label')
|
||
plt.xlabel('Predicted label')
|
||
|
||
# 在每个单元格中显示数值
|
||
thresh = cm.max() / 2.0
|
||
for i in range(cm.shape[0]):
|
||
for j in range(cm.shape[1]):
|
||
if normalize is not None:
|
||
plt.text(j, i, f"{cm[i, j]:.2f}",
|
||
ha="center", va="center",
|
||
color="white" if cm[i, j] > thresh else "black")
|
||
else:
|
||
plt.text(j, i, f"{cm[i, j]}",
|
||
ha="center", va="center",
|
||
color="white" if cm[i, j] > thresh else "black")
|
||
|
||
plt.tight_layout()
|
||
|
||
# 保存或显示图像
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"混淆矩阵图已保存到: {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
def classification_report(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
output_dict: bool = False) -> Union[str, Dict]:
|
||
"""
|
||
生成分类报告
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
output_dict: 是否以字典形式返回
|
||
|
||
Returns:
|
||
分类报告
|
||
"""
|
||
target_names = self.class_names if self.class_names else None
|
||
return classification_report(y_true, y_pred,
|
||
target_names=target_names,
|
||
output_dict=output_dict,
|
||
zero_division=0)
|
||
|
||
def auc_roc(self, y_true: np.ndarray, y_prob: np.ndarray,
|
||
multi_class: str = 'ovr') -> Union[float, np.ndarray]:
|
||
"""
|
||
计算AUC-ROC
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_prob: 预测概率
|
||
multi_class: 多分类处理方式,可选值: 'ovr', 'ovo'
|
||
|
||
Returns:
|
||
AUC-ROC
|
||
"""
|
||
try:
|
||
# 如果y_true是one-hot编码,转换为类别索引
|
||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||
y_true = np.argmax(y_true, axis=1)
|
||
|
||
# 多分类
|
||
if y_prob.shape[1] > 2:
|
||
return roc_auc_score(y_true, y_prob, multi_class=multi_class, average='macro')
|
||
# 二分类
|
||
else:
|
||
return roc_auc_score(y_true, y_prob[:, 1])
|
||
except Exception as e:
|
||
logger.error(f"计算AUC-ROC时出错: {e}")
|
||
return 0.0
|
||
|
||
def average_precision(self, y_true: np.ndarray, y_prob: np.ndarray,
|
||
average: str = 'macro') -> Union[float, np.ndarray]:
|
||
"""
|
||
计算平均精确率
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_prob: 预测概率
|
||
average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
|
||
|
||
Returns:
|
||
平均精确率
|
||
"""
|
||
try:
|
||
# 如果y_true是one-hot编码,转换为类别索引
|
||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||
y_true = np.argmax(y_true, axis=1)
|
||
|
||
# 多分类:使用sklearn的方法
|
||
return average_precision_score(
|
||
tf.keras.utils.to_categorical(y_true, num_classes=y_prob.shape[1]),
|
||
y_prob,
|
||
average=average
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"计算平均精确率时出错: {e}")
|
||
return 0.0
|
||
|
||
def plot_precision_recall_curve(self, y_true: np.ndarray, y_prob: np.ndarray,
|
||
class_id: Optional[int] = None,
|
||
figsize: Tuple[int, int] = (10, 8),
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制精确率-召回率曲线
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_prob: 预测概率
|
||
class_id: 要绘制的类别ID,如果为None则绘制所有类别
|
||
figsize: 图像大小
|
||
save_path: 保存路径,如果为None则显示图像
|
||
"""
|
||
# 如果y_true是one-hot编码,转换为类别索引
|
||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||
y_true = np.argmax(y_true, axis=1)
|
||
|
||
# 创建图像
|
||
plt.figure(figsize=figsize)
|
||
|
||
# 确定要绘制的类别
|
||
if class_id is not None and class_id < y_prob.shape[1]:
|
||
# 绘制指定类别的PR曲线
|
||
y_true_bin = (y_true == class_id).astype(int)
|
||
precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, class_id])
|
||
avg_prec = average_precision_score(y_true_bin, y_prob[:, class_id])
|
||
|
||
plt.step(recall, precision, where='post',
|
||
label=f'Class {class_id} (AP = {avg_prec:.3f})')
|
||
else:
|
||
# 绘制所有类别的PR曲线
|
||
for i in range(y_prob.shape[1]):
|
||
y_true_bin = (y_true == i).astype(int)
|
||
precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, i])
|
||
avg_prec = average_precision_score(y_true_bin, y_prob[:, i])
|
||
|
||
class_name = self.class_names[i] if self.class_names else f"Class {i}"
|
||
plt.step(recall, precision, where='post',
|
||
label=f'{class_name} (AP = {avg_prec:.3f})')
|
||
|
||
plt.xlabel('Recall')
|
||
plt.ylabel('Precision')
|
||
plt.title('Precision-Recall Curve')
|
||
plt.legend(loc='lower left')
|
||
plt.grid(True)
|
||
|
||
# 保存或显示图像
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"精确率-召回率曲线图已保存到: {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
def calculate_all_metrics(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
y_prob: Optional[np.ndarray] = None) -> Dict[str, float]:
|
||
"""
|
||
计算所有评估指标
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
y_prob: 预测概率
|
||
|
||
Returns:
|
||
包含所有评估指标的字典
|
||
"""
|
||
metrics = {}
|
||
|
||
# 基础指标
|
||
metrics['accuracy'] = self.accuracy(y_true, y_pred)
|
||
metrics['precision_macro'] = self.precision(y_true, y_pred, average='macro')
|
||
metrics['recall_macro'] = self.recall(y_true, y_pred, average='macro')
|
||
metrics['f1_macro'] = self.f1(y_true, y_pred, average='macro')
|
||
|
||
# 如果提供了预测概率,计算AUC-ROC和平均精确率
|
||
if y_prob is not None:
|
||
try:
|
||
metrics['auc_roc'] = self.auc_roc(y_true, y_prob)
|
||
metrics['average_precision'] = self.average_precision(y_true, y_prob)
|
||
except Exception as e:
|
||
logger.error(f"计算概率指标时出错: {e}")
|
||
|
||
# 类别级别的指标
|
||
for avg in ['micro', 'weighted']:
|
||
metrics[f'precision_{avg}'] = self.precision(y_true, y_pred, average=avg)
|
||
metrics[f'recall_{avg}'] = self.recall(y_true, y_pred, average=avg)
|
||
metrics[f'f1_{avg}'] = self.f1(y_true, y_pred, average=avg)
|
||
|
||
return metrics
|
||
|
||
def metrics_to_dataframe(self, metrics: Dict[str, float]) -> pd.DataFrame:
|
||
"""
|
||
将评估指标转换为DataFrame
|
||
|
||
Args:
|
||
metrics: 评估指标字典
|
||
|
||
Returns:
|
||
评估指标DataFrame
|
||
"""
|
||
return pd.DataFrame(metrics.items(), columns=['Metric', 'Value']).set_index('Metric')
|