2025-03-08 01:34:36 +08:00

583 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据管理模块:负责数据的存储、读取和转换
"""
import os
import pickle
import json
import time
from typing import List, Dict, Tuple, Optional, Any, Union
import numpy as np
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from config.system_config import (
PROCESSED_DATA_DIR, ENCODING, CATEGORY_TO_ID, ID_TO_CATEGORY
)
from config.model_config import (
VALIDATION_SPLIT, TEST_SPLIT, RANDOM_SEED
)
from utils.logger import get_logger
from utils.file_utils import (
save_pickle, load_pickle, save_json, load_json, ensure_dir
)
from data.dataloader import DataLoader
logger = get_logger("DataManager")
class DataManager:
"""数据管理类,负责数据的存储、读取和转换"""
def __init__(self, processed_dir: Optional[str] = None):
"""
初始化数据管理器
Args:
processed_dir: 处理后数据的存储目录,默认使用配置文件中的路径
"""
self.processed_dir = processed_dir or PROCESSED_DATA_DIR
ensure_dir(self.processed_dir)
# 数据分割后的存储
self.train_texts = []
self.train_labels = []
self.val_texts = []
self.val_labels = []
self.test_texts = []
self.test_labels = []
# 数据统计信息
self.stats = {}
# 标签编码映射
self.label_to_id = CATEGORY_TO_ID
self.id_to_label = ID_TO_CATEGORY
logger.info(f"数据管理器初始化完成,处理后数据将存储在 {self.processed_dir}")
def load_and_split_data(self, data_loader: DataLoader,
categories: Optional[List[str]] = None,
val_split: float = VALIDATION_SPLIT,
test_split: float = TEST_SPLIT,
sample_ratio: float = 1.0,
balanced: bool = False,
n_per_category: int = 1000,
save: bool = True) -> Dict[str, Any]:
"""
加载并分割数据集
Args:
data_loader: 数据加载器实例
categories: 要包含的类别列表,默认为所有类别
val_split: 验证集比例
test_split: 测试集比例
sample_ratio: 采样比例默认为1.0(全部)
balanced: 是否平衡各类别的样本数量
n_per_category: 平衡模式下每个类别的样本数量
save: 是否保存处理后的数据
Returns:
包含分割后数据集的字典
"""
start_time = time.time()
# 加载数据
if balanced:
logger.info(f"加载平衡数据集,每个类别 {n_per_category} 个样本")
data = data_loader.load_balanced_data(
n_per_category=n_per_category,
categories=categories,
shuffle=True
)
else:
logger.info(f"加载数据集,采样比例 {sample_ratio}")
data = data_loader.load_data(
categories=categories,
sample_ratio=sample_ratio,
shuffle=True,
return_generator=False
)
logger.info(f"加载了 {len(data)} 个样本")
# 分离文本和标签
texts = [text for text, _ in data]
labels = [label for _, label in data]
# 进行标签编码
encoded_labels = np.array([self.label_to_id[label] for label in labels])
# 计算数据统计信息
self._compute_stats(texts, labels)
# 划分训练集、验证集和测试集
# 先分出测试集
if test_split > 0:
train_val_texts, self.test_texts, train_val_labels, self.test_labels = train_test_split(
texts, encoded_labels,
test_size=test_split,
random_state=RANDOM_SEED,
stratify=encoded_labels if len(set(encoded_labels)) > 1 else None
)
else:
train_val_texts, train_val_labels = texts, encoded_labels
self.test_texts, self.test_labels = [], []
# 再划分训练集和验证集
if val_split > 0:
self.train_texts, self.val_texts, self.train_labels, self.val_labels = train_test_split(
train_val_texts, train_val_labels,
test_size=val_split / (1 - test_split),
random_state=RANDOM_SEED,
stratify=train_val_labels if len(set(train_val_labels)) > 1 else None
)
else:
self.train_texts, self.train_labels = train_val_texts, train_val_labels
self.val_texts, self.val_labels = [], []
# 打印数据集划分结果
logger.info(f"数据集划分结果:")
logger.info(f" 训练集:{len(self.train_texts)} 个样本")
logger.info(f" 验证集:{len(self.val_texts)} 个样本")
logger.info(f" 测试集:{len(self.test_texts)} 个样本")
# 保存处理后的数据
if save:
self.save_data()
elapsed = time.time() - start_time
logger.info(f"数据加载和分割完成,用时 {elapsed:.2f}")
return {
"train_texts": self.train_texts,
"train_labels": self.train_labels,
"val_texts": self.val_texts,
"val_labels": self.val_labels,
"test_texts": self.test_texts,
"test_labels": self.test_labels,
"stats": self.stats
}
def _compute_stats(self, texts: List[str], labels: List[str]) -> None:
"""
计算数据统计信息
Args:
texts: 文本列表
labels: 标签列表
"""
# 文本数量
num_samples = len(texts)
# 类别分布
label_counter = Counter(labels)
label_distribution = {label: count / num_samples * 100 for label, count in label_counter.items()}
# 文本长度统计
text_lengths = [len(text) for text in texts]
avg_length = sum(text_lengths) / len(text_lengths)
max_length = max(text_lengths)
min_length = min(text_lengths)
# 前5个最长和最短的文本的长度
sorted_lengths = sorted(text_lengths)
shortest_lengths = sorted_lengths[:5]
longest_lengths = sorted_lengths[-5:]
# 95%的文本长度分位数
percentile_95 = np.percentile(text_lengths, 95)
# 存储统计信息
self.stats = {
"num_samples": num_samples,
"num_categories": len(label_counter),
"label_counter": label_counter,
"label_distribution": label_distribution,
"text_length": {
"average": avg_length,
"max": max_length,
"min": min_length,
"percentile_95": percentile_95,
"shortest_5": shortest_lengths,
"longest_5": longest_lengths
}
}
def save_data(self, save_dir: Optional[str] = None) -> None:
"""
保存处理后的数据
Args:
save_dir: 保存目录,默认使用初始化时设置的目录
"""
save_dir = save_dir or self.processed_dir
ensure_dir(save_dir)
# 保存训练集
save_pickle(
{"texts": self.train_texts, "labels": self.train_labels},
os.path.join(save_dir, "train_data.pkl")
)
# 保存验证集
if len(self.val_texts) > 0:
save_pickle(
{"texts": self.val_texts, "labels": self.val_labels},
os.path.join(save_dir, "val_data.pkl")
)
# 保存测试集
if len(self.test_texts) > 0:
save_pickle(
{"texts": self.test_texts, "labels": self.test_labels},
os.path.join(save_dir, "test_data.pkl")
)
# 保存标签编码映射
save_json(
{"label_to_id": self.label_to_id, "id_to_label": self.id_to_label},
os.path.join(save_dir, "label_mapping.json")
)
# 保存数据统计信息
# 将Counter对象转换为普通字典以便JSON序列化
stats_for_json = self.stats.copy()
if "label_counter" in stats_for_json:
stats_for_json["label_counter"] = dict(stats_for_json["label_counter"])
save_json(
stats_for_json,
os.path.join(save_dir, "data_stats.json")
)
logger.info(f"已将处理后的数据保存到 {save_dir}")
def load_data(self, load_dir: Optional[str] = None) -> Dict[str, Any]:
"""
加载处理后的数据
Args:
load_dir: 加载目录,默认使用初始化时设置的目录
Returns:
包含加载的数据集的字典
"""
load_dir = load_dir or self.processed_dir
# 加载训练集
train_data_path = os.path.join(load_dir, "train_data.pkl")
if os.path.exists(train_data_path):
train_data = load_pickle(train_data_path)
self.train_texts = train_data["texts"]
self.train_labels = train_data["labels"]
logger.info(f"已加载训练集,包含 {len(self.train_texts)} 个样本")
else:
logger.warning(f"训练集文件不存在: {train_data_path}")
self.train_texts, self.train_labels = [], []
# 加载验证集
val_data_path = os.path.join(load_dir, "val_data.pkl")
if os.path.exists(val_data_path):
val_data = load_pickle(val_data_path)
self.val_texts = val_data["texts"]
self.val_labels = val_data["labels"]
logger.info(f"已加载验证集,包含 {len(self.val_texts)} 个样本")
else:
logger.warning(f"验证集文件不存在: {val_data_path}")
self.val_texts, self.val_labels = [], []
# 加载测试集
test_data_path = os.path.join(load_dir, "test_data.pkl")
if os.path.exists(test_data_path):
test_data = load_pickle(test_data_path)
self.test_texts = test_data["texts"]
self.test_labels = test_data["labels"]
logger.info(f"已加载测试集,包含 {len(self.test_texts)} 个样本")
else:
logger.warning(f"测试集文件不存在: {test_data_path}")
self.test_texts, self.test_labels = [], []
# 加载标签编码映射
mapping_path = os.path.join(load_dir, "label_mapping.json")
if os.path.exists(mapping_path):
mapping = load_json(mapping_path)
self.label_to_id = mapping["label_to_id"]
self.id_to_label = mapping["id_to_label"]
# 将字符串键转换为整数JSON序列化会将所有键转为字符串
self.id_to_label = {int(k): v for k, v in self.id_to_label.items()}
logger.info(f"已加载标签编码映射,共 {len(self.label_to_id)} 个类别")
# 加载数据统计信息
stats_path = os.path.join(load_dir, "data_stats.json")
if os.path.exists(stats_path):
self.stats = load_json(stats_path)
logger.info("已加载数据统计信息")
return {
"train_texts": self.train_texts,
"train_labels": self.train_labels,
"val_texts": self.val_texts,
"val_labels": self.val_labels,
"test_texts": self.test_texts,
"test_labels": self.test_labels,
"stats": self.stats
}
def get_label_distribution(self, dataset: str = "train") -> Dict[str, float]:
"""
获取指定数据集的标签分布
Args:
dataset: 数据集名称,可选值:'train', 'val', 'test'
Returns:
标签分布字典,键为类别名称,值为比例
"""
if dataset == "train":
labels = self.train_labels
elif dataset == "val":
labels = self.val_labels
elif dataset == "test":
labels = self.test_labels
else:
raise ValueError(f"不支持的数据集名称: {dataset}")
# 计算标签分布
label_counter = Counter(labels)
num_samples = len(labels)
# 将数字标签转换为类别名称
distribution = {}
for label_id, count in label_counter.items():
label_name = self.id_to_label.get(label_id, str(label_id))
distribution[label_name] = count / num_samples * 100
return distribution
def visualize_label_distribution(self, dataset: str = "train",
save_path: Optional[str] = None) -> None:
"""
可视化标签分布
Args:
dataset: 数据集名称,可选值:'train', 'val', 'test', 'all'
save_path: 图表保存路径默认为None显示而不保存
"""
plt.figure(figsize=(12, 8))
if dataset == "all":
# 显示所有数据集的标签分布
train_dist = self.get_label_distribution("train")
val_dist = self.get_label_distribution("val") if len(self.val_labels) > 0 else {}
test_dist = self.get_label_distribution("test") if len(self.test_labels) > 0 else {}
# 准备数据
categories = list(train_dist.keys())
train_values = [train_dist.get(cat, 0) for cat in categories]
val_values = [val_dist.get(cat, 0) for cat in categories]
test_values = [test_dist.get(cat, 0) for cat in categories]
# 绘制条形图
x = np.arange(len(categories))
width = 0.25
plt.bar(x - width, train_values, width, label="Training")
if val_values:
plt.bar(x, val_values, width, label="Validation")
if test_values:
plt.bar(x + width, test_values, width, label="Testing")
plt.xlabel("Categories")
plt.ylabel("Percentage (%)")
plt.title("Label Distribution Across Datasets")
plt.xticks(x, categories, rotation=45, ha="right")
plt.legend()
plt.tight_layout()
else:
# 显示单个数据集的标签分布
distribution = self.get_label_distribution(dataset)
# 按值排序
sorted_items = sorted(distribution.items(), key=lambda x: x[1], reverse=True)
categories = [item[0] for item in sorted_items]
values = [item[1] for item in sorted_items]
# 绘制条形图
plt.bar(categories, values, color='skyblue')
plt.xlabel("Categories")
plt.ylabel("Percentage (%)")
plt.title(f"Label Distribution in {dataset.capitalize()} Dataset")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
# 保存或显示图表
if save_path:
plt.savefig(save_path)
logger.info(f"标签分布图已保存到 {save_path}")
else:
plt.show()
def visualize_text_length_distribution(self, dataset: str = "train",
bins: int = 50,
save_path: Optional[str] = None) -> None:
"""
可视化文本长度分布
Args:
dataset: 数据集名称,可选值:'train', 'val', 'test'
bins: 直方图的箱数
save_path: 图表保存路径默认为None显示而不保存
"""
if dataset == "train":
texts = self.train_texts
elif dataset == "val":
texts = self.val_texts
elif dataset == "test":
texts = self.test_texts
else:
raise ValueError(f"不支持的数据集名称: {dataset}")
# 计算文本长度
text_lengths = [len(text) for text in texts]
# 绘制直方图
plt.figure(figsize=(10, 6))
plt.hist(text_lengths, bins=bins, color='skyblue', alpha=0.7)
# 计算并绘制一些统计量
avg_length = sum(text_lengths) / len(text_lengths)
median_length = np.median(text_lengths)
percentile_95 = np.percentile(text_lengths, 95)
plt.axvline(avg_length, color='red', linestyle='dashed', linewidth=1, label=f'Mean: {avg_length:.1f}')
plt.axvline(median_length, color='green', linestyle='dashed', linewidth=1, label=f'Median: {median_length:.1f}')
plt.axvline(percentile_95, color='purple', linestyle='dashed', linewidth=1,
label=f'95th Percentile: {percentile_95:.1f}')
plt.xlabel('Text Length (characters)')
plt.ylabel('Frequency')
plt.title(f'Text Length Distribution in {dataset.capitalize()} Dataset')
plt.legend()
plt.tight_layout()
# 保存或显示图表
if save_path:
plt.savefig(save_path)
logger.info(f"文本长度分布图已保存到 {save_path}")
else:
plt.show()
def get_data_summary(self) -> Dict[str, Any]:
"""
获取数据集的摘要信息
Returns:
包含数据摘要的字典
"""
# 获取数据集的基本信息
summary = {
"train_size": len(self.train_texts),
"val_size": len(self.val_texts),
"test_size": len(self.test_texts),
"num_categories": len(self.label_to_id),
"categories": list(self.label_to_id.keys()),
}
# 添加训练集的标签分布
if len(self.train_texts) > 0:
summary["train_label_distribution"] = self.get_label_distribution("train")
# 添加验证集的标签分布
if len(self.val_texts) > 0:
summary["val_label_distribution"] = self.get_label_distribution("val")
# 添加测试集的标签分布
if len(self.test_texts) > 0:
summary["test_label_distribution"] = self.get_label_distribution("test")
# 添加更多统计信息(如果有)
if self.stats:
# 只添加一些关键的统计信息
if "text_length" in self.stats:
summary["text_length_stats"] = self.stats["text_length"]
return summary
def export_to_pandas(self, dataset: str = "train") -> pd.DataFrame:
"""
将数据导出为Pandas DataFrame
Args:
dataset: 数据集名称,可选值:'train', 'val', 'test'
Returns:
Pandas DataFrame
"""
if dataset == "train":
texts = self.train_texts
labels_ids = self.train_labels
elif dataset == "val":
texts = self.val_texts
labels_ids = self.val_labels
elif dataset == "test":
texts = self.test_texts
labels_ids = self.test_labels
else:
raise ValueError(f"不支持的数据集名称: {dataset}")
# 将数字标签转换为类别名称
labels = [self.id_to_label.get(label_id, str(label_id)) for label_id in labels_ids]
# 创建DataFrame
df = pd.DataFrame({
"text": texts,
"label_id": labels_ids,
"label": labels
})
return df
def get_label_name(self, label_id: int) -> str:
"""
获取标签ID对应的类别名称
Args:
label_id: 标签ID
Returns:
类别名称
"""
return self.id_to_label.get(label_id, str(label_id))
def get_label_id(self, label_name: str) -> int:
"""
获取类别名称对应的标签ID
Args:
label_name: 类别名称
Returns:
标签ID
"""
return self.label_to_id.get(label_name, -1)
def get_data(self, dataset: str = "train") -> Tuple[List[str], np.ndarray]:
"""
获取指定数据集的文本和标签
Args:
dataset: 数据集名称,可选值:'train', 'val', 'test'
Returns:
(文本列表, 标签数组)的元组
"""
if dataset == "train":
return self.train_texts, self.train_labels
elif dataset == "val":
return self.val_texts, self.val_labels
elif dataset == "test":
return self.test_texts, self.test_labels
else:
raise ValueError(f"不支持的数据集名称: {dataset}")