2025-03-08 01:34:36 +08:00

357 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
评估指标模块:实现各种评估指标
"""
import numpy as np
import tensorflow as tf
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
confusion_matrix, classification_report, roc_auc_score,
precision_recall_curve, average_precision_score
)
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
import pandas as pd
from utils.logger import get_logger
logger = get_logger("Metrics")
class ClassificationMetrics:
"""分类评估指标类,计算各种分类评估指标"""
def __init__(self, class_names: Optional[List[str]] = None):
"""
初始化分类评估指标类
Args:
class_names: 类别名称列表
"""
self.class_names = class_names
def accuracy(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""
计算准确率
Args:
y_true: 真实标签
y_pred: 预测标签
Returns:
准确率
"""
return accuracy_score(y_true, y_pred)
def precision(self, y_true: np.ndarray, y_pred: np.ndarray,
average: str = 'macro') -> Union[float, np.ndarray]:
"""
计算精确率
Args:
y_true: 真实标签
y_pred: 预测标签
average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
Returns:
精确率
"""
return precision_score(y_true, y_pred, average=average, zero_division=0)
def recall(self, y_true: np.ndarray, y_pred: np.ndarray,
average: str = 'macro') -> Union[float, np.ndarray]:
"""
计算召回率
Args:
y_true: 真实标签
y_pred: 预测标签
average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
Returns:
召回率
"""
return recall_score(y_true, y_pred, average=average, zero_division=0)
def f1(self, y_true: np.ndarray, y_pred: np.ndarray,
average: str = 'macro') -> Union[float, np.ndarray]:
"""
计算F1分数
Args:
y_true: 真实标签
y_pred: 预测标签
average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
Returns:
F1分数
"""
return f1_score(y_true, y_pred, average=average, zero_division=0)
def confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray,
normalize: Optional[str] = None) -> np.ndarray:
"""
计算混淆矩阵
Args:
y_true: 真实标签
y_pred: 预测标签
normalize: 归一化方式,可选值: 'true', 'pred', 'all', None
Returns:
混淆矩阵
"""
cm = confusion_matrix(y_true, y_pred)
if normalize is not None:
if normalize == 'true':
cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
elif normalize == 'pred':
cm = cm.astype('float') / cm.sum(axis=0, keepdims=True)
elif normalize == 'all':
cm = cm.astype('float') / cm.sum()
return cm
def plot_confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray,
normalize: Optional[str] = None,
figsize: Tuple[int, int] = (10, 8),
save_path: Optional[str] = None) -> None:
"""
绘制混淆矩阵
Args:
y_true: 真实标签
y_pred: 预测标签
normalize: 归一化方式,可选值: 'true', 'pred', 'all', None
figsize: 图像大小
save_path: 保存路径如果为None则显示图像
"""
# 计算混淆矩阵
cm = self.confusion_matrix(y_true, y_pred, normalize)
# 确定类别名称
if self.class_names is None:
class_names = [str(i) for i in range(cm.shape[0])]
else:
class_names = self.class_names
# 创建图像
plt.figure(figsize=figsize)
# 使用热图显示混淆矩阵
im = plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.colorbar(im)
# 设置坐标轴标签
plt.xticks(np.arange(cm.shape[1]), class_names, rotation=45, ha='right')
plt.yticks(np.arange(cm.shape[0]), class_names)
# 设置标题
if normalize is not None:
plt.title(f"Normalized ({normalize}) Confusion Matrix")
else:
plt.title("Confusion Matrix")
plt.ylabel('True label')
plt.xlabel('Predicted label')
# 在每个单元格中显示数值
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
if normalize is not None:
plt.text(j, i, f"{cm[i, j]:.2f}",
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
else:
plt.text(j, i, f"{cm[i, j]}",
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
# 保存或显示图像
if save_path:
plt.savefig(save_path)
logger.info(f"混淆矩阵图已保存到: {save_path}")
else:
plt.show()
def classification_report(self, y_true: np.ndarray, y_pred: np.ndarray,
output_dict: bool = False) -> Union[str, Dict]:
"""
生成分类报告
Args:
y_true: 真实标签
y_pred: 预测标签
output_dict: 是否以字典形式返回
Returns:
分类报告
"""
target_names = self.class_names if self.class_names else None
return classification_report(y_true, y_pred,
target_names=target_names,
output_dict=output_dict,
zero_division=0)
def auc_roc(self, y_true: np.ndarray, y_prob: np.ndarray,
multi_class: str = 'ovr') -> Union[float, np.ndarray]:
"""
计算AUC-ROC
Args:
y_true: 真实标签
y_prob: 预测概率
multi_class: 多分类处理方式,可选值: 'ovr', 'ovo'
Returns:
AUC-ROC
"""
try:
# 如果y_true是one-hot编码转换为类别索引
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 多分类
if y_prob.shape[1] > 2:
return roc_auc_score(y_true, y_prob, multi_class=multi_class, average='macro')
# 二分类
else:
return roc_auc_score(y_true, y_prob[:, 1])
except Exception as e:
logger.error(f"计算AUC-ROC时出错: {e}")
return 0.0
def average_precision(self, y_true: np.ndarray, y_prob: np.ndarray,
average: str = 'macro') -> Union[float, np.ndarray]:
"""
计算平均精确率
Args:
y_true: 真实标签
y_prob: 预测概率
average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
Returns:
平均精确率
"""
try:
# 如果y_true是one-hot编码转换为类别索引
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 多分类使用sklearn的方法
return average_precision_score(
tf.keras.utils.to_categorical(y_true, num_classes=y_prob.shape[1]),
y_prob,
average=average
)
except Exception as e:
logger.error(f"计算平均精确率时出错: {e}")
return 0.0
def plot_precision_recall_curve(self, y_true: np.ndarray, y_prob: np.ndarray,
class_id: Optional[int] = None,
figsize: Tuple[int, int] = (10, 8),
save_path: Optional[str] = None) -> None:
"""
绘制精确率-召回率曲线
Args:
y_true: 真实标签
y_prob: 预测概率
class_id: 要绘制的类别ID如果为None则绘制所有类别
figsize: 图像大小
save_path: 保存路径如果为None则显示图像
"""
# 如果y_true是one-hot编码转换为类别索引
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 创建图像
plt.figure(figsize=figsize)
# 确定要绘制的类别
if class_id is not None and class_id < y_prob.shape[1]:
# 绘制指定类别的PR曲线
y_true_bin = (y_true == class_id).astype(int)
precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, class_id])
avg_prec = average_precision_score(y_true_bin, y_prob[:, class_id])
plt.step(recall, precision, where='post',
label=f'Class {class_id} (AP = {avg_prec:.3f})')
else:
# 绘制所有类别的PR曲线
for i in range(y_prob.shape[1]):
y_true_bin = (y_true == i).astype(int)
precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, i])
avg_prec = average_precision_score(y_true_bin, y_prob[:, i])
class_name = self.class_names[i] if self.class_names else f"Class {i}"
plt.step(recall, precision, where='post',
label=f'{class_name} (AP = {avg_prec:.3f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc='lower left')
plt.grid(True)
# 保存或显示图像
if save_path:
plt.savefig(save_path)
logger.info(f"精确率-召回率曲线图已保存到: {save_path}")
else:
plt.show()
def calculate_all_metrics(self, y_true: np.ndarray, y_pred: np.ndarray,
y_prob: Optional[np.ndarray] = None) -> Dict[str, float]:
"""
计算所有评估指标
Args:
y_true: 真实标签
y_pred: 预测标签
y_prob: 预测概率
Returns:
包含所有评估指标的字典
"""
metrics = {}
# 基础指标
metrics['accuracy'] = self.accuracy(y_true, y_pred)
metrics['precision_macro'] = self.precision(y_true, y_pred, average='macro')
metrics['recall_macro'] = self.recall(y_true, y_pred, average='macro')
metrics['f1_macro'] = self.f1(y_true, y_pred, average='macro')
# 如果提供了预测概率计算AUC-ROC和平均精确率
if y_prob is not None:
try:
metrics['auc_roc'] = self.auc_roc(y_true, y_prob)
metrics['average_precision'] = self.average_precision(y_true, y_prob)
except Exception as e:
logger.error(f"计算概率指标时出错: {e}")
# 类别级别的指标
for avg in ['micro', 'weighted']:
metrics[f'precision_{avg}'] = self.precision(y_true, y_pred, average=avg)
metrics[f'recall_{avg}'] = self.recall(y_true, y_pred, average=avg)
metrics[f'f1_{avg}'] = self.f1(y_true, y_pred, average=avg)
return metrics
def metrics_to_dataframe(self, metrics: Dict[str, float]) -> pd.DataFrame:
"""
将评估指标转换为DataFrame
Args:
metrics: 评估指标字典
Returns:
评估指标DataFrame
"""
return pd.DataFrame(metrics.items(), columns=['Metric', 'Value']).set_index('Metric')