317 lines
11 KiB
Python
317 lines
11 KiB
Python
"""
|
||
预测器模块:实现模型预测功能,支持单条和批量文本预测
|
||
"""
|
||
import os
|
||
import time
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import pandas as pd
|
||
import json
|
||
|
||
from config.system_config import CATEGORY_TO_ID, ID_TO_CATEGORY
|
||
from models.base_model import TextClassificationModel
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
from preprocessing.vectorizer import SequenceVectorizer
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("Predictor")
|
||
|
||
|
||
class Predictor:
|
||
"""预测器,负责加载模型和进行预测"""
|
||
|
||
def __init__(self, model: TextClassificationModel,
|
||
tokenizer: Optional[ChineseTokenizer] = None,
|
||
vectorizer: Optional[SequenceVectorizer] = None,
|
||
class_names: Optional[List[str]] = None,
|
||
max_sequence_length: int = 500,
|
||
batch_size: Optional[int] = None):
|
||
"""
|
||
初始化预测器
|
||
|
||
Args:
|
||
model: 已训练的模型实例
|
||
tokenizer: 分词器实例,如果为None则创建一个新的分词器
|
||
vectorizer: 文本向量化器实例,如果为None则表示模型直接接收序列
|
||
class_names: 类别名称列表,如果为None则使用ID_TO_CATEGORY
|
||
max_sequence_length: 最大序列长度
|
||
batch_size: 批大小,如果为None则使用模型默认值
|
||
"""
|
||
self.model = model
|
||
self.tokenizer = tokenizer or ChineseTokenizer()
|
||
self.vectorizer = vectorizer
|
||
self.class_names = class_names
|
||
|
||
if class_names is None and hasattr(model, 'num_classes'):
|
||
# 如果模型具有类别数量信息,从ID_TO_CATEGORY获取类别名称
|
||
self.class_names = [ID_TO_CATEGORY.get(i, str(i)) for i in range(model.num_classes)]
|
||
|
||
self.max_sequence_length = max_sequence_length
|
||
self.batch_size = batch_size or (model.batch_size if hasattr(model, 'batch_size') else 32)
|
||
|
||
logger.info(f"初始化预测器,批大小: {self.batch_size}")
|
||
|
||
def preprocess_text(self, text: str) -> Any:
|
||
"""
|
||
预处理单条文本
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
预处理后的文本表示
|
||
"""
|
||
# 分词
|
||
tokenized_text = self.tokenizer.tokenize(text, return_string=True)
|
||
|
||
# 如果有向量化器,应用向量化
|
||
if self.vectorizer is not None:
|
||
return self.vectorizer.transform([tokenized_text])[0]
|
||
|
||
return tokenized_text
|
||
|
||
def preprocess_texts(self, texts: List[str]) -> Any:
|
||
"""
|
||
批量预处理文本
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
|
||
Returns:
|
||
预处理后的批量文本表示
|
||
"""
|
||
# 分词
|
||
tokenized_texts = [self.tokenizer.tokenize(text, return_string=True) for text in texts]
|
||
|
||
# 如果有向量化器,应用向量化
|
||
if self.vectorizer is not None:
|
||
return self.vectorizer.transform(tokenized_texts)
|
||
|
||
return tokenized_texts
|
||
|
||
def predict(self, text: str, return_top_k: int = 1,
|
||
return_probabilities: bool = False) -> Union[str, Dict, List]:
|
||
"""
|
||
预测单条文本的类别
|
||
|
||
Args:
|
||
text: 原始文本
|
||
return_top_k: 返回概率最高的前k个类别
|
||
return_probabilities: 是否返回概率值
|
||
|
||
Returns:
|
||
预测结果,格式取决于参数设置
|
||
"""
|
||
# 预处理文本
|
||
processed_text = self.preprocess_text(text)
|
||
|
||
# 添加批次维度
|
||
if isinstance(processed_text, str):
|
||
input_data = np.array([processed_text])
|
||
else:
|
||
input_data = np.expand_dims(processed_text, axis=0)
|
||
|
||
# 预测
|
||
start_time = time.time()
|
||
predictions = self.model.predict(input_data)
|
||
prediction_time = time.time() - start_time
|
||
|
||
# 获取前k个预测结果
|
||
if return_top_k > 1:
|
||
top_indices = np.argsort(predictions[0])[::-1][:return_top_k]
|
||
top_probs = predictions[0][top_indices]
|
||
|
||
if self.class_names:
|
||
top_classes = [self.class_names[idx] for idx in top_indices]
|
||
else:
|
||
top_classes = [str(idx) for idx in top_indices]
|
||
|
||
if return_probabilities:
|
||
return [{'class': cls, 'probability': float(prob)}
|
||
for cls, prob in zip(top_classes, top_probs)]
|
||
else:
|
||
return top_classes
|
||
else:
|
||
# 获取最高概率的类别
|
||
pred_idx = np.argmax(predictions[0])
|
||
pred_prob = float(predictions[0][pred_idx])
|
||
|
||
if self.class_names:
|
||
pred_class = self.class_names[pred_idx]
|
||
else:
|
||
pred_class = str(pred_idx)
|
||
|
||
if return_probabilities:
|
||
return {'class': pred_class, 'probability': pred_prob, 'time': prediction_time}
|
||
else:
|
||
return pred_class
|
||
|
||
def predict_batch(self, texts: List[str], return_top_k: int = 1,
|
||
return_probabilities: bool = False) -> List:
|
||
"""
|
||
批量预测文本类别
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
return_top_k: 返回概率最高的前k个类别
|
||
return_probabilities: 是否返回概率值
|
||
|
||
Returns:
|
||
预测结果列表
|
||
"""
|
||
# 空列表检查
|
||
if not texts:
|
||
return []
|
||
|
||
# 预处理文本
|
||
processed_texts = self.preprocess_texts(texts)
|
||
|
||
# 预测
|
||
start_time = time.time()
|
||
predictions = self.model.predict(processed_texts, batch_size=self.batch_size)
|
||
prediction_time = time.time() - start_time
|
||
|
||
# 处理预测结果
|
||
results = []
|
||
|
||
for i, pred in enumerate(predictions):
|
||
if return_top_k > 1:
|
||
top_indices = np.argsort(pred)[::-1][:return_top_k]
|
||
top_probs = pred[top_indices]
|
||
|
||
if self.class_names:
|
||
top_classes = [self.class_names[idx] for idx in top_indices]
|
||
else:
|
||
top_classes = [str(idx) for idx in top_indices]
|
||
|
||
if return_probabilities:
|
||
results.append([{'class': cls, 'probability': float(prob)}
|
||
for cls, prob in zip(top_classes, top_probs)])
|
||
else:
|
||
results.append(top_classes)
|
||
else:
|
||
# 获取最高概率的类别
|
||
pred_idx = np.argmax(pred)
|
||
pred_prob = float(pred[pred_idx])
|
||
|
||
if self.class_names:
|
||
pred_class = self.class_names[pred_idx]
|
||
else:
|
||
pred_class = str(pred_idx)
|
||
|
||
if return_probabilities:
|
||
results.append({'class': pred_class, 'probability': pred_prob})
|
||
else:
|
||
results.append(pred_class)
|
||
|
||
logger.info(f"批量预测 {len(texts)} 条文本完成,用时: {prediction_time:.2f} 秒")
|
||
|
||
return results
|
||
|
||
def predict_to_dataframe(self, texts: List[str],
|
||
text_ids: Optional[List[Union[str, int]]] = None,
|
||
return_top_k: int = 1) -> pd.DataFrame:
|
||
"""
|
||
批量预测并返回DataFrame
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
text_ids: 文本ID列表,如果为None则使用索引
|
||
return_top_k: 返回概率最高的前k个类别
|
||
|
||
Returns:
|
||
预测结果DataFrame
|
||
"""
|
||
# 预测
|
||
predictions = self.predict_batch(texts, return_top_k=return_top_k, return_probabilities=True)
|
||
|
||
# 创建DataFrame
|
||
if text_ids is None:
|
||
text_ids = list(range(len(texts)))
|
||
|
||
if return_top_k > 1:
|
||
# 多个类别的情况
|
||
results = []
|
||
for i, preds in enumerate(predictions):
|
||
for j, pred in enumerate(preds):
|
||
results.append({
|
||
'id': text_ids[i],
|
||
'text': texts[i],
|
||
'rank': j + 1,
|
||
'predicted_class': pred['class'],
|
||
'probability': pred['probability']
|
||
})
|
||
df = pd.DataFrame(results)
|
||
else:
|
||
# 单个类别的情况
|
||
df = pd.DataFrame({
|
||
'id': text_ids,
|
||
'text': texts,
|
||
'predicted_class': [pred['class'] for pred in predictions],
|
||
'probability': [pred['probability'] for pred in predictions]
|
||
})
|
||
|
||
return df
|
||
|
||
def save_predictions(self, texts: List[str],
|
||
output_path: str,
|
||
text_ids: Optional[List[Union[str, int]]] = None,
|
||
return_top_k: int = 1,
|
||
format: str = 'csv') -> str:
|
||
"""
|
||
批量预测并保存结果
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
output_path: 输出文件路径
|
||
text_ids: 文本ID列表,如果为None则使用索引
|
||
return_top_k: 返回概率最高的前k个类别
|
||
format: 输出格式,'csv'或'json'
|
||
|
||
Returns:
|
||
输出文件路径
|
||
"""
|
||
# 获取预测结果DataFrame
|
||
df = self.predict_to_dataframe(texts, text_ids, return_top_k)
|
||
|
||
# 保存结果
|
||
if format.lower() == 'csv':
|
||
df.to_csv(output_path, index=False, encoding='utf-8')
|
||
elif format.lower() == 'json':
|
||
# 转换为嵌套的JSON格式
|
||
if return_top_k > 1:
|
||
# 分组后转换为嵌套格式
|
||
result = {}
|
||
for id_val in df['id'].unique():
|
||
sub_df = df[df['id'] == id_val]
|
||
predictions = []
|
||
for _, row in sub_df.iterrows():
|
||
predictions.append({
|
||
'class': row['predicted_class'],
|
||
'probability': row['probability']
|
||
})
|
||
result[str(id_val)] = {
|
||
'text': sub_df.iloc[0]['text'],
|
||
'predictions': predictions
|
||
}
|
||
else:
|
||
# 直接构建JSON
|
||
result = {}
|
||
for _, row in df.iterrows():
|
||
result[str(row['id'])] = {
|
||
'text': row['text'],
|
||
'predicted_class': row['predicted_class'],
|
||
'probability': row['probability']
|
||
}
|
||
|
||
# 保存为JSON
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
else:
|
||
raise ValueError(f"不支持的输出格式: {format}")
|
||
|
||
logger.info(f"预测结果已保存到: {output_path}")
|
||
|
||
return output_path
|