# routes/classify.py import os import time import uuid import zipfile import rarfile import shutil from flask import Blueprint, request, jsonify, current_app, send_file, g, session from werkzeug.utils import secure_filename import mysql.connector from utils.model_service import text_classifier from utils.db import get_db, close_db import logging # 创建蓝图 classify_bp = Blueprint('classify', __name__, url_prefix='/api/classify') # 设置日志 logger = logging.getLogger(__name__) # 允许的文件扩展名 ALLOWED_TEXT_EXTENSIONS = {'txt'} ALLOWED_ARCHIVE_EXTENSIONS = {'zip', 'rar'} def allowed_text_file(filename): """检查文本文件扩展名是否允许""" return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_TEXT_EXTENSIONS def allowed_archive_file(filename): """检查压缩文件扩展名是否允许""" return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_ARCHIVE_EXTENSIONS def get_next_document_number(category): """获取给定类别的下一个文档编号 Args: category (str): 文档类别 Returns: int: 下一个编号 """ db = get_db() cursor = db.cursor(dictionary=True) try: # 查询该类别的最大编号 query = """ SELECT MAX(CAST(SUBSTRING_INDEX(SUBSTRING_INDEX(stored_filename, '-', 2), '-', -1) AS UNSIGNED)) as max_num FROM documents WHERE stored_filename LIKE %s """ cursor.execute(query, (f"{category}-%",)) result = cursor.fetchone() # 如果没有记录,或者最大编号为None,返回1,否则返回最大编号+1 if not result or result['max_num'] is None: return 1 else: return result['max_num'] + 1 except Exception as e: logger.error(f"获取下一个文档编号时出错: {str(e)}") return 1 finally: cursor.close() def save_classified_document(user_id, original_filename, category, content, file_size=None): """保存分类后的文档 Args: user_id (int): 用户ID original_filename (str): 原始文件名 category (str): 分类类别 content (str): 文档内容 file_size (int, optional): 文件大小,如果为None则自动计算 Returns: tuple: (成功标志, 存储的文件名或错误信息) """ try: # 获取下一个文档编号 next_num = get_next_document_number(category) # 安全处理文件名 safe_original_name = secure_filename(original_filename) # 生成新文件名 (类别-编号-原文件名) formatted_num = f"{next_num:04d}" # 确保编号格式为4位数 new_filename = f"{category}-{formatted_num}-{safe_original_name}" # 确保uploads目录存在 uploads_dir = os.path.join(current_app.root_path, current_app.config['UPLOAD_FOLDER']) category_dir = os.path.join(uploads_dir, category) os.makedirs(category_dir, exist_ok=True) # 文件完整路径 file_path = os.path.join(category_dir, new_filename) # 保存文件 with open(file_path, 'w', encoding='utf-8') as f: f.write(content) # 如果文件大小未提供,则计算 if file_size is None: file_size = os.path.getsize(file_path) # 获取db连接 db = get_db() cursor = db.cursor() # 获取分类类别ID category_query = "SELECT id FROM categories WHERE name = %s" cursor.execute(category_query, (category,)) category_result = cursor.fetchone() if not category_result: return False, "类别不存在" category_id = category_result[0] # 插入数据库记录 insert_query = """ INSERT INTO documents (user_id, original_filename, stored_filename, file_path, file_size, category_id, status, classified_time) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW()) """ cursor.execute( insert_query, (user_id, original_filename, new_filename, file_path, file_size, category_id, '已分类') ) # 提交事务 db.commit() return True, new_filename except Exception as e: logger.error(f"保存分类文档时出错: {str(e)}") return False, str(e) @classify_bp.route('/single', methods=['POST']) def classify_single_file(): """单文件上传和分类API""" # 检查用户是否登录 if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 user_id = session['user_id'] # 检查是否上传了文件 if 'file' not in request.files: return jsonify({"success": False, "error": "没有文件"}), 400 file = request.files['file'] # 检查文件名 if file.filename == '': return jsonify({"success": False, "error": "未选择文件"}), 400 # 检查文件类型 if not allowed_text_file(file.filename): return jsonify({"success": False, "error": "不支持的文件类型,仅支持txt文件"}), 400 try: # 创建临时文件以供处理 temp_dir = os.path.join(current_app.root_path, 'temp') os.makedirs(temp_dir, exist_ok=True) temp_filename = f"{uuid.uuid4().hex}.txt" temp_path = os.path.join(temp_dir, temp_filename) # 保存上传文件到临时位置 file.save(temp_path) # 读取文件内容 with open(temp_path, 'r', encoding='utf-8') as f: file_content = f.read() # 调用模型进行分类 result = text_classifier.classify_text(file_content) if not result['success']: return jsonify({"success": False, "error": result['error']}), 500 # 保存分类后的文档 file_size = os.path.getsize(temp_path) save_success, message = save_classified_document( user_id, file.filename, result['category'], file_content, file_size ) # 清理临时文件 if os.path.exists(temp_path): os.remove(temp_path) if not save_success: return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500 # 返回分类结果 return jsonify({ "success": True, "filename": file.filename, "category": result['category'], "confidence": result['confidence'], "stored_filename": message }) except UnicodeDecodeError: # 尝试GBK编码 try: with open(temp_path, 'r', encoding='gbk') as f: file_content = f.read() # 调用模型进行分类 result = text_classifier.classify_text(file_content) if not result['success']: return jsonify({"success": False, "error": result['error']}), 500 # 保存分类后的文档 file_size = os.path.getsize(temp_path) save_success, message = save_classified_document( user_id, file.filename, result['category'], file_content, file_size ) # 清理临时文件 if os.path.exists(temp_path): os.remove(temp_path) if not save_success: return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500 # 返回分类结果 return jsonify({ "success": True, "filename": file.filename, "category": result['category'], "confidence": result['confidence'], "stored_filename": message }) except Exception as e: if os.path.exists(temp_path): os.remove(temp_path) return jsonify({"success": False, "error": f"文件编码错误,请确保文件为UTF-8或GBK编码: {str(e)}"}), 400 except Exception as e: # 确保清理临时文件 if 'temp_path' in locals() and os.path.exists(temp_path): os.remove(temp_path) logger.error(f"文件处理过程中发生错误: {str(e)}") return jsonify({"success": False, "error": f"文件处理错误: {str(e)}"}), 500 @classify_bp.route('/batch', methods=['POST']) def classify_batch_files(): """批量文件上传和分类API(压缩包处理)""" # 检查用户是否登录 if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 user_id = session['user_id'] # 检查是否上传了文件 if 'file' not in request.files: return jsonify({"success": False, "error": "没有文件"}), 400 file = request.files['file'] # 检查文件名 if file.filename == '': return jsonify({"success": False, "error": "未选择文件"}), 400 # 检查文件类型 if not allowed_archive_file(file.filename): return jsonify({"success": False, "error": "不支持的文件类型,仅支持zip和rar压缩文件"}), 400 # 检查文件大小 if request.content_length > 10 * 1024 * 1024: # 10MB return jsonify({"success": False, "error": "文件太大,最大支持10MB"}), 400 try: # 创建临时目录 temp_dir = os.path.join(current_app.root_path, 'temp') extract_dir = os.path.join(temp_dir, f"extract_{uuid.uuid4().hex}") os.makedirs(extract_dir, exist_ok=True) # 保存上传的压缩文件 archive_path = os.path.join(temp_dir, secure_filename(file.filename)) file.save(archive_path) # 解压文件 file_extension = file.filename.rsplit('.', 1)[1].lower() if file_extension == 'zip': with zipfile.ZipFile(archive_path, 'r') as zip_ref: zip_ref.extractall(extract_dir) elif file_extension == 'rar': with rarfile.RarFile(archive_path, 'r') as rar_ref: rar_ref.extractall(extract_dir) # 处理结果统计 results = { "total": 0, "success": 0, "failed": 0, "categories": {}, "failed_files": [] } # 递归处理所有txt文件 for root, dirs, files in os.walk(extract_dir): for filename in files: if filename.lower().endswith('.txt'): file_path = os.path.join(root, filename) results["total"] += 1 try: # 读取文件内容 try: with open(file_path, 'r', encoding='utf-8') as f: file_content = f.read() except UnicodeDecodeError: # 尝试GBK编码 with open(file_path, 'r', encoding='gbk') as f: file_content = f.read() # 调用模型进行分类 result = text_classifier.classify_text(file_content) if not result['success']: results["failed"] += 1 results["failed_files"].append({ "filename": filename, "error": result['error'] }) continue category = result['category'] # 统计类别数量 if category not in results["categories"]: results["categories"][category] = 0 results["categories"][category] += 1 # 保存分类后的文档 file_size = os.path.getsize(file_path) save_success, message = save_classified_document( user_id, filename, category, file_content, file_size ) if save_success: results["success"] += 1 else: results["failed"] += 1 results["failed_files"].append({ "filename": filename, "error": message }) except Exception as e: results["failed"] += 1 results["failed_files"].append({ "filename": filename, "error": str(e) }) # 清理临时文件 if os.path.exists(archive_path): os.remove(archive_path) if os.path.exists(extract_dir): shutil.rmtree(extract_dir) # 返回处理结果 return jsonify({ "success": True, "archive_name": file.filename, "results": results }) except Exception as e: # 确保清理临时文件 if 'archive_path' in locals() and os.path.exists(archive_path): os.remove(archive_path) if 'extract_dir' in locals() and os.path.exists(extract_dir): shutil.rmtree(extract_dir) logger.error(f"压缩包处理过程中发生错误: {str(e)}") return jsonify({"success": False, "error": f"压缩包处理错误: {str(e)}"}), 500 @classify_bp.route('/documents', methods=['GET']) def get_classified_documents(): """获取已分类的文档列表""" # 检查用户是否登录 if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 user_id = session['user_id'] # 获取查询参数 category = request.args.get('category', 'all') page = int(request.args.get('page', 1)) per_page = int(request.args.get('per_page', 10)) # 验证每页条数 if per_page not in [10, 25, 50, 100]: per_page = 10 # 计算偏移量 offset = (page - 1) * per_page db = get_db() cursor = db.cursor(dictionary=True) try: # 构建查询条件 where_clause = "WHERE d.user_id = %s AND d.status = '已分类'" params = [user_id] if category != 'all': where_clause += " AND c.name = %s" params.append(category) # 查询总记录数 count_query = f""" SELECT COUNT(*) as total FROM documents d JOIN categories c ON d.category_id = c.id {where_clause} """ cursor.execute(count_query, params) total_count = cursor.fetchone()['total'] # 计算总页数 total_pages = (total_count + per_page - 1) // per_page # 查询分页数据 query = f""" SELECT d.id, d.original_filename, d.stored_filename, d.file_size, c.name as category, d.upload_time, d.classified_time FROM documents d JOIN categories c ON d.category_id = c.id {where_clause} ORDER BY d.classified_time DESC LIMIT %s OFFSET %s """ params.extend([per_page, offset]) cursor.execute(query, params) documents = cursor.fetchall() # 获取所有可用类别 cursor.execute("SELECT name FROM categories ORDER BY name") categories = [row['name'] for row in cursor.fetchall()] return jsonify({ "success": True, "documents": documents, "pagination": { "total": total_count, "per_page": per_page, "current_page": page, "total_pages": total_pages }, "categories": categories, "current_category": category }) except Exception as e: logger.error(f"获取文档列表时出错: {str(e)}") return jsonify({"success": False, "error": f"获取文档列表失败: {str(e)}"}), 500 finally: cursor.close() @classify_bp.route('/download/', methods=['GET']) def download_document(document_id): """下载已分类的文档""" # 检查用户是否登录 if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 user_id = session['user_id'] db = get_db() cursor = db.cursor(dictionary=True) try: # 查询文档信息 query = """ SELECT file_path, original_filename, stored_filename FROM documents WHERE id = %s AND user_id = %s """ cursor.execute(query, (document_id, user_id)) document = cursor.fetchone() if not document: return jsonify({"success": False, "error": "文档不存在或无权访问"}), 404 # 检查文件是否存在 if not os.path.exists(document['file_path']): return jsonify({"success": False, "error": "文件不存在"}), 404 # 返回文件下载 return send_file( document['file_path'], as_attachment=True, download_name=document['original_filename'], mimetype='text/plain' ) except Exception as e: logger.error(f"下载文档时出错: {str(e)}") return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500 finally: cursor.close() @classify_bp.route('/download-multiple', methods=['POST']) def download_multiple_documents(): """下载多个文档(打包为zip)""" # 检查用户是否登录 if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 user_id = session['user_id'] # 获取请求数据 data = request.get_json() if not data or 'document_ids' not in data: return jsonify({"success": False, "error": "缺少必要参数"}), 400 document_ids = data['document_ids'] if not isinstance(document_ids, list) or not document_ids: return jsonify({"success": False, "error": "文档ID列表无效"}), 400 db = get_db() cursor = db.cursor(dictionary=True) try: # 创建临时目录用于存放zip文件 temp_dir = os.path.join(current_app.root_path, 'temp') os.makedirs(temp_dir, exist_ok=True) # 创建临时ZIP文件 zip_filename = f"documents_{int(time.time())}.zip" zip_path = os.path.join(temp_dir, zip_filename) # 查询所有符合条件的文档 placeholders = ', '.join(['%s'] * len(document_ids)) query = f""" SELECT id, file_path, original_filename FROM documents WHERE id IN ({placeholders}) AND user_id = %s """ params = document_ids + [user_id] cursor.execute(query, params) documents = cursor.fetchall() if not documents: return jsonify({"success": False, "error": "没有找到符合条件的文档"}), 404 # 创建ZIP文件并添加文档 with zipfile.ZipFile(zip_path, 'w') as zipf: for doc in documents: if os.path.exists(doc['file_path']): # 添加文件到zip,使用原始文件名 zipf.write(doc['file_path'], arcname=doc['original_filename']) # 返回ZIP文件下载 return send_file( zip_path, as_attachment=True, download_name=zip_filename, mimetype='application/zip' ) except Exception as e: logger.error(f"下载多个文档时出错: {str(e)}") return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500 finally: cursor.close() @classify_bp.route('/classify-text', methods=['POST']) def classify_text_directly(): """直接对文本进行分类(不保存文件)""" # 检查用户是否登录 if 'user_id' not in session: return jsonify({"success": False, "error": "请先登录"}), 401 # 获取请求数据 data = request.get_json() if not data or 'text' not in data: return jsonify({"success": False, "error": "缺少必要参数"}), 400 text = data['text'] if not text.strip(): return jsonify({"success": False, "error": "文本内容不能为空"}), 400 try: # 调用模型进行分类 result = text_classifier.classify_text(text) if not result['success']: return jsonify({"success": False, "error": result['error']}), 500 # 返回分类结果 return jsonify({ "success": True, "category": result['category'], "confidence": result['confidence'], "all_confidences": result['all_confidences'] }) except Exception as e: logger.error(f"文本分类过程中发生错误: {str(e)}") return jsonify({"success": False, "error": f"文本分类错误: {str(e)}"}), 500 @classify_bp.route('/categories', methods=['GET']) def get_categories(): """获取所有分类类别""" db = get_db() cursor = db.cursor(dictionary=True) try: cursor.execute("SELECT id, name, description FROM categories ORDER BY name") categories = cursor.fetchall() return jsonify({ "success": True, "categories": categories }) except Exception as e: logger.error(f"获取类别列表时出错: {str(e)}") return jsonify({"success": False, "error": f"获取类别列表失败: {str(e)}"}), 500 finally: cursor.close()