2025-03-08 01:34:36 +08:00

167 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
评估脚本:评估文本分类模型性能
"""
import os
import sys
import time
import argparse
import logging
from typing import List, Dict, Tuple, Optional, Any, Union
import numpy as np
import matplotlib.pyplot as plt
# 将项目根目录添加到系统路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
from config.system_config import (
RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR
)
from config.model_config import (
BATCH_SIZE, MAX_SEQUENCE_LENGTH
)
from data.dataloader import DataLoader
from data.data_manager import DataManager
from preprocessing.tokenization import ChineseTokenizer
from preprocessing.vectorizer import SequenceVectorizer
from models.model_factory import ModelFactory
from evaluation.evaluator import ModelEvaluator
from utils.logger import get_logger
logger = get_logger("Evaluation")
def evaluate_model(model_path: str,
data_dir: Optional[str] = None,
batch_size: int = BATCH_SIZE,
output_dir: Optional[str] = None) -> Dict[str, float]:
"""
评估文本分类模型
Args:
model_path: 模型路径
data_dir: 数据目录如果为None则使用默认目录
batch_size: 批大小
output_dir: 评估结果输出目录如果为None则使用默认目录
Returns:
评估指标
"""
logger.info(f"开始评估模型: {model_path}")
start_time = time.time()
# 设置数据目录
data_dir = data_dir or RAW_DATA_DIR
# 设置输出目录
if output_dir:
output_dir = os.path.abspath(output_dir)
os.makedirs(output_dir, exist_ok=True)
# 1. 加载模型
logger.info("加载模型...")
model = ModelFactory.load_model(model_path)
# 2. 加载数据
logger.info("加载数据...")
data_loader = DataLoader(data_dir=data_dir)
data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR)
# 加载测试集
data_manager.load_data()
test_texts, test_labels = data_manager.get_data(dataset="test")
# 3. 准备数据
# 创建分词器
tokenizer = ChineseTokenizer()
# 对测试文本进行分词
logger.info("对文本进行分词...")
tokenized_test_texts = [tokenizer.tokenize(text, return_string=True) for text in test_texts]
# 创建序列向量化器
logger.info("加载向量化器...")
# 查找向量化器文件
vectorizer_path = None
for model_type in ["cnn", "rnn", "transformer"]:
path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
if os.path.exists(path):
vectorizer_path = path
break
if not vectorizer_path:
# 如果找不到向量化器,创建一个新的
logger.warning("未找到向量化器,创建一个新的")
vectorizer = SequenceVectorizer(
max_features=MAX_NUM_WORDS,
max_sequence_length=MAX_SEQUENCE_LENGTH
)
else:
# 加载向量化器
vectorizer = SequenceVectorizer()
vectorizer.load(vectorizer_path)
# 转换测试文本
X_test = vectorizer.transform(tokenized_test_texts)
# 4. 创建评估器
logger.info("创建评估器...")
evaluator = ModelEvaluator(
model=model,
class_names=CATEGORIES,
output_dir=output_dir
)
# 5. 评估模型
logger.info("评估模型...")
metrics = evaluator.evaluate(X_test, test_labels, batch_size)
# 6. 保存评估结果
logger.info("保存评估结果...")
evaluator.save_evaluation_results(save_plots=True)
# 7. 可视化混淆矩阵
logger.info("可视化混淆矩阵...")
cm = evaluator.evaluation_results['confusion_matrix']
evaluator.metrics.plot_confusion_matrix(
y_true=test_labels,
y_pred=np.argmax(model.predict(X_test), axis=1),
normalize='true',
save_path=os.path.join(output_dir or os.path.dirname(model_path), "confusion_matrix.png")
)
# 8. 类别性能分析
logger.info("分析各类别性能...")
class_performance = evaluator.evaluate_class_performance(X_test, test_labels)
# 9. 计算评估时间
eval_time = time.time() - start_time
logger.info(f"模型评估完成,耗时: {eval_time:.2f}")
# 10. 输出主要指标
logger.info("主要评估指标:")
for metric_name, metric_value in metrics.items():
if metric_name in ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']:
logger.info(f" {metric_name}: {metric_value:.4f}")
return metrics
if __name__ == "__main__":
# 解析命令行参数
parser = argparse.ArgumentParser(description="评估文本分类模型")
parser.add_argument("--model_path", required=True, help="模型路径")
parser.add_argument("--data_dir", help="数据目录")
parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小")
parser.add_argument("--output_dir", help="评估结果输出目录")
args = parser.parse_args()
# 评估模型
evaluate_model(
model_path=args.model_path,
data_dir=args.data_dir,
batch_size=args.batch_size,
output_dir=args.output_dir
)