2025-03-08 01:34:36 +08:00

181 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
CNN模型实现基于卷积神经网络的文本分类模型
"""
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
Input, Embedding, Conv1D, MaxPooling1D, GlobalMaxPooling1D,
Dense, Dropout, Concatenate, BatchNormalization, Activation
)
from typing import List, Dict, Tuple, Optional, Any, Union
from config.model_config import (
MAX_SEQUENCE_LENGTH, CNN_CONFIG
)
from models.base_model import TextClassificationModel
from utils.logger import get_logger
logger = get_logger("CNNModel")
class CNNTextClassifier(TextClassificationModel):
"""卷积神经网络文本分类模型"""
def __init__(self, num_classes: int, vocab_size: int,
embedding_dim: int = CNN_CONFIG["embedding_dim"],
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
num_filters: int = CNN_CONFIG["num_filters"],
filter_sizes: List[int] = CNN_CONFIG["filter_sizes"],
dropout_rate: float = CNN_CONFIG["dropout_rate"],
l2_reg_lambda: float = CNN_CONFIG["l2_reg_lambda"],
embedding_matrix: Optional[np.ndarray] = None,
trainable_embedding: bool = True,
model_name: str = "cnn_text_classifier",
batch_size: int = 64,
learning_rate: float = 0.001):
"""
初始化CNN文本分类模型
Args:
num_classes: 类别数量
vocab_size: 词汇表大小
embedding_dim: 词嵌入维度
max_sequence_length: 最大序列长度
num_filters: 卷积核数量
filter_sizes: 卷积核大小列表
dropout_rate: Dropout比例
l2_reg_lambda: L2正则化系数
embedding_matrix: 预训练词嵌入矩阵如果为None则使用随机初始化
trainable_embedding: 词嵌入是否可训练
model_name: 模型名称
batch_size: 批大小
learning_rate: 学习率
"""
super().__init__(num_classes, model_name, batch_size, learning_rate)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.max_sequence_length = max_sequence_length
self.num_filters = num_filters
self.filter_sizes = filter_sizes
self.dropout_rate = dropout_rate
self.l2_reg_lambda = l2_reg_lambda
self.embedding_matrix = embedding_matrix
self.trainable_embedding = trainable_embedding
# 更新配置
self.config.update({
"vocab_size": vocab_size,
"embedding_dim": embedding_dim,
"max_sequence_length": max_sequence_length,
"num_filters": num_filters,
"filter_sizes": filter_sizes,
"dropout_rate": dropout_rate,
"l2_reg_lambda": l2_reg_lambda,
"trainable_embedding": trainable_embedding,
"model_type": "CNN"
})
logger.info(f"初始化CNN文本分类模型词汇表大小: {vocab_size}, 嵌入维度: {embedding_dim}")
def build(self) -> None:
"""构建CNN模型架构"""
# Input layer
sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
# Embedding layer
if self.embedding_matrix is not None:
embedding_layer = Embedding(
input_dim=self.vocab_size,
output_dim=self.embedding_dim,
weights=[self.embedding_matrix],
input_length=self.max_sequence_length,
trainable=self.trainable_embedding,
name='embedding'
)
else:
embedding_layer = Embedding(
input_dim=self.vocab_size,
output_dim=self.embedding_dim,
input_length=self.max_sequence_length,
trainable=True,
name='embedding'
)
embedded_sequences = embedding_layer(sequence_input)
# Convolutional layers with different filter sizes
conv_blocks = []
for filter_size in self.filter_sizes:
conv = Conv1D(
filters=self.num_filters,
kernel_size=filter_size,
padding='valid',
activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg_lambda),
name=f'conv_{filter_size}'
)(embedded_sequences)
# Max pooling
pooled = GlobalMaxPooling1D(name=f'max_pooling_{filter_size}')(conv)
conv_blocks.append(pooled)
# Concatenate pooled features if we have multiple filter sizes
if len(self.filter_sizes) > 1:
concatenated = Concatenate(name='concatenate')(conv_blocks)
else:
concatenated = conv_blocks[0]
# Dropout for regularization
x = Dropout(self.dropout_rate, name='dropout_1')(concatenated)
# Dense layer
x = Dense(128, name='dense_1')(x)
x = BatchNormalization(name='batch_norm_1')(x)
x = Activation('relu', name='activation_1')(x)
x = Dropout(self.dropout_rate, name='dropout_2')(x)
# Output layer
if self.num_classes == 2:
# Binary classification
predictions = Dense(1, activation='sigmoid', name='predictions')(x)
else:
# Multi-class classification
predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
# Build the model
self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
logger.info(f"CNN模型构建完成过滤器大小: {self.filter_sizes}, 每种大小的过滤器数量: {self.num_filters}")
def compile(self, optimizer=None, loss=None, metrics=None) -> None:
"""
编译CNN模型
Args:
optimizer: 优化器默认为Adam
loss: 损失函数,默认根据类别数量选择
metrics: 评估指标默认为accuracy
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 默认优化器
if optimizer is None:
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
# 默认损失函数
if loss is None:
if self.num_classes == 2:
loss = 'binary_crossentropy'
else:
loss = 'sparse_categorical_crossentropy'
# 默认评估指标
if metrics is None:
metrics = ['accuracy']
# 编译模型
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
logger.info(f"CNN模型已编译优化器: {optimizer.__class__.__name__}, 损失函数: {loss}")