""" 预测器模块:实现模型预测功能,支持单条和批量文本预测 """ import os import time import numpy as np import tensorflow as tf from typing import List, Dict, Tuple, Optional, Any, Union import pandas as pd import json from config.system_config import CATEGORY_TO_ID, ID_TO_CATEGORY from models.base_model import TextClassificationModel from preprocessing.tokenization import ChineseTokenizer from preprocessing.vectorizer import SequenceVectorizer from utils.logger import get_logger logger = get_logger("Predictor") class Predictor: """预测器,负责加载模型和进行预测""" def __init__(self, model: TextClassificationModel, tokenizer: Optional[ChineseTokenizer] = None, vectorizer: Optional[SequenceVectorizer] = None, class_names: Optional[List[str]] = None, max_sequence_length: int = 500, batch_size: Optional[int] = None): """ 初始化预测器 Args: model: 已训练的模型实例 tokenizer: 分词器实例,如果为None则创建一个新的分词器 vectorizer: 文本向量化器实例,如果为None则表示模型直接接收序列 class_names: 类别名称列表,如果为None则使用ID_TO_CATEGORY max_sequence_length: 最大序列长度 batch_size: 批大小,如果为None则使用模型默认值 """ self.model = model self.tokenizer = tokenizer or ChineseTokenizer() self.vectorizer = vectorizer self.class_names = class_names if class_names is None and hasattr(model, 'num_classes'): # 如果模型具有类别数量信息,从ID_TO_CATEGORY获取类别名称 self.class_names = [ID_TO_CATEGORY.get(i, str(i)) for i in range(model.num_classes)] self.max_sequence_length = max_sequence_length self.batch_size = batch_size or (model.batch_size if hasattr(model, 'batch_size') else 32) logger.info(f"初始化预测器,批大小: {self.batch_size}") def preprocess_text(self, text: str) -> Any: """ 预处理单条文本 Args: text: 原始文本 Returns: 预处理后的文本表示 """ # 分词 tokenized_text = self.tokenizer.tokenize(text, return_string=True) # 如果有向量化器,应用向量化 if self.vectorizer is not None: return self.vectorizer.transform([tokenized_text])[0] return tokenized_text def preprocess_texts(self, texts: List[str]) -> Any: """ 批量预处理文本 Args: texts: 原始文本列表 Returns: 预处理后的批量文本表示 """ # 分词 tokenized_texts = [self.tokenizer.tokenize(text, return_string=True) for text in texts] # 如果有向量化器,应用向量化 if self.vectorizer is not None: return self.vectorizer.transform(tokenized_texts) return tokenized_texts def predict(self, text: str, return_top_k: int = 1, return_probabilities: bool = False) -> Union[str, Dict, List]: """ 预测单条文本的类别 Args: text: 原始文本 return_top_k: 返回概率最高的前k个类别 return_probabilities: 是否返回概率值 Returns: 预测结果,格式取决于参数设置 """ # 预处理文本 processed_text = self.preprocess_text(text) # 添加批次维度 if isinstance(processed_text, str): input_data = np.array([processed_text]) else: input_data = np.expand_dims(processed_text, axis=0) # 预测 start_time = time.time() predictions = self.model.predict(input_data) prediction_time = time.time() - start_time # 获取前k个预测结果 if return_top_k > 1: top_indices = np.argsort(predictions[0])[::-1][:return_top_k] top_probs = predictions[0][top_indices] if self.class_names: top_classes = [self.class_names[idx] for idx in top_indices] else: top_classes = [str(idx) for idx in top_indices] if return_probabilities: return [{'class': cls, 'probability': float(prob)} for cls, prob in zip(top_classes, top_probs)] else: return top_classes else: # 获取最高概率的类别 pred_idx = np.argmax(predictions[0]) pred_prob = float(predictions[0][pred_idx]) if self.class_names: pred_class = self.class_names[pred_idx] else: pred_class = str(pred_idx) if return_probabilities: return {'class': pred_class, 'probability': pred_prob, 'time': prediction_time} else: return pred_class def predict_batch(self, texts: List[str], return_top_k: int = 1, return_probabilities: bool = False) -> List: """ 批量预测文本类别 Args: texts: 原始文本列表 return_top_k: 返回概率最高的前k个类别 return_probabilities: 是否返回概率值 Returns: 预测结果列表 """ # 空列表检查 if not texts: return [] # 预处理文本 processed_texts = self.preprocess_texts(texts) # 预测 start_time = time.time() predictions = self.model.predict(processed_texts, batch_size=self.batch_size) prediction_time = time.time() - start_time # 处理预测结果 results = [] for i, pred in enumerate(predictions): if return_top_k > 1: top_indices = np.argsort(pred)[::-1][:return_top_k] top_probs = pred[top_indices] if self.class_names: top_classes = [self.class_names[idx] for idx in top_indices] else: top_classes = [str(idx) for idx in top_indices] if return_probabilities: results.append([{'class': cls, 'probability': float(prob)} for cls, prob in zip(top_classes, top_probs)]) else: results.append(top_classes) else: # 获取最高概率的类别 pred_idx = np.argmax(pred) pred_prob = float(pred[pred_idx]) if self.class_names: pred_class = self.class_names[pred_idx] else: pred_class = str(pred_idx) if return_probabilities: results.append({'class': pred_class, 'probability': pred_prob}) else: results.append(pred_class) logger.info(f"批量预测 {len(texts)} 条文本完成,用时: {prediction_time:.2f} 秒") return results def predict_to_dataframe(self, texts: List[str], text_ids: Optional[List[Union[str, int]]] = None, return_top_k: int = 1) -> pd.DataFrame: """ 批量预测并返回DataFrame Args: texts: 原始文本列表 text_ids: 文本ID列表,如果为None则使用索引 return_top_k: 返回概率最高的前k个类别 Returns: 预测结果DataFrame """ # 预测 predictions = self.predict_batch(texts, return_top_k=return_top_k, return_probabilities=True) # 创建DataFrame if text_ids is None: text_ids = list(range(len(texts))) if return_top_k > 1: # 多个类别的情况 results = [] for i, preds in enumerate(predictions): for j, pred in enumerate(preds): results.append({ 'id': text_ids[i], 'text': texts[i], 'rank': j + 1, 'predicted_class': pred['class'], 'probability': pred['probability'] }) df = pd.DataFrame(results) else: # 单个类别的情况 df = pd.DataFrame({ 'id': text_ids, 'text': texts, 'predicted_class': [pred['class'] for pred in predictions], 'probability': [pred['probability'] for pred in predictions] }) return df def save_predictions(self, texts: List[str], output_path: str, text_ids: Optional[List[Union[str, int]]] = None, return_top_k: int = 1, format: str = 'csv') -> str: """ 批量预测并保存结果 Args: texts: 原始文本列表 output_path: 输出文件路径 text_ids: 文本ID列表,如果为None则使用索引 return_top_k: 返回概率最高的前k个类别 format: 输出格式,'csv'或'json' Returns: 输出文件路径 """ # 获取预测结果DataFrame df = self.predict_to_dataframe(texts, text_ids, return_top_k) # 保存结果 if format.lower() == 'csv': df.to_csv(output_path, index=False, encoding='utf-8') elif format.lower() == 'json': # 转换为嵌套的JSON格式 if return_top_k > 1: # 分组后转换为嵌套格式 result = {} for id_val in df['id'].unique(): sub_df = df[df['id'] == id_val] predictions = [] for _, row in sub_df.iterrows(): predictions.append({ 'class': row['predicted_class'], 'probability': row['probability'] }) result[str(id_val)] = { 'text': sub_df.iloc[0]['text'], 'predictions': predictions } else: # 直接构建JSON result = {} for _, row in df.iterrows(): result[str(row['id'])] = { 'text': row['text'], 'predicted_class': row['predicted_class'], 'probability': row['probability'] } # 保存为JSON with open(output_path, 'w', encoding='utf-8') as f: json.dump(result, f, ensure_ascii=False, indent=2) else: raise ValueError(f"不支持的输出格式: {format}") logger.info(f"预测结果已保存到: {output_path}") return output_path