""" 数据加载模块:负责从文件系统加载原始文本数据 """ import os import glob import time from pathlib import Path from typing import List, Dict, Tuple, Optional, Any from concurrent.futures import ThreadPoolExecutor, as_completed import random import numpy as np from config.system_config import ( RAW_DATA_DIR, DATA_LOADING_WORKERS, CATEGORIES, CATEGORY_TO_ID, ENCODING, MAX_MEMORY_GB, MAX_TEXT_PER_BATCH ) from config.model_config import RANDOM_SEED from utils.logger import get_logger from utils.file_utils import read_text_file, read_files_parallel, list_files # 设置随机种子以保证可重复性 random.seed(RANDOM_SEED) np.random.seed(RANDOM_SEED) logger = get_logger("DataLoader") class DataLoader: """负责加载THUCNews数据集的类""" def __init__(self, data_dir: Optional[str] = None, categories: Optional[List[str]] = None, encoding: str = ENCODING, max_workers: int = DATA_LOADING_WORKERS, max_text_per_batch: int = MAX_TEXT_PER_BATCH): """ 初始化数据加载器 Args: data_dir: 数据目录,默认使用配置文件中的路径 categories: 要加载的类别列表,默认加载所有类别 encoding: 文件编码 max_workers: 最大工作线程数 max_text_per_batch: 每批处理的最大文本数量 """ self.data_dir = Path(data_dir) if data_dir else RAW_DATA_DIR self.categories = categories if categories else CATEGORIES self.encoding = encoding self.max_workers = max_workers self.max_text_per_batch = max_text_per_batch # 验证数据目录是否存在 if not self.data_dir.exists(): raise FileNotFoundError(f"数据目录不存在: {self.data_dir}") # 验证类别是否存在 for category in self.categories: category_dir = self.data_dir / category if not category_dir.exists(): logger.warning(f"类别目录不存在: {category_dir}") # 存储类别目录的映射 self.category_dirs = { category: self.data_dir / category for category in self.categories if (self.data_dir / category).exists() } # 记录类别文件数量 self.category_file_counts = {} # 统计并记录每个类别的文件数量 self._count_files() logger.info(f"初始化完成,共找到 {sum(self.category_file_counts.values())} 个文本文件") def _count_files(self) -> None: """统计每个类别的文件数量""" for category, category_dir in self.category_dirs.items(): files = list(category_dir.glob("*.txt")) self.category_file_counts[category] = len(files) logger.info(f"类别 [{category}] 包含 {len(files)} 个文本文件") def get_file_paths(self, category: Optional[str] = None, sample_ratio: float = 1.0, shuffle: bool = True) -> List[Tuple[str, str]]: """ 获取指定类别的文件路径列表 Args: category: 类别名称,如果为None则获取所有类别 sample_ratio: 采样比例,默认为1.0(全部) shuffle: 是否打乱文件顺序 Returns: 包含(文件路径, 类别)元组的列表 """ file_paths = [] # 确定要处理的类别 categories_to_process = [category] if category else self.categories # 获取每个类别的文件路径 for cat in categories_to_process: if cat in self.category_dirs: category_dir = self.category_dirs[cat] cat_files = list(category_dir.glob("*.txt")) # 采样 if sample_ratio < 1.0: sample_size = int(len(cat_files) * sample_ratio) if shuffle: cat_files = random.sample(cat_files, sample_size) else: cat_files = cat_files[:sample_size] # 添加文件路径和对应的类别 file_paths.extend([(str(file), cat) for file in cat_files]) # 打乱全局顺序(如果需要) if shuffle: random.shuffle(file_paths) return file_paths def load_texts(self, file_paths: List[Tuple[str, str]]) -> List[Tuple[str, str]]: """ 加载指定路径的文本内容 Args: file_paths: 包含(文件路径, 类别)元组的列表 Returns: 包含(文本内容, 类别)元组的列表 """ start_time = time.time() texts_with_labels = [] # 提取文件路径列表 paths = [path for path, _ in file_paths] labels = [label for _, label in file_paths] # 并行加载文本内容 contents = read_files_parallel(paths, max_workers=self.max_workers, encoding=self.encoding) # 将内容与标签配对 for content, label in zip(contents, labels): if content: # 确保内容不为空 texts_with_labels.append((content, label)) elapsed = time.time() - start_time logger.info(f"加载了 {len(texts_with_labels)} 个文本,用时 {elapsed:.2f} 秒") return texts_with_labels def load_data(self, categories: Optional[List[str]] = None, sample_ratio: float = 1.0, shuffle: bool = True, return_generator: bool = False) -> Any: """ 加载指定类别的所有数据 Args: categories: 要加载的类别列表,默认为所有类别 sample_ratio: 采样比例,默认为1.0(全部) shuffle: 是否打乱数据顺序 return_generator: 是否返回生成器(批量加载) Returns: 如果return_generator为False,返回包含(文本内容, 类别)元组的列表 如果return_generator为True,返回一个生成器,每次产生一批数据 """ # 确定要处理的类别 cats_to_process = categories if categories else self.categories # 验证类别是否存在 for cat in cats_to_process: if cat not in self.category_dirs: logger.warning(f"类别 {cat} 不存在,将被忽略") # 筛选存在的类别 cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs] # 获取所有文件路径 all_file_paths = [] for cat in cats_to_process: cat_files = self.get_file_paths(cat, sample_ratio=sample_ratio, shuffle=shuffle) all_file_paths.extend(cat_files) # 打乱全局顺序(如果需要) if shuffle: random.shuffle(all_file_paths) # 如果需要返回生成器,分批次加载数据 if return_generator: def data_generator(): for i in range(0, len(all_file_paths), self.max_text_per_batch): batch_paths = all_file_paths[i:i + self.max_text_per_batch] batch_data = self.load_texts(batch_paths) yield batch_data return data_generator() # 否则,一次性加载所有数据 return self.load_texts(all_file_paths) def load_balanced_data(self, n_per_category: int = 1000, categories: Optional[List[str]] = None, shuffle: bool = True) -> List[Tuple[str, str]]: """ 加载平衡的数据集(每个类别的样本数量相同) Args: n_per_category: 每个类别加载的样本数量 categories: 要加载的类别列表,默认为所有类别 shuffle: 是否打乱数据顺序 Returns: 包含(文本内容, 类别)元组的列表 """ # 确定要处理的类别 cats_to_process = categories if categories else self.categories cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs] balanced_data = [] for cat in cats_to_process: # 获取该类别的文件路径 cat_files = self.get_file_paths(cat, shuffle=shuffle) # 限制数量 cat_files = cat_files[:n_per_category] # 加载文本 cat_data = self.load_texts(cat_files) balanced_data.extend(cat_data) # 打乱全局顺序(如果需要) if shuffle: random.shuffle(balanced_data) return balanced_data def get_category_distribution(self) -> Dict[str, int]: """ 获取数据集的类别分布 Returns: 包含各类别样本数量的字典 """ return self.category_file_counts def get_data_stats(self) -> Dict[str, Any]: """ 获取数据集的统计信息 Returns: 包含统计信息的字典 """ # 计算总样本数 total_samples = sum(self.category_file_counts.values()) # 计算各类别占比 category_percentages = { cat: count / total_samples * 100 for cat, count in self.category_file_counts.items() } # 采样几个文件计算平均文本长度 sample_files = [] for cat in self.categories: if cat in self.category_dirs: cat_files = list((self.data_dir / cat).glob("*.txt")) if cat_files: # 每个类别最多采样10个文件 sample_files.extend(random.sample(cat_files, min(10, len(cat_files)))) # 加载采样的文件内容 sample_contents = [] for file_path in sample_files: content = read_text_file(str(file_path), encoding=self.encoding) if content: sample_contents.append(content) # 计算平均文本长度(字符数) avg_char_length = sum(len(content) for content in sample_contents) / len( sample_contents) if sample_contents else 0 # 返回统计信息 return { "total_samples": total_samples, "category_counts": self.category_file_counts, "category_percentages": category_percentages, "average_text_length": avg_char_length, "categories": self.categories, "num_categories": len(self.categories), }