2025-03-08 01:34:36 +08:00

363 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
批处理模块:实现批量处理大规模文本数据
"""
import os
import time
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Optional, Any, Union, Callable, Iterator
import concurrent.futures
from tqdm import tqdm
import glob
import json
from config.system_config import ENCODING, DATA_LOADING_WORKERS, MAX_TEXT_PER_BATCH
from utils.logger import get_logger
from utils.file_utils import read_text_file, ensure_dir
from inference.predictor import Predictor
logger = get_logger("BatchProcessor")
class BatchProcessor:
"""批处理器,负责批量处理大规模文本数据"""
def __init__(self, predictor: Predictor,
batch_size: int = 64,
max_workers: int = DATA_LOADING_WORKERS,
max_batch_queue: int = 10):
"""
初始化批处理器
Args:
predictor: 预测器实例
batch_size: 批大小
max_workers: 最大工作线程数
max_batch_queue: 最大批次队列长度
"""
self.predictor = predictor
self.batch_size = batch_size
self.max_workers = max_workers
self.max_batch_queue = max_batch_queue
logger.info(f"初始化批处理器,批大小: {batch_size}, 最大工作线程数: {max_workers}")
def _extract_text_from_file(self, file_path: str) -> str:
"""
从文件中提取文本
Args:
file_path: 文件路径
Returns:
文本内容
"""
return read_text_file(file_path, encoding=ENCODING)
def _batch_generator(self, texts: List[str], batch_size: int) -> Iterator[List[str]]:
"""
生成文本批次
Args:
texts: 文本列表
batch_size: 批大小
Returns:
文本批次生成器
"""
for i in range(0, len(texts), batch_size):
yield texts[i:i + batch_size]
def process_files(self, file_paths: List[str], output_path: Optional[str] = None,
return_top_k: int = 1, format: str = 'csv') -> pd.DataFrame:
"""
批量处理文件
Args:
file_paths: 文件路径列表
output_path: 输出文件路径如果为None则不保存
return_top_k: 返回概率最高的前k个类别
format: 输出格式,'csv''json'
Returns:
预测结果DataFrame
"""
logger.info(f"开始批量处理 {len(file_paths)} 个文件")
start_time = time.time()
# 使用线程池并行读取文件
texts = []
file_names = []
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_file = {executor.submit(self._extract_text_from_file, file_path): file_path for file_path in
file_paths}
for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(file_paths), desc="读取文件"):
file_path = future_to_file[future]
try:
text = future.result()
if text:
texts.append(text)
file_names.append(os.path.basename(file_path))
except Exception as e:
logger.error(f"处理文件 {file_path} 时出错: {e}")
# 批量预测
all_predictions = []
for batch in tqdm(self._batch_generator(texts, self.batch_size),
total=(len(texts) + self.batch_size - 1) // self.batch_size, desc="预测"):
predictions = self.predictor.predict_batch(batch, return_top_k=return_top_k, return_probabilities=True)
all_predictions.extend(predictions)
# 整合结果
if return_top_k > 1:
# 多个类别的情况
results = []
for i, preds in enumerate(all_predictions):
for j, pred in enumerate(preds):
results.append({
'file_name': file_names[i],
'rank': j + 1,
'predicted_class': pred['class'],
'probability': pred['probability']
})
df = pd.DataFrame(results)
else:
# 单个类别的情况
df = pd.DataFrame({
'file_name': file_names,
'predicted_class': [pred['class'] for pred in all_predictions],
'probability': [pred['probability'] for pred in all_predictions]
})
# 保存结果
if output_path:
if format.lower() == 'csv':
df.to_csv(output_path, index=False, encoding='utf-8')
elif format.lower() == 'json':
# 转换为嵌套的JSON格式
if return_top_k > 1:
# 分组后转换为嵌套格式
result = {}
for file_name in df['file_name'].unique():
sub_df = df[df['file_name'] == file_name]
predictions = []
for _, row in sub_df.iterrows():
predictions.append({
'class': row['predicted_class'],
'probability': row['probability']
})
result[file_name] = {
'predictions': predictions
}
else:
# 直接构建JSON
result = {}
for _, row in df.iterrows():
result[row['file_name']] = {
'predicted_class': row['predicted_class'],
'probability': row['probability']
}
# 保存为JSON
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
else:
raise ValueError(f"不支持的输出格式: {format}")
logger.info(f"预测结果已保存到: {output_path}")
processing_time = time.time() - start_time
logger.info(f"批量处理完成,共处理 {len(texts)} 个文件,用时: {processing_time:.2f}")
return df
def process_directory(self, directory: str, pattern: str = "*.txt",
output_path: Optional[str] = None,
return_top_k: int = 1, format: str = 'csv',
recursive: bool = True) -> pd.DataFrame:
"""
批量处理目录中的文件
Args:
directory: 目录路径
pattern: 文件匹配模式
output_path: 输出文件路径如果为None则不保存
return_top_k: 返回概率最高的前k个类别
format: 输出格式,'csv''json'
recursive: 是否递归处理子目录
Returns:
预测结果DataFrame
"""
# 获取符合模式的文件路径
if recursive:
file_paths = glob.glob(os.path.join(directory, "**", pattern), recursive=True)
else:
file_paths = glob.glob(os.path.join(directory, pattern))
if not file_paths:
logger.warning(f"在目录 {directory} 中未找到符合模式 {pattern} 的文件")
return pd.DataFrame()
logger.info(f"在目录 {directory} 中找到 {len(file_paths)} 个符合模式 {pattern} 的文件")
# 调用process_files处理文件
return self.process_files(file_paths, output_path, return_top_k, format)
def process_dataframe(self, df: pd.DataFrame, text_column: str,
id_column: Optional[str] = None,
output_path: Optional[str] = None,
return_top_k: int = 1, format: str = 'csv') -> pd.DataFrame:
"""
批量处理DataFrame中的文本
Args:
df: 输入DataFrame
text_column: 文本列名
id_column: ID列名如果为None则使用索引
output_path: 输出文件路径如果为None则不保存
return_top_k: 返回概率最高的前k个类别
format: 输出格式,'csv''json'
Returns:
预测结果DataFrame
"""
# 获取文本和ID
texts = df[text_column].tolist()
if id_column:
ids = df[id_column].tolist()
else:
ids = df.index.tolist()
# 批量预测
result_df = self.predictor.predict_to_dataframe(texts, ids, return_top_k)
# 保存结果
if output_path:
if format.lower() == 'csv':
result_df.to_csv(output_path, index=False, encoding='utf-8')
elif format.lower() == 'json':
# 转换为嵌套的JSON格式
self.predictor.save_predictions(texts, output_path, ids, return_top_k, 'json')
else:
raise ValueError(f"不支持的输出格式: {format}")
logger.info(f"预测结果已保存到: {output_path}")
return result_df
def process_large_file(self, file_path: str, output_path: Optional[str] = None,
return_top_k: int = 1, format: str = 'csv',
chunk_size: int = MAX_TEXT_PER_BATCH,
delimiter: str = '\n\n') -> None:
"""
处理大型文本文件,文件会被分块读取和处理
Args:
file_path: 文件路径
output_path: 输出文件路径如果为None则不保存
return_top_k: 返回概率最高的前k个类别
format: 输出格式,'csv''json'
chunk_size: 每个块的大小(文本数量)
delimiter: 文本分隔符
"""
logger.info(f"开始处理大型文件: {file_path}")
start_time = time.time()
# 读取文件内容
with open(file_path, 'r', encoding=ENCODING) as f:
content = f.read()
# 分割文本
texts = content.split(delimiter)
texts = [text.strip() for text in texts if text.strip()]
logger.info(f"文件共包含 {len(texts)} 条文本")
# 创建输出文件
if output_path:
if format.lower() == 'csv':
# 创建CSV文件头
if return_top_k > 1:
header = "text_id,text,rank,predicted_class,probability\n"
else:
header = "text_id,text,predicted_class,probability\n"
with open(output_path, 'w', encoding=ENCODING) as f:
f.write(header)
elif format.lower() == 'json':
# 创建JSON文件
with open(output_path, 'w', encoding=ENCODING) as f:
f.write('{\n')
# 分块处理
total_chunks = (len(texts) + chunk_size - 1) // chunk_size
for i in range(0, len(texts), chunk_size):
chunk = texts[i:i + chunk_size]
chunk_ids = list(range(i, i + len(chunk)))
logger.info(f"处理第 {i // chunk_size + 1}/{total_chunks} 块,包含 {len(chunk)} 条文本")
# 批量预测
result_df = self.predictor.predict_to_dataframe(chunk, chunk_ids, return_top_k)
# 追加到输出文件
if output_path:
if format.lower() == 'csv':
result_df.to_csv(output_path, index=False, encoding=ENCODING, mode='a', header=False)
elif format.lower() == 'json':
# 转换为JSON并追加
if return_top_k > 1:
# 分组后转换为嵌套格式
for id_val in result_df['id'].unique():
sub_df = result_df[result_df['id'] == id_val]
predictions = []
for _, row in sub_df.iterrows():
predictions.append({
'class': row['predicted_class'],
'probability': float(row['probability'])
})
json_str = f' "{id_val}": {{\n'
json_str += f' "text": {json.dumps(sub_df.iloc[0]["text"], ensure_ascii=False)},\n'
json_str += f' "predictions": {json.dumps(predictions, ensure_ascii=False)}\n'
json_str += ' },'
with open(output_path, 'a', encoding=ENCODING) as f:
f.write(json_str + '\n')
else:
# 直接构建JSON
for _, row in result_df.iterrows():
json_str = f' "{row["id"]}": {{\n'
json_str += f' "text": {json.dumps(row["text"], ensure_ascii=False)},\n'
json_str += f' "predicted_class": "{row["predicted_class"]}",\n'
json_str += f' "probability": {float(row["probability"])}\n'
json_str += ' },'
with open(output_path, 'a', encoding=ENCODING) as f:
f.write(json_str + '\n')
# 完成JSON文件
if output_path and format.lower() == 'json':
with open(output_path, 'a', encoding=ENCODING) as f:
f.write('}\n')
# 修复JSON文件中的最后一个逗号
with open(output_path, 'r', encoding=ENCODING) as f:
content = f.read()
content = content.rstrip('\n}')
content = content.rstrip(',')
content += '\n}\n'
with open(output_path, 'w', encoding=ENCODING) as f:
f.write(content)
processing_time = time.time() - start_time
logger.info(f"处理大型文件完成,共处理 {len(texts)} 条文本,用时: {processing_time:.2f}")