2025-03-08 01:34:36 +08:00

392 lines
9.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
API接口模块提供REST API接口
"""
import os
import sys
import json
import time
from typing import List, Dict, Tuple, Optional, Any, Union
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query, Depends, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
import asyncio
import pandas as pd
import io
# 将项目根目录添加到sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config.system_config import CLASSIFIERS_DIR, CATEGORIES
from models.model_factory import ModelFactory
from preprocessing.tokenization import ChineseTokenizer
from preprocessing.vectorizer import SequenceVectorizer
from inference.predictor import Predictor
from inference.batch_processor import BatchProcessor
from utils.logger import get_logger
logger = get_logger("API")
# 数据模型
class TextItem(BaseModel):
text: str
id: Optional[str] = None
class BatchPredictRequest(BaseModel):
texts: List[TextItem]
top_k: Optional[int] = 1
class BatchFileRequest(BaseModel):
file_paths: List[str]
top_k: Optional[int] = 1
class ModelInfo(BaseModel):
id: str
name: str
type: str
num_classes: int
created_time: str
file_size: str
# 应用实例
app = FastAPI(
title="中文文本分类系统API",
description="提供中文文本分类功能的REST API",
version="1.0.0"
)
# 允许跨域请求
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局对象
predictor = None
def get_predictor() -> Predictor:
"""
获取或创建全局Predictor实例
Returns:
Predictor实例
"""
global predictor
if predictor is None:
# 获取可用模型列表
models_info = ModelFactory.get_available_models()
if not models_info:
raise HTTPException(status_code=500, detail="未找到可用的模型")
# 使用最新的模型
model_path = models_info[0]['path']
logger.info(f"API加载模型: {model_path}")
# 加载模型
model = ModelFactory.load_model(model_path)
# 创建分词器
tokenizer = ChineseTokenizer()
# 创建预测器
predictor = Predictor(
model=model,
tokenizer=tokenizer,
class_names=CATEGORIES,
batch_size=64
)
return predictor
@app.get("/")
async def root():
"""API根路径"""
return {"message": "欢迎使用中文文本分类系统API"}
@app.post("/predict/text")
async def predict_text(text: str = Form(...), top_k: int = Form(1)):
"""
预测单条文本
Args:
text: 要预测的文本
top_k: 返回概率最高的前k个类别
Returns:
预测结果
"""
logger.info(f"接收到文本预测请求,文本长度: {len(text)}")
try:
# 获取预测器
predictor = get_predictor()
# 预测
start_time = time.time()
result = predictor.predict(
text=text,
return_top_k=top_k,
return_probabilities=True
)
prediction_time = time.time() - start_time
# 构建响应
response = {
"success": True,
"predictions": result if top_k > 1 else [result],
"time": prediction_time
}
return response
except Exception as e:
logger.error(f"预测文本时出错: {e}")
raise HTTPException(status_code=500, detail=f"预测文本时出错: {str(e)}")
@app.post("/predict/batch")
async def predict_batch(request: BatchPredictRequest):
"""
批量预测文本
Args:
request: 包含文本列表和参数的请求
Returns:
批量预测结果
"""
texts = [item.text for item in request.texts]
ids = [item.id or str(i) for i, item in enumerate(request.texts)]
logger.info(f"接收到批量预测请求,共 {len(texts)} 条文本")
try:
# 获取预测器
predictor = get_predictor()
# 预测
start_time = time.time()
results = predictor.predict_batch(
texts=texts,
return_top_k=request.top_k,
return_probabilities=True
)
prediction_time = time.time() - start_time
# 构建响应
response = {
"success": True,
"total": len(texts),
"time": prediction_time,
"results": {}
}
# 将结果关联到ID
for i, (id_val, result) in enumerate(zip(ids, results)):
response["results"][id_val] = result
return response
except Exception as e:
logger.error(f"批量预测文本时出错: {e}")
raise HTTPException(status_code=500, detail=f"批量预测文本时出错: {str(e)}")
@app.post("/predict/file")
async def predict_file(file: UploadFile = File(...), top_k: int = Form(1)):
"""
预测文件内容
Args:
file: 上传的文件
top_k: 返回概率最高的前k个类别
Returns:
预测结果
"""
logger.info(f"接收到文件预测请求,文件名: {file.filename}")
try:
# 读取文件内容
content = await file.read()
# 根据文件类型处理内容
if file.filename.endswith(('.txt', '.md')):
# 文本文件
text = content.decode('utf-8')
# 获取预测器
predictor = get_predictor()
# 预测
result = predictor.predict(
text=text,
return_top_k=top_k,
return_probabilities=True
)
# 构建响应
response = {
"success": True,
"filename": file.filename,
"predictions": result if top_k > 1 else [result]
}
return response
elif file.filename.endswith(('.csv', '.xls', '.xlsx')):
# 表格文件
if file.filename.endswith('.csv'):
df = pd.read_csv(io.BytesIO(content))
else:
df = pd.read_excel(io.BytesIO(content))
# 查找可能的文本列
text_columns = [col for col in df.columns if df[col].dtype == 'object']
if not text_columns:
raise HTTPException(status_code=400, detail="文件中没有找到可能的文本列")
# 使用第一个文本列
text_column = text_columns[0]
texts = df[text_column].fillna('').tolist()
# 获取预测器
predictor = get_predictor()
# 批量预测
results = predictor.predict_batch(
texts=texts,
return_top_k=top_k,
return_probabilities=True
)
# 构建响应
response = {
"success": True,
"filename": file.filename,
"text_column": text_column,
"total": len(texts),
"results": results
}
return response
else:
raise HTTPException(status_code=400, detail=f"不支持的文件类型: {file.filename}")
except Exception as e:
logger.error(f"预测文件时出错: {e}")
raise HTTPException(status_code=500, detail=f"预测文件时出错: {str(e)}")
@app.get("/models")
async def list_models():
"""
列出可用的模型
Returns:
可用模型列表
"""
try:
# 获取可用模型列表
models_info = ModelFactory.get_available_models()
# 转换为响应格式
models = []
for info in models_info:
models.append(ModelInfo(
id=os.path.basename(info['path']),
name=info['name'],
type=info['type'],
num_classes=info['num_classes'],
created_time=info['created_time'],
file_size=info['file_size']
))
return {"models": models}
except Exception as e:
logger.error(f"获取模型列表时出错: {e}")
raise HTTPException(status_code=500, detail=f"获取模型列表时出错: {str(e)}")
@app.get("/categories")
async def list_categories():
"""
列出支持的类别
Returns:
支持的类别列表
"""
try:
return {"categories": CATEGORIES}
except Exception as e:
logger.error(f"获取类别列表时出错: {e}")
raise HTTPException(status_code=500, detail=f"获取类别列表时出错: {str(e)}")
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""
记录请求日志
Args:
request: 请求对象
call_next: 下一个处理函数
Returns:
响应对象
"""
start_time = time.time()
# 记录请求信息
logger.info(f"请求: {request.method} {request.url}")
# 处理请求
response = await call_next(request)
# 记录响应信息
process_time = time.time() - start_time
logger.info(f"响应: {response.status_code} ({process_time:.2f}s)")
return response
def run_server(host: str = "0.0.0.0", port: int = 8000):
"""
运行API服务器
Args:
host: 主机地址
port: 端口号
"""
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
# 解析命令行参数
import argparse
parser = argparse.ArgumentParser(description="中文文本分类系统API服务器")
parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址")
parser.add_argument("--port", type=int, default=8000, help="服务器端口号")
args = parser.parse_args()
# 运行服务器
run_server(host=args.host, port=args.port)