154 lines
5.2 KiB
Python
154 lines
5.2 KiB
Python
# utils/model_service.py
|
||
import os
|
||
import jieba
|
||
import numpy as np
|
||
import pickle
|
||
from tensorflow.keras.models import load_model
|
||
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
||
import logging
|
||
|
||
|
||
class TextClassificationModel:
|
||
"""中文文本分类模型服务封装类"""
|
||
|
||
# 类别列表 - 与模型训练时保持一致
|
||
CATEGORIES = ["体育", "娱乐", "家居", "彩票", "房产", "教育",
|
||
"时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"]
|
||
|
||
def __init__(self, model_path=None, tokenizer_path=None, max_length=500):
|
||
"""初始化模型服务
|
||
|
||
Args:
|
||
model_path (str): 模型文件路径,默认为项目根目录下的trained_model.h5
|
||
tokenizer_path (str): 分词器文件路径,默认为项目根目录下的tokenizer.pickle
|
||
max_length (int): 文本序列最大长度,默认为500
|
||
"""
|
||
self.model_path = model_path or os.path.join(os.path.dirname(os.path.dirname(
|
||
os.path.abspath(__file__))), 'model', 'trained_model.h5')
|
||
self.tokenizer_path = tokenizer_path or os.path.join(os.path.dirname(os.path.dirname(
|
||
os.path.abspath(__file__))), 'model', 'tokenizer.pickle')
|
||
self.max_length = max_length
|
||
self.model = None
|
||
self.tokenizer = None
|
||
self.is_initialized = False
|
||
|
||
# 设置日志
|
||
self.logger = logging.getLogger(__name__)
|
||
|
||
def initialize(self):
|
||
"""初始化并加载模型和分词器"""
|
||
try:
|
||
self.logger.info("开始加载文本分类模型...")
|
||
# 加载模型
|
||
self.model = load_model(self.model_path)
|
||
self.logger.info("模型加载成功")
|
||
|
||
# 加载tokenizer
|
||
with open(self.tokenizer_path, 'rb') as handle:
|
||
self.tokenizer = pickle.load(handle)
|
||
self.logger.info("Tokenizer加载成功")
|
||
|
||
self.is_initialized = True
|
||
self.logger.info("模型初始化完成")
|
||
return True
|
||
except Exception as e:
|
||
self.logger.error(f"模型初始化失败: {str(e)}")
|
||
self.is_initialized = False
|
||
return False
|
||
|
||
def preprocess_text(self, text):
|
||
"""对文本进行预处理
|
||
|
||
Args:
|
||
text (str): 待处理的原始文本
|
||
|
||
Returns:
|
||
str: 处理后的文本
|
||
"""
|
||
# 使用jieba进行分词
|
||
tokens = jieba.lcut(text)
|
||
# 将分词结果用空格连接成字符串
|
||
return " ".join(tokens)
|
||
|
||
def classify_text(self, text):
|
||
"""对文本进行分类
|
||
|
||
Args:
|
||
text (str): 待分类的文本
|
||
|
||
Returns:
|
||
dict: 分类结果,包含类别标签和置信度
|
||
"""
|
||
if not self.is_initialized:
|
||
success = self.initialize()
|
||
if not success:
|
||
return {"success": False, "error": "模型初始化失败"}
|
||
|
||
try:
|
||
# 文本预处理
|
||
processed_text = self.preprocess_text(text)
|
||
|
||
# 转换为序列
|
||
sequence = self.tokenizer.texts_to_sequences([processed_text])
|
||
|
||
# 填充序列
|
||
padded_sequence = pad_sequences(sequence, maxlen=self.max_length, padding="post")
|
||
|
||
# 预测
|
||
predictions = self.model.predict(padded_sequence)
|
||
|
||
# 获取预测类别索引和置信度
|
||
predicted_index = np.argmax(predictions, axis=1)[0]
|
||
confidence = float(predictions[0][predicted_index])
|
||
|
||
# 获取预测类别标签
|
||
predicted_label = self.CATEGORIES[predicted_index]
|
||
|
||
# 获取所有类别的置信度
|
||
all_confidences = {cat: float(conf) for cat, conf in zip(self.CATEGORIES, predictions[0])}
|
||
|
||
return {
|
||
"success": True,
|
||
"category": predicted_label,
|
||
"confidence": confidence,
|
||
"all_confidences": all_confidences
|
||
}
|
||
except Exception as e:
|
||
self.logger.error(f"文本分类过程中发生错误: {str(e)}")
|
||
return {"success": False, "error": str(e)}
|
||
|
||
def classify_file(self, file_path):
|
||
"""对文件内容进行分类
|
||
|
||
Args:
|
||
file_path (str): 文件路径
|
||
|
||
Returns:
|
||
dict: 分类结果,包含类别标签和置信度
|
||
"""
|
||
try:
|
||
# 读取文件内容
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
text = f.read().strip()
|
||
|
||
# 调用文本分类函数
|
||
return self.classify_text(text)
|
||
|
||
except UnicodeDecodeError:
|
||
# 如果UTF-8解码失败,尝试其他编码
|
||
try:
|
||
with open(file_path, 'r', encoding='gbk') as f:
|
||
text = f.read().strip()
|
||
return self.classify_text(text)
|
||
except Exception as e:
|
||
return {"success": False, "error": f"文件解码失败: {str(e)}"}
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"文件处理过程中发生错误: {str(e)}")
|
||
return {"success": False, "error": f"文件处理错误: {str(e)}"}
|
||
|
||
|
||
# 创建单例实例,避免重复加载模型
|
||
text_classifier = TextClassificationModel()
|
||
|