""" 评估指标模块:实现各种评估指标 """ import numpy as np import tensorflow as tf from sklearn.metrics import ( accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report, roc_auc_score, precision_recall_curve, average_precision_score ) import matplotlib.pyplot as plt from typing import List, Dict, Tuple, Optional, Any, Union, Callable import pandas as pd from utils.logger import get_logger logger = get_logger("Metrics") class ClassificationMetrics: """分类评估指标类,计算各种分类评估指标""" def __init__(self, class_names: Optional[List[str]] = None): """ 初始化分类评估指标类 Args: class_names: 类别名称列表 """ self.class_names = class_names def accuracy(self, y_true: np.ndarray, y_pred: np.ndarray) -> float: """ 计算准确率 Args: y_true: 真实标签 y_pred: 预测标签 Returns: 准确率 """ return accuracy_score(y_true, y_pred) def precision(self, y_true: np.ndarray, y_pred: np.ndarray, average: str = 'macro') -> Union[float, np.ndarray]: """ 计算精确率 Args: y_true: 真实标签 y_pred: 预测标签 average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None Returns: 精确率 """ return precision_score(y_true, y_pred, average=average, zero_division=0) def recall(self, y_true: np.ndarray, y_pred: np.ndarray, average: str = 'macro') -> Union[float, np.ndarray]: """ 计算召回率 Args: y_true: 真实标签 y_pred: 预测标签 average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None Returns: 召回率 """ return recall_score(y_true, y_pred, average=average, zero_division=0) def f1(self, y_true: np.ndarray, y_pred: np.ndarray, average: str = 'macro') -> Union[float, np.ndarray]: """ 计算F1分数 Args: y_true: 真实标签 y_pred: 预测标签 average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None Returns: F1分数 """ return f1_score(y_true, y_pred, average=average, zero_division=0) def confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray, normalize: Optional[str] = None) -> np.ndarray: """ 计算混淆矩阵 Args: y_true: 真实标签 y_pred: 预测标签 normalize: 归一化方式,可选值: 'true', 'pred', 'all', None Returns: 混淆矩阵 """ cm = confusion_matrix(y_true, y_pred) if normalize is not None: if normalize == 'true': cm = cm.astype('float') / cm.sum(axis=1, keepdims=True) elif normalize == 'pred': cm = cm.astype('float') / cm.sum(axis=0, keepdims=True) elif normalize == 'all': cm = cm.astype('float') / cm.sum() return cm def plot_confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray, normalize: Optional[str] = None, figsize: Tuple[int, int] = (10, 8), save_path: Optional[str] = None) -> None: """ 绘制混淆矩阵 Args: y_true: 真实标签 y_pred: 预测标签 normalize: 归一化方式,可选值: 'true', 'pred', 'all', None figsize: 图像大小 save_path: 保存路径,如果为None则显示图像 """ # 计算混淆矩阵 cm = self.confusion_matrix(y_true, y_pred, normalize) # 确定类别名称 if self.class_names is None: class_names = [str(i) for i in range(cm.shape[0])] else: class_names = self.class_names # 创建图像 plt.figure(figsize=figsize) # 使用热图显示混淆矩阵 im = plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) plt.colorbar(im) # 设置坐标轴标签 plt.xticks(np.arange(cm.shape[1]), class_names, rotation=45, ha='right') plt.yticks(np.arange(cm.shape[0]), class_names) # 设置标题 if normalize is not None: plt.title(f"Normalized ({normalize}) Confusion Matrix") else: plt.title("Confusion Matrix") plt.ylabel('True label') plt.xlabel('Predicted label') # 在每个单元格中显示数值 thresh = cm.max() / 2.0 for i in range(cm.shape[0]): for j in range(cm.shape[1]): if normalize is not None: plt.text(j, i, f"{cm[i, j]:.2f}", ha="center", va="center", color="white" if cm[i, j] > thresh else "black") else: plt.text(j, i, f"{cm[i, j]}", ha="center", va="center", color="white" if cm[i, j] > thresh else "black") plt.tight_layout() # 保存或显示图像 if save_path: plt.savefig(save_path) logger.info(f"混淆矩阵图已保存到: {save_path}") else: plt.show() def classification_report(self, y_true: np.ndarray, y_pred: np.ndarray, output_dict: bool = False) -> Union[str, Dict]: """ 生成分类报告 Args: y_true: 真实标签 y_pred: 预测标签 output_dict: 是否以字典形式返回 Returns: 分类报告 """ target_names = self.class_names if self.class_names else None return classification_report(y_true, y_pred, target_names=target_names, output_dict=output_dict, zero_division=0) def auc_roc(self, y_true: np.ndarray, y_prob: np.ndarray, multi_class: str = 'ovr') -> Union[float, np.ndarray]: """ 计算AUC-ROC Args: y_true: 真实标签 y_prob: 预测概率 multi_class: 多分类处理方式,可选值: 'ovr', 'ovo' Returns: AUC-ROC """ try: # 如果y_true是one-hot编码,转换为类别索引 if len(y_true.shape) > 1 and y_true.shape[1] > 1: y_true = np.argmax(y_true, axis=1) # 多分类 if y_prob.shape[1] > 2: return roc_auc_score(y_true, y_prob, multi_class=multi_class, average='macro') # 二分类 else: return roc_auc_score(y_true, y_prob[:, 1]) except Exception as e: logger.error(f"计算AUC-ROC时出错: {e}") return 0.0 def average_precision(self, y_true: np.ndarray, y_prob: np.ndarray, average: str = 'macro') -> Union[float, np.ndarray]: """ 计算平均精确率 Args: y_true: 真实标签 y_prob: 预测概率 average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None Returns: 平均精确率 """ try: # 如果y_true是one-hot编码,转换为类别索引 if len(y_true.shape) > 1 and y_true.shape[1] > 1: y_true = np.argmax(y_true, axis=1) # 多分类:使用sklearn的方法 return average_precision_score( tf.keras.utils.to_categorical(y_true, num_classes=y_prob.shape[1]), y_prob, average=average ) except Exception as e: logger.error(f"计算平均精确率时出错: {e}") return 0.0 def plot_precision_recall_curve(self, y_true: np.ndarray, y_prob: np.ndarray, class_id: Optional[int] = None, figsize: Tuple[int, int] = (10, 8), save_path: Optional[str] = None) -> None: """ 绘制精确率-召回率曲线 Args: y_true: 真实标签 y_prob: 预测概率 class_id: 要绘制的类别ID,如果为None则绘制所有类别 figsize: 图像大小 save_path: 保存路径,如果为None则显示图像 """ # 如果y_true是one-hot编码,转换为类别索引 if len(y_true.shape) > 1 and y_true.shape[1] > 1: y_true = np.argmax(y_true, axis=1) # 创建图像 plt.figure(figsize=figsize) # 确定要绘制的类别 if class_id is not None and class_id < y_prob.shape[1]: # 绘制指定类别的PR曲线 y_true_bin = (y_true == class_id).astype(int) precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, class_id]) avg_prec = average_precision_score(y_true_bin, y_prob[:, class_id]) plt.step(recall, precision, where='post', label=f'Class {class_id} (AP = {avg_prec:.3f})') else: # 绘制所有类别的PR曲线 for i in range(y_prob.shape[1]): y_true_bin = (y_true == i).astype(int) precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, i]) avg_prec = average_precision_score(y_true_bin, y_prob[:, i]) class_name = self.class_names[i] if self.class_names else f"Class {i}" plt.step(recall, precision, where='post', label=f'{class_name} (AP = {avg_prec:.3f})') plt.xlabel('Recall') plt.ylabel('Precision') plt.title('Precision-Recall Curve') plt.legend(loc='lower left') plt.grid(True) # 保存或显示图像 if save_path: plt.savefig(save_path) logger.info(f"精确率-召回率曲线图已保存到: {save_path}") else: plt.show() def calculate_all_metrics(self, y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None) -> Dict[str, float]: """ 计算所有评估指标 Args: y_true: 真实标签 y_pred: 预测标签 y_prob: 预测概率 Returns: 包含所有评估指标的字典 """ metrics = {} # 基础指标 metrics['accuracy'] = self.accuracy(y_true, y_pred) metrics['precision_macro'] = self.precision(y_true, y_pred, average='macro') metrics['recall_macro'] = self.recall(y_true, y_pred, average='macro') metrics['f1_macro'] = self.f1(y_true, y_pred, average='macro') # 如果提供了预测概率,计算AUC-ROC和平均精确率 if y_prob is not None: try: metrics['auc_roc'] = self.auc_roc(y_true, y_prob) metrics['average_precision'] = self.average_precision(y_true, y_prob) except Exception as e: logger.error(f"计算概率指标时出错: {e}") # 类别级别的指标 for avg in ['micro', 'weighted']: metrics[f'precision_{avg}'] = self.precision(y_true, y_pred, average=avg) metrics[f'recall_{avg}'] = self.recall(y_true, y_pred, average=avg) metrics[f'f1_{avg}'] = self.f1(y_true, y_pred, average=avg) return metrics def metrics_to_dataframe(self, metrics: Dict[str, float]) -> pd.DataFrame: """ 将评估指标转换为DataFrame Args: metrics: 评估指标字典 Returns: 评估指标DataFrame """ return pd.DataFrame(metrics.items(), columns=['Metric', 'Value']).set_index('Metric')