221 lines
9.0 KiB
Python
221 lines
9.0 KiB
Python
"""
|
||
RNN模型:实现基于循环神经网络的文本分类模型
|
||
"""
|
||
import tensorflow as tf
|
||
from tensorflow.keras.models import Model
|
||
from tensorflow.keras.layers import (
|
||
Input, Embedding, LSTM, GRU, Bidirectional, Dense, Dropout,
|
||
BatchNormalization, Activation, GlobalMaxPooling1D, GlobalAveragePooling1D
|
||
)
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import numpy as np
|
||
|
||
from config.model_config import (
|
||
MAX_SEQUENCE_LENGTH, RNN_CONFIG
|
||
)
|
||
from models.base_model import TextClassificationModel
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("RNNModel")
|
||
|
||
|
||
class RNNTextClassifier(TextClassificationModel):
|
||
"""循环神经网络文本分类模型"""
|
||
|
||
def __init__(self, num_classes: int, vocab_size: int,
|
||
embedding_dim: int = RNN_CONFIG["embedding_dim"],
|
||
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
|
||
hidden_size: int = RNN_CONFIG["hidden_size"],
|
||
num_layers: int = RNN_CONFIG["num_layers"],
|
||
bidirectional: bool = RNN_CONFIG["bidirectional"],
|
||
rnn_type: str = "lstm", # 'lstm' or 'gru'
|
||
dropout_rate: float = RNN_CONFIG["dropout_rate"],
|
||
embedding_matrix: Optional[np.ndarray] = None,
|
||
trainable_embedding: bool = True,
|
||
pool_type: str = "max", # 'max', 'avg', or 'both'
|
||
model_name: str = "rnn_text_classifier",
|
||
batch_size: int = 64,
|
||
learning_rate: float = 0.001):
|
||
"""
|
||
初始化RNN文本分类模型
|
||
|
||
Args:
|
||
num_classes: 类别数量
|
||
vocab_size: 词汇表大小
|
||
embedding_dim: 词嵌入维度
|
||
max_sequence_length: 最大序列长度
|
||
hidden_size: 隐藏层大小
|
||
num_layers: RNN层数
|
||
bidirectional: 是否使用双向RNN
|
||
rnn_type: RNN类型,'lstm'或'gru'
|
||
dropout_rate: Dropout比例
|
||
embedding_matrix: 预训练词嵌入矩阵,如果为None则使用随机初始化
|
||
trainable_embedding: 词嵌入是否可训练
|
||
pool_type: 池化类型,'max'、'avg'或'both'
|
||
model_name: 模型名称
|
||
batch_size: 批大小
|
||
learning_rate: 学习率
|
||
"""
|
||
super().__init__(num_classes, model_name, batch_size, learning_rate)
|
||
|
||
self.vocab_size = vocab_size
|
||
self.embedding_dim = embedding_dim
|
||
self.max_sequence_length = max_sequence_length
|
||
self.hidden_size = hidden_size
|
||
self.num_layers = num_layers
|
||
self.bidirectional = bidirectional
|
||
self.rnn_type = rnn_type.lower()
|
||
self.dropout_rate = dropout_rate
|
||
self.embedding_matrix = embedding_matrix
|
||
self.trainable_embedding = trainable_embedding
|
||
self.pool_type = pool_type
|
||
|
||
# 验证RNN类型
|
||
if self.rnn_type not in ["lstm", "gru"]:
|
||
raise ValueError("无效的RNN类型,支持的类型: 'lstm', 'gru'")
|
||
|
||
# 验证池化类型
|
||
if self.pool_type not in ["max", "avg", "both"]:
|
||
raise ValueError("无效的池化类型,支持的类型: 'max', 'avg', 'both'")
|
||
|
||
# 更新配置
|
||
self.config.update({
|
||
"vocab_size": vocab_size,
|
||
"embedding_dim": embedding_dim,
|
||
"max_sequence_length": max_sequence_length,
|
||
"hidden_size": hidden_size,
|
||
"num_layers": num_layers,
|
||
"bidirectional": bidirectional,
|
||
"rnn_type": rnn_type,
|
||
"dropout_rate": dropout_rate,
|
||
"trainable_embedding": trainable_embedding,
|
||
"pool_type": pool_type,
|
||
"model_type": "RNN"
|
||
})
|
||
|
||
logger.info(f"初始化RNN文本分类模型,类型: {rnn_type.upper()}, 隐藏层大小: {hidden_size}, 层数: {num_layers}")
|
||
|
||
def build(self) -> None:
|
||
"""构建RNN模型架构"""
|
||
# Input layer
|
||
sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
|
||
|
||
# Embedding layer
|
||
if self.embedding_matrix is not None:
|
||
embedding_layer = Embedding(
|
||
input_dim=self.vocab_size,
|
||
output_dim=self.embedding_dim,
|
||
weights=[self.embedding_matrix],
|
||
input_length=self.max_sequence_length,
|
||
trainable=self.trainable_embedding,
|
||
name='embedding'
|
||
)
|
||
else:
|
||
embedding_layer = Embedding(
|
||
input_dim=self.vocab_size,
|
||
output_dim=self.embedding_dim,
|
||
input_length=self.max_sequence_length,
|
||
trainable=True,
|
||
name='embedding'
|
||
)
|
||
|
||
embedded_sequences = embedding_layer(sequence_input)
|
||
|
||
# 选择RNN层类型
|
||
if self.rnn_type == "lstm":
|
||
rnn_layer = LSTM
|
||
else: # gru
|
||
rnn_layer = GRU
|
||
|
||
# 构建多层RNN
|
||
x = embedded_sequences
|
||
for i in range(self.num_layers):
|
||
return_sequences = i < self.num_layers - 1 or self.pool_type != "last"
|
||
|
||
if self.bidirectional:
|
||
x = Bidirectional(
|
||
rnn_layer(
|
||
self.hidden_size,
|
||
return_sequences=return_sequences,
|
||
dropout=self.dropout_rate if i < self.num_layers - 1 else 0,
|
||
name=f'{self.rnn_type}_{i + 1}'
|
||
)
|
||
)(x)
|
||
else:
|
||
x = rnn_layer(
|
||
self.hidden_size,
|
||
return_sequences=return_sequences,
|
||
dropout=self.dropout_rate if i < self.num_layers - 1 else 0,
|
||
name=f'{self.rnn_type}_{i + 1}'
|
||
)(x)
|
||
|
||
# 根据池化类型选择池化方法
|
||
if self.pool_type == "max":
|
||
# 使用全局最大池化
|
||
pooled = GlobalMaxPooling1D(name='global_max_pooling')(x)
|
||
elif self.pool_type == "avg":
|
||
# 使用全局平均池化
|
||
pooled = GlobalAveragePooling1D(name='global_avg_pooling')(x)
|
||
elif self.pool_type == "both":
|
||
# 同时使用最大池化和平均池化,然后拼接
|
||
max_pooled = GlobalMaxPooling1D(name='global_max_pooling')(x)
|
||
avg_pooled = GlobalAveragePooling1D(name='global_avg_pooling')(x)
|
||
pooled = tf.keras.layers.Concatenate(name='concatenate')([max_pooled, avg_pooled])
|
||
else: # "last",使用最后一个时间步的输出
|
||
# 最后一层RNN已经返回了最后一个时间步的状态,不需要额外池化
|
||
pooled = x
|
||
|
||
# Dropout for regularization
|
||
x = Dropout(self.dropout_rate, name='dropout_1')(pooled)
|
||
|
||
# Dense layer
|
||
x = Dense(128, name='dense_1')(x)
|
||
x = BatchNormalization(name='batch_norm_1')(x)
|
||
x = Activation('relu', name='activation_1')(x)
|
||
x = Dropout(self.dropout_rate, name='dropout_2')(x)
|
||
|
||
# Output layer
|
||
if self.num_classes == 2:
|
||
# Binary classification
|
||
predictions = Dense(1, activation='sigmoid', name='predictions')(x)
|
||
else:
|
||
# Multi-class classification
|
||
predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
|
||
|
||
# Build the model
|
||
self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
|
||
|
||
logger.info(
|
||
f"RNN模型构建完成,类型: {self.rnn_type.upper()}, 双向: {self.bidirectional}, 池化类型: {self.pool_type}")
|
||
|
||
def compile(self, optimizer=None, loss=None, metrics=None) -> None:
|
||
"""
|
||
编译RNN模型
|
||
|
||
Args:
|
||
optimizer: 优化器,默认为Adam
|
||
loss: 损失函数,默认根据类别数量选择
|
||
metrics: 评估指标,默认为accuracy
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 默认优化器
|
||
if optimizer is None:
|
||
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
|
||
|
||
# 默认损失函数
|
||
if loss is None:
|
||
if self.num_classes == 2:
|
||
loss = 'binary_crossentropy'
|
||
else:
|
||
loss = 'sparse_categorical_crossentropy'
|
||
|
||
# 默认评估指标
|
||
if metrics is None:
|
||
metrics = ['accuracy']
|
||
|
||
# 编译模型
|
||
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
|
||
logger.info(f"RNN模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}")
|