2025-03-08 01:34:36 +08:00

297 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据加载模块:负责从文件系统加载原始文本数据
"""
import os
import glob
import time
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
import random
import numpy as np
from config.system_config import (
RAW_DATA_DIR, DATA_LOADING_WORKERS, CATEGORIES,
CATEGORY_TO_ID, ENCODING, MAX_MEMORY_GB, MAX_TEXT_PER_BATCH
)
from config.model_config import RANDOM_SEED
from utils.logger import get_logger
from utils.file_utils import read_text_file, read_files_parallel, list_files
# 设置随机种子以保证可重复性
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
logger = get_logger("DataLoader")
class DataLoader:
"""负责加载THUCNews数据集的类"""
def __init__(self, data_dir: Optional[str] = None,
categories: Optional[List[str]] = None,
encoding: str = ENCODING,
max_workers: int = DATA_LOADING_WORKERS,
max_text_per_batch: int = MAX_TEXT_PER_BATCH):
"""
初始化数据加载器
Args:
data_dir: 数据目录,默认使用配置文件中的路径
categories: 要加载的类别列表,默认加载所有类别
encoding: 文件编码
max_workers: 最大工作线程数
max_text_per_batch: 每批处理的最大文本数量
"""
self.data_dir = Path(data_dir) if data_dir else RAW_DATA_DIR
self.categories = categories if categories else CATEGORIES
self.encoding = encoding
self.max_workers = max_workers
self.max_text_per_batch = max_text_per_batch
# 验证数据目录是否存在
if not self.data_dir.exists():
raise FileNotFoundError(f"数据目录不存在: {self.data_dir}")
# 验证类别是否存在
for category in self.categories:
category_dir = self.data_dir / category
if not category_dir.exists():
logger.warning(f"类别目录不存在: {category_dir}")
# 存储类别目录的映射
self.category_dirs = {
category: self.data_dir / category
for category in self.categories
if (self.data_dir / category).exists()
}
# 记录类别文件数量
self.category_file_counts = {}
# 统计并记录每个类别的文件数量
self._count_files()
logger.info(f"初始化完成,共找到 {sum(self.category_file_counts.values())} 个文本文件")
def _count_files(self) -> None:
"""统计每个类别的文件数量"""
for category, category_dir in self.category_dirs.items():
files = list(category_dir.glob("*.txt"))
self.category_file_counts[category] = len(files)
logger.info(f"类别 [{category}] 包含 {len(files)} 个文本文件")
def get_file_paths(self, category: Optional[str] = None,
sample_ratio: float = 1.0,
shuffle: bool = True) -> List[Tuple[str, str]]:
"""
获取指定类别的文件路径列表
Args:
category: 类别名称如果为None则获取所有类别
sample_ratio: 采样比例默认为1.0(全部)
shuffle: 是否打乱文件顺序
Returns:
包含(文件路径, 类别)元组的列表
"""
file_paths = []
# 确定要处理的类别
categories_to_process = [category] if category else self.categories
# 获取每个类别的文件路径
for cat in categories_to_process:
if cat in self.category_dirs:
category_dir = self.category_dirs[cat]
cat_files = list(category_dir.glob("*.txt"))
# 采样
if sample_ratio < 1.0:
sample_size = int(len(cat_files) * sample_ratio)
if shuffle:
cat_files = random.sample(cat_files, sample_size)
else:
cat_files = cat_files[:sample_size]
# 添加文件路径和对应的类别
file_paths.extend([(str(file), cat) for file in cat_files])
# 打乱全局顺序(如果需要)
if shuffle:
random.shuffle(file_paths)
return file_paths
def load_texts(self, file_paths: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""
加载指定路径的文本内容
Args:
file_paths: 包含(文件路径, 类别)元组的列表
Returns:
包含(文本内容, 类别)元组的列表
"""
start_time = time.time()
texts_with_labels = []
# 提取文件路径列表
paths = [path for path, _ in file_paths]
labels = [label for _, label in file_paths]
# 并行加载文本内容
contents = read_files_parallel(paths, max_workers=self.max_workers, encoding=self.encoding)
# 将内容与标签配对
for content, label in zip(contents, labels):
if content: # 确保内容不为空
texts_with_labels.append((content, label))
elapsed = time.time() - start_time
logger.info(f"加载了 {len(texts_with_labels)} 个文本,用时 {elapsed:.2f}")
return texts_with_labels
def load_data(self, categories: Optional[List[str]] = None,
sample_ratio: float = 1.0,
shuffle: bool = True,
return_generator: bool = False) -> Any:
"""
加载指定类别的所有数据
Args:
categories: 要加载的类别列表,默认为所有类别
sample_ratio: 采样比例默认为1.0(全部)
shuffle: 是否打乱数据顺序
return_generator: 是否返回生成器(批量加载)
Returns:
如果return_generator为False返回包含(文本内容, 类别)元组的列表
如果return_generator为True返回一个生成器每次产生一批数据
"""
# 确定要处理的类别
cats_to_process = categories if categories else self.categories
# 验证类别是否存在
for cat in cats_to_process:
if cat not in self.category_dirs:
logger.warning(f"类别 {cat} 不存在,将被忽略")
# 筛选存在的类别
cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs]
# 获取所有文件路径
all_file_paths = []
for cat in cats_to_process:
cat_files = self.get_file_paths(cat, sample_ratio=sample_ratio, shuffle=shuffle)
all_file_paths.extend(cat_files)
# 打乱全局顺序(如果需要)
if shuffle:
random.shuffle(all_file_paths)
# 如果需要返回生成器,分批次加载数据
if return_generator:
def data_generator():
for i in range(0, len(all_file_paths), self.max_text_per_batch):
batch_paths = all_file_paths[i:i + self.max_text_per_batch]
batch_data = self.load_texts(batch_paths)
yield batch_data
return data_generator()
# 否则,一次性加载所有数据
return self.load_texts(all_file_paths)
def load_balanced_data(self, n_per_category: int = 1000,
categories: Optional[List[str]] = None,
shuffle: bool = True) -> List[Tuple[str, str]]:
"""
加载平衡的数据集(每个类别的样本数量相同)
Args:
n_per_category: 每个类别加载的样本数量
categories: 要加载的类别列表,默认为所有类别
shuffle: 是否打乱数据顺序
Returns:
包含(文本内容, 类别)元组的列表
"""
# 确定要处理的类别
cats_to_process = categories if categories else self.categories
cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs]
balanced_data = []
for cat in cats_to_process:
# 获取该类别的文件路径
cat_files = self.get_file_paths(cat, shuffle=shuffle)
# 限制数量
cat_files = cat_files[:n_per_category]
# 加载文本
cat_data = self.load_texts(cat_files)
balanced_data.extend(cat_data)
# 打乱全局顺序(如果需要)
if shuffle:
random.shuffle(balanced_data)
return balanced_data
def get_category_distribution(self) -> Dict[str, int]:
"""
获取数据集的类别分布
Returns:
包含各类别样本数量的字典
"""
return self.category_file_counts
def get_data_stats(self) -> Dict[str, Any]:
"""
获取数据集的统计信息
Returns:
包含统计信息的字典
"""
# 计算总样本数
total_samples = sum(self.category_file_counts.values())
# 计算各类别占比
category_percentages = {
cat: count / total_samples * 100
for cat, count in self.category_file_counts.items()
}
# 采样几个文件计算平均文本长度
sample_files = []
for cat in self.categories:
if cat in self.category_dirs:
cat_files = list((self.data_dir / cat).glob("*.txt"))
if cat_files:
# 每个类别最多采样10个文件
sample_files.extend(random.sample(cat_files, min(10, len(cat_files))))
# 加载采样的文件内容
sample_contents = []
for file_path in sample_files:
content = read_text_file(str(file_path), encoding=self.encoding)
if content:
sample_contents.append(content)
# 计算平均文本长度(字符数)
avg_char_length = sum(len(content) for content in sample_contents) / len(
sample_contents) if sample_contents else 0
# 返回统计信息
return {
"total_samples": total_samples,
"category_counts": self.category_file_counts,
"category_percentages": category_percentages,
"average_text_length": avg_char_length,
"categories": self.categories,
"num_categories": len(self.categories),
}