243 lines
7.5 KiB
Python
243 lines
7.5 KiB
Python
"""
|
||
预测脚本:使用模型进行预测
|
||
"""
|
||
import os
|
||
import sys
|
||
import time
|
||
import argparse
|
||
import logging
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import numpy as np
|
||
import json
|
||
|
||
# 将项目根目录添加到系统路径
|
||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
sys.path.append(project_root)
|
||
|
||
from config.system_config import (
|
||
CATEGORIES, CLASSIFIERS_DIR
|
||
)
|
||
from models.model_factory import ModelFactory
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
from preprocessing.vectorizer import SequenceVectorizer
|
||
from inference.predictor import Predictor
|
||
from inference.batch_processor import BatchProcessor
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import read_text_file
|
||
|
||
logger = get_logger("Prediction")
|
||
|
||
|
||
def predict_text(text: str, model_path: Optional[str] = None,
|
||
output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]:
|
||
"""
|
||
预测单条文本
|
||
|
||
Args:
|
||
text: 要预测的文本
|
||
model_path: 模型路径,如果为None则使用最新的模型
|
||
output_path: 输出文件路径,如果为None则不保存
|
||
top_k: 返回概率最高的前k个类别
|
||
|
||
Returns:
|
||
预测结果
|
||
"""
|
||
logger.info("开始预测文本")
|
||
|
||
# 1. 加载模型
|
||
if model_path is None:
|
||
# 获取可用模型列表
|
||
models_info = ModelFactory.get_available_models()
|
||
|
||
if not models_info:
|
||
raise ValueError("未找到可用的模型")
|
||
|
||
# 使用最新的模型
|
||
model_path = models_info[0]['path']
|
||
|
||
logger.info(f"加载模型: {model_path}")
|
||
model = ModelFactory.load_model(model_path)
|
||
|
||
# 2. 创建分词器和预测器
|
||
tokenizer = ChineseTokenizer()
|
||
|
||
# 查找向量化器文件
|
||
vectorizer = None
|
||
for model_type in ["cnn", "rnn", "transformer"]:
|
||
vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
|
||
if os.path.exists(vectorizer_path):
|
||
# 加载向量化器
|
||
vectorizer = SequenceVectorizer()
|
||
vectorizer.load(vectorizer_path)
|
||
logger.info(f"加载向量化器: {vectorizer_path}")
|
||
break
|
||
|
||
# 创建预测器
|
||
predictor = Predictor(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
vectorizer=vectorizer,
|
||
class_names=CATEGORIES
|
||
)
|
||
|
||
# 3. 预测
|
||
result = predictor.predict(
|
||
text=text,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 4. 输出结果
|
||
if top_k > 1:
|
||
logger.info("预测结果:")
|
||
for i, pred in enumerate(result):
|
||
logger.info(f" {i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
|
||
else:
|
||
logger.info(f"预测结果: {result['class']} (概率: {result['probability']:.4f})")
|
||
|
||
# 5. 保存结果
|
||
if output_path:
|
||
if output_path.endswith('.json'):
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
else:
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
if top_k > 1:
|
||
f.write("rank,class,probability\n")
|
||
for i, pred in enumerate(result):
|
||
f.write(f"{i + 1},{pred['class']},{pred['probability']}\n")
|
||
else:
|
||
f.write(f"class,probability\n")
|
||
f.write(f"{result['class']},{result['probability']}\n")
|
||
|
||
logger.info(f"结果已保存到: {output_path}")
|
||
|
||
return result
|
||
|
||
|
||
def predict_file(file_path: str, model_path: Optional[str] = None,
|
||
output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]:
|
||
"""
|
||
预测文件内容
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
model_path: 模型路径,如果为None则使用最新的模型
|
||
output_path: 输出文件路径,如果为None则不保存
|
||
top_k: 返回概率最高的前k个类别
|
||
|
||
Returns:
|
||
预测结果
|
||
"""
|
||
logger.info(f"开始预测文件: {file_path}")
|
||
|
||
# 检查文件是否存在
|
||
if not os.path.exists(file_path):
|
||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||
|
||
# 读取文件内容
|
||
if file_path.endswith('.txt'):
|
||
# 文本文件
|
||
text = read_text_file(file_path)
|
||
return predict_text(text, model_path, output_path, top_k)
|
||
|
||
elif file_path.endswith(('.csv', '.xls', '.xlsx')):
|
||
# 表格文件
|
||
import pandas as pd
|
||
|
||
if file_path.endswith('.csv'):
|
||
df = pd.read_csv(file_path)
|
||
else:
|
||
df = pd.read_excel(file_path)
|
||
|
||
# 查找可能的文本列
|
||
text_columns = [col for col in df.columns if df[col].dtype == 'object']
|
||
|
||
if not text_columns:
|
||
raise ValueError("文件中没有找到可能的文本列")
|
||
|
||
# 使用第一个文本列
|
||
text_column = text_columns[0]
|
||
logger.info(f"使用文本列: {text_column}")
|
||
|
||
# 1. 加载模型
|
||
if model_path is None:
|
||
# 获取可用模型列表
|
||
models_info = ModelFactory.get_available_models()
|
||
|
||
if not models_info:
|
||
raise ValueError("未找到可用的模型")
|
||
|
||
# 使用最新的模型
|
||
model_path = models_info[0]['path']
|
||
|
||
logger.info(f"加载模型: {model_path}")
|
||
model = ModelFactory.load_model(model_path)
|
||
|
||
# 2. 创建分词器和预测器
|
||
tokenizer = ChineseTokenizer()
|
||
|
||
# 查找向量化器文件
|
||
vectorizer = None
|
||
for model_type in ["cnn", "rnn", "transformer"]:
|
||
vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
|
||
if os.path.exists(vectorizer_path):
|
||
# 加载向量化器
|
||
vectorizer = SequenceVectorizer()
|
||
vectorizer.load(vectorizer_path)
|
||
logger.info(f"加载向量化器: {vectorizer_path}")
|
||
break
|
||
|
||
# 创建预测器
|
||
predictor = Predictor(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
vectorizer=vectorizer,
|
||
class_names=CATEGORIES
|
||
)
|
||
|
||
# 3. 创建批处理器
|
||
batch_processor = BatchProcessor(
|
||
predictor=predictor,
|
||
batch_size=64
|
||
)
|
||
|
||
# 4. 批量预测
|
||
result_df = batch_processor.process_dataframe(
|
||
df=df,
|
||
text_column=text_column,
|
||
output_path=output_path,
|
||
return_top_k=top_k,
|
||
format=output_path.split('.')[-1] if output_path else 'csv'
|
||
)
|
||
|
||
logger.info(f"已处理 {len(result_df)} 行数据")
|
||
|
||
# 返回结果
|
||
return result_df.to_dict(orient='records')
|
||
|
||
else:
|
||
raise ValueError(f"不支持的文件类型: {file_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 解析命令行参数
|
||
parser = argparse.ArgumentParser(description="使用模型预测")
|
||
parser.add_argument("--model_path", help="模型路径")
|
||
parser.add_argument("--text", help="要预测的文本")
|
||
parser.add_argument("--file", help="要预测的文件")
|
||
parser.add_argument("--output", help="输出文件")
|
||
parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 检查输入
|
||
if not args.text and not args.file:
|
||
parser.error("请提供要预测的文本或文件")
|
||
|
||
# 预测
|
||
if args.text:
|
||
predict_text(args.text, args.model_path, args.output, args.top_k)
|
||
else:
|
||
predict_file(args.file, args.model_path, args.output, args.top_k)
|