167 lines
5.2 KiB
Python
167 lines
5.2 KiB
Python
"""
|
||
评估脚本:评估文本分类模型性能
|
||
"""
|
||
import os
|
||
import sys
|
||
import time
|
||
import argparse
|
||
import logging
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
|
||
# 将项目根目录添加到系统路径
|
||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
sys.path.append(project_root)
|
||
|
||
from config.system_config import (
|
||
RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR
|
||
)
|
||
from config.model_config import (
|
||
BATCH_SIZE, MAX_SEQUENCE_LENGTH
|
||
)
|
||
from data.dataloader import DataLoader
|
||
from data.data_manager import DataManager
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
from preprocessing.vectorizer import SequenceVectorizer
|
||
from models.model_factory import ModelFactory
|
||
from evaluation.evaluator import ModelEvaluator
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("Evaluation")
|
||
|
||
|
||
def evaluate_model(model_path: str,
|
||
data_dir: Optional[str] = None,
|
||
batch_size: int = BATCH_SIZE,
|
||
output_dir: Optional[str] = None) -> Dict[str, float]:
|
||
"""
|
||
评估文本分类模型
|
||
|
||
Args:
|
||
model_path: 模型路径
|
||
data_dir: 数据目录,如果为None则使用默认目录
|
||
batch_size: 批大小
|
||
output_dir: 评估结果输出目录,如果为None则使用默认目录
|
||
|
||
Returns:
|
||
评估指标
|
||
"""
|
||
logger.info(f"开始评估模型: {model_path}")
|
||
start_time = time.time()
|
||
|
||
# 设置数据目录
|
||
data_dir = data_dir or RAW_DATA_DIR
|
||
|
||
# 设置输出目录
|
||
if output_dir:
|
||
output_dir = os.path.abspath(output_dir)
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 1. 加载模型
|
||
logger.info("加载模型...")
|
||
model = ModelFactory.load_model(model_path)
|
||
|
||
# 2. 加载数据
|
||
logger.info("加载数据...")
|
||
data_loader = DataLoader(data_dir=data_dir)
|
||
data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR)
|
||
|
||
# 加载测试集
|
||
data_manager.load_data()
|
||
test_texts, test_labels = data_manager.get_data(dataset="test")
|
||
|
||
# 3. 准备数据
|
||
# 创建分词器
|
||
tokenizer = ChineseTokenizer()
|
||
|
||
# 对测试文本进行分词
|
||
logger.info("对文本进行分词...")
|
||
tokenized_test_texts = [tokenizer.tokenize(text, return_string=True) for text in test_texts]
|
||
|
||
# 创建序列向量化器
|
||
logger.info("加载向量化器...")
|
||
# 查找向量化器文件
|
||
vectorizer_path = None
|
||
for model_type in ["cnn", "rnn", "transformer"]:
|
||
path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
|
||
if os.path.exists(path):
|
||
vectorizer_path = path
|
||
break
|
||
|
||
if not vectorizer_path:
|
||
# 如果找不到向量化器,创建一个新的
|
||
logger.warning("未找到向量化器,创建一个新的")
|
||
vectorizer = SequenceVectorizer(
|
||
max_features=MAX_NUM_WORDS,
|
||
max_sequence_length=MAX_SEQUENCE_LENGTH
|
||
)
|
||
else:
|
||
# 加载向量化器
|
||
vectorizer = SequenceVectorizer()
|
||
vectorizer.load(vectorizer_path)
|
||
|
||
# 转换测试文本
|
||
X_test = vectorizer.transform(tokenized_test_texts)
|
||
|
||
# 4. 创建评估器
|
||
logger.info("创建评估器...")
|
||
evaluator = ModelEvaluator(
|
||
model=model,
|
||
class_names=CATEGORIES,
|
||
output_dir=output_dir
|
||
)
|
||
|
||
# 5. 评估模型
|
||
logger.info("评估模型...")
|
||
metrics = evaluator.evaluate(X_test, test_labels, batch_size)
|
||
|
||
# 6. 保存评估结果
|
||
logger.info("保存评估结果...")
|
||
evaluator.save_evaluation_results(save_plots=True)
|
||
|
||
# 7. 可视化混淆矩阵
|
||
logger.info("可视化混淆矩阵...")
|
||
cm = evaluator.evaluation_results['confusion_matrix']
|
||
evaluator.metrics.plot_confusion_matrix(
|
||
y_true=test_labels,
|
||
y_pred=np.argmax(model.predict(X_test), axis=1),
|
||
normalize='true',
|
||
save_path=os.path.join(output_dir or os.path.dirname(model_path), "confusion_matrix.png")
|
||
)
|
||
|
||
# 8. 类别性能分析
|
||
logger.info("分析各类别性能...")
|
||
class_performance = evaluator.evaluate_class_performance(X_test, test_labels)
|
||
|
||
# 9. 计算评估时间
|
||
eval_time = time.time() - start_time
|
||
logger.info(f"模型评估完成,耗时: {eval_time:.2f} 秒")
|
||
|
||
# 10. 输出主要指标
|
||
logger.info("主要评估指标:")
|
||
for metric_name, metric_value in metrics.items():
|
||
if metric_name in ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']:
|
||
logger.info(f" {metric_name}: {metric_value:.4f}")
|
||
|
||
return metrics
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 解析命令行参数
|
||
parser = argparse.ArgumentParser(description="评估文本分类模型")
|
||
parser.add_argument("--model_path", required=True, help="模型路径")
|
||
parser.add_argument("--data_dir", help="数据目录")
|
||
parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小")
|
||
parser.add_argument("--output_dir", help="评估结果输出目录")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 评估模型
|
||
evaluate_model(
|
||
model_path=args.model_path,
|
||
data_dir=args.data_dir,
|
||
batch_size=args.batch_size,
|
||
output_dir=args.output_dir
|
||
)
|