371 lines
13 KiB
Python
371 lines
13 KiB
Python
"""
|
||
可视化模块:实现评估结果的可视化
|
||
"""
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import pandas as pd
|
||
import seaborn as sns
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import os
|
||
import itertools
|
||
from sklearn.metrics import roc_curve, precision_recall_curve, auc
|
||
from sklearn.manifold import TSNE
|
||
from sklearn.decomposition import PCA
|
||
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import ensure_dir
|
||
|
||
logger = get_logger("Visualization")
|
||
|
||
|
||
class EvaluationVisualizer:
|
||
"""评估结果可视化类"""
|
||
|
||
def __init__(self, output_dir: Optional[str] = None,
|
||
class_names: Optional[List[str]] = None,
|
||
figsize: Tuple[int, int] = (10, 8)):
|
||
"""
|
||
初始化评估结果可视化类
|
||
|
||
Args:
|
||
output_dir: 输出目录,用于保存可视化结果
|
||
class_names: 类别名称列表
|
||
figsize: 图像默认大小
|
||
"""
|
||
self.output_dir = output_dir
|
||
if output_dir:
|
||
ensure_dir(output_dir)
|
||
|
||
self.class_names = class_names
|
||
self.figsize = figsize
|
||
|
||
def plot_confusion_matrix(self, cm: np.ndarray,
|
||
normalize: Optional[str] = None,
|
||
title: str = 'Confusion Matrix',
|
||
cmap: str = 'Blues',
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制混淆矩阵
|
||
|
||
Args:
|
||
cm: 混淆矩阵
|
||
normalize: 归一化方式,可选值: 'true', 'pred', 'all', None
|
||
title: 图像标题
|
||
cmap: 颜色映射
|
||
save_path: 保存路径,如果为None则使用output_dir/confusion_matrix.png
|
||
"""
|
||
if normalize is not None:
|
||
if normalize == 'true':
|
||
cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
|
||
title = 'Normalized (by true) ' + title
|
||
elif normalize == 'pred':
|
||
cm = cm.astype('float') / cm.sum(axis=0, keepdims=True)
|
||
title = 'Normalized (by pred) ' + title
|
||
elif normalize == 'all':
|
||
cm = cm.astype('float') / cm.sum()
|
||
title = 'Normalized (by all) ' + title
|
||
|
||
plt.figure(figsize=self.figsize)
|
||
|
||
# 使用seaborn绘制热图
|
||
sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd',
|
||
cmap=cmap, square=True, cbar=True)
|
||
|
||
# 设置坐标轴标签
|
||
if self.class_names:
|
||
tick_marks = np.arange(len(self.class_names))
|
||
plt.xticks(tick_marks + 0.5, self.class_names, rotation=45, ha='right')
|
||
plt.yticks(tick_marks + 0.5, self.class_names, rotation=0)
|
||
|
||
plt.title(title)
|
||
plt.ylabel('True label')
|
||
plt.xlabel('Predicted label')
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
if save_path is None and self.output_dir:
|
||
save_path = os.path.join(self.output_dir, 'confusion_matrix.png')
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
logger.info(f"混淆矩阵图已保存到: {save_path}")
|
||
|
||
plt.close()
|
||
|
||
def plot_metrics_comparison(self, metrics_dict: Dict[str, Dict[str, float]],
|
||
selected_metrics: Optional[List[str]] = None,
|
||
title: str = 'Metrics Comparison',
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制多个模型的评估指标比较
|
||
|
||
Args:
|
||
metrics_dict: 模型评估指标字典,格式为{model_name: {metric_name: value}}
|
||
selected_metrics: 要比较的指标列表,如果为None则使用所有指标
|
||
title: 图像标题
|
||
save_path: 保存路径,如果为None则使用output_dir/metrics_comparison.png
|
||
"""
|
||
# 创建DataFrame
|
||
df = pd.DataFrame(metrics_dict).T
|
||
|
||
# 筛选指标
|
||
if selected_metrics:
|
||
df = df[selected_metrics]
|
||
|
||
# 绘制条形图
|
||
plt.figure(figsize=self.figsize)
|
||
df.plot(kind='bar', figsize=self.figsize)
|
||
|
||
plt.title(title)
|
||
plt.ylabel('Score')
|
||
plt.ylim(0, 1)
|
||
plt.legend(loc='best')
|
||
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
if save_path is None and self.output_dir:
|
||
save_path = os.path.join(self.output_dir, 'metrics_comparison.png')
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
logger.info(f"评估指标比较图已保存到: {save_path}")
|
||
|
||
plt.close()
|
||
|
||
def plot_roc_curves(self, y_true: np.ndarray, y_prob: np.ndarray,
|
||
title: str = 'ROC Curves',
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制ROC曲线
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_prob: 预测概率
|
||
title: 图像标题
|
||
save_path: 保存路径,如果为None则使用output_dir/roc_curves.png
|
||
"""
|
||
plt.figure(figsize=self.figsize)
|
||
|
||
# 确保y_true是一维数组
|
||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||
y_true = np.argmax(y_true, axis=1)
|
||
|
||
# 获取类别数
|
||
num_classes = y_prob.shape[1]
|
||
|
||
# 绘制每个类别的ROC曲线
|
||
for i in range(num_classes):
|
||
# 二分类转换:当前类别为正类,其他为负类
|
||
y_true_bin = (y_true == i).astype(int)
|
||
|
||
# 计算ROC曲线
|
||
fpr, tpr, _ = roc_curve(y_true_bin, y_prob[:, i])
|
||
roc_auc = auc(fpr, tpr)
|
||
|
||
# 确定类别名称
|
||
if self.class_names and i < len(self.class_names):
|
||
class_name = self.class_names[i]
|
||
else:
|
||
class_name = f'Class {i}'
|
||
|
||
# 绘制ROC曲线
|
||
plt.plot(fpr, tpr, lw=2, label=f'{class_name} (AUC = {roc_auc:.3f})')
|
||
|
||
# 绘制随机猜测的基准线
|
||
plt.plot([0, 1], [0, 1], 'k--', lw=2)
|
||
|
||
plt.xlim([0.0, 1.0])
|
||
plt.ylim([0.0, 1.05])
|
||
plt.xlabel('False Positive Rate')
|
||
plt.ylabel('True Positive Rate')
|
||
plt.title(title)
|
||
plt.legend(loc='lower right')
|
||
plt.grid(True, alpha=0.3)
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
if save_path is None and self.output_dir:
|
||
save_path = os.path.join(self.output_dir, 'roc_curves.png')
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
logger.info(f"ROC曲线图已保存到: {save_path}")
|
||
|
||
plt.close()
|
||
|
||
def plot_precision_recall_curves(self, y_true: np.ndarray, y_prob: np.ndarray,
|
||
title: str = 'Precision-Recall Curves',
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制精确率-召回率曲线
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_prob: 预测概率
|
||
title: 图像标题
|
||
save_path: 保存路径,如果为None则使用output_dir/precision_recall_curves.png
|
||
"""
|
||
plt.figure(figsize=self.figsize)
|
||
|
||
# 确保y_true是一维数组
|
||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||
y_true = np.argmax(y_true, axis=1)
|
||
|
||
# 获取类别数
|
||
num_classes = y_prob.shape[1]
|
||
|
||
# 绘制每个类别的PR曲线
|
||
for i in range(num_classes):
|
||
# 二分类转换:当前类别为正类,其他为负类
|
||
y_true_bin = (y_true == i).astype(int)
|
||
|
||
# 计算PR曲线
|
||
precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, i])
|
||
pr_auc = auc(recall, precision)
|
||
|
||
# 确定类别名称
|
||
if self.class_names and i < len(self.class_names):
|
||
class_name = self.class_names[i]
|
||
else:
|
||
class_name = f'Class {i}'
|
||
|
||
# 绘制PR曲线
|
||
plt.plot(recall, precision, lw=2, label=f'{class_name} (AUC = {pr_auc:.3f})')
|
||
|
||
plt.xlim([0.0, 1.0])
|
||
plt.ylim([0.0, 1.05])
|
||
plt.xlabel('Recall')
|
||
plt.ylabel('Precision')
|
||
plt.title(title)
|
||
plt.legend(loc='best')
|
||
plt.grid(True, alpha=0.3)
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
if save_path is None and self.output_dir:
|
||
save_path = os.path.join(self.output_dir, 'precision_recall_curves.png')
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
logger.info(f"精确率-召回率曲线图已保存到: {save_path}")
|
||
|
||
plt.close()
|
||
|
||
def plot_feature_importance(self, feature_names: List[str],
|
||
importance: np.ndarray,
|
||
title: str = 'Feature Importance',
|
||
top_n: int = 20,
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制特征重要性
|
||
|
||
Args:
|
||
feature_names: 特征名称列表
|
||
importance: 特征重要性数组
|
||
title: 图像标题
|
||
top_n: 显示前N个重要的特征
|
||
save_path: 保存路径,如果为None则使用output_dir/feature_importance.png
|
||
"""
|
||
# 创建DataFrame
|
||
df = pd.DataFrame({'Feature': feature_names, 'Importance': importance})
|
||
|
||
# 按重要性排序
|
||
df = df.sort_values('Importance', ascending=False).head(top_n)
|
||
|
||
# 绘制条形图
|
||
plt.figure(figsize=self.figsize)
|
||
sns.barplot(x='Importance', y='Feature', data=df)
|
||
|
||
plt.title(title)
|
||
plt.xlabel('Importance')
|
||
plt.ylabel('Feature')
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
if save_path is None and self.output_dir:
|
||
save_path = os.path.join(self.output_dir, 'feature_importance.png')
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
logger.info(f"特征重要性图已保存到: {save_path}")
|
||
|
||
plt.close()
|
||
|
||
def plot_embedding_visualization(self, embeddings: np.ndarray,
|
||
labels: np.ndarray,
|
||
method: str = 'tsne',
|
||
title: str = 'Embedding Visualization',
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制嵌入向量可视化
|
||
|
||
Args:
|
||
embeddings: 嵌入向量,形状为(样本数, 嵌入维度)
|
||
labels: 类别标签,形状为(样本数,)
|
||
method: 降维方法,'tsne'或'pca'
|
||
title: 图像标题
|
||
save_path: 保存路径,如果为None则使用output_dir/embedding_visualization.png
|
||
"""
|
||
# 降维
|
||
if method.lower() == 'tsne':
|
||
reducer = TSNE(n_components=2, random_state=42)
|
||
elif method.lower() == 'pca':
|
||
reducer = PCA(n_components=2, random_state=42)
|
||
else:
|
||
raise ValueError(f"不支持的降维方法: {method}")
|
||
|
||
# 如果嵌入向量太多,采样一部分
|
||
max_samples = 5000
|
||
if len(embeddings) > max_samples:
|
||
indices = np.random.choice(len(embeddings), max_samples, replace=False)
|
||
embeddings_sample = embeddings[indices]
|
||
labels_sample = labels[indices]
|
||
else:
|
||
embeddings_sample = embeddings
|
||
labels_sample = labels
|
||
|
||
# 执行降维
|
||
embeddings_2d = reducer.fit_transform(embeddings_sample)
|
||
|
||
# 绘制散点图
|
||
plt.figure(figsize=self.figsize)
|
||
|
||
# 确保标签是一维数组
|
||
if len(labels_sample.shape) > 1 and labels_sample.shape[1] > 1:
|
||
labels_sample = np.argmax(labels_sample, axis=1)
|
||
|
||
# 获取唯一类别
|
||
unique_labels = np.unique(labels_sample)
|
||
|
||
# 为每个类别分配不同的颜色
|
||
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
|
||
|
||
for i, label in enumerate(unique_labels):
|
||
mask = labels_sample == label
|
||
|
||
# 确定类别名称
|
||
if self.class_names and label < len(self.class_names):
|
||
class_name = self.class_names[int(label)]
|
||
else:
|
||
class_name = f'Class {int(label)}'
|
||
|
||
plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
|
||
c=[colors[i]], label=class_name, alpha=0.7)
|
||
|
||
plt.title(title)
|
||
plt.legend(loc='best')
|
||
plt.grid(True, alpha=0.3)
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
if save_path is None and self.output_dir:
|
||
save_path = os.path.join(self.output_dir, f'embedding_visualization_{method}.png')
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
logger.info(f"嵌入向量可视化图已保存到: {save_path}")
|
||
|
||
plt.close()
|
||
|