2025-04-02 21:47:17 +08:00

738 lines
28 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# routes/classify.py
import os
import time
import uuid
import zipfile
import rarfile
import shutil
from flask import Blueprint, request, jsonify, current_app, send_file, g, session, after_this_request
from werkzeug.utils import secure_filename
import mysql.connector
from utils.model_service import text_classifier # Assuming text_classifier is initialized correctly elsewhere
from utils.db import get_db, close_db
import logging
# 创建蓝图
classify_bp = Blueprint('classify', __name__, url_prefix='/api/classify')
# 设置日志
logger = logging.getLogger(__name__)
# 允许的文件扩展名
ALLOWED_TEXT_EXTENSIONS = {'txt'}
ALLOWED_ARCHIVE_EXTENSIONS = {'zip', 'rar'}
def allowed_text_file(filename):
"""检查文本文件扩展名是否允许"""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_TEXT_EXTENSIONS
def allowed_archive_file(filename):
"""检查压缩文件扩展名是否允许"""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_ARCHIVE_EXTENSIONS
def get_next_document_number(category):
"""获取给定类别的下一个文档编号
Args:
category (str): 文档类别
Returns:
int: 下一个编号
"""
db = get_db()
cursor = db.cursor(dictionary=True)
try:
# 查询该类别的最大编号
query = """
SELECT MAX(CAST(SUBSTRING_INDEX(SUBSTRING_INDEX(stored_filename, '-', 2), '-', -1) AS UNSIGNED)) as max_num
FROM documents
WHERE stored_filename LIKE %s
"""
cursor.execute(query, (f"{category}-%",))
result = cursor.fetchone()
# 如果没有记录或者最大编号为None返回1否则返回最大编号+1
if not result or result['max_num'] is None:
return 1
else:
# Ensure it's an int before adding
return int(result['max_num']) + 1
except Exception as e:
logger.error(f"获取下一个文档编号时出错: {str(e)}")
return 1
finally:
cursor.close()
def save_classified_document(user_id, original_filename, category, content, file_size=None):
"""保存分类后的文档
Args:
user_id (int): 用户ID
original_filename (str): 原始文件名
category (str): 分类类别
content (str): 文档内容
file_size (int, optional): 文件大小如果为None则自动计算
Returns:
tuple: (成功标志, 存储的文件名或错误信息)
"""
db = None # Initialize db to None
cursor = None # Initialize cursor to None
try:
# 获取下一个文档编号
next_num = get_next_document_number(category)
# 安全处理文件名 - ensure it doesn't become empty
safe_original_name_base = os.path.splitext(secure_filename(original_filename))[0]
if not safe_original_name_base:
safe_original_name_base = "untitled" # Default if name becomes empty
safe_original_name = f"{safe_original_name_base}.txt" # Ensure .txt extension
# 生成新文件名 (类别-编号-原文件名)
formatted_num = f"{next_num:04d}" # 确保编号格式为4位数
new_filename = f"{category}-{formatted_num}-{safe_original_name}"
# 确保uploads目录存在
uploads_dir = os.path.join(current_app.root_path, current_app.config['UPLOAD_FOLDER'])
category_dir = os.path.join(uploads_dir, category)
os.makedirs(category_dir, exist_ok=True)
# 文件完整路径
file_path = os.path.join(category_dir, new_filename)
# 保存文件
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
# 如果文件大小未提供,则计算
if file_size is None:
file_size = os.path.getsize(file_path)
# 获取db连接
db = get_db()
cursor = db.cursor()
# 获取分类类别ID
category_query = "SELECT id FROM categories WHERE name = %s"
cursor.execute(category_query, (category,))
category_result = cursor.fetchone()
if not category_result:
logger.error(f"类别 '{category}' 在数据库中未找到。")
# Optionally create the category if needed, or return error
return False, f"类别 '{category}' 不存在"
category_id = category_result[0]
# 插入数据库记录
insert_query = """
INSERT INTO documents
(user_id, original_filename, stored_filename, file_path, file_size, category_id, status, classified_time, upload_time)
VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW())
""" # Added upload_time=NOW() assuming it's desired
cursor.execute(
insert_query,
(user_id, original_filename, new_filename, file_path, file_size, category_id, '已分类')
)
# 提交事务
db.commit()
logger.info(f"成功保存文档: {new_filename} (ID: {cursor.lastrowid})")
return True, new_filename
except Exception as e:
# Rollback transaction if error occurs
if db:
db.rollback()
logger.exception(f"保存分类文档 '{original_filename}' 时出错: {str(e)}") # Use logger.exception for traceback
return False, str(e)
finally:
if cursor:
cursor.close()
# Note: db connection is usually closed via @app.teardown_appcontext
@classify_bp.route('/single', methods=['POST'])
def classify_single_file():
"""单文件上传和分类API"""
# 检查用户是否登录
if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id']
# 检查是否上传了文件
if 'file' not in request.files:
return jsonify({"success": False, "error": "没有文件"}), 400
file = request.files['file']
original_filename = file.filename # Store original name early
# 检查文件名
if original_filename == '':
return jsonify({"success": False, "error": "未选择文件"}), 400
# 检查文件类型
if not allowed_text_file(original_filename):
return jsonify({"success": False, "error": "不支持的文件类型仅支持txt文件"}), 400
temp_path = None # Initialize to ensure it's defined for cleanup
try:
# 创建临时文件以供处理
temp_dir = os.path.join(current_app.root_path, 'temp')
os.makedirs(temp_dir, exist_ok=True)
# Use secure_filename for the temp file name part as well
temp_filename = f"{uuid.uuid4().hex}-{secure_filename(original_filename)}"
temp_path = os.path.join(temp_dir, temp_filename)
# 保存上传文件到临时位置
file.save(temp_path)
file_size = os.path.getsize(temp_path) # Get size after saving
# 读取文件内容 - Try multiple encodings
file_content = None
encodings_to_try = ['utf-8', 'gbk', 'gb18030']
for enc in encodings_to_try:
try:
with open(temp_path, 'r', encoding=enc) as f:
file_content = f.read()
logger.info(f"成功以 {enc} 读取临时文件: {temp_path}")
break
except UnicodeDecodeError:
logger.warning(f"使用 {enc} 解码临时文件失败: {temp_path}")
continue
except Exception as read_err: # Catch other potential read errors
logger.error(f"读取临时文件时出错 ({enc}): {read_err}")
# Decide if you want to stop or try next encoding
continue
if file_content is None:
raise ValueError(f"无法使用支持的编码 {encodings_to_try} 读取文件内容。")
# 调用模型进行分类
result = text_classifier.classify_text(file_content)
if not result.get('success', False): # Check if key exists and is True
error_msg = result.get('error', '未知的分类错误')
logger.error(f"文本分类失败 for {original_filename}: {error_msg}")
return jsonify({"success": False, "error": error_msg}), 500
category = result['category']
# 保存分类后的文档
save_success, message = save_classified_document(
user_id,
original_filename, # Use the original filename here
category,
file_content,
file_size
)
if not save_success:
return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500
# 返回分类结果
return jsonify({
"success": True,
"filename": original_filename,
"category": category,
"confidence": result.get('confidence'), # Use .get for safety
"stored_filename": message
})
except Exception as e:
logger.exception(f"处理单文件 '{original_filename}' 过程中发生错误: {str(e)}") # Log exception with traceback
return jsonify({"success": False, "error": f"文件处理错误: {str(e)}"}), 500
finally:
# 清理临时文件
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
logger.info(f"已删除临时文件: {temp_path}")
except OSError as rm_err:
logger.error(f"删除临时文件失败 {temp_path}: {rm_err}")
@classify_bp.route('/batch', methods=['POST'])
def classify_batch_files():
"""批量文件上传和分类API压缩包处理"""
if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id']
if 'file' not in request.files:
return jsonify({"success": False, "error": "没有文件"}), 400
file = request.files['file']
original_archive_name = file.filename
if original_archive_name == '':
return jsonify({"success": False, "error": "未选择文件"}), 400
if not allowed_archive_file(original_archive_name):
return jsonify({"success": False, "error": "不支持的文件类型仅支持zip和rar压缩文件"}), 400
# Consider adding file size check here if needed (e.g., request.content_length)
# if request.content_length > current_app.config['MAX_CONTENT_LENGTH']: # Example
# return jsonify({"success": False, "error": "文件过大"}), 413
temp_dir = os.path.join(current_app.root_path, 'temp')
extract_dir = os.path.join(temp_dir, f"extract_{uuid.uuid4().hex}")
archive_path = None
# --- 修改点 1: 将 total_attempted 改回 total ---
results = {
"total": 0, # <--- 使用 'total'
"success": 0,
"failed": 0,
"categories": {},
"failed_files": []
}
try:
os.makedirs(extract_dir, exist_ok=True)
archive_path = os.path.join(temp_dir, secure_filename(original_archive_name))
file.save(archive_path)
logger.info(f"已保存上传的压缩文件: {archive_path}")
# 解压文件
file_extension = original_archive_name.rsplit('.', 1)[1].lower()
if file_extension == 'zip':
with zipfile.ZipFile(archive_path, 'r') as zip_ref:
zip_ref.extractall(extract_dir)
logger.info(f"已解压ZIP文件到: {extract_dir}")
elif file_extension == 'rar':
# Ensure rarfile is installed and unrar executable is available
try:
with rarfile.RarFile(archive_path, 'r') as rar_ref:
rar_ref.extractall(extract_dir)
logger.info(f"已解压RAR文件到: {extract_dir}")
except rarfile.NeedFirstVolume:
logger.error(f"RAR文件是分卷压缩的一部分需要第一个分卷: {original_archive_name}")
raise ValueError("不支持分卷压缩的RAR文件")
except rarfile.BadRarFile:
logger.error(f"RAR文件损坏或格式错误: {original_archive_name}")
raise ValueError("RAR文件损坏或格式错误")
except Exception as rar_err: # Catch other rar errors (like missing unrar)
logger.error(f"解压RAR文件时出错: {rar_err}")
raise ValueError(f"解压RAR文件失败: {rar_err}")
# 递归处理所有txt文件
for root, dirs, files in os.walk(extract_dir):
# --- 开始过滤 macOS 文件夹 ---
if '__MACOSX' in root.split(os.path.sep):
logger.info(f"Skipping macOS metadata directory: {root}")
dirs[:] = []
files[:] = []
continue
# --- 结束过滤 macOS 文件夹 ---
for filename in files:
# --- 开始过滤 macOS 文件 ---
if filename.startswith('._') or filename == '.DS_Store':
logger.info(f"Skipping macOS metadata file: {filename} in {root}")
continue
# --- 结束过滤 macOS 文件 ---
# Process only allowed text files
if allowed_text_file(filename):
file_path = os.path.join(root, filename)
# --- 修改点 2: 将 total_attempted 改回 total ---
results["total"] += 1 # <--- 使用 'total'
try:
# 读取文件内容 - Try multiple encodings
file_content = None
encodings_to_try = ['utf-8', 'gbk', 'gb18030']
for enc in encodings_to_try:
try:
with open(file_path, 'r', encoding=enc) as f:
file_content = f.read()
logger.info(f"成功以 {enc} 读取文件: {file_path}")
break
except UnicodeDecodeError:
logger.warning(f"使用 {enc} 解码文件失败: {file_path}")
continue
except Exception as read_err:
logger.error(f"读取文件时发生其他错误 ({enc}) {file_path}: {read_err}")
continue
if file_content is None:
raise ValueError(f"无法使用支持的编码 {encodings_to_try} 读取文件内容。")
# 调用模型进行分类
result = text_classifier.classify_text(file_content)
if not result.get('success', False):
error_msg = result.get('error', '未知的分类错误')
raise Exception(f"分类失败: {error_msg}")
category = result['category']
results["categories"][category] = results["categories"].get(category, 0) + 1
file_size = os.path.getsize(file_path)
save_success, message = save_classified_document(
user_id,
filename,
category,
file_content,
file_size
)
if save_success:
results["success"] += 1
else:
raise Exception(f"保存失败: {message}")
except Exception as e:
logger.error(f"处理文件 '{filename}' 失败: {str(e)}")
results["failed"] += 1
results["failed_files"].append({
"filename": filename,
"error": str(e)
})
# 返回处理结果
return jsonify({
"success": True,
"archive_name": original_archive_name,
"results": results # <--- 确保返回的 results 包含 'total' 键
})
except Exception as e:
logger.exception(f"处理压缩包 '{original_archive_name}' 过程中发生严重错误: {str(e)}")
return jsonify({"success": False, "error": f"压缩包处理错误: {str(e)}"}), 500
finally:
# 清理临时文件和目录 (保持不变)
if archive_path and os.path.exists(archive_path):
try:
os.remove(archive_path)
logger.info(f"已删除临时压缩文件: {archive_path}")
except OSError as rm_err:
logger.error(f"删除临时压缩文件失败 {archive_path}: {rm_err}")
if os.path.exists(extract_dir):
try:
shutil.rmtree(extract_dir)
logger.info(f"已删除临时解压目录: {extract_dir}")
except OSError as rmtree_err:
logger.error(f"删除临时解压目录失败 {extract_dir}: {rmtree_err}")
@classify_bp.route('/documents', methods=['GET'])
def get_classified_documents():
"""获取已分类的文档列表"""
if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id']
try:
category = request.args.get('category', 'all')
page = int(request.args.get('page', 1))
per_page = int(request.args.get('per_page', 10))
if per_page not in [10, 25, 50, 100]:
per_page = 10
if page < 1:
page = 1
offset = (page - 1) * per_page
db = get_db()
cursor = db.cursor(dictionary=True) # Use dictionary cursor
# Get available categories first
cursor.execute("SELECT name FROM categories ORDER BY name")
available_categories = [row['name'] for row in cursor.fetchall()]
# Build query
params = [user_id]
base_query = """
FROM documents d
JOIN categories c ON d.category_id = c.id
WHERE d.user_id = %s AND d.status = '已分类'
"""
if category != 'all' and category in available_categories:
base_query += " AND c.name = %s"
params.append(category)
# Count query
count_query = f"SELECT COUNT(*) as total {base_query}"
cursor.execute(count_query, params)
total_count_result = cursor.fetchone()
total_count = total_count_result['total'] if total_count_result else 0
total_pages = (total_count + per_page - 1) // per_page if per_page > 0 else 0
# Data query
data_query = f"""
SELECT d.id, d.original_filename, d.stored_filename, d.file_size,
c.name as category, d.upload_time, d.classified_time
{base_query}
ORDER BY d.classified_time DESC
LIMIT %s OFFSET %s
"""
params.extend([per_page, offset])
cursor.execute(data_query, params)
documents = cursor.fetchall()
# Format dates if they are datetime objects (optional, depends on driver)
for doc in documents:
if hasattr(doc.get('upload_time'), 'isoformat'):
doc['upload_time'] = doc['upload_time'].isoformat()
if hasattr(doc.get('classified_time'), 'isoformat'):
doc['classified_time'] = doc['classified_time'].isoformat()
return jsonify({
"success": True,
"documents": documents,
"pagination": {
"total": total_count,
"per_page": per_page,
"current_page": page,
"total_pages": total_pages
},
"categories": available_categories, # Send the fetched list
"current_category": category
})
except Exception as e:
logger.exception(f"获取文档列表时出错: {str(e)}") # Log traceback
return jsonify({"success": False, "error": f"获取文档列表失败: {str(e)}"}), 500
finally:
# Cursor closing handled by context usually, but explicit close is safe
if 'cursor' in locals() and cursor:
cursor.close()
@classify_bp.route('/download/<int:document_id>', methods=['GET'])
def download_document(document_id):
"""下载已分类的文档"""
if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id']
cursor = None
try:
db = get_db()
cursor = db.cursor(dictionary=True)
query = """
SELECT file_path, original_filename
FROM documents
WHERE id = %s AND user_id = %s AND status = '已分类'
"""
cursor.execute(query, (document_id, user_id))
document = cursor.fetchone()
if not document:
return jsonify({"success": False, "error": "文档不存在或无权访问"}), 404
file_path = document['file_path']
download_name = document['original_filename']
if not os.path.exists(file_path):
logger.error(f"请求下载的文件在服务器上不存在: {file_path} (Doc ID: {document_id})")
# Update status in DB?
# update_status_query = "UPDATE documents SET status = '文件丢失' WHERE id = %s"
# cursor.execute(update_status_query, (document_id,))
# db.commit()
return jsonify({"success": False, "error": "文件在服务器上丢失"}), 404
logger.info(f"用户 {user_id} 请求下载文档 ID: {document_id}, 文件: {download_name}")
return send_file(
file_path,
as_attachment=True,
download_name=download_name, # Use original filename for download
mimetype='text/plain' # Assuming text, adjust if other types allowed
)
except Exception as e:
logger.exception(f"下载文档 ID {document_id} 时出错: {str(e)}")
return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500
finally:
if cursor:
cursor.close()
@classify_bp.route('/download-multiple', methods=['POST'])
def download_multiple_documents():
"""下载多个文档打包为zip"""
if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id']
cursor = None
zip_path = None # Initialize for finally block
try:
data = request.get_json()
if not data or 'document_ids' not in data:
return jsonify({"success": False, "error": "缺少必要参数 'document_ids'"}), 400
document_ids = data['document_ids']
if not isinstance(document_ids, list) or not document_ids:
return jsonify({"success": False, "error": "文档ID列表无效"}), 400
# Sanitize IDs to be integers
try:
document_ids = [int(doc_id) for doc_id in document_ids]
except ValueError:
return jsonify({"success": False, "error": "文档ID必须是数字"}), 400
db = get_db()
cursor = db.cursor(dictionary=True)
# Query valid documents for the user
placeholders = ', '.join(['%s'] * len(document_ids))
query = f"""
SELECT id, file_path, original_filename
FROM documents
WHERE id IN ({placeholders}) AND user_id = %s AND status = '已分类'
"""
params = document_ids + [user_id]
cursor.execute(query, params)
documents = cursor.fetchall()
if not documents:
return jsonify({"success": False, "error": "没有找到符合条件的可下载文档"}), 404
# Create temp zip file
temp_dir = os.path.join(current_app.root_path, 'temp')
os.makedirs(temp_dir, exist_ok=True)
zip_filename = f"documents_{user_id}_{int(time.time())}.zip"
zip_path = os.path.join(temp_dir, zip_filename)
files_added = 0
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for doc in documents:
file_path = doc['file_path']
original_filename = doc['original_filename']
if os.path.exists(file_path):
zipf.write(file_path, arcname=original_filename)
files_added += 1
logger.info(f"添加到ZIP: {original_filename} (来自 {file_path})")
else:
logger.warning(f"跳过丢失的文件: {original_filename} (路径: {file_path}, Doc ID: {doc['id']})")
# Optionally add a readme to the zip indicating missing files
if files_added == 0:
# This case means all selected files were missing on disk
return jsonify({"success": False, "error": "所有选中的文件在服务器上都已丢失"}), 404
logger.info(f"用户 {user_id} 请求下载 {files_added} 个文档打包为 {zip_filename}")
# Use after_this_request to delete the temp file after sending
@after_this_request
def remove_file(response):
try:
if zip_path and os.path.exists(zip_path):
os.remove(zip_path)
logger.info(f"已删除临时ZIP文件: {zip_path}")
except Exception as error:
logger.error(f"删除临时ZIP文件错误 {zip_path}: {error}")
return response
return send_file(
zip_path,
as_attachment=True,
download_name=zip_filename,
mimetype='application/zip'
)
except Exception as e:
logger.exception(f"下载多个文档时出错: {str(e)}")
# Clean up zip file if created but error occurred before sending
if zip_path and os.path.exists(zip_path):
try: os.remove(zip_path)
except OSError: pass
return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500
finally:
if cursor:
cursor.close()
@classify_bp.route('/classify-text', methods=['POST'])
def classify_text_directly():
"""直接对文本进行分类(不保存文件)"""
if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401
try:
data = request.get_json()
if not data or 'text' not in data:
return jsonify({"success": False, "error": "缺少必要参数 'text'"}), 400
text = data['text']
# Basic validation
if not isinstance(text, str) or not text.strip():
return jsonify({"success": False, "error": "文本内容不能为空或无效"}), 400
# Limit text length?
MAX_TEXT_LENGTH = 10000 # Example limit
if len(text) > MAX_TEXT_LENGTH:
return jsonify({"success": False, "error": f"文本过长,最大支持 {MAX_TEXT_LENGTH} 字符"}), 413
# 调用模型进行分类
# Ensure model is initialized - consider adding a check or lazy loading
if not text_classifier or not text_classifier.is_initialized:
logger.warning("文本分类器未初始化,尝试初始化...")
if not text_classifier.initialize():
logger.error("文本分类器初始化失败。")
return jsonify({"success": False, "error": "分类服务暂时不可用"}), 503
result = text_classifier.classify_text(text)
if not result.get('success', False):
error_msg = result.get('error', '未知的分类错误')
logger.error(f"直接文本分类失败: {error_msg}")
return jsonify({"success": False, "error": error_msg}), 500
# 返回分类结果
return jsonify({
"success": True,
"category": result.get('category'),
"confidence": result.get('confidence'),
"all_confidences": result.get('all_confidences') # Include all confidences if available
})
except Exception as e:
logger.exception(f"直接文本分类过程中发生错误: {str(e)}")
return jsonify({"success": False, "error": f"文本分类错误: {str(e)}"}), 500
@classify_bp.route('/categories', methods=['GET'])
def get_categories():
"""获取所有分类类别"""
cursor = None
try:
db = get_db()
cursor = db.cursor(dictionary=True)
cursor.execute("SELECT id, name, description FROM categories ORDER BY name")
categories = cursor.fetchall()
return jsonify({
"success": True,
"categories": categories
})
except Exception as e:
logger.exception(f"获取类别列表时出错: {str(e)}")
return jsonify({"success": False, "error": f"获取类别列表失败: {str(e)}"}), 500
finally:
if cursor:
cursor.close()