""" 预测脚本:使用模型进行预测 """ import os import sys import time import argparse import logging from typing import List, Dict, Tuple, Optional, Any, Union import numpy as np import json # 将项目根目录添加到系统路径 project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(project_root) from config.system_config import ( CATEGORIES, CLASSIFIERS_DIR ) from models.model_factory import ModelFactory from preprocessing.tokenization import ChineseTokenizer from preprocessing.vectorizer import SequenceVectorizer from inference.predictor import Predictor from inference.batch_processor import BatchProcessor from utils.logger import get_logger from utils.file_utils import read_text_file logger = get_logger("Prediction") def predict_text(text: str, model_path: Optional[str] = None, output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]: """ 预测单条文本 Args: text: 要预测的文本 model_path: 模型路径,如果为None则使用最新的模型 output_path: 输出文件路径,如果为None则不保存 top_k: 返回概率最高的前k个类别 Returns: 预测结果 """ logger.info("开始预测文本") # 1. 加载模型 if model_path is None: # 获取可用模型列表 models_info = ModelFactory.get_available_models() if not models_info: raise ValueError("未找到可用的模型") # 使用最新的模型 model_path = models_info[0]['path'] logger.info(f"加载模型: {model_path}") model = ModelFactory.load_model(model_path) # 2. 创建分词器和预测器 tokenizer = ChineseTokenizer() # 查找向量化器文件 vectorizer = None for model_type in ["cnn", "rnn", "transformer"]: vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl") if os.path.exists(vectorizer_path): # 加载向量化器 vectorizer = SequenceVectorizer() vectorizer.load(vectorizer_path) logger.info(f"加载向量化器: {vectorizer_path}") break # 创建预测器 predictor = Predictor( model=model, tokenizer=tokenizer, vectorizer=vectorizer, class_names=CATEGORIES ) # 3. 预测 result = predictor.predict( text=text, return_top_k=top_k, return_probabilities=True ) # 4. 输出结果 if top_k > 1: logger.info("预测结果:") for i, pred in enumerate(result): logger.info(f" {i + 1}. {pred['class']} (概率: {pred['probability']:.4f})") else: logger.info(f"预测结果: {result['class']} (概率: {result['probability']:.4f})") # 5. 保存结果 if output_path: if output_path.endswith('.json'): with open(output_path, 'w', encoding='utf-8') as f: json.dump(result, f, ensure_ascii=False, indent=2) else: with open(output_path, 'w', encoding='utf-8') as f: if top_k > 1: f.write("rank,class,probability\n") for i, pred in enumerate(result): f.write(f"{i + 1},{pred['class']},{pred['probability']}\n") else: f.write(f"class,probability\n") f.write(f"{result['class']},{result['probability']}\n") logger.info(f"结果已保存到: {output_path}") return result def predict_file(file_path: str, model_path: Optional[str] = None, output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]: """ 预测文件内容 Args: file_path: 文件路径 model_path: 模型路径,如果为None则使用最新的模型 output_path: 输出文件路径,如果为None则不保存 top_k: 返回概率最高的前k个类别 Returns: 预测结果 """ logger.info(f"开始预测文件: {file_path}") # 检查文件是否存在 if not os.path.exists(file_path): raise FileNotFoundError(f"文件不存在: {file_path}") # 读取文件内容 if file_path.endswith('.txt'): # 文本文件 text = read_text_file(file_path) return predict_text(text, model_path, output_path, top_k) elif file_path.endswith(('.csv', '.xls', '.xlsx')): # 表格文件 import pandas as pd if file_path.endswith('.csv'): df = pd.read_csv(file_path) else: df = pd.read_excel(file_path) # 查找可能的文本列 text_columns = [col for col in df.columns if df[col].dtype == 'object'] if not text_columns: raise ValueError("文件中没有找到可能的文本列") # 使用第一个文本列 text_column = text_columns[0] logger.info(f"使用文本列: {text_column}") # 1. 加载模型 if model_path is None: # 获取可用模型列表 models_info = ModelFactory.get_available_models() if not models_info: raise ValueError("未找到可用的模型") # 使用最新的模型 model_path = models_info[0]['path'] logger.info(f"加载模型: {model_path}") model = ModelFactory.load_model(model_path) # 2. 创建分词器和预测器 tokenizer = ChineseTokenizer() # 查找向量化器文件 vectorizer = None for model_type in ["cnn", "rnn", "transformer"]: vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl") if os.path.exists(vectorizer_path): # 加载向量化器 vectorizer = SequenceVectorizer() vectorizer.load(vectorizer_path) logger.info(f"加载向量化器: {vectorizer_path}") break # 创建预测器 predictor = Predictor( model=model, tokenizer=tokenizer, vectorizer=vectorizer, class_names=CATEGORIES ) # 3. 创建批处理器 batch_processor = BatchProcessor( predictor=predictor, batch_size=64 ) # 4. 批量预测 result_df = batch_processor.process_dataframe( df=df, text_column=text_column, output_path=output_path, return_top_k=top_k, format=output_path.split('.')[-1] if output_path else 'csv' ) logger.info(f"已处理 {len(result_df)} 行数据") # 返回结果 return result_df.to_dict(orient='records') else: raise ValueError(f"不支持的文件类型: {file_path}") if __name__ == "__main__": # 解析命令行参数 parser = argparse.ArgumentParser(description="使用模型预测") parser.add_argument("--model_path", help="模型路径") parser.add_argument("--text", help="要预测的文本") parser.add_argument("--file", help="要预测的文件") parser.add_argument("--output", help="输出文件") parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别") args = parser.parse_args() # 检查输入 if not args.text and not args.file: parser.error("请提供要预测的文本或文件") # 预测 if args.text: predict_text(args.text, args.model_path, args.output, args.top_k) else: predict_file(args.file, args.model_path, args.output, args.top_k)