2025-03-08 01:34:36 +08:00

204 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
训练脚本:训练文本分类模型
"""
import os
import sys
import time
import argparse
import logging
from typing import List, Dict, Tuple, Optional, Any, Union
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# 将项目根目录添加到系统路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
from config.system_config import (
RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR
)
from config.model_config import (
BATCH_SIZE, NUM_EPOCHS, MAX_SEQUENCE_LENGTH, MAX_NUM_WORDS
)
from data.dataloader import DataLoader
from data.data_manager import DataManager
from preprocessing.tokenization import ChineseTokenizer
from preprocessing.vectorizer import SequenceVectorizer
from models.model_factory import ModelFactory
from training.trainer import Trainer
from utils.logger import get_logger
logger = get_logger("Training")
def train_model(data_dir: Optional[str] = None,
model_type: str = "cnn",
epochs: int = NUM_EPOCHS,
batch_size: int = BATCH_SIZE,
save_dir: Optional[str] = None,
validation_split: float = 0.1,
use_pretrained_embedding: bool = False,
embedding_path: Optional[str] = None) -> str:
"""
训练文本分类模型
Args:
data_dir: 数据目录如果为None则使用默认目录
model_type: 模型类型,'cnn', 'rnn', 或 'transformer'
epochs: 训练轮数
batch_size: 批大小
save_dir: 模型保存目录如果为None则使用默认目录
validation_split: 验证集比例
use_pretrained_embedding: 是否使用预训练词向量
embedding_path: 预训练词向量路径
Returns:
保存的模型路径
"""
logger.info(f"开始训练 {model_type.upper()} 模型")
start_time = time.time()
# 设置数据目录
data_dir = data_dir or RAW_DATA_DIR
# 设置保存目录
if save_dir:
save_dir = os.path.abspath(save_dir)
os.makedirs(save_dir, exist_ok=True)
else:
save_dir = CLASSIFIERS_DIR
os.makedirs(save_dir, exist_ok=True)
# 1. 加载数据
logger.info("加载数据...")
data_loader = DataLoader(data_dir=data_dir)
data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR)
# 加载和分割数据
data = data_manager.load_and_split_data(
data_loader=data_loader,
val_split=validation_split,
sample_ratio=1.0,
save=True
)
# 获取训练集和验证集
train_texts, train_labels = data_manager.get_data(dataset="train")
val_texts, val_labels = data_manager.get_data(dataset="val")
# 2. 准备数据
# 创建分词器
tokenizer = ChineseTokenizer()
# 对训练文本进行分词
logger.info("对文本进行分词...")
tokenized_train_texts = [tokenizer.tokenize(text, return_string=True) for text in train_texts]
tokenized_val_texts = [tokenizer.tokenize(text, return_string=True) for text in val_texts]
# 创建序列向量化器
logger.info("创建序列向量化器...")
vectorizer = SequenceVectorizer(
max_features=MAX_NUM_WORDS,
max_sequence_length=MAX_SEQUENCE_LENGTH
)
# 训练向量化器并转换文本
vectorizer.fit(tokenized_train_texts)
X_train = vectorizer.transform(tokenized_train_texts)
X_val = vectorizer.transform(tokenized_val_texts)
# 保存向量化器
vectorizer_path = os.path.join(save_dir, f"vectorizer_{model_type}.pkl")
vectorizer.save(vectorizer_path)
logger.info(f"向量化器已保存到: {vectorizer_path}")
# 获取一些基本参数
num_classes = len(CATEGORIES)
vocab_size = vectorizer.get_vocabulary_size()
# 3. 创建模型
logger.info(f"创建 {model_type.upper()} 模型...")
# 加载预训练词向量(如果指定)
embedding_matrix = None
if use_pretrained_embedding and embedding_path:
# 这里简化处理,实际应用中应该加载和处理预训练词向量
logger.info("加载预训练词向量...")
embedding_matrix = np.random.random((vocab_size, 200))
# 创建模型
model = ModelFactory.create_model(
model_type=model_type,
num_classes=num_classes,
vocab_size=vocab_size,
embedding_matrix=embedding_matrix,
batch_size=batch_size
)
# 构建模型
model.build()
model.compile()
model.summary()
# 4. 训练模型
logger.info("开始训练模型...")
trainer = Trainer(
model=model,
epochs=epochs,
batch_size=batch_size,
early_stopping=True,
tensorboard=True
)
# 训练
history = trainer.train(
x_train=X_train,
y_train=train_labels,
x_val=X_val,
y_val=val_labels
)
# 5. 保存模型
timestamp = time.strftime("%Y%m%d_%H%M%S")
model_path = os.path.join(save_dir, f"{model_type}_model_{timestamp}")
model.save(model_path)
logger.info(f"模型已保存到: {model_path}")
# 6. 绘制训练历史
logger.info("绘制训练历史...")
model.plot_training_history(save_path=os.path.join(save_dir, f"training_history_{model_type}_{timestamp}.png"))
# 7. 计算训练时间
train_time = time.time() - start_time
logger.info(f"模型训练完成,耗时: {train_time:.2f}")
return model_path
if __name__ == "__main__":
# 解析命令行参数
parser = argparse.ArgumentParser(description="训练文本分类模型")
parser.add_argument("--data_dir", help="数据目录")
parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型")
parser.add_argument("--epochs", type=int, default=NUM_EPOCHS, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小")
parser.add_argument("--save_dir", help="模型保存目录")
parser.add_argument("--validation_split", type=float, default=0.1, help="验证集比例")
parser.add_argument("--use_pretrained_embedding", action="store_true", help="是否使用预训练词向量")
parser.add_argument("--embedding_path", help="预训练词向量路径")
args = parser.parse_args()
# 训练模型
train_model(
data_dir=args.data_dir,
model_type=args.model_type,
epochs=args.epochs,
batch_size=args.batch_size,
save_dir=args.save_dir,
validation_split=args.validation_split,
use_pretrained_embedding=args.use_pretrained_embedding,
embedding_path=args.embedding_path
)