261 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			261 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| return render_template('predict_text.html', text=text, predictions=predictions)
 | |
| 
 | |
| except Exception as e:
 | |
| logger.error(f"预测文本时出错: {e}")
 | |
| flash(f'预测失败: {str(e)}')
 | |
| return render_template('predict_text.html', text=text)
 | |
| 
 | |
| return render_template('predict_text.html')
 | |
| 
 | |
| 
 | |
| @app.route('/predict_file', methods=['GET', 'POST'])
 | |
| def predict_file():
 | |
|     """文件预测页面"""
 | |
|     if request.method == 'POST':
 | |
|         # 检查是否有文件上传
 | |
|         if 'file' not in request.files:
 | |
|             flash('未选择文件')
 | |
|             return render_template('predict_file.html')
 | |
| 
 | |
|         file = request.files['file']
 | |
| 
 | |
|         # 如果用户没有选择文件
 | |
|         if file.filename == '':
 | |
|             flash('未选择文件')
 | |
|             return render_template('predict_file.html')
 | |
| 
 | |
|         if file and allowed_file(file.filename):
 | |
|             # 获取参数
 | |
|             top_k = int(request.form.get('top_k', 3))
 | |
|             text_column = request.form.get('text_column', 'text')
 | |
| 
 | |
|             try:
 | |
|                 # 安全地保存文件
 | |
|                 filename = secure_filename(file.filename)
 | |
|                 file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
 | |
|                 file.save(file_path)
 | |
| 
 | |
|                 # 处理文件类型
 | |
|                 if filename.endswith('.txt'):
 | |
|                     # 文本文件
 | |
|                     with open(file_path, 'r', encoding='utf-8') as f:
 | |
|                         text = f.read()
 | |
| 
 | |
|                     # 获取预测器
 | |
|                     predictor = get_predictor()
 | |
| 
 | |
|                     # 预测
 | |
|                     result = predictor.predict(
 | |
|                         text=text,
 | |
|                         return_top_k=top_k,
 | |
|                         return_probabilities=True
 | |
|                     )
 | |
| 
 | |
|                     # 准备结果
 | |
|                     predictions = result if top_k > 1 else [result]
 | |
| 
 | |
|                     return render_template(
 | |
|                         'predict_file.html',
 | |
|                         filename=filename,
 | |
|                         file_type='text',
 | |
|                         predictions=predictions
 | |
|                     )
 | |
| 
 | |
|                 elif filename.endswith(('.csv', '.xls', '.xlsx')):
 | |
|                     # 表格文件
 | |
|                     if filename.endswith('.csv'):
 | |
|                         df = pd.read_csv(file_path)
 | |
|                     else:
 | |
|                         df = pd.read_excel(file_path)
 | |
| 
 | |
|                     # 检查文本列是否存在
 | |
|                     if text_column not in df.columns:
 | |
|                         flash(f"文件中没有找到列: {text_column}")
 | |
|                         return render_template('predict_file.html')
 | |
| 
 | |
|                     # 提取文本
 | |
|                     texts = df[text_column].fillna('').tolist()
 | |
| 
 | |
|                     # 获取预测器
 | |
|                     predictor = get_predictor()
 | |
| 
 | |
|                     # 批量预测
 | |
|                     results = predictor.predict_batch(
 | |
|                         texts=texts[:100],  # 仅处理前100条记录
 | |
|                         return_top_k=top_k,
 | |
|                         return_probabilities=True
 | |
|                     )
 | |
| 
 | |
|                     # 准备结果
 | |
|                     batch_results = []
 | |
|                     for i, result in enumerate(results):
 | |
|                         batch_results.append({
 | |
|                             'id': i,
 | |
|                             'text': texts[i][:100] + '...' if len(texts[i]) > 100 else texts[i],
 | |
|                             'predictions': result if top_k > 1 else [result]
 | |
|                         })
 | |
| 
 | |
|                     return render_template(
 | |
|                         'predict_file.html',
 | |
|                         filename=filename,
 | |
|                         file_type='table',
 | |
|                         total_records=len(texts),
 | |
|                         processed_records=min(100, len(texts)),
 | |
|                         batch_results=batch_results
 | |
|                     )
 | |
| 
 | |
|                 else:
 | |
|                     flash(f"不支持的文件类型: {filename}")
 | |
|                     return render_template('predict_file.html')
 | |
| 
 | |
|             except Exception as e:
 | |
|                 logger.error(f"处理文件时出错: {e}")
 | |
|                 flash(f'处理失败: {str(e)}')
 | |
|                 return render_template('predict_file.html')
 | |
|         else:
 | |
|             flash(f'不支持的文件类型,允许的类型: {", ".join(ALLOWED_EXTENSIONS)}')
 | |
|             return render_template('predict_file.html')
 | |
| 
 | |
|     return render_template('predict_file.html')
 | |
| 
 | |
| 
 | |
| @app.route('/batch_predict', methods=['GET', 'POST'])
 | |
| def batch_predict():
 | |
|     """批量预测页面"""
 | |
|     if request.method == 'POST':
 | |
|         texts = request.form.get('texts', '')
 | |
|         top_k = int(request.form.get('top_k', 3))
 | |
| 
 | |
|         if not texts:
 | |
|             flash('请输入文本内容')
 | |
|             return render_template('batch_predict.html')
 | |
| 
 | |
|         # 分割文本
 | |
|         text_list = [text.strip() for text in texts.split('\n') if text.strip()]
 | |
| 
 | |
|         if not text_list:
 | |
|             flash('请输入有效的文本内容')
 | |
|             return render_template('batch_predict.html', texts=texts)
 | |
| 
 | |
|         try:
 | |
|             # 获取预测器
 | |
|             predictor = get_predictor()
 | |
| 
 | |
|             # 批量预测
 | |
|             results = predictor.predict_batch(
 | |
|                 texts=text_list,
 | |
|                 return_top_k=top_k,
 | |
|                 return_probabilities=True
 | |
|             )
 | |
| 
 | |
|             # 准备结果
 | |
|             batch_results = []
 | |
|             for i, result in enumerate(results):
 | |
|                 batch_results.append({
 | |
|                     'id': i + 1,
 | |
|                     'text': text_list[i],
 | |
|                     'predictions': result if top_k > 1 else [result]
 | |
|                 })
 | |
| 
 | |
|             return render_template(
 | |
|                 'batch_predict.html',
 | |
|                 texts=texts,
 | |
|                 batch_results=batch_results
 | |
|             )
 | |
| 
 | |
|         except Exception as e:
 | |
|             logger.error(f"批量预测时出错: {e}")
 | |
|             flash(f'预测失败: {str(e)}')
 | |
|             return render_template('batch_predict.html', texts=texts)
 | |
| 
 | |
|     return render_template('batch_predict.html')
 | |
| 
 | |
| 
 | |
| @app.route('/models')
 | |
| def list_models():
 | |
|     """模型列表页面"""
 | |
|     try:
 | |
|         # 获取可用模型列表
 | |
|         models_info = ModelFactory.get_available_models()
 | |
| 
 | |
|         return render_template('models.html', models=models_info)
 | |
| 
 | |
|     except Exception as e:
 | |
|         logger.error(f"获取模型列表时出错: {e}")
 | |
|         flash(f'获取模型列表失败: {str(e)}')
 | |
|         return render_template('models.html', models=[])
 | |
| 
 | |
| 
 | |
| @app.route('/about')
 | |
| def about():
 | |
|     """关于页面"""
 | |
|     return render_template('about.html')
 | |
| 
 | |
| 
 | |
| @app.errorhandler(404)
 | |
| def page_not_found(e):
 | |
|     """404错误处理"""
 | |
|     return render_template('404.html'), 404
 | |
| 
 | |
| 
 | |
| @app.errorhandler(500)
 | |
| def internal_server_error(e):
 | |
|     """500错误处理"""
 | |
|     return render_template('500.html'), 500
 | |
| 
 | |
| 
 | |
| # 过滤器
 | |
| @app.template_filter('format_time')
 | |
| def format_time_filter(timestamp):
 | |
|     """格式化时间戳"""
 | |
|     if isinstance(timestamp, str):
 | |
|         try:
 | |
|             dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
 | |
|             return dt.strftime("%Y年%m月%d日 %H:%M")
 | |
|         except:
 | |
|             return timestamp
 | |
|     return timestamp
 | |
| 
 | |
| 
 | |
| @app.template_filter('truncate_text')
 | |
| def truncate_text_filter(text, length=100):
 | |
|     """截断文本"""
 | |
|     if len(text) <= length:
 | |
|         return text
 | |
|     return text[:length] + '...'
 | |
| 
 | |
| 
 | |
| @app.template_filter('format_percent')
 | |
| def format_percent_filter(value):
 | |
|     """格式化百分比"""
 | |
|     if isinstance(value, (int, float)):
 | |
|         return f"{value * 100:.2f}%"
 | |
|     return value
 | |
| 
 | |
| 
 | |
| def run_server(host: str = "0.0.0.0", port: int = 5000, debug: bool = False):
 | |
|     """
 | |
|     运行Web服务器
 | |
| 
 | |
|     Args:
 | |
|         host: 主机地址
 | |
|         port: 端口号
 | |
|         debug: 是否开启调试模式
 | |
|     """
 | |
|     app.run(host=host, port=port, debug=debug)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     # 解析命令行参数
 | |
|     import argparse
 | |
| 
 | |
|     parser = argparse.ArgumentParser(description="中文文本分类系统Web应用")
 | |
|     parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址")
 | |
|     parser.add_argument("--port", type=int, default=5000, help="服务器端口号")
 | |
|     parser.add_argument("--debug", action="store_true", help="是否开启调试模式")
 | |
| 
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     # 运行服务器
 | |
|     run_server(host=args.host, port=args.port, debug=args.debug)
 | 
