363 lines
14 KiB
Python
363 lines
14 KiB
Python
"""
|
||
批处理模块:实现批量处理大规模文本数据
|
||
"""
|
||
import os
|
||
import time
|
||
import pandas as pd
|
||
import numpy as np
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Callable, Iterator
|
||
import concurrent.futures
|
||
from tqdm import tqdm
|
||
import glob
|
||
import json
|
||
|
||
from config.system_config import ENCODING, DATA_LOADING_WORKERS, MAX_TEXT_PER_BATCH
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import read_text_file, ensure_dir
|
||
from inference.predictor import Predictor
|
||
|
||
logger = get_logger("BatchProcessor")
|
||
|
||
|
||
class BatchProcessor:
|
||
"""批处理器,负责批量处理大规模文本数据"""
|
||
|
||
def __init__(self, predictor: Predictor,
|
||
batch_size: int = 64,
|
||
max_workers: int = DATA_LOADING_WORKERS,
|
||
max_batch_queue: int = 10):
|
||
"""
|
||
初始化批处理器
|
||
|
||
Args:
|
||
predictor: 预测器实例
|
||
batch_size: 批大小
|
||
max_workers: 最大工作线程数
|
||
max_batch_queue: 最大批次队列长度
|
||
"""
|
||
self.predictor = predictor
|
||
self.batch_size = batch_size
|
||
self.max_workers = max_workers
|
||
self.max_batch_queue = max_batch_queue
|
||
|
||
logger.info(f"初始化批处理器,批大小: {batch_size}, 最大工作线程数: {max_workers}")
|
||
|
||
def _extract_text_from_file(self, file_path: str) -> str:
|
||
"""
|
||
从文件中提取文本
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
|
||
Returns:
|
||
文本内容
|
||
"""
|
||
return read_text_file(file_path, encoding=ENCODING)
|
||
|
||
def _batch_generator(self, texts: List[str], batch_size: int) -> Iterator[List[str]]:
|
||
"""
|
||
生成文本批次
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
batch_size: 批大小
|
||
|
||
Returns:
|
||
文本批次生成器
|
||
"""
|
||
for i in range(0, len(texts), batch_size):
|
||
yield texts[i:i + batch_size]
|
||
|
||
def process_files(self, file_paths: List[str], output_path: Optional[str] = None,
|
||
return_top_k: int = 1, format: str = 'csv') -> pd.DataFrame:
|
||
"""
|
||
批量处理文件
|
||
|
||
Args:
|
||
file_paths: 文件路径列表
|
||
output_path: 输出文件路径,如果为None则不保存
|
||
return_top_k: 返回概率最高的前k个类别
|
||
format: 输出格式,'csv'或'json'
|
||
|
||
Returns:
|
||
预测结果DataFrame
|
||
"""
|
||
logger.info(f"开始批量处理 {len(file_paths)} 个文件")
|
||
start_time = time.time()
|
||
|
||
# 使用线程池并行读取文件
|
||
texts = []
|
||
file_names = []
|
||
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||
future_to_file = {executor.submit(self._extract_text_from_file, file_path): file_path for file_path in
|
||
file_paths}
|
||
|
||
for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(file_paths), desc="读取文件"):
|
||
file_path = future_to_file[future]
|
||
try:
|
||
text = future.result()
|
||
if text:
|
||
texts.append(text)
|
||
file_names.append(os.path.basename(file_path))
|
||
except Exception as e:
|
||
logger.error(f"处理文件 {file_path} 时出错: {e}")
|
||
|
||
# 批量预测
|
||
all_predictions = []
|
||
|
||
for batch in tqdm(self._batch_generator(texts, self.batch_size),
|
||
total=(len(texts) + self.batch_size - 1) // self.batch_size, desc="预测"):
|
||
predictions = self.predictor.predict_batch(batch, return_top_k=return_top_k, return_probabilities=True)
|
||
all_predictions.extend(predictions)
|
||
|
||
# 整合结果
|
||
if return_top_k > 1:
|
||
# 多个类别的情况
|
||
results = []
|
||
for i, preds in enumerate(all_predictions):
|
||
for j, pred in enumerate(preds):
|
||
results.append({
|
||
'file_name': file_names[i],
|
||
'rank': j + 1,
|
||
'predicted_class': pred['class'],
|
||
'probability': pred['probability']
|
||
})
|
||
df = pd.DataFrame(results)
|
||
else:
|
||
# 单个类别的情况
|
||
df = pd.DataFrame({
|
||
'file_name': file_names,
|
||
'predicted_class': [pred['class'] for pred in all_predictions],
|
||
'probability': [pred['probability'] for pred in all_predictions]
|
||
})
|
||
|
||
# 保存结果
|
||
if output_path:
|
||
if format.lower() == 'csv':
|
||
df.to_csv(output_path, index=False, encoding='utf-8')
|
||
elif format.lower() == 'json':
|
||
# 转换为嵌套的JSON格式
|
||
if return_top_k > 1:
|
||
# 分组后转换为嵌套格式
|
||
result = {}
|
||
for file_name in df['file_name'].unique():
|
||
sub_df = df[df['file_name'] == file_name]
|
||
predictions = []
|
||
for _, row in sub_df.iterrows():
|
||
predictions.append({
|
||
'class': row['predicted_class'],
|
||
'probability': row['probability']
|
||
})
|
||
result[file_name] = {
|
||
'predictions': predictions
|
||
}
|
||
else:
|
||
# 直接构建JSON
|
||
result = {}
|
||
for _, row in df.iterrows():
|
||
result[row['file_name']] = {
|
||
'predicted_class': row['predicted_class'],
|
||
'probability': row['probability']
|
||
}
|
||
|
||
# 保存为JSON
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
else:
|
||
raise ValueError(f"不支持的输出格式: {format}")
|
||
|
||
logger.info(f"预测结果已保存到: {output_path}")
|
||
|
||
processing_time = time.time() - start_time
|
||
logger.info(f"批量处理完成,共处理 {len(texts)} 个文件,用时: {processing_time:.2f} 秒")
|
||
|
||
return df
|
||
|
||
def process_directory(self, directory: str, pattern: str = "*.txt",
|
||
output_path: Optional[str] = None,
|
||
return_top_k: int = 1, format: str = 'csv',
|
||
recursive: bool = True) -> pd.DataFrame:
|
||
"""
|
||
批量处理目录中的文件
|
||
|
||
Args:
|
||
directory: 目录路径
|
||
pattern: 文件匹配模式
|
||
output_path: 输出文件路径,如果为None则不保存
|
||
return_top_k: 返回概率最高的前k个类别
|
||
format: 输出格式,'csv'或'json'
|
||
recursive: 是否递归处理子目录
|
||
|
||
Returns:
|
||
预测结果DataFrame
|
||
"""
|
||
# 获取符合模式的文件路径
|
||
if recursive:
|
||
file_paths = glob.glob(os.path.join(directory, "**", pattern), recursive=True)
|
||
else:
|
||
file_paths = glob.glob(os.path.join(directory, pattern))
|
||
|
||
if not file_paths:
|
||
logger.warning(f"在目录 {directory} 中未找到符合模式 {pattern} 的文件")
|
||
return pd.DataFrame()
|
||
|
||
logger.info(f"在目录 {directory} 中找到 {len(file_paths)} 个符合模式 {pattern} 的文件")
|
||
|
||
# 调用process_files处理文件
|
||
return self.process_files(file_paths, output_path, return_top_k, format)
|
||
|
||
def process_dataframe(self, df: pd.DataFrame, text_column: str,
|
||
id_column: Optional[str] = None,
|
||
output_path: Optional[str] = None,
|
||
return_top_k: int = 1, format: str = 'csv') -> pd.DataFrame:
|
||
"""
|
||
批量处理DataFrame中的文本
|
||
|
||
Args:
|
||
df: 输入DataFrame
|
||
text_column: 文本列名
|
||
id_column: ID列名,如果为None则使用索引
|
||
output_path: 输出文件路径,如果为None则不保存
|
||
return_top_k: 返回概率最高的前k个类别
|
||
format: 输出格式,'csv'或'json'
|
||
|
||
Returns:
|
||
预测结果DataFrame
|
||
"""
|
||
# 获取文本和ID
|
||
texts = df[text_column].tolist()
|
||
|
||
if id_column:
|
||
ids = df[id_column].tolist()
|
||
else:
|
||
ids = df.index.tolist()
|
||
|
||
# 批量预测
|
||
result_df = self.predictor.predict_to_dataframe(texts, ids, return_top_k)
|
||
|
||
# 保存结果
|
||
if output_path:
|
||
if format.lower() == 'csv':
|
||
result_df.to_csv(output_path, index=False, encoding='utf-8')
|
||
elif format.lower() == 'json':
|
||
# 转换为嵌套的JSON格式
|
||
self.predictor.save_predictions(texts, output_path, ids, return_top_k, 'json')
|
||
else:
|
||
raise ValueError(f"不支持的输出格式: {format}")
|
||
|
||
logger.info(f"预测结果已保存到: {output_path}")
|
||
|
||
return result_df
|
||
|
||
def process_large_file(self, file_path: str, output_path: Optional[str] = None,
|
||
return_top_k: int = 1, format: str = 'csv',
|
||
chunk_size: int = MAX_TEXT_PER_BATCH,
|
||
delimiter: str = '\n\n') -> None:
|
||
"""
|
||
处理大型文本文件,文件会被分块读取和处理
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
output_path: 输出文件路径,如果为None则不保存
|
||
return_top_k: 返回概率最高的前k个类别
|
||
format: 输出格式,'csv'或'json'
|
||
chunk_size: 每个块的大小(文本数量)
|
||
delimiter: 文本分隔符
|
||
"""
|
||
logger.info(f"开始处理大型文件: {file_path}")
|
||
start_time = time.time()
|
||
|
||
# 读取文件内容
|
||
with open(file_path, 'r', encoding=ENCODING) as f:
|
||
content = f.read()
|
||
|
||
# 分割文本
|
||
texts = content.split(delimiter)
|
||
texts = [text.strip() for text in texts if text.strip()]
|
||
|
||
logger.info(f"文件共包含 {len(texts)} 条文本")
|
||
|
||
# 创建输出文件
|
||
if output_path:
|
||
if format.lower() == 'csv':
|
||
# 创建CSV文件头
|
||
if return_top_k > 1:
|
||
header = "text_id,text,rank,predicted_class,probability\n"
|
||
else:
|
||
header = "text_id,text,predicted_class,probability\n"
|
||
|
||
with open(output_path, 'w', encoding=ENCODING) as f:
|
||
f.write(header)
|
||
elif format.lower() == 'json':
|
||
# 创建JSON文件
|
||
with open(output_path, 'w', encoding=ENCODING) as f:
|
||
f.write('{\n')
|
||
|
||
# 分块处理
|
||
total_chunks = (len(texts) + chunk_size - 1) // chunk_size
|
||
|
||
for i in range(0, len(texts), chunk_size):
|
||
chunk = texts[i:i + chunk_size]
|
||
chunk_ids = list(range(i, i + len(chunk)))
|
||
|
||
logger.info(f"处理第 {i // chunk_size + 1}/{total_chunks} 块,包含 {len(chunk)} 条文本")
|
||
|
||
# 批量预测
|
||
result_df = self.predictor.predict_to_dataframe(chunk, chunk_ids, return_top_k)
|
||
|
||
# 追加到输出文件
|
||
if output_path:
|
||
if format.lower() == 'csv':
|
||
result_df.to_csv(output_path, index=False, encoding=ENCODING, mode='a', header=False)
|
||
elif format.lower() == 'json':
|
||
# 转换为JSON并追加
|
||
if return_top_k > 1:
|
||
# 分组后转换为嵌套格式
|
||
for id_val in result_df['id'].unique():
|
||
sub_df = result_df[result_df['id'] == id_val]
|
||
predictions = []
|
||
for _, row in sub_df.iterrows():
|
||
predictions.append({
|
||
'class': row['predicted_class'],
|
||
'probability': float(row['probability'])
|
||
})
|
||
|
||
json_str = f' "{id_val}": {{\n'
|
||
json_str += f' "text": {json.dumps(sub_df.iloc[0]["text"], ensure_ascii=False)},\n'
|
||
json_str += f' "predictions": {json.dumps(predictions, ensure_ascii=False)}\n'
|
||
json_str += ' },'
|
||
|
||
with open(output_path, 'a', encoding=ENCODING) as f:
|
||
f.write(json_str + '\n')
|
||
else:
|
||
# 直接构建JSON
|
||
for _, row in result_df.iterrows():
|
||
json_str = f' "{row["id"]}": {{\n'
|
||
json_str += f' "text": {json.dumps(row["text"], ensure_ascii=False)},\n'
|
||
json_str += f' "predicted_class": "{row["predicted_class"]}",\n'
|
||
json_str += f' "probability": {float(row["probability"])}\n'
|
||
json_str += ' },'
|
||
|
||
with open(output_path, 'a', encoding=ENCODING) as f:
|
||
f.write(json_str + '\n')
|
||
|
||
# 完成JSON文件
|
||
if output_path and format.lower() == 'json':
|
||
with open(output_path, 'a', encoding=ENCODING) as f:
|
||
f.write('}\n')
|
||
|
||
# 修复JSON文件中的最后一个逗号
|
||
with open(output_path, 'r', encoding=ENCODING) as f:
|
||
content = f.read()
|
||
|
||
content = content.rstrip('\n}')
|
||
content = content.rstrip(',')
|
||
content += '\n}\n'
|
||
|
||
with open(output_path, 'w', encoding=ENCODING) as f:
|
||
f.write(content)
|
||
|
||
processing_time = time.time() - start_time
|
||
logger.info(f"处理大型文件完成,共处理 {len(texts)} 条文本,用时: {processing_time:.2f} 秒")
|