From 24953f68dff690d30f09842c358172bb05d8191a Mon Sep 17 00:00:00 2001 From: superlishunqin <852326703@qq.com> Date: Mon, 31 Mar 2025 03:06:26 +0800 Subject: [PATCH] fix_bug --- .idea/vcs.xml | 6 + app.py | 3 +- routes/classify.py | 612 +++++++++++++++++++++++------------------ utils/model_service.py | 98 +++++-- 4 files changed, 421 insertions(+), 298 deletions(-) create mode 100644 .idea/vcs.xml diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/app.py b/app.py index 460222f..82ce587 100644 --- a/app.py +++ b/app.py @@ -60,5 +60,4 @@ def internal_server_error(e): return render_template('error.html', error='服务器内部错误'), 500 if __name__ == '__main__': - app.run(debug=True, host='0.0.0.0', port=5009) - + app.run(debug=True, host='0.0.0.0', port=50004) \ No newline at end of file diff --git a/routes/classify.py b/routes/classify.py index 7ac0f98..09d38a6 100644 --- a/routes/classify.py +++ b/routes/classify.py @@ -8,7 +8,7 @@ import shutil from flask import Blueprint, request, jsonify, current_app, send_file, g, session from werkzeug.utils import secure_filename import mysql.connector -from utils.model_service import text_classifier +from utils.model_service import text_classifier # Assuming text_classifier is initialized correctly elsewhere from utils.db import get_db, close_db import logging @@ -59,7 +59,8 @@ def get_next_document_number(category): if not result or result['max_num'] is None: return 1 else: - return result['max_num'] + 1 + # Ensure it's an int before adding + return int(result['max_num']) + 1 except Exception as e: logger.error(f"获取下一个文档编号时出错: {str(e)}") return 1 @@ -80,12 +81,17 @@ def save_classified_document(user_id, original_filename, category, content, file Returns: tuple: (成功标志, 存储的文件名或错误信息) """ + db = None # Initialize db to None + cursor = None # Initialize cursor to None try: # 获取下一个文档编号 next_num = get_next_document_number(category) - # 安全处理文件名 - safe_original_name = secure_filename(original_filename) + # 安全处理文件名 - ensure it doesn't become empty + safe_original_name_base = os.path.splitext(secure_filename(original_filename))[0] + if not safe_original_name_base: + safe_original_name_base = "untitled" # Default if name becomes empty + safe_original_name = f"{safe_original_name_base}.txt" # Ensure .txt extension # 生成新文件名 (类别-编号-原文件名) formatted_num = f"{next_num:04d}" # 确保编号格式为4位数 @@ -117,16 +123,18 @@ def save_classified_document(user_id, original_filename, category, content, file category_result = cursor.fetchone() if not category_result: - return False, "类别不存在" + logger.error(f"类别 '{category}' 在数据库中未找到。") + # Optionally create the category if needed, or return error + return False, f"类别 '{category}' 不存在" category_id = category_result[0] # 插入数据库记录 insert_query = """ - INSERT INTO documents - (user_id, original_filename, stored_filename, file_path, file_size, category_id, status, classified_time) - VALUES (%s, %s, %s, %s, %s, %s, %s, NOW()) - """ + INSERT INTO documents + (user_id, original_filename, stored_filename, file_path, file_size, category_id, status, classified_time, upload_time) + VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW()) + """ # Added upload_time=NOW() assuming it's desired cursor.execute( insert_query, @@ -135,12 +143,19 @@ def save_classified_document(user_id, original_filename, category, content, file # 提交事务 db.commit() - + logger.info(f"成功保存文档: {new_filename} (ID: {cursor.lastrowid})") return True, new_filename except Exception as e: - logger.error(f"保存分类文档时出错: {str(e)}") + # Rollback transaction if error occurs + if db: + db.rollback() + logger.exception(f"保存分类文档 '{original_filename}' 时出错: {str(e)}") # Use logger.exception for traceback return False, str(e) + finally: + if cursor: + cursor.close() + # Note: db connection is usually closed via @app.teardown_appcontext @classify_bp.route('/single', methods=['POST']) @@ -157,203 +172,214 @@ def classify_single_file(): return jsonify({"success": False, "error": "没有文件"}), 400 file = request.files['file'] + original_filename = file.filename # Store original name early # 检查文件名 - if file.filename == '': + if original_filename == '': return jsonify({"success": False, "error": "未选择文件"}), 400 # 检查文件类型 - if not allowed_text_file(file.filename): + if not allowed_text_file(original_filename): return jsonify({"success": False, "error": "不支持的文件类型,仅支持txt文件"}), 400 + temp_path = None # Initialize to ensure it's defined for cleanup try: # 创建临时文件以供处理 temp_dir = os.path.join(current_app.root_path, 'temp') os.makedirs(temp_dir, exist_ok=True) - temp_filename = f"{uuid.uuid4().hex}.txt" + # Use secure_filename for the temp file name part as well + temp_filename = f"{uuid.uuid4().hex}-{secure_filename(original_filename)}" temp_path = os.path.join(temp_dir, temp_filename) # 保存上传文件到临时位置 file.save(temp_path) + file_size = os.path.getsize(temp_path) # Get size after saving + + # 读取文件内容 - Try multiple encodings + file_content = None + encodings_to_try = ['utf-8', 'gbk', 'gb18030'] + for enc in encodings_to_try: + try: + with open(temp_path, 'r', encoding=enc) as f: + file_content = f.read() + logger.info(f"成功以 {enc} 读取临时文件: {temp_path}") + break + except UnicodeDecodeError: + logger.warning(f"使用 {enc} 解码临时文件失败: {temp_path}") + continue + except Exception as read_err: # Catch other potential read errors + logger.error(f"读取临时文件时出错 ({enc}): {read_err}") + # Decide if you want to stop or try next encoding + continue + + if file_content is None: + raise ValueError(f"无法使用支持的编码 {encodings_to_try} 读取文件内容。") - # 读取文件内容 - with open(temp_path, 'r', encoding='utf-8') as f: - file_content = f.read() # 调用模型进行分类 result = text_classifier.classify_text(file_content) - if not result['success']: - return jsonify({"success": False, "error": result['error']}), 500 + if not result.get('success', False): # Check if key exists and is True + error_msg = result.get('error', '未知的分类错误') + logger.error(f"文本分类失败 for {original_filename}: {error_msg}") + return jsonify({"success": False, "error": error_msg}), 500 + + category = result['category'] # 保存分类后的文档 - file_size = os.path.getsize(temp_path) save_success, message = save_classified_document( user_id, - file.filename, - result['category'], + original_filename, # Use the original filename here + category, file_content, file_size ) - # 清理临时文件 - if os.path.exists(temp_path): - os.remove(temp_path) - if not save_success: return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500 # 返回分类结果 return jsonify({ "success": True, - "filename": file.filename, - "category": result['category'], - "confidence": result['confidence'], + "filename": original_filename, + "category": category, + "confidence": result.get('confidence'), # Use .get for safety "stored_filename": message }) - except UnicodeDecodeError: - # 尝试GBK编码 - try: - with open(temp_path, 'r', encoding='gbk') as f: - file_content = f.read() - - # 调用模型进行分类 - result = text_classifier.classify_text(file_content) - - if not result['success']: - return jsonify({"success": False, "error": result['error']}), 500 - - # 保存分类后的文档 - file_size = os.path.getsize(temp_path) - save_success, message = save_classified_document( - user_id, - file.filename, - result['category'], - file_content, - file_size - ) - - # 清理临时文件 - if os.path.exists(temp_path): - os.remove(temp_path) - - if not save_success: - return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500 - - # 返回分类结果 - return jsonify({ - "success": True, - "filename": file.filename, - "category": result['category'], - "confidence": result['confidence'], - "stored_filename": message - }) - except Exception as e: - if os.path.exists(temp_path): - os.remove(temp_path) - return jsonify({"success": False, "error": f"文件编码错误,请确保文件为UTF-8或GBK编码: {str(e)}"}), 400 - except Exception as e: - # 确保清理临时文件 - if 'temp_path' in locals() and os.path.exists(temp_path): - os.remove(temp_path) - logger.error(f"文件处理过程中发生错误: {str(e)}") + logger.exception(f"处理单文件 '{original_filename}' 过程中发生错误: {str(e)}") # Log exception with traceback return jsonify({"success": False, "error": f"文件处理错误: {str(e)}"}), 500 + finally: + # 清理临时文件 + if temp_path and os.path.exists(temp_path): + try: + os.remove(temp_path) + logger.info(f"已删除临时文件: {temp_path}") + except OSError as rm_err: + logger.error(f"删除临时文件失败 {temp_path}: {rm_err}") @classify_bp.route('/batch', methods=['POST']) def classify_batch_files(): """批量文件上传和分类API(压缩包处理)""" - # 检查用户是否登录 if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 user_id = session['user_id'] - # 检查是否上传了文件 if 'file' not in request.files: return jsonify({"success": False, "error": "没有文件"}), 400 file = request.files['file'] + original_archive_name = file.filename - # 检查文件名 - if file.filename == '': + if original_archive_name == '': return jsonify({"success": False, "error": "未选择文件"}), 400 - # 检查文件类型 - if not allowed_archive_file(file.filename): + if not allowed_archive_file(original_archive_name): return jsonify({"success": False, "error": "不支持的文件类型,仅支持zip和rar压缩文件"}), 400 - # 检查文件大小 - if request.content_length > 10 * 1024 * 1024: # 10MB - return jsonify({"success": False, "error": "文件太大,最大支持10MB"}), 400 + # Consider adding file size check here if needed (e.g., request.content_length) + # if request.content_length > current_app.config['MAX_CONTENT_LENGTH']: # Example + # return jsonify({"success": False, "error": "文件过大"}), 413 + + temp_dir = os.path.join(current_app.root_path, 'temp') + extract_dir = os.path.join(temp_dir, f"extract_{uuid.uuid4().hex}") + archive_path = None + # --- 修改点 1: 将 total_attempted 改回 total --- + results = { + "total": 0, # <--- 使用 'total' + "success": 0, + "failed": 0, + "categories": {}, + "failed_files": [] + } try: - # 创建临时目录 - temp_dir = os.path.join(current_app.root_path, 'temp') - extract_dir = os.path.join(temp_dir, f"extract_{uuid.uuid4().hex}") os.makedirs(extract_dir, exist_ok=True) - - # 保存上传的压缩文件 - archive_path = os.path.join(temp_dir, secure_filename(file.filename)) + archive_path = os.path.join(temp_dir, secure_filename(original_archive_name)) file.save(archive_path) + logger.info(f"已保存上传的压缩文件: {archive_path}") # 解压文件 - file_extension = file.filename.rsplit('.', 1)[1].lower() + file_extension = original_archive_name.rsplit('.', 1)[1].lower() if file_extension == 'zip': with zipfile.ZipFile(archive_path, 'r') as zip_ref: zip_ref.extractall(extract_dir) + logger.info(f"已解压ZIP文件到: {extract_dir}") elif file_extension == 'rar': - with rarfile.RarFile(archive_path, 'r') as rar_ref: - rar_ref.extractall(extract_dir) + # Ensure rarfile is installed and unrar executable is available + try: + with rarfile.RarFile(archive_path, 'r') as rar_ref: + rar_ref.extractall(extract_dir) + logger.info(f"已解压RAR文件到: {extract_dir}") + except rarfile.NeedFirstVolume: + logger.error(f"RAR文件是分卷压缩的一部分,需要第一个分卷: {original_archive_name}") + raise ValueError("不支持分卷压缩的RAR文件") + except rarfile.BadRarFile: + logger.error(f"RAR文件损坏或格式错误: {original_archive_name}") + raise ValueError("RAR文件损坏或格式错误") + except Exception as rar_err: # Catch other rar errors (like missing unrar) + logger.error(f"解压RAR文件时出错: {rar_err}") + raise ValueError(f"解压RAR文件失败: {rar_err}") - # 处理结果统计 - results = { - "total": 0, - "success": 0, - "failed": 0, - "categories": {}, - "failed_files": [] - } # 递归处理所有txt文件 for root, dirs, files in os.walk(extract_dir): + # --- 开始过滤 macOS 文件夹 --- + if '__MACOSX' in root.split(os.path.sep): + logger.info(f"Skipping macOS metadata directory: {root}") + dirs[:] = [] + files[:] = [] + continue + # --- 结束过滤 macOS 文件夹 --- + for filename in files: - if filename.lower().endswith('.txt'): + # --- 开始过滤 macOS 文件 --- + if filename.startswith('._') or filename == '.DS_Store': + logger.info(f"Skipping macOS metadata file: {filename} in {root}") + continue + # --- 结束过滤 macOS 文件 --- + + # Process only allowed text files + if allowed_text_file(filename): file_path = os.path.join(root, filename) - results["total"] += 1 + # --- 修改点 2: 将 total_attempted 改回 total --- + results["total"] += 1 # <--- 使用 'total' try: - # 读取文件内容 - try: - with open(file_path, 'r', encoding='utf-8') as f: - file_content = f.read() - except UnicodeDecodeError: - # 尝试GBK编码 - with open(file_path, 'r', encoding='gbk') as f: - file_content = f.read() + # 读取文件内容 - Try multiple encodings + file_content = None + encodings_to_try = ['utf-8', 'gbk', 'gb18030'] + for enc in encodings_to_try: + try: + with open(file_path, 'r', encoding=enc) as f: + file_content = f.read() + logger.info(f"成功以 {enc} 读取文件: {file_path}") + break + except UnicodeDecodeError: + logger.warning(f"使用 {enc} 解码文件失败: {file_path}") + continue + except Exception as read_err: + logger.error(f"读取文件时发生其他错误 ({enc}) {file_path}: {read_err}") + continue + + if file_content is None: + raise ValueError(f"无法使用支持的编码 {encodings_to_try} 读取文件内容。") # 调用模型进行分类 result = text_classifier.classify_text(file_content) - if not result['success']: - results["failed"] += 1 - results["failed_files"].append({ - "filename": filename, - "error": result['error'] - }) - continue + if not result.get('success', False): + error_msg = result.get('error', '未知的分类错误') + raise Exception(f"分类失败: {error_msg}") category = result['category'] - # 统计类别数量 - if category not in results["categories"]: - results["categories"][category] = 0 - results["categories"][category] += 1 + results["categories"][category] = results["categories"].get(category, 0) + 1 - # 保存分类后的文档 file_size = os.path.getsize(file_path) save_success, message = save_classified_document( user_id, @@ -366,106 +392,108 @@ def classify_batch_files(): if save_success: results["success"] += 1 else: - results["failed"] += 1 - results["failed_files"].append({ - "filename": filename, - "error": message - }) + raise Exception(f"保存失败: {message}") except Exception as e: + logger.error(f"处理文件 '{filename}' 失败: {str(e)}") results["failed"] += 1 results["failed_files"].append({ "filename": filename, "error": str(e) }) - # 清理临时文件 - if os.path.exists(archive_path): - os.remove(archive_path) - if os.path.exists(extract_dir): - shutil.rmtree(extract_dir) - # 返回处理结果 return jsonify({ "success": True, - "archive_name": file.filename, - "results": results + "archive_name": original_archive_name, + "results": results # <--- 确保返回的 results 包含 'total' 键 }) except Exception as e: - # 确保清理临时文件 - if 'archive_path' in locals() and os.path.exists(archive_path): - os.remove(archive_path) - if 'extract_dir' in locals() and os.path.exists(extract_dir): - shutil.rmtree(extract_dir) - - logger.error(f"压缩包处理过程中发生错误: {str(e)}") + logger.exception(f"处理压缩包 '{original_archive_name}' 过程中发生严重错误: {str(e)}") return jsonify({"success": False, "error": f"压缩包处理错误: {str(e)}"}), 500 + finally: + # 清理临时文件和目录 (保持不变) + if archive_path and os.path.exists(archive_path): + try: + os.remove(archive_path) + logger.info(f"已删除临时压缩文件: {archive_path}") + except OSError as rm_err: + logger.error(f"删除临时压缩文件失败 {archive_path}: {rm_err}") + if os.path.exists(extract_dir): + try: + shutil.rmtree(extract_dir) + logger.info(f"已删除临时解压目录: {extract_dir}") + except OSError as rmtree_err: + logger.error(f"删除临时解压目录失败 {extract_dir}: {rmtree_err}") + @classify_bp.route('/documents', methods=['GET']) def get_classified_documents(): """获取已分类的文档列表""" - # 检查用户是否登录 if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 user_id = session['user_id'] - # 获取查询参数 - category = request.args.get('category', 'all') - page = int(request.args.get('page', 1)) - per_page = int(request.args.get('per_page', 10)) - - # 验证每页条数 - if per_page not in [10, 25, 50, 100]: - per_page = 10 - - # 计算偏移量 - offset = (page - 1) * per_page - - db = get_db() - cursor = db.cursor(dictionary=True) - try: - # 构建查询条件 - where_clause = "WHERE d.user_id = %s AND d.status = '已分类'" - params = [user_id] + category = request.args.get('category', 'all') + page = int(request.args.get('page', 1)) + per_page = int(request.args.get('per_page', 10)) - if category != 'all': - where_clause += " AND c.name = %s" + if per_page not in [10, 25, 50, 100]: + per_page = 10 + if page < 1: + page = 1 + + offset = (page - 1) * per_page + + db = get_db() + cursor = db.cursor(dictionary=True) # Use dictionary cursor + + # Get available categories first + cursor.execute("SELECT name FROM categories ORDER BY name") + available_categories = [row['name'] for row in cursor.fetchall()] + + # Build query + params = [user_id] + base_query = """ + FROM documents d + JOIN categories c ON d.category_id = c.id + WHERE d.user_id = %s AND d.status = '已分类' + """ + if category != 'all' and category in available_categories: + base_query += " AND c.name = %s" params.append(category) - # 查询总记录数 - count_query = f""" - SELECT COUNT(*) as total - FROM documents d - JOIN categories c ON d.category_id = c.id - {where_clause} - """ + # Count query + count_query = f"SELECT COUNT(*) as total {base_query}" cursor.execute(count_query, params) - total_count = cursor.fetchone()['total'] + total_count_result = cursor.fetchone() + total_count = total_count_result['total'] if total_count_result else 0 - # 计算总页数 - total_pages = (total_count + per_page - 1) // per_page + total_pages = (total_count + per_page - 1) // per_page if per_page > 0 else 0 - # 查询分页数据 - query = f""" - SELECT d.id, d.original_filename, d.stored_filename, d.file_size, + # Data query + data_query = f""" + SELECT d.id, d.original_filename, d.stored_filename, d.file_size, c.name as category, d.upload_time, d.classified_time - FROM documents d - JOIN categories c ON d.category_id = c.id - {where_clause} + {base_query} ORDER BY d.classified_time DESC LIMIT %s OFFSET %s """ params.extend([per_page, offset]) - cursor.execute(query, params) + cursor.execute(data_query, params) documents = cursor.fetchall() - # 获取所有可用类别 - cursor.execute("SELECT name FROM categories ORDER BY name") - categories = [row['name'] for row in cursor.fetchall()] + # Format dates if they are datetime objects (optional, depends on driver) + for doc in documents: + if hasattr(doc.get('upload_time'), 'isoformat'): + doc['upload_time'] = doc['upload_time'].isoformat() + if hasattr(doc.get('classified_time'), 'isoformat'): + doc['classified_time'] = doc['classified_time'].isoformat() + return jsonify({ "success": True, @@ -476,36 +504,35 @@ def get_classified_documents(): "current_page": page, "total_pages": total_pages }, - "categories": categories, + "categories": available_categories, # Send the fetched list "current_category": category }) except Exception as e: - logger.error(f"获取文档列表时出错: {str(e)}") + logger.exception(f"获取文档列表时出错: {str(e)}") # Log traceback return jsonify({"success": False, "error": f"获取文档列表失败: {str(e)}"}), 500 - finally: - cursor.close() + # Cursor closing handled by context usually, but explicit close is safe + if 'cursor' in locals() and cursor: + cursor.close() @classify_bp.route('/download/', methods=['GET']) def download_document(document_id): """下载已分类的文档""" - # 检查用户是否登录 if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 user_id = session['user_id'] - - db = get_db() - cursor = db.cursor(dictionary=True) - + cursor = None try: - # 查询文档信息 + db = get_db() + cursor = db.cursor(dictionary=True) + query = """ - SELECT file_path, original_filename, stored_filename + SELECT file_path, original_filename FROM documents - WHERE id = %s AND user_id = %s + WHERE id = %s AND user_id = %s AND status = '已分类' """ cursor.execute(query, (document_id, user_id)) document = cursor.fetchone() @@ -513,136 +540,187 @@ def download_document(document_id): if not document: return jsonify({"success": False, "error": "文档不存在或无权访问"}), 404 - # 检查文件是否存在 - if not os.path.exists(document['file_path']): - return jsonify({"success": False, "error": "文件不存在"}), 404 + file_path = document['file_path'] + download_name = document['original_filename'] - # 返回文件下载 + if not os.path.exists(file_path): + logger.error(f"请求下载的文件在服务器上不存在: {file_path} (Doc ID: {document_id})") + # Update status in DB? + # update_status_query = "UPDATE documents SET status = '文件丢失' WHERE id = %s" + # cursor.execute(update_status_query, (document_id,)) + # db.commit() + return jsonify({"success": False, "error": "文件在服务器上丢失"}), 404 + + logger.info(f"用户 {user_id} 请求下载文档 ID: {document_id}, 文件: {download_name}") return send_file( - document['file_path'], + file_path, as_attachment=True, - download_name=document['original_filename'], - mimetype='text/plain' + download_name=download_name, # Use original filename for download + mimetype='text/plain' # Assuming text, adjust if other types allowed ) except Exception as e: - logger.error(f"下载文档时出错: {str(e)}") + logger.exception(f"下载文档 ID {document_id} 时出错: {str(e)}") return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500 - finally: - cursor.close() + if cursor: + cursor.close() @classify_bp.route('/download-multiple', methods=['POST']) def download_multiple_documents(): """下载多个文档(打包为zip)""" - # 检查用户是否登录 if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 - + user_id = session['user_id'] - - # 获取请求数据 - data = request.get_json() - if not data or 'document_ids' not in data: - return jsonify({"success": False, "error": "缺少必要参数"}), 400 - - document_ids = data['document_ids'] - if not isinstance(document_ids, list) or not document_ids: - return jsonify({"success": False, "error": "文档ID列表无效"}), 400 - - db = get_db() - cursor = db.cursor(dictionary=True) - + cursor = None + zip_path = None # Initialize for finally block + try: - # 创建临时目录用于存放zip文件 - temp_dir = os.path.join(current_app.root_path, 'temp') - os.makedirs(temp_dir, exist_ok=True) - - # 创建临时ZIP文件 - zip_filename = f"documents_{int(time.time())}.zip" - zip_path = os.path.join(temp_dir, zip_filename) - - # 查询所有符合条件的文档 + data = request.get_json() + if not data or 'document_ids' not in data: + return jsonify({"success": False, "error": "缺少必要参数 'document_ids'"}), 400 + + document_ids = data['document_ids'] + if not isinstance(document_ids, list) or not document_ids: + return jsonify({"success": False, "error": "文档ID列表无效"}), 400 + # Sanitize IDs to be integers + try: + document_ids = [int(doc_id) for doc_id in document_ids] + except ValueError: + return jsonify({"success": False, "error": "文档ID必须是数字"}), 400 + + db = get_db() + cursor = db.cursor(dictionary=True) + + # Query valid documents for the user placeholders = ', '.join(['%s'] * len(document_ids)) query = f""" SELECT id, file_path, original_filename FROM documents - WHERE id IN ({placeholders}) AND user_id = %s + WHERE id IN ({placeholders}) AND user_id = %s AND status = '已分类' """ params = document_ids + [user_id] cursor.execute(query, params) documents = cursor.fetchall() - + if not documents: - return jsonify({"success": False, "error": "没有找到符合条件的文档"}), 404 - - # 创建ZIP文件并添加文档 - with zipfile.ZipFile(zip_path, 'w') as zipf: + return jsonify({"success": False, "error": "没有找到符合条件的可下载文档"}), 404 + + # Create temp zip file + temp_dir = os.path.join(current_app.root_path, 'temp') + os.makedirs(temp_dir, exist_ok=True) + zip_filename = f"documents_{user_id}_{int(time.time())}.zip" + zip_path = os.path.join(temp_dir, zip_filename) + + files_added = 0 + with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: for doc in documents: - if os.path.exists(doc['file_path']): - # 添加文件到zip,使用原始文件名 - zipf.write(doc['file_path'], arcname=doc['original_filename']) - - # 返回ZIP文件下载 + file_path = doc['file_path'] + original_filename = doc['original_filename'] + if os.path.exists(file_path): + zipf.write(file_path, arcname=original_filename) + files_added += 1 + logger.info(f"添加到ZIP: {original_filename} (来自 {file_path})") + else: + logger.warning(f"跳过丢失的文件: {original_filename} (路径: {file_path}, Doc ID: {doc['id']})") + # Optionally add a readme to the zip indicating missing files + + if files_added == 0: + # This case means all selected files were missing on disk + return jsonify({"success": False, "error": "所有选中的文件在服务器上都已丢失"}), 404 + + + logger.info(f"用户 {user_id} 请求下载 {files_added} 个文档打包为 {zip_filename}") + # Use after_this_request to delete the temp file after sending + @current_app.after_request + def remove_file(response): + try: + if zip_path and os.path.exists(zip_path): + os.remove(zip_path) + logger.info(f"已删除临时ZIP文件: {zip_path}") + except Exception as error: + logger.error(f"删除临时ZIP文件错误 {zip_path}: {error}") + return response + return send_file( zip_path, as_attachment=True, download_name=zip_filename, mimetype='application/zip' ) - + except Exception as e: - logger.error(f"下载多个文档时出错: {str(e)}") + logger.exception(f"下载多个文档时出错: {str(e)}") + # Clean up zip file if created but error occurred before sending + if zip_path and os.path.exists(zip_path): + try: os.remove(zip_path) + except OSError: pass return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500 - + finally: - cursor.close() + if cursor: + cursor.close() @classify_bp.route('/classify-text', methods=['POST']) def classify_text_directly(): """直接对文本进行分类(不保存文件)""" - # 检查用户是否登录 if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 - # 获取请求数据 - data = request.get_json() - if not data or 'text' not in data: - return jsonify({"success": False, "error": "缺少必要参数"}), 400 - - text = data['text'] - if not text.strip(): - return jsonify({"success": False, "error": "文本内容不能为空"}), 400 - try: + data = request.get_json() + if not data or 'text' not in data: + return jsonify({"success": False, "error": "缺少必要参数 'text'"}), 400 + + text = data['text'] + # Basic validation + if not isinstance(text, str) or not text.strip(): + return jsonify({"success": False, "error": "文本内容不能为空或无效"}), 400 + + # Limit text length? + MAX_TEXT_LENGTH = 10000 # Example limit + if len(text) > MAX_TEXT_LENGTH: + return jsonify({"success": False, "error": f"文本过长,最大支持 {MAX_TEXT_LENGTH} 字符"}), 413 + + # 调用模型进行分类 + # Ensure model is initialized - consider adding a check or lazy loading + if not text_classifier or not text_classifier.is_initialized: + logger.warning("文本分类器未初始化,尝试初始化...") + if not text_classifier.initialize(): + logger.error("文本分类器初始化失败。") + return jsonify({"success": False, "error": "分类服务暂时不可用"}), 503 + result = text_classifier.classify_text(text) - if not result['success']: - return jsonify({"success": False, "error": result['error']}), 500 + if not result.get('success', False): + error_msg = result.get('error', '未知的分类错误') + logger.error(f"直接文本分类失败: {error_msg}") + return jsonify({"success": False, "error": error_msg}), 500 # 返回分类结果 return jsonify({ "success": True, - "category": result['category'], - "confidence": result['confidence'], - "all_confidences": result['all_confidences'] + "category": result.get('category'), + "confidence": result.get('confidence'), + "all_confidences": result.get('all_confidences') # Include all confidences if available }) except Exception as e: - logger.error(f"文本分类过程中发生错误: {str(e)}") + logger.exception(f"直接文本分类过程中发生错误: {str(e)}") return jsonify({"success": False, "error": f"文本分类错误: {str(e)}"}), 500 @classify_bp.route('/categories', methods=['GET']) def get_categories(): """获取所有分类类别""" - db = get_db() - cursor = db.cursor(dictionary=True) - + cursor = None try: + db = get_db() + cursor = db.cursor(dictionary=True) cursor.execute("SELECT id, name, description FROM categories ORDER BY name") categories = cursor.fetchall() @@ -652,8 +730,8 @@ def get_categories(): }) except Exception as e: - logger.error(f"获取类别列表时出错: {str(e)}") + logger.exception(f"获取类别列表时出错: {str(e)}") return jsonify({"success": False, "error": f"获取类别列表失败: {str(e)}"}), 500 - finally: - cursor.close() + if cursor: + cursor.close() diff --git a/utils/model_service.py b/utils/model_service.py index 7f8937a..9b88659 100644 --- a/utils/model_service.py +++ b/utils/model_service.py @@ -3,9 +3,12 @@ import os import jieba import numpy as np import pickle -from tensorflow.keras.models import load_model -from tensorflow.keras.preprocessing.sequence import pad_sequences +import tensorflow as tf +# from tensorflow import keras # tf.keras is preferred +from keras.models import load_model # Keep if specifically needed, else use tf.keras.models.load_model +from keras.preprocessing.sequence import pad_sequences # Keep if specifically needed, else use tf.keras.preprocessing.sequence.pad_sequences import logging +import h5py # Moved import here as it's used conditionally class TextClassificationModel: @@ -34,16 +37,42 @@ class TextClassificationModel: # 设置日志 self.logger = logging.getLogger(__name__) + logging.basicConfig(level=logging.INFO) # Basic logging setup if not configured elsewhere def initialize(self): - """初始化并加载模型和分词器""" + """初始化并加载模型和分词器""" # <--- Corrected Indentation Starts Here try: self.logger.info("开始加载文本分类模型...") - # 加载模型 - self.model = load_model(self.model_path) - self.logger.info("模型加载成功") + + # 优先尝试加载 HDF5 格式 (.h5),因为文件名是 .h5 + try: + self.logger.info(f"尝试以 HDF5 格式加载模型: {self.model_path}") + # For H5 files, direct load_model is usually sufficient if saved correctly. + # compile=False is often needed if you don't need training features immediately. + self.model = tf.keras.models.load_model(self.model_path, compile=False) + self.logger.info("HDF5 模型加载成功") + + except Exception as h5_exc: + self.logger.warning(f"HDF5 格式加载失败 ({h5_exc}),尝试以 SavedModel 格式加载...") + # 如果 HDF5 加载失败,再尝试 SavedModel 格式 (通常是一个目录,而不是 .h5 文件) + # This might fail if model_path truly points to an h5 file. + try: + self.model = tf.keras.models.load_model( + self.model_path, + compile=False # Usually false for inference + # custom_objects can be added here if needed + # options=tf.saved_model.LoadOptions(experimental_io_device='/job:localhost') # Usually not needed unless specific TF distribution setup + ) + self.logger.info("SavedModel 格式加载成功") + except Exception as sm_exc: + self.logger.error(f"SavedModel 格式加载也失败 ({sm_exc}). 无法加载模型。") + # Consider adding the fallback JSON+weights logic here if needed, + # but it's less common now. + # Re-raising or handling the error appropriately + raise ValueError(f"无法加载模型文件: {self.model_path}. H5 Error: {h5_exc}, SavedModel Error: {sm_exc}") # 加载tokenizer + self.logger.info(f"开始加载 Tokenizer: {self.tokenizer_path}") with open(self.tokenizer_path, 'rb') as handle: self.tokenizer = pickle.load(handle) self.logger.info("Tokenizer加载成功") @@ -51,8 +80,9 @@ class TextClassificationModel: self.is_initialized = True self.logger.info("模型初始化完成") return True + except Exception as e: - self.logger.error(f"模型初始化失败: {str(e)}") + self.logger.exception(f"模型初始化过程中发生严重错误: {str(e)}") # Use logger.exception to include traceback self.is_initialized = False return False @@ -80,9 +110,12 @@ class TextClassificationModel: dict: 分类结果,包含类别标签和置信度 """ if not self.is_initialized: + self.logger.warning("模型尚未初始化,尝试现在初始化...") success = self.initialize() if not success: + self.logger.error("分类前初始化模型失败。") return {"success": False, "error": "模型初始化失败"} + self.logger.info("模型初始化成功,继续分类。") try: # 文本预处理 @@ -92,17 +125,23 @@ class TextClassificationModel: sequence = self.tokenizer.texts_to_sequences([processed_text]) # 填充序列 - padded_sequence = pad_sequences(sequence, maxlen=self.max_length, padding="post") + padded_sequence = tf.keras.preprocessing.sequence.pad_sequences( # Using tf.keras path + sequence, maxlen=self.max_length, padding="post" + ) # 预测 predictions = self.model.predict(padded_sequence) # 获取预测类别索引和置信度 predicted_index = np.argmax(predictions, axis=1)[0] - confidence = float(predictions[0][predicted_index]) + confidence = float(predictions[0][predicted_index]) # Convert numpy float to python float # 获取预测类别标签 - predicted_label = self.CATEGORIES[predicted_index] + if predicted_index < len(self.CATEGORIES): + predicted_label = self.CATEGORIES[predicted_index] + else: + self.logger.warning(f"预测索引 {predicted_index} 超出类别列表范围!") + predicted_label = "未知类别" # Handle out-of-bounds index # 获取所有类别的置信度 all_confidences = {cat: float(conf) for cat, conf in zip(self.CATEGORIES, predictions[0])} @@ -114,8 +153,8 @@ class TextClassificationModel: "all_confidences": all_confidences } except Exception as e: - self.logger.error(f"文本分类过程中发生错误: {str(e)}") - return {"success": False, "error": str(e)} + self.logger.exception(f"文本分类过程中发生错误: {str(e)}") # Use logger.exception + return {"success": False, "error": f"分类错误: {str(e)}"} def classify_file(self, file_path): """对文件内容进行分类 @@ -126,28 +165,29 @@ class TextClassificationModel: Returns: dict: 分类结果,包含类别标签和置信度 """ - try: - # 读取文件内容 - with open(file_path, 'r', encoding='utf-8') as f: - text = f.read().strip() - - # 调用文本分类函数 - return self.classify_text(text) - - except UnicodeDecodeError: - # 如果UTF-8解码失败,尝试其他编码 + text = None + encodings_to_try = ['utf-8', 'gbk', 'gb18030'] # Common encodings + for enc in encodings_to_try: try: - with open(file_path, 'r', encoding='gbk') as f: + with open(file_path, 'r', encoding=enc) as f: text = f.read().strip() - return self.classify_text(text) + self.logger.info(f"成功以 {enc} 编码读取文件: {file_path}") + break # Exit loop if read successful + except UnicodeDecodeError: + self.logger.warning(f"使用 {enc} 解码文件失败: {file_path}") + continue # Try next encoding except Exception as e: - return {"success": False, "error": f"文件解码失败: {str(e)}"} + self.logger.error(f"读取文件时发生其他错误 ({enc}): {str(e)}") + return {"success": False, "error": f"文件读取错误 ({enc}): {str(e)}"} - except Exception as e: - self.logger.error(f"文件处理过程中发生错误: {str(e)}") - return {"success": False, "error": f"文件处理错误: {str(e)}"} + if text is None: + self.logger.error(f"尝试所有编码后仍无法读取文件: {file_path}") + return {"success": False, "error": f"文件解码失败,尝试的编码: {encodings_to_try}"} + + # 调用文本分类函数 + return self.classify_text(text) # 创建单例实例,避免重复加载模型 +# Consider lazy initialization if the model is large and not always needed immediately text_classifier = TextClassificationModel() -