167 lines
4.8 KiB
Python
167 lines
4.8 KiB
Python
"""
|
||
Web路由模块:定义Web应用的路由
|
||
"""
|
||
from flask import Blueprint, render_template, request, jsonify, redirect, url_for, session, flash
|
||
from werkzeug.utils import secure_filename
|
||
import os
|
||
import pandas as pd
|
||
import io
|
||
import time
|
||
from typing import List, Dict, Tuple, Optional
|
||
|
||
from interface.web.app import get_predictor, allowed_file, app
|
||
|
||
# 创建蓝图
|
||
bp = Blueprint('routes', __name__)
|
||
|
||
|
||
@bp.route('/')
|
||
def index():
|
||
"""首页"""
|
||
return render_template('index.html')
|
||
|
||
|
||
@bp.route('/predict_text', methods=['GET', 'POST'])
|
||
def predict_text():
|
||
"""文本预测页面"""
|
||
if request.method == 'POST':
|
||
text = request.form.get('text', '')
|
||
top_k = int(request.form.get('top_k', 3))
|
||
|
||
if not text:
|
||
flash('请输入文本内容')
|
||
return render_template('predict_text.html')
|
||
|
||
try:
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 预测
|
||
result = predictor.predict(
|
||
text=text,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 准备结果
|
||
predictions = result if top_k > 1 else [result]
|
||
|
||
return render_template('predict_text.html', text=text, predictions=predictions)
|
||
|
||
except Exception as e:
|
||
flash(f'预测失败: {str(e)}')
|
||
return render_template('predict_text.html', text=text)
|
||
|
||
return render_template('predict_text.html')
|
||
|
||
|
||
@bp.route('/predict_file', methods=['GET', 'POST'])
|
||
def predict_file():
|
||
"""文件预测页面"""
|
||
if request.method == 'POST':
|
||
# 检查是否有文件上传
|
||
if 'file' not in request.files:
|
||
flash('未选择文件')
|
||
return render_template('predict_file.html')
|
||
|
||
file = request.files['file']
|
||
|
||
# 如果用户没有选择文件
|
||
if file.filename == '':
|
||
flash('未选择文件')
|
||
return render_template('predict_file.html')
|
||
|
||
if file and allowed_file(file.filename):
|
||
# 获取参数
|
||
top_k = int(request.form.get('top_k', 3))
|
||
text_column = request.form.get('text_column', 'text')
|
||
|
||
try:
|
||
# 安全地保存文件
|
||
filename = secure_filename(file.filename)
|
||
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
||
file.save(file_path)
|
||
|
||
# 处理文件
|
||
# 注意:此处与app.py中的predict_file重复,实际项目中应该将逻辑抽取到单独的函数中
|
||
# 这里简化处理,仅返回渲染模板
|
||
|
||
return render_template('predict_file.html', filename=filename)
|
||
|
||
except Exception as e:
|
||
flash(f'处理失败: {str(e)}')
|
||
return render_template('predict_file.html')
|
||
else:
|
||
flash(f'不支持的文件类型')
|
||
return render_template('predict_file.html')
|
||
|
||
return render_template('predict_file.html')
|
||
|
||
|
||
@bp.route('/batch_predict', methods=['GET', 'POST'])
|
||
def batch_predict():
|
||
"""批量预测页面"""
|
||
if request.method == 'POST':
|
||
texts = request.form.get('texts', '')
|
||
top_k = int(request.form.get('top_k', 3))
|
||
|
||
if not texts:
|
||
flash('请输入文本内容')
|
||
return render_template('batch_predict.html')
|
||
|
||
# 分割文本
|
||
text_list = [text.strip() for text in texts.split('\n') if text.strip()]
|
||
|
||
if not text_list:
|
||
flash('请输入有效的文本内容')
|
||
return render_template('batch_predict.html', texts=texts)
|
||
|
||
try:
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 批量预测
|
||
results = predictor.predict_batch(
|
||
texts=text_list,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 准备结果
|
||
batch_results = []
|
||
for i, result in enumerate(results):
|
||
batch_results.append({
|
||
'id': i + 1,
|
||
'text': text_list[i],
|
||
'predictions': result if top_k > 1 else [result]
|
||
})
|
||
|
||
return render_template(
|
||
'batch_predict.html',
|
||
texts=texts,
|
||
batch_results=batch_results
|
||
)
|
||
|
||
except Exception as e:
|
||
flash(f'预测失败: {str(e)}')
|
||
return render_template('batch_predict.html', texts=texts)
|
||
|
||
return render_template('batch_predict.html')
|
||
|
||
|
||
@bp.route('/models')
|
||
def list_models():
|
||
"""模型列表页面"""
|
||
return render_template('models.html')
|
||
|
||
|
||
@bp.route('/about')
|
||
def about():
|
||
"""关于页面"""
|
||
return render_template('about.html')
|
||
|
||
|
||
# 注册蓝图
|
||
def init_app(app):
|
||
app.register_blueprint(bp)
|