2025-03-08 01:34:36 +08:00

72 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
模型配置文件
"""
# 文本预处理参数
MAX_SEQUENCE_LENGTH = 500 # 文本序列最大长度
MAX_NUM_WORDS = 50000 # 词汇表最大大小
MAX_CHAR_LENGTH = 2000 # 字符级最大长度
MIN_WORD_FREQUENCY = 5 # 最小词频
# 模型架构参数
CNN_CONFIG = {
"embedding_dim": 200,
"num_filters": 256,
"filter_sizes": [3, 4, 5],
"dropout_rate": 0.5,
"l2_reg_lambda": 0.0,
}
RNN_CONFIG = {
"embedding_dim": 200,
"hidden_size": 256,
"num_layers": 2,
"bidirectional": True,
"dropout_rate": 0.5,
}
TRANSFORMER_CONFIG = {
"embedding_dim": 200,
"num_heads": 8,
"ff_dim": 512,
"num_layers": 4,
"dropout_rate": 0.1,
}
# 针对RTX 4090的优化设置
BATCH_SIZE = 128 # RTX 4090有24GB显存可以支持较大的batch
EVAL_BATCH_SIZE = 256 # 评估时可以用更大的batch
# 训练参数
LEARNING_RATE = 1e-3
NUM_EPOCHS = 20
EARLY_STOPPING_PATIENCE = 3
REDUCE_LR_PATIENCE = 2
REDUCE_LR_FACTOR = 0.5
VALIDATION_SPLIT = 0.1
TEST_SPLIT = 0.1
# 词嵌入参数
USE_PRETRAINED_EMBEDDING = True
EMBEDDING_TYPE = "word2vec" # 可选: word2vec, glove, fasttext
# 随机种子,保证实验可重复性
RANDOM_SEED = 42
# 模型保存参数
SAVE_BEST_ONLY = True
MODEL_CHECKPOINT_PATH = "best_model.h5"
# 特征工程参数
USE_CHAR_LEVEL = False # 是否使用字符级特征
USE_WORD_LEVEL = True # 是否使用词级特征
USE_TFIDF = False # 是否使用TF-IDF特征
USE_POS_TAGS = False # 是否使用词性标注特征
# 数据增强参数
USE_DATA_AUGMENTATION = False
AUGMENTATION_FACTOR = 0.2 # 增强20%的数据
# 推理参数
PREDICTION_THRESHOLD = 0.5
TOP_K_PREDICTIONS = 3