2025-03-08 01:34:36 +08:00

371 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
可视化模块:实现评估结果的可视化
"""
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from typing import List, Dict, Tuple, Optional, Any, Union
import os
import itertools
from sklearn.metrics import roc_curve, precision_recall_curve, auc
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from utils.logger import get_logger
from utils.file_utils import ensure_dir
logger = get_logger("Visualization")
class EvaluationVisualizer:
"""评估结果可视化类"""
def __init__(self, output_dir: Optional[str] = None,
class_names: Optional[List[str]] = None,
figsize: Tuple[int, int] = (10, 8)):
"""
初始化评估结果可视化类
Args:
output_dir: 输出目录,用于保存可视化结果
class_names: 类别名称列表
figsize: 图像默认大小
"""
self.output_dir = output_dir
if output_dir:
ensure_dir(output_dir)
self.class_names = class_names
self.figsize = figsize
def plot_confusion_matrix(self, cm: np.ndarray,
normalize: Optional[str] = None,
title: str = 'Confusion Matrix',
cmap: str = 'Blues',
save_path: Optional[str] = None) -> None:
"""
绘制混淆矩阵
Args:
cm: 混淆矩阵
normalize: 归一化方式,可选值: 'true', 'pred', 'all', None
title: 图像标题
cmap: 颜色映射
save_path: 保存路径如果为None则使用output_dir/confusion_matrix.png
"""
if normalize is not None:
if normalize == 'true':
cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
title = 'Normalized (by true) ' + title
elif normalize == 'pred':
cm = cm.astype('float') / cm.sum(axis=0, keepdims=True)
title = 'Normalized (by pred) ' + title
elif normalize == 'all':
cm = cm.astype('float') / cm.sum()
title = 'Normalized (by all) ' + title
plt.figure(figsize=self.figsize)
# 使用seaborn绘制热图
sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd',
cmap=cmap, square=True, cbar=True)
# 设置坐标轴标签
if self.class_names:
tick_marks = np.arange(len(self.class_names))
plt.xticks(tick_marks + 0.5, self.class_names, rotation=45, ha='right')
plt.yticks(tick_marks + 0.5, self.class_names, rotation=0)
plt.title(title)
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
# 保存图像
if save_path is None and self.output_dir:
save_path = os.path.join(self.output_dir, 'confusion_matrix.png')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"混淆矩阵图已保存到: {save_path}")
plt.close()
def plot_metrics_comparison(self, metrics_dict: Dict[str, Dict[str, float]],
selected_metrics: Optional[List[str]] = None,
title: str = 'Metrics Comparison',
save_path: Optional[str] = None) -> None:
"""
绘制多个模型的评估指标比较
Args:
metrics_dict: 模型评估指标字典,格式为{model_name: {metric_name: value}}
selected_metrics: 要比较的指标列表如果为None则使用所有指标
title: 图像标题
save_path: 保存路径如果为None则使用output_dir/metrics_comparison.png
"""
# 创建DataFrame
df = pd.DataFrame(metrics_dict).T
# 筛选指标
if selected_metrics:
df = df[selected_metrics]
# 绘制条形图
plt.figure(figsize=self.figsize)
df.plot(kind='bar', figsize=self.figsize)
plt.title(title)
plt.ylabel('Score')
plt.ylim(0, 1)
plt.legend(loc='best')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
# 保存图像
if save_path is None and self.output_dir:
save_path = os.path.join(self.output_dir, 'metrics_comparison.png')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"评估指标比较图已保存到: {save_path}")
plt.close()
def plot_roc_curves(self, y_true: np.ndarray, y_prob: np.ndarray,
title: str = 'ROC Curves',
save_path: Optional[str] = None) -> None:
"""
绘制ROC曲线
Args:
y_true: 真实标签
y_prob: 预测概率
title: 图像标题
save_path: 保存路径如果为None则使用output_dir/roc_curves.png
"""
plt.figure(figsize=self.figsize)
# 确保y_true是一维数组
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 获取类别数
num_classes = y_prob.shape[1]
# 绘制每个类别的ROC曲线
for i in range(num_classes):
# 二分类转换:当前类别为正类,其他为负类
y_true_bin = (y_true == i).astype(int)
# 计算ROC曲线
fpr, tpr, _ = roc_curve(y_true_bin, y_prob[:, i])
roc_auc = auc(fpr, tpr)
# 确定类别名称
if self.class_names and i < len(self.class_names):
class_name = self.class_names[i]
else:
class_name = f'Class {i}'
# 绘制ROC曲线
plt.plot(fpr, tpr, lw=2, label=f'{class_name} (AUC = {roc_auc:.3f})')
# 绘制随机猜测的基准线
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(title)
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.tight_layout()
# 保存图像
if save_path is None and self.output_dir:
save_path = os.path.join(self.output_dir, 'roc_curves.png')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"ROC曲线图已保存到: {save_path}")
plt.close()
def plot_precision_recall_curves(self, y_true: np.ndarray, y_prob: np.ndarray,
title: str = 'Precision-Recall Curves',
save_path: Optional[str] = None) -> None:
"""
绘制精确率-召回率曲线
Args:
y_true: 真实标签
y_prob: 预测概率
title: 图像标题
save_path: 保存路径如果为None则使用output_dir/precision_recall_curves.png
"""
plt.figure(figsize=self.figsize)
# 确保y_true是一维数组
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 获取类别数
num_classes = y_prob.shape[1]
# 绘制每个类别的PR曲线
for i in range(num_classes):
# 二分类转换:当前类别为正类,其他为负类
y_true_bin = (y_true == i).astype(int)
# 计算PR曲线
precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, i])
pr_auc = auc(recall, precision)
# 确定类别名称
if self.class_names and i < len(self.class_names):
class_name = self.class_names[i]
else:
class_name = f'Class {i}'
# 绘制PR曲线
plt.plot(recall, precision, lw=2, label=f'{class_name} (AUC = {pr_auc:.3f})')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title(title)
plt.legend(loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()
# 保存图像
if save_path is None and self.output_dir:
save_path = os.path.join(self.output_dir, 'precision_recall_curves.png')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"精确率-召回率曲线图已保存到: {save_path}")
plt.close()
def plot_feature_importance(self, feature_names: List[str],
importance: np.ndarray,
title: str = 'Feature Importance',
top_n: int = 20,
save_path: Optional[str] = None) -> None:
"""
绘制特征重要性
Args:
feature_names: 特征名称列表
importance: 特征重要性数组
title: 图像标题
top_n: 显示前N个重要的特征
save_path: 保存路径如果为None则使用output_dir/feature_importance.png
"""
# 创建DataFrame
df = pd.DataFrame({'Feature': feature_names, 'Importance': importance})
# 按重要性排序
df = df.sort_values('Importance', ascending=False).head(top_n)
# 绘制条形图
plt.figure(figsize=self.figsize)
sns.barplot(x='Importance', y='Feature', data=df)
plt.title(title)
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.tight_layout()
# 保存图像
if save_path is None and self.output_dir:
save_path = os.path.join(self.output_dir, 'feature_importance.png')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"特征重要性图已保存到: {save_path}")
plt.close()
def plot_embedding_visualization(self, embeddings: np.ndarray,
labels: np.ndarray,
method: str = 'tsne',
title: str = 'Embedding Visualization',
save_path: Optional[str] = None) -> None:
"""
绘制嵌入向量可视化
Args:
embeddings: 嵌入向量,形状为(样本数, 嵌入维度)
labels: 类别标签,形状为(样本数,)
method: 降维方法,'tsne''pca'
title: 图像标题
save_path: 保存路径如果为None则使用output_dir/embedding_visualization.png
"""
# 降维
if method.lower() == 'tsne':
reducer = TSNE(n_components=2, random_state=42)
elif method.lower() == 'pca':
reducer = PCA(n_components=2, random_state=42)
else:
raise ValueError(f"不支持的降维方法: {method}")
# 如果嵌入向量太多,采样一部分
max_samples = 5000
if len(embeddings) > max_samples:
indices = np.random.choice(len(embeddings), max_samples, replace=False)
embeddings_sample = embeddings[indices]
labels_sample = labels[indices]
else:
embeddings_sample = embeddings
labels_sample = labels
# 执行降维
embeddings_2d = reducer.fit_transform(embeddings_sample)
# 绘制散点图
plt.figure(figsize=self.figsize)
# 确保标签是一维数组
if len(labels_sample.shape) > 1 and labels_sample.shape[1] > 1:
labels_sample = np.argmax(labels_sample, axis=1)
# 获取唯一类别
unique_labels = np.unique(labels_sample)
# 为每个类别分配不同的颜色
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
for i, label in enumerate(unique_labels):
mask = labels_sample == label
# 确定类别名称
if self.class_names and label < len(self.class_names):
class_name = self.class_names[int(label)]
else:
class_name = f'Class {int(label)}'
plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
c=[colors[i]], label=class_name, alpha=0.7)
plt.title(title)
plt.legend(loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()
# 保存图像
if save_path is None and self.output_dir:
save_path = os.path.join(self.output_dir, f'embedding_visualization_{method}.png')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"嵌入向量可视化图已保存到: {save_path}")
plt.close()