194 lines
8.4 KiB
Python
194 lines
8.4 KiB
Python
# utils/model_service.py
|
||
import os
|
||
import jieba
|
||
import numpy as np
|
||
import pickle
|
||
import tensorflow as tf
|
||
# from tensorflow import keras # tf.keras is preferred
|
||
from keras.models import load_model # Keep if specifically needed, else use tf.keras.models.load_model
|
||
from keras.preprocessing.sequence import pad_sequences # Keep if specifically needed, else use tf.keras.preprocessing.sequence.pad_sequences
|
||
import logging
|
||
import h5py # Moved import here as it's used conditionally
|
||
|
||
|
||
class TextClassificationModel:
|
||
"""中文文本分类模型服务封装类"""
|
||
|
||
# 类别列表 - 与模型训练时保持一致
|
||
CATEGORIES = ["体育", "娱乐", "家居", "彩票", "房产", "教育",
|
||
"时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"]
|
||
|
||
def __init__(self, model_path=None, tokenizer_path=None, max_length=500):
|
||
"""初始化模型服务
|
||
|
||
Args:
|
||
model_path (str): 模型文件路径,默认为项目根目录下的trained_model.h5
|
||
tokenizer_path (str): 分词器文件路径,默认为项目根目录下的tokenizer.pickle
|
||
max_length (int): 文本序列最大长度,默认为500
|
||
"""
|
||
self.model_path = model_path or os.path.join(os.path.dirname(os.path.dirname(
|
||
os.path.abspath(__file__))), 'model', 'trained_model.h5')
|
||
self.tokenizer_path = tokenizer_path or os.path.join(os.path.dirname(os.path.dirname(
|
||
os.path.abspath(__file__))), 'model', 'tokenizer.pickle')
|
||
self.max_length = max_length
|
||
self.model = None
|
||
self.tokenizer = None
|
||
self.is_initialized = False
|
||
|
||
# 设置日志
|
||
self.logger = logging.getLogger(__name__)
|
||
logging.basicConfig(level=logging.INFO) # Basic logging setup if not configured elsewhere
|
||
|
||
def initialize(self):
|
||
"""初始化并加载模型和分词器""" # <--- Corrected Indentation Starts Here
|
||
try:
|
||
self.logger.info("开始加载文本分类模型...")
|
||
|
||
# 优先尝试加载 HDF5 格式 (.h5),因为文件名是 .h5
|
||
try:
|
||
self.logger.info(f"尝试以 HDF5 格式加载模型: {self.model_path}")
|
||
# For H5 files, direct load_model is usually sufficient if saved correctly.
|
||
# compile=False is often needed if you don't need training features immediately.
|
||
self.model = tf.keras.models.load_model(self.model_path, compile=False)
|
||
self.logger.info("HDF5 模型加载成功")
|
||
|
||
except Exception as h5_exc:
|
||
self.logger.warning(f"HDF5 格式加载失败 ({h5_exc}),尝试以 SavedModel 格式加载...")
|
||
# 如果 HDF5 加载失败,再尝试 SavedModel 格式 (通常是一个目录,而不是 .h5 文件)
|
||
# This might fail if model_path truly points to an h5 file.
|
||
try:
|
||
self.model = tf.keras.models.load_model(
|
||
self.model_path,
|
||
compile=False # Usually false for inference
|
||
# custom_objects can be added here if needed
|
||
# options=tf.saved_model.LoadOptions(experimental_io_device='/job:localhost') # Usually not needed unless specific TF distribution setup
|
||
)
|
||
self.logger.info("SavedModel 格式加载成功")
|
||
except Exception as sm_exc:
|
||
self.logger.error(f"SavedModel 格式加载也失败 ({sm_exc}). 无法加载模型。")
|
||
# Consider adding the fallback JSON+weights logic here if needed,
|
||
# but it's less common now.
|
||
# Re-raising or handling the error appropriately
|
||
raise ValueError(f"无法加载模型文件: {self.model_path}. H5 Error: {h5_exc}, SavedModel Error: {sm_exc}")
|
||
|
||
# 加载tokenizer
|
||
self.logger.info(f"开始加载 Tokenizer: {self.tokenizer_path}")
|
||
with open(self.tokenizer_path, 'rb') as handle:
|
||
self.tokenizer = pickle.load(handle)
|
||
self.logger.info("Tokenizer加载成功")
|
||
|
||
self.is_initialized = True
|
||
self.logger.info("模型初始化完成")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.logger.exception(f"模型初始化过程中发生严重错误: {str(e)}") # Use logger.exception to include traceback
|
||
self.is_initialized = False
|
||
return False
|
||
|
||
def preprocess_text(self, text):
|
||
"""对文本进行预处理
|
||
|
||
Args:
|
||
text (str): 待处理的原始文本
|
||
|
||
Returns:
|
||
str: 处理后的文本
|
||
"""
|
||
# 使用jieba进行分词
|
||
tokens = jieba.lcut(text)
|
||
# 将分词结果用空格连接成字符串
|
||
return " ".join(tokens)
|
||
|
||
def classify_text(self, text):
|
||
"""对文本进行分类
|
||
|
||
Args:
|
||
text (str): 待分类的文本
|
||
|
||
Returns:
|
||
dict: 分类结果,包含类别标签和置信度
|
||
"""
|
||
if not self.is_initialized:
|
||
self.logger.warning("模型尚未初始化,尝试现在初始化...")
|
||
success = self.initialize()
|
||
if not success:
|
||
self.logger.error("分类前初始化模型失败。")
|
||
return {"success": False, "error": "模型初始化失败"}
|
||
self.logger.info("模型初始化成功,继续分类。")
|
||
|
||
try:
|
||
# 文本预处理
|
||
processed_text = self.preprocess_text(text)
|
||
|
||
# 转换为序列
|
||
sequence = self.tokenizer.texts_to_sequences([processed_text])
|
||
|
||
# 填充序列
|
||
padded_sequence = tf.keras.preprocessing.sequence.pad_sequences( # Using tf.keras path
|
||
sequence, maxlen=self.max_length, padding="post"
|
||
)
|
||
|
||
# 预测
|
||
predictions = self.model.predict(padded_sequence)
|
||
|
||
# 获取预测类别索引和置信度
|
||
predicted_index = np.argmax(predictions, axis=1)[0]
|
||
confidence = float(predictions[0][predicted_index]) # Convert numpy float to python float
|
||
|
||
# 获取预测类别标签
|
||
if predicted_index < len(self.CATEGORIES):
|
||
predicted_label = self.CATEGORIES[predicted_index]
|
||
else:
|
||
self.logger.warning(f"预测索引 {predicted_index} 超出类别列表范围!")
|
||
predicted_label = "未知类别" # Handle out-of-bounds index
|
||
|
||
# 获取所有类别的置信度
|
||
all_confidences = {cat: float(conf) for cat, conf in zip(self.CATEGORIES, predictions[0])}
|
||
|
||
return {
|
||
"success": True,
|
||
"category": predicted_label,
|
||
"confidence": confidence,
|
||
"all_confidences": all_confidences
|
||
}
|
||
except Exception as e:
|
||
self.logger.exception(f"文本分类过程中发生错误: {str(e)}") # Use logger.exception
|
||
return {"success": False, "error": f"分类错误: {str(e)}"}
|
||
|
||
def classify_file(self, file_path):
|
||
"""对文件内容进行分类
|
||
|
||
Args:
|
||
file_path (str): 文件路径
|
||
|
||
Returns:
|
||
dict: 分类结果,包含类别标签和置信度
|
||
"""
|
||
text = None
|
||
encodings_to_try = ['utf-8', 'gbk', 'gb18030'] # Common encodings
|
||
for enc in encodings_to_try:
|
||
try:
|
||
with open(file_path, 'r', encoding=enc) as f:
|
||
text = f.read().strip()
|
||
self.logger.info(f"成功以 {enc} 编码读取文件: {file_path}")
|
||
break # Exit loop if read successful
|
||
except UnicodeDecodeError:
|
||
self.logger.warning(f"使用 {enc} 解码文件失败: {file_path}")
|
||
continue # Try next encoding
|
||
except Exception as e:
|
||
self.logger.error(f"读取文件时发生其他错误 ({enc}): {str(e)}")
|
||
return {"success": False, "error": f"文件读取错误 ({enc}): {str(e)}"}
|
||
|
||
if text is None:
|
||
self.logger.error(f"尝试所有编码后仍无法读取文件: {file_path}")
|
||
return {"success": False, "error": f"文件解码失败,尝试的编码: {encodings_to_try}"}
|
||
|
||
# 调用文本分类函数
|
||
return self.classify_text(text)
|
||
|
||
|
||
# 创建单例实例,避免重复加载模型
|
||
# Consider lazy initialization if the model is large and not always needed immediately
|
||
text_classifier = TextClassificationModel()
|