2025-03-08 01:34:36 +08:00

379 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
命令行界面模块:提供命令行交互功能
"""
import argparse
import os
import sys
import pandas as pd
from typing import List, Dict, Tuple, Optional, Any, Union
import json
# 将项目根目录添加到sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config.system_config import CLASSIFIERS_DIR, CATEGORIES
from models.model_factory import ModelFactory
from models.base_model import TextClassificationModel
from preprocessing.tokenization import ChineseTokenizer
from preprocessing.vectorizer import SequenceVectorizer
from inference.predictor import Predictor
from inference.batch_processor import BatchProcessor
from utils.logger import get_logger
from utils.file_utils import ensure_dir, read_text_file
logger = get_logger("CLI")
def load_model_and_components(model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
vectorizer_path: Optional[str] = None,
class_names: Optional[List[str]] = None) -> Tuple[
TextClassificationModel, ChineseTokenizer, Optional[SequenceVectorizer]]:
"""
加载模型和相关组件
Args:
model_path: 模型路径如果为None则使用最新的模型
tokenizer_path: 分词器路径如果为None则创建一个新的分词器
vectorizer_path: 向量化器路径如果为None则不使用向量化器
class_names: 类别名称列表如果为None则使用CATEGORIES
Returns:
(模型, 分词器, 向量化器)的元组
"""
# 加载模型
if model_path is None:
# 获取可用模型列表
models_info = ModelFactory.get_available_models()
if not models_info:
raise ValueError("未找到可用的模型,请指定模型路径")
# 使用最新的模型
model_path = models_info[0]['path']
logger.info(f"使用最新的模型: {model_path}")
# 加载模型
model = ModelFactory.load_model(model_path)
# 加载或创建分词器
if tokenizer_path:
tokenizer = ChineseTokenizer() # 实际上应该从文件加载,这里简化处理
logger.info(f"已加载分词器: {tokenizer_path}")
else:
tokenizer = ChineseTokenizer()
logger.info("已创建新的分词器")
# 加载向量化器
vectorizer = None
if vectorizer_path:
vectorizer = SequenceVectorizer() # 实际上应该从文件加载,这里简化处理
vectorizer.load(vectorizer_path)
logger.info(f"已加载向量化器: {vectorizer_path}")
return model, tokenizer, vectorizer
def predict_text(args):
"""处理单条文本预测命令"""
# 加载模型和组件
model, tokenizer, vectorizer = load_model_and_components(
args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
)
# 创建预测器
predictor = Predictor(
model=model,
tokenizer=tokenizer,
vectorizer=vectorizer,
class_names=args.class_names or CATEGORIES,
batch_size=args.batch_size
)
# 获取文本
text = args.text
# 如果提供的是文件路径而非文本内容
if args.file and os.path.exists(text):
text = read_text_file(text)
# 预测
result = predictor.predict(
text=text,
return_top_k=args.top_k,
return_probabilities=True
)
# 输出结果
if args.top_k > 1:
print("\n预测结果:")
for i, pred in enumerate(result):
print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
else:
print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})")
# 保存结果
if args.output:
if args.output.endswith('.json'):
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
else:
with open(args.output, 'w', encoding='utf-8') as f:
if args.top_k > 1:
f.write("rank,class,probability\n")
for i, pred in enumerate(result):
f.write(f"{i + 1},{pred['class']},{pred['probability']}\n")
else:
f.write(f"class,probability\n")
f.write(f"{result['class']},{result['probability']}\n")
print(f"结果已保存到: {args.output}")
def predict_batch(args):
"""处理批量文本预测命令"""
# 加载模型和组件
model, tokenizer, vectorizer = load_model_and_components(
args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
)
# 创建预测器
predictor = Predictor(
model=model,
tokenizer=tokenizer,
vectorizer=vectorizer,
class_names=args.class_names or CATEGORIES,
batch_size=args.batch_size
)
# 创建批处理器
batch_processor = BatchProcessor(
predictor=predictor,
batch_size=args.batch_size,
max_workers=args.workers
)
# 确保输出目录存在
if args.output:
ensure_dir(os.path.dirname(args.output))
# 根据输入类型选择处理方法
if args.input_type == 'file' and os.path.isfile(args.input):
# 单个文件
if args.large_file:
# 大型文件,分块处理
batch_processor.process_large_file(
file_path=args.input,
output_path=args.output,
return_top_k=args.top_k,
format=args.format
)
else:
# CSV或Excel文件
if args.input.endswith('.csv'):
df = pd.read_csv(args.input, encoding='utf-8')
elif args.input.endswith(('.xls', '.xlsx')):
df = pd.read_excel(args.input)
else:
print(f"不支持的文件格式: {args.input}")
return
# 检查文本列是否存在
if args.text_column not in df.columns:
print(f"文本列 '{args.text_column}' 不在输入文件中,可用列: {', '.join(df.columns)}")
return
# 处理DataFrame
result_df = batch_processor.process_dataframe(
df=df,
text_column=args.text_column,
id_column=args.id_column,
output_path=args.output,
return_top_k=args.top_k,
format=args.format
)
# 输出结果统计
print(f"\n已处理 {len(result_df)} 条文本")
print("类别分布:")
if args.top_k == 1:
class_counts = result_df['predicted_class'].value_counts()
for cls, count in class_counts.items():
print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)")
elif args.input_type == 'dir' and os.path.isdir(args.input):
# 目录
result_df = batch_processor.process_directory(
directory=args.input,
pattern=args.pattern,
output_path=args.output,
return_top_k=args.top_k,
format=args.format,
recursive=args.recursive
)
# 输出结果统计
if not result_df.empty:
print(f"\n已处理 {len(result_df)} 个文件")
print("类别分布:")
if args.top_k == 1:
class_counts = result_df['predicted_class'].value_counts()
for cls, count in class_counts.items():
print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)")
else:
print(f"无效的输入: {args.input}")
def list_models(args):
"""列出可用的模型"""
models_info = ModelFactory.get_available_models()
if not models_info:
print("未找到可用的模型")
return
print(f"找到 {len(models_info)} 个可用模型:")
for i, info in enumerate(models_info):
print(f"\n{i + 1}. {info['name']} ({info['type']})")
print(f" 路径: {info['path']}")
print(f" 创建时间: {info['created_time']}")
print(f" 类别数: {info['num_classes']}")
print(f" 文件大小: {info['file_size']}")
def interactive_mode(args):
"""交互模式"""
print("启动交互模式...")
# 加载模型和组件
model, tokenizer, vectorizer = load_model_and_components(
args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
)
# 创建预测器
predictor = Predictor(
model=model,
tokenizer=tokenizer,
vectorizer=vectorizer,
class_names=args.class_names or CATEGORIES,
batch_size=args.batch_size
)
print("\n模型已加载,可以开始交互式文本分类")
print("输入 'quit''exit' 退出交互模式\n")
while True:
try:
# 获取用户输入
text = input("请输入要分类的文本: ")
# 检查是否退出
if text.lower() in ['quit', 'exit', 'q']:
print("退出交互模式")
break
# 空输入
if not text.strip():
continue
# 预测
result = predictor.predict(
text=text,
return_top_k=args.top_k,
return_probabilities=True
)
# 输出结果
if args.top_k > 1:
print("\n预测结果:")
for i, pred in enumerate(result):
print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
else:
print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})")
print() # 空行
except KeyboardInterrupt:
print("\n退出交互模式")
break
except Exception as e:
print(f"处理过程中出错: {e}")
def main():
"""主函数,解析命令行参数并调用相应的功能"""
parser = argparse.ArgumentParser(description="中文文本分类系统命令行工具")
# 创建子命令
subparsers = parser.add_subparsers(dest="command", help="子命令")
# 预测单条文本命令
predict_parser = subparsers.add_parser("predict", help="预测单条文本")
predict_parser.add_argument("text", help="要预测的文本或文本文件路径")
predict_parser.add_argument("--file", action="store_true", help="将text参数视为文件路径")
predict_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
predict_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
predict_parser.add_argument("--vectorizer_path", help="向量化器路径")
predict_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
predict_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别")
predict_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
predict_parser.add_argument("--output", help="保存预测结果的文件路径")
predict_parser.set_defaults(func=predict_text)
# 批量预测命令
batch_parser = subparsers.add_parser("batch", help="批量预测文本")
batch_parser.add_argument("input", help="输入文件或目录路径")
batch_parser.add_argument("--input_type", choices=["file", "dir"], default="file", help="输入类型")
batch_parser.add_argument("--text_column", default="text", help="CSV/Excel文件中的文本列名")
batch_parser.add_argument("--id_column", help="CSV/Excel文件中的ID列名")
batch_parser.add_argument("--pattern", default="*.txt", help="文件匹配模式")
batch_parser.add_argument("--recursive", action="store_true", help="递归处理子目录")
batch_parser.add_argument("--large_file", action="store_true", help="处理大型文本文件")
batch_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
batch_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
batch_parser.add_argument("--vectorizer_path", help="向量化器路径")
batch_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
batch_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别")
batch_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
batch_parser.add_argument("--workers", type=int, default=4, help="工作线程数")
batch_parser.add_argument("--output", required=True, help="输出文件路径")
batch_parser.add_argument("--format", choices=["csv", "json"], default="csv", help="输出格式")
batch_parser.set_defaults(func=predict_batch)
# 列出可用模型命令
list_parser = subparsers.add_parser("list", help="列出可用的模型")
list_parser.set_defaults(func=list_models)
# 交互模式命令
interactive_parser = subparsers.add_parser("interactive", help="启动交互式分类模式")
interactive_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
interactive_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
interactive_parser.add_argument("--vectorizer_path", help="向量化器路径")
interactive_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
interactive_parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别")
interactive_parser.add_argument("--batch_size", type=int, default=1, help="批大小")
interactive_parser.set_defaults(func=interactive_mode)
# 解析参数
args = parser.parse_args()
# 如果没有指定命令,显示帮助
if not hasattr(args, 'func'):
parser.print_help()
return
# 执行命令
try:
args.func(args)
except Exception as e:
logger.error(f"执行命令时出错: {e}")
print(f"执行命令时出错: {e}")
return 1
return 0
if __name__ == "__main__":
sys.exit(main())