2025-03-08 01:34:36 +08:00

167 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Web路由模块定义Web应用的路由
"""
from flask import Blueprint, render_template, request, jsonify, redirect, url_for, session, flash
from werkzeug.utils import secure_filename
import os
import pandas as pd
import io
import time
from typing import List, Dict, Tuple, Optional
from interface.web.app import get_predictor, allowed_file, app
# 创建蓝图
bp = Blueprint('routes', __name__)
@bp.route('/')
def index():
"""首页"""
return render_template('index.html')
@bp.route('/predict_text', methods=['GET', 'POST'])
def predict_text():
"""文本预测页面"""
if request.method == 'POST':
text = request.form.get('text', '')
top_k = int(request.form.get('top_k', 3))
if not text:
flash('请输入文本内容')
return render_template('predict_text.html')
try:
# 获取预测器
predictor = get_predictor()
# 预测
result = predictor.predict(
text=text,
return_top_k=top_k,
return_probabilities=True
)
# 准备结果
predictions = result if top_k > 1 else [result]
return render_template('predict_text.html', text=text, predictions=predictions)
except Exception as e:
flash(f'预测失败: {str(e)}')
return render_template('predict_text.html', text=text)
return render_template('predict_text.html')
@bp.route('/predict_file', methods=['GET', 'POST'])
def predict_file():
"""文件预测页面"""
if request.method == 'POST':
# 检查是否有文件上传
if 'file' not in request.files:
flash('未选择文件')
return render_template('predict_file.html')
file = request.files['file']
# 如果用户没有选择文件
if file.filename == '':
flash('未选择文件')
return render_template('predict_file.html')
if file and allowed_file(file.filename):
# 获取参数
top_k = int(request.form.get('top_k', 3))
text_column = request.form.get('text_column', 'text')
try:
# 安全地保存文件
filename = secure_filename(file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
# 处理文件
# 注意此处与app.py中的predict_file重复实际项目中应该将逻辑抽取到单独的函数中
# 这里简化处理,仅返回渲染模板
return render_template('predict_file.html', filename=filename)
except Exception as e:
flash(f'处理失败: {str(e)}')
return render_template('predict_file.html')
else:
flash(f'不支持的文件类型')
return render_template('predict_file.html')
return render_template('predict_file.html')
@bp.route('/batch_predict', methods=['GET', 'POST'])
def batch_predict():
"""批量预测页面"""
if request.method == 'POST':
texts = request.form.get('texts', '')
top_k = int(request.form.get('top_k', 3))
if not texts:
flash('请输入文本内容')
return render_template('batch_predict.html')
# 分割文本
text_list = [text.strip() for text in texts.split('\n') if text.strip()]
if not text_list:
flash('请输入有效的文本内容')
return render_template('batch_predict.html', texts=texts)
try:
# 获取预测器
predictor = get_predictor()
# 批量预测
results = predictor.predict_batch(
texts=text_list,
return_top_k=top_k,
return_probabilities=True
)
# 准备结果
batch_results = []
for i, result in enumerate(results):
batch_results.append({
'id': i + 1,
'text': text_list[i],
'predictions': result if top_k > 1 else [result]
})
return render_template(
'batch_predict.html',
texts=texts,
batch_results=batch_results
)
except Exception as e:
flash(f'预测失败: {str(e)}')
return render_template('batch_predict.html', texts=texts)
return render_template('batch_predict.html')
@bp.route('/models')
def list_models():
"""模型列表页面"""
return render_template('models.html')
@bp.route('/about')
def about():
"""关于页面"""
return render_template('about.html')
# 注册蓝图
def init_app(app):
app.register_blueprint(bp)