""" 命令行界面模块:提供命令行交互功能 """ import argparse import os import sys import pandas as pd from typing import List, Dict, Tuple, Optional, Any, Union import json # 将项目根目录添加到sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from config.system_config import CLASSIFIERS_DIR, CATEGORIES from models.model_factory import ModelFactory from models.base_model import TextClassificationModel from preprocessing.tokenization import ChineseTokenizer from preprocessing.vectorizer import SequenceVectorizer from inference.predictor import Predictor from inference.batch_processor import BatchProcessor from utils.logger import get_logger from utils.file_utils import ensure_dir, read_text_file logger = get_logger("CLI") def load_model_and_components(model_path: Optional[str] = None, tokenizer_path: Optional[str] = None, vectorizer_path: Optional[str] = None, class_names: Optional[List[str]] = None) -> Tuple[ TextClassificationModel, ChineseTokenizer, Optional[SequenceVectorizer]]: """ 加载模型和相关组件 Args: model_path: 模型路径,如果为None则使用最新的模型 tokenizer_path: 分词器路径,如果为None则创建一个新的分词器 vectorizer_path: 向量化器路径,如果为None则不使用向量化器 class_names: 类别名称列表,如果为None则使用CATEGORIES Returns: (模型, 分词器, 向量化器)的元组 """ # 加载模型 if model_path is None: # 获取可用模型列表 models_info = ModelFactory.get_available_models() if not models_info: raise ValueError("未找到可用的模型,请指定模型路径") # 使用最新的模型 model_path = models_info[0]['path'] logger.info(f"使用最新的模型: {model_path}") # 加载模型 model = ModelFactory.load_model(model_path) # 加载或创建分词器 if tokenizer_path: tokenizer = ChineseTokenizer() # 实际上应该从文件加载,这里简化处理 logger.info(f"已加载分词器: {tokenizer_path}") else: tokenizer = ChineseTokenizer() logger.info("已创建新的分词器") # 加载向量化器 vectorizer = None if vectorizer_path: vectorizer = SequenceVectorizer() # 实际上应该从文件加载,这里简化处理 vectorizer.load(vectorizer_path) logger.info(f"已加载向量化器: {vectorizer_path}") return model, tokenizer, vectorizer def predict_text(args): """处理单条文本预测命令""" # 加载模型和组件 model, tokenizer, vectorizer = load_model_and_components( args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names ) # 创建预测器 predictor = Predictor( model=model, tokenizer=tokenizer, vectorizer=vectorizer, class_names=args.class_names or CATEGORIES, batch_size=args.batch_size ) # 获取文本 text = args.text # 如果提供的是文件路径而非文本内容 if args.file and os.path.exists(text): text = read_text_file(text) # 预测 result = predictor.predict( text=text, return_top_k=args.top_k, return_probabilities=True ) # 输出结果 if args.top_k > 1: print("\n预测结果:") for i, pred in enumerate(result): print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})") else: print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})") # 保存结果 if args.output: if args.output.endswith('.json'): with open(args.output, 'w', encoding='utf-8') as f: json.dump(result, f, ensure_ascii=False, indent=2) else: with open(args.output, 'w', encoding='utf-8') as f: if args.top_k > 1: f.write("rank,class,probability\n") for i, pred in enumerate(result): f.write(f"{i + 1},{pred['class']},{pred['probability']}\n") else: f.write(f"class,probability\n") f.write(f"{result['class']},{result['probability']}\n") print(f"结果已保存到: {args.output}") def predict_batch(args): """处理批量文本预测命令""" # 加载模型和组件 model, tokenizer, vectorizer = load_model_and_components( args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names ) # 创建预测器 predictor = Predictor( model=model, tokenizer=tokenizer, vectorizer=vectorizer, class_names=args.class_names or CATEGORIES, batch_size=args.batch_size ) # 创建批处理器 batch_processor = BatchProcessor( predictor=predictor, batch_size=args.batch_size, max_workers=args.workers ) # 确保输出目录存在 if args.output: ensure_dir(os.path.dirname(args.output)) # 根据输入类型选择处理方法 if args.input_type == 'file' and os.path.isfile(args.input): # 单个文件 if args.large_file: # 大型文件,分块处理 batch_processor.process_large_file( file_path=args.input, output_path=args.output, return_top_k=args.top_k, format=args.format ) else: # CSV或Excel文件 if args.input.endswith('.csv'): df = pd.read_csv(args.input, encoding='utf-8') elif args.input.endswith(('.xls', '.xlsx')): df = pd.read_excel(args.input) else: print(f"不支持的文件格式: {args.input}") return # 检查文本列是否存在 if args.text_column not in df.columns: print(f"文本列 '{args.text_column}' 不在输入文件中,可用列: {', '.join(df.columns)}") return # 处理DataFrame result_df = batch_processor.process_dataframe( df=df, text_column=args.text_column, id_column=args.id_column, output_path=args.output, return_top_k=args.top_k, format=args.format ) # 输出结果统计 print(f"\n已处理 {len(result_df)} 条文本") print("类别分布:") if args.top_k == 1: class_counts = result_df['predicted_class'].value_counts() for cls, count in class_counts.items(): print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)") elif args.input_type == 'dir' and os.path.isdir(args.input): # 目录 result_df = batch_processor.process_directory( directory=args.input, pattern=args.pattern, output_path=args.output, return_top_k=args.top_k, format=args.format, recursive=args.recursive ) # 输出结果统计 if not result_df.empty: print(f"\n已处理 {len(result_df)} 个文件") print("类别分布:") if args.top_k == 1: class_counts = result_df['predicted_class'].value_counts() for cls, count in class_counts.items(): print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)") else: print(f"无效的输入: {args.input}") def list_models(args): """列出可用的模型""" models_info = ModelFactory.get_available_models() if not models_info: print("未找到可用的模型") return print(f"找到 {len(models_info)} 个可用模型:") for i, info in enumerate(models_info): print(f"\n{i + 1}. {info['name']} ({info['type']})") print(f" 路径: {info['path']}") print(f" 创建时间: {info['created_time']}") print(f" 类别数: {info['num_classes']}") print(f" 文件大小: {info['file_size']}") def interactive_mode(args): """交互模式""" print("启动交互模式...") # 加载模型和组件 model, tokenizer, vectorizer = load_model_and_components( args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names ) # 创建预测器 predictor = Predictor( model=model, tokenizer=tokenizer, vectorizer=vectorizer, class_names=args.class_names or CATEGORIES, batch_size=args.batch_size ) print("\n模型已加载,可以开始交互式文本分类") print("输入 'quit' 或 'exit' 退出交互模式\n") while True: try: # 获取用户输入 text = input("请输入要分类的文本: ") # 检查是否退出 if text.lower() in ['quit', 'exit', 'q']: print("退出交互模式") break # 空输入 if not text.strip(): continue # 预测 result = predictor.predict( text=text, return_top_k=args.top_k, return_probabilities=True ) # 输出结果 if args.top_k > 1: print("\n预测结果:") for i, pred in enumerate(result): print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})") else: print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})") print() # 空行 except KeyboardInterrupt: print("\n退出交互模式") break except Exception as e: print(f"处理过程中出错: {e}") def main(): """主函数,解析命令行参数并调用相应的功能""" parser = argparse.ArgumentParser(description="中文文本分类系统命令行工具") # 创建子命令 subparsers = parser.add_subparsers(dest="command", help="子命令") # 预测单条文本命令 predict_parser = subparsers.add_parser("predict", help="预测单条文本") predict_parser.add_argument("text", help="要预测的文本或文本文件路径") predict_parser.add_argument("--file", action="store_true", help="将text参数视为文件路径") predict_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型") predict_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器") predict_parser.add_argument("--vectorizer_path", help="向量化器路径") predict_parser.add_argument("--class_names", nargs="+", help="类别名称列表") predict_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别") predict_parser.add_argument("--batch_size", type=int, default=64, help="批大小") predict_parser.add_argument("--output", help="保存预测结果的文件路径") predict_parser.set_defaults(func=predict_text) # 批量预测命令 batch_parser = subparsers.add_parser("batch", help="批量预测文本") batch_parser.add_argument("input", help="输入文件或目录路径") batch_parser.add_argument("--input_type", choices=["file", "dir"], default="file", help="输入类型") batch_parser.add_argument("--text_column", default="text", help="CSV/Excel文件中的文本列名") batch_parser.add_argument("--id_column", help="CSV/Excel文件中的ID列名") batch_parser.add_argument("--pattern", default="*.txt", help="文件匹配模式") batch_parser.add_argument("--recursive", action="store_true", help="递归处理子目录") batch_parser.add_argument("--large_file", action="store_true", help="处理大型文本文件") batch_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型") batch_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器") batch_parser.add_argument("--vectorizer_path", help="向量化器路径") batch_parser.add_argument("--class_names", nargs="+", help="类别名称列表") batch_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别") batch_parser.add_argument("--batch_size", type=int, default=64, help="批大小") batch_parser.add_argument("--workers", type=int, default=4, help="工作线程数") batch_parser.add_argument("--output", required=True, help="输出文件路径") batch_parser.add_argument("--format", choices=["csv", "json"], default="csv", help="输出格式") batch_parser.set_defaults(func=predict_batch) # 列出可用模型命令 list_parser = subparsers.add_parser("list", help="列出可用的模型") list_parser.set_defaults(func=list_models) # 交互模式命令 interactive_parser = subparsers.add_parser("interactive", help="启动交互式分类模式") interactive_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型") interactive_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器") interactive_parser.add_argument("--vectorizer_path", help="向量化器路径") interactive_parser.add_argument("--class_names", nargs="+", help="类别名称列表") interactive_parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别") interactive_parser.add_argument("--batch_size", type=int, default=1, help="批大小") interactive_parser.set_defaults(func=interactive_mode) # 解析参数 args = parser.parse_args() # 如果没有指定命令,显示帮助 if not hasattr(args, 'func'): parser.print_help() return # 执行命令 try: args.func(args) except Exception as e: logger.error(f"执行命令时出错: {e}") print(f"执行命令时出错: {e}") return 1 return 0 if __name__ == "__main__": sys.exit(main())