return render_template('predict_text.html', text=text, predictions=predictions) except Exception as e: logger.error(f"预测文本时出错: {e}") flash(f'预测失败: {str(e)}') return render_template('predict_text.html', text=text) return render_template('predict_text.html') @app.route('/predict_file', methods=['GET', 'POST']) def predict_file(): """文件预测页面""" if request.method == 'POST': # 检查是否有文件上传 if 'file' not in request.files: flash('未选择文件') return render_template('predict_file.html') file = request.files['file'] # 如果用户没有选择文件 if file.filename == '': flash('未选择文件') return render_template('predict_file.html') if file and allowed_file(file.filename): # 获取参数 top_k = int(request.form.get('top_k', 3)) text_column = request.form.get('text_column', 'text') try: # 安全地保存文件 filename = secure_filename(file.filename) file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(file_path) # 处理文件类型 if filename.endswith('.txt'): # 文本文件 with open(file_path, 'r', encoding='utf-8') as f: text = f.read() # 获取预测器 predictor = get_predictor() # 预测 result = predictor.predict( text=text, return_top_k=top_k, return_probabilities=True ) # 准备结果 predictions = result if top_k > 1 else [result] return render_template( 'predict_file.html', filename=filename, file_type='text', predictions=predictions ) elif filename.endswith(('.csv', '.xls', '.xlsx')): # 表格文件 if filename.endswith('.csv'): df = pd.read_csv(file_path) else: df = pd.read_excel(file_path) # 检查文本列是否存在 if text_column not in df.columns: flash(f"文件中没有找到列: {text_column}") return render_template('predict_file.html') # 提取文本 texts = df[text_column].fillna('').tolist() # 获取预测器 predictor = get_predictor() # 批量预测 results = predictor.predict_batch( texts=texts[:100], # 仅处理前100条记录 return_top_k=top_k, return_probabilities=True ) # 准备结果 batch_results = [] for i, result in enumerate(results): batch_results.append({ 'id': i, 'text': texts[i][:100] + '...' if len(texts[i]) > 100 else texts[i], 'predictions': result if top_k > 1 else [result] }) return render_template( 'predict_file.html', filename=filename, file_type='table', total_records=len(texts), processed_records=min(100, len(texts)), batch_results=batch_results ) else: flash(f"不支持的文件类型: {filename}") return render_template('predict_file.html') except Exception as e: logger.error(f"处理文件时出错: {e}") flash(f'处理失败: {str(e)}') return render_template('predict_file.html') else: flash(f'不支持的文件类型,允许的类型: {", ".join(ALLOWED_EXTENSIONS)}') return render_template('predict_file.html') return render_template('predict_file.html') @app.route('/batch_predict', methods=['GET', 'POST']) def batch_predict(): """批量预测页面""" if request.method == 'POST': texts = request.form.get('texts', '') top_k = int(request.form.get('top_k', 3)) if not texts: flash('请输入文本内容') return render_template('batch_predict.html') # 分割文本 text_list = [text.strip() for text in texts.split('\n') if text.strip()] if not text_list: flash('请输入有效的文本内容') return render_template('batch_predict.html', texts=texts) try: # 获取预测器 predictor = get_predictor() # 批量预测 results = predictor.predict_batch( texts=text_list, return_top_k=top_k, return_probabilities=True ) # 准备结果 batch_results = [] for i, result in enumerate(results): batch_results.append({ 'id': i + 1, 'text': text_list[i], 'predictions': result if top_k > 1 else [result] }) return render_template( 'batch_predict.html', texts=texts, batch_results=batch_results ) except Exception as e: logger.error(f"批量预测时出错: {e}") flash(f'预测失败: {str(e)}') return render_template('batch_predict.html', texts=texts) return render_template('batch_predict.html') @app.route('/models') def list_models(): """模型列表页面""" try: # 获取可用模型列表 models_info = ModelFactory.get_available_models() return render_template('models.html', models=models_info) except Exception as e: logger.error(f"获取模型列表时出错: {e}") flash(f'获取模型列表失败: {str(e)}') return render_template('models.html', models=[]) @app.route('/about') def about(): """关于页面""" return render_template('about.html') @app.errorhandler(404) def page_not_found(e): """404错误处理""" return render_template('404.html'), 404 @app.errorhandler(500) def internal_server_error(e): """500错误处理""" return render_template('500.html'), 500 # 过滤器 @app.template_filter('format_time') def format_time_filter(timestamp): """格式化时间戳""" if isinstance(timestamp, str): try: dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S") return dt.strftime("%Y年%m月%d日 %H:%M") except: return timestamp return timestamp @app.template_filter('truncate_text') def truncate_text_filter(text, length=100): """截断文本""" if len(text) <= length: return text return text[:length] + '...' @app.template_filter('format_percent') def format_percent_filter(value): """格式化百分比""" if isinstance(value, (int, float)): return f"{value * 100:.2f}%" return value def run_server(host: str = "0.0.0.0", port: int = 5000, debug: bool = False): """ 运行Web服务器 Args: host: 主机地址 port: 端口号 debug: 是否开启调试模式 """ app.run(host=host, port=port, debug=debug) if __name__ == "__main__": # 解析命令行参数 import argparse parser = argparse.ArgumentParser(description="中文文本分类系统Web应用") parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址") parser.add_argument("--port", type=int, default=5000, help="服务器端口号") parser.add_argument("--debug", action="store_true", help="是否开启调试模式") args = parser.parse_args() # 运行服务器 run_server(host=args.host, port=args.port, debug=args.debug)