text-classify-ui/utils/model_service.py
superlishunqin 24953f68df fix_bug
2025-03-31 03:06:26 +08:00

194 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# utils/model_service.py
import os
import jieba
import numpy as np
import pickle
import tensorflow as tf
# from tensorflow import keras # tf.keras is preferred
from keras.models import load_model # Keep if specifically needed, else use tf.keras.models.load_model
from keras.preprocessing.sequence import pad_sequences # Keep if specifically needed, else use tf.keras.preprocessing.sequence.pad_sequences
import logging
import h5py # Moved import here as it's used conditionally
class TextClassificationModel:
"""中文文本分类模型服务封装类"""
# 类别列表 - 与模型训练时保持一致
CATEGORIES = ["体育", "娱乐", "家居", "彩票", "房产", "教育",
"时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"]
def __init__(self, model_path=None, tokenizer_path=None, max_length=500):
"""初始化模型服务
Args:
model_path (str): 模型文件路径默认为项目根目录下的trained_model.h5
tokenizer_path (str): 分词器文件路径默认为项目根目录下的tokenizer.pickle
max_length (int): 文本序列最大长度默认为500
"""
self.model_path = model_path or os.path.join(os.path.dirname(os.path.dirname(
os.path.abspath(__file__))), 'model', 'trained_model.h5')
self.tokenizer_path = tokenizer_path or os.path.join(os.path.dirname(os.path.dirname(
os.path.abspath(__file__))), 'model', 'tokenizer.pickle')
self.max_length = max_length
self.model = None
self.tokenizer = None
self.is_initialized = False
# 设置日志
self.logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) # Basic logging setup if not configured elsewhere
def initialize(self):
"""初始化并加载模型和分词器""" # <--- Corrected Indentation Starts Here
try:
self.logger.info("开始加载文本分类模型...")
# 优先尝试加载 HDF5 格式 (.h5),因为文件名是 .h5
try:
self.logger.info(f"尝试以 HDF5 格式加载模型: {self.model_path}")
# For H5 files, direct load_model is usually sufficient if saved correctly.
# compile=False is often needed if you don't need training features immediately.
self.model = tf.keras.models.load_model(self.model_path, compile=False)
self.logger.info("HDF5 模型加载成功")
except Exception as h5_exc:
self.logger.warning(f"HDF5 格式加载失败 ({h5_exc}),尝试以 SavedModel 格式加载...")
# 如果 HDF5 加载失败,再尝试 SavedModel 格式 (通常是一个目录,而不是 .h5 文件)
# This might fail if model_path truly points to an h5 file.
try:
self.model = tf.keras.models.load_model(
self.model_path,
compile=False # Usually false for inference
# custom_objects can be added here if needed
# options=tf.saved_model.LoadOptions(experimental_io_device='/job:localhost') # Usually not needed unless specific TF distribution setup
)
self.logger.info("SavedModel 格式加载成功")
except Exception as sm_exc:
self.logger.error(f"SavedModel 格式加载也失败 ({sm_exc}). 无法加载模型。")
# Consider adding the fallback JSON+weights logic here if needed,
# but it's less common now.
# Re-raising or handling the error appropriately
raise ValueError(f"无法加载模型文件: {self.model_path}. H5 Error: {h5_exc}, SavedModel Error: {sm_exc}")
# 加载tokenizer
self.logger.info(f"开始加载 Tokenizer: {self.tokenizer_path}")
with open(self.tokenizer_path, 'rb') as handle:
self.tokenizer = pickle.load(handle)
self.logger.info("Tokenizer加载成功")
self.is_initialized = True
self.logger.info("模型初始化完成")
return True
except Exception as e:
self.logger.exception(f"模型初始化过程中发生严重错误: {str(e)}") # Use logger.exception to include traceback
self.is_initialized = False
return False
def preprocess_text(self, text):
"""对文本进行预处理
Args:
text (str): 待处理的原始文本
Returns:
str: 处理后的文本
"""
# 使用jieba进行分词
tokens = jieba.lcut(text)
# 将分词结果用空格连接成字符串
return " ".join(tokens)
def classify_text(self, text):
"""对文本进行分类
Args:
text (str): 待分类的文本
Returns:
dict: 分类结果,包含类别标签和置信度
"""
if not self.is_initialized:
self.logger.warning("模型尚未初始化,尝试现在初始化...")
success = self.initialize()
if not success:
self.logger.error("分类前初始化模型失败。")
return {"success": False, "error": "模型初始化失败"}
self.logger.info("模型初始化成功,继续分类。")
try:
# 文本预处理
processed_text = self.preprocess_text(text)
# 转换为序列
sequence = self.tokenizer.texts_to_sequences([processed_text])
# 填充序列
padded_sequence = tf.keras.preprocessing.sequence.pad_sequences( # Using tf.keras path
sequence, maxlen=self.max_length, padding="post"
)
# 预测
predictions = self.model.predict(padded_sequence)
# 获取预测类别索引和置信度
predicted_index = np.argmax(predictions, axis=1)[0]
confidence = float(predictions[0][predicted_index]) # Convert numpy float to python float
# 获取预测类别标签
if predicted_index < len(self.CATEGORIES):
predicted_label = self.CATEGORIES[predicted_index]
else:
self.logger.warning(f"预测索引 {predicted_index} 超出类别列表范围!")
predicted_label = "未知类别" # Handle out-of-bounds index
# 获取所有类别的置信度
all_confidences = {cat: float(conf) for cat, conf in zip(self.CATEGORIES, predictions[0])}
return {
"success": True,
"category": predicted_label,
"confidence": confidence,
"all_confidences": all_confidences
}
except Exception as e:
self.logger.exception(f"文本分类过程中发生错误: {str(e)}") # Use logger.exception
return {"success": False, "error": f"分类错误: {str(e)}"}
def classify_file(self, file_path):
"""对文件内容进行分类
Args:
file_path (str): 文件路径
Returns:
dict: 分类结果,包含类别标签和置信度
"""
text = None
encodings_to_try = ['utf-8', 'gbk', 'gb18030'] # Common encodings
for enc in encodings_to_try:
try:
with open(file_path, 'r', encoding=enc) as f:
text = f.read().strip()
self.logger.info(f"成功以 {enc} 编码读取文件: {file_path}")
break # Exit loop if read successful
except UnicodeDecodeError:
self.logger.warning(f"使用 {enc} 解码文件失败: {file_path}")
continue # Try next encoding
except Exception as e:
self.logger.error(f"读取文件时发生其他错误 ({enc}): {str(e)}")
return {"success": False, "error": f"文件读取错误 ({enc}): {str(e)}"}
if text is None:
self.logger.error(f"尝试所有编码后仍无法读取文件: {file_path}")
return {"success": False, "error": f"文件解码失败,尝试的编码: {encodings_to_try}"}
# 调用文本分类函数
return self.classify_text(text)
# 创建单例实例,避免重复加载模型
# Consider lazy initialization if the model is large and not always needed immediately
text_classifier = TextClassificationModel()