2025-03-08 01:34:36 +08:00

243 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
预测脚本:使用模型进行预测
"""
import os
import sys
import time
import argparse
import logging
from typing import List, Dict, Tuple, Optional, Any, Union
import numpy as np
import json
# 将项目根目录添加到系统路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
from config.system_config import (
CATEGORIES, CLASSIFIERS_DIR
)
from models.model_factory import ModelFactory
from preprocessing.tokenization import ChineseTokenizer
from preprocessing.vectorizer import SequenceVectorizer
from inference.predictor import Predictor
from inference.batch_processor import BatchProcessor
from utils.logger import get_logger
from utils.file_utils import read_text_file
logger = get_logger("Prediction")
def predict_text(text: str, model_path: Optional[str] = None,
output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]:
"""
预测单条文本
Args:
text: 要预测的文本
model_path: 模型路径如果为None则使用最新的模型
output_path: 输出文件路径如果为None则不保存
top_k: 返回概率最高的前k个类别
Returns:
预测结果
"""
logger.info("开始预测文本")
# 1. 加载模型
if model_path is None:
# 获取可用模型列表
models_info = ModelFactory.get_available_models()
if not models_info:
raise ValueError("未找到可用的模型")
# 使用最新的模型
model_path = models_info[0]['path']
logger.info(f"加载模型: {model_path}")
model = ModelFactory.load_model(model_path)
# 2. 创建分词器和预测器
tokenizer = ChineseTokenizer()
# 查找向量化器文件
vectorizer = None
for model_type in ["cnn", "rnn", "transformer"]:
vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
if os.path.exists(vectorizer_path):
# 加载向量化器
vectorizer = SequenceVectorizer()
vectorizer.load(vectorizer_path)
logger.info(f"加载向量化器: {vectorizer_path}")
break
# 创建预测器
predictor = Predictor(
model=model,
tokenizer=tokenizer,
vectorizer=vectorizer,
class_names=CATEGORIES
)
# 3. 预测
result = predictor.predict(
text=text,
return_top_k=top_k,
return_probabilities=True
)
# 4. 输出结果
if top_k > 1:
logger.info("预测结果:")
for i, pred in enumerate(result):
logger.info(f" {i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
else:
logger.info(f"预测结果: {result['class']} (概率: {result['probability']:.4f})")
# 5. 保存结果
if output_path:
if output_path.endswith('.json'):
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
else:
with open(output_path, 'w', encoding='utf-8') as f:
if top_k > 1:
f.write("rank,class,probability\n")
for i, pred in enumerate(result):
f.write(f"{i + 1},{pred['class']},{pred['probability']}\n")
else:
f.write(f"class,probability\n")
f.write(f"{result['class']},{result['probability']}\n")
logger.info(f"结果已保存到: {output_path}")
return result
def predict_file(file_path: str, model_path: Optional[str] = None,
output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]:
"""
预测文件内容
Args:
file_path: 文件路径
model_path: 模型路径如果为None则使用最新的模型
output_path: 输出文件路径如果为None则不保存
top_k: 返回概率最高的前k个类别
Returns:
预测结果
"""
logger.info(f"开始预测文件: {file_path}")
# 检查文件是否存在
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
# 读取文件内容
if file_path.endswith('.txt'):
# 文本文件
text = read_text_file(file_path)
return predict_text(text, model_path, output_path, top_k)
elif file_path.endswith(('.csv', '.xls', '.xlsx')):
# 表格文件
import pandas as pd
if file_path.endswith('.csv'):
df = pd.read_csv(file_path)
else:
df = pd.read_excel(file_path)
# 查找可能的文本列
text_columns = [col for col in df.columns if df[col].dtype == 'object']
if not text_columns:
raise ValueError("文件中没有找到可能的文本列")
# 使用第一个文本列
text_column = text_columns[0]
logger.info(f"使用文本列: {text_column}")
# 1. 加载模型
if model_path is None:
# 获取可用模型列表
models_info = ModelFactory.get_available_models()
if not models_info:
raise ValueError("未找到可用的模型")
# 使用最新的模型
model_path = models_info[0]['path']
logger.info(f"加载模型: {model_path}")
model = ModelFactory.load_model(model_path)
# 2. 创建分词器和预测器
tokenizer = ChineseTokenizer()
# 查找向量化器文件
vectorizer = None
for model_type in ["cnn", "rnn", "transformer"]:
vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
if os.path.exists(vectorizer_path):
# 加载向量化器
vectorizer = SequenceVectorizer()
vectorizer.load(vectorizer_path)
logger.info(f"加载向量化器: {vectorizer_path}")
break
# 创建预测器
predictor = Predictor(
model=model,
tokenizer=tokenizer,
vectorizer=vectorizer,
class_names=CATEGORIES
)
# 3. 创建批处理器
batch_processor = BatchProcessor(
predictor=predictor,
batch_size=64
)
# 4. 批量预测
result_df = batch_processor.process_dataframe(
df=df,
text_column=text_column,
output_path=output_path,
return_top_k=top_k,
format=output_path.split('.')[-1] if output_path else 'csv'
)
logger.info(f"已处理 {len(result_df)} 行数据")
# 返回结果
return result_df.to_dict(orient='records')
else:
raise ValueError(f"不支持的文件类型: {file_path}")
if __name__ == "__main__":
# 解析命令行参数
parser = argparse.ArgumentParser(description="使用模型预测")
parser.add_argument("--model_path", help="模型路径")
parser.add_argument("--text", help="要预测的文本")
parser.add_argument("--file", help="要预测的文件")
parser.add_argument("--output", help="输出文件")
parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别")
args = parser.parse_args()
# 检查输入
if not args.text and not args.file:
parser.error("请提供要预测的文本或文件")
# 预测
if args.text:
predict_text(args.text, args.model_path, args.output, args.top_k)
else:
predict_file(args.file, args.model_path, args.output, args.top_k)