2025-03-08 01:34:36 +08:00

221 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RNN模型实现基于循环神经网络的文本分类模型
"""
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
Input, Embedding, LSTM, GRU, Bidirectional, Dense, Dropout,
BatchNormalization, Activation, GlobalMaxPooling1D, GlobalAveragePooling1D
)
from typing import List, Dict, Tuple, Optional, Any, Union
import numpy as np
from config.model_config import (
MAX_SEQUENCE_LENGTH, RNN_CONFIG
)
from models.base_model import TextClassificationModel
from utils.logger import get_logger
logger = get_logger("RNNModel")
class RNNTextClassifier(TextClassificationModel):
"""循环神经网络文本分类模型"""
def __init__(self, num_classes: int, vocab_size: int,
embedding_dim: int = RNN_CONFIG["embedding_dim"],
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
hidden_size: int = RNN_CONFIG["hidden_size"],
num_layers: int = RNN_CONFIG["num_layers"],
bidirectional: bool = RNN_CONFIG["bidirectional"],
rnn_type: str = "lstm", # 'lstm' or 'gru'
dropout_rate: float = RNN_CONFIG["dropout_rate"],
embedding_matrix: Optional[np.ndarray] = None,
trainable_embedding: bool = True,
pool_type: str = "max", # 'max', 'avg', or 'both'
model_name: str = "rnn_text_classifier",
batch_size: int = 64,
learning_rate: float = 0.001):
"""
初始化RNN文本分类模型
Args:
num_classes: 类别数量
vocab_size: 词汇表大小
embedding_dim: 词嵌入维度
max_sequence_length: 最大序列长度
hidden_size: 隐藏层大小
num_layers: RNN层数
bidirectional: 是否使用双向RNN
rnn_type: RNN类型'lstm''gru'
dropout_rate: Dropout比例
embedding_matrix: 预训练词嵌入矩阵如果为None则使用随机初始化
trainable_embedding: 词嵌入是否可训练
pool_type: 池化类型,'max''avg''both'
model_name: 模型名称
batch_size: 批大小
learning_rate: 学习率
"""
super().__init__(num_classes, model_name, batch_size, learning_rate)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.max_sequence_length = max_sequence_length
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
self.rnn_type = rnn_type.lower()
self.dropout_rate = dropout_rate
self.embedding_matrix = embedding_matrix
self.trainable_embedding = trainable_embedding
self.pool_type = pool_type
# 验证RNN类型
if self.rnn_type not in ["lstm", "gru"]:
raise ValueError("无效的RNN类型支持的类型: 'lstm', 'gru'")
# 验证池化类型
if self.pool_type not in ["max", "avg", "both"]:
raise ValueError("无效的池化类型,支持的类型: 'max', 'avg', 'both'")
# 更新配置
self.config.update({
"vocab_size": vocab_size,
"embedding_dim": embedding_dim,
"max_sequence_length": max_sequence_length,
"hidden_size": hidden_size,
"num_layers": num_layers,
"bidirectional": bidirectional,
"rnn_type": rnn_type,
"dropout_rate": dropout_rate,
"trainable_embedding": trainable_embedding,
"pool_type": pool_type,
"model_type": "RNN"
})
logger.info(f"初始化RNN文本分类模型类型: {rnn_type.upper()}, 隐藏层大小: {hidden_size}, 层数: {num_layers}")
def build(self) -> None:
"""构建RNN模型架构"""
# Input layer
sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
# Embedding layer
if self.embedding_matrix is not None:
embedding_layer = Embedding(
input_dim=self.vocab_size,
output_dim=self.embedding_dim,
weights=[self.embedding_matrix],
input_length=self.max_sequence_length,
trainable=self.trainable_embedding,
name='embedding'
)
else:
embedding_layer = Embedding(
input_dim=self.vocab_size,
output_dim=self.embedding_dim,
input_length=self.max_sequence_length,
trainable=True,
name='embedding'
)
embedded_sequences = embedding_layer(sequence_input)
# 选择RNN层类型
if self.rnn_type == "lstm":
rnn_layer = LSTM
else: # gru
rnn_layer = GRU
# 构建多层RNN
x = embedded_sequences
for i in range(self.num_layers):
return_sequences = i < self.num_layers - 1 or self.pool_type != "last"
if self.bidirectional:
x = Bidirectional(
rnn_layer(
self.hidden_size,
return_sequences=return_sequences,
dropout=self.dropout_rate if i < self.num_layers - 1 else 0,
name=f'{self.rnn_type}_{i + 1}'
)
)(x)
else:
x = rnn_layer(
self.hidden_size,
return_sequences=return_sequences,
dropout=self.dropout_rate if i < self.num_layers - 1 else 0,
name=f'{self.rnn_type}_{i + 1}'
)(x)
# 根据池化类型选择池化方法
if self.pool_type == "max":
# 使用全局最大池化
pooled = GlobalMaxPooling1D(name='global_max_pooling')(x)
elif self.pool_type == "avg":
# 使用全局平均池化
pooled = GlobalAveragePooling1D(name='global_avg_pooling')(x)
elif self.pool_type == "both":
# 同时使用最大池化和平均池化,然后拼接
max_pooled = GlobalMaxPooling1D(name='global_max_pooling')(x)
avg_pooled = GlobalAveragePooling1D(name='global_avg_pooling')(x)
pooled = tf.keras.layers.Concatenate(name='concatenate')([max_pooled, avg_pooled])
else: # "last",使用最后一个时间步的输出
# 最后一层RNN已经返回了最后一个时间步的状态不需要额外池化
pooled = x
# Dropout for regularization
x = Dropout(self.dropout_rate, name='dropout_1')(pooled)
# Dense layer
x = Dense(128, name='dense_1')(x)
x = BatchNormalization(name='batch_norm_1')(x)
x = Activation('relu', name='activation_1')(x)
x = Dropout(self.dropout_rate, name='dropout_2')(x)
# Output layer
if self.num_classes == 2:
# Binary classification
predictions = Dense(1, activation='sigmoid', name='predictions')(x)
else:
# Multi-class classification
predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
# Build the model
self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
logger.info(
f"RNN模型构建完成类型: {self.rnn_type.upper()}, 双向: {self.bidirectional}, 池化类型: {self.pool_type}")
def compile(self, optimizer=None, loss=None, metrics=None) -> None:
"""
编译RNN模型
Args:
optimizer: 优化器默认为Adam
loss: 损失函数,默认根据类别数量选择
metrics: 评估指标默认为accuracy
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 默认优化器
if optimizer is None:
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
# 默认损失函数
if loss is None:
if self.num_classes == 2:
loss = 'binary_crossentropy'
else:
loss = 'sparse_categorical_crossentropy'
# 默认评估指标
if metrics is None:
metrics = ['accuracy']
# 编译模型
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
logger.info(f"RNN模型已编译优化器: {optimizer.__class__.__name__}, 损失函数: {loss}")