270 lines
9.5 KiB
Python
270 lines
9.5 KiB
Python
"""
|
||
Transformer模型:实现基于Transformer的文本分类模型
|
||
"""
|
||
import tensorflow as tf
|
||
from tensorflow.keras.models import Model
|
||
from tensorflow.keras.layers import (
|
||
Input, Embedding, Dense, Dropout, LayerNormalization,
|
||
GlobalAveragePooling1D, MultiHeadAttention, Add
|
||
)
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import numpy as np
|
||
|
||
from config.model_config import (
|
||
MAX_SEQUENCE_LENGTH, TRANSFORMER_CONFIG
|
||
)
|
||
from models.base_model import TextClassificationModel
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("TransformerModel")
|
||
|
||
|
||
class TransformerBlock(tf.keras.layers.Layer):
|
||
"""Transformer块,包含多头注意力和前馈网络"""
|
||
|
||
def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout_rate: float = 0.1):
|
||
"""
|
||
初始化Transformer块
|
||
|
||
Args:
|
||
embed_dim: 嵌入维度
|
||
num_heads: 注意力头数
|
||
ff_dim: 前馈网络维度
|
||
dropout_rate: Dropout比例
|
||
"""
|
||
super(TransformerBlock, self).__init__()
|
||
self.embed_dim = embed_dim
|
||
self.num_heads = num_heads
|
||
self.ff_dim = ff_dim
|
||
self.dropout_rate = dropout_rate
|
||
|
||
self.attention = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
|
||
self.ffn = tf.keras.Sequential([
|
||
Dense(ff_dim, activation="relu"),
|
||
Dense(embed_dim),
|
||
])
|
||
self.layernorm1 = LayerNormalization(epsilon=1e-6)
|
||
self.layernorm2 = LayerNormalization(epsilon=1e-6)
|
||
self.dropout1 = Dropout(dropout_rate)
|
||
self.dropout2 = Dropout(dropout_rate)
|
||
|
||
def call(self, inputs, training=False):
|
||
"""
|
||
前向传播
|
||
|
||
Args:
|
||
inputs: 输入张量
|
||
training: 是否处于训练模式
|
||
|
||
Returns:
|
||
输出张量
|
||
"""
|
||
# 多头自注意力
|
||
attention_output = self.attention(inputs, inputs)
|
||
attention_output = self.dropout1(attention_output, training=training)
|
||
out1 = self.layernorm1(inputs + attention_output)
|
||
|
||
# 前馈网络
|
||
ffn_output = self.ffn(out1)
|
||
ffn_output = self.dropout2(ffn_output, training=training)
|
||
out2 = self.layernorm2(out1 + ffn_output)
|
||
|
||
return out2
|
||
|
||
def get_config(self):
|
||
"""获取配置"""
|
||
config = super(TransformerBlock, self).get_config()
|
||
config.update({
|
||
"embed_dim": self.embed_dim,
|
||
"num_heads": self.num_heads,
|
||
"ff_dim": self.ff_dim,
|
||
"dropout_rate": self.dropout_rate
|
||
})
|
||
return config
|
||
|
||
|
||
class TransformerTextClassifier(TextClassificationModel):
|
||
"""Transformer文本分类模型"""
|
||
|
||
def __init__(self, num_classes: int, vocab_size: int,
|
||
embedding_dim: int = TRANSFORMER_CONFIG["embedding_dim"],
|
||
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
|
||
num_heads: int = TRANSFORMER_CONFIG["num_heads"],
|
||
ff_dim: int = TRANSFORMER_CONFIG["ff_dim"],
|
||
num_layers: int = TRANSFORMER_CONFIG["num_layers"],
|
||
dropout_rate: float = TRANSFORMER_CONFIG["dropout_rate"],
|
||
embedding_matrix: Optional[np.ndarray] = None,
|
||
trainable_embedding: bool = True,
|
||
use_positional_encoding: bool = True,
|
||
model_name: str = "transformer_text_classifier",
|
||
batch_size: int = 64,
|
||
learning_rate: float = 0.001):
|
||
"""
|
||
初始化Transformer文本分类模型
|
||
|
||
Args:
|
||
num_classes: 类别数量
|
||
vocab_size: 词汇表大小
|
||
embedding_dim: 词嵌入维度
|
||
max_sequence_length: 最大序列长度
|
||
num_heads: 注意力头数
|
||
ff_dim: 前馈网络维度
|
||
num_layers: Transformer层数
|
||
dropout_rate: Dropout比例
|
||
embedding_matrix: 预训练词嵌入矩阵,如果为None则使用随机初始化
|
||
trainable_embedding: 词嵌入是否可训练
|
||
use_positional_encoding: 是否使用位置编码
|
||
model_name: 模型名称
|
||
batch_size: 批大小
|
||
learning_rate: 学习率
|
||
"""
|
||
super().__init__(num_classes, model_name, batch_size, learning_rate)
|
||
|
||
self.vocab_size = vocab_size
|
||
self.embedding_dim = embedding_dim
|
||
self.max_sequence_length = max_sequence_length
|
||
self.num_heads = num_heads
|
||
self.ff_dim = ff_dim
|
||
self.num_layers = num_layers
|
||
self.dropout_rate = dropout_rate
|
||
self.embedding_matrix = embedding_matrix
|
||
self.trainable_embedding = trainable_embedding
|
||
self.use_positional_encoding = use_positional_encoding
|
||
|
||
# 更新配置
|
||
self.config.update({
|
||
"vocab_size": vocab_size,
|
||
"embedding_dim": embedding_dim,
|
||
"max_sequence_length": max_sequence_length,
|
||
"num_heads": num_heads,
|
||
"ff_dim": ff_dim,
|
||
"num_layers": num_layers,
|
||
"dropout_rate": dropout_rate,
|
||
"trainable_embedding": trainable_embedding,
|
||
"use_positional_encoding": use_positional_encoding,
|
||
"model_type": "Transformer"
|
||
})
|
||
|
||
logger.info(f"初始化Transformer文本分类模型,头数: {num_heads}, 层数: {num_layers}")
|
||
|
||
def _positional_encoding(self, max_length: int, d_model: int) -> tf.Tensor:
|
||
"""
|
||
生成位置编码
|
||
|
||
Args:
|
||
max_length: 最大序列长度
|
||
d_model: 模型维度
|
||
|
||
Returns:
|
||
位置编码张量
|
||
"""
|
||
positions = np.arange(max_length)[:, np.newaxis]
|
||
depths = np.arange(d_model)[np.newaxis, :] // 2 * 2
|
||
angle_rates = 1 / np.power(10000, depths / d_model)
|
||
angle_rads = positions * angle_rates
|
||
|
||
# sin用于偶数索引,cos用于奇数索引
|
||
sines = np.sin(angle_rads[:, 0::2])
|
||
cosines = np.cos(angle_rads[:, 1::2])
|
||
|
||
pos_encoding = np.zeros((max_length, d_model))
|
||
pos_encoding[:, 0::2] = sines
|
||
pos_encoding[:, 1::2] = cosines
|
||
|
||
return tf.cast(pos_encoding[tf.newaxis, ...], dtype=tf.float32)
|
||
|
||
def build(self) -> None:
|
||
"""构建Transformer模型架构"""
|
||
# Input layer
|
||
sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
|
||
|
||
# Embedding layer
|
||
if self.embedding_matrix is not None:
|
||
embedding_layer = Embedding(
|
||
input_dim=self.vocab_size,
|
||
output_dim=self.embedding_dim,
|
||
weights=[self.embedding_matrix],
|
||
input_length=self.max_sequence_length,
|
||
trainable=self.trainable_embedding,
|
||
name='embedding'
|
||
)
|
||
else:
|
||
embedding_layer = Embedding(
|
||
input_dim=self.vocab_size,
|
||
output_dim=self.embedding_dim,
|
||
input_length=self.max_sequence_length,
|
||
trainable=True,
|
||
name='embedding'
|
||
)
|
||
|
||
embedded_sequences = embedding_layer(sequence_input)
|
||
|
||
# 添加位置编码
|
||
if self.use_positional_encoding:
|
||
pos_encoding = self._positional_encoding(self.max_sequence_length, self.embedding_dim)
|
||
embedded_sequences = embedded_sequences + pos_encoding
|
||
|
||
# Transformer层
|
||
x = embedded_sequences
|
||
for i in range(self.num_layers):
|
||
x = TransformerBlock(
|
||
embed_dim=self.embedding_dim,
|
||
num_heads=self.num_heads,
|
||
ff_dim=self.ff_dim,
|
||
dropout_rate=self.dropout_rate,
|
||
name=f'transformer_block_{i + 1}'
|
||
)(x)
|
||
|
||
# 全局池化
|
||
x = GlobalAveragePooling1D(name='global_avg_pooling')(x)
|
||
|
||
# Dropout for regularization
|
||
x = Dropout(self.dropout_rate, name='dropout_1')(x)
|
||
|
||
# Dense layer
|
||
x = Dense(128, activation='relu', name='dense_1')(x)
|
||
x = Dropout(self.dropout_rate, name='dropout_2')(x)
|
||
|
||
# Output layer
|
||
if self.num_classes == 2:
|
||
# Binary classification
|
||
predictions = Dense(1, activation='sigmoid', name='predictions')(x)
|
||
else:
|
||
# Multi-class classification
|
||
predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
|
||
|
||
# Build the model
|
||
self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
|
||
|
||
logger.info(f"Transformer模型构建完成,头数: {self.num_heads}, 层数: {self.num_layers}")
|
||
|
||
def compile(self, optimizer=None, loss=None, metrics=None) -> None:
|
||
"""
|
||
编译Transformer模型
|
||
|
||
Args:
|
||
optimizer: 优化器,默认为Adam
|
||
loss: 损失函数,默认根据类别数量选择
|
||
metrics: 评估指标,默认为accuracy
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 默认优化器
|
||
if optimizer is None:
|
||
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
|
||
|
||
# 默认损失函数
|
||
if loss is None:
|
||
if self.num_classes == 2:
|
||
loss = 'binary_crossentropy'
|
||
else:
|
||
loss = 'sparse_categorical_crossentropy'
|
||
|
||
# 默认评估指标
|
||
if metrics is None:
|
||
metrics = ['accuracy']
|
||
|
||
# 编译模型
|
||
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
|
||
logger.info(f"Transformer模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}") |