# routes/classify.py import os import time import uuid import zipfile import rarfile import shutil from flask import Blueprint, request, jsonify, current_app, send_file, g, session from werkzeug.utils import secure_filename import mysql.connector from utils.model_service import text_classifier # Assuming text_classifier is initialized correctly elsewhere from utils.db import get_db, close_db import logging # 创建蓝图 classify_bp = Blueprint('classify', __name__, url_prefix='/api/classify') # 设置日志 logger = logging.getLogger(__name__) # 允许的文件扩展名 ALLOWED_TEXT_EXTENSIONS = {'txt'} ALLOWED_ARCHIVE_EXTENSIONS = {'zip', 'rar'} def allowed_text_file(filename): """检查文本文件扩展名是否允许""" return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_TEXT_EXTENSIONS def allowed_archive_file(filename): """检查压缩文件扩展名是否允许""" return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_ARCHIVE_EXTENSIONS def get_next_document_number(category): """获取给定类别的下一个文档编号 Args: category (str): 文档类别 Returns: int: 下一个编号 """ db = get_db() cursor = db.cursor(dictionary=True) try: # 查询该类别的最大编号 query = """ SELECT MAX(CAST(SUBSTRING_INDEX(SUBSTRING_INDEX(stored_filename, '-', 2), '-', -1) AS UNSIGNED)) as max_num FROM documents WHERE stored_filename LIKE %s """ cursor.execute(query, (f"{category}-%",)) result = cursor.fetchone() # 如果没有记录,或者最大编号为None,返回1,否则返回最大编号+1 if not result or result['max_num'] is None: return 1 else: # Ensure it's an int before adding return int(result['max_num']) + 1 except Exception as e: logger.error(f"获取下一个文档编号时出错: {str(e)}") return 1 finally: cursor.close() def save_classified_document(user_id, original_filename, category, content, file_size=None): """保存分类后的文档 Args: user_id (int): 用户ID original_filename (str): 原始文件名 category (str): 分类类别 content (str): 文档内容 file_size (int, optional): 文件大小,如果为None则自动计算 Returns: tuple: (成功标志, 存储的文件名或错误信息) """ db = None # Initialize db to None cursor = None # Initialize cursor to None try: # 获取下一个文档编号 next_num = get_next_document_number(category) # 安全处理文件名 - ensure it doesn't become empty safe_original_name_base = os.path.splitext(secure_filename(original_filename))[0] if not safe_original_name_base: safe_original_name_base = "untitled" # Default if name becomes empty safe_original_name = f"{safe_original_name_base}.txt" # Ensure .txt extension # 生成新文件名 (类别-编号-原文件名) formatted_num = f"{next_num:04d}" # 确保编号格式为4位数 new_filename = f"{category}-{formatted_num}-{safe_original_name}" # 确保uploads目录存在 uploads_dir = os.path.join(current_app.root_path, current_app.config['UPLOAD_FOLDER']) category_dir = os.path.join(uploads_dir, category) os.makedirs(category_dir, exist_ok=True) # 文件完整路径 file_path = os.path.join(category_dir, new_filename) # 保存文件 with open(file_path, 'w', encoding='utf-8') as f: f.write(content) # 如果文件大小未提供,则计算 if file_size is None: file_size = os.path.getsize(file_path) # 获取db连接 db = get_db() cursor = db.cursor() # 获取分类类别ID category_query = "SELECT id FROM categories WHERE name = %s" cursor.execute(category_query, (category,)) category_result = cursor.fetchone() if not category_result: logger.error(f"类别 '{category}' 在数据库中未找到。") # Optionally create the category if needed, or return error return False, f"类别 '{category}' 不存在" category_id = category_result[0] # 插入数据库记录 insert_query = """ INSERT INTO documents (user_id, original_filename, stored_filename, file_path, file_size, category_id, status, classified_time, upload_time) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW()) """ # Added upload_time=NOW() assuming it's desired cursor.execute( insert_query, (user_id, original_filename, new_filename, file_path, file_size, category_id, '已分类') ) # 提交事务 db.commit() logger.info(f"成功保存文档: {new_filename} (ID: {cursor.lastrowid})") return True, new_filename except Exception as e: # Rollback transaction if error occurs if db: db.rollback() logger.exception(f"保存分类文档 '{original_filename}' 时出错: {str(e)}") # Use logger.exception for traceback return False, str(e) finally: if cursor: cursor.close() # Note: db connection is usually closed via @app.teardown_appcontext @classify_bp.route('/single', methods=['POST']) def classify_single_file(): """单文件上传和分类API""" # 检查用户是否登录 if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 user_id = session['user_id'] # 检查是否上传了文件 if 'file' not in request.files: return jsonify({"success": False, "error": "没有文件"}), 400 file = request.files['file'] original_filename = file.filename # Store original name early # 检查文件名 if original_filename == '': return jsonify({"success": False, "error": "未选择文件"}), 400 # 检查文件类型 if not allowed_text_file(original_filename): return jsonify({"success": False, "error": "不支持的文件类型,仅支持txt文件"}), 400 temp_path = None # Initialize to ensure it's defined for cleanup try: # 创建临时文件以供处理 temp_dir = os.path.join(current_app.root_path, 'temp') os.makedirs(temp_dir, exist_ok=True) # Use secure_filename for the temp file name part as well temp_filename = f"{uuid.uuid4().hex}-{secure_filename(original_filename)}" temp_path = os.path.join(temp_dir, temp_filename) # 保存上传文件到临时位置 file.save(temp_path) file_size = os.path.getsize(temp_path) # Get size after saving # 读取文件内容 - Try multiple encodings file_content = None encodings_to_try = ['utf-8', 'gbk', 'gb18030'] for enc in encodings_to_try: try: with open(temp_path, 'r', encoding=enc) as f: file_content = f.read() logger.info(f"成功以 {enc} 读取临时文件: {temp_path}") break except UnicodeDecodeError: logger.warning(f"使用 {enc} 解码临时文件失败: {temp_path}") continue except Exception as read_err: # Catch other potential read errors logger.error(f"读取临时文件时出错 ({enc}): {read_err}") # Decide if you want to stop or try next encoding continue if file_content is None: raise ValueError(f"无法使用支持的编码 {encodings_to_try} 读取文件内容。") # 调用模型进行分类 result = text_classifier.classify_text(file_content) if not result.get('success', False): # Check if key exists and is True error_msg = result.get('error', '未知的分类错误') logger.error(f"文本分类失败 for {original_filename}: {error_msg}") return jsonify({"success": False, "error": error_msg}), 500 category = result['category'] # 保存分类后的文档 save_success, message = save_classified_document( user_id, original_filename, # Use the original filename here category, file_content, file_size ) if not save_success: return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500 # 返回分类结果 return jsonify({ "success": True, "filename": original_filename, "category": category, "confidence": result.get('confidence'), # Use .get for safety "stored_filename": message }) except Exception as e: logger.exception(f"处理单文件 '{original_filename}' 过程中发生错误: {str(e)}") # Log exception with traceback return jsonify({"success": False, "error": f"文件处理错误: {str(e)}"}), 500 finally: # 清理临时文件 if temp_path and os.path.exists(temp_path): try: os.remove(temp_path) logger.info(f"已删除临时文件: {temp_path}") except OSError as rm_err: logger.error(f"删除临时文件失败 {temp_path}: {rm_err}") @classify_bp.route('/batch', methods=['POST']) def classify_batch_files(): """批量文件上传和分类API(压缩包处理)""" if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 user_id = session['user_id'] if 'file' not in request.files: return jsonify({"success": False, "error": "没有文件"}), 400 file = request.files['file'] original_archive_name = file.filename if original_archive_name == '': return jsonify({"success": False, "error": "未选择文件"}), 400 if not allowed_archive_file(original_archive_name): return jsonify({"success": False, "error": "不支持的文件类型,仅支持zip和rar压缩文件"}), 400 # Consider adding file size check here if needed (e.g., request.content_length) # if request.content_length > current_app.config['MAX_CONTENT_LENGTH']: # Example # return jsonify({"success": False, "error": "文件过大"}), 413 temp_dir = os.path.join(current_app.root_path, 'temp') extract_dir = os.path.join(temp_dir, f"extract_{uuid.uuid4().hex}") archive_path = None # --- 修改点 1: 将 total_attempted 改回 total --- results = { "total": 0, # <--- 使用 'total' "success": 0, "failed": 0, "categories": {}, "failed_files": [] } try: os.makedirs(extract_dir, exist_ok=True) archive_path = os.path.join(temp_dir, secure_filename(original_archive_name)) file.save(archive_path) logger.info(f"已保存上传的压缩文件: {archive_path}") # 解压文件 file_extension = original_archive_name.rsplit('.', 1)[1].lower() if file_extension == 'zip': with zipfile.ZipFile(archive_path, 'r') as zip_ref: zip_ref.extractall(extract_dir) logger.info(f"已解压ZIP文件到: {extract_dir}") elif file_extension == 'rar': # Ensure rarfile is installed and unrar executable is available try: with rarfile.RarFile(archive_path, 'r') as rar_ref: rar_ref.extractall(extract_dir) logger.info(f"已解压RAR文件到: {extract_dir}") except rarfile.NeedFirstVolume: logger.error(f"RAR文件是分卷压缩的一部分,需要第一个分卷: {original_archive_name}") raise ValueError("不支持分卷压缩的RAR文件") except rarfile.BadRarFile: logger.error(f"RAR文件损坏或格式错误: {original_archive_name}") raise ValueError("RAR文件损坏或格式错误") except Exception as rar_err: # Catch other rar errors (like missing unrar) logger.error(f"解压RAR文件时出错: {rar_err}") raise ValueError(f"解压RAR文件失败: {rar_err}") # 递归处理所有txt文件 for root, dirs, files in os.walk(extract_dir): # --- 开始过滤 macOS 文件夹 --- if '__MACOSX' in root.split(os.path.sep): logger.info(f"Skipping macOS metadata directory: {root}") dirs[:] = [] files[:] = [] continue # --- 结束过滤 macOS 文件夹 --- for filename in files: # --- 开始过滤 macOS 文件 --- if filename.startswith('._') or filename == '.DS_Store': logger.info(f"Skipping macOS metadata file: {filename} in {root}") continue # --- 结束过滤 macOS 文件 --- # Process only allowed text files if allowed_text_file(filename): file_path = os.path.join(root, filename) # --- 修改点 2: 将 total_attempted 改回 total --- results["total"] += 1 # <--- 使用 'total' try: # 读取文件内容 - Try multiple encodings file_content = None encodings_to_try = ['utf-8', 'gbk', 'gb18030'] for enc in encodings_to_try: try: with open(file_path, 'r', encoding=enc) as f: file_content = f.read() logger.info(f"成功以 {enc} 读取文件: {file_path}") break except UnicodeDecodeError: logger.warning(f"使用 {enc} 解码文件失败: {file_path}") continue except Exception as read_err: logger.error(f"读取文件时发生其他错误 ({enc}) {file_path}: {read_err}") continue if file_content is None: raise ValueError(f"无法使用支持的编码 {encodings_to_try} 读取文件内容。") # 调用模型进行分类 result = text_classifier.classify_text(file_content) if not result.get('success', False): error_msg = result.get('error', '未知的分类错误') raise Exception(f"分类失败: {error_msg}") category = result['category'] results["categories"][category] = results["categories"].get(category, 0) + 1 file_size = os.path.getsize(file_path) save_success, message = save_classified_document( user_id, filename, category, file_content, file_size ) if save_success: results["success"] += 1 else: raise Exception(f"保存失败: {message}") except Exception as e: logger.error(f"处理文件 '{filename}' 失败: {str(e)}") results["failed"] += 1 results["failed_files"].append({ "filename": filename, "error": str(e) }) # 返回处理结果 return jsonify({ "success": True, "archive_name": original_archive_name, "results": results # <--- 确保返回的 results 包含 'total' 键 }) except Exception as e: logger.exception(f"处理压缩包 '{original_archive_name}' 过程中发生严重错误: {str(e)}") return jsonify({"success": False, "error": f"压缩包处理错误: {str(e)}"}), 500 finally: # 清理临时文件和目录 (保持不变) if archive_path and os.path.exists(archive_path): try: os.remove(archive_path) logger.info(f"已删除临时压缩文件: {archive_path}") except OSError as rm_err: logger.error(f"删除临时压缩文件失败 {archive_path}: {rm_err}") if os.path.exists(extract_dir): try: shutil.rmtree(extract_dir) logger.info(f"已删除临时解压目录: {extract_dir}") except OSError as rmtree_err: logger.error(f"删除临时解压目录失败 {extract_dir}: {rmtree_err}") @classify_bp.route('/documents', methods=['GET']) def get_classified_documents(): """获取已分类的文档列表""" if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 user_id = session['user_id'] try: category = request.args.get('category', 'all') page = int(request.args.get('page', 1)) per_page = int(request.args.get('per_page', 10)) if per_page not in [10, 25, 50, 100]: per_page = 10 if page < 1: page = 1 offset = (page - 1) * per_page db = get_db() cursor = db.cursor(dictionary=True) # Use dictionary cursor # Get available categories first cursor.execute("SELECT name FROM categories ORDER BY name") available_categories = [row['name'] for row in cursor.fetchall()] # Build query params = [user_id] base_query = """ FROM documents d JOIN categories c ON d.category_id = c.id WHERE d.user_id = %s AND d.status = '已分类' """ if category != 'all' and category in available_categories: base_query += " AND c.name = %s" params.append(category) # Count query count_query = f"SELECT COUNT(*) as total {base_query}" cursor.execute(count_query, params) total_count_result = cursor.fetchone() total_count = total_count_result['total'] if total_count_result else 0 total_pages = (total_count + per_page - 1) // per_page if per_page > 0 else 0 # Data query data_query = f""" SELECT d.id, d.original_filename, d.stored_filename, d.file_size, c.name as category, d.upload_time, d.classified_time {base_query} ORDER BY d.classified_time DESC LIMIT %s OFFSET %s """ params.extend([per_page, offset]) cursor.execute(data_query, params) documents = cursor.fetchall() # Format dates if they are datetime objects (optional, depends on driver) for doc in documents: if hasattr(doc.get('upload_time'), 'isoformat'): doc['upload_time'] = doc['upload_time'].isoformat() if hasattr(doc.get('classified_time'), 'isoformat'): doc['classified_time'] = doc['classified_time'].isoformat() return jsonify({ "success": True, "documents": documents, "pagination": { "total": total_count, "per_page": per_page, "current_page": page, "total_pages": total_pages }, "categories": available_categories, # Send the fetched list "current_category": category }) except Exception as e: logger.exception(f"获取文档列表时出错: {str(e)}") # Log traceback return jsonify({"success": False, "error": f"获取文档列表失败: {str(e)}"}), 500 finally: # Cursor closing handled by context usually, but explicit close is safe if 'cursor' in locals() and cursor: cursor.close() @classify_bp.route('/download/', methods=['GET']) def download_document(document_id): """下载已分类的文档""" if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 user_id = session['user_id'] cursor = None try: db = get_db() cursor = db.cursor(dictionary=True) query = """ SELECT file_path, original_filename FROM documents WHERE id = %s AND user_id = %s AND status = '已分类' """ cursor.execute(query, (document_id, user_id)) document = cursor.fetchone() if not document: return jsonify({"success": False, "error": "文档不存在或无权访问"}), 404 file_path = document['file_path'] download_name = document['original_filename'] if not os.path.exists(file_path): logger.error(f"请求下载的文件在服务器上不存在: {file_path} (Doc ID: {document_id})") # Update status in DB? # update_status_query = "UPDATE documents SET status = '文件丢失' WHERE id = %s" # cursor.execute(update_status_query, (document_id,)) # db.commit() return jsonify({"success": False, "error": "文件在服务器上丢失"}), 404 logger.info(f"用户 {user_id} 请求下载文档 ID: {document_id}, 文件: {download_name}") return send_file( file_path, as_attachment=True, download_name=download_name, # Use original filename for download mimetype='text/plain' # Assuming text, adjust if other types allowed ) except Exception as e: logger.exception(f"下载文档 ID {document_id} 时出错: {str(e)}") return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500 finally: if cursor: cursor.close() @classify_bp.route('/download-multiple', methods=['POST']) def download_multiple_documents(): """下载多个文档(打包为zip)""" if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 user_id = session['user_id'] cursor = None zip_path = None # Initialize for finally block try: data = request.get_json() if not data or 'document_ids' not in data: return jsonify({"success": False, "error": "缺少必要参数 'document_ids'"}), 400 document_ids = data['document_ids'] if not isinstance(document_ids, list) or not document_ids: return jsonify({"success": False, "error": "文档ID列表无效"}), 400 # Sanitize IDs to be integers try: document_ids = [int(doc_id) for doc_id in document_ids] except ValueError: return jsonify({"success": False, "error": "文档ID必须是数字"}), 400 db = get_db() cursor = db.cursor(dictionary=True) # Query valid documents for the user placeholders = ', '.join(['%s'] * len(document_ids)) query = f""" SELECT id, file_path, original_filename FROM documents WHERE id IN ({placeholders}) AND user_id = %s AND status = '已分类' """ params = document_ids + [user_id] cursor.execute(query, params) documents = cursor.fetchall() if not documents: return jsonify({"success": False, "error": "没有找到符合条件的可下载文档"}), 404 # Create temp zip file temp_dir = os.path.join(current_app.root_path, 'temp') os.makedirs(temp_dir, exist_ok=True) zip_filename = f"documents_{user_id}_{int(time.time())}.zip" zip_path = os.path.join(temp_dir, zip_filename) files_added = 0 with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: for doc in documents: file_path = doc['file_path'] original_filename = doc['original_filename'] if os.path.exists(file_path): zipf.write(file_path, arcname=original_filename) files_added += 1 logger.info(f"添加到ZIP: {original_filename} (来自 {file_path})") else: logger.warning(f"跳过丢失的文件: {original_filename} (路径: {file_path}, Doc ID: {doc['id']})") # Optionally add a readme to the zip indicating missing files if files_added == 0: # This case means all selected files were missing on disk return jsonify({"success": False, "error": "所有选中的文件在服务器上都已丢失"}), 404 logger.info(f"用户 {user_id} 请求下载 {files_added} 个文档打包为 {zip_filename}") # Use after_this_request to delete the temp file after sending @current_app.after_request def remove_file(response): try: if zip_path and os.path.exists(zip_path): os.remove(zip_path) logger.info(f"已删除临时ZIP文件: {zip_path}") except Exception as error: logger.error(f"删除临时ZIP文件错误 {zip_path}: {error}") return response return send_file( zip_path, as_attachment=True, download_name=zip_filename, mimetype='application/zip' ) except Exception as e: logger.exception(f"下载多个文档时出错: {str(e)}") # Clean up zip file if created but error occurred before sending if zip_path and os.path.exists(zip_path): try: os.remove(zip_path) except OSError: pass return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500 finally: if cursor: cursor.close() @classify_bp.route('/classify-text', methods=['POST']) def classify_text_directly(): """直接对文本进行分类(不保存文件)""" if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 try: data = request.get_json() if not data or 'text' not in data: return jsonify({"success": False, "error": "缺少必要参数 'text'"}), 400 text = data['text'] # Basic validation if not isinstance(text, str) or not text.strip(): return jsonify({"success": False, "error": "文本内容不能为空或无效"}), 400 # Limit text length? MAX_TEXT_LENGTH = 10000 # Example limit if len(text) > MAX_TEXT_LENGTH: return jsonify({"success": False, "error": f"文本过长,最大支持 {MAX_TEXT_LENGTH} 字符"}), 413 # 调用模型进行分类 # Ensure model is initialized - consider adding a check or lazy loading if not text_classifier or not text_classifier.is_initialized: logger.warning("文本分类器未初始化,尝试初始化...") if not text_classifier.initialize(): logger.error("文本分类器初始化失败。") return jsonify({"success": False, "error": "分类服务暂时不可用"}), 503 result = text_classifier.classify_text(text) if not result.get('success', False): error_msg = result.get('error', '未知的分类错误') logger.error(f"直接文本分类失败: {error_msg}") return jsonify({"success": False, "error": error_msg}), 500 # 返回分类结果 return jsonify({ "success": True, "category": result.get('category'), "confidence": result.get('confidence'), "all_confidences": result.get('all_confidences') # Include all confidences if available }) except Exception as e: logger.exception(f"直接文本分类过程中发生错误: {str(e)}") return jsonify({"success": False, "error": f"文本分类错误: {str(e)}"}), 500 @classify_bp.route('/categories', methods=['GET']) def get_categories(): """获取所有分类类别""" cursor = None try: db = get_db() cursor = db.cursor(dictionary=True) cursor.execute("SELECT id, name, description FROM categories ORDER BY name") categories = cursor.fetchall() return jsonify({ "success": True, "categories": categories }) except Exception as e: logger.exception(f"获取类别列表时出错: {str(e)}") return jsonify({"success": False, "error": f"获取类别列表失败: {str(e)}"}), 500 finally: if cursor: cursor.close()