text-classify-ui/utils/model_service.py
superlishunqin f434b83090 first commit
2025-03-17 22:43:53 +08:00

154 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# utils/model_service.py
import os
import jieba
import numpy as np
import pickle
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.sequence import pad_sequences
import logging
class TextClassificationModel:
"""中文文本分类模型服务封装类"""
# 类别列表 - 与模型训练时保持一致
CATEGORIES = ["体育", "娱乐", "家居", "彩票", "房产", "教育",
"时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"]
def __init__(self, model_path=None, tokenizer_path=None, max_length=500):
"""初始化模型服务
Args:
model_path (str): 模型文件路径默认为项目根目录下的trained_model.h5
tokenizer_path (str): 分词器文件路径默认为项目根目录下的tokenizer.pickle
max_length (int): 文本序列最大长度默认为500
"""
self.model_path = model_path or os.path.join(os.path.dirname(os.path.dirname(
os.path.abspath(__file__))), 'model', 'trained_model.h5')
self.tokenizer_path = tokenizer_path or os.path.join(os.path.dirname(os.path.dirname(
os.path.abspath(__file__))), 'model', 'tokenizer.pickle')
self.max_length = max_length
self.model = None
self.tokenizer = None
self.is_initialized = False
# 设置日志
self.logger = logging.getLogger(__name__)
def initialize(self):
"""初始化并加载模型和分词器"""
try:
self.logger.info("开始加载文本分类模型...")
# 加载模型
self.model = load_model(self.model_path)
self.logger.info("模型加载成功")
# 加载tokenizer
with open(self.tokenizer_path, 'rb') as handle:
self.tokenizer = pickle.load(handle)
self.logger.info("Tokenizer加载成功")
self.is_initialized = True
self.logger.info("模型初始化完成")
return True
except Exception as e:
self.logger.error(f"模型初始化失败: {str(e)}")
self.is_initialized = False
return False
def preprocess_text(self, text):
"""对文本进行预处理
Args:
text (str): 待处理的原始文本
Returns:
str: 处理后的文本
"""
# 使用jieba进行分词
tokens = jieba.lcut(text)
# 将分词结果用空格连接成字符串
return " ".join(tokens)
def classify_text(self, text):
"""对文本进行分类
Args:
text (str): 待分类的文本
Returns:
dict: 分类结果,包含类别标签和置信度
"""
if not self.is_initialized:
success = self.initialize()
if not success:
return {"success": False, "error": "模型初始化失败"}
try:
# 文本预处理
processed_text = self.preprocess_text(text)
# 转换为序列
sequence = self.tokenizer.texts_to_sequences([processed_text])
# 填充序列
padded_sequence = pad_sequences(sequence, maxlen=self.max_length, padding="post")
# 预测
predictions = self.model.predict(padded_sequence)
# 获取预测类别索引和置信度
predicted_index = np.argmax(predictions, axis=1)[0]
confidence = float(predictions[0][predicted_index])
# 获取预测类别标签
predicted_label = self.CATEGORIES[predicted_index]
# 获取所有类别的置信度
all_confidences = {cat: float(conf) for cat, conf in zip(self.CATEGORIES, predictions[0])}
return {
"success": True,
"category": predicted_label,
"confidence": confidence,
"all_confidences": all_confidences
}
except Exception as e:
self.logger.error(f"文本分类过程中发生错误: {str(e)}")
return {"success": False, "error": str(e)}
def classify_file(self, file_path):
"""对文件内容进行分类
Args:
file_path (str): 文件路径
Returns:
dict: 分类结果,包含类别标签和置信度
"""
try:
# 读取文件内容
with open(file_path, 'r', encoding='utf-8') as f:
text = f.read().strip()
# 调用文本分类函数
return self.classify_text(text)
except UnicodeDecodeError:
# 如果UTF-8解码失败尝试其他编码
try:
with open(file_path, 'r', encoding='gbk') as f:
text = f.read().strip()
return self.classify_text(text)
except Exception as e:
return {"success": False, "error": f"文件解码失败: {str(e)}"}
except Exception as e:
self.logger.error(f"文件处理过程中发生错误: {str(e)}")
return {"success": False, "error": f"文件处理错误: {str(e)}"}
# 创建单例实例,避免重复加载模型
text_classifier = TextClassificationModel()