This commit is contained in:
superlishunqin 2025-03-31 03:06:26 +08:00
parent a8a0beb277
commit 24953f68df
4 changed files with 421 additions and 298 deletions

6
.idea/vcs.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

3
app.py
View File

@ -60,5 +60,4 @@ def internal_server_error(e):
return render_template('error.html', error='服务器内部错误'), 500 return render_template('error.html', error='服务器内部错误'), 500
if __name__ == '__main__': if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=5009) app.run(debug=True, host='0.0.0.0', port=50004)

View File

@ -8,7 +8,7 @@ import shutil
from flask import Blueprint, request, jsonify, current_app, send_file, g, session from flask import Blueprint, request, jsonify, current_app, send_file, g, session
from werkzeug.utils import secure_filename from werkzeug.utils import secure_filename
import mysql.connector import mysql.connector
from utils.model_service import text_classifier from utils.model_service import text_classifier # Assuming text_classifier is initialized correctly elsewhere
from utils.db import get_db, close_db from utils.db import get_db, close_db
import logging import logging
@ -59,7 +59,8 @@ def get_next_document_number(category):
if not result or result['max_num'] is None: if not result or result['max_num'] is None:
return 1 return 1
else: else:
return result['max_num'] + 1 # Ensure it's an int before adding
return int(result['max_num']) + 1
except Exception as e: except Exception as e:
logger.error(f"获取下一个文档编号时出错: {str(e)}") logger.error(f"获取下一个文档编号时出错: {str(e)}")
return 1 return 1
@ -80,12 +81,17 @@ def save_classified_document(user_id, original_filename, category, content, file
Returns: Returns:
tuple: (成功标志, 存储的文件名或错误信息) tuple: (成功标志, 存储的文件名或错误信息)
""" """
db = None # Initialize db to None
cursor = None # Initialize cursor to None
try: try:
# 获取下一个文档编号 # 获取下一个文档编号
next_num = get_next_document_number(category) next_num = get_next_document_number(category)
# 安全处理文件名 # 安全处理文件名 - ensure it doesn't become empty
safe_original_name = secure_filename(original_filename) safe_original_name_base = os.path.splitext(secure_filename(original_filename))[0]
if not safe_original_name_base:
safe_original_name_base = "untitled" # Default if name becomes empty
safe_original_name = f"{safe_original_name_base}.txt" # Ensure .txt extension
# 生成新文件名 (类别-编号-原文件名) # 生成新文件名 (类别-编号-原文件名)
formatted_num = f"{next_num:04d}" # 确保编号格式为4位数 formatted_num = f"{next_num:04d}" # 确保编号格式为4位数
@ -117,16 +123,18 @@ def save_classified_document(user_id, original_filename, category, content, file
category_result = cursor.fetchone() category_result = cursor.fetchone()
if not category_result: if not category_result:
return False, "类别不存在" logger.error(f"类别 '{category}' 在数据库中未找到。")
# Optionally create the category if needed, or return error
return False, f"类别 '{category}' 不存在"
category_id = category_result[0] category_id = category_result[0]
# 插入数据库记录 # 插入数据库记录
insert_query = """ insert_query = """
INSERT INTO documents INSERT INTO documents
(user_id, original_filename, stored_filename, file_path, file_size, category_id, status, classified_time) (user_id, original_filename, stored_filename, file_path, file_size, category_id, status, classified_time, upload_time)
VALUES (%s, %s, %s, %s, %s, %s, %s, NOW()) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW())
""" """ # Added upload_time=NOW() assuming it's desired
cursor.execute( cursor.execute(
insert_query, insert_query,
@ -135,12 +143,19 @@ def save_classified_document(user_id, original_filename, category, content, file
# 提交事务 # 提交事务
db.commit() db.commit()
logger.info(f"成功保存文档: {new_filename} (ID: {cursor.lastrowid})")
return True, new_filename return True, new_filename
except Exception as e: except Exception as e:
logger.error(f"保存分类文档时出错: {str(e)}") # Rollback transaction if error occurs
if db:
db.rollback()
logger.exception(f"保存分类文档 '{original_filename}' 时出错: {str(e)}") # Use logger.exception for traceback
return False, str(e) return False, str(e)
finally:
if cursor:
cursor.close()
# Note: db connection is usually closed via @app.teardown_appcontext
@classify_bp.route('/single', methods=['POST']) @classify_bp.route('/single', methods=['POST'])
@ -157,203 +172,214 @@ def classify_single_file():
return jsonify({"success": False, "error": "没有文件"}), 400 return jsonify({"success": False, "error": "没有文件"}), 400
file = request.files['file'] file = request.files['file']
original_filename = file.filename # Store original name early
# 检查文件名 # 检查文件名
if file.filename == '': if original_filename == '':
return jsonify({"success": False, "error": "未选择文件"}), 400 return jsonify({"success": False, "error": "未选择文件"}), 400
# 检查文件类型 # 检查文件类型
if not allowed_text_file(file.filename): if not allowed_text_file(original_filename):
return jsonify({"success": False, "error": "不支持的文件类型仅支持txt文件"}), 400 return jsonify({"success": False, "error": "不支持的文件类型仅支持txt文件"}), 400
temp_path = None # Initialize to ensure it's defined for cleanup
try: try:
# 创建临时文件以供处理 # 创建临时文件以供处理
temp_dir = os.path.join(current_app.root_path, 'temp') temp_dir = os.path.join(current_app.root_path, 'temp')
os.makedirs(temp_dir, exist_ok=True) os.makedirs(temp_dir, exist_ok=True)
temp_filename = f"{uuid.uuid4().hex}.txt" # Use secure_filename for the temp file name part as well
temp_filename = f"{uuid.uuid4().hex}-{secure_filename(original_filename)}"
temp_path = os.path.join(temp_dir, temp_filename) temp_path = os.path.join(temp_dir, temp_filename)
# 保存上传文件到临时位置 # 保存上传文件到临时位置
file.save(temp_path) file.save(temp_path)
file_size = os.path.getsize(temp_path) # Get size after saving
# 读取文件内容 # 读取文件内容 - Try multiple encodings
with open(temp_path, 'r', encoding='utf-8') as f: file_content = None
file_content = f.read() encodings_to_try = ['utf-8', 'gbk', 'gb18030']
for enc in encodings_to_try:
# 调用模型进行分类
result = text_classifier.classify_text(file_content)
if not result['success']:
return jsonify({"success": False, "error": result['error']}), 500
# 保存分类后的文档
file_size = os.path.getsize(temp_path)
save_success, message = save_classified_document(
user_id,
file.filename,
result['category'],
file_content,
file_size
)
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
if not save_success:
return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500
# 返回分类结果
return jsonify({
"success": True,
"filename": file.filename,
"category": result['category'],
"confidence": result['confidence'],
"stored_filename": message
})
except UnicodeDecodeError:
# 尝试GBK编码
try: try:
with open(temp_path, 'r', encoding='gbk') as f: with open(temp_path, 'r', encoding=enc) as f:
file_content = f.read() file_content = f.read()
logger.info(f"成功以 {enc} 读取临时文件: {temp_path}")
break
except UnicodeDecodeError:
logger.warning(f"使用 {enc} 解码临时文件失败: {temp_path}")
continue
except Exception as read_err: # Catch other potential read errors
logger.error(f"读取临时文件时出错 ({enc}): {read_err}")
# Decide if you want to stop or try next encoding
continue
if file_content is None:
raise ValueError(f"无法使用支持的编码 {encodings_to_try} 读取文件内容。")
# 调用模型进行分类 # 调用模型进行分类
result = text_classifier.classify_text(file_content) result = text_classifier.classify_text(file_content)
if not result['success']: if not result.get('success', False): # Check if key exists and is True
return jsonify({"success": False, "error": result['error']}), 500 error_msg = result.get('error', '未知的分类错误')
logger.error(f"文本分类失败 for {original_filename}: {error_msg}")
return jsonify({"success": False, "error": error_msg}), 500
category = result['category']
# 保存分类后的文档 # 保存分类后的文档
file_size = os.path.getsize(temp_path)
save_success, message = save_classified_document( save_success, message = save_classified_document(
user_id, user_id,
file.filename, original_filename, # Use the original filename here
result['category'], category,
file_content, file_content,
file_size file_size
) )
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
if not save_success: if not save_success:
return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500 return jsonify({"success": False, "error": f"保存文档失败: {message}"}), 500
# 返回分类结果 # 返回分类结果
return jsonify({ return jsonify({
"success": True, "success": True,
"filename": file.filename, "filename": original_filename,
"category": result['category'], "category": category,
"confidence": result['confidence'], "confidence": result.get('confidence'), # Use .get for safety
"stored_filename": message "stored_filename": message
}) })
except Exception as e:
if os.path.exists(temp_path):
os.remove(temp_path)
return jsonify({"success": False, "error": f"文件编码错误请确保文件为UTF-8或GBK编码: {str(e)}"}), 400
except Exception as e: except Exception as e:
# 确保清理临时文件 logger.exception(f"处理单文件 '{original_filename}' 过程中发生错误: {str(e)}") # Log exception with traceback
if 'temp_path' in locals() and os.path.exists(temp_path):
os.remove(temp_path)
logger.error(f"文件处理过程中发生错误: {str(e)}")
return jsonify({"success": False, "error": f"文件处理错误: {str(e)}"}), 500 return jsonify({"success": False, "error": f"文件处理错误: {str(e)}"}), 500
finally:
# 清理临时文件
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
logger.info(f"已删除临时文件: {temp_path}")
except OSError as rm_err:
logger.error(f"删除临时文件失败 {temp_path}: {rm_err}")
@classify_bp.route('/batch', methods=['POST']) @classify_bp.route('/batch', methods=['POST'])
def classify_batch_files(): def classify_batch_files():
"""批量文件上传和分类API压缩包处理""" """批量文件上传和分类API压缩包处理"""
# 检查用户是否登录
if 'user_id' not in session: if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401 return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id'] user_id = session['user_id']
# 检查是否上传了文件
if 'file' not in request.files: if 'file' not in request.files:
return jsonify({"success": False, "error": "没有文件"}), 400 return jsonify({"success": False, "error": "没有文件"}), 400
file = request.files['file'] file = request.files['file']
original_archive_name = file.filename
# 检查文件名 if original_archive_name == '':
if file.filename == '':
return jsonify({"success": False, "error": "未选择文件"}), 400 return jsonify({"success": False, "error": "未选择文件"}), 400
# 检查文件类型 if not allowed_archive_file(original_archive_name):
if not allowed_archive_file(file.filename):
return jsonify({"success": False, "error": "不支持的文件类型仅支持zip和rar压缩文件"}), 400 return jsonify({"success": False, "error": "不支持的文件类型仅支持zip和rar压缩文件"}), 400
# 检查文件大小 # Consider adding file size check here if needed (e.g., request.content_length)
if request.content_length > 10 * 1024 * 1024: # 10MB # if request.content_length > current_app.config['MAX_CONTENT_LENGTH']: # Example
return jsonify({"success": False, "error": "文件太大最大支持10MB"}), 400 # return jsonify({"success": False, "error": "文件过大"}), 413
try:
# 创建临时目录
temp_dir = os.path.join(current_app.root_path, 'temp') temp_dir = os.path.join(current_app.root_path, 'temp')
extract_dir = os.path.join(temp_dir, f"extract_{uuid.uuid4().hex}") extract_dir = os.path.join(temp_dir, f"extract_{uuid.uuid4().hex}")
os.makedirs(extract_dir, exist_ok=True) archive_path = None
# --- 修改点 1: 将 total_attempted 改回 total ---
# 保存上传的压缩文件
archive_path = os.path.join(temp_dir, secure_filename(file.filename))
file.save(archive_path)
# 解压文件
file_extension = file.filename.rsplit('.', 1)[1].lower()
if file_extension == 'zip':
with zipfile.ZipFile(archive_path, 'r') as zip_ref:
zip_ref.extractall(extract_dir)
elif file_extension == 'rar':
with rarfile.RarFile(archive_path, 'r') as rar_ref:
rar_ref.extractall(extract_dir)
# 处理结果统计
results = { results = {
"total": 0, "total": 0, # <--- 使用 'total'
"success": 0, "success": 0,
"failed": 0, "failed": 0,
"categories": {}, "categories": {},
"failed_files": [] "failed_files": []
} }
try:
os.makedirs(extract_dir, exist_ok=True)
archive_path = os.path.join(temp_dir, secure_filename(original_archive_name))
file.save(archive_path)
logger.info(f"已保存上传的压缩文件: {archive_path}")
# 解压文件
file_extension = original_archive_name.rsplit('.', 1)[1].lower()
if file_extension == 'zip':
with zipfile.ZipFile(archive_path, 'r') as zip_ref:
zip_ref.extractall(extract_dir)
logger.info(f"已解压ZIP文件到: {extract_dir}")
elif file_extension == 'rar':
# Ensure rarfile is installed and unrar executable is available
try:
with rarfile.RarFile(archive_path, 'r') as rar_ref:
rar_ref.extractall(extract_dir)
logger.info(f"已解压RAR文件到: {extract_dir}")
except rarfile.NeedFirstVolume:
logger.error(f"RAR文件是分卷压缩的一部分需要第一个分卷: {original_archive_name}")
raise ValueError("不支持分卷压缩的RAR文件")
except rarfile.BadRarFile:
logger.error(f"RAR文件损坏或格式错误: {original_archive_name}")
raise ValueError("RAR文件损坏或格式错误")
except Exception as rar_err: # Catch other rar errors (like missing unrar)
logger.error(f"解压RAR文件时出错: {rar_err}")
raise ValueError(f"解压RAR文件失败: {rar_err}")
# 递归处理所有txt文件 # 递归处理所有txt文件
for root, dirs, files in os.walk(extract_dir): for root, dirs, files in os.walk(extract_dir):
# --- 开始过滤 macOS 文件夹 ---
if '__MACOSX' in root.split(os.path.sep):
logger.info(f"Skipping macOS metadata directory: {root}")
dirs[:] = []
files[:] = []
continue
# --- 结束过滤 macOS 文件夹 ---
for filename in files: for filename in files:
if filename.lower().endswith('.txt'): # --- 开始过滤 macOS 文件 ---
if filename.startswith('._') or filename == '.DS_Store':
logger.info(f"Skipping macOS metadata file: {filename} in {root}")
continue
# --- 结束过滤 macOS 文件 ---
# Process only allowed text files
if allowed_text_file(filename):
file_path = os.path.join(root, filename) file_path = os.path.join(root, filename)
results["total"] += 1 # --- 修改点 2: 将 total_attempted 改回 total ---
results["total"] += 1 # <--- 使用 'total'
try: try:
# 读取文件内容 # 读取文件内容 - Try multiple encodings
file_content = None
encodings_to_try = ['utf-8', 'gbk', 'gb18030']
for enc in encodings_to_try:
try: try:
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, 'r', encoding=enc) as f:
file_content = f.read() file_content = f.read()
logger.info(f"成功以 {enc} 读取文件: {file_path}")
break
except UnicodeDecodeError: except UnicodeDecodeError:
# 尝试GBK编码 logger.warning(f"使用 {enc} 解码文件失败: {file_path}")
with open(file_path, 'r', encoding='gbk') as f: continue
file_content = f.read() except Exception as read_err:
logger.error(f"读取文件时发生其他错误 ({enc}) {file_path}: {read_err}")
continue
if file_content is None:
raise ValueError(f"无法使用支持的编码 {encodings_to_try} 读取文件内容。")
# 调用模型进行分类 # 调用模型进行分类
result = text_classifier.classify_text(file_content) result = text_classifier.classify_text(file_content)
if not result['success']: if not result.get('success', False):
results["failed"] += 1 error_msg = result.get('error', '未知的分类错误')
results["failed_files"].append({ raise Exception(f"分类失败: {error_msg}")
"filename": filename,
"error": result['error']
})
continue
category = result['category'] category = result['category']
# 统计类别数量 results["categories"][category] = results["categories"].get(category, 0) + 1
if category not in results["categories"]:
results["categories"][category] = 0
results["categories"][category] += 1
# 保存分类后的文档
file_size = os.path.getsize(file_path) file_size = os.path.getsize(file_path)
save_success, message = save_classified_document( save_success, message = save_classified_document(
user_id, user_id,
@ -366,106 +392,108 @@ def classify_batch_files():
if save_success: if save_success:
results["success"] += 1 results["success"] += 1
else: else:
results["failed"] += 1 raise Exception(f"保存失败: {message}")
results["failed_files"].append({
"filename": filename,
"error": message
})
except Exception as e: except Exception as e:
logger.error(f"处理文件 '{filename}' 失败: {str(e)}")
results["failed"] += 1 results["failed"] += 1
results["failed_files"].append({ results["failed_files"].append({
"filename": filename, "filename": filename,
"error": str(e) "error": str(e)
}) })
# 清理临时文件
if os.path.exists(archive_path):
os.remove(archive_path)
if os.path.exists(extract_dir):
shutil.rmtree(extract_dir)
# 返回处理结果 # 返回处理结果
return jsonify({ return jsonify({
"success": True, "success": True,
"archive_name": file.filename, "archive_name": original_archive_name,
"results": results "results": results # <--- 确保返回的 results 包含 'total' 键
}) })
except Exception as e: except Exception as e:
# 确保清理临时文件 logger.exception(f"处理压缩包 '{original_archive_name}' 过程中发生严重错误: {str(e)}")
if 'archive_path' in locals() and os.path.exists(archive_path):
os.remove(archive_path)
if 'extract_dir' in locals() and os.path.exists(extract_dir):
shutil.rmtree(extract_dir)
logger.error(f"压缩包处理过程中发生错误: {str(e)}")
return jsonify({"success": False, "error": f"压缩包处理错误: {str(e)}"}), 500 return jsonify({"success": False, "error": f"压缩包处理错误: {str(e)}"}), 500
finally:
# 清理临时文件和目录 (保持不变)
if archive_path and os.path.exists(archive_path):
try:
os.remove(archive_path)
logger.info(f"已删除临时压缩文件: {archive_path}")
except OSError as rm_err:
logger.error(f"删除临时压缩文件失败 {archive_path}: {rm_err}")
if os.path.exists(extract_dir):
try:
shutil.rmtree(extract_dir)
logger.info(f"已删除临时解压目录: {extract_dir}")
except OSError as rmtree_err:
logger.error(f"删除临时解压目录失败 {extract_dir}: {rmtree_err}")
@classify_bp.route('/documents', methods=['GET']) @classify_bp.route('/documents', methods=['GET'])
def get_classified_documents(): def get_classified_documents():
"""获取已分类的文档列表""" """获取已分类的文档列表"""
# 检查用户是否登录
if 'user_id' not in session: if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401 return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id'] user_id = session['user_id']
# 获取查询参数 try:
category = request.args.get('category', 'all') category = request.args.get('category', 'all')
page = int(request.args.get('page', 1)) page = int(request.args.get('page', 1))
per_page = int(request.args.get('per_page', 10)) per_page = int(request.args.get('per_page', 10))
# 验证每页条数
if per_page not in [10, 25, 50, 100]: if per_page not in [10, 25, 50, 100]:
per_page = 10 per_page = 10
if page < 1:
page = 1
# 计算偏移量
offset = (page - 1) * per_page offset = (page - 1) * per_page
db = get_db() db = get_db()
cursor = db.cursor(dictionary=True) cursor = db.cursor(dictionary=True) # Use dictionary cursor
try: # Get available categories first
# 构建查询条件 cursor.execute("SELECT name FROM categories ORDER BY name")
where_clause = "WHERE d.user_id = %s AND d.status = '已分类'" available_categories = [row['name'] for row in cursor.fetchall()]
# Build query
params = [user_id] params = [user_id]
base_query = """
if category != 'all': FROM documents d
where_clause += " AND c.name = %s" JOIN categories c ON d.category_id = c.id
WHERE d.user_id = %s AND d.status = '已分类'
"""
if category != 'all' and category in available_categories:
base_query += " AND c.name = %s"
params.append(category) params.append(category)
# 查询总记录数 # Count query
count_query = f""" count_query = f"SELECT COUNT(*) as total {base_query}"
SELECT COUNT(*) as total
FROM documents d
JOIN categories c ON d.category_id = c.id
{where_clause}
"""
cursor.execute(count_query, params) cursor.execute(count_query, params)
total_count = cursor.fetchone()['total'] total_count_result = cursor.fetchone()
total_count = total_count_result['total'] if total_count_result else 0
# 计算总页数 total_pages = (total_count + per_page - 1) // per_page if per_page > 0 else 0
total_pages = (total_count + per_page - 1) // per_page
# 查询分页数据 # Data query
query = f""" data_query = f"""
SELECT d.id, d.original_filename, d.stored_filename, d.file_size, SELECT d.id, d.original_filename, d.stored_filename, d.file_size,
c.name as category, d.upload_time, d.classified_time c.name as category, d.upload_time, d.classified_time
FROM documents d {base_query}
JOIN categories c ON d.category_id = c.id
{where_clause}
ORDER BY d.classified_time DESC ORDER BY d.classified_time DESC
LIMIT %s OFFSET %s LIMIT %s OFFSET %s
""" """
params.extend([per_page, offset]) params.extend([per_page, offset])
cursor.execute(query, params) cursor.execute(data_query, params)
documents = cursor.fetchall() documents = cursor.fetchall()
# 获取所有可用类别 # Format dates if they are datetime objects (optional, depends on driver)
cursor.execute("SELECT name FROM categories ORDER BY name") for doc in documents:
categories = [row['name'] for row in cursor.fetchall()] if hasattr(doc.get('upload_time'), 'isoformat'):
doc['upload_time'] = doc['upload_time'].isoformat()
if hasattr(doc.get('classified_time'), 'isoformat'):
doc['classified_time'] = doc['classified_time'].isoformat()
return jsonify({ return jsonify({
"success": True, "success": True,
@ -476,36 +504,35 @@ def get_classified_documents():
"current_page": page, "current_page": page,
"total_pages": total_pages "total_pages": total_pages
}, },
"categories": categories, "categories": available_categories, # Send the fetched list
"current_category": category "current_category": category
}) })
except Exception as e: except Exception as e:
logger.error(f"获取文档列表时出错: {str(e)}") logger.exception(f"获取文档列表时出错: {str(e)}") # Log traceback
return jsonify({"success": False, "error": f"获取文档列表失败: {str(e)}"}), 500 return jsonify({"success": False, "error": f"获取文档列表失败: {str(e)}"}), 500
finally: finally:
# Cursor closing handled by context usually, but explicit close is safe
if 'cursor' in locals() and cursor:
cursor.close() cursor.close()
@classify_bp.route('/download/<int:document_id>', methods=['GET']) @classify_bp.route('/download/<int:document_id>', methods=['GET'])
def download_document(document_id): def download_document(document_id):
"""下载已分类的文档""" """下载已分类的文档"""
# 检查用户是否登录
if 'user_id' not in session: if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401 return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id'] user_id = session['user_id']
cursor = None
try:
db = get_db() db = get_db()
cursor = db.cursor(dictionary=True) cursor = db.cursor(dictionary=True)
try:
# 查询文档信息
query = """ query = """
SELECT file_path, original_filename, stored_filename SELECT file_path, original_filename
FROM documents FROM documents
WHERE id = %s AND user_id = %s WHERE id = %s AND user_id = %s AND status = '已分类'
""" """
cursor.execute(query, (document_id, user_id)) cursor.execute(query, (document_id, user_id))
document = cursor.fetchone() document = cursor.fetchone()
@ -513,78 +540,110 @@ def download_document(document_id):
if not document: if not document:
return jsonify({"success": False, "error": "文档不存在或无权访问"}), 404 return jsonify({"success": False, "error": "文档不存在或无权访问"}), 404
# 检查文件是否存在 file_path = document['file_path']
if not os.path.exists(document['file_path']): download_name = document['original_filename']
return jsonify({"success": False, "error": "文件不存在"}), 404
# 返回文件下载 if not os.path.exists(file_path):
logger.error(f"请求下载的文件在服务器上不存在: {file_path} (Doc ID: {document_id})")
# Update status in DB?
# update_status_query = "UPDATE documents SET status = '文件丢失' WHERE id = %s"
# cursor.execute(update_status_query, (document_id,))
# db.commit()
return jsonify({"success": False, "error": "文件在服务器上丢失"}), 404
logger.info(f"用户 {user_id} 请求下载文档 ID: {document_id}, 文件: {download_name}")
return send_file( return send_file(
document['file_path'], file_path,
as_attachment=True, as_attachment=True,
download_name=document['original_filename'], download_name=download_name, # Use original filename for download
mimetype='text/plain' mimetype='text/plain' # Assuming text, adjust if other types allowed
) )
except Exception as e: except Exception as e:
logger.error(f"下载文档时出错: {str(e)}") logger.exception(f"下载文档 ID {document_id} 时出错: {str(e)}")
return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500 return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500
finally: finally:
if cursor:
cursor.close() cursor.close()
@classify_bp.route('/download-multiple', methods=['POST']) @classify_bp.route('/download-multiple', methods=['POST'])
def download_multiple_documents(): def download_multiple_documents():
"""下载多个文档打包为zip""" """下载多个文档打包为zip"""
# 检查用户是否登录
if 'user_id' not in session: if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401 return jsonify({"success": False, "error": "请先登录"}), 401
user_id = session['user_id'] user_id = session['user_id']
cursor = None
zip_path = None # Initialize for finally block
# 获取请求数据 try:
data = request.get_json() data = request.get_json()
if not data or 'document_ids' not in data: if not data or 'document_ids' not in data:
return jsonify({"success": False, "error": "缺少必要参数"}), 400 return jsonify({"success": False, "error": "缺少必要参数 'document_ids'"}), 400
document_ids = data['document_ids'] document_ids = data['document_ids']
if not isinstance(document_ids, list) or not document_ids: if not isinstance(document_ids, list) or not document_ids:
return jsonify({"success": False, "error": "文档ID列表无效"}), 400 return jsonify({"success": False, "error": "文档ID列表无效"}), 400
# Sanitize IDs to be integers
try:
document_ids = [int(doc_id) for doc_id in document_ids]
except ValueError:
return jsonify({"success": False, "error": "文档ID必须是数字"}), 400
db = get_db() db = get_db()
cursor = db.cursor(dictionary=True) cursor = db.cursor(dictionary=True)
try: # Query valid documents for the user
# 创建临时目录用于存放zip文件
temp_dir = os.path.join(current_app.root_path, 'temp')
os.makedirs(temp_dir, exist_ok=True)
# 创建临时ZIP文件
zip_filename = f"documents_{int(time.time())}.zip"
zip_path = os.path.join(temp_dir, zip_filename)
# 查询所有符合条件的文档
placeholders = ', '.join(['%s'] * len(document_ids)) placeholders = ', '.join(['%s'] * len(document_ids))
query = f""" query = f"""
SELECT id, file_path, original_filename SELECT id, file_path, original_filename
FROM documents FROM documents
WHERE id IN ({placeholders}) AND user_id = %s WHERE id IN ({placeholders}) AND user_id = %s AND status = '已分类'
""" """
params = document_ids + [user_id] params = document_ids + [user_id]
cursor.execute(query, params) cursor.execute(query, params)
documents = cursor.fetchall() documents = cursor.fetchall()
if not documents: if not documents:
return jsonify({"success": False, "error": "没有找到符合条件的文档"}), 404 return jsonify({"success": False, "error": "没有找到符合条件的可下载文档"}), 404
# 创建ZIP文件并添加文档 # Create temp zip file
with zipfile.ZipFile(zip_path, 'w') as zipf: temp_dir = os.path.join(current_app.root_path, 'temp')
os.makedirs(temp_dir, exist_ok=True)
zip_filename = f"documents_{user_id}_{int(time.time())}.zip"
zip_path = os.path.join(temp_dir, zip_filename)
files_added = 0
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for doc in documents: for doc in documents:
if os.path.exists(doc['file_path']): file_path = doc['file_path']
# 添加文件到zip使用原始文件名 original_filename = doc['original_filename']
zipf.write(doc['file_path'], arcname=doc['original_filename']) if os.path.exists(file_path):
zipf.write(file_path, arcname=original_filename)
files_added += 1
logger.info(f"添加到ZIP: {original_filename} (来自 {file_path})")
else:
logger.warning(f"跳过丢失的文件: {original_filename} (路径: {file_path}, Doc ID: {doc['id']})")
# Optionally add a readme to the zip indicating missing files
if files_added == 0:
# This case means all selected files were missing on disk
return jsonify({"success": False, "error": "所有选中的文件在服务器上都已丢失"}), 404
logger.info(f"用户 {user_id} 请求下载 {files_added} 个文档打包为 {zip_filename}")
# Use after_this_request to delete the temp file after sending
@current_app.after_request
def remove_file(response):
try:
if zip_path and os.path.exists(zip_path):
os.remove(zip_path)
logger.info(f"已删除临时ZIP文件: {zip_path}")
except Exception as error:
logger.error(f"删除临时ZIP文件错误 {zip_path}: {error}")
return response
# 返回ZIP文件下载
return send_file( return send_file(
zip_path, zip_path,
as_attachment=True, as_attachment=True,
@ -593,56 +652,75 @@ def download_multiple_documents():
) )
except Exception as e: except Exception as e:
logger.error(f"下载多个文档时出错: {str(e)}") logger.exception(f"下载多个文档时出错: {str(e)}")
# Clean up zip file if created but error occurred before sending
if zip_path and os.path.exists(zip_path):
try: os.remove(zip_path)
except OSError: pass
return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500 return jsonify({"success": False, "error": f"下载文档失败: {str(e)}"}), 500
finally: finally:
if cursor:
cursor.close() cursor.close()
@classify_bp.route('/classify-text', methods=['POST']) @classify_bp.route('/classify-text', methods=['POST'])
def classify_text_directly(): def classify_text_directly():
"""直接对文本进行分类(不保存文件)""" """直接对文本进行分类(不保存文件)"""
# 检查用户是否登录
if 'user_id' not in session: if 'user_id' not in session:
return jsonify({"success": False, "error": "请先登录"}), 401 return jsonify({"success": False, "error": "请先登录"}), 401
# 获取请求数据 try:
data = request.get_json() data = request.get_json()
if not data or 'text' not in data: if not data or 'text' not in data:
return jsonify({"success": False, "error": "缺少必要参数"}), 400 return jsonify({"success": False, "error": "缺少必要参数 'text'"}), 400
text = data['text'] text = data['text']
if not text.strip(): # Basic validation
return jsonify({"success": False, "error": "文本内容不能为空"}), 400 if not isinstance(text, str) or not text.strip():
return jsonify({"success": False, "error": "文本内容不能为空或无效"}), 400
# Limit text length?
MAX_TEXT_LENGTH = 10000 # Example limit
if len(text) > MAX_TEXT_LENGTH:
return jsonify({"success": False, "error": f"文本过长,最大支持 {MAX_TEXT_LENGTH} 字符"}), 413
try:
# 调用模型进行分类 # 调用模型进行分类
# Ensure model is initialized - consider adding a check or lazy loading
if not text_classifier or not text_classifier.is_initialized:
logger.warning("文本分类器未初始化,尝试初始化...")
if not text_classifier.initialize():
logger.error("文本分类器初始化失败。")
return jsonify({"success": False, "error": "分类服务暂时不可用"}), 503
result = text_classifier.classify_text(text) result = text_classifier.classify_text(text)
if not result['success']: if not result.get('success', False):
return jsonify({"success": False, "error": result['error']}), 500 error_msg = result.get('error', '未知的分类错误')
logger.error(f"直接文本分类失败: {error_msg}")
return jsonify({"success": False, "error": error_msg}), 500
# 返回分类结果 # 返回分类结果
return jsonify({ return jsonify({
"success": True, "success": True,
"category": result['category'], "category": result.get('category'),
"confidence": result['confidence'], "confidence": result.get('confidence'),
"all_confidences": result['all_confidences'] "all_confidences": result.get('all_confidences') # Include all confidences if available
}) })
except Exception as e: except Exception as e:
logger.error(f"文本分类过程中发生错误: {str(e)}") logger.exception(f"直接文本分类过程中发生错误: {str(e)}")
return jsonify({"success": False, "error": f"文本分类错误: {str(e)}"}), 500 return jsonify({"success": False, "error": f"文本分类错误: {str(e)}"}), 500
@classify_bp.route('/categories', methods=['GET']) @classify_bp.route('/categories', methods=['GET'])
def get_categories(): def get_categories():
"""获取所有分类类别""" """获取所有分类类别"""
cursor = None
try:
db = get_db() db = get_db()
cursor = db.cursor(dictionary=True) cursor = db.cursor(dictionary=True)
try:
cursor.execute("SELECT id, name, description FROM categories ORDER BY name") cursor.execute("SELECT id, name, description FROM categories ORDER BY name")
categories = cursor.fetchall() categories = cursor.fetchall()
@ -652,8 +730,8 @@ def get_categories():
}) })
except Exception as e: except Exception as e:
logger.error(f"获取类别列表时出错: {str(e)}") logger.exception(f"获取类别列表时出错: {str(e)}")
return jsonify({"success": False, "error": f"获取类别列表失败: {str(e)}"}), 500 return jsonify({"success": False, "error": f"获取类别列表失败: {str(e)}"}), 500
finally: finally:
if cursor:
cursor.close() cursor.close()

View File

@ -3,9 +3,12 @@ import os
import jieba import jieba
import numpy as np import numpy as np
import pickle import pickle
from tensorflow.keras.models import load_model import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences # from tensorflow import keras # tf.keras is preferred
from keras.models import load_model # Keep if specifically needed, else use tf.keras.models.load_model
from keras.preprocessing.sequence import pad_sequences # Keep if specifically needed, else use tf.keras.preprocessing.sequence.pad_sequences
import logging import logging
import h5py # Moved import here as it's used conditionally
class TextClassificationModel: class TextClassificationModel:
@ -34,16 +37,42 @@ class TextClassificationModel:
# 设置日志 # 设置日志
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) # Basic logging setup if not configured elsewhere
def initialize(self): def initialize(self):
"""初始化并加载模型和分词器""" """初始化并加载模型和分词器""" # <--- Corrected Indentation Starts Here
try: try:
self.logger.info("开始加载文本分类模型...") self.logger.info("开始加载文本分类模型...")
# 加载模型
self.model = load_model(self.model_path) # 优先尝试加载 HDF5 格式 (.h5),因为文件名是 .h5
self.logger.info("模型加载成功") try:
self.logger.info(f"尝试以 HDF5 格式加载模型: {self.model_path}")
# For H5 files, direct load_model is usually sufficient if saved correctly.
# compile=False is often needed if you don't need training features immediately.
self.model = tf.keras.models.load_model(self.model_path, compile=False)
self.logger.info("HDF5 模型加载成功")
except Exception as h5_exc:
self.logger.warning(f"HDF5 格式加载失败 ({h5_exc}),尝试以 SavedModel 格式加载...")
# 如果 HDF5 加载失败,再尝试 SavedModel 格式 (通常是一个目录,而不是 .h5 文件)
# This might fail if model_path truly points to an h5 file.
try:
self.model = tf.keras.models.load_model(
self.model_path,
compile=False # Usually false for inference
# custom_objects can be added here if needed
# options=tf.saved_model.LoadOptions(experimental_io_device='/job:localhost') # Usually not needed unless specific TF distribution setup
)
self.logger.info("SavedModel 格式加载成功")
except Exception as sm_exc:
self.logger.error(f"SavedModel 格式加载也失败 ({sm_exc}). 无法加载模型。")
# Consider adding the fallback JSON+weights logic here if needed,
# but it's less common now.
# Re-raising or handling the error appropriately
raise ValueError(f"无法加载模型文件: {self.model_path}. H5 Error: {h5_exc}, SavedModel Error: {sm_exc}")
# 加载tokenizer # 加载tokenizer
self.logger.info(f"开始加载 Tokenizer: {self.tokenizer_path}")
with open(self.tokenizer_path, 'rb') as handle: with open(self.tokenizer_path, 'rb') as handle:
self.tokenizer = pickle.load(handle) self.tokenizer = pickle.load(handle)
self.logger.info("Tokenizer加载成功") self.logger.info("Tokenizer加载成功")
@ -51,8 +80,9 @@ class TextClassificationModel:
self.is_initialized = True self.is_initialized = True
self.logger.info("模型初始化完成") self.logger.info("模型初始化完成")
return True return True
except Exception as e: except Exception as e:
self.logger.error(f"模型初始化失败: {str(e)}") self.logger.exception(f"模型初始化过程中发生严重错误: {str(e)}") # Use logger.exception to include traceback
self.is_initialized = False self.is_initialized = False
return False return False
@ -80,9 +110,12 @@ class TextClassificationModel:
dict: 分类结果包含类别标签和置信度 dict: 分类结果包含类别标签和置信度
""" """
if not self.is_initialized: if not self.is_initialized:
self.logger.warning("模型尚未初始化,尝试现在初始化...")
success = self.initialize() success = self.initialize()
if not success: if not success:
self.logger.error("分类前初始化模型失败。")
return {"success": False, "error": "模型初始化失败"} return {"success": False, "error": "模型初始化失败"}
self.logger.info("模型初始化成功,继续分类。")
try: try:
# 文本预处理 # 文本预处理
@ -92,17 +125,23 @@ class TextClassificationModel:
sequence = self.tokenizer.texts_to_sequences([processed_text]) sequence = self.tokenizer.texts_to_sequences([processed_text])
# 填充序列 # 填充序列
padded_sequence = pad_sequences(sequence, maxlen=self.max_length, padding="post") padded_sequence = tf.keras.preprocessing.sequence.pad_sequences( # Using tf.keras path
sequence, maxlen=self.max_length, padding="post"
)
# 预测 # 预测
predictions = self.model.predict(padded_sequence) predictions = self.model.predict(padded_sequence)
# 获取预测类别索引和置信度 # 获取预测类别索引和置信度
predicted_index = np.argmax(predictions, axis=1)[0] predicted_index = np.argmax(predictions, axis=1)[0]
confidence = float(predictions[0][predicted_index]) confidence = float(predictions[0][predicted_index]) # Convert numpy float to python float
# 获取预测类别标签 # 获取预测类别标签
if predicted_index < len(self.CATEGORIES):
predicted_label = self.CATEGORIES[predicted_index] predicted_label = self.CATEGORIES[predicted_index]
else:
self.logger.warning(f"预测索引 {predicted_index} 超出类别列表范围!")
predicted_label = "未知类别" # Handle out-of-bounds index
# 获取所有类别的置信度 # 获取所有类别的置信度
all_confidences = {cat: float(conf) for cat, conf in zip(self.CATEGORIES, predictions[0])} all_confidences = {cat: float(conf) for cat, conf in zip(self.CATEGORIES, predictions[0])}
@ -114,8 +153,8 @@ class TextClassificationModel:
"all_confidences": all_confidences "all_confidences": all_confidences
} }
except Exception as e: except Exception as e:
self.logger.error(f"文本分类过程中发生错误: {str(e)}") self.logger.exception(f"文本分类过程中发生错误: {str(e)}") # Use logger.exception
return {"success": False, "error": str(e)} return {"success": False, "error": f"分类错误: {str(e)}"}
def classify_file(self, file_path): def classify_file(self, file_path):
"""对文件内容进行分类 """对文件内容进行分类
@ -126,28 +165,29 @@ class TextClassificationModel:
Returns: Returns:
dict: 分类结果包含类别标签和置信度 dict: 分类结果包含类别标签和置信度
""" """
text = None
encodings_to_try = ['utf-8', 'gbk', 'gb18030'] # Common encodings
for enc in encodings_to_try:
try: try:
# 读取文件内容 with open(file_path, 'r', encoding=enc) as f:
with open(file_path, 'r', encoding='utf-8') as f:
text = f.read().strip() text = f.read().strip()
self.logger.info(f"成功以 {enc} 编码读取文件: {file_path}")
break # Exit loop if read successful
except UnicodeDecodeError:
self.logger.warning(f"使用 {enc} 解码文件失败: {file_path}")
continue # Try next encoding
except Exception as e:
self.logger.error(f"读取文件时发生其他错误 ({enc}): {str(e)}")
return {"success": False, "error": f"文件读取错误 ({enc}): {str(e)}"}
if text is None:
self.logger.error(f"尝试所有编码后仍无法读取文件: {file_path}")
return {"success": False, "error": f"文件解码失败,尝试的编码: {encodings_to_try}"}
# 调用文本分类函数 # 调用文本分类函数
return self.classify_text(text) return self.classify_text(text)
except UnicodeDecodeError:
# 如果UTF-8解码失败尝试其他编码
try:
with open(file_path, 'r', encoding='gbk') as f:
text = f.read().strip()
return self.classify_text(text)
except Exception as e:
return {"success": False, "error": f"文件解码失败: {str(e)}"}
except Exception as e:
self.logger.error(f"文件处理过程中发生错误: {str(e)}")
return {"success": False, "error": f"文件处理错误: {str(e)}"}
# 创建单例实例,避免重复加载模型 # 创建单例实例,避免重复加载模型
# Consider lazy initialization if the model is large and not always needed immediately
text_classifier = TextClassificationModel() text_classifier = TextClassificationModel()