2025-03-08 01:34:36 +08:00

317 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
预测器模块:实现模型预测功能,支持单条和批量文本预测
"""
import os
import time
import numpy as np
import tensorflow as tf
from typing import List, Dict, Tuple, Optional, Any, Union
import pandas as pd
import json
from config.system_config import CATEGORY_TO_ID, ID_TO_CATEGORY
from models.base_model import TextClassificationModel
from preprocessing.tokenization import ChineseTokenizer
from preprocessing.vectorizer import SequenceVectorizer
from utils.logger import get_logger
logger = get_logger("Predictor")
class Predictor:
"""预测器,负责加载模型和进行预测"""
def __init__(self, model: TextClassificationModel,
tokenizer: Optional[ChineseTokenizer] = None,
vectorizer: Optional[SequenceVectorizer] = None,
class_names: Optional[List[str]] = None,
max_sequence_length: int = 500,
batch_size: Optional[int] = None):
"""
初始化预测器
Args:
model: 已训练的模型实例
tokenizer: 分词器实例如果为None则创建一个新的分词器
vectorizer: 文本向量化器实例如果为None则表示模型直接接收序列
class_names: 类别名称列表如果为None则使用ID_TO_CATEGORY
max_sequence_length: 最大序列长度
batch_size: 批大小如果为None则使用模型默认值
"""
self.model = model
self.tokenizer = tokenizer or ChineseTokenizer()
self.vectorizer = vectorizer
self.class_names = class_names
if class_names is None and hasattr(model, 'num_classes'):
# 如果模型具有类别数量信息从ID_TO_CATEGORY获取类别名称
self.class_names = [ID_TO_CATEGORY.get(i, str(i)) for i in range(model.num_classes)]
self.max_sequence_length = max_sequence_length
self.batch_size = batch_size or (model.batch_size if hasattr(model, 'batch_size') else 32)
logger.info(f"初始化预测器,批大小: {self.batch_size}")
def preprocess_text(self, text: str) -> Any:
"""
预处理单条文本
Args:
text: 原始文本
Returns:
预处理后的文本表示
"""
# 分词
tokenized_text = self.tokenizer.tokenize(text, return_string=True)
# 如果有向量化器,应用向量化
if self.vectorizer is not None:
return self.vectorizer.transform([tokenized_text])[0]
return tokenized_text
def preprocess_texts(self, texts: List[str]) -> Any:
"""
批量预处理文本
Args:
texts: 原始文本列表
Returns:
预处理后的批量文本表示
"""
# 分词
tokenized_texts = [self.tokenizer.tokenize(text, return_string=True) for text in texts]
# 如果有向量化器,应用向量化
if self.vectorizer is not None:
return self.vectorizer.transform(tokenized_texts)
return tokenized_texts
def predict(self, text: str, return_top_k: int = 1,
return_probabilities: bool = False) -> Union[str, Dict, List]:
"""
预测单条文本的类别
Args:
text: 原始文本
return_top_k: 返回概率最高的前k个类别
return_probabilities: 是否返回概率值
Returns:
预测结果,格式取决于参数设置
"""
# 预处理文本
processed_text = self.preprocess_text(text)
# 添加批次维度
if isinstance(processed_text, str):
input_data = np.array([processed_text])
else:
input_data = np.expand_dims(processed_text, axis=0)
# 预测
start_time = time.time()
predictions = self.model.predict(input_data)
prediction_time = time.time() - start_time
# 获取前k个预测结果
if return_top_k > 1:
top_indices = np.argsort(predictions[0])[::-1][:return_top_k]
top_probs = predictions[0][top_indices]
if self.class_names:
top_classes = [self.class_names[idx] for idx in top_indices]
else:
top_classes = [str(idx) for idx in top_indices]
if return_probabilities:
return [{'class': cls, 'probability': float(prob)}
for cls, prob in zip(top_classes, top_probs)]
else:
return top_classes
else:
# 获取最高概率的类别
pred_idx = np.argmax(predictions[0])
pred_prob = float(predictions[0][pred_idx])
if self.class_names:
pred_class = self.class_names[pred_idx]
else:
pred_class = str(pred_idx)
if return_probabilities:
return {'class': pred_class, 'probability': pred_prob, 'time': prediction_time}
else:
return pred_class
def predict_batch(self, texts: List[str], return_top_k: int = 1,
return_probabilities: bool = False) -> List:
"""
批量预测文本类别
Args:
texts: 原始文本列表
return_top_k: 返回概率最高的前k个类别
return_probabilities: 是否返回概率值
Returns:
预测结果列表
"""
# 空列表检查
if not texts:
return []
# 预处理文本
processed_texts = self.preprocess_texts(texts)
# 预测
start_time = time.time()
predictions = self.model.predict(processed_texts, batch_size=self.batch_size)
prediction_time = time.time() - start_time
# 处理预测结果
results = []
for i, pred in enumerate(predictions):
if return_top_k > 1:
top_indices = np.argsort(pred)[::-1][:return_top_k]
top_probs = pred[top_indices]
if self.class_names:
top_classes = [self.class_names[idx] for idx in top_indices]
else:
top_classes = [str(idx) for idx in top_indices]
if return_probabilities:
results.append([{'class': cls, 'probability': float(prob)}
for cls, prob in zip(top_classes, top_probs)])
else:
results.append(top_classes)
else:
# 获取最高概率的类别
pred_idx = np.argmax(pred)
pred_prob = float(pred[pred_idx])
if self.class_names:
pred_class = self.class_names[pred_idx]
else:
pred_class = str(pred_idx)
if return_probabilities:
results.append({'class': pred_class, 'probability': pred_prob})
else:
results.append(pred_class)
logger.info(f"批量预测 {len(texts)} 条文本完成,用时: {prediction_time:.2f}")
return results
def predict_to_dataframe(self, texts: List[str],
text_ids: Optional[List[Union[str, int]]] = None,
return_top_k: int = 1) -> pd.DataFrame:
"""
批量预测并返回DataFrame
Args:
texts: 原始文本列表
text_ids: 文本ID列表如果为None则使用索引
return_top_k: 返回概率最高的前k个类别
Returns:
预测结果DataFrame
"""
# 预测
predictions = self.predict_batch(texts, return_top_k=return_top_k, return_probabilities=True)
# 创建DataFrame
if text_ids is None:
text_ids = list(range(len(texts)))
if return_top_k > 1:
# 多个类别的情况
results = []
for i, preds in enumerate(predictions):
for j, pred in enumerate(preds):
results.append({
'id': text_ids[i],
'text': texts[i],
'rank': j + 1,
'predicted_class': pred['class'],
'probability': pred['probability']
})
df = pd.DataFrame(results)
else:
# 单个类别的情况
df = pd.DataFrame({
'id': text_ids,
'text': texts,
'predicted_class': [pred['class'] for pred in predictions],
'probability': [pred['probability'] for pred in predictions]
})
return df
def save_predictions(self, texts: List[str],
output_path: str,
text_ids: Optional[List[Union[str, int]]] = None,
return_top_k: int = 1,
format: str = 'csv') -> str:
"""
批量预测并保存结果
Args:
texts: 原始文本列表
output_path: 输出文件路径
text_ids: 文本ID列表如果为None则使用索引
return_top_k: 返回概率最高的前k个类别
format: 输出格式,'csv''json'
Returns:
输出文件路径
"""
# 获取预测结果DataFrame
df = self.predict_to_dataframe(texts, text_ids, return_top_k)
# 保存结果
if format.lower() == 'csv':
df.to_csv(output_path, index=False, encoding='utf-8')
elif format.lower() == 'json':
# 转换为嵌套的JSON格式
if return_top_k > 1:
# 分组后转换为嵌套格式
result = {}
for id_val in df['id'].unique():
sub_df = df[df['id'] == id_val]
predictions = []
for _, row in sub_df.iterrows():
predictions.append({
'class': row['predicted_class'],
'probability': row['probability']
})
result[str(id_val)] = {
'text': sub_df.iloc[0]['text'],
'predictions': predictions
}
else:
# 直接构建JSON
result = {}
for _, row in df.iterrows():
result[str(row['id'])] = {
'text': row['text'],
'predicted_class': row['predicted_class'],
'probability': row['probability']
}
# 保存为JSON
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
else:
raise ValueError(f"不支持的输出格式: {format}")
logger.info(f"预测结果已保存到: {output_path}")
return output_path