# utils/model_service.py import os import jieba import numpy as np import pickle from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing.sequence import pad_sequences import logging class TextClassificationModel: """中文文本分类模型服务封装类""" # 类别列表 - 与模型训练时保持一致 CATEGORIES = ["体育", "娱乐", "家居", "彩票", "房产", "教育", "时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"] def __init__(self, model_path=None, tokenizer_path=None, max_length=500): """初始化模型服务 Args: model_path (str): 模型文件路径,默认为项目根目录下的trained_model.h5 tokenizer_path (str): 分词器文件路径,默认为项目根目录下的tokenizer.pickle max_length (int): 文本序列最大长度,默认为500 """ self.model_path = model_path or os.path.join(os.path.dirname(os.path.dirname( os.path.abspath(__file__))), 'model', 'trained_model.h5') self.tokenizer_path = tokenizer_path or os.path.join(os.path.dirname(os.path.dirname( os.path.abspath(__file__))), 'model', 'tokenizer.pickle') self.max_length = max_length self.model = None self.tokenizer = None self.is_initialized = False # 设置日志 self.logger = logging.getLogger(__name__) def initialize(self): """初始化并加载模型和分词器""" try: self.logger.info("开始加载文本分类模型...") # 加载模型 self.model = load_model(self.model_path) self.logger.info("模型加载成功") # 加载tokenizer with open(self.tokenizer_path, 'rb') as handle: self.tokenizer = pickle.load(handle) self.logger.info("Tokenizer加载成功") self.is_initialized = True self.logger.info("模型初始化完成") return True except Exception as e: self.logger.error(f"模型初始化失败: {str(e)}") self.is_initialized = False return False def preprocess_text(self, text): """对文本进行预处理 Args: text (str): 待处理的原始文本 Returns: str: 处理后的文本 """ # 使用jieba进行分词 tokens = jieba.lcut(text) # 将分词结果用空格连接成字符串 return " ".join(tokens) def classify_text(self, text): """对文本进行分类 Args: text (str): 待分类的文本 Returns: dict: 分类结果,包含类别标签和置信度 """ if not self.is_initialized: success = self.initialize() if not success: return {"success": False, "error": "模型初始化失败"} try: # 文本预处理 processed_text = self.preprocess_text(text) # 转换为序列 sequence = self.tokenizer.texts_to_sequences([processed_text]) # 填充序列 padded_sequence = pad_sequences(sequence, maxlen=self.max_length, padding="post") # 预测 predictions = self.model.predict(padded_sequence) # 获取预测类别索引和置信度 predicted_index = np.argmax(predictions, axis=1)[0] confidence = float(predictions[0][predicted_index]) # 获取预测类别标签 predicted_label = self.CATEGORIES[predicted_index] # 获取所有类别的置信度 all_confidences = {cat: float(conf) for cat, conf in zip(self.CATEGORIES, predictions[0])} return { "success": True, "category": predicted_label, "confidence": confidence, "all_confidences": all_confidences } except Exception as e: self.logger.error(f"文本分类过程中发生错误: {str(e)}") return {"success": False, "error": str(e)} def classify_file(self, file_path): """对文件内容进行分类 Args: file_path (str): 文件路径 Returns: dict: 分类结果,包含类别标签和置信度 """ try: # 读取文件内容 with open(file_path, 'r', encoding='utf-8') as f: text = f.read().strip() # 调用文本分类函数 return self.classify_text(text) except UnicodeDecodeError: # 如果UTF-8解码失败,尝试其他编码 try: with open(file_path, 'r', encoding='gbk') as f: text = f.read().strip() return self.classify_text(text) except Exception as e: return {"success": False, "error": f"文件解码失败: {str(e)}"} except Exception as e: self.logger.error(f"文件处理过程中发生错误: {str(e)}") return {"success": False, "error": f"文件处理错误: {str(e)}"} # 创建单例实例,避免重复加载模型 text_classifier = TextClassificationModel()