583 lines
20 KiB
Python
583 lines
20 KiB
Python
"""
|
||
数据管理模块:负责数据的存储、读取和转换
|
||
"""
|
||
import os
|
||
import pickle
|
||
import json
|
||
import time
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import numpy as np
|
||
import pandas as pd
|
||
from collections import Counter
|
||
import matplotlib.pyplot as plt
|
||
from sklearn.model_selection import train_test_split
|
||
|
||
from config.system_config import (
|
||
PROCESSED_DATA_DIR, ENCODING, CATEGORY_TO_ID, ID_TO_CATEGORY
|
||
)
|
||
from config.model_config import (
|
||
VALIDATION_SPLIT, TEST_SPLIT, RANDOM_SEED
|
||
)
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import (
|
||
save_pickle, load_pickle, save_json, load_json, ensure_dir
|
||
)
|
||
from data.dataloader import DataLoader
|
||
|
||
logger = get_logger("DataManager")
|
||
|
||
|
||
class DataManager:
|
||
"""数据管理类,负责数据的存储、读取和转换"""
|
||
|
||
def __init__(self, processed_dir: Optional[str] = None):
|
||
"""
|
||
初始化数据管理器
|
||
|
||
Args:
|
||
processed_dir: 处理后数据的存储目录,默认使用配置文件中的路径
|
||
"""
|
||
self.processed_dir = processed_dir or PROCESSED_DATA_DIR
|
||
ensure_dir(self.processed_dir)
|
||
|
||
# 数据分割后的存储
|
||
self.train_texts = []
|
||
self.train_labels = []
|
||
self.val_texts = []
|
||
self.val_labels = []
|
||
self.test_texts = []
|
||
self.test_labels = []
|
||
|
||
# 数据统计信息
|
||
self.stats = {}
|
||
|
||
# 标签编码映射
|
||
self.label_to_id = CATEGORY_TO_ID
|
||
self.id_to_label = ID_TO_CATEGORY
|
||
|
||
logger.info(f"数据管理器初始化完成,处理后数据将存储在 {self.processed_dir}")
|
||
|
||
def load_and_split_data(self, data_loader: DataLoader,
|
||
categories: Optional[List[str]] = None,
|
||
val_split: float = VALIDATION_SPLIT,
|
||
test_split: float = TEST_SPLIT,
|
||
sample_ratio: float = 1.0,
|
||
balanced: bool = False,
|
||
n_per_category: int = 1000,
|
||
save: bool = True) -> Dict[str, Any]:
|
||
"""
|
||
加载并分割数据集
|
||
|
||
Args:
|
||
data_loader: 数据加载器实例
|
||
categories: 要包含的类别列表,默认为所有类别
|
||
val_split: 验证集比例
|
||
test_split: 测试集比例
|
||
sample_ratio: 采样比例,默认为1.0(全部)
|
||
balanced: 是否平衡各类别的样本数量
|
||
n_per_category: 平衡模式下每个类别的样本数量
|
||
save: 是否保存处理后的数据
|
||
|
||
Returns:
|
||
包含分割后数据集的字典
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 加载数据
|
||
if balanced:
|
||
logger.info(f"加载平衡数据集,每个类别 {n_per_category} 个样本")
|
||
data = data_loader.load_balanced_data(
|
||
n_per_category=n_per_category,
|
||
categories=categories,
|
||
shuffle=True
|
||
)
|
||
else:
|
||
logger.info(f"加载数据集,采样比例 {sample_ratio}")
|
||
data = data_loader.load_data(
|
||
categories=categories,
|
||
sample_ratio=sample_ratio,
|
||
shuffle=True,
|
||
return_generator=False
|
||
)
|
||
|
||
logger.info(f"加载了 {len(data)} 个样本")
|
||
|
||
# 分离文本和标签
|
||
texts = [text for text, _ in data]
|
||
labels = [label for _, label in data]
|
||
|
||
# 进行标签编码
|
||
encoded_labels = np.array([self.label_to_id[label] for label in labels])
|
||
|
||
# 计算数据统计信息
|
||
self._compute_stats(texts, labels)
|
||
|
||
# 划分训练集、验证集和测试集
|
||
# 先分出测试集
|
||
if test_split > 0:
|
||
train_val_texts, self.test_texts, train_val_labels, self.test_labels = train_test_split(
|
||
texts, encoded_labels,
|
||
test_size=test_split,
|
||
random_state=RANDOM_SEED,
|
||
stratify=encoded_labels if len(set(encoded_labels)) > 1 else None
|
||
)
|
||
else:
|
||
train_val_texts, train_val_labels = texts, encoded_labels
|
||
self.test_texts, self.test_labels = [], []
|
||
|
||
# 再划分训练集和验证集
|
||
if val_split > 0:
|
||
self.train_texts, self.val_texts, self.train_labels, self.val_labels = train_test_split(
|
||
train_val_texts, train_val_labels,
|
||
test_size=val_split / (1 - test_split),
|
||
random_state=RANDOM_SEED,
|
||
stratify=train_val_labels if len(set(train_val_labels)) > 1 else None
|
||
)
|
||
else:
|
||
self.train_texts, self.train_labels = train_val_texts, train_val_labels
|
||
self.val_texts, self.val_labels = [], []
|
||
|
||
# 打印数据集划分结果
|
||
logger.info(f"数据集划分结果:")
|
||
logger.info(f" 训练集:{len(self.train_texts)} 个样本")
|
||
logger.info(f" 验证集:{len(self.val_texts)} 个样本")
|
||
logger.info(f" 测试集:{len(self.test_texts)} 个样本")
|
||
|
||
# 保存处理后的数据
|
||
if save:
|
||
self.save_data()
|
||
|
||
elapsed = time.time() - start_time
|
||
logger.info(f"数据加载和分割完成,用时 {elapsed:.2f} 秒")
|
||
|
||
return {
|
||
"train_texts": self.train_texts,
|
||
"train_labels": self.train_labels,
|
||
"val_texts": self.val_texts,
|
||
"val_labels": self.val_labels,
|
||
"test_texts": self.test_texts,
|
||
"test_labels": self.test_labels,
|
||
"stats": self.stats
|
||
}
|
||
|
||
def _compute_stats(self, texts: List[str], labels: List[str]) -> None:
|
||
"""
|
||
计算数据统计信息
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
labels: 标签列表
|
||
"""
|
||
# 文本数量
|
||
num_samples = len(texts)
|
||
|
||
# 类别分布
|
||
label_counter = Counter(labels)
|
||
label_distribution = {label: count / num_samples * 100 for label, count in label_counter.items()}
|
||
|
||
# 文本长度统计
|
||
text_lengths = [len(text) for text in texts]
|
||
avg_length = sum(text_lengths) / len(text_lengths)
|
||
max_length = max(text_lengths)
|
||
min_length = min(text_lengths)
|
||
|
||
# 前5个最长和最短的文本的长度
|
||
sorted_lengths = sorted(text_lengths)
|
||
shortest_lengths = sorted_lengths[:5]
|
||
longest_lengths = sorted_lengths[-5:]
|
||
|
||
# 95%的文本长度分位数
|
||
percentile_95 = np.percentile(text_lengths, 95)
|
||
|
||
# 存储统计信息
|
||
self.stats = {
|
||
"num_samples": num_samples,
|
||
"num_categories": len(label_counter),
|
||
"label_counter": label_counter,
|
||
"label_distribution": label_distribution,
|
||
"text_length": {
|
||
"average": avg_length,
|
||
"max": max_length,
|
||
"min": min_length,
|
||
"percentile_95": percentile_95,
|
||
"shortest_5": shortest_lengths,
|
||
"longest_5": longest_lengths
|
||
}
|
||
}
|
||
|
||
def save_data(self, save_dir: Optional[str] = None) -> None:
|
||
"""
|
||
保存处理后的数据
|
||
|
||
Args:
|
||
save_dir: 保存目录,默认使用初始化时设置的目录
|
||
"""
|
||
save_dir = save_dir or self.processed_dir
|
||
ensure_dir(save_dir)
|
||
|
||
# 保存训练集
|
||
save_pickle(
|
||
{"texts": self.train_texts, "labels": self.train_labels},
|
||
os.path.join(save_dir, "train_data.pkl")
|
||
)
|
||
|
||
# 保存验证集
|
||
if len(self.val_texts) > 0:
|
||
save_pickle(
|
||
{"texts": self.val_texts, "labels": self.val_labels},
|
||
os.path.join(save_dir, "val_data.pkl")
|
||
)
|
||
|
||
# 保存测试集
|
||
if len(self.test_texts) > 0:
|
||
save_pickle(
|
||
{"texts": self.test_texts, "labels": self.test_labels},
|
||
os.path.join(save_dir, "test_data.pkl")
|
||
)
|
||
|
||
# 保存标签编码映射
|
||
save_json(
|
||
{"label_to_id": self.label_to_id, "id_to_label": self.id_to_label},
|
||
os.path.join(save_dir, "label_mapping.json")
|
||
)
|
||
|
||
# 保存数据统计信息
|
||
# 将Counter对象转换为普通字典以便JSON序列化
|
||
stats_for_json = self.stats.copy()
|
||
if "label_counter" in stats_for_json:
|
||
stats_for_json["label_counter"] = dict(stats_for_json["label_counter"])
|
||
|
||
save_json(
|
||
stats_for_json,
|
||
os.path.join(save_dir, "data_stats.json")
|
||
)
|
||
|
||
logger.info(f"已将处理后的数据保存到 {save_dir}")
|
||
|
||
def load_data(self, load_dir: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
加载处理后的数据
|
||
|
||
Args:
|
||
load_dir: 加载目录,默认使用初始化时设置的目录
|
||
|
||
Returns:
|
||
包含加载的数据集的字典
|
||
"""
|
||
load_dir = load_dir or self.processed_dir
|
||
|
||
# 加载训练集
|
||
train_data_path = os.path.join(load_dir, "train_data.pkl")
|
||
if os.path.exists(train_data_path):
|
||
train_data = load_pickle(train_data_path)
|
||
self.train_texts = train_data["texts"]
|
||
self.train_labels = train_data["labels"]
|
||
logger.info(f"已加载训练集,包含 {len(self.train_texts)} 个样本")
|
||
else:
|
||
logger.warning(f"训练集文件不存在: {train_data_path}")
|
||
self.train_texts, self.train_labels = [], []
|
||
|
||
# 加载验证集
|
||
val_data_path = os.path.join(load_dir, "val_data.pkl")
|
||
if os.path.exists(val_data_path):
|
||
val_data = load_pickle(val_data_path)
|
||
self.val_texts = val_data["texts"]
|
||
self.val_labels = val_data["labels"]
|
||
logger.info(f"已加载验证集,包含 {len(self.val_texts)} 个样本")
|
||
else:
|
||
logger.warning(f"验证集文件不存在: {val_data_path}")
|
||
self.val_texts, self.val_labels = [], []
|
||
|
||
# 加载测试集
|
||
test_data_path = os.path.join(load_dir, "test_data.pkl")
|
||
if os.path.exists(test_data_path):
|
||
test_data = load_pickle(test_data_path)
|
||
self.test_texts = test_data["texts"]
|
||
self.test_labels = test_data["labels"]
|
||
logger.info(f"已加载测试集,包含 {len(self.test_texts)} 个样本")
|
||
else:
|
||
logger.warning(f"测试集文件不存在: {test_data_path}")
|
||
self.test_texts, self.test_labels = [], []
|
||
|
||
# 加载标签编码映射
|
||
mapping_path = os.path.join(load_dir, "label_mapping.json")
|
||
if os.path.exists(mapping_path):
|
||
mapping = load_json(mapping_path)
|
||
self.label_to_id = mapping["label_to_id"]
|
||
self.id_to_label = mapping["id_to_label"]
|
||
# 将字符串键转换为整数(JSON序列化会将所有键转为字符串)
|
||
self.id_to_label = {int(k): v for k, v in self.id_to_label.items()}
|
||
logger.info(f"已加载标签编码映射,共 {len(self.label_to_id)} 个类别")
|
||
|
||
# 加载数据统计信息
|
||
stats_path = os.path.join(load_dir, "data_stats.json")
|
||
if os.path.exists(stats_path):
|
||
self.stats = load_json(stats_path)
|
||
logger.info("已加载数据统计信息")
|
||
|
||
return {
|
||
"train_texts": self.train_texts,
|
||
"train_labels": self.train_labels,
|
||
"val_texts": self.val_texts,
|
||
"val_labels": self.val_labels,
|
||
"test_texts": self.test_texts,
|
||
"test_labels": self.test_labels,
|
||
"stats": self.stats
|
||
}
|
||
|
||
def get_label_distribution(self, dataset: str = "train") -> Dict[str, float]:
|
||
"""
|
||
获取指定数据集的标签分布
|
||
|
||
Args:
|
||
dataset: 数据集名称,可选值:'train', 'val', 'test'
|
||
|
||
Returns:
|
||
标签分布字典,键为类别名称,值为比例
|
||
"""
|
||
if dataset == "train":
|
||
labels = self.train_labels
|
||
elif dataset == "val":
|
||
labels = self.val_labels
|
||
elif dataset == "test":
|
||
labels = self.test_labels
|
||
else:
|
||
raise ValueError(f"不支持的数据集名称: {dataset}")
|
||
|
||
# 计算标签分布
|
||
label_counter = Counter(labels)
|
||
num_samples = len(labels)
|
||
|
||
# 将数字标签转换为类别名称
|
||
distribution = {}
|
||
for label_id, count in label_counter.items():
|
||
label_name = self.id_to_label.get(label_id, str(label_id))
|
||
distribution[label_name] = count / num_samples * 100
|
||
|
||
return distribution
|
||
|
||
def visualize_label_distribution(self, dataset: str = "train",
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
可视化标签分布
|
||
|
||
Args:
|
||
dataset: 数据集名称,可选值:'train', 'val', 'test', 'all'
|
||
save_path: 图表保存路径,默认为None(显示而不保存)
|
||
"""
|
||
plt.figure(figsize=(12, 8))
|
||
|
||
if dataset == "all":
|
||
# 显示所有数据集的标签分布
|
||
train_dist = self.get_label_distribution("train")
|
||
val_dist = self.get_label_distribution("val") if len(self.val_labels) > 0 else {}
|
||
test_dist = self.get_label_distribution("test") if len(self.test_labels) > 0 else {}
|
||
|
||
# 准备数据
|
||
categories = list(train_dist.keys())
|
||
train_values = [train_dist.get(cat, 0) for cat in categories]
|
||
val_values = [val_dist.get(cat, 0) for cat in categories]
|
||
test_values = [test_dist.get(cat, 0) for cat in categories]
|
||
|
||
# 绘制条形图
|
||
x = np.arange(len(categories))
|
||
width = 0.25
|
||
|
||
plt.bar(x - width, train_values, width, label="Training")
|
||
if val_values:
|
||
plt.bar(x, val_values, width, label="Validation")
|
||
if test_values:
|
||
plt.bar(x + width, test_values, width, label="Testing")
|
||
|
||
plt.xlabel("Categories")
|
||
plt.ylabel("Percentage (%)")
|
||
plt.title("Label Distribution Across Datasets")
|
||
plt.xticks(x, categories, rotation=45, ha="right")
|
||
plt.legend()
|
||
plt.tight_layout()
|
||
else:
|
||
# 显示单个数据集的标签分布
|
||
distribution = self.get_label_distribution(dataset)
|
||
|
||
# 按值排序
|
||
sorted_items = sorted(distribution.items(), key=lambda x: x[1], reverse=True)
|
||
categories = [item[0] for item in sorted_items]
|
||
values = [item[1] for item in sorted_items]
|
||
|
||
# 绘制条形图
|
||
plt.bar(categories, values, color='skyblue')
|
||
plt.xlabel("Categories")
|
||
plt.ylabel("Percentage (%)")
|
||
plt.title(f"Label Distribution in {dataset.capitalize()} Dataset")
|
||
plt.xticks(rotation=45, ha="right")
|
||
plt.tight_layout()
|
||
|
||
# 保存或显示图表
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"标签分布图已保存到 {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
def visualize_text_length_distribution(self, dataset: str = "train",
|
||
bins: int = 50,
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
可视化文本长度分布
|
||
|
||
Args:
|
||
dataset: 数据集名称,可选值:'train', 'val', 'test'
|
||
bins: 直方图的箱数
|
||
save_path: 图表保存路径,默认为None(显示而不保存)
|
||
"""
|
||
if dataset == "train":
|
||
texts = self.train_texts
|
||
elif dataset == "val":
|
||
texts = self.val_texts
|
||
elif dataset == "test":
|
||
texts = self.test_texts
|
||
else:
|
||
raise ValueError(f"不支持的数据集名称: {dataset}")
|
||
|
||
# 计算文本长度
|
||
text_lengths = [len(text) for text in texts]
|
||
|
||
# 绘制直方图
|
||
plt.figure(figsize=(10, 6))
|
||
plt.hist(text_lengths, bins=bins, color='skyblue', alpha=0.7)
|
||
|
||
# 计算并绘制一些统计量
|
||
avg_length = sum(text_lengths) / len(text_lengths)
|
||
median_length = np.median(text_lengths)
|
||
percentile_95 = np.percentile(text_lengths, 95)
|
||
|
||
plt.axvline(avg_length, color='red', linestyle='dashed', linewidth=1, label=f'Mean: {avg_length:.1f}')
|
||
plt.axvline(median_length, color='green', linestyle='dashed', linewidth=1, label=f'Median: {median_length:.1f}')
|
||
plt.axvline(percentile_95, color='purple', linestyle='dashed', linewidth=1,
|
||
label=f'95th Percentile: {percentile_95:.1f}')
|
||
|
||
plt.xlabel('Text Length (characters)')
|
||
plt.ylabel('Frequency')
|
||
plt.title(f'Text Length Distribution in {dataset.capitalize()} Dataset')
|
||
plt.legend()
|
||
plt.tight_layout()
|
||
|
||
# 保存或显示图表
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"文本长度分布图已保存到 {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
def get_data_summary(self) -> Dict[str, Any]:
|
||
"""
|
||
获取数据集的摘要信息
|
||
|
||
Returns:
|
||
包含数据摘要的字典
|
||
"""
|
||
# 获取数据集的基本信息
|
||
summary = {
|
||
"train_size": len(self.train_texts),
|
||
"val_size": len(self.val_texts),
|
||
"test_size": len(self.test_texts),
|
||
"num_categories": len(self.label_to_id),
|
||
"categories": list(self.label_to_id.keys()),
|
||
}
|
||
|
||
# 添加训练集的标签分布
|
||
if len(self.train_texts) > 0:
|
||
summary["train_label_distribution"] = self.get_label_distribution("train")
|
||
|
||
# 添加验证集的标签分布
|
||
if len(self.val_texts) > 0:
|
||
summary["val_label_distribution"] = self.get_label_distribution("val")
|
||
|
||
# 添加测试集的标签分布
|
||
if len(self.test_texts) > 0:
|
||
summary["test_label_distribution"] = self.get_label_distribution("test")
|
||
|
||
# 添加更多统计信息(如果有)
|
||
if self.stats:
|
||
# 只添加一些关键的统计信息
|
||
if "text_length" in self.stats:
|
||
summary["text_length_stats"] = self.stats["text_length"]
|
||
|
||
return summary
|
||
|
||
def export_to_pandas(self, dataset: str = "train") -> pd.DataFrame:
|
||
"""
|
||
将数据导出为Pandas DataFrame
|
||
|
||
Args:
|
||
dataset: 数据集名称,可选值:'train', 'val', 'test'
|
||
|
||
Returns:
|
||
Pandas DataFrame
|
||
"""
|
||
if dataset == "train":
|
||
texts = self.train_texts
|
||
labels_ids = self.train_labels
|
||
elif dataset == "val":
|
||
texts = self.val_texts
|
||
labels_ids = self.val_labels
|
||
elif dataset == "test":
|
||
texts = self.test_texts
|
||
labels_ids = self.test_labels
|
||
else:
|
||
raise ValueError(f"不支持的数据集名称: {dataset}")
|
||
|
||
# 将数字标签转换为类别名称
|
||
labels = [self.id_to_label.get(label_id, str(label_id)) for label_id in labels_ids]
|
||
|
||
# 创建DataFrame
|
||
df = pd.DataFrame({
|
||
"text": texts,
|
||
"label_id": labels_ids,
|
||
"label": labels
|
||
})
|
||
|
||
return df
|
||
|
||
def get_label_name(self, label_id: int) -> str:
|
||
"""
|
||
获取标签ID对应的类别名称
|
||
|
||
Args:
|
||
label_id: 标签ID
|
||
|
||
Returns:
|
||
类别名称
|
||
"""
|
||
return self.id_to_label.get(label_id, str(label_id))
|
||
|
||
def get_label_id(self, label_name: str) -> int:
|
||
"""
|
||
获取类别名称对应的标签ID
|
||
|
||
Args:
|
||
label_name: 类别名称
|
||
|
||
Returns:
|
||
标签ID
|
||
"""
|
||
return self.label_to_id.get(label_name, -1)
|
||
|
||
def get_data(self, dataset: str = "train") -> Tuple[List[str], np.ndarray]:
|
||
"""
|
||
获取指定数据集的文本和标签
|
||
|
||
Args:
|
||
dataset: 数据集名称,可选值:'train', 'val', 'test'
|
||
|
||
Returns:
|
||
(文本列表, 标签数组)的元组
|
||
"""
|
||
if dataset == "train":
|
||
return self.train_texts, self.train_labels
|
||
elif dataset == "val":
|
||
return self.val_texts, self.val_labels
|
||
elif dataset == "test":
|
||
return self.test_texts, self.test_labels
|
||
else:
|
||
raise ValueError(f"不支持的数据集名称: {dataset}") |