""" 评估脚本:评估文本分类模型性能 """ import os import sys import time import argparse import logging from typing import List, Dict, Tuple, Optional, Any, Union import numpy as np import matplotlib.pyplot as plt # 将项目根目录添加到系统路径 project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(project_root) from config.system_config import ( RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR ) from config.model_config import ( BATCH_SIZE, MAX_SEQUENCE_LENGTH ) from data.dataloader import DataLoader from data.data_manager import DataManager from preprocessing.tokenization import ChineseTokenizer from preprocessing.vectorizer import SequenceVectorizer from models.model_factory import ModelFactory from evaluation.evaluator import ModelEvaluator from utils.logger import get_logger logger = get_logger("Evaluation") def evaluate_model(model_path: str, data_dir: Optional[str] = None, batch_size: int = BATCH_SIZE, output_dir: Optional[str] = None) -> Dict[str, float]: """ 评估文本分类模型 Args: model_path: 模型路径 data_dir: 数据目录,如果为None则使用默认目录 batch_size: 批大小 output_dir: 评估结果输出目录,如果为None则使用默认目录 Returns: 评估指标 """ logger.info(f"开始评估模型: {model_path}") start_time = time.time() # 设置数据目录 data_dir = data_dir or RAW_DATA_DIR # 设置输出目录 if output_dir: output_dir = os.path.abspath(output_dir) os.makedirs(output_dir, exist_ok=True) # 1. 加载模型 logger.info("加载模型...") model = ModelFactory.load_model(model_path) # 2. 加载数据 logger.info("加载数据...") data_loader = DataLoader(data_dir=data_dir) data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR) # 加载测试集 data_manager.load_data() test_texts, test_labels = data_manager.get_data(dataset="test") # 3. 准备数据 # 创建分词器 tokenizer = ChineseTokenizer() # 对测试文本进行分词 logger.info("对文本进行分词...") tokenized_test_texts = [tokenizer.tokenize(text, return_string=True) for text in test_texts] # 创建序列向量化器 logger.info("加载向量化器...") # 查找向量化器文件 vectorizer_path = None for model_type in ["cnn", "rnn", "transformer"]: path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl") if os.path.exists(path): vectorizer_path = path break if not vectorizer_path: # 如果找不到向量化器,创建一个新的 logger.warning("未找到向量化器,创建一个新的") vectorizer = SequenceVectorizer( max_features=MAX_NUM_WORDS, max_sequence_length=MAX_SEQUENCE_LENGTH ) else: # 加载向量化器 vectorizer = SequenceVectorizer() vectorizer.load(vectorizer_path) # 转换测试文本 X_test = vectorizer.transform(tokenized_test_texts) # 4. 创建评估器 logger.info("创建评估器...") evaluator = ModelEvaluator( model=model, class_names=CATEGORIES, output_dir=output_dir ) # 5. 评估模型 logger.info("评估模型...") metrics = evaluator.evaluate(X_test, test_labels, batch_size) # 6. 保存评估结果 logger.info("保存评估结果...") evaluator.save_evaluation_results(save_plots=True) # 7. 可视化混淆矩阵 logger.info("可视化混淆矩阵...") cm = evaluator.evaluation_results['confusion_matrix'] evaluator.metrics.plot_confusion_matrix( y_true=test_labels, y_pred=np.argmax(model.predict(X_test), axis=1), normalize='true', save_path=os.path.join(output_dir or os.path.dirname(model_path), "confusion_matrix.png") ) # 8. 类别性能分析 logger.info("分析各类别性能...") class_performance = evaluator.evaluate_class_performance(X_test, test_labels) # 9. 计算评估时间 eval_time = time.time() - start_time logger.info(f"模型评估完成,耗时: {eval_time:.2f} 秒") # 10. 输出主要指标 logger.info("主要评估指标:") for metric_name, metric_value in metrics.items(): if metric_name in ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']: logger.info(f" {metric_name}: {metric_value:.4f}") return metrics if __name__ == "__main__": # 解析命令行参数 parser = argparse.ArgumentParser(description="评估文本分类模型") parser.add_argument("--model_path", required=True, help="模型路径") parser.add_argument("--data_dir", help="数据目录") parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小") parser.add_argument("--output_dir", help="评估结果输出目录") args = parser.parse_args() # 评估模型 evaluate_model( model_path=args.model_path, data_dir=args.data_dir, batch_size=args.batch_size, output_dir=args.output_dir )