72 lines
1.6 KiB
Python
72 lines
1.6 KiB
Python
"""
|
||
模型配置文件
|
||
"""
|
||
|
||
# 文本预处理参数
|
||
MAX_SEQUENCE_LENGTH = 500 # 文本序列最大长度
|
||
MAX_NUM_WORDS = 50000 # 词汇表最大大小
|
||
MAX_CHAR_LENGTH = 2000 # 字符级最大长度
|
||
MIN_WORD_FREQUENCY = 5 # 最小词频
|
||
|
||
# 模型架构参数
|
||
CNN_CONFIG = {
|
||
"embedding_dim": 200,
|
||
"num_filters": 256,
|
||
"filter_sizes": [3, 4, 5],
|
||
"dropout_rate": 0.5,
|
||
"l2_reg_lambda": 0.0,
|
||
}
|
||
|
||
RNN_CONFIG = {
|
||
"embedding_dim": 200,
|
||
"hidden_size": 256,
|
||
"num_layers": 2,
|
||
"bidirectional": True,
|
||
"dropout_rate": 0.5,
|
||
}
|
||
|
||
TRANSFORMER_CONFIG = {
|
||
"embedding_dim": 200,
|
||
"num_heads": 8,
|
||
"ff_dim": 512,
|
||
"num_layers": 4,
|
||
"dropout_rate": 0.1,
|
||
}
|
||
|
||
# 针对RTX 4090的优化设置
|
||
BATCH_SIZE = 128 # RTX 4090有24GB显存,可以支持较大的batch
|
||
EVAL_BATCH_SIZE = 256 # 评估时可以用更大的batch
|
||
|
||
# 训练参数
|
||
LEARNING_RATE = 1e-3
|
||
NUM_EPOCHS = 20
|
||
EARLY_STOPPING_PATIENCE = 3
|
||
REDUCE_LR_PATIENCE = 2
|
||
REDUCE_LR_FACTOR = 0.5
|
||
VALIDATION_SPLIT = 0.1
|
||
TEST_SPLIT = 0.1
|
||
|
||
# 词嵌入参数
|
||
USE_PRETRAINED_EMBEDDING = True
|
||
EMBEDDING_TYPE = "word2vec" # 可选: word2vec, glove, fasttext
|
||
|
||
# 随机种子,保证实验可重复性
|
||
RANDOM_SEED = 42
|
||
|
||
# 模型保存参数
|
||
SAVE_BEST_ONLY = True
|
||
MODEL_CHECKPOINT_PATH = "best_model.h5"
|
||
|
||
# 特征工程参数
|
||
USE_CHAR_LEVEL = False # 是否使用字符级特征
|
||
USE_WORD_LEVEL = True # 是否使用词级特征
|
||
USE_TFIDF = False # 是否使用TF-IDF特征
|
||
USE_POS_TAGS = False # 是否使用词性标注特征
|
||
|
||
# 数据增强参数
|
||
USE_DATA_AUGMENTATION = False
|
||
AUGMENTATION_FACTOR = 0.2 # 增强20%的数据
|
||
|
||
# 推理参数
|
||
PREDICTION_THRESHOLD = 0.5
|
||
TOP_K_PREDICTIONS = 3 |