text-classify-ui/routes/classify.py
2025-03-18 01:13:24 +08:00

660 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# routes/classify.py
import os
import time
import uuid
import zipfile
import rarfile
import shutil
from flask import Blueprint, request, jsonify, current_app, send_file, g, session
from werkzeug.utils import secure_filename
import mysql.connector
from utils.model_service import text_classifier
from utils.db import get_db, close_db
import logging
# 创建蓝图
classify_bp = Blueprint('classify', __name__, url_prefix='/api/classify')
# 设置日志
logger = logging.getLogger(__name__)
# 允许的文件扩展名
ALLOWED_TEXT_EXTENSIONS = {'txt'}
ALLOWED_ARCHIVE_EXTENSIONS = {'zip', 'rar'}
def allowed_text_file(filename):
"""检查文本文件扩展名是否允许"""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_TEXT_EXTENSIONS
def allowed_archive_file(filename):
"""检查压缩文件扩展名是否允许"""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_ARCHIVE_EXTENSIONS
def get_next_document_number(category):
"""获取给定类别的下一个文档编号
Args:
category (str): 文档类别
Returns:
int: 下一个编号
"""
db = get_db()
cursor = db.cursor(dictionary=True)
try:
# 查询该类别的最大编号
query = """
SELECT MAX(CAST(SUBSTRING_INDEX(SUBSTRING_INDEX(stored_filename, '-', 2), '-', -1) AS UNSIGNED)) as max_num
FROM documents
WHERE stored_filename LIKE %s
"""
cursor.execute(query, (f"{category}-%",))
result = cursor.fetchone()
# 如果没有记录或者最大编号为None返回1否则返回最大编号+1
if not result or result['max_num'] is None:
return 1
else:
return result['max_num'] + 1
except Exception as e:
logger.error(f"获取下一个文档编号时出错: {str(e)}")
return 1
finally:
cursor.close()
def save_classified_document(user_id, original_filename, category, content, file_size=None):
"""保存分类后的文档
Args:
user_id (int): 用户ID
original_filename (str): 原始文件名
category (str): 分类类别
content (str): 文档内容
file_size (int, optional): 文件大小如果为None则自动计算
Returns:
tuple: (成功标志, 存储的文件名或错误信息)
"""
try:
# 获取下一个文档编号
next_num = get_next_document_number(category)
# 安全处理文件名
safe_original_name = secure_filename(original_filename)
# 生成新文件名 (类别-编号-原文件名)
formatted_num = f"{next_num:04d}" # 确保编号格式为4位数
new_filename = f"{category}-{formatted_num}-{safe_original_name}"
# 确保uploads目录存在
uploads_dir = os.path.join(current_app.root_path, current_app.config['UPLOAD_FOLDER'])
category_dir = os.path.join(uploads_dir, category)
os.makedirs(category_dir, exist_ok=True)
# 文件完整路径
file_path = os.path.join(category_dir, new_filename)
# 保存文件
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
# 如果文件大小未提供,则计算
if file_size is None:
file_size = os.path.getsize(file_path)
# 获取db连接
db = get_db()
cursor = db.cursor()
# 获取分类类别ID
category_query = "SELECT id FROM categories WHERE name = %s"
cursor.execute(category_query, (category,))
category_result = cursor.fetchone()
if not category_result:
return False, "类别不存在"
category_id = category_result[0]
# 插入数据库记录
insert_query = """
INSERT INTO documents
(user_id, original_filename, stored_filename, file_path, file_size, category_id, status, classified_time)
VALUES (%s, %s, %s, %s, %s, %s, %s, NOW())
"""
cursor.execute(
insert_query,
(user_id, original_filename, new_filename, file_path, file_size, category_id, '已分类')
)
# 提交事务
db.commit()
return True, new_filename
except Exception as e:
logger.error(f"保存分类文档时出错: {str(e)}")
return False, str(e)
@classify_bp.route('/single', methods=['POST'])
def classify_single_file():
"""单文件上传和分类API"""
# 检查用户是否登录
if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id']
# 检查是否上传了文件
if 'file' not in request.files:
return jsonify({"success": False, "error": "没有文件"}), 400
file = request.files['file']
# 检查文件名
if file.filename == '':
return jsonify({"success": False, "error": "未选择文件"}), 400
# 检查文件类型
if not allowed_text_file(file.filename):
return jsonify({"success": False, "error": "不支持的文件类型仅支持txt文件"}), 400
try:
# 创建临时文件以供处理
temp_dir = os.path.join(current_app.root_path, 'temp')
os.makedirs(temp_dir, exist_ok=True)
temp_filename = f"{uuid.uuid4().hex}.txt"
temp_path = os.path.join(temp_dir, temp_filename)
# 保存上传文件到临时位置
file.save(temp_path)
# 读取文件内容
with open(temp_path, 'r', encoding='utf-8') as f:
file_content = f.read()
# 调用模型进行分类
result = text_classifier.classify_text(file_content)
if not result['success']:
return jsonify({"success": False, "error": result['error']}), 500
# 保存分类后的文档
file_size = os.path.getsize(temp_path)
save_success, message = save_classified_document(
user_id,
file.filename,
result['category'],
file_content,
file_size
)
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
if not save_success:
return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500
# 返回分类结果
return jsonify({
"success": True,
"filename": file.filename,
"category": result['category'],
"confidence": result['confidence'],
"stored_filename": message
})
except UnicodeDecodeError:
# 尝试GBK编码
try:
with open(temp_path, 'r', encoding='gbk') as f:
file_content = f.read()
# 调用模型进行分类
result = text_classifier.classify_text(file_content)
if not result['success']:
return jsonify({"success": False, "error": result['error']}), 500
# 保存分类后的文档
file_size = os.path.getsize(temp_path)
save_success, message = save_classified_document(
user_id,
file.filename,
result['category'],
file_content,
file_size
)
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
if not save_success:
return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500
# 返回分类结果
return jsonify({
"success": True,
"filename": file.filename,
"category": result['category'],
"confidence": result['confidence'],
"stored_filename": message
})
except Exception as e:
if os.path.exists(temp_path):
os.remove(temp_path)
return jsonify({"success": False, "error": f"文件编码错误请确保文件为UTF-8或GBK编码: {str(e)}"}), 400
except Exception as e:
# 确保清理临时文件
if 'temp_path' in locals() and os.path.exists(temp_path):
os.remove(temp_path)
logger.error(f"文件处理过程中发生错误: {str(e)}")
return jsonify({"success": False, "error": f"文件处理错误: {str(e)}"}), 500
@classify_bp.route('/batch', methods=['POST'])
def classify_batch_files():
"""批量文件上传和分类API压缩包处理"""
# 检查用户是否登录
if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id']
# 检查是否上传了文件
if 'file' not in request.files:
return jsonify({"success": False, "error": "没有文件"}), 400
file = request.files['file']
# 检查文件名
if file.filename == '':
return jsonify({"success": False, "error": "未选择文件"}), 400
# 检查文件类型
if not allowed_archive_file(file.filename):
return jsonify({"success": False, "error": "不支持的文件类型仅支持zip和rar压缩文件"}), 400
# 检查文件大小
if request.content_length > 10 * 1024 * 1024: # 10MB
return jsonify({"success": False, "error": "文件太大最大支持10MB"}), 400
try:
# 创建临时目录
temp_dir = os.path.join(current_app.root_path, 'temp')
extract_dir = os.path.join(temp_dir, f"extract_{uuid.uuid4().hex}")
os.makedirs(extract_dir, exist_ok=True)
# 保存上传的压缩文件
archive_path = os.path.join(temp_dir, secure_filename(file.filename))
file.save(archive_path)
# 解压文件
file_extension = file.filename.rsplit('.', 1)[1].lower()
if file_extension == 'zip':
with zipfile.ZipFile(archive_path, 'r') as zip_ref:
zip_ref.extractall(extract_dir)
elif file_extension == 'rar':
with rarfile.RarFile(archive_path, 'r') as rar_ref:
rar_ref.extractall(extract_dir)
# 处理结果统计
results = {
"total": 0,
"success": 0,
"failed": 0,
"categories": {},
"failed_files": []
}
# 递归处理所有txt文件
for root, dirs, files in os.walk(extract_dir):
for filename in files:
if filename.lower().endswith('.txt'):
file_path = os.path.join(root, filename)
results["total"] += 1
try:
# 读取文件内容
try:
with open(file_path, 'r', encoding='utf-8') as f:
file_content = f.read()
except UnicodeDecodeError:
# 尝试GBK编码
with open(file_path, 'r', encoding='gbk') as f:
file_content = f.read()
# 调用模型进行分类
result = text_classifier.classify_text(file_content)
if not result['success']:
results["failed"] += 1
results["failed_files"].append({
"filename": filename,
"error": result['error']
})
continue
category = result['category']
# 统计类别数量
if category not in results["categories"]:
results["categories"][category] = 0
results["categories"][category] += 1
# 保存分类后的文档
file_size = os.path.getsize(file_path)
save_success, message = save_classified_document(
user_id,
filename,
category,
file_content,
file_size
)
if save_success:
results["success"] += 1
else:
results["failed"] += 1
results["failed_files"].append({
"filename": filename,
"error": message
})
except Exception as e:
results["failed"] += 1
results["failed_files"].append({
"filename": filename,
"error": str(e)
})
# 清理临时文件
if os.path.exists(archive_path):
os.remove(archive_path)
if os.path.exists(extract_dir):
shutil.rmtree(extract_dir)
# 返回处理结果
return jsonify({
"success": True,
"archive_name": file.filename,
"results": results
})
except Exception as e:
# 确保清理临时文件
if 'archive_path' in locals() and os.path.exists(archive_path):
os.remove(archive_path)
if 'extract_dir' in locals() and os.path.exists(extract_dir):
shutil.rmtree(extract_dir)
logger.error(f"压缩包处理过程中发生错误: {str(e)}")
return jsonify({"success": False, "error": f"压缩包处理错误: {str(e)}"}), 500
@classify_bp.route('/documents', methods=['GET'])
def get_classified_documents():
"""获取已分类的文档列表"""
# 检查用户是否登录
if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id']
# 获取查询参数
category = request.args.get('category', 'all')
page = int(request.args.get('page', 1))
per_page = int(request.args.get('per_page', 10))
# 验证每页条数
if per_page not in [10, 25, 50, 100]:
per_page = 10
# 计算偏移量
offset = (page - 1) * per_page
db = get_db()
cursor = db.cursor(dictionary=True)
try:
# 构建查询条件
where_clause = "WHERE d.user_id = %s AND d.status = '已分类'"
params = [user_id]
if category != 'all':
where_clause += " AND c.name = %s"
params.append(category)
# 查询总记录数
count_query = f"""
SELECT COUNT(*) as total
FROM documents d
JOIN categories c ON d.category_id = c.id
{where_clause}
"""
cursor.execute(count_query, params)
total_count = cursor.fetchone()['total']
# 计算总页数
total_pages = (total_count + per_page - 1) // per_page
# 查询分页数据
query = f"""
SELECT d.id, d.original_filename, d.stored_filename, d.file_size,
c.name as category, d.upload_time, d.classified_time
FROM documents d
JOIN categories c ON d.category_id = c.id
{where_clause}
ORDER BY d.classified_time DESC
LIMIT %s OFFSET %s
"""
params.extend([per_page, offset])
cursor.execute(query, params)
documents = cursor.fetchall()
# 获取所有可用类别
cursor.execute("SELECT name FROM categories ORDER BY name")
categories = [row['name'] for row in cursor.fetchall()]
return jsonify({
"success": True,
"documents": documents,
"pagination": {
"total": total_count,
"per_page": per_page,
"current_page": page,
"total_pages": total_pages
},
"categories": categories,
"current_category": category
})
except Exception as e:
logger.error(f"获取文档列表时出错: {str(e)}")
return jsonify({"success": False, "error": f"获取文档列表失败: {str(e)}"}), 500
finally:
cursor.close()
@classify_bp.route('/download/<int:document_id>', methods=['GET'])
def download_document(document_id):
"""下载已分类的文档"""
# 检查用户是否登录
if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id']
db = get_db()
cursor = db.cursor(dictionary=True)
try:
# 查询文档信息
query = """
SELECT file_path, original_filename, stored_filename
FROM documents
WHERE id = %s AND user_id = %s
"""
cursor.execute(query, (document_id, user_id))
document = cursor.fetchone()
if not document:
return jsonify({"success": False, "error": "文档不存在或无权访问"}), 404
# 检查文件是否存在
if not os.path.exists(document['file_path']):
return jsonify({"success": False, "error": "文件不存在"}), 404
# 返回文件下载
return send_file(
document['file_path'],
as_attachment=True,
download_name=document['original_filename'],
mimetype='text/plain'
)
except Exception as e:
logger.error(f"下载文档时出错: {str(e)}")
return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500
finally:
cursor.close()
@classify_bp.route('/download-multiple', methods=['POST'])
def download_multiple_documents():
"""下载多个文档打包为zip"""
# 检查用户是否登录
if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id']
# 获取请求数据
data = request.get_json()
if not data or 'document_ids' not in data:
return jsonify({"success": False, "error": "缺少必要参数"}), 400
document_ids = data['document_ids']
if not isinstance(document_ids, list) or not document_ids:
return jsonify({"success": False, "error": "文档ID列表无效"}), 400
db = get_db()
cursor = db.cursor(dictionary=True)
try:
# 创建临时目录用于存放zip文件
temp_dir = os.path.join(current_app.root_path, 'temp')
os.makedirs(temp_dir, exist_ok=True)
# 创建临时ZIP文件
zip_filename = f"documents_{int(time.time())}.zip"
zip_path = os.path.join(temp_dir, zip_filename)
# 查询所有符合条件的文档
placeholders = ', '.join(['%s'] * len(document_ids))
query = f"""
SELECT id, file_path, original_filename
FROM documents
WHERE id IN ({placeholders}) AND user_id = %s
"""
params = document_ids + [user_id]
cursor.execute(query, params)
documents = cursor.fetchall()
if not documents:
return jsonify({"success": False, "error": "没有找到符合条件的文档"}), 404
# 创建ZIP文件并添加文档
with zipfile.ZipFile(zip_path, 'w') as zipf:
for doc in documents:
if os.path.exists(doc['file_path']):
# 添加文件到zip使用原始文件名
zipf.write(doc['file_path'], arcname=doc['original_filename'])
# 返回ZIP文件下载
return send_file(
zip_path,
as_attachment=True,
download_name=zip_filename,
mimetype='application/zip'
)
except Exception as e:
logger.error(f"下载多个文档时出错: {str(e)}")
return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500
finally:
cursor.close()
@classify_bp.route('/classify-text', methods=['POST'])
def classify_text_directly():
"""直接对文本进行分类(不保存文件)"""
# 检查用户是否登录
if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401
# 获取请求数据
data = request.get_json()
if not data or 'text' not in data:
return jsonify({"success": False, "error": "缺少必要参数"}), 400
text = data['text']
if not text.strip():
return jsonify({"success": False, "error": "文本内容不能为空"}), 400
try:
# 调用模型进行分类
result = text_classifier.classify_text(text)
if not result['success']:
return jsonify({"success": False, "error": result['error']}), 500
# 返回分类结果
return jsonify({
"success": True,
"category": result['category'],
"confidence": result['confidence'],
"all_confidences": result['all_confidences']
})
except Exception as e:
logger.error(f"文本分类过程中发生错误: {str(e)}")
return jsonify({"success": False, "error": f"文本分类错误: {str(e)}"}), 500
@classify_bp.route('/categories', methods=['GET'])
def get_categories():
"""获取所有分类类别"""
db = get_db()
cursor = db.cursor(dictionary=True)
try:
cursor.execute("SELECT id, name, description FROM categories ORDER BY name")
categories = cursor.fetchall()
return jsonify({
"success": True,
"categories": categories
})
except Exception as e:
logger.error(f"获取类别列表时出错: {str(e)}")
return jsonify({"success": False, "error": f"获取类别列表失败: {str(e)}"}), 500
finally:
cursor.close()