297 lines
10 KiB
Python
297 lines
10 KiB
Python
"""
|
||
数据加载模块:负责从文件系统加载原始文本数据
|
||
"""
|
||
import os
|
||
import glob
|
||
import time
|
||
from pathlib import Path
|
||
from typing import List, Dict, Tuple, Optional, Any
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import random
|
||
import numpy as np
|
||
|
||
from config.system_config import (
|
||
RAW_DATA_DIR, DATA_LOADING_WORKERS, CATEGORIES,
|
||
CATEGORY_TO_ID, ENCODING, MAX_MEMORY_GB, MAX_TEXT_PER_BATCH
|
||
)
|
||
from config.model_config import RANDOM_SEED
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import read_text_file, read_files_parallel, list_files
|
||
|
||
# 设置随机种子以保证可重复性
|
||
random.seed(RANDOM_SEED)
|
||
np.random.seed(RANDOM_SEED)
|
||
|
||
logger = get_logger("DataLoader")
|
||
|
||
|
||
class DataLoader:
|
||
"""负责加载THUCNews数据集的类"""
|
||
|
||
def __init__(self, data_dir: Optional[str] = None,
|
||
categories: Optional[List[str]] = None,
|
||
encoding: str = ENCODING,
|
||
max_workers: int = DATA_LOADING_WORKERS,
|
||
max_text_per_batch: int = MAX_TEXT_PER_BATCH):
|
||
"""
|
||
初始化数据加载器
|
||
|
||
Args:
|
||
data_dir: 数据目录,默认使用配置文件中的路径
|
||
categories: 要加载的类别列表,默认加载所有类别
|
||
encoding: 文件编码
|
||
max_workers: 最大工作线程数
|
||
max_text_per_batch: 每批处理的最大文本数量
|
||
"""
|
||
self.data_dir = Path(data_dir) if data_dir else RAW_DATA_DIR
|
||
self.categories = categories if categories else CATEGORIES
|
||
self.encoding = encoding
|
||
self.max_workers = max_workers
|
||
self.max_text_per_batch = max_text_per_batch
|
||
|
||
# 验证数据目录是否存在
|
||
if not self.data_dir.exists():
|
||
raise FileNotFoundError(f"数据目录不存在: {self.data_dir}")
|
||
|
||
# 验证类别是否存在
|
||
for category in self.categories:
|
||
category_dir = self.data_dir / category
|
||
if not category_dir.exists():
|
||
logger.warning(f"类别目录不存在: {category_dir}")
|
||
|
||
# 存储类别目录的映射
|
||
self.category_dirs = {
|
||
category: self.data_dir / category
|
||
for category in self.categories
|
||
if (self.data_dir / category).exists()
|
||
}
|
||
|
||
# 记录类别文件数量
|
||
self.category_file_counts = {}
|
||
|
||
# 统计并记录每个类别的文件数量
|
||
self._count_files()
|
||
logger.info(f"初始化完成,共找到 {sum(self.category_file_counts.values())} 个文本文件")
|
||
|
||
def _count_files(self) -> None:
|
||
"""统计每个类别的文件数量"""
|
||
for category, category_dir in self.category_dirs.items():
|
||
files = list(category_dir.glob("*.txt"))
|
||
self.category_file_counts[category] = len(files)
|
||
logger.info(f"类别 [{category}] 包含 {len(files)} 个文本文件")
|
||
|
||
def get_file_paths(self, category: Optional[str] = None,
|
||
sample_ratio: float = 1.0,
|
||
shuffle: bool = True) -> List[Tuple[str, str]]:
|
||
"""
|
||
获取指定类别的文件路径列表
|
||
|
||
Args:
|
||
category: 类别名称,如果为None则获取所有类别
|
||
sample_ratio: 采样比例,默认为1.0(全部)
|
||
shuffle: 是否打乱文件顺序
|
||
|
||
Returns:
|
||
包含(文件路径, 类别)元组的列表
|
||
"""
|
||
file_paths = []
|
||
|
||
# 确定要处理的类别
|
||
categories_to_process = [category] if category else self.categories
|
||
|
||
# 获取每个类别的文件路径
|
||
for cat in categories_to_process:
|
||
if cat in self.category_dirs:
|
||
category_dir = self.category_dirs[cat]
|
||
cat_files = list(category_dir.glob("*.txt"))
|
||
|
||
# 采样
|
||
if sample_ratio < 1.0:
|
||
sample_size = int(len(cat_files) * sample_ratio)
|
||
if shuffle:
|
||
cat_files = random.sample(cat_files, sample_size)
|
||
else:
|
||
cat_files = cat_files[:sample_size]
|
||
|
||
# 添加文件路径和对应的类别
|
||
file_paths.extend([(str(file), cat) for file in cat_files])
|
||
|
||
# 打乱全局顺序(如果需要)
|
||
if shuffle:
|
||
random.shuffle(file_paths)
|
||
|
||
return file_paths
|
||
|
||
def load_texts(self, file_paths: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||
"""
|
||
加载指定路径的文本内容
|
||
|
||
Args:
|
||
file_paths: 包含(文件路径, 类别)元组的列表
|
||
|
||
Returns:
|
||
包含(文本内容, 类别)元组的列表
|
||
"""
|
||
start_time = time.time()
|
||
texts_with_labels = []
|
||
|
||
# 提取文件路径列表
|
||
paths = [path for path, _ in file_paths]
|
||
labels = [label for _, label in file_paths]
|
||
|
||
# 并行加载文本内容
|
||
contents = read_files_parallel(paths, max_workers=self.max_workers, encoding=self.encoding)
|
||
|
||
# 将内容与标签配对
|
||
for content, label in zip(contents, labels):
|
||
if content: # 确保内容不为空
|
||
texts_with_labels.append((content, label))
|
||
|
||
elapsed = time.time() - start_time
|
||
logger.info(f"加载了 {len(texts_with_labels)} 个文本,用时 {elapsed:.2f} 秒")
|
||
|
||
return texts_with_labels
|
||
|
||
def load_data(self, categories: Optional[List[str]] = None,
|
||
sample_ratio: float = 1.0,
|
||
shuffle: bool = True,
|
||
return_generator: bool = False) -> Any:
|
||
"""
|
||
加载指定类别的所有数据
|
||
|
||
Args:
|
||
categories: 要加载的类别列表,默认为所有类别
|
||
sample_ratio: 采样比例,默认为1.0(全部)
|
||
shuffle: 是否打乱数据顺序
|
||
return_generator: 是否返回生成器(批量加载)
|
||
|
||
Returns:
|
||
如果return_generator为False,返回包含(文本内容, 类别)元组的列表
|
||
如果return_generator为True,返回一个生成器,每次产生一批数据
|
||
"""
|
||
# 确定要处理的类别
|
||
cats_to_process = categories if categories else self.categories
|
||
|
||
# 验证类别是否存在
|
||
for cat in cats_to_process:
|
||
if cat not in self.category_dirs:
|
||
logger.warning(f"类别 {cat} 不存在,将被忽略")
|
||
|
||
# 筛选存在的类别
|
||
cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs]
|
||
|
||
# 获取所有文件路径
|
||
all_file_paths = []
|
||
for cat in cats_to_process:
|
||
cat_files = self.get_file_paths(cat, sample_ratio=sample_ratio, shuffle=shuffle)
|
||
all_file_paths.extend(cat_files)
|
||
|
||
# 打乱全局顺序(如果需要)
|
||
if shuffle:
|
||
random.shuffle(all_file_paths)
|
||
|
||
# 如果需要返回生成器,分批次加载数据
|
||
if return_generator:
|
||
def data_generator():
|
||
for i in range(0, len(all_file_paths), self.max_text_per_batch):
|
||
batch_paths = all_file_paths[i:i + self.max_text_per_batch]
|
||
batch_data = self.load_texts(batch_paths)
|
||
yield batch_data
|
||
|
||
return data_generator()
|
||
|
||
# 否则,一次性加载所有数据
|
||
return self.load_texts(all_file_paths)
|
||
|
||
def load_balanced_data(self, n_per_category: int = 1000,
|
||
categories: Optional[List[str]] = None,
|
||
shuffle: bool = True) -> List[Tuple[str, str]]:
|
||
"""
|
||
加载平衡的数据集(每个类别的样本数量相同)
|
||
|
||
Args:
|
||
n_per_category: 每个类别加载的样本数量
|
||
categories: 要加载的类别列表,默认为所有类别
|
||
shuffle: 是否打乱数据顺序
|
||
|
||
Returns:
|
||
包含(文本内容, 类别)元组的列表
|
||
"""
|
||
# 确定要处理的类别
|
||
cats_to_process = categories if categories else self.categories
|
||
cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs]
|
||
|
||
balanced_data = []
|
||
|
||
for cat in cats_to_process:
|
||
# 获取该类别的文件路径
|
||
cat_files = self.get_file_paths(cat, shuffle=shuffle)
|
||
|
||
# 限制数量
|
||
cat_files = cat_files[:n_per_category]
|
||
|
||
# 加载文本
|
||
cat_data = self.load_texts(cat_files)
|
||
balanced_data.extend(cat_data)
|
||
|
||
# 打乱全局顺序(如果需要)
|
||
if shuffle:
|
||
random.shuffle(balanced_data)
|
||
|
||
return balanced_data
|
||
|
||
def get_category_distribution(self) -> Dict[str, int]:
|
||
"""
|
||
获取数据集的类别分布
|
||
|
||
Returns:
|
||
包含各类别样本数量的字典
|
||
"""
|
||
return self.category_file_counts
|
||
|
||
def get_data_stats(self) -> Dict[str, Any]:
|
||
"""
|
||
获取数据集的统计信息
|
||
|
||
Returns:
|
||
包含统计信息的字典
|
||
"""
|
||
# 计算总样本数
|
||
total_samples = sum(self.category_file_counts.values())
|
||
|
||
# 计算各类别占比
|
||
category_percentages = {
|
||
cat: count / total_samples * 100
|
||
for cat, count in self.category_file_counts.items()
|
||
}
|
||
|
||
# 采样几个文件计算平均文本长度
|
||
sample_files = []
|
||
for cat in self.categories:
|
||
if cat in self.category_dirs:
|
||
cat_files = list((self.data_dir / cat).glob("*.txt"))
|
||
if cat_files:
|
||
# 每个类别最多采样10个文件
|
||
sample_files.extend(random.sample(cat_files, min(10, len(cat_files))))
|
||
|
||
# 加载采样的文件内容
|
||
sample_contents = []
|
||
for file_path in sample_files:
|
||
content = read_text_file(str(file_path), encoding=self.encoding)
|
||
if content:
|
||
sample_contents.append(content)
|
||
|
||
# 计算平均文本长度(字符数)
|
||
avg_char_length = sum(len(content) for content in sample_contents) / len(
|
||
sample_contents) if sample_contents else 0
|
||
|
||
# 返回统计信息
|
||
return {
|
||
"total_samples": total_samples,
|
||
"category_counts": self.category_file_counts,
|
||
"category_percentages": category_percentages,
|
||
"average_text_length": avg_char_length,
|
||
"categories": self.categories,
|
||
"num_categories": len(self.categories),
|
||
}
|