660 lines
21 KiB
Python
660 lines
21 KiB
Python
# routes/classify.py
|
||
import os
|
||
import time
|
||
import uuid
|
||
import zipfile
|
||
import rarfile
|
||
import shutil
|
||
from flask import Blueprint, request, jsonify, current_app, send_file, g, session
|
||
from werkzeug.utils import secure_filename
|
||
import mysql.connector
|
||
from utils.model_service import text_classifier
|
||
from utils.db import get_db, close_db
|
||
import logging
|
||
|
||
# 创建蓝图
|
||
classify_bp = Blueprint('classify', __name__, url_prefix='/api/classify')
|
||
|
||
# 设置日志
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 允许的文件扩展名
|
||
ALLOWED_TEXT_EXTENSIONS = {'txt'}
|
||
ALLOWED_ARCHIVE_EXTENSIONS = {'zip', 'rar'}
|
||
|
||
|
||
def allowed_text_file(filename):
|
||
"""检查文本文件扩展名是否允许"""
|
||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_TEXT_EXTENSIONS
|
||
|
||
|
||
def allowed_archive_file(filename):
|
||
"""检查压缩文件扩展名是否允许"""
|
||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_ARCHIVE_EXTENSIONS
|
||
|
||
|
||
def get_next_document_number(category):
|
||
"""获取给定类别的下一个文档编号
|
||
|
||
Args:
|
||
category (str): 文档类别
|
||
|
||
Returns:
|
||
int: 下一个编号
|
||
"""
|
||
db = get_db()
|
||
cursor = db.cursor(dictionary=True)
|
||
|
||
try:
|
||
# 查询该类别的最大编号
|
||
query = """
|
||
SELECT MAX(CAST(SUBSTRING_INDEX(SUBSTRING_INDEX(stored_filename, '-', 2), '-', -1) AS UNSIGNED)) as max_num
|
||
FROM documents
|
||
WHERE stored_filename LIKE %s
|
||
"""
|
||
cursor.execute(query, (f"{category}-%",))
|
||
result = cursor.fetchone()
|
||
|
||
# 如果没有记录,或者最大编号为None,返回1,否则返回最大编号+1
|
||
if not result or result['max_num'] is None:
|
||
return 1
|
||
else:
|
||
return result['max_num'] + 1
|
||
except Exception as e:
|
||
logger.error(f"获取下一个文档编号时出错: {str(e)}")
|
||
return 1
|
||
finally:
|
||
cursor.close()
|
||
|
||
|
||
def save_classified_document(user_id, original_filename, category, content, file_size=None):
|
||
"""保存分类后的文档
|
||
|
||
Args:
|
||
user_id (int): 用户ID
|
||
original_filename (str): 原始文件名
|
||
category (str): 分类类别
|
||
content (str): 文档内容
|
||
file_size (int, optional): 文件大小,如果为None则自动计算
|
||
|
||
Returns:
|
||
tuple: (成功标志, 存储的文件名或错误信息)
|
||
"""
|
||
try:
|
||
# 获取下一个文档编号
|
||
next_num = get_next_document_number(category)
|
||
|
||
# 安全处理文件名
|
||
safe_original_name = secure_filename(original_filename)
|
||
|
||
# 生成新文件名 (类别-编号-原文件名)
|
||
formatted_num = f"{next_num:04d}" # 确保编号格式为4位数
|
||
new_filename = f"{category}-{formatted_num}-{safe_original_name}"
|
||
|
||
# 确保uploads目录存在
|
||
uploads_dir = os.path.join(current_app.root_path, current_app.config['UPLOAD_FOLDER'])
|
||
category_dir = os.path.join(uploads_dir, category)
|
||
os.makedirs(category_dir, exist_ok=True)
|
||
|
||
# 文件完整路径
|
||
file_path = os.path.join(category_dir, new_filename)
|
||
|
||
# 保存文件
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
# 如果文件大小未提供,则计算
|
||
if file_size is None:
|
||
file_size = os.path.getsize(file_path)
|
||
|
||
# 获取db连接
|
||
db = get_db()
|
||
cursor = db.cursor()
|
||
|
||
# 获取分类类别ID
|
||
category_query = "SELECT id FROM categories WHERE name = %s"
|
||
cursor.execute(category_query, (category,))
|
||
category_result = cursor.fetchone()
|
||
|
||
if not category_result:
|
||
return False, "类别不存在"
|
||
|
||
category_id = category_result[0]
|
||
|
||
# 插入数据库记录
|
||
insert_query = """
|
||
INSERT INTO documents
|
||
(user_id, original_filename, stored_filename, file_path, file_size, category_id, status, classified_time)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, NOW())
|
||
"""
|
||
|
||
cursor.execute(
|
||
insert_query,
|
||
(user_id, original_filename, new_filename, file_path, file_size, category_id, '已分类')
|
||
)
|
||
|
||
# 提交事务
|
||
db.commit()
|
||
|
||
return True, new_filename
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存分类文档时出错: {str(e)}")
|
||
return False, str(e)
|
||
|
||
|
||
@classify_bp.route('/single', methods=['POST'])
|
||
def classify_single_file():
|
||
"""单文件上传和分类API"""
|
||
# 检查用户是否登录
|
||
if 'user_id' not in session:
|
||
return jsonify({"success": False, "error": "请先登录"}), 401
|
||
|
||
user_id = session['user_id']
|
||
|
||
# 检查是否上传了文件
|
||
if 'file' not in request.files:
|
||
return jsonify({"success": False, "error": "没有文件"}), 400
|
||
|
||
file = request.files['file']
|
||
|
||
# 检查文件名
|
||
if file.filename == '':
|
||
return jsonify({"success": False, "error": "未选择文件"}), 400
|
||
|
||
# 检查文件类型
|
||
if not allowed_text_file(file.filename):
|
||
return jsonify({"success": False, "error": "不支持的文件类型,仅支持txt文件"}), 400
|
||
|
||
try:
|
||
# 创建临时文件以供处理
|
||
temp_dir = os.path.join(current_app.root_path, 'temp')
|
||
os.makedirs(temp_dir, exist_ok=True)
|
||
|
||
temp_filename = f"{uuid.uuid4().hex}.txt"
|
||
temp_path = os.path.join(temp_dir, temp_filename)
|
||
|
||
# 保存上传文件到临时位置
|
||
file.save(temp_path)
|
||
|
||
# 读取文件内容
|
||
with open(temp_path, 'r', encoding='utf-8') as f:
|
||
file_content = f.read()
|
||
|
||
# 调用模型进行分类
|
||
result = text_classifier.classify_text(file_content)
|
||
|
||
if not result['success']:
|
||
return jsonify({"success": False, "error": result['error']}), 500
|
||
|
||
# 保存分类后的文档
|
||
file_size = os.path.getsize(temp_path)
|
||
save_success, message = save_classified_document(
|
||
user_id,
|
||
file.filename,
|
||
result['category'],
|
||
file_content,
|
||
file_size
|
||
)
|
||
|
||
# 清理临时文件
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
if not save_success:
|
||
return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500
|
||
|
||
# 返回分类结果
|
||
return jsonify({
|
||
"success": True,
|
||
"filename": file.filename,
|
||
"category": result['category'],
|
||
"confidence": result['confidence'],
|
||
"stored_filename": message
|
||
})
|
||
|
||
except UnicodeDecodeError:
|
||
# 尝试GBK编码
|
||
try:
|
||
with open(temp_path, 'r', encoding='gbk') as f:
|
||
file_content = f.read()
|
||
|
||
# 调用模型进行分类
|
||
result = text_classifier.classify_text(file_content)
|
||
|
||
if not result['success']:
|
||
return jsonify({"success": False, "error": result['error']}), 500
|
||
|
||
# 保存分类后的文档
|
||
file_size = os.path.getsize(temp_path)
|
||
save_success, message = save_classified_document(
|
||
user_id,
|
||
file.filename,
|
||
result['category'],
|
||
file_content,
|
||
file_size
|
||
)
|
||
|
||
# 清理临时文件
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
if not save_success:
|
||
return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500
|
||
|
||
# 返回分类结果
|
||
return jsonify({
|
||
"success": True,
|
||
"filename": file.filename,
|
||
"category": result['category'],
|
||
"confidence": result['confidence'],
|
||
"stored_filename": message
|
||
})
|
||
except Exception as e:
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
return jsonify({"success": False, "error": f"文件编码错误,请确保文件为UTF-8或GBK编码: {str(e)}"}), 400
|
||
|
||
except Exception as e:
|
||
# 确保清理临时文件
|
||
if 'temp_path' in locals() and os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
logger.error(f"文件处理过程中发生错误: {str(e)}")
|
||
return jsonify({"success": False, "error": f"文件处理错误: {str(e)}"}), 500
|
||
|
||
|
||
@classify_bp.route('/batch', methods=['POST'])
|
||
def classify_batch_files():
|
||
"""批量文件上传和分类API(压缩包处理)"""
|
||
# 检查用户是否登录
|
||
if 'user_id' not in session:
|
||
return jsonify({"success": False, "error": "请先登录"}), 401
|
||
|
||
user_id = session['user_id']
|
||
|
||
# 检查是否上传了文件
|
||
if 'file' not in request.files:
|
||
return jsonify({"success": False, "error": "没有文件"}), 400
|
||
|
||
file = request.files['file']
|
||
|
||
# 检查文件名
|
||
if file.filename == '':
|
||
return jsonify({"success": False, "error": "未选择文件"}), 400
|
||
|
||
# 检查文件类型
|
||
if not allowed_archive_file(file.filename):
|
||
return jsonify({"success": False, "error": "不支持的文件类型,仅支持zip和rar压缩文件"}), 400
|
||
|
||
# 检查文件大小
|
||
if request.content_length > 10 * 1024 * 1024: # 10MB
|
||
return jsonify({"success": False, "error": "文件太大,最大支持10MB"}), 400
|
||
|
||
try:
|
||
# 创建临时目录
|
||
temp_dir = os.path.join(current_app.root_path, 'temp')
|
||
extract_dir = os.path.join(temp_dir, f"extract_{uuid.uuid4().hex}")
|
||
os.makedirs(extract_dir, exist_ok=True)
|
||
|
||
# 保存上传的压缩文件
|
||
archive_path = os.path.join(temp_dir, secure_filename(file.filename))
|
||
file.save(archive_path)
|
||
|
||
# 解压文件
|
||
file_extension = file.filename.rsplit('.', 1)[1].lower()
|
||
if file_extension == 'zip':
|
||
with zipfile.ZipFile(archive_path, 'r') as zip_ref:
|
||
zip_ref.extractall(extract_dir)
|
||
elif file_extension == 'rar':
|
||
with rarfile.RarFile(archive_path, 'r') as rar_ref:
|
||
rar_ref.extractall(extract_dir)
|
||
|
||
# 处理结果统计
|
||
results = {
|
||
"total": 0,
|
||
"success": 0,
|
||
"failed": 0,
|
||
"categories": {},
|
||
"failed_files": []
|
||
}
|
||
|
||
# 递归处理所有txt文件
|
||
for root, dirs, files in os.walk(extract_dir):
|
||
for filename in files:
|
||
if filename.lower().endswith('.txt'):
|
||
file_path = os.path.join(root, filename)
|
||
results["total"] += 1
|
||
|
||
try:
|
||
# 读取文件内容
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
file_content = f.read()
|
||
except UnicodeDecodeError:
|
||
# 尝试GBK编码
|
||
with open(file_path, 'r', encoding='gbk') as f:
|
||
file_content = f.read()
|
||
|
||
# 调用模型进行分类
|
||
result = text_classifier.classify_text(file_content)
|
||
|
||
if not result['success']:
|
||
results["failed"] += 1
|
||
results["failed_files"].append({
|
||
"filename": filename,
|
||
"error": result['error']
|
||
})
|
||
continue
|
||
|
||
category = result['category']
|
||
|
||
# 统计类别数量
|
||
if category not in results["categories"]:
|
||
results["categories"][category] = 0
|
||
results["categories"][category] += 1
|
||
|
||
# 保存分类后的文档
|
||
file_size = os.path.getsize(file_path)
|
||
save_success, message = save_classified_document(
|
||
user_id,
|
||
filename,
|
||
category,
|
||
file_content,
|
||
file_size
|
||
)
|
||
|
||
if save_success:
|
||
results["success"] += 1
|
||
else:
|
||
results["failed"] += 1
|
||
results["failed_files"].append({
|
||
"filename": filename,
|
||
"error": message
|
||
})
|
||
|
||
except Exception as e:
|
||
results["failed"] += 1
|
||
results["failed_files"].append({
|
||
"filename": filename,
|
||
"error": str(e)
|
||
})
|
||
|
||
# 清理临时文件
|
||
if os.path.exists(archive_path):
|
||
os.remove(archive_path)
|
||
if os.path.exists(extract_dir):
|
||
shutil.rmtree(extract_dir)
|
||
|
||
# 返回处理结果
|
||
return jsonify({
|
||
"success": True,
|
||
"archive_name": file.filename,
|
||
"results": results
|
||
})
|
||
|
||
except Exception as e:
|
||
# 确保清理临时文件
|
||
if 'archive_path' in locals() and os.path.exists(archive_path):
|
||
os.remove(archive_path)
|
||
if 'extract_dir' in locals() and os.path.exists(extract_dir):
|
||
shutil.rmtree(extract_dir)
|
||
|
||
logger.error(f"压缩包处理过程中发生错误: {str(e)}")
|
||
return jsonify({"success": False, "error": f"压缩包处理错误: {str(e)}"}), 500
|
||
|
||
|
||
@classify_bp.route('/documents', methods=['GET'])
|
||
def get_classified_documents():
|
||
"""获取已分类的文档列表"""
|
||
# 检查用户是否登录
|
||
if 'user_id' not in session:
|
||
return jsonify({"success": False, "error": "请先登录"}), 401
|
||
|
||
user_id = session['user_id']
|
||
|
||
# 获取查询参数
|
||
category = request.args.get('category', 'all')
|
||
page = int(request.args.get('page', 1))
|
||
per_page = int(request.args.get('per_page', 10))
|
||
|
||
# 验证每页条数
|
||
if per_page not in [10, 25, 50, 100]:
|
||
per_page = 10
|
||
|
||
# 计算偏移量
|
||
offset = (page - 1) * per_page
|
||
|
||
db = get_db()
|
||
cursor = db.cursor(dictionary=True)
|
||
|
||
try:
|
||
# 构建查询条件
|
||
where_clause = "WHERE d.user_id = %s AND d.status = '已分类'"
|
||
params = [user_id]
|
||
|
||
if category != 'all':
|
||
where_clause += " AND c.name = %s"
|
||
params.append(category)
|
||
|
||
# 查询总记录数
|
||
count_query = f"""
|
||
SELECT COUNT(*) as total
|
||
FROM documents d
|
||
JOIN categories c ON d.category_id = c.id
|
||
{where_clause}
|
||
"""
|
||
cursor.execute(count_query, params)
|
||
total_count = cursor.fetchone()['total']
|
||
|
||
# 计算总页数
|
||
total_pages = (total_count + per_page - 1) // per_page
|
||
|
||
# 查询分页数据
|
||
query = f"""
|
||
SELECT d.id, d.original_filename, d.stored_filename, d.file_size,
|
||
c.name as category, d.upload_time, d.classified_time
|
||
FROM documents d
|
||
JOIN categories c ON d.category_id = c.id
|
||
{where_clause}
|
||
ORDER BY d.classified_time DESC
|
||
LIMIT %s OFFSET %s
|
||
"""
|
||
params.extend([per_page, offset])
|
||
cursor.execute(query, params)
|
||
documents = cursor.fetchall()
|
||
|
||
# 获取所有可用类别
|
||
cursor.execute("SELECT name FROM categories ORDER BY name")
|
||
categories = [row['name'] for row in cursor.fetchall()]
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"documents": documents,
|
||
"pagination": {
|
||
"total": total_count,
|
||
"per_page": per_page,
|
||
"current_page": page,
|
||
"total_pages": total_pages
|
||
},
|
||
"categories": categories,
|
||
"current_category": category
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取文档列表时出错: {str(e)}")
|
||
return jsonify({"success": False, "error": f"获取文档列表失败: {str(e)}"}), 500
|
||
|
||
finally:
|
||
cursor.close()
|
||
|
||
|
||
@classify_bp.route('/download/<int:document_id>', methods=['GET'])
|
||
def download_document(document_id):
|
||
"""下载已分类的文档"""
|
||
# 检查用户是否登录
|
||
if 'user_id' not in session:
|
||
return jsonify({"success": False, "error": "请先登录"}), 401
|
||
|
||
user_id = session['user_id']
|
||
|
||
db = get_db()
|
||
cursor = db.cursor(dictionary=True)
|
||
|
||
try:
|
||
# 查询文档信息
|
||
query = """
|
||
SELECT file_path, original_filename, stored_filename
|
||
FROM documents
|
||
WHERE id = %s AND user_id = %s
|
||
"""
|
||
cursor.execute(query, (document_id, user_id))
|
||
document = cursor.fetchone()
|
||
|
||
if not document:
|
||
return jsonify({"success": False, "error": "文档不存在或无权访问"}), 404
|
||
|
||
# 检查文件是否存在
|
||
if not os.path.exists(document['file_path']):
|
||
return jsonify({"success": False, "error": "文件不存在"}), 404
|
||
|
||
# 返回文件下载
|
||
return send_file(
|
||
document['file_path'],
|
||
as_attachment=True,
|
||
download_name=document['original_filename'],
|
||
mimetype='text/plain'
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"下载文档时出错: {str(e)}")
|
||
return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500
|
||
|
||
finally:
|
||
cursor.close()
|
||
|
||
|
||
@classify_bp.route('/download-multiple', methods=['POST'])
|
||
def download_multiple_documents():
|
||
"""下载多个文档(打包为zip)"""
|
||
# 检查用户是否登录
|
||
if 'user_id' not in session:
|
||
return jsonify({"success": False, "error": "请先登录"}), 401
|
||
|
||
user_id = session['user_id']
|
||
|
||
# 获取请求数据
|
||
data = request.get_json()
|
||
if not data or 'document_ids' not in data:
|
||
return jsonify({"success": False, "error": "缺少必要参数"}), 400
|
||
|
||
document_ids = data['document_ids']
|
||
if not isinstance(document_ids, list) or not document_ids:
|
||
return jsonify({"success": False, "error": "文档ID列表无效"}), 400
|
||
|
||
db = get_db()
|
||
cursor = db.cursor(dictionary=True)
|
||
|
||
try:
|
||
# 创建临时目录用于存放zip文件
|
||
temp_dir = os.path.join(current_app.root_path, 'temp')
|
||
os.makedirs(temp_dir, exist_ok=True)
|
||
|
||
# 创建临时ZIP文件
|
||
zip_filename = f"documents_{int(time.time())}.zip"
|
||
zip_path = os.path.join(temp_dir, zip_filename)
|
||
|
||
# 查询所有符合条件的文档
|
||
placeholders = ', '.join(['%s'] * len(document_ids))
|
||
query = f"""
|
||
SELECT id, file_path, original_filename
|
||
FROM documents
|
||
WHERE id IN ({placeholders}) AND user_id = %s
|
||
"""
|
||
params = document_ids + [user_id]
|
||
cursor.execute(query, params)
|
||
documents = cursor.fetchall()
|
||
|
||
if not documents:
|
||
return jsonify({"success": False, "error": "没有找到符合条件的文档"}), 404
|
||
|
||
# 创建ZIP文件并添加文档
|
||
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
||
for doc in documents:
|
||
if os.path.exists(doc['file_path']):
|
||
# 添加文件到zip,使用原始文件名
|
||
zipf.write(doc['file_path'], arcname=doc['original_filename'])
|
||
|
||
# 返回ZIP文件下载
|
||
return send_file(
|
||
zip_path,
|
||
as_attachment=True,
|
||
download_name=zip_filename,
|
||
mimetype='application/zip'
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"下载多个文档时出错: {str(e)}")
|
||
return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500
|
||
|
||
finally:
|
||
cursor.close()
|
||
|
||
|
||
@classify_bp.route('/classify-text', methods=['POST'])
|
||
def classify_text_directly():
|
||
"""直接对文本进行分类(不保存文件)"""
|
||
# 检查用户是否登录
|
||
if 'user_id' not in session:
|
||
return jsonify({"success": False, "error": "请先登录"}), 401
|
||
|
||
# 获取请求数据
|
||
data = request.get_json()
|
||
if not data or 'text' not in data:
|
||
return jsonify({"success": False, "error": "缺少必要参数"}), 400
|
||
|
||
text = data['text']
|
||
if not text.strip():
|
||
return jsonify({"success": False, "error": "文本内容不能为空"}), 400
|
||
|
||
try:
|
||
# 调用模型进行分类
|
||
result = text_classifier.classify_text(text)
|
||
|
||
if not result['success']:
|
||
return jsonify({"success": False, "error": result['error']}), 500
|
||
|
||
# 返回分类结果
|
||
return jsonify({
|
||
"success": True,
|
||
"category": result['category'],
|
||
"confidence": result['confidence'],
|
||
"all_confidences": result['all_confidences']
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"文本分类过程中发生错误: {str(e)}")
|
||
return jsonify({"success": False, "error": f"文本分类错误: {str(e)}"}), 500
|
||
|
||
|
||
@classify_bp.route('/categories', methods=['GET'])
|
||
def get_categories():
|
||
"""获取所有分类类别"""
|
||
db = get_db()
|
||
cursor = db.cursor(dictionary=True)
|
||
|
||
try:
|
||
cursor.execute("SELECT id, name, description FROM categories ORDER BY name")
|
||
categories = cursor.fetchall()
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"categories": categories
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取类别列表时出错: {str(e)}")
|
||
return jsonify({"success": False, "error": f"获取类别列表失败: {str(e)}"}), 500
|
||
|
||
finally:
|
||
cursor.close()
|