392 lines
9.8 KiB
Python
392 lines
9.8 KiB
Python
"""
|
||
API接口模块:提供REST API接口
|
||
"""
|
||
import os
|
||
import sys
|
||
import json
|
||
import time
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query, Depends, Request
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
import uvicorn
|
||
import asyncio
|
||
import pandas as pd
|
||
import io
|
||
|
||
# 将项目根目录添加到sys.path
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from config.system_config import CLASSIFIERS_DIR, CATEGORIES
|
||
from models.model_factory import ModelFactory
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
from preprocessing.vectorizer import SequenceVectorizer
|
||
from inference.predictor import Predictor
|
||
from inference.batch_processor import BatchProcessor
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("API")
|
||
|
||
|
||
# 数据模型
|
||
class TextItem(BaseModel):
|
||
text: str
|
||
id: Optional[str] = None
|
||
|
||
|
||
class BatchPredictRequest(BaseModel):
|
||
texts: List[TextItem]
|
||
top_k: Optional[int] = 1
|
||
|
||
|
||
class BatchFileRequest(BaseModel):
|
||
file_paths: List[str]
|
||
top_k: Optional[int] = 1
|
||
|
||
|
||
class ModelInfo(BaseModel):
|
||
id: str
|
||
name: str
|
||
type: str
|
||
num_classes: int
|
||
created_time: str
|
||
file_size: str
|
||
|
||
|
||
# 应用实例
|
||
app = FastAPI(
|
||
title="中文文本分类系统API",
|
||
description="提供中文文本分类功能的REST API",
|
||
version="1.0.0"
|
||
)
|
||
|
||
# 允许跨域请求
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# 全局对象
|
||
predictor = None
|
||
|
||
|
||
def get_predictor() -> Predictor:
|
||
"""
|
||
获取或创建全局Predictor实例
|
||
|
||
Returns:
|
||
Predictor实例
|
||
"""
|
||
global predictor
|
||
if predictor is None:
|
||
# 获取可用模型列表
|
||
models_info = ModelFactory.get_available_models()
|
||
|
||
if not models_info:
|
||
raise HTTPException(status_code=500, detail="未找到可用的模型")
|
||
|
||
# 使用最新的模型
|
||
model_path = models_info[0]['path']
|
||
logger.info(f"API加载模型: {model_path}")
|
||
|
||
# 加载模型
|
||
model = ModelFactory.load_model(model_path)
|
||
|
||
# 创建分词器
|
||
tokenizer = ChineseTokenizer()
|
||
|
||
# 创建预测器
|
||
predictor = Predictor(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
class_names=CATEGORIES,
|
||
batch_size=64
|
||
)
|
||
|
||
return predictor
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""API根路径"""
|
||
return {"message": "欢迎使用中文文本分类系统API"}
|
||
|
||
|
||
@app.post("/predict/text")
|
||
async def predict_text(text: str = Form(...), top_k: int = Form(1)):
|
||
"""
|
||
预测单条文本
|
||
|
||
Args:
|
||
text: 要预测的文本
|
||
top_k: 返回概率最高的前k个类别
|
||
|
||
Returns:
|
||
预测结果
|
||
"""
|
||
logger.info(f"接收到文本预测请求,文本长度: {len(text)}")
|
||
|
||
try:
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 预测
|
||
start_time = time.time()
|
||
result = predictor.predict(
|
||
text=text,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
prediction_time = time.time() - start_time
|
||
|
||
# 构建响应
|
||
response = {
|
||
"success": True,
|
||
"predictions": result if top_k > 1 else [result],
|
||
"time": prediction_time
|
||
}
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error(f"预测文本时出错: {e}")
|
||
raise HTTPException(status_code=500, detail=f"预测文本时出错: {str(e)}")
|
||
|
||
|
||
@app.post("/predict/batch")
|
||
async def predict_batch(request: BatchPredictRequest):
|
||
"""
|
||
批量预测文本
|
||
|
||
Args:
|
||
request: 包含文本列表和参数的请求
|
||
|
||
Returns:
|
||
批量预测结果
|
||
"""
|
||
texts = [item.text for item in request.texts]
|
||
ids = [item.id or str(i) for i, item in enumerate(request.texts)]
|
||
|
||
logger.info(f"接收到批量预测请求,共 {len(texts)} 条文本")
|
||
|
||
try:
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 预测
|
||
start_time = time.time()
|
||
results = predictor.predict_batch(
|
||
texts=texts,
|
||
return_top_k=request.top_k,
|
||
return_probabilities=True
|
||
)
|
||
prediction_time = time.time() - start_time
|
||
|
||
# 构建响应
|
||
response = {
|
||
"success": True,
|
||
"total": len(texts),
|
||
"time": prediction_time,
|
||
"results": {}
|
||
}
|
||
|
||
# 将结果关联到ID
|
||
for i, (id_val, result) in enumerate(zip(ids, results)):
|
||
response["results"][id_val] = result
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量预测文本时出错: {e}")
|
||
raise HTTPException(status_code=500, detail=f"批量预测文本时出错: {str(e)}")
|
||
|
||
|
||
@app.post("/predict/file")
|
||
async def predict_file(file: UploadFile = File(...), top_k: int = Form(1)):
|
||
"""
|
||
预测文件内容
|
||
|
||
Args:
|
||
file: 上传的文件
|
||
top_k: 返回概率最高的前k个类别
|
||
|
||
Returns:
|
||
预测结果
|
||
"""
|
||
logger.info(f"接收到文件预测请求,文件名: {file.filename}")
|
||
|
||
try:
|
||
# 读取文件内容
|
||
content = await file.read()
|
||
|
||
# 根据文件类型处理内容
|
||
if file.filename.endswith(('.txt', '.md')):
|
||
# 文本文件
|
||
text = content.decode('utf-8')
|
||
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 预测
|
||
result = predictor.predict(
|
||
text=text,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 构建响应
|
||
response = {
|
||
"success": True,
|
||
"filename": file.filename,
|
||
"predictions": result if top_k > 1 else [result]
|
||
}
|
||
|
||
return response
|
||
|
||
elif file.filename.endswith(('.csv', '.xls', '.xlsx')):
|
||
# 表格文件
|
||
if file.filename.endswith('.csv'):
|
||
df = pd.read_csv(io.BytesIO(content))
|
||
else:
|
||
df = pd.read_excel(io.BytesIO(content))
|
||
|
||
# 查找可能的文本列
|
||
text_columns = [col for col in df.columns if df[col].dtype == 'object']
|
||
|
||
if not text_columns:
|
||
raise HTTPException(status_code=400, detail="文件中没有找到可能的文本列")
|
||
|
||
# 使用第一个文本列
|
||
text_column = text_columns[0]
|
||
texts = df[text_column].fillna('').tolist()
|
||
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 批量预测
|
||
results = predictor.predict_batch(
|
||
texts=texts,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 构建响应
|
||
response = {
|
||
"success": True,
|
||
"filename": file.filename,
|
||
"text_column": text_column,
|
||
"total": len(texts),
|
||
"results": results
|
||
}
|
||
|
||
return response
|
||
|
||
else:
|
||
raise HTTPException(status_code=400, detail=f"不支持的文件类型: {file.filename}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"预测文件时出错: {e}")
|
||
raise HTTPException(status_code=500, detail=f"预测文件时出错: {str(e)}")
|
||
|
||
|
||
@app.get("/models")
|
||
async def list_models():
|
||
"""
|
||
列出可用的模型
|
||
|
||
Returns:
|
||
可用模型列表
|
||
"""
|
||
try:
|
||
# 获取可用模型列表
|
||
models_info = ModelFactory.get_available_models()
|
||
|
||
# 转换为响应格式
|
||
models = []
|
||
for info in models_info:
|
||
models.append(ModelInfo(
|
||
id=os.path.basename(info['path']),
|
||
name=info['name'],
|
||
type=info['type'],
|
||
num_classes=info['num_classes'],
|
||
created_time=info['created_time'],
|
||
file_size=info['file_size']
|
||
))
|
||
|
||
return {"models": models}
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取模型列表时出错: {e}")
|
||
raise HTTPException(status_code=500, detail=f"获取模型列表时出错: {str(e)}")
|
||
|
||
|
||
@app.get("/categories")
|
||
async def list_categories():
|
||
"""
|
||
列出支持的类别
|
||
|
||
Returns:
|
||
支持的类别列表
|
||
"""
|
||
try:
|
||
return {"categories": CATEGORIES}
|
||
except Exception as e:
|
||
logger.error(f"获取类别列表时出错: {e}")
|
||
raise HTTPException(status_code=500, detail=f"获取类别列表时出错: {str(e)}")
|
||
|
||
|
||
@app.middleware("http")
|
||
async def log_requests(request: Request, call_next):
|
||
"""
|
||
记录请求日志
|
||
|
||
Args:
|
||
request: 请求对象
|
||
call_next: 下一个处理函数
|
||
|
||
Returns:
|
||
响应对象
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 记录请求信息
|
||
logger.info(f"请求: {request.method} {request.url}")
|
||
|
||
# 处理请求
|
||
response = await call_next(request)
|
||
|
||
# 记录响应信息
|
||
process_time = time.time() - start_time
|
||
logger.info(f"响应: {response.status_code} ({process_time:.2f}s)")
|
||
|
||
return response
|
||
|
||
|
||
def run_server(host: str = "0.0.0.0", port: int = 8000):
|
||
"""
|
||
运行API服务器
|
||
|
||
Args:
|
||
host: 主机地址
|
||
port: 端口号
|
||
"""
|
||
uvicorn.run(app, host=host, port=port)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 解析命令行参数
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="中文文本分类系统API服务器")
|
||
parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址")
|
||
parser.add_argument("--port", type=int, default=8000, help="服务器端口号")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 运行服务器
|
||
run_server(host=args.host, port=args.port)
|