379 lines
14 KiB
Python
379 lines
14 KiB
Python
"""
|
||
命令行界面模块:提供命令行交互功能
|
||
"""
|
||
import argparse
|
||
import os
|
||
import sys
|
||
import pandas as pd
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import json
|
||
|
||
# 将项目根目录添加到sys.path
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from config.system_config import CLASSIFIERS_DIR, CATEGORIES
|
||
from models.model_factory import ModelFactory
|
||
from models.base_model import TextClassificationModel
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
from preprocessing.vectorizer import SequenceVectorizer
|
||
from inference.predictor import Predictor
|
||
from inference.batch_processor import BatchProcessor
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import ensure_dir, read_text_file
|
||
|
||
logger = get_logger("CLI")
|
||
|
||
|
||
def load_model_and_components(model_path: Optional[str] = None,
|
||
tokenizer_path: Optional[str] = None,
|
||
vectorizer_path: Optional[str] = None,
|
||
class_names: Optional[List[str]] = None) -> Tuple[
|
||
TextClassificationModel, ChineseTokenizer, Optional[SequenceVectorizer]]:
|
||
"""
|
||
加载模型和相关组件
|
||
|
||
Args:
|
||
model_path: 模型路径,如果为None则使用最新的模型
|
||
tokenizer_path: 分词器路径,如果为None则创建一个新的分词器
|
||
vectorizer_path: 向量化器路径,如果为None则不使用向量化器
|
||
class_names: 类别名称列表,如果为None则使用CATEGORIES
|
||
|
||
Returns:
|
||
(模型, 分词器, 向量化器)的元组
|
||
"""
|
||
# 加载模型
|
||
if model_path is None:
|
||
# 获取可用模型列表
|
||
models_info = ModelFactory.get_available_models()
|
||
|
||
if not models_info:
|
||
raise ValueError("未找到可用的模型,请指定模型路径")
|
||
|
||
# 使用最新的模型
|
||
model_path = models_info[0]['path']
|
||
logger.info(f"使用最新的模型: {model_path}")
|
||
|
||
# 加载模型
|
||
model = ModelFactory.load_model(model_path)
|
||
|
||
# 加载或创建分词器
|
||
if tokenizer_path:
|
||
tokenizer = ChineseTokenizer() # 实际上应该从文件加载,这里简化处理
|
||
logger.info(f"已加载分词器: {tokenizer_path}")
|
||
else:
|
||
tokenizer = ChineseTokenizer()
|
||
logger.info("已创建新的分词器")
|
||
|
||
# 加载向量化器
|
||
vectorizer = None
|
||
if vectorizer_path:
|
||
vectorizer = SequenceVectorizer() # 实际上应该从文件加载,这里简化处理
|
||
vectorizer.load(vectorizer_path)
|
||
logger.info(f"已加载向量化器: {vectorizer_path}")
|
||
|
||
return model, tokenizer, vectorizer
|
||
|
||
|
||
def predict_text(args):
|
||
"""处理单条文本预测命令"""
|
||
# 加载模型和组件
|
||
model, tokenizer, vectorizer = load_model_and_components(
|
||
args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
|
||
)
|
||
|
||
# 创建预测器
|
||
predictor = Predictor(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
vectorizer=vectorizer,
|
||
class_names=args.class_names or CATEGORIES,
|
||
batch_size=args.batch_size
|
||
)
|
||
|
||
# 获取文本
|
||
text = args.text
|
||
|
||
# 如果提供的是文件路径而非文本内容
|
||
if args.file and os.path.exists(text):
|
||
text = read_text_file(text)
|
||
|
||
# 预测
|
||
result = predictor.predict(
|
||
text=text,
|
||
return_top_k=args.top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 输出结果
|
||
if args.top_k > 1:
|
||
print("\n预测结果:")
|
||
for i, pred in enumerate(result):
|
||
print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
|
||
else:
|
||
print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})")
|
||
|
||
# 保存结果
|
||
if args.output:
|
||
if args.output.endswith('.json'):
|
||
with open(args.output, 'w', encoding='utf-8') as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
else:
|
||
with open(args.output, 'w', encoding='utf-8') as f:
|
||
if args.top_k > 1:
|
||
f.write("rank,class,probability\n")
|
||
for i, pred in enumerate(result):
|
||
f.write(f"{i + 1},{pred['class']},{pred['probability']}\n")
|
||
else:
|
||
f.write(f"class,probability\n")
|
||
f.write(f"{result['class']},{result['probability']}\n")
|
||
|
||
print(f"结果已保存到: {args.output}")
|
||
|
||
|
||
def predict_batch(args):
|
||
"""处理批量文本预测命令"""
|
||
# 加载模型和组件
|
||
model, tokenizer, vectorizer = load_model_and_components(
|
||
args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
|
||
)
|
||
|
||
# 创建预测器
|
||
predictor = Predictor(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
vectorizer=vectorizer,
|
||
class_names=args.class_names or CATEGORIES,
|
||
batch_size=args.batch_size
|
||
)
|
||
|
||
# 创建批处理器
|
||
batch_processor = BatchProcessor(
|
||
predictor=predictor,
|
||
batch_size=args.batch_size,
|
||
max_workers=args.workers
|
||
)
|
||
|
||
# 确保输出目录存在
|
||
if args.output:
|
||
ensure_dir(os.path.dirname(args.output))
|
||
|
||
# 根据输入类型选择处理方法
|
||
if args.input_type == 'file' and os.path.isfile(args.input):
|
||
# 单个文件
|
||
if args.large_file:
|
||
# 大型文件,分块处理
|
||
batch_processor.process_large_file(
|
||
file_path=args.input,
|
||
output_path=args.output,
|
||
return_top_k=args.top_k,
|
||
format=args.format
|
||
)
|
||
else:
|
||
# CSV或Excel文件
|
||
if args.input.endswith('.csv'):
|
||
df = pd.read_csv(args.input, encoding='utf-8')
|
||
elif args.input.endswith(('.xls', '.xlsx')):
|
||
df = pd.read_excel(args.input)
|
||
else:
|
||
print(f"不支持的文件格式: {args.input}")
|
||
return
|
||
|
||
# 检查文本列是否存在
|
||
if args.text_column not in df.columns:
|
||
print(f"文本列 '{args.text_column}' 不在输入文件中,可用列: {', '.join(df.columns)}")
|
||
return
|
||
|
||
# 处理DataFrame
|
||
result_df = batch_processor.process_dataframe(
|
||
df=df,
|
||
text_column=args.text_column,
|
||
id_column=args.id_column,
|
||
output_path=args.output,
|
||
return_top_k=args.top_k,
|
||
format=args.format
|
||
)
|
||
|
||
# 输出结果统计
|
||
print(f"\n已处理 {len(result_df)} 条文本")
|
||
print("类别分布:")
|
||
if args.top_k == 1:
|
||
class_counts = result_df['predicted_class'].value_counts()
|
||
for cls, count in class_counts.items():
|
||
print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)")
|
||
|
||
elif args.input_type == 'dir' and os.path.isdir(args.input):
|
||
# 目录
|
||
result_df = batch_processor.process_directory(
|
||
directory=args.input,
|
||
pattern=args.pattern,
|
||
output_path=args.output,
|
||
return_top_k=args.top_k,
|
||
format=args.format,
|
||
recursive=args.recursive
|
||
)
|
||
|
||
# 输出结果统计
|
||
if not result_df.empty:
|
||
print(f"\n已处理 {len(result_df)} 个文件")
|
||
print("类别分布:")
|
||
if args.top_k == 1:
|
||
class_counts = result_df['predicted_class'].value_counts()
|
||
for cls, count in class_counts.items():
|
||
print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)")
|
||
|
||
else:
|
||
print(f"无效的输入: {args.input}")
|
||
|
||
|
||
def list_models(args):
|
||
"""列出可用的模型"""
|
||
models_info = ModelFactory.get_available_models()
|
||
|
||
if not models_info:
|
||
print("未找到可用的模型")
|
||
return
|
||
|
||
print(f"找到 {len(models_info)} 个可用模型:")
|
||
for i, info in enumerate(models_info):
|
||
print(f"\n{i + 1}. {info['name']} ({info['type']})")
|
||
print(f" 路径: {info['path']}")
|
||
print(f" 创建时间: {info['created_time']}")
|
||
print(f" 类别数: {info['num_classes']}")
|
||
print(f" 文件大小: {info['file_size']}")
|
||
|
||
|
||
def interactive_mode(args):
|
||
"""交互模式"""
|
||
print("启动交互模式...")
|
||
|
||
# 加载模型和组件
|
||
model, tokenizer, vectorizer = load_model_and_components(
|
||
args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
|
||
)
|
||
|
||
# 创建预测器
|
||
predictor = Predictor(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
vectorizer=vectorizer,
|
||
class_names=args.class_names or CATEGORIES,
|
||
batch_size=args.batch_size
|
||
)
|
||
|
||
print("\n模型已加载,可以开始交互式文本分类")
|
||
print("输入 'quit' 或 'exit' 退出交互模式\n")
|
||
|
||
while True:
|
||
try:
|
||
# 获取用户输入
|
||
text = input("请输入要分类的文本: ")
|
||
|
||
# 检查是否退出
|
||
if text.lower() in ['quit', 'exit', 'q']:
|
||
print("退出交互模式")
|
||
break
|
||
|
||
# 空输入
|
||
if not text.strip():
|
||
continue
|
||
|
||
# 预测
|
||
result = predictor.predict(
|
||
text=text,
|
||
return_top_k=args.top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 输出结果
|
||
if args.top_k > 1:
|
||
print("\n预测结果:")
|
||
for i, pred in enumerate(result):
|
||
print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
|
||
else:
|
||
print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})")
|
||
|
||
print() # 空行
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n退出交互模式")
|
||
break
|
||
except Exception as e:
|
||
print(f"处理过程中出错: {e}")
|
||
|
||
|
||
def main():
|
||
"""主函数,解析命令行参数并调用相应的功能"""
|
||
parser = argparse.ArgumentParser(description="中文文本分类系统命令行工具")
|
||
|
||
# 创建子命令
|
||
subparsers = parser.add_subparsers(dest="command", help="子命令")
|
||
|
||
# 预测单条文本命令
|
||
predict_parser = subparsers.add_parser("predict", help="预测单条文本")
|
||
predict_parser.add_argument("text", help="要预测的文本或文本文件路径")
|
||
predict_parser.add_argument("--file", action="store_true", help="将text参数视为文件路径")
|
||
predict_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
|
||
predict_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
|
||
predict_parser.add_argument("--vectorizer_path", help="向量化器路径")
|
||
predict_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
|
||
predict_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别")
|
||
predict_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
|
||
predict_parser.add_argument("--output", help="保存预测结果的文件路径")
|
||
predict_parser.set_defaults(func=predict_text)
|
||
|
||
# 批量预测命令
|
||
batch_parser = subparsers.add_parser("batch", help="批量预测文本")
|
||
batch_parser.add_argument("input", help="输入文件或目录路径")
|
||
batch_parser.add_argument("--input_type", choices=["file", "dir"], default="file", help="输入类型")
|
||
batch_parser.add_argument("--text_column", default="text", help="CSV/Excel文件中的文本列名")
|
||
batch_parser.add_argument("--id_column", help="CSV/Excel文件中的ID列名")
|
||
batch_parser.add_argument("--pattern", default="*.txt", help="文件匹配模式")
|
||
batch_parser.add_argument("--recursive", action="store_true", help="递归处理子目录")
|
||
batch_parser.add_argument("--large_file", action="store_true", help="处理大型文本文件")
|
||
batch_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
|
||
batch_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
|
||
batch_parser.add_argument("--vectorizer_path", help="向量化器路径")
|
||
batch_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
|
||
batch_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别")
|
||
batch_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
|
||
batch_parser.add_argument("--workers", type=int, default=4, help="工作线程数")
|
||
batch_parser.add_argument("--output", required=True, help="输出文件路径")
|
||
batch_parser.add_argument("--format", choices=["csv", "json"], default="csv", help="输出格式")
|
||
batch_parser.set_defaults(func=predict_batch)
|
||
|
||
# 列出可用模型命令
|
||
list_parser = subparsers.add_parser("list", help="列出可用的模型")
|
||
list_parser.set_defaults(func=list_models)
|
||
|
||
# 交互模式命令
|
||
interactive_parser = subparsers.add_parser("interactive", help="启动交互式分类模式")
|
||
interactive_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
|
||
interactive_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
|
||
interactive_parser.add_argument("--vectorizer_path", help="向量化器路径")
|
||
interactive_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
|
||
interactive_parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别")
|
||
interactive_parser.add_argument("--batch_size", type=int, default=1, help="批大小")
|
||
interactive_parser.set_defaults(func=interactive_mode)
|
||
|
||
# 解析参数
|
||
args = parser.parse_args()
|
||
|
||
# 如果没有指定命令,显示帮助
|
||
if not hasattr(args, 'func'):
|
||
parser.print_help()
|
||
return
|
||
|
||
# 执行命令
|
||
try:
|
||
args.func(args)
|
||
except Exception as e:
|
||
logger.error(f"执行命令时出错: {e}")
|
||
print(f"执行命令时出错: {e}")
|
||
return 1
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|