2025-03-08 01:34:36 +08:00

261 lines
8.2 KiB
Python

return render_template('predict_text.html', text=text, predictions=predictions)
except Exception as e:
logger.error(f"预测文本时出错: {e}")
flash(f'预测失败: {str(e)}')
return render_template('predict_text.html', text=text)
return render_template('predict_text.html')
@app.route('/predict_file', methods=['GET', 'POST'])
def predict_file():
"""文件预测页面"""
if request.method == 'POST':
# 检查是否有文件上传
if 'file' not in request.files:
flash('未选择文件')
return render_template('predict_file.html')
file = request.files['file']
# 如果用户没有选择文件
if file.filename == '':
flash('未选择文件')
return render_template('predict_file.html')
if file and allowed_file(file.filename):
# 获取参数
top_k = int(request.form.get('top_k', 3))
text_column = request.form.get('text_column', 'text')
try:
# 安全地保存文件
filename = secure_filename(file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
# 处理文件类型
if filename.endswith('.txt'):
# 文本文件
with open(file_path, 'r', encoding='utf-8') as f:
text = f.read()
# 获取预测器
predictor = get_predictor()
# 预测
result = predictor.predict(
text=text,
return_top_k=top_k,
return_probabilities=True
)
# 准备结果
predictions = result if top_k > 1 else [result]
return render_template(
'predict_file.html',
filename=filename,
file_type='text',
predictions=predictions
)
elif filename.endswith(('.csv', '.xls', '.xlsx')):
# 表格文件
if filename.endswith('.csv'):
df = pd.read_csv(file_path)
else:
df = pd.read_excel(file_path)
# 检查文本列是否存在
if text_column not in df.columns:
flash(f"文件中没有找到列: {text_column}")
return render_template('predict_file.html')
# 提取文本
texts = df[text_column].fillna('').tolist()
# 获取预测器
predictor = get_predictor()
# 批量预测
results = predictor.predict_batch(
texts=texts[:100], # 仅处理前100条记录
return_top_k=top_k,
return_probabilities=True
)
# 准备结果
batch_results = []
for i, result in enumerate(results):
batch_results.append({
'id': i,
'text': texts[i][:100] + '...' if len(texts[i]) > 100 else texts[i],
'predictions': result if top_k > 1 else [result]
})
return render_template(
'predict_file.html',
filename=filename,
file_type='table',
total_records=len(texts),
processed_records=min(100, len(texts)),
batch_results=batch_results
)
else:
flash(f"不支持的文件类型: {filename}")
return render_template('predict_file.html')
except Exception as e:
logger.error(f"处理文件时出错: {e}")
flash(f'处理失败: {str(e)}')
return render_template('predict_file.html')
else:
flash(f'不支持的文件类型,允许的类型: {", ".join(ALLOWED_EXTENSIONS)}')
return render_template('predict_file.html')
return render_template('predict_file.html')
@app.route('/batch_predict', methods=['GET', 'POST'])
def batch_predict():
"""批量预测页面"""
if request.method == 'POST':
texts = request.form.get('texts', '')
top_k = int(request.form.get('top_k', 3))
if not texts:
flash('请输入文本内容')
return render_template('batch_predict.html')
# 分割文本
text_list = [text.strip() for text in texts.split('\n') if text.strip()]
if not text_list:
flash('请输入有效的文本内容')
return render_template('batch_predict.html', texts=texts)
try:
# 获取预测器
predictor = get_predictor()
# 批量预测
results = predictor.predict_batch(
texts=text_list,
return_top_k=top_k,
return_probabilities=True
)
# 准备结果
batch_results = []
for i, result in enumerate(results):
batch_results.append({
'id': i + 1,
'text': text_list[i],
'predictions': result if top_k > 1 else [result]
})
return render_template(
'batch_predict.html',
texts=texts,
batch_results=batch_results
)
except Exception as e:
logger.error(f"批量预测时出错: {e}")
flash(f'预测失败: {str(e)}')
return render_template('batch_predict.html', texts=texts)
return render_template('batch_predict.html')
@app.route('/models')
def list_models():
"""模型列表页面"""
try:
# 获取可用模型列表
models_info = ModelFactory.get_available_models()
return render_template('models.html', models=models_info)
except Exception as e:
logger.error(f"获取模型列表时出错: {e}")
flash(f'获取模型列表失败: {str(e)}')
return render_template('models.html', models=[])
@app.route('/about')
def about():
"""关于页面"""
return render_template('about.html')
@app.errorhandler(404)
def page_not_found(e):
"""404错误处理"""
return render_template('404.html'), 404
@app.errorhandler(500)
def internal_server_error(e):
"""500错误处理"""
return render_template('500.html'), 500
# 过滤器
@app.template_filter('format_time')
def format_time_filter(timestamp):
"""格式化时间戳"""
if isinstance(timestamp, str):
try:
dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
return dt.strftime("%Y年%m月%d%H:%M")
except:
return timestamp
return timestamp
@app.template_filter('truncate_text')
def truncate_text_filter(text, length=100):
"""截断文本"""
if len(text) <= length:
return text
return text[:length] + '...'
@app.template_filter('format_percent')
def format_percent_filter(value):
"""格式化百分比"""
if isinstance(value, (int, float)):
return f"{value * 100:.2f}%"
return value
def run_server(host: str = "0.0.0.0", port: int = 5000, debug: bool = False):
"""
运行Web服务器
Args:
host: 主机地址
port: 端口号
debug: 是否开启调试模式
"""
app.run(host=host, port=port, debug=debug)
if __name__ == "__main__":
# 解析命令行参数
import argparse
parser = argparse.ArgumentParser(description="中文文本分类系统Web应用")
parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址")
parser.add_argument("--port", type=int, default=5000, help="服务器端口号")
parser.add_argument("--debug", action="store_true", help="是否开启调试模式")
args = parser.parse_args()
# 运行服务器
run_server(host=args.host, port=args.port, debug=args.debug)