# utils/model_service.py import os import jieba import numpy as np import pickle import tensorflow as tf # from tensorflow import keras # tf.keras is preferred from keras.models import load_model # Keep if specifically needed, else use tf.keras.models.load_model from keras.preprocessing.sequence import pad_sequences # Keep if specifically needed, else use tf.keras.preprocessing.sequence.pad_sequences import logging import h5py # Moved import here as it's used conditionally class TextClassificationModel: """中文文本分类模型服务封装类""" # 类别列表 - 与模型训练时保持一致 CATEGORIES = ["体育", "娱乐", "家居", "彩票", "房产", "教育", "时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"] def __init__(self, model_path=None, tokenizer_path=None, max_length=500): """初始化模型服务 Args: model_path (str): 模型文件路径,默认为项目根目录下的trained_model.h5 tokenizer_path (str): 分词器文件路径,默认为项目根目录下的tokenizer.pickle max_length (int): 文本序列最大长度,默认为500 """ self.model_path = model_path or os.path.join(os.path.dirname(os.path.dirname( os.path.abspath(__file__))), 'model', 'trained_model.h5') self.tokenizer_path = tokenizer_path or os.path.join(os.path.dirname(os.path.dirname( os.path.abspath(__file__))), 'model', 'tokenizer.pickle') self.max_length = max_length self.model = None self.tokenizer = None self.is_initialized = False # 设置日志 self.logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # Basic logging setup if not configured elsewhere def initialize(self): """初始化并加载模型和分词器""" # <--- Corrected Indentation Starts Here try: self.logger.info("开始加载文本分类模型...") # 优先尝试加载 HDF5 格式 (.h5),因为文件名是 .h5 try: self.logger.info(f"尝试以 HDF5 格式加载模型: {self.model_path}") # For H5 files, direct load_model is usually sufficient if saved correctly. # compile=False is often needed if you don't need training features immediately. self.model = tf.keras.models.load_model(self.model_path, compile=False) self.logger.info("HDF5 模型加载成功") except Exception as h5_exc: self.logger.warning(f"HDF5 格式加载失败 ({h5_exc}),尝试以 SavedModel 格式加载...") # 如果 HDF5 加载失败,再尝试 SavedModel 格式 (通常是一个目录,而不是 .h5 文件) # This might fail if model_path truly points to an h5 file. try: self.model = tf.keras.models.load_model( self.model_path, compile=False # Usually false for inference # custom_objects can be added here if needed # options=tf.saved_model.LoadOptions(experimental_io_device='/job:localhost') # Usually not needed unless specific TF distribution setup ) self.logger.info("SavedModel 格式加载成功") except Exception as sm_exc: self.logger.error(f"SavedModel 格式加载也失败 ({sm_exc}). 无法加载模型。") # Consider adding the fallback JSON+weights logic here if needed, # but it's less common now. # Re-raising or handling the error appropriately raise ValueError(f"无法加载模型文件: {self.model_path}. H5 Error: {h5_exc}, SavedModel Error: {sm_exc}") # 加载tokenizer self.logger.info(f"开始加载 Tokenizer: {self.tokenizer_path}") with open(self.tokenizer_path, 'rb') as handle: self.tokenizer = pickle.load(handle) self.logger.info("Tokenizer加载成功") self.is_initialized = True self.logger.info("模型初始化完成") return True except Exception as e: self.logger.exception(f"模型初始化过程中发生严重错误: {str(e)}") # Use logger.exception to include traceback self.is_initialized = False return False def preprocess_text(self, text): """对文本进行预处理 Args: text (str): 待处理的原始文本 Returns: str: 处理后的文本 """ # 使用jieba进行分词 tokens = jieba.lcut(text) # 将分词结果用空格连接成字符串 return " ".join(tokens) def classify_text(self, text): """对文本进行分类 Args: text (str): 待分类的文本 Returns: dict: 分类结果,包含类别标签和置信度 """ if not self.is_initialized: self.logger.warning("模型尚未初始化,尝试现在初始化...") success = self.initialize() if not success: self.logger.error("分类前初始化模型失败。") return {"success": False, "error": "模型初始化失败"} self.logger.info("模型初始化成功,继续分类。") try: # 文本预处理 processed_text = self.preprocess_text(text) # 转换为序列 sequence = self.tokenizer.texts_to_sequences([processed_text]) # 填充序列 padded_sequence = tf.keras.preprocessing.sequence.pad_sequences( # Using tf.keras path sequence, maxlen=self.max_length, padding="post" ) # 预测 predictions = self.model.predict(padded_sequence) # 获取预测类别索引和置信度 predicted_index = np.argmax(predictions, axis=1)[0] confidence = float(predictions[0][predicted_index]) # Convert numpy float to python float # 获取预测类别标签 if predicted_index < len(self.CATEGORIES): predicted_label = self.CATEGORIES[predicted_index] else: self.logger.warning(f"预测索引 {predicted_index} 超出类别列表范围!") predicted_label = "未知类别" # Handle out-of-bounds index # 获取所有类别的置信度 all_confidences = {cat: float(conf) for cat, conf in zip(self.CATEGORIES, predictions[0])} return { "success": True, "category": predicted_label, "confidence": confidence, "all_confidences": all_confidences } except Exception as e: self.logger.exception(f"文本分类过程中发生错误: {str(e)}") # Use logger.exception return {"success": False, "error": f"分类错误: {str(e)}"} def classify_file(self, file_path): """对文件内容进行分类 Args: file_path (str): 文件路径 Returns: dict: 分类结果,包含类别标签和置信度 """ text = None encodings_to_try = ['utf-8', 'gbk', 'gb18030'] # Common encodings for enc in encodings_to_try: try: with open(file_path, 'r', encoding=enc) as f: text = f.read().strip() self.logger.info(f"成功以 {enc} 编码读取文件: {file_path}") break # Exit loop if read successful except UnicodeDecodeError: self.logger.warning(f"使用 {enc} 解码文件失败: {file_path}") continue # Try next encoding except Exception as e: self.logger.error(f"读取文件时发生其他错误 ({enc}): {str(e)}") return {"success": False, "error": f"文件读取错误 ({enc}): {str(e)}"} if text is None: self.logger.error(f"尝试所有编码后仍无法读取文件: {file_path}") return {"success": False, "error": f"文件解码失败,尝试的编码: {encodings_to_try}"} # 调用文本分类函数 return self.classify_text(text) # 创建单例实例,避免重复加载模型 # Consider lazy initialization if the model is large and not always needed immediately text_classifier = TextClassificationModel()