204 lines
6.6 KiB
Python
204 lines
6.6 KiB
Python
"""
|
||
训练脚本:训练文本分类模型
|
||
"""
|
||
import os
|
||
import sys
|
||
import time
|
||
import argparse
|
||
import logging
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
import matplotlib.pyplot as plt
|
||
|
||
# 将项目根目录添加到系统路径
|
||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
sys.path.append(project_root)
|
||
|
||
from config.system_config import (
|
||
RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR
|
||
)
|
||
from config.model_config import (
|
||
BATCH_SIZE, NUM_EPOCHS, MAX_SEQUENCE_LENGTH, MAX_NUM_WORDS
|
||
)
|
||
from data.dataloader import DataLoader
|
||
from data.data_manager import DataManager
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
from preprocessing.vectorizer import SequenceVectorizer
|
||
from models.model_factory import ModelFactory
|
||
from training.trainer import Trainer
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("Training")
|
||
|
||
|
||
def train_model(data_dir: Optional[str] = None,
|
||
model_type: str = "cnn",
|
||
epochs: int = NUM_EPOCHS,
|
||
batch_size: int = BATCH_SIZE,
|
||
save_dir: Optional[str] = None,
|
||
validation_split: float = 0.1,
|
||
use_pretrained_embedding: bool = False,
|
||
embedding_path: Optional[str] = None) -> str:
|
||
"""
|
||
训练文本分类模型
|
||
|
||
Args:
|
||
data_dir: 数据目录,如果为None则使用默认目录
|
||
model_type: 模型类型,'cnn', 'rnn', 或 'transformer'
|
||
epochs: 训练轮数
|
||
batch_size: 批大小
|
||
save_dir: 模型保存目录,如果为None则使用默认目录
|
||
validation_split: 验证集比例
|
||
use_pretrained_embedding: 是否使用预训练词向量
|
||
embedding_path: 预训练词向量路径
|
||
|
||
Returns:
|
||
保存的模型路径
|
||
"""
|
||
logger.info(f"开始训练 {model_type.upper()} 模型")
|
||
start_time = time.time()
|
||
|
||
# 设置数据目录
|
||
data_dir = data_dir or RAW_DATA_DIR
|
||
|
||
# 设置保存目录
|
||
if save_dir:
|
||
save_dir = os.path.abspath(save_dir)
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
else:
|
||
save_dir = CLASSIFIERS_DIR
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
# 1. 加载数据
|
||
logger.info("加载数据...")
|
||
data_loader = DataLoader(data_dir=data_dir)
|
||
data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR)
|
||
|
||
# 加载和分割数据
|
||
data = data_manager.load_and_split_data(
|
||
data_loader=data_loader,
|
||
val_split=validation_split,
|
||
sample_ratio=1.0,
|
||
save=True
|
||
)
|
||
|
||
# 获取训练集和验证集
|
||
train_texts, train_labels = data_manager.get_data(dataset="train")
|
||
val_texts, val_labels = data_manager.get_data(dataset="val")
|
||
|
||
# 2. 准备数据
|
||
# 创建分词器
|
||
tokenizer = ChineseTokenizer()
|
||
|
||
# 对训练文本进行分词
|
||
logger.info("对文本进行分词...")
|
||
tokenized_train_texts = [tokenizer.tokenize(text, return_string=True) for text in train_texts]
|
||
tokenized_val_texts = [tokenizer.tokenize(text, return_string=True) for text in val_texts]
|
||
|
||
# 创建序列向量化器
|
||
logger.info("创建序列向量化器...")
|
||
vectorizer = SequenceVectorizer(
|
||
max_features=MAX_NUM_WORDS,
|
||
max_sequence_length=MAX_SEQUENCE_LENGTH
|
||
)
|
||
|
||
# 训练向量化器并转换文本
|
||
vectorizer.fit(tokenized_train_texts)
|
||
X_train = vectorizer.transform(tokenized_train_texts)
|
||
X_val = vectorizer.transform(tokenized_val_texts)
|
||
|
||
# 保存向量化器
|
||
vectorizer_path = os.path.join(save_dir, f"vectorizer_{model_type}.pkl")
|
||
vectorizer.save(vectorizer_path)
|
||
logger.info(f"向量化器已保存到: {vectorizer_path}")
|
||
|
||
# 获取一些基本参数
|
||
num_classes = len(CATEGORIES)
|
||
vocab_size = vectorizer.get_vocabulary_size()
|
||
|
||
# 3. 创建模型
|
||
logger.info(f"创建 {model_type.upper()} 模型...")
|
||
|
||
# 加载预训练词向量(如果指定)
|
||
embedding_matrix = None
|
||
if use_pretrained_embedding and embedding_path:
|
||
# 这里简化处理,实际应用中应该加载和处理预训练词向量
|
||
logger.info("加载预训练词向量...")
|
||
embedding_matrix = np.random.random((vocab_size, 200))
|
||
|
||
# 创建模型
|
||
model = ModelFactory.create_model(
|
||
model_type=model_type,
|
||
num_classes=num_classes,
|
||
vocab_size=vocab_size,
|
||
embedding_matrix=embedding_matrix,
|
||
batch_size=batch_size
|
||
)
|
||
|
||
# 构建模型
|
||
model.build()
|
||
model.compile()
|
||
model.summary()
|
||
|
||
# 4. 训练模型
|
||
logger.info("开始训练模型...")
|
||
trainer = Trainer(
|
||
model=model,
|
||
epochs=epochs,
|
||
batch_size=batch_size,
|
||
early_stopping=True,
|
||
tensorboard=True
|
||
)
|
||
|
||
# 训练
|
||
history = trainer.train(
|
||
x_train=X_train,
|
||
y_train=train_labels,
|
||
x_val=X_val,
|
||
y_val=val_labels
|
||
)
|
||
|
||
# 5. 保存模型
|
||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||
model_path = os.path.join(save_dir, f"{model_type}_model_{timestamp}")
|
||
model.save(model_path)
|
||
logger.info(f"模型已保存到: {model_path}")
|
||
|
||
# 6. 绘制训练历史
|
||
logger.info("绘制训练历史...")
|
||
model.plot_training_history(save_path=os.path.join(save_dir, f"training_history_{model_type}_{timestamp}.png"))
|
||
|
||
# 7. 计算训练时间
|
||
train_time = time.time() - start_time
|
||
logger.info(f"模型训练完成,耗时: {train_time:.2f} 秒")
|
||
|
||
return model_path
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 解析命令行参数
|
||
parser = argparse.ArgumentParser(description="训练文本分类模型")
|
||
parser.add_argument("--data_dir", help="数据目录")
|
||
parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型")
|
||
parser.add_argument("--epochs", type=int, default=NUM_EPOCHS, help="训练轮数")
|
||
parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小")
|
||
parser.add_argument("--save_dir", help="模型保存目录")
|
||
parser.add_argument("--validation_split", type=float, default=0.1, help="验证集比例")
|
||
parser.add_argument("--use_pretrained_embedding", action="store_true", help="是否使用预训练词向量")
|
||
parser.add_argument("--embedding_path", help="预训练词向量路径")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 训练模型
|
||
train_model(
|
||
data_dir=args.data_dir,
|
||
model_type=args.model_type,
|
||
epochs=args.epochs,
|
||
batch_size=args.batch_size,
|
||
save_dir=args.save_dir,
|
||
validation_split=args.validation_split,
|
||
use_pretrained_embedding=args.use_pretrained_embedding,
|
||
embedding_path=args.embedding_path
|
||
)
|