fix_bug
This commit is contained in:
		
							parent
							
								
									a8a0beb277
								
							
						
					
					
						commit
						24953f68df
					
				
							
								
								
									
										6
									
								
								.idea/vcs.xml
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								.idea/vcs.xml
									
									
									
										generated
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <project version="4"> | ||||
|   <component name="VcsDirectoryMappings"> | ||||
|     <mapping directory="" vcs="Git" /> | ||||
|   </component> | ||||
| </project> | ||||
							
								
								
									
										3
									
								
								app.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								app.py
									
									
									
									
									
								
							| @ -60,5 +60,4 @@ def internal_server_error(e): | ||||
|     return render_template('error.html', error='服务器内部错误'), 500 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     app.run(debug=True, host='0.0.0.0', port=5009) | ||||
| 
 | ||||
|     app.run(debug=True, host='0.0.0.0', port=50004) | ||||
| @ -8,7 +8,7 @@ import shutil | ||||
| from flask import Blueprint, request, jsonify, current_app, send_file, g, session | ||||
| from werkzeug.utils import secure_filename | ||||
| import mysql.connector | ||||
| from utils.model_service import text_classifier | ||||
| from utils.model_service import text_classifier # Assuming text_classifier is initialized correctly elsewhere | ||||
| from utils.db import get_db, close_db | ||||
| import logging | ||||
| 
 | ||||
| @ -59,7 +59,8 @@ def get_next_document_number(category): | ||||
|         if not result or result['max_num'] is None: | ||||
|             return 1 | ||||
|         else: | ||||
|             return result['max_num'] + 1 | ||||
|             # Ensure it's an int before adding | ||||
|             return int(result['max_num']) + 1 | ||||
|     except Exception as e: | ||||
|         logger.error(f"获取下一个文档编号时出错: {str(e)}") | ||||
|         return 1 | ||||
| @ -80,12 +81,17 @@ def save_classified_document(user_id, original_filename, category, content, file | ||||
|     Returns: | ||||
|         tuple: (成功标志, 存储的文件名或错误信息) | ||||
|     """ | ||||
|     db = None # Initialize db to None | ||||
|     cursor = None # Initialize cursor to None | ||||
|     try: | ||||
|         # 获取下一个文档编号 | ||||
|         next_num = get_next_document_number(category) | ||||
| 
 | ||||
|         # 安全处理文件名 | ||||
|         safe_original_name = secure_filename(original_filename) | ||||
|         # 安全处理文件名 - ensure it doesn't become empty | ||||
|         safe_original_name_base = os.path.splitext(secure_filename(original_filename))[0] | ||||
|         if not safe_original_name_base: | ||||
|              safe_original_name_base = "untitled" # Default if name becomes empty | ||||
|         safe_original_name = f"{safe_original_name_base}.txt" # Ensure .txt extension | ||||
| 
 | ||||
|         # 生成新文件名 (类别-编号-原文件名) | ||||
|         formatted_num = f"{next_num:04d}"  # 确保编号格式为4位数 | ||||
| @ -117,16 +123,18 @@ def save_classified_document(user_id, original_filename, category, content, file | ||||
|         category_result = cursor.fetchone() | ||||
| 
 | ||||
|         if not category_result: | ||||
|             return False, "类别不存在" | ||||
|             logger.error(f"类别 '{category}' 在数据库中未找到。") | ||||
|             # Optionally create the category if needed, or return error | ||||
|             return False, f"类别 '{category}' 不存在" | ||||
| 
 | ||||
|         category_id = category_result[0] | ||||
| 
 | ||||
|         # 插入数据库记录 | ||||
|         insert_query = """ | ||||
|         INSERT INTO documents  | ||||
|         (user_id, original_filename, stored_filename, file_path, file_size, category_id, status, classified_time) | ||||
|         VALUES (%s, %s, %s, %s, %s, %s, %s, NOW()) | ||||
|         """ | ||||
|         INSERT INTO documents | ||||
|         (user_id, original_filename, stored_filename, file_path, file_size, category_id, status, classified_time, upload_time) | ||||
|         VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW()) | ||||
|         """ # Added upload_time=NOW() assuming it's desired | ||||
| 
 | ||||
|         cursor.execute( | ||||
|             insert_query, | ||||
| @ -135,12 +143,19 @@ def save_classified_document(user_id, original_filename, category, content, file | ||||
| 
 | ||||
|         # 提交事务 | ||||
|         db.commit() | ||||
| 
 | ||||
|         logger.info(f"成功保存文档: {new_filename} (ID: {cursor.lastrowid})") | ||||
|         return True, new_filename | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         logger.error(f"保存分类文档时出错: {str(e)}") | ||||
|         # Rollback transaction if error occurs | ||||
|         if db: | ||||
|             db.rollback() | ||||
|         logger.exception(f"保存分类文档 '{original_filename}' 时出错: {str(e)}") # Use logger.exception for traceback | ||||
|         return False, str(e) | ||||
|     finally: | ||||
|         if cursor: | ||||
|             cursor.close() | ||||
|         # Note: db connection is usually closed via @app.teardown_appcontext | ||||
| 
 | ||||
| 
 | ||||
| @classify_bp.route('/single', methods=['POST']) | ||||
| @ -157,203 +172,214 @@ def classify_single_file(): | ||||
|         return jsonify({"success": False, "error": "没有文件"}), 400 | ||||
| 
 | ||||
|     file = request.files['file'] | ||||
|     original_filename = file.filename # Store original name early | ||||
| 
 | ||||
|     # 检查文件名 | ||||
|     if file.filename == '': | ||||
|     if original_filename == '': | ||||
|         return jsonify({"success": False, "error": "未选择文件"}), 400 | ||||
| 
 | ||||
|     # 检查文件类型 | ||||
|     if not allowed_text_file(file.filename): | ||||
|     if not allowed_text_file(original_filename): | ||||
|         return jsonify({"success": False, "error": "不支持的文件类型,仅支持txt文件"}), 400 | ||||
| 
 | ||||
|     temp_path = None # Initialize to ensure it's defined for cleanup | ||||
|     try: | ||||
|         # 创建临时文件以供处理 | ||||
|         temp_dir = os.path.join(current_app.root_path, 'temp') | ||||
|         os.makedirs(temp_dir, exist_ok=True) | ||||
| 
 | ||||
|         temp_filename = f"{uuid.uuid4().hex}.txt" | ||||
|         # Use secure_filename for the temp file name part as well | ||||
|         temp_filename = f"{uuid.uuid4().hex}-{secure_filename(original_filename)}" | ||||
|         temp_path = os.path.join(temp_dir, temp_filename) | ||||
| 
 | ||||
|         # 保存上传文件到临时位置 | ||||
|         file.save(temp_path) | ||||
|         file_size = os.path.getsize(temp_path) # Get size after saving | ||||
| 
 | ||||
|         # 读取文件内容 - Try multiple encodings | ||||
|         file_content = None | ||||
|         encodings_to_try = ['utf-8', 'gbk', 'gb18030'] | ||||
|         for enc in encodings_to_try: | ||||
|             try: | ||||
|                 with open(temp_path, 'r', encoding=enc) as f: | ||||
|                     file_content = f.read() | ||||
|                 logger.info(f"成功以 {enc} 读取临时文件: {temp_path}") | ||||
|                 break | ||||
|             except UnicodeDecodeError: | ||||
|                 logger.warning(f"使用 {enc} 解码临时文件失败: {temp_path}") | ||||
|                 continue | ||||
|             except Exception as read_err: # Catch other potential read errors | ||||
|                  logger.error(f"读取临时文件时出错 ({enc}): {read_err}") | ||||
|                  # Decide if you want to stop or try next encoding | ||||
|                  continue | ||||
| 
 | ||||
|         if file_content is None: | ||||
|              raise ValueError(f"无法使用支持的编码 {encodings_to_try} 读取文件内容。") | ||||
| 
 | ||||
|         # 读取文件内容 | ||||
|         with open(temp_path, 'r', encoding='utf-8') as f: | ||||
|             file_content = f.read() | ||||
| 
 | ||||
|         # 调用模型进行分类 | ||||
|         result = text_classifier.classify_text(file_content) | ||||
| 
 | ||||
|         if not result['success']: | ||||
|             return jsonify({"success": False, "error": result['error']}), 500 | ||||
|         if not result.get('success', False): # Check if key exists and is True | ||||
|             error_msg = result.get('error', '未知的分类错误') | ||||
|             logger.error(f"文本分类失败 for {original_filename}: {error_msg}") | ||||
|             return jsonify({"success": False, "error": error_msg}), 500 | ||||
| 
 | ||||
|         category = result['category'] | ||||
| 
 | ||||
|         # 保存分类后的文档 | ||||
|         file_size = os.path.getsize(temp_path) | ||||
|         save_success, message = save_classified_document( | ||||
|             user_id, | ||||
|             file.filename, | ||||
|             result['category'], | ||||
|             original_filename, # Use the original filename here | ||||
|             category, | ||||
|             file_content, | ||||
|             file_size | ||||
|         ) | ||||
| 
 | ||||
|         # 清理临时文件 | ||||
|         if os.path.exists(temp_path): | ||||
|             os.remove(temp_path) | ||||
| 
 | ||||
|         if not save_success: | ||||
|             return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500 | ||||
| 
 | ||||
|         # 返回分类结果 | ||||
|         return jsonify({ | ||||
|             "success": True, | ||||
|             "filename": file.filename, | ||||
|             "category": result['category'], | ||||
|             "confidence": result['confidence'], | ||||
|             "filename": original_filename, | ||||
|             "category": category, | ||||
|             "confidence": result.get('confidence'), # Use .get for safety | ||||
|             "stored_filename": message | ||||
|         }) | ||||
| 
 | ||||
|     except UnicodeDecodeError: | ||||
|         # 尝试GBK编码 | ||||
|         try: | ||||
|             with open(temp_path, 'r', encoding='gbk') as f: | ||||
|                 file_content = f.read() | ||||
| 
 | ||||
|             # 调用模型进行分类 | ||||
|             result = text_classifier.classify_text(file_content) | ||||
| 
 | ||||
|             if not result['success']: | ||||
|                 return jsonify({"success": False, "error": result['error']}), 500 | ||||
| 
 | ||||
|             # 保存分类后的文档 | ||||
|             file_size = os.path.getsize(temp_path) | ||||
|             save_success, message = save_classified_document( | ||||
|                 user_id, | ||||
|                 file.filename, | ||||
|                 result['category'], | ||||
|                 file_content, | ||||
|                 file_size | ||||
|             ) | ||||
| 
 | ||||
|             # 清理临时文件 | ||||
|             if os.path.exists(temp_path): | ||||
|                 os.remove(temp_path) | ||||
| 
 | ||||
|             if not save_success: | ||||
|                 return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500 | ||||
| 
 | ||||
|             # 返回分类结果 | ||||
|             return jsonify({ | ||||
|                 "success": True, | ||||
|                 "filename": file.filename, | ||||
|                 "category": result['category'], | ||||
|                 "confidence": result['confidence'], | ||||
|                 "stored_filename": message | ||||
|             }) | ||||
|         except Exception as e: | ||||
|             if os.path.exists(temp_path): | ||||
|                 os.remove(temp_path) | ||||
|             return jsonify({"success": False, "error": f"文件编码错误,请确保文件为UTF-8或GBK编码: {str(e)}"}), 400 | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         # 确保清理临时文件 | ||||
|         if 'temp_path' in locals() and os.path.exists(temp_path): | ||||
|             os.remove(temp_path) | ||||
|         logger.error(f"文件处理过程中发生错误: {str(e)}") | ||||
|         logger.exception(f"处理单文件 '{original_filename}' 过程中发生错误: {str(e)}") # Log exception with traceback | ||||
|         return jsonify({"success": False, "error": f"文件处理错误: {str(e)}"}), 500 | ||||
|     finally: | ||||
|         # 清理临时文件 | ||||
|         if temp_path and os.path.exists(temp_path): | ||||
|             try: | ||||
|                 os.remove(temp_path) | ||||
|                 logger.info(f"已删除临时文件: {temp_path}") | ||||
|             except OSError as rm_err: | ||||
|                 logger.error(f"删除临时文件失败 {temp_path}: {rm_err}") | ||||
| 
 | ||||
| 
 | ||||
| @classify_bp.route('/batch', methods=['POST']) | ||||
| def classify_batch_files(): | ||||
|     """批量文件上传和分类API(压缩包处理)""" | ||||
|     # 检查用户是否登录 | ||||
|     if 'user_id' not in session: | ||||
|         return jsonify({"success": False, "error": "请先登录"}), 401 | ||||
| 
 | ||||
|     user_id = session['user_id'] | ||||
| 
 | ||||
|     # 检查是否上传了文件 | ||||
|     if 'file' not in request.files: | ||||
|         return jsonify({"success": False, "error": "没有文件"}), 400 | ||||
| 
 | ||||
|     file = request.files['file'] | ||||
|     original_archive_name = file.filename | ||||
| 
 | ||||
|     # 检查文件名 | ||||
|     if file.filename == '': | ||||
|     if original_archive_name == '': | ||||
|         return jsonify({"success": False, "error": "未选择文件"}), 400 | ||||
| 
 | ||||
|     # 检查文件类型 | ||||
|     if not allowed_archive_file(file.filename): | ||||
|     if not allowed_archive_file(original_archive_name): | ||||
|         return jsonify({"success": False, "error": "不支持的文件类型,仅支持zip和rar压缩文件"}), 400 | ||||
| 
 | ||||
|     # 检查文件大小 | ||||
|     if request.content_length > 10 * 1024 * 1024:  # 10MB | ||||
|         return jsonify({"success": False, "error": "文件太大,最大支持10MB"}), 400 | ||||
|     # Consider adding file size check here if needed (e.g., request.content_length) | ||||
|     # if request.content_length > current_app.config['MAX_CONTENT_LENGTH']: # Example | ||||
|     #     return jsonify({"success": False, "error": "文件过大"}), 413 | ||||
| 
 | ||||
|     temp_dir = os.path.join(current_app.root_path, 'temp') | ||||
|     extract_dir = os.path.join(temp_dir, f"extract_{uuid.uuid4().hex}") | ||||
|     archive_path = None | ||||
|     # --- 修改点 1: 将 total_attempted 改回 total --- | ||||
|     results = { | ||||
|         "total": 0, # <--- 使用 'total' | ||||
|         "success": 0, | ||||
|         "failed": 0, | ||||
|         "categories": {}, | ||||
|         "failed_files": [] | ||||
|     } | ||||
| 
 | ||||
|     try: | ||||
|         # 创建临时目录 | ||||
|         temp_dir = os.path.join(current_app.root_path, 'temp') | ||||
|         extract_dir = os.path.join(temp_dir, f"extract_{uuid.uuid4().hex}") | ||||
|         os.makedirs(extract_dir, exist_ok=True) | ||||
| 
 | ||||
|         # 保存上传的压缩文件 | ||||
|         archive_path = os.path.join(temp_dir, secure_filename(file.filename)) | ||||
|         archive_path = os.path.join(temp_dir, secure_filename(original_archive_name)) | ||||
|         file.save(archive_path) | ||||
|         logger.info(f"已保存上传的压缩文件: {archive_path}") | ||||
| 
 | ||||
|         # 解压文件 | ||||
|         file_extension = file.filename.rsplit('.', 1)[1].lower() | ||||
|         file_extension = original_archive_name.rsplit('.', 1)[1].lower() | ||||
|         if file_extension == 'zip': | ||||
|             with zipfile.ZipFile(archive_path, 'r') as zip_ref: | ||||
|                 zip_ref.extractall(extract_dir) | ||||
|             logger.info(f"已解压ZIP文件到: {extract_dir}") | ||||
|         elif file_extension == 'rar': | ||||
|             with rarfile.RarFile(archive_path, 'r') as rar_ref: | ||||
|                 rar_ref.extractall(extract_dir) | ||||
|              # Ensure rarfile is installed and unrar executable is available | ||||
|              try: | ||||
|                 with rarfile.RarFile(archive_path, 'r') as rar_ref: | ||||
|                     rar_ref.extractall(extract_dir) | ||||
|                 logger.info(f"已解压RAR文件到: {extract_dir}") | ||||
|              except rarfile.NeedFirstVolume: | ||||
|                  logger.error(f"RAR文件是分卷压缩的一部分,需要第一个分卷: {original_archive_name}") | ||||
|                  raise ValueError("不支持分卷压缩的RAR文件") | ||||
|              except rarfile.BadRarFile: | ||||
|                  logger.error(f"RAR文件损坏或格式错误: {original_archive_name}") | ||||
|                  raise ValueError("RAR文件损坏或格式错误") | ||||
|              except Exception as rar_err: # Catch other rar errors (like missing unrar) | ||||
|                  logger.error(f"解压RAR文件时出错: {rar_err}") | ||||
|                  raise ValueError(f"解压RAR文件失败: {rar_err}") | ||||
| 
 | ||||
|         # 处理结果统计 | ||||
|         results = { | ||||
|             "total": 0, | ||||
|             "success": 0, | ||||
|             "failed": 0, | ||||
|             "categories": {}, | ||||
|             "failed_files": [] | ||||
|         } | ||||
| 
 | ||||
|         # 递归处理所有txt文件 | ||||
|         for root, dirs, files in os.walk(extract_dir): | ||||
|             # --- 开始过滤 macOS 文件夹 --- | ||||
|             if '__MACOSX' in root.split(os.path.sep): | ||||
|                 logger.info(f"Skipping macOS metadata directory: {root}") | ||||
|                 dirs[:] = [] | ||||
|                 files[:] = [] | ||||
|                 continue | ||||
|             # --- 结束过滤 macOS 文件夹 --- | ||||
| 
 | ||||
|             for filename in files: | ||||
|                 if filename.lower().endswith('.txt'): | ||||
|                 # --- 开始过滤 macOS 文件 --- | ||||
|                 if filename.startswith('._') or filename == '.DS_Store': | ||||
|                     logger.info(f"Skipping macOS metadata file: {filename} in {root}") | ||||
|                     continue | ||||
|                 # --- 结束过滤 macOS 文件 --- | ||||
| 
 | ||||
|                 # Process only allowed text files | ||||
|                 if allowed_text_file(filename): | ||||
|                     file_path = os.path.join(root, filename) | ||||
|                     results["total"] += 1 | ||||
|                     # --- 修改点 2: 将 total_attempted 改回 total --- | ||||
|                     results["total"] += 1 # <--- 使用 'total' | ||||
| 
 | ||||
|                     try: | ||||
|                         # 读取文件内容 | ||||
|                         try: | ||||
|                             with open(file_path, 'r', encoding='utf-8') as f: | ||||
|                                 file_content = f.read() | ||||
|                         except UnicodeDecodeError: | ||||
|                             # 尝试GBK编码 | ||||
|                             with open(file_path, 'r', encoding='gbk') as f: | ||||
|                                 file_content = f.read() | ||||
|                         # 读取文件内容 - Try multiple encodings | ||||
|                         file_content = None | ||||
|                         encodings_to_try = ['utf-8', 'gbk', 'gb18030'] | ||||
|                         for enc in encodings_to_try: | ||||
|                             try: | ||||
|                                 with open(file_path, 'r', encoding=enc) as f: | ||||
|                                     file_content = f.read() | ||||
|                                 logger.info(f"成功以 {enc} 读取文件: {file_path}") | ||||
|                                 break | ||||
|                             except UnicodeDecodeError: | ||||
|                                 logger.warning(f"使用 {enc} 解码文件失败: {file_path}") | ||||
|                                 continue | ||||
|                             except Exception as read_err: | ||||
|                                  logger.error(f"读取文件时发生其他错误 ({enc}) {file_path}: {read_err}") | ||||
|                                  continue | ||||
| 
 | ||||
|                         if file_content is None: | ||||
|                             raise ValueError(f"无法使用支持的编码 {encodings_to_try} 读取文件内容。") | ||||
| 
 | ||||
|                         # 调用模型进行分类 | ||||
|                         result = text_classifier.classify_text(file_content) | ||||
| 
 | ||||
|                         if not result['success']: | ||||
|                             results["failed"] += 1 | ||||
|                             results["failed_files"].append({ | ||||
|                                 "filename": filename, | ||||
|                                 "error": result['error'] | ||||
|                             }) | ||||
|                             continue | ||||
|                         if not result.get('success', False): | ||||
|                             error_msg = result.get('error', '未知的分类错误') | ||||
|                             raise Exception(f"分类失败: {error_msg}") | ||||
| 
 | ||||
|                         category = result['category'] | ||||
| 
 | ||||
|                         # 统计类别数量 | ||||
|                         if category not in results["categories"]: | ||||
|                             results["categories"][category] = 0 | ||||
|                         results["categories"][category] += 1 | ||||
|                         results["categories"][category] = results["categories"].get(category, 0) + 1 | ||||
| 
 | ||||
|                         # 保存分类后的文档 | ||||
|                         file_size = os.path.getsize(file_path) | ||||
|                         save_success, message = save_classified_document( | ||||
|                             user_id, | ||||
| @ -366,106 +392,108 @@ def classify_batch_files(): | ||||
|                         if save_success: | ||||
|                             results["success"] += 1 | ||||
|                         else: | ||||
|                             results["failed"] += 1 | ||||
|                             results["failed_files"].append({ | ||||
|                                 "filename": filename, | ||||
|                                 "error": message | ||||
|                             }) | ||||
|                             raise Exception(f"保存失败: {message}") | ||||
| 
 | ||||
|                     except Exception as e: | ||||
|                         logger.error(f"处理文件 '{filename}' 失败: {str(e)}") | ||||
|                         results["failed"] += 1 | ||||
|                         results["failed_files"].append({ | ||||
|                             "filename": filename, | ||||
|                             "error": str(e) | ||||
|                         }) | ||||
| 
 | ||||
|         # 清理临时文件 | ||||
|         if os.path.exists(archive_path): | ||||
|             os.remove(archive_path) | ||||
|         if os.path.exists(extract_dir): | ||||
|             shutil.rmtree(extract_dir) | ||||
| 
 | ||||
|         # 返回处理结果 | ||||
|         return jsonify({ | ||||
|             "success": True, | ||||
|             "archive_name": file.filename, | ||||
|             "results": results | ||||
|             "archive_name": original_archive_name, | ||||
|             "results": results # <--- 确保返回的 results 包含 'total' 键 | ||||
|         }) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         # 确保清理临时文件 | ||||
|         if 'archive_path' in locals() and os.path.exists(archive_path): | ||||
|             os.remove(archive_path) | ||||
|         if 'extract_dir' in locals() and os.path.exists(extract_dir): | ||||
|             shutil.rmtree(extract_dir) | ||||
| 
 | ||||
|         logger.error(f"压缩包处理过程中发生错误: {str(e)}") | ||||
|         logger.exception(f"处理压缩包 '{original_archive_name}' 过程中发生严重错误: {str(e)}") | ||||
|         return jsonify({"success": False, "error": f"压缩包处理错误: {str(e)}"}), 500 | ||||
| 
 | ||||
|     finally: | ||||
|         # 清理临时文件和目录 (保持不变) | ||||
|         if archive_path and os.path.exists(archive_path): | ||||
|             try: | ||||
|                 os.remove(archive_path) | ||||
|                 logger.info(f"已删除临时压缩文件: {archive_path}") | ||||
|             except OSError as rm_err: | ||||
|                  logger.error(f"删除临时压缩文件失败 {archive_path}: {rm_err}") | ||||
|         if os.path.exists(extract_dir): | ||||
|             try: | ||||
|                 shutil.rmtree(extract_dir) | ||||
|                 logger.info(f"已删除临时解压目录: {extract_dir}") | ||||
|             except OSError as rmtree_err: | ||||
|                  logger.error(f"删除临时解压目录失败 {extract_dir}: {rmtree_err}") | ||||
| 
 | ||||
| 
 | ||||
| @classify_bp.route('/documents', methods=['GET']) | ||||
| def get_classified_documents(): | ||||
|     """获取已分类的文档列表""" | ||||
|     # 检查用户是否登录 | ||||
|     if 'user_id' not in session: | ||||
|         return jsonify({"success": False, "error": "请先登录"}), 401 | ||||
| 
 | ||||
|     user_id = session['user_id'] | ||||
| 
 | ||||
|     # 获取查询参数 | ||||
|     category = request.args.get('category', 'all') | ||||
|     page = int(request.args.get('page', 1)) | ||||
|     per_page = int(request.args.get('per_page', 10)) | ||||
| 
 | ||||
|     # 验证每页条数 | ||||
|     if per_page not in [10, 25, 50, 100]: | ||||
|         per_page = 10 | ||||
| 
 | ||||
|     # 计算偏移量 | ||||
|     offset = (page - 1) * per_page | ||||
| 
 | ||||
|     db = get_db() | ||||
|     cursor = db.cursor(dictionary=True) | ||||
| 
 | ||||
|     try: | ||||
|         # 构建查询条件 | ||||
|         where_clause = "WHERE d.user_id = %s AND d.status = '已分类'" | ||||
|         params = [user_id] | ||||
|         category = request.args.get('category', 'all') | ||||
|         page = int(request.args.get('page', 1)) | ||||
|         per_page = int(request.args.get('per_page', 10)) | ||||
| 
 | ||||
|         if category != 'all': | ||||
|             where_clause += " AND c.name = %s" | ||||
|         if per_page not in [10, 25, 50, 100]: | ||||
|             per_page = 10 | ||||
|         if page < 1: | ||||
|             page = 1 | ||||
| 
 | ||||
|         offset = (page - 1) * per_page | ||||
| 
 | ||||
|         db = get_db() | ||||
|         cursor = db.cursor(dictionary=True) # Use dictionary cursor | ||||
| 
 | ||||
|         # Get available categories first | ||||
|         cursor.execute("SELECT name FROM categories ORDER BY name") | ||||
|         available_categories = [row['name'] for row in cursor.fetchall()] | ||||
| 
 | ||||
|         # Build query | ||||
|         params = [user_id] | ||||
|         base_query = """ | ||||
|         FROM documents d | ||||
|         JOIN categories c ON d.category_id = c.id | ||||
|         WHERE d.user_id = %s AND d.status = '已分类' | ||||
|         """ | ||||
|         if category != 'all' and category in available_categories: | ||||
|             base_query += " AND c.name = %s" | ||||
|             params.append(category) | ||||
| 
 | ||||
|         # 查询总记录数 | ||||
|         count_query = f""" | ||||
|         SELECT COUNT(*) as total | ||||
|         FROM documents d | ||||
|         JOIN categories c ON d.category_id = c.id | ||||
|         {where_clause} | ||||
|         """ | ||||
|         # Count query | ||||
|         count_query = f"SELECT COUNT(*) as total {base_query}" | ||||
|         cursor.execute(count_query, params) | ||||
|         total_count = cursor.fetchone()['total'] | ||||
|         total_count_result = cursor.fetchone() | ||||
|         total_count = total_count_result['total'] if total_count_result else 0 | ||||
| 
 | ||||
|         # 计算总页数 | ||||
|         total_pages = (total_count + per_page - 1) // per_page | ||||
|         total_pages = (total_count + per_page - 1) // per_page if per_page > 0 else 0 | ||||
| 
 | ||||
|         # 查询分页数据 | ||||
|         query = f""" | ||||
|         SELECT d.id, d.original_filename, d.stored_filename, d.file_size,  | ||||
|         # Data query | ||||
|         data_query = f""" | ||||
|         SELECT d.id, d.original_filename, d.stored_filename, d.file_size, | ||||
|                c.name as category, d.upload_time, d.classified_time | ||||
|         FROM documents d | ||||
|         JOIN categories c ON d.category_id = c.id | ||||
|         {where_clause} | ||||
|         {base_query} | ||||
|         ORDER BY d.classified_time DESC | ||||
|         LIMIT %s OFFSET %s | ||||
|         """ | ||||
|         params.extend([per_page, offset]) | ||||
|         cursor.execute(query, params) | ||||
|         cursor.execute(data_query, params) | ||||
|         documents = cursor.fetchall() | ||||
| 
 | ||||
|         # 获取所有可用类别 | ||||
|         cursor.execute("SELECT name FROM categories ORDER BY name") | ||||
|         categories = [row['name'] for row in cursor.fetchall()] | ||||
|         # Format dates if they are datetime objects (optional, depends on driver) | ||||
|         for doc in documents: | ||||
|             if hasattr(doc.get('upload_time'), 'isoformat'): | ||||
|                  doc['upload_time'] = doc['upload_time'].isoformat() | ||||
|             if hasattr(doc.get('classified_time'), 'isoformat'): | ||||
|                  doc['classified_time'] = doc['classified_time'].isoformat() | ||||
| 
 | ||||
| 
 | ||||
|         return jsonify({ | ||||
|             "success": True, | ||||
| @ -476,36 +504,35 @@ def get_classified_documents(): | ||||
|                 "current_page": page, | ||||
|                 "total_pages": total_pages | ||||
|             }, | ||||
|             "categories": categories, | ||||
|             "categories": available_categories, # Send the fetched list | ||||
|             "current_category": category | ||||
|         }) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         logger.error(f"获取文档列表时出错: {str(e)}") | ||||
|         logger.exception(f"获取文档列表时出错: {str(e)}") # Log traceback | ||||
|         return jsonify({"success": False, "error": f"获取文档列表失败: {str(e)}"}), 500 | ||||
| 
 | ||||
|     finally: | ||||
|         cursor.close() | ||||
|         # Cursor closing handled by context usually, but explicit close is safe | ||||
|          if 'cursor' in locals() and cursor: | ||||
|               cursor.close() | ||||
| 
 | ||||
| 
 | ||||
| @classify_bp.route('/download/<int:document_id>', methods=['GET']) | ||||
| def download_document(document_id): | ||||
|     """下载已分类的文档""" | ||||
|     # 检查用户是否登录 | ||||
|     if 'user_id' not in session: | ||||
|         return jsonify({"success": False, "error": "请先登录"}), 401 | ||||
| 
 | ||||
|     user_id = session['user_id'] | ||||
| 
 | ||||
|     db = get_db() | ||||
|     cursor = db.cursor(dictionary=True) | ||||
| 
 | ||||
|     cursor = None | ||||
|     try: | ||||
|         # 查询文档信息 | ||||
|         db = get_db() | ||||
|         cursor = db.cursor(dictionary=True) | ||||
| 
 | ||||
|         query = """ | ||||
|         SELECT file_path, original_filename, stored_filename | ||||
|         SELECT file_path, original_filename | ||||
|         FROM documents | ||||
|         WHERE id = %s AND user_id = %s | ||||
|         WHERE id = %s AND user_id = %s AND status = '已分类' | ||||
|         """ | ||||
|         cursor.execute(query, (document_id, user_id)) | ||||
|         document = cursor.fetchone() | ||||
| @ -513,136 +540,187 @@ def download_document(document_id): | ||||
|         if not document: | ||||
|             return jsonify({"success": False, "error": "文档不存在或无权访问"}), 404 | ||||
| 
 | ||||
|         # 检查文件是否存在 | ||||
|         if not os.path.exists(document['file_path']): | ||||
|             return jsonify({"success": False, "error": "文件不存在"}), 404 | ||||
|         file_path = document['file_path'] | ||||
|         download_name = document['original_filename'] | ||||
| 
 | ||||
|         # 返回文件下载 | ||||
|         if not os.path.exists(file_path): | ||||
|              logger.error(f"请求下载的文件在服务器上不存在: {file_path} (Doc ID: {document_id})") | ||||
|              # Update status in DB? | ||||
|              # update_status_query = "UPDATE documents SET status = '文件丢失' WHERE id = %s" | ||||
|              # cursor.execute(update_status_query, (document_id,)) | ||||
|              # db.commit() | ||||
|              return jsonify({"success": False, "error": "文件在服务器上丢失"}), 404 | ||||
| 
 | ||||
|         logger.info(f"用户 {user_id} 请求下载文档 ID: {document_id}, 文件: {download_name}") | ||||
|         return send_file( | ||||
|             document['file_path'], | ||||
|             file_path, | ||||
|             as_attachment=True, | ||||
|             download_name=document['original_filename'], | ||||
|             mimetype='text/plain' | ||||
|             download_name=download_name, # Use original filename for download | ||||
|             mimetype='text/plain' # Assuming text, adjust if other types allowed | ||||
|         ) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         logger.error(f"下载文档时出错: {str(e)}") | ||||
|         logger.exception(f"下载文档 ID {document_id} 时出错: {str(e)}") | ||||
|         return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500 | ||||
| 
 | ||||
|     finally: | ||||
|         cursor.close() | ||||
|         if cursor: | ||||
|             cursor.close() | ||||
| 
 | ||||
| 
 | ||||
| @classify_bp.route('/download-multiple', methods=['POST']) | ||||
| def download_multiple_documents(): | ||||
|     """下载多个文档(打包为zip)""" | ||||
|     # 检查用户是否登录 | ||||
|     if 'user_id' not in session: | ||||
|         return jsonify({"success": False, "error": "请先登录"}), 401 | ||||
|      | ||||
| 
 | ||||
|     user_id = session['user_id'] | ||||
|      | ||||
|     # 获取请求数据 | ||||
|     data = request.get_json() | ||||
|     if not data or 'document_ids' not in data: | ||||
|         return jsonify({"success": False, "error": "缺少必要参数"}), 400 | ||||
|      | ||||
|     document_ids = data['document_ids'] | ||||
|     if not isinstance(document_ids, list) or not document_ids: | ||||
|         return jsonify({"success": False, "error": "文档ID列表无效"}), 400 | ||||
|      | ||||
|     db = get_db() | ||||
|     cursor = db.cursor(dictionary=True) | ||||
|      | ||||
|     cursor = None | ||||
|     zip_path = None # Initialize for finally block | ||||
| 
 | ||||
|     try: | ||||
|         # 创建临时目录用于存放zip文件 | ||||
|         temp_dir = os.path.join(current_app.root_path, 'temp') | ||||
|         os.makedirs(temp_dir, exist_ok=True) | ||||
|          | ||||
|         # 创建临时ZIP文件 | ||||
|         zip_filename = f"documents_{int(time.time())}.zip" | ||||
|         zip_path = os.path.join(temp_dir, zip_filename) | ||||
|          | ||||
|         # 查询所有符合条件的文档 | ||||
|         data = request.get_json() | ||||
|         if not data or 'document_ids' not in data: | ||||
|             return jsonify({"success": False, "error": "缺少必要参数 'document_ids'"}), 400 | ||||
| 
 | ||||
|         document_ids = data['document_ids'] | ||||
|         if not isinstance(document_ids, list) or not document_ids: | ||||
|             return jsonify({"success": False, "error": "文档ID列表无效"}), 400 | ||||
|         # Sanitize IDs to be integers | ||||
|         try: | ||||
|             document_ids = [int(doc_id) for doc_id in document_ids] | ||||
|         except ValueError: | ||||
|             return jsonify({"success": False, "error": "文档ID必须是数字"}), 400 | ||||
| 
 | ||||
|         db = get_db() | ||||
|         cursor = db.cursor(dictionary=True) | ||||
| 
 | ||||
|         # Query valid documents for the user | ||||
|         placeholders = ', '.join(['%s'] * len(document_ids)) | ||||
|         query = f""" | ||||
|         SELECT id, file_path, original_filename | ||||
|         FROM documents | ||||
|         WHERE id IN ({placeholders}) AND user_id = %s | ||||
|         WHERE id IN ({placeholders}) AND user_id = %s AND status = '已分类' | ||||
|         """ | ||||
|         params = document_ids + [user_id] | ||||
|         cursor.execute(query, params) | ||||
|         documents = cursor.fetchall() | ||||
|          | ||||
| 
 | ||||
|         if not documents: | ||||
|             return jsonify({"success": False, "error": "没有找到符合条件的文档"}), 404 | ||||
|          | ||||
|         # 创建ZIP文件并添加文档 | ||||
|         with zipfile.ZipFile(zip_path, 'w') as zipf: | ||||
|             return jsonify({"success": False, "error": "没有找到符合条件的可下载文档"}), 404 | ||||
| 
 | ||||
|         # Create temp zip file | ||||
|         temp_dir = os.path.join(current_app.root_path, 'temp') | ||||
|         os.makedirs(temp_dir, exist_ok=True) | ||||
|         zip_filename = f"documents_{user_id}_{int(time.time())}.zip" | ||||
|         zip_path = os.path.join(temp_dir, zip_filename) | ||||
| 
 | ||||
|         files_added = 0 | ||||
|         with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | ||||
|             for doc in documents: | ||||
|                 if os.path.exists(doc['file_path']): | ||||
|                     # 添加文件到zip,使用原始文件名 | ||||
|                     zipf.write(doc['file_path'], arcname=doc['original_filename']) | ||||
|          | ||||
|         # 返回ZIP文件下载 | ||||
|                 file_path = doc['file_path'] | ||||
|                 original_filename = doc['original_filename'] | ||||
|                 if os.path.exists(file_path): | ||||
|                     zipf.write(file_path, arcname=original_filename) | ||||
|                     files_added += 1 | ||||
|                     logger.info(f"添加到ZIP: {original_filename} (来自 {file_path})") | ||||
|                 else: | ||||
|                     logger.warning(f"跳过丢失的文件: {original_filename} (路径: {file_path}, Doc ID: {doc['id']})") | ||||
|                     # Optionally add a readme to the zip indicating missing files | ||||
| 
 | ||||
|         if files_added == 0: | ||||
|              # This case means all selected files were missing on disk | ||||
|              return jsonify({"success": False, "error": "所有选中的文件在服务器上都已丢失"}), 404 | ||||
| 
 | ||||
| 
 | ||||
|         logger.info(f"用户 {user_id} 请求下载 {files_added} 个文档打包为 {zip_filename}") | ||||
|         # Use after_this_request to delete the temp file after sending | ||||
|         @current_app.after_request | ||||
|         def remove_file(response): | ||||
|             try: | ||||
|                 if zip_path and os.path.exists(zip_path): | ||||
|                     os.remove(zip_path) | ||||
|                     logger.info(f"已删除临时ZIP文件: {zip_path}") | ||||
|             except Exception as error: | ||||
|                 logger.error(f"删除临时ZIP文件错误 {zip_path}: {error}") | ||||
|             return response | ||||
| 
 | ||||
|         return send_file( | ||||
|             zip_path, | ||||
|             as_attachment=True, | ||||
|             download_name=zip_filename, | ||||
|             mimetype='application/zip' | ||||
|         ) | ||||
|      | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         logger.error(f"下载多个文档时出错: {str(e)}") | ||||
|         logger.exception(f"下载多个文档时出错: {str(e)}") | ||||
|         # Clean up zip file if created but error occurred before sending | ||||
|         if zip_path and os.path.exists(zip_path): | ||||
|              try: os.remove(zip_path) | ||||
|              except OSError: pass | ||||
|         return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500 | ||||
|      | ||||
| 
 | ||||
|     finally: | ||||
|         cursor.close() | ||||
|         if cursor: | ||||
|             cursor.close() | ||||
| 
 | ||||
| 
 | ||||
| @classify_bp.route('/classify-text', methods=['POST']) | ||||
| def classify_text_directly(): | ||||
|     """直接对文本进行分类(不保存文件)""" | ||||
|     # 检查用户是否登录 | ||||
|     if 'user_id' not in session: | ||||
|         return jsonify({"success": False, "error": "请先登录"}), 401 | ||||
| 
 | ||||
|     # 获取请求数据 | ||||
|     data = request.get_json() | ||||
|     if not data or 'text' not in data: | ||||
|         return jsonify({"success": False, "error": "缺少必要参数"}), 400 | ||||
| 
 | ||||
|     text = data['text'] | ||||
|     if not text.strip(): | ||||
|         return jsonify({"success": False, "error": "文本内容不能为空"}), 400 | ||||
| 
 | ||||
|     try: | ||||
|         data = request.get_json() | ||||
|         if not data or 'text' not in data: | ||||
|             return jsonify({"success": False, "error": "缺少必要参数 'text'"}), 400 | ||||
| 
 | ||||
|         text = data['text'] | ||||
|         # Basic validation | ||||
|         if not isinstance(text, str) or not text.strip(): | ||||
|             return jsonify({"success": False, "error": "文本内容不能为空或无效"}), 400 | ||||
| 
 | ||||
|         # Limit text length? | ||||
|         MAX_TEXT_LENGTH = 10000 # Example limit | ||||
|         if len(text) > MAX_TEXT_LENGTH: | ||||
|              return jsonify({"success": False, "error": f"文本过长,最大支持 {MAX_TEXT_LENGTH} 字符"}), 413 | ||||
| 
 | ||||
| 
 | ||||
|         # 调用模型进行分类 | ||||
|         # Ensure model is initialized - consider adding a check or lazy loading | ||||
|         if not text_classifier or not text_classifier.is_initialized: | ||||
|              logger.warning("文本分类器未初始化,尝试初始化...") | ||||
|              if not text_classifier.initialize(): | ||||
|                  logger.error("文本分类器初始化失败。") | ||||
|                  return jsonify({"success": False, "error": "分类服务暂时不可用"}), 503 | ||||
| 
 | ||||
|         result = text_classifier.classify_text(text) | ||||
| 
 | ||||
|         if not result['success']: | ||||
|             return jsonify({"success": False, "error": result['error']}), 500 | ||||
|         if not result.get('success', False): | ||||
|             error_msg = result.get('error', '未知的分类错误') | ||||
|             logger.error(f"直接文本分类失败: {error_msg}") | ||||
|             return jsonify({"success": False, "error": error_msg}), 500 | ||||
| 
 | ||||
|         # 返回分类结果 | ||||
|         return jsonify({ | ||||
|             "success": True, | ||||
|             "category": result['category'], | ||||
|             "confidence": result['confidence'], | ||||
|             "all_confidences": result['all_confidences'] | ||||
|             "category": result.get('category'), | ||||
|             "confidence": result.get('confidence'), | ||||
|             "all_confidences": result.get('all_confidences') # Include all confidences if available | ||||
|         }) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         logger.error(f"文本分类过程中发生错误: {str(e)}") | ||||
|         logger.exception(f"直接文本分类过程中发生错误: {str(e)}") | ||||
|         return jsonify({"success": False, "error": f"文本分类错误: {str(e)}"}), 500 | ||||
| 
 | ||||
| 
 | ||||
| @classify_bp.route('/categories', methods=['GET']) | ||||
| def get_categories(): | ||||
|     """获取所有分类类别""" | ||||
|     db = get_db() | ||||
|     cursor = db.cursor(dictionary=True) | ||||
| 
 | ||||
|     cursor = None | ||||
|     try: | ||||
|         db = get_db() | ||||
|         cursor = db.cursor(dictionary=True) | ||||
|         cursor.execute("SELECT id, name, description FROM categories ORDER BY name") | ||||
|         categories = cursor.fetchall() | ||||
| 
 | ||||
| @ -652,8 +730,8 @@ def get_categories(): | ||||
|         }) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         logger.error(f"获取类别列表时出错: {str(e)}") | ||||
|         logger.exception(f"获取类别列表时出错: {str(e)}") | ||||
|         return jsonify({"success": False, "error": f"获取类别列表失败: {str(e)}"}), 500 | ||||
| 
 | ||||
|     finally: | ||||
|         cursor.close() | ||||
|         if cursor: | ||||
|             cursor.close() | ||||
|  | ||||
| @ -3,9 +3,12 @@ import os | ||||
| import jieba | ||||
| import numpy as np | ||||
| import pickle | ||||
| from tensorflow.keras.models import load_model | ||||
| from tensorflow.keras.preprocessing.sequence import pad_sequences | ||||
| import tensorflow as tf | ||||
| # from tensorflow import keras # tf.keras is preferred | ||||
| from keras.models import load_model # Keep if specifically needed, else use tf.keras.models.load_model | ||||
| from keras.preprocessing.sequence import pad_sequences # Keep if specifically needed, else use tf.keras.preprocessing.sequence.pad_sequences | ||||
| import logging | ||||
| import h5py # Moved import here as it's used conditionally | ||||
| 
 | ||||
| 
 | ||||
| class TextClassificationModel: | ||||
| @ -34,16 +37,42 @@ class TextClassificationModel: | ||||
| 
 | ||||
|         # 设置日志 | ||||
|         self.logger = logging.getLogger(__name__) | ||||
|         logging.basicConfig(level=logging.INFO) # Basic logging setup if not configured elsewhere | ||||
| 
 | ||||
|     def initialize(self): | ||||
|         """初始化并加载模型和分词器""" | ||||
|         """初始化并加载模型和分词器""" # <--- Corrected Indentation Starts Here | ||||
|         try: | ||||
|             self.logger.info("开始加载文本分类模型...") | ||||
|             # 加载模型 | ||||
|             self.model = load_model(self.model_path) | ||||
|             self.logger.info("模型加载成功") | ||||
| 
 | ||||
|             # 优先尝试加载 HDF5 格式 (.h5),因为文件名是 .h5 | ||||
|             try: | ||||
|                 self.logger.info(f"尝试以 HDF5 格式加载模型: {self.model_path}") | ||||
|                 # For H5 files, direct load_model is usually sufficient if saved correctly. | ||||
|                 # compile=False is often needed if you don't need training features immediately. | ||||
|                 self.model = tf.keras.models.load_model(self.model_path, compile=False) | ||||
|                 self.logger.info("HDF5 模型加载成功") | ||||
| 
 | ||||
|             except Exception as h5_exc: | ||||
|                 self.logger.warning(f"HDF5 格式加载失败 ({h5_exc}),尝试以 SavedModel 格式加载...") | ||||
|                 # 如果 HDF5 加载失败,再尝试 SavedModel 格式 (通常是一个目录,而不是 .h5 文件) | ||||
|                 # This might fail if model_path truly points to an h5 file. | ||||
|                 try: | ||||
|                     self.model = tf.keras.models.load_model( | ||||
|                         self.model_path, | ||||
|                         compile=False # Usually false for inference | ||||
|                         # custom_objects can be added here if needed | ||||
|                         # options=tf.saved_model.LoadOptions(experimental_io_device='/job:localhost') # Usually not needed unless specific TF distribution setup | ||||
|                     ) | ||||
|                     self.logger.info("SavedModel 格式加载成功") | ||||
|                 except Exception as sm_exc: | ||||
|                     self.logger.error(f"SavedModel 格式加载也失败 ({sm_exc}). 无法加载模型。") | ||||
|                     # Consider adding the fallback JSON+weights logic here if needed, | ||||
|                     # but it's less common now. | ||||
|                     # Re-raising or handling the error appropriately | ||||
|                     raise ValueError(f"无法加载模型文件: {self.model_path}. H5 Error: {h5_exc}, SavedModel Error: {sm_exc}") | ||||
| 
 | ||||
|             # 加载tokenizer | ||||
|             self.logger.info(f"开始加载 Tokenizer: {self.tokenizer_path}") | ||||
|             with open(self.tokenizer_path, 'rb') as handle: | ||||
|                 self.tokenizer = pickle.load(handle) | ||||
|             self.logger.info("Tokenizer加载成功") | ||||
| @ -51,8 +80,9 @@ class TextClassificationModel: | ||||
|             self.is_initialized = True | ||||
|             self.logger.info("模型初始化完成") | ||||
|             return True | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             self.logger.error(f"模型初始化失败: {str(e)}") | ||||
|             self.logger.exception(f"模型初始化过程中发生严重错误: {str(e)}") # Use logger.exception to include traceback | ||||
|             self.is_initialized = False | ||||
|             return False | ||||
| 
 | ||||
| @ -80,9 +110,12 @@ class TextClassificationModel: | ||||
|             dict: 分类结果,包含类别标签和置信度 | ||||
|         """ | ||||
|         if not self.is_initialized: | ||||
|             self.logger.warning("模型尚未初始化,尝试现在初始化...") | ||||
|             success = self.initialize() | ||||
|             if not success: | ||||
|                 self.logger.error("分类前初始化模型失败。") | ||||
|                 return {"success": False, "error": "模型初始化失败"} | ||||
|             self.logger.info("模型初始化成功,继续分类。") | ||||
| 
 | ||||
|         try: | ||||
|             # 文本预处理 | ||||
| @ -92,17 +125,23 @@ class TextClassificationModel: | ||||
|             sequence = self.tokenizer.texts_to_sequences([processed_text]) | ||||
| 
 | ||||
|             # 填充序列 | ||||
|             padded_sequence = pad_sequences(sequence, maxlen=self.max_length, padding="post") | ||||
|             padded_sequence = tf.keras.preprocessing.sequence.pad_sequences( # Using tf.keras path | ||||
|                 sequence, maxlen=self.max_length, padding="post" | ||||
|                 ) | ||||
| 
 | ||||
|             # 预测 | ||||
|             predictions = self.model.predict(padded_sequence) | ||||
| 
 | ||||
|             # 获取预测类别索引和置信度 | ||||
|             predicted_index = np.argmax(predictions, axis=1)[0] | ||||
|             confidence = float(predictions[0][predicted_index]) | ||||
|             confidence = float(predictions[0][predicted_index]) # Convert numpy float to python float | ||||
| 
 | ||||
|             # 获取预测类别标签 | ||||
|             predicted_label = self.CATEGORIES[predicted_index] | ||||
|             if predicted_index < len(self.CATEGORIES): | ||||
|                 predicted_label = self.CATEGORIES[predicted_index] | ||||
|             else: | ||||
|                 self.logger.warning(f"预测索引 {predicted_index} 超出类别列表范围!") | ||||
|                 predicted_label = "未知类别" # Handle out-of-bounds index | ||||
| 
 | ||||
|             # 获取所有类别的置信度 | ||||
|             all_confidences = {cat: float(conf) for cat, conf in zip(self.CATEGORIES, predictions[0])} | ||||
| @ -114,8 +153,8 @@ class TextClassificationModel: | ||||
|                 "all_confidences": all_confidences | ||||
|             } | ||||
|         except Exception as e: | ||||
|             self.logger.error(f"文本分类过程中发生错误: {str(e)}") | ||||
|             return {"success": False, "error": str(e)} | ||||
|             self.logger.exception(f"文本分类过程中发生错误: {str(e)}") # Use logger.exception | ||||
|             return {"success": False, "error": f"分类错误: {str(e)}"} | ||||
| 
 | ||||
|     def classify_file(self, file_path): | ||||
|         """对文件内容进行分类 | ||||
| @ -126,28 +165,29 @@ class TextClassificationModel: | ||||
|         Returns: | ||||
|             dict: 分类结果,包含类别标签和置信度 | ||||
|         """ | ||||
|         try: | ||||
|             # 读取文件内容 | ||||
|             with open(file_path, 'r', encoding='utf-8') as f: | ||||
|                 text = f.read().strip() | ||||
| 
 | ||||
|             # 调用文本分类函数 | ||||
|             return self.classify_text(text) | ||||
| 
 | ||||
|         except UnicodeDecodeError: | ||||
|             # 如果UTF-8解码失败,尝试其他编码 | ||||
|         text = None | ||||
|         encodings_to_try = ['utf-8', 'gbk', 'gb18030'] # Common encodings | ||||
|         for enc in encodings_to_try: | ||||
|             try: | ||||
|                 with open(file_path, 'r', encoding='gbk') as f: | ||||
|                 with open(file_path, 'r', encoding=enc) as f: | ||||
|                     text = f.read().strip() | ||||
|                 return self.classify_text(text) | ||||
|                 self.logger.info(f"成功以 {enc} 编码读取文件: {file_path}") | ||||
|                 break # Exit loop if read successful | ||||
|             except UnicodeDecodeError: | ||||
|                 self.logger.warning(f"使用 {enc} 解码文件失败: {file_path}") | ||||
|                 continue # Try next encoding | ||||
|             except Exception as e: | ||||
|                 return {"success": False, "error": f"文件解码失败: {str(e)}"} | ||||
|                 self.logger.error(f"读取文件时发生其他错误 ({enc}): {str(e)}") | ||||
|                 return {"success": False, "error": f"文件读取错误 ({enc}): {str(e)}"} | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             self.logger.error(f"文件处理过程中发生错误: {str(e)}") | ||||
|             return {"success": False, "error": f"文件处理错误: {str(e)}"} | ||||
|         if text is None: | ||||
|             self.logger.error(f"尝试所有编码后仍无法读取文件: {file_path}") | ||||
|             return {"success": False, "error": f"文件解码失败,尝试的编码: {encodings_to_try}"} | ||||
| 
 | ||||
|         # 调用文本分类函数 | ||||
|         return self.classify_text(text) | ||||
| 
 | ||||
| 
 | ||||
| # 创建单例实例,避免重复加载模型 | ||||
| # Consider lazy initialization if the model is large and not always needed immediately | ||||
| text_classifier = TextClassificationModel() | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 superlishunqin
						superlishunqin