""" API接口模块:提供REST API接口 """ import os import sys import json import time from typing import List, Dict, Tuple, Optional, Any, Union from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query, Depends, Request from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import uvicorn import asyncio import pandas as pd import io # 将项目根目录添加到sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from config.system_config import CLASSIFIERS_DIR, CATEGORIES from models.model_factory import ModelFactory from preprocessing.tokenization import ChineseTokenizer from preprocessing.vectorizer import SequenceVectorizer from inference.predictor import Predictor from inference.batch_processor import BatchProcessor from utils.logger import get_logger logger = get_logger("API") # 数据模型 class TextItem(BaseModel): text: str id: Optional[str] = None class BatchPredictRequest(BaseModel): texts: List[TextItem] top_k: Optional[int] = 1 class BatchFileRequest(BaseModel): file_paths: List[str] top_k: Optional[int] = 1 class ModelInfo(BaseModel): id: str name: str type: str num_classes: int created_time: str file_size: str # 应用实例 app = FastAPI( title="中文文本分类系统API", description="提供中文文本分类功能的REST API", version="1.0.0" ) # 允许跨域请求 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 全局对象 predictor = None def get_predictor() -> Predictor: """ 获取或创建全局Predictor实例 Returns: Predictor实例 """ global predictor if predictor is None: # 获取可用模型列表 models_info = ModelFactory.get_available_models() if not models_info: raise HTTPException(status_code=500, detail="未找到可用的模型") # 使用最新的模型 model_path = models_info[0]['path'] logger.info(f"API加载模型: {model_path}") # 加载模型 model = ModelFactory.load_model(model_path) # 创建分词器 tokenizer = ChineseTokenizer() # 创建预测器 predictor = Predictor( model=model, tokenizer=tokenizer, class_names=CATEGORIES, batch_size=64 ) return predictor @app.get("/") async def root(): """API根路径""" return {"message": "欢迎使用中文文本分类系统API"} @app.post("/predict/text") async def predict_text(text: str = Form(...), top_k: int = Form(1)): """ 预测单条文本 Args: text: 要预测的文本 top_k: 返回概率最高的前k个类别 Returns: 预测结果 """ logger.info(f"接收到文本预测请求,文本长度: {len(text)}") try: # 获取预测器 predictor = get_predictor() # 预测 start_time = time.time() result = predictor.predict( text=text, return_top_k=top_k, return_probabilities=True ) prediction_time = time.time() - start_time # 构建响应 response = { "success": True, "predictions": result if top_k > 1 else [result], "time": prediction_time } return response except Exception as e: logger.error(f"预测文本时出错: {e}") raise HTTPException(status_code=500, detail=f"预测文本时出错: {str(e)}") @app.post("/predict/batch") async def predict_batch(request: BatchPredictRequest): """ 批量预测文本 Args: request: 包含文本列表和参数的请求 Returns: 批量预测结果 """ texts = [item.text for item in request.texts] ids = [item.id or str(i) for i, item in enumerate(request.texts)] logger.info(f"接收到批量预测请求,共 {len(texts)} 条文本") try: # 获取预测器 predictor = get_predictor() # 预测 start_time = time.time() results = predictor.predict_batch( texts=texts, return_top_k=request.top_k, return_probabilities=True ) prediction_time = time.time() - start_time # 构建响应 response = { "success": True, "total": len(texts), "time": prediction_time, "results": {} } # 将结果关联到ID for i, (id_val, result) in enumerate(zip(ids, results)): response["results"][id_val] = result return response except Exception as e: logger.error(f"批量预测文本时出错: {e}") raise HTTPException(status_code=500, detail=f"批量预测文本时出错: {str(e)}") @app.post("/predict/file") async def predict_file(file: UploadFile = File(...), top_k: int = Form(1)): """ 预测文件内容 Args: file: 上传的文件 top_k: 返回概率最高的前k个类别 Returns: 预测结果 """ logger.info(f"接收到文件预测请求,文件名: {file.filename}") try: # 读取文件内容 content = await file.read() # 根据文件类型处理内容 if file.filename.endswith(('.txt', '.md')): # 文本文件 text = content.decode('utf-8') # 获取预测器 predictor = get_predictor() # 预测 result = predictor.predict( text=text, return_top_k=top_k, return_probabilities=True ) # 构建响应 response = { "success": True, "filename": file.filename, "predictions": result if top_k > 1 else [result] } return response elif file.filename.endswith(('.csv', '.xls', '.xlsx')): # 表格文件 if file.filename.endswith('.csv'): df = pd.read_csv(io.BytesIO(content)) else: df = pd.read_excel(io.BytesIO(content)) # 查找可能的文本列 text_columns = [col for col in df.columns if df[col].dtype == 'object'] if not text_columns: raise HTTPException(status_code=400, detail="文件中没有找到可能的文本列") # 使用第一个文本列 text_column = text_columns[0] texts = df[text_column].fillna('').tolist() # 获取预测器 predictor = get_predictor() # 批量预测 results = predictor.predict_batch( texts=texts, return_top_k=top_k, return_probabilities=True ) # 构建响应 response = { "success": True, "filename": file.filename, "text_column": text_column, "total": len(texts), "results": results } return response else: raise HTTPException(status_code=400, detail=f"不支持的文件类型: {file.filename}") except Exception as e: logger.error(f"预测文件时出错: {e}") raise HTTPException(status_code=500, detail=f"预测文件时出错: {str(e)}") @app.get("/models") async def list_models(): """ 列出可用的模型 Returns: 可用模型列表 """ try: # 获取可用模型列表 models_info = ModelFactory.get_available_models() # 转换为响应格式 models = [] for info in models_info: models.append(ModelInfo( id=os.path.basename(info['path']), name=info['name'], type=info['type'], num_classes=info['num_classes'], created_time=info['created_time'], file_size=info['file_size'] )) return {"models": models} except Exception as e: logger.error(f"获取模型列表时出错: {e}") raise HTTPException(status_code=500, detail=f"获取模型列表时出错: {str(e)}") @app.get("/categories") async def list_categories(): """ 列出支持的类别 Returns: 支持的类别列表 """ try: return {"categories": CATEGORIES} except Exception as e: logger.error(f"获取类别列表时出错: {e}") raise HTTPException(status_code=500, detail=f"获取类别列表时出错: {str(e)}") @app.middleware("http") async def log_requests(request: Request, call_next): """ 记录请求日志 Args: request: 请求对象 call_next: 下一个处理函数 Returns: 响应对象 """ start_time = time.time() # 记录请求信息 logger.info(f"请求: {request.method} {request.url}") # 处理请求 response = await call_next(request) # 记录响应信息 process_time = time.time() - start_time logger.info(f"响应: {response.status_code} ({process_time:.2f}s)") return response def run_server(host: str = "0.0.0.0", port: int = 8000): """ 运行API服务器 Args: host: 主机地址 port: 端口号 """ uvicorn.run(app, host=host, port=port) if __name__ == "__main__": # 解析命令行参数 import argparse parser = argparse.ArgumentParser(description="中文文本分类系统API服务器") parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址") parser.add_argument("--port", type=int, default=8000, help="服务器端口号") args = parser.parse_args() # 运行服务器 run_server(host=args.host, port=args.port)