""" 主入口文件:整合系统的所有功能,提供命令行接口 """ import os import sys import argparse import logging from typing import List, Optional # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger("main") # 命令行参数 parser = argparse.ArgumentParser(description="中文文本分类系统") subparsers = parser.add_subparsers(dest="command", help="命令") # 训练命令 train_parser = subparsers.add_parser("train", help="训练模型") train_parser.add_argument("--data_dir", help="数据目录") train_parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型") train_parser.add_argument("--epochs", type=int, default=10, help="训练轮数") train_parser.add_argument("--batch_size", type=int, default=64, help="批大小") train_parser.add_argument("--save_dir", help="模型保存目录") # 评估命令 evaluate_parser = subparsers.add_parser("evaluate", help="评估模型") evaluate_parser.add_argument("--model_path", required=True, help="模型路径") evaluate_parser.add_argument("--data_dir", help="数据目录") evaluate_parser.add_argument("--batch_size", type=int, default=64, help="批大小") evaluate_parser.add_argument("--output_dir", help="评估结果输出目录") # 预测命令 predict_parser = subparsers.add_parser("predict", help="使用模型预测") predict_parser.add_argument("--model_path", help="模型路径") predict_parser.add_argument("--text", help="要预测的文本") predict_parser.add_argument("--file", help="要预测的文件") predict_parser.add_argument("--output", help="输出文件") # Web服务命令 web_parser = subparsers.add_parser("web", help="启动Web服务") web_parser.add_argument("--host", default="0.0.0.0", help="服务器主机") web_parser.add_argument("--port", type=int, default=5000, help="服务器端口") web_parser.add_argument("--debug", action="store_true", help="是否开启调试模式") # API服务命令 api_parser = subparsers.add_parser("api", help="启动API服务") api_parser.add_argument("--host", default="0.0.0.0", help="服务器主机") api_parser.add_argument("--port", type=int, default=8000, help="服务器端口") # CLI命令 cli_parser = subparsers.add_parser("cli", help="启动命令行接口") cli_parser.add_argument("--model_path", help="模型路径") cli_parser.add_argument("--interactive", action="store_true", help="是否开启交互模式") def main(): """主函数""" args = parser.parse_args() # 如果没有指定命令,显示帮助信息 if not args.command: parser.print_help() return 0 # 根据命令调用相应的功能 if args.command == "train": # 导入训练模块 from scripts.train import train_model # 调用训练功能 train_model( data_dir=args.data_dir, model_type=args.model_type, epochs=args.epochs, batch_size=args.batch_size, save_dir=args.save_dir ) elif args.command == "evaluate": # 导入评估模块 from scripts.evaluate import evaluate_model # 调用评估功能 evaluate_model( model_path=args.model_path, data_dir=args.data_dir, batch_size=args.batch_size, output_dir=args.output_dir ) elif args.command == "predict": # 导入预测模块 from scripts.predict import predict_text, predict_file # 根据输入类型调用相应的预测功能 if args.text: predict_text(args.text, args.model_path, args.output) elif args.file: predict_file(args.file, args.model_path, args.output) else: logger.error("请提供要预测的文本或文件") return 1 elif args.command == "web": # 导入Web服务模块 from interface.web.app import run_server # 启动Web服务 run_server(host=args.host, port=args.port, debug=args.debug) elif args.command == "api": # 导入API服务模块 from interface.api import run_server # 启动API服务 run_server(host=args.host, port=args.port) elif args.command == "cli": # 导入CLI模块 from interface.cli import main as cli_main # 将命令行参数转换为CLI模块可接受的格式 sys.argv = ["interface/cli.py"] if args.interactive: sys.argv.append("interactive") if args.model_path: sys.argv.extend(["--model_path", args.model_path]) elif args.model_path: sys.argv.extend(["list", "--model_path", args.model_path]) else: sys.argv.append("list") # 调用CLI主函数 return cli_main() return 0 if __name__ == "__main__": sys.exit(main())