2025-03-08 01:34:36 +08:00

143 lines
4.9 KiB
Python

"""
主入口文件:整合系统的所有功能,提供命令行接口
"""
import os
import sys
import argparse
import logging
from typing import List, Optional
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("main")
# 命令行参数
parser = argparse.ArgumentParser(description="中文文本分类系统")
subparsers = parser.add_subparsers(dest="command", help="命令")
# 训练命令
train_parser = subparsers.add_parser("train", help="训练模型")
train_parser.add_argument("--data_dir", help="数据目录")
train_parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型")
train_parser.add_argument("--epochs", type=int, default=10, help="训练轮数")
train_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
train_parser.add_argument("--save_dir", help="模型保存目录")
# 评估命令
evaluate_parser = subparsers.add_parser("evaluate", help="评估模型")
evaluate_parser.add_argument("--model_path", required=True, help="模型路径")
evaluate_parser.add_argument("--data_dir", help="数据目录")
evaluate_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
evaluate_parser.add_argument("--output_dir", help="评估结果输出目录")
# 预测命令
predict_parser = subparsers.add_parser("predict", help="使用模型预测")
predict_parser.add_argument("--model_path", help="模型路径")
predict_parser.add_argument("--text", help="要预测的文本")
predict_parser.add_argument("--file", help="要预测的文件")
predict_parser.add_argument("--output", help="输出文件")
# Web服务命令
web_parser = subparsers.add_parser("web", help="启动Web服务")
web_parser.add_argument("--host", default="0.0.0.0", help="服务器主机")
web_parser.add_argument("--port", type=int, default=5000, help="服务器端口")
web_parser.add_argument("--debug", action="store_true", help="是否开启调试模式")
# API服务命令
api_parser = subparsers.add_parser("api", help="启动API服务")
api_parser.add_argument("--host", default="0.0.0.0", help="服务器主机")
api_parser.add_argument("--port", type=int, default=8000, help="服务器端口")
# CLI命令
cli_parser = subparsers.add_parser("cli", help="启动命令行接口")
cli_parser.add_argument("--model_path", help="模型路径")
cli_parser.add_argument("--interactive", action="store_true", help="是否开启交互模式")
def main():
"""主函数"""
args = parser.parse_args()
# 如果没有指定命令,显示帮助信息
if not args.command:
parser.print_help()
return 0
# 根据命令调用相应的功能
if args.command == "train":
# 导入训练模块
from scripts.train import train_model
# 调用训练功能
train_model(
data_dir=args.data_dir,
model_type=args.model_type,
epochs=args.epochs,
batch_size=args.batch_size,
save_dir=args.save_dir
)
elif args.command == "evaluate":
# 导入评估模块
from scripts.evaluate import evaluate_model
# 调用评估功能
evaluate_model(
model_path=args.model_path,
data_dir=args.data_dir,
batch_size=args.batch_size,
output_dir=args.output_dir
)
elif args.command == "predict":
# 导入预测模块
from scripts.predict import predict_text, predict_file
# 根据输入类型调用相应的预测功能
if args.text:
predict_text(args.text, args.model_path, args.output)
elif args.file:
predict_file(args.file, args.model_path, args.output)
else:
logger.error("请提供要预测的文本或文件")
return 1
elif args.command == "web":
# 导入Web服务模块
from interface.web.app import run_server
# 启动Web服务
run_server(host=args.host, port=args.port, debug=args.debug)
elif args.command == "api":
# 导入API服务模块
from interface.api import run_server
# 启动API服务
run_server(host=args.host, port=args.port)
elif args.command == "cli":
# 导入CLI模块
from interface.cli import main as cli_main
# 将命令行参数转换为CLI模块可接受的格式
sys.argv = ["interface/cli.py"]
if args.interactive:
sys.argv.append("interactive")
if args.model_path:
sys.argv.extend(["--model_path", args.model_path])
elif args.model_path:
sys.argv.extend(["list", "--model_path", args.model_path])
else:
sys.argv.append("list")
# 调用CLI主函数
return cli_main()
return 0
if __name__ == "__main__":
sys.exit(main())