261 lines
8.2 KiB
Python
261 lines
8.2 KiB
Python
return render_template('predict_text.html', text=text, predictions=predictions)
|
|
|
|
except Exception as e:
|
|
logger.error(f"预测文本时出错: {e}")
|
|
flash(f'预测失败: {str(e)}')
|
|
return render_template('predict_text.html', text=text)
|
|
|
|
return render_template('predict_text.html')
|
|
|
|
|
|
@app.route('/predict_file', methods=['GET', 'POST'])
|
|
def predict_file():
|
|
"""文件预测页面"""
|
|
if request.method == 'POST':
|
|
# 检查是否有文件上传
|
|
if 'file' not in request.files:
|
|
flash('未选择文件')
|
|
return render_template('predict_file.html')
|
|
|
|
file = request.files['file']
|
|
|
|
# 如果用户没有选择文件
|
|
if file.filename == '':
|
|
flash('未选择文件')
|
|
return render_template('predict_file.html')
|
|
|
|
if file and allowed_file(file.filename):
|
|
# 获取参数
|
|
top_k = int(request.form.get('top_k', 3))
|
|
text_column = request.form.get('text_column', 'text')
|
|
|
|
try:
|
|
# 安全地保存文件
|
|
filename = secure_filename(file.filename)
|
|
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
|
file.save(file_path)
|
|
|
|
# 处理文件类型
|
|
if filename.endswith('.txt'):
|
|
# 文本文件
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
text = f.read()
|
|
|
|
# 获取预测器
|
|
predictor = get_predictor()
|
|
|
|
# 预测
|
|
result = predictor.predict(
|
|
text=text,
|
|
return_top_k=top_k,
|
|
return_probabilities=True
|
|
)
|
|
|
|
# 准备结果
|
|
predictions = result if top_k > 1 else [result]
|
|
|
|
return render_template(
|
|
'predict_file.html',
|
|
filename=filename,
|
|
file_type='text',
|
|
predictions=predictions
|
|
)
|
|
|
|
elif filename.endswith(('.csv', '.xls', '.xlsx')):
|
|
# 表格文件
|
|
if filename.endswith('.csv'):
|
|
df = pd.read_csv(file_path)
|
|
else:
|
|
df = pd.read_excel(file_path)
|
|
|
|
# 检查文本列是否存在
|
|
if text_column not in df.columns:
|
|
flash(f"文件中没有找到列: {text_column}")
|
|
return render_template('predict_file.html')
|
|
|
|
# 提取文本
|
|
texts = df[text_column].fillna('').tolist()
|
|
|
|
# 获取预测器
|
|
predictor = get_predictor()
|
|
|
|
# 批量预测
|
|
results = predictor.predict_batch(
|
|
texts=texts[:100], # 仅处理前100条记录
|
|
return_top_k=top_k,
|
|
return_probabilities=True
|
|
)
|
|
|
|
# 准备结果
|
|
batch_results = []
|
|
for i, result in enumerate(results):
|
|
batch_results.append({
|
|
'id': i,
|
|
'text': texts[i][:100] + '...' if len(texts[i]) > 100 else texts[i],
|
|
'predictions': result if top_k > 1 else [result]
|
|
})
|
|
|
|
return render_template(
|
|
'predict_file.html',
|
|
filename=filename,
|
|
file_type='table',
|
|
total_records=len(texts),
|
|
processed_records=min(100, len(texts)),
|
|
batch_results=batch_results
|
|
)
|
|
|
|
else:
|
|
flash(f"不支持的文件类型: {filename}")
|
|
return render_template('predict_file.html')
|
|
|
|
except Exception as e:
|
|
logger.error(f"处理文件时出错: {e}")
|
|
flash(f'处理失败: {str(e)}')
|
|
return render_template('predict_file.html')
|
|
else:
|
|
flash(f'不支持的文件类型,允许的类型: {", ".join(ALLOWED_EXTENSIONS)}')
|
|
return render_template('predict_file.html')
|
|
|
|
return render_template('predict_file.html')
|
|
|
|
|
|
@app.route('/batch_predict', methods=['GET', 'POST'])
|
|
def batch_predict():
|
|
"""批量预测页面"""
|
|
if request.method == 'POST':
|
|
texts = request.form.get('texts', '')
|
|
top_k = int(request.form.get('top_k', 3))
|
|
|
|
if not texts:
|
|
flash('请输入文本内容')
|
|
return render_template('batch_predict.html')
|
|
|
|
# 分割文本
|
|
text_list = [text.strip() for text in texts.split('\n') if text.strip()]
|
|
|
|
if not text_list:
|
|
flash('请输入有效的文本内容')
|
|
return render_template('batch_predict.html', texts=texts)
|
|
|
|
try:
|
|
# 获取预测器
|
|
predictor = get_predictor()
|
|
|
|
# 批量预测
|
|
results = predictor.predict_batch(
|
|
texts=text_list,
|
|
return_top_k=top_k,
|
|
return_probabilities=True
|
|
)
|
|
|
|
# 准备结果
|
|
batch_results = []
|
|
for i, result in enumerate(results):
|
|
batch_results.append({
|
|
'id': i + 1,
|
|
'text': text_list[i],
|
|
'predictions': result if top_k > 1 else [result]
|
|
})
|
|
|
|
return render_template(
|
|
'batch_predict.html',
|
|
texts=texts,
|
|
batch_results=batch_results
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"批量预测时出错: {e}")
|
|
flash(f'预测失败: {str(e)}')
|
|
return render_template('batch_predict.html', texts=texts)
|
|
|
|
return render_template('batch_predict.html')
|
|
|
|
|
|
@app.route('/models')
|
|
def list_models():
|
|
"""模型列表页面"""
|
|
try:
|
|
# 获取可用模型列表
|
|
models_info = ModelFactory.get_available_models()
|
|
|
|
return render_template('models.html', models=models_info)
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取模型列表时出错: {e}")
|
|
flash(f'获取模型列表失败: {str(e)}')
|
|
return render_template('models.html', models=[])
|
|
|
|
|
|
@app.route('/about')
|
|
def about():
|
|
"""关于页面"""
|
|
return render_template('about.html')
|
|
|
|
|
|
@app.errorhandler(404)
|
|
def page_not_found(e):
|
|
"""404错误处理"""
|
|
return render_template('404.html'), 404
|
|
|
|
|
|
@app.errorhandler(500)
|
|
def internal_server_error(e):
|
|
"""500错误处理"""
|
|
return render_template('500.html'), 500
|
|
|
|
|
|
# 过滤器
|
|
@app.template_filter('format_time')
|
|
def format_time_filter(timestamp):
|
|
"""格式化时间戳"""
|
|
if isinstance(timestamp, str):
|
|
try:
|
|
dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
|
|
return dt.strftime("%Y年%m月%d日 %H:%M")
|
|
except:
|
|
return timestamp
|
|
return timestamp
|
|
|
|
|
|
@app.template_filter('truncate_text')
|
|
def truncate_text_filter(text, length=100):
|
|
"""截断文本"""
|
|
if len(text) <= length:
|
|
return text
|
|
return text[:length] + '...'
|
|
|
|
|
|
@app.template_filter('format_percent')
|
|
def format_percent_filter(value):
|
|
"""格式化百分比"""
|
|
if isinstance(value, (int, float)):
|
|
return f"{value * 100:.2f}%"
|
|
return value
|
|
|
|
|
|
def run_server(host: str = "0.0.0.0", port: int = 5000, debug: bool = False):
|
|
"""
|
|
运行Web服务器
|
|
|
|
Args:
|
|
host: 主机地址
|
|
port: 端口号
|
|
debug: 是否开启调试模式
|
|
"""
|
|
app.run(host=host, port=port, debug=debug)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 解析命令行参数
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="中文文本分类系统Web应用")
|
|
parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址")
|
|
parser.add_argument("--port", type=int, default=5000, help="服务器端口号")
|
|
parser.add_argument("--debug", action="store_true", help="是否开启调试模式")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# 运行服务器
|
|
run_server(host=args.host, port=args.port, debug=args.debug)
|