181 lines
6.7 KiB
Python
181 lines
6.7 KiB
Python
"""
|
||
CNN模型:实现基于卷积神经网络的文本分类模型
|
||
"""
|
||
import tensorflow as tf
|
||
from tensorflow.keras.models import Model
|
||
from tensorflow.keras.layers import (
|
||
Input, Embedding, Conv1D, MaxPooling1D, GlobalMaxPooling1D,
|
||
Dense, Dropout, Concatenate, BatchNormalization, Activation
|
||
)
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
|
||
from config.model_config import (
|
||
MAX_SEQUENCE_LENGTH, CNN_CONFIG
|
||
)
|
||
from models.base_model import TextClassificationModel
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("CNNModel")
|
||
|
||
|
||
class CNNTextClassifier(TextClassificationModel):
|
||
"""卷积神经网络文本分类模型"""
|
||
|
||
def __init__(self, num_classes: int, vocab_size: int,
|
||
embedding_dim: int = CNN_CONFIG["embedding_dim"],
|
||
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
|
||
num_filters: int = CNN_CONFIG["num_filters"],
|
||
filter_sizes: List[int] = CNN_CONFIG["filter_sizes"],
|
||
dropout_rate: float = CNN_CONFIG["dropout_rate"],
|
||
l2_reg_lambda: float = CNN_CONFIG["l2_reg_lambda"],
|
||
embedding_matrix: Optional[np.ndarray] = None,
|
||
trainable_embedding: bool = True,
|
||
model_name: str = "cnn_text_classifier",
|
||
batch_size: int = 64,
|
||
learning_rate: float = 0.001):
|
||
"""
|
||
初始化CNN文本分类模型
|
||
|
||
Args:
|
||
num_classes: 类别数量
|
||
vocab_size: 词汇表大小
|
||
embedding_dim: 词嵌入维度
|
||
max_sequence_length: 最大序列长度
|
||
num_filters: 卷积核数量
|
||
filter_sizes: 卷积核大小列表
|
||
dropout_rate: Dropout比例
|
||
l2_reg_lambda: L2正则化系数
|
||
embedding_matrix: 预训练词嵌入矩阵,如果为None则使用随机初始化
|
||
trainable_embedding: 词嵌入是否可训练
|
||
model_name: 模型名称
|
||
batch_size: 批大小
|
||
learning_rate: 学习率
|
||
"""
|
||
super().__init__(num_classes, model_name, batch_size, learning_rate)
|
||
|
||
self.vocab_size = vocab_size
|
||
self.embedding_dim = embedding_dim
|
||
self.max_sequence_length = max_sequence_length
|
||
self.num_filters = num_filters
|
||
self.filter_sizes = filter_sizes
|
||
self.dropout_rate = dropout_rate
|
||
self.l2_reg_lambda = l2_reg_lambda
|
||
self.embedding_matrix = embedding_matrix
|
||
self.trainable_embedding = trainable_embedding
|
||
|
||
# 更新配置
|
||
self.config.update({
|
||
"vocab_size": vocab_size,
|
||
"embedding_dim": embedding_dim,
|
||
"max_sequence_length": max_sequence_length,
|
||
"num_filters": num_filters,
|
||
"filter_sizes": filter_sizes,
|
||
"dropout_rate": dropout_rate,
|
||
"l2_reg_lambda": l2_reg_lambda,
|
||
"trainable_embedding": trainable_embedding,
|
||
"model_type": "CNN"
|
||
})
|
||
|
||
logger.info(f"初始化CNN文本分类模型,词汇表大小: {vocab_size}, 嵌入维度: {embedding_dim}")
|
||
|
||
def build(self) -> None:
|
||
"""构建CNN模型架构"""
|
||
# Input layer
|
||
sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
|
||
|
||
# Embedding layer
|
||
if self.embedding_matrix is not None:
|
||
embedding_layer = Embedding(
|
||
input_dim=self.vocab_size,
|
||
output_dim=self.embedding_dim,
|
||
weights=[self.embedding_matrix],
|
||
input_length=self.max_sequence_length,
|
||
trainable=self.trainable_embedding,
|
||
name='embedding'
|
||
)
|
||
else:
|
||
embedding_layer = Embedding(
|
||
input_dim=self.vocab_size,
|
||
output_dim=self.embedding_dim,
|
||
input_length=self.max_sequence_length,
|
||
trainable=True,
|
||
name='embedding'
|
||
)
|
||
|
||
embedded_sequences = embedding_layer(sequence_input)
|
||
|
||
# Convolutional layers with different filter sizes
|
||
conv_blocks = []
|
||
for filter_size in self.filter_sizes:
|
||
conv = Conv1D(
|
||
filters=self.num_filters,
|
||
kernel_size=filter_size,
|
||
padding='valid',
|
||
activation='relu',
|
||
kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg_lambda),
|
||
name=f'conv_{filter_size}'
|
||
)(embedded_sequences)
|
||
|
||
# Max pooling
|
||
pooled = GlobalMaxPooling1D(name=f'max_pooling_{filter_size}')(conv)
|
||
conv_blocks.append(pooled)
|
||
|
||
# Concatenate pooled features if we have multiple filter sizes
|
||
if len(self.filter_sizes) > 1:
|
||
concatenated = Concatenate(name='concatenate')(conv_blocks)
|
||
else:
|
||
concatenated = conv_blocks[0]
|
||
|
||
# Dropout for regularization
|
||
x = Dropout(self.dropout_rate, name='dropout_1')(concatenated)
|
||
|
||
# Dense layer
|
||
x = Dense(128, name='dense_1')(x)
|
||
x = BatchNormalization(name='batch_norm_1')(x)
|
||
x = Activation('relu', name='activation_1')(x)
|
||
x = Dropout(self.dropout_rate, name='dropout_2')(x)
|
||
|
||
# Output layer
|
||
if self.num_classes == 2:
|
||
# Binary classification
|
||
predictions = Dense(1, activation='sigmoid', name='predictions')(x)
|
||
else:
|
||
# Multi-class classification
|
||
predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
|
||
|
||
# Build the model
|
||
self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
|
||
|
||
logger.info(f"CNN模型构建完成,过滤器大小: {self.filter_sizes}, 每种大小的过滤器数量: {self.num_filters}")
|
||
|
||
def compile(self, optimizer=None, loss=None, metrics=None) -> None:
|
||
"""
|
||
编译CNN模型
|
||
|
||
Args:
|
||
optimizer: 优化器,默认为Adam
|
||
loss: 损失函数,默认根据类别数量选择
|
||
metrics: 评估指标,默认为accuracy
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 默认优化器
|
||
if optimizer is None:
|
||
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
|
||
|
||
# 默认损失函数
|
||
if loss is None:
|
||
if self.num_classes == 2:
|
||
loss = 'binary_crossentropy'
|
||
else:
|
||
loss = 'sparse_categorical_crossentropy'
|
||
|
||
# 默认评估指标
|
||
if metrics is None:
|
||
metrics = ['accuracy']
|
||
|
||
# 编译模型
|
||
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
|
||
logger.info(f"CNN模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}")
|