""" 训练脚本:训练文本分类模型 """ import os import sys import time import argparse import logging from typing import List, Dict, Tuple, Optional, Any, Union import numpy as np import tensorflow as tf import matplotlib.pyplot as plt # 检测 GPU physical_devices = tf.config.list_physical_devices('GPU') print("可用的物理 GPU 设备:", physical_devices) if physical_devices: try: # 设置 GPU 内存增长模式 for gpu in physical_devices: tf.config.experimental.set_memory_growth(gpu, True) print("已设置 GPU 内存增长模式") except RuntimeError as e: print(f"设置 GPU 内存增长时出错: {e}") # 将项目根目录添加到系统路径 project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(project_root) from config.system_config import ( RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR ) from config.model_config import ( BATCH_SIZE, NUM_EPOCHS, MAX_SEQUENCE_LENGTH, MAX_NUM_WORDS ) from data.dataloader import DataLoader from data.data_manager import DataManager from preprocessing.tokenization import ChineseTokenizer from preprocessing.vectorizer import SequenceVectorizer from models.model_factory import ModelFactory from training.trainer import Trainer from utils.logger import get_logger logger = get_logger("Training") def train_model(data_dir: Optional[str] = None, model_type: str = "cnn", epochs: int = NUM_EPOCHS, batch_size: int = BATCH_SIZE, save_dir: Optional[str] = None, validation_split: float = 0.1, use_pretrained_embedding: bool = False, embedding_path: Optional[str] = None) -> str: """ 训练文本分类模型 Args: data_dir: 数据目录,如果为None则使用默认目录 model_type: 模型类型,'cnn', 'rnn', 或 'transformer' epochs: 训练轮数 batch_size: 批大小 save_dir: 模型保存目录,如果为None则使用默认目录 validation_split: 验证集比例 use_pretrained_embedding: 是否使用预训练词向量 embedding_path: 预训练词向量路径 Returns: 保存的模型路径 """ logger.info(f"开始训练 {model_type.upper()} 模型") start_time = time.time() # 设置数据目录 data_dir = data_dir or RAW_DATA_DIR # 设置保存目录 if save_dir: save_dir = os.path.abspath(save_dir) os.makedirs(save_dir, exist_ok=True) else: save_dir = CLASSIFIERS_DIR os.makedirs(save_dir, exist_ok=True) # 1. 加载数据 logger.info("加载数据...") data_loader = DataLoader(data_dir=data_dir) data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR) # 加载和分割数据 data = data_manager.load_and_split_data( data_loader=data_loader, val_split=validation_split, sample_ratio=1.0, save=True ) # 获取训练集和验证集 train_texts, train_labels = data_manager.get_data(dataset="train") val_texts, val_labels = data_manager.get_data(dataset="val") # 2. 准备数据 # 创建分词器 tokenizer = ChineseTokenizer() # 对训练文本进行分词 logger.info("对文本进行分词...") tokenized_train_texts = [tokenizer.tokenize(text, return_string=True) for text in train_texts] tokenized_val_texts = [tokenizer.tokenize(text, return_string=True) for text in val_texts] # 创建序列向量化器 logger.info("创建序列向量化器...") vectorizer = SequenceVectorizer( max_features=MAX_NUM_WORDS, max_sequence_length=MAX_SEQUENCE_LENGTH ) # 训练向量化器并转换文本 vectorizer.fit(tokenized_train_texts) X_train = vectorizer.transform(tokenized_train_texts) X_val = vectorizer.transform(tokenized_val_texts) # 保存向量化器 vectorizer_path = os.path.join(save_dir, f"vectorizer_{model_type}.pkl") vectorizer.save(vectorizer_path) logger.info(f"向量化器已保存到: {vectorizer_path}") # 获取一些基本参数 num_classes = len(CATEGORIES) vocab_size = vectorizer.get_vocabulary_size() # 3. 创建模型 logger.info(f"创建 {model_type.upper()} 模型...") # 加载预训练词向量(如果指定) embedding_matrix = None if use_pretrained_embedding and embedding_path: # 这里简化处理,实际应用中应该加载和处理预训练词向量 logger.info("加载预训练词向量...") embedding_matrix = np.random.random((vocab_size, 200)) # 创建模型 model = ModelFactory.create_model( model_type=model_type, num_classes=num_classes, vocab_size=vocab_size, embedding_matrix=embedding_matrix, batch_size=batch_size ) # 构建模型 model.build() model.compile() model.summary() # 4. 训练模型 logger.info("开始训练模型...") trainer = Trainer( model=model, epochs=epochs, batch_size=batch_size, early_stopping=True, tensorboard=True ) # 训练 history = trainer.train( x_train=X_train, y_train=train_labels, x_val=X_val, y_val=val_labels ) # 5. 保存模型 timestamp = time.strftime("%Y%m%d_%H%M%S") model_path = os.path.join(save_dir, f"{model_type}_model_{timestamp}") model.save(model_path) logger.info(f"模型已保存到: {model_path}") # 6. 绘制训练历史 logger.info("绘制训练历史...") model.plot_training_history(save_path=os.path.join(save_dir, f"training_history_{model_type}_{timestamp}.png")) # 7. 计算训练时间 train_time = time.time() - start_time logger.info(f"模型训练完成,耗时: {train_time:.2f} 秒") return model_path if __name__ == "__main__": # 解析命令行参数 parser = argparse.ArgumentParser(description="训练文本分类模型") parser.add_argument("--data_dir", help="数据目录") parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型") parser.add_argument("--epochs", type=int, default=NUM_EPOCHS, help="训练轮数") parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小") parser.add_argument("--save_dir", help="模型保存目录") parser.add_argument("--validation_split", type=float, default=0.1, help="验证集比例") parser.add_argument("--use_pretrained_embedding", action="store_true", help="是否使用预训练词向量") parser.add_argument("--embedding_path", help="预训练词向量路径") args = parser.parse_args() # 训练模型 train_model( data_dir=args.data_dir, model_type=args.model_type, epochs=args.epochs, batch_size=args.batch_size, save_dir=args.save_dir, validation_split=args.validation_split, use_pretrained_embedding=args.use_pretrained_embedding, embedding_path=args.embedding_path )