143 lines
4.9 KiB
Python
143 lines
4.9 KiB
Python
"""
|
|
主入口文件:整合系统的所有功能,提供命令行接口
|
|
"""
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import logging
|
|
from typing import List, Optional
|
|
|
|
# 设置日志
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger("main")
|
|
|
|
# 命令行参数
|
|
parser = argparse.ArgumentParser(description="中文文本分类系统")
|
|
subparsers = parser.add_subparsers(dest="command", help="命令")
|
|
|
|
# 训练命令
|
|
train_parser = subparsers.add_parser("train", help="训练模型")
|
|
train_parser.add_argument("--data_dir", help="数据目录")
|
|
train_parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型")
|
|
train_parser.add_argument("--epochs", type=int, default=10, help="训练轮数")
|
|
train_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
|
|
train_parser.add_argument("--save_dir", help="模型保存目录")
|
|
|
|
# 评估命令
|
|
evaluate_parser = subparsers.add_parser("evaluate", help="评估模型")
|
|
evaluate_parser.add_argument("--model_path", required=True, help="模型路径")
|
|
evaluate_parser.add_argument("--data_dir", help="数据目录")
|
|
evaluate_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
|
|
evaluate_parser.add_argument("--output_dir", help="评估结果输出目录")
|
|
|
|
# 预测命令
|
|
predict_parser = subparsers.add_parser("predict", help="使用模型预测")
|
|
predict_parser.add_argument("--model_path", help="模型路径")
|
|
predict_parser.add_argument("--text", help="要预测的文本")
|
|
predict_parser.add_argument("--file", help="要预测的文件")
|
|
predict_parser.add_argument("--output", help="输出文件")
|
|
|
|
# Web服务命令
|
|
web_parser = subparsers.add_parser("web", help="启动Web服务")
|
|
web_parser.add_argument("--host", default="0.0.0.0", help="服务器主机")
|
|
web_parser.add_argument("--port", type=int, default=5000, help="服务器端口")
|
|
web_parser.add_argument("--debug", action="store_true", help="是否开启调试模式")
|
|
|
|
# API服务命令
|
|
api_parser = subparsers.add_parser("api", help="启动API服务")
|
|
api_parser.add_argument("--host", default="0.0.0.0", help="服务器主机")
|
|
api_parser.add_argument("--port", type=int, default=8000, help="服务器端口")
|
|
|
|
# CLI命令
|
|
cli_parser = subparsers.add_parser("cli", help="启动命令行接口")
|
|
cli_parser.add_argument("--model_path", help="模型路径")
|
|
cli_parser.add_argument("--interactive", action="store_true", help="是否开启交互模式")
|
|
|
|
|
|
def main():
|
|
"""主函数"""
|
|
args = parser.parse_args()
|
|
|
|
# 如果没有指定命令,显示帮助信息
|
|
if not args.command:
|
|
parser.print_help()
|
|
return 0
|
|
|
|
# 根据命令调用相应的功能
|
|
if args.command == "train":
|
|
# 导入训练模块
|
|
from scripts.train import train_model
|
|
|
|
# 调用训练功能
|
|
train_model(
|
|
data_dir=args.data_dir,
|
|
model_type=args.model_type,
|
|
epochs=args.epochs,
|
|
batch_size=args.batch_size,
|
|
save_dir=args.save_dir
|
|
)
|
|
|
|
elif args.command == "evaluate":
|
|
# 导入评估模块
|
|
from scripts.evaluate import evaluate_model
|
|
|
|
# 调用评估功能
|
|
evaluate_model(
|
|
model_path=args.model_path,
|
|
data_dir=args.data_dir,
|
|
batch_size=args.batch_size,
|
|
output_dir=args.output_dir
|
|
)
|
|
|
|
elif args.command == "predict":
|
|
# 导入预测模块
|
|
from scripts.predict import predict_text, predict_file
|
|
|
|
# 根据输入类型调用相应的预测功能
|
|
if args.text:
|
|
predict_text(args.text, args.model_path, args.output)
|
|
elif args.file:
|
|
predict_file(args.file, args.model_path, args.output)
|
|
else:
|
|
logger.error("请提供要预测的文本或文件")
|
|
return 1
|
|
|
|
elif args.command == "web":
|
|
# 导入Web服务模块
|
|
from interface.web.app import run_server
|
|
|
|
# 启动Web服务
|
|
run_server(host=args.host, port=args.port, debug=args.debug)
|
|
|
|
elif args.command == "api":
|
|
# 导入API服务模块
|
|
from interface.api import run_server
|
|
|
|
# 启动API服务
|
|
run_server(host=args.host, port=args.port)
|
|
|
|
elif args.command == "cli":
|
|
# 导入CLI模块
|
|
from interface.cli import main as cli_main
|
|
|
|
# 将命令行参数转换为CLI模块可接受的格式
|
|
sys.argv = ["interface/cli.py"]
|
|
|
|
if args.interactive:
|
|
sys.argv.append("interactive")
|
|
if args.model_path:
|
|
sys.argv.extend(["--model_path", args.model_path])
|
|
elif args.model_path:
|
|
sys.argv.extend(["list", "--model_path", args.model_path])
|
|
else:
|
|
sys.argv.append("list")
|
|
|
|
# 调用CLI主函数
|
|
return cli_main()
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|