commit ba6d4c40ea982cdf1d936d11910610eecaa68a8b
Author: superlishunqin <852326703@qq.com>
Date: Sat Mar 8 01:34:36 2025 +0800
初始提交,排除大型数据集
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..819786a
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+data/raw/THUCNews/
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..35410ca
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..66e8cc5
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,45 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..a971a2c
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..a476500
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/text_classification.iml b/.idea/text_classification.iml
new file mode 100644
index 0000000..e62124b
--- /dev/null
+++ b/.idea/text_classification.iml
@@ -0,0 +1,19 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..88b1321
--- /dev/null
+++ b/README.md
@@ -0,0 +1,281 @@
+# 基于Python的中文文本分类系统
+
+本项目是一个全面的中文文本分类系统,它基于TensorFlow框架,使用深度学习方法对中文文本进行自动分类。系统采用分层架构设计,由数据层、处理层、模型层和接口层四个主要层次构成。
+
+## 功能特点
+
+- 支持多种深度学习模型:CNN、RNN/LSTM、Transformer
+- 完善的文本预处理:中文分词、停用词过滤、特征工程
+- 丰富的评估指标:准确率、精确率、召回率、F1分数等
+- 多种接口方式:命令行、Web界面、REST API
+- 批量处理能力:支持批量文本、文件上传
+
+## 系统架构
+
+1. **数据层**:负责原始文本数据的存储和管理
+2. **处理层**:实现文本预处理和特征工程
+3. **模型层**:包含核心的分类模型,负责模型训练和预测
+4. **接口层**:提供用户交互界面,支持数据输入输出和结果展示
+
+## 安装方法
+
+### 环境需求
+
+- Python 3.7+
+- TensorFlow 2.5+
+- 其他依赖见 `requirements.txt`
+
+### 安装步骤
+
+1. 克隆仓库:
+
+```bash
+git clone https://git.sq0715.com/qin/Chinese_Text_Classification_System.git
+cd chinese-text-classification
+```
+
+2. 创建虚拟环境(可选):
+
+```bash
+python -m venv venv
+source venv/bin/activate # Linux/Mac
+venv\Scripts\activate # Windows
+```
+
+3. 安装依赖:
+
+```bash
+pip install -r requirements.txt
+```
+
+4. 安装项目:
+
+```bash
+pip install -e .
+```
+
+## 使用方法
+
+### 训练模型
+
+使用以下命令训练模型:
+
+```bash
+python main.py train --data_dir path/to/data --model_type cnn --epochs 10 --batch_size 64
+```
+
+参数说明:
+- `--data_dir`:数据目录,默认使用配置中的目录
+- `--model_type`:模型类型,可选 'cnn', 'rnn', 'transformer'
+- `--epochs`:训练轮数
+- `--batch_size`:批大小
+- `--save_dir`:模型保存目录
+
+### 评估模型
+
+使用以下命令评估模型:
+
+```bash
+python main.py evaluate --model_path path/to/model --data_dir path/to/data
+```
+
+参数说明:
+- `--model_path`:模型路径
+- `--data_dir`:数据目录,默认使用配置中的目录
+- `--output_dir`:评估结果输出目录
+
+### 预测文本
+
+使用以下命令预测单条文本:
+
+```bash
+python main.py predict --text "这是一条测试文本" --model_path path/to/model
+```
+
+使用以下命令预测文件内容:
+
+```bash
+python main.py predict --file path/to/file.txt --model_path path/to/model
+```
+
+### 启动Web服务
+
+使用以下命令启动Web服务:
+
+```bash
+python main.py web --host 0.0.0.0 --port 5000
+```
+
+然后在浏览器中访问 `http://localhost:5000`
+
+### 启动API服务
+
+使用以下命令启动API服务:
+
+```bash
+python main.py api --host 0.0.0.0 --port 8000
+```
+
+然后可以通过 `http://localhost:8000/docs` 查看API文档
+
+### 使用命令行接口
+
+使用以下命令启动交互式命令行接口:
+
+```bash
+python main.py cli --interactive
+```
+
+## 目录结构
+
+```
+text_classification_system/
+│
+├── config/ # 配置文件目录
+├── data/ # 数据层
+├── preprocessing/ # 处理层
+├── models/ # 模型层
+├── training/ # 训练相关
+├── evaluation/ # 评估相关
+├── inference/ # 推理相关
+├── interface/ # 接口层
+├── utils/ # 工具类
+├── saved_models/ # 保存的模型
+├── tests/ # 测试代码
+├── docs/ # 文档
+├── scripts/ # 脚本文件
+│
+├── main.py # 主入口文件
+├── requirements.txt # 依赖列表
+├── setup.py # 安装脚本
+└── README.md # 项目说明
+```
+
+## 数据集
+
+本项目使用的数据集是清华大学开源的THUCNews数据集,包含14个类别的新闻文本。您可以在以下位置下载数据集:
+
+[THUCNews数据集下载链接](http://thuctc.thunlp.org/)
+
+下载后,将数据解压到 `data/raw/THUCNews` 目录下。
+
+## 许可证
+
+本项目采用 MIT 许可证,详情请参见 LICENSE 文件。
+
+## 贡献指南
+
+欢迎贡献代码、报告问题或提出改进建议。请先fork本仓库,然后提交pull request。
+
+## 联系方式
+
+如有问题或建议,请通过 issue 或以下方式联系我们:
+
+- 邮箱:your.email@example.com
+```
+
+# 如何使用这个文本分类系统
+
+## 1. 安装项目
+
+首先,确保您的环境满足以下要求:
+- Python 3.7+
+- TensorFlow 2.5+
+
+安装步骤:
+
+```bash
+# 克隆代码(假设您已有代码)
+cd text_classification
+
+# 创建虚拟环境(可选)
+python -m venv venv
+source venv/bin/activate # Linux/Mac
+# 或者在Windows上:
+# venv\Scripts\activate
+
+# 安装依赖
+pip install -r requirements.txt
+
+# 安装项目(可选)
+pip install -e .
+```
+
+## 2. 数据准备
+
+确保THUCNews数据集位于正确路径:
+```
+data/raw/THUCNews/
+```
+
+数据集应包含多个类别文件夹,每个文件夹中包含对应类别的文本文件。
+
+## 3. 训练模型
+
+使用以下命令训练CNN模型:
+
+```bash
+python main.py train --model_type cnn --epochs 10 --batch_size 64
+```
+
+您也可以尝试其他模型类型:
+- RNN/LSTM模型: `--model_type rnn`
+- Transformer模型: `--model_type transformer`
+
+训练完成后,模型将保存在`saved_models/classifiers/`目录下。
+
+## 4. 评估模型
+
+使用以下命令评估模型性能:
+
+```bash
+python main.py evaluate --model_path path/to/your/model
+```
+
+评估结果将包括准确率、精确率、召回率、F1分数等指标,以及混淆矩阵可视化。
+
+## 5. 使用模型预测
+
+有多种方式使用训练好的模型:
+
+### 命令行预测
+```bash
+# 预测单条文本
+python main.py predict --text "2021年羽毛球冠军是林丹"
+
+# 预测文件内容
+python main.py predict --file path/to/your/file.txt
+```
+
+### 启动Web界面
+```bash
+python main.py web --port 5000
+```
+然后在浏览器中访问`http://localhost:5000`
+
+### 启动API服务
+```bash
+python main.py api --port 8000
+```
+API文档可在`http://localhost:8000/docs`查看
+
+### 交互式命令行
+```bash
+python main.py cli --interactive
+```
+
+## 6. 系统扩展
+
+如需扩展系统功能:
+
+1. 添加新模型:在`models/`目录下添加新的模型类
+2. 调整预处理:修改`preprocessing/`目录下的相关模块
+3. 添加新接口:在`interface/`目录下进行扩展
+
+## 7. 常见问题解决
+
+- 内存不足:减小`batch_size`或使用数据生成器
+- 训练速度慢:调整模型复杂度,考虑使用GPU加速
+- 分类准确率低:尝试不同模型结构,增加数据预处理步骤,调整超参数
+
+这个完整的中文文本分类系统现在已经准备就绪,您可以根据上述说明开始训练和使用!
\ No newline at end of file
diff --git a/config/__init__.py b/config/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/config/model_config.py b/config/model_config.py
new file mode 100644
index 0000000..1b70999
--- /dev/null
+++ b/config/model_config.py
@@ -0,0 +1,72 @@
+"""
+模型配置文件
+"""
+
+# 文本预处理参数
+MAX_SEQUENCE_LENGTH = 500 # 文本序列最大长度
+MAX_NUM_WORDS = 50000 # 词汇表最大大小
+MAX_CHAR_LENGTH = 2000 # 字符级最大长度
+MIN_WORD_FREQUENCY = 5 # 最小词频
+
+# 模型架构参数
+CNN_CONFIG = {
+ "embedding_dim": 200,
+ "num_filters": 256,
+ "filter_sizes": [3, 4, 5],
+ "dropout_rate": 0.5,
+ "l2_reg_lambda": 0.0,
+}
+
+RNN_CONFIG = {
+ "embedding_dim": 200,
+ "hidden_size": 256,
+ "num_layers": 2,
+ "bidirectional": True,
+ "dropout_rate": 0.5,
+}
+
+TRANSFORMER_CONFIG = {
+ "embedding_dim": 200,
+ "num_heads": 8,
+ "ff_dim": 512,
+ "num_layers": 4,
+ "dropout_rate": 0.1,
+}
+
+# 针对RTX 4090的优化设置
+BATCH_SIZE = 128 # RTX 4090有24GB显存,可以支持较大的batch
+EVAL_BATCH_SIZE = 256 # 评估时可以用更大的batch
+
+# 训练参数
+LEARNING_RATE = 1e-3
+NUM_EPOCHS = 20
+EARLY_STOPPING_PATIENCE = 3
+REDUCE_LR_PATIENCE = 2
+REDUCE_LR_FACTOR = 0.5
+VALIDATION_SPLIT = 0.1
+TEST_SPLIT = 0.1
+
+# 词嵌入参数
+USE_PRETRAINED_EMBEDDING = True
+EMBEDDING_TYPE = "word2vec" # 可选: word2vec, glove, fasttext
+
+# 随机种子,保证实验可重复性
+RANDOM_SEED = 42
+
+# 模型保存参数
+SAVE_BEST_ONLY = True
+MODEL_CHECKPOINT_PATH = "best_model.h5"
+
+# 特征工程参数
+USE_CHAR_LEVEL = False # 是否使用字符级特征
+USE_WORD_LEVEL = True # 是否使用词级特征
+USE_TFIDF = False # 是否使用TF-IDF特征
+USE_POS_TAGS = False # 是否使用词性标注特征
+
+# 数据增强参数
+USE_DATA_AUGMENTATION = False
+AUGMENTATION_FACTOR = 0.2 # 增强20%的数据
+
+# 推理参数
+PREDICTION_THRESHOLD = 0.5
+TOP_K_PREDICTIONS = 3
\ No newline at end of file
diff --git a/config/system_config.py b/config/system_config.py
new file mode 100644
index 0000000..b677cd8
--- /dev/null
+++ b/config/system_config.py
@@ -0,0 +1,71 @@
+"""
+系统全局配置文件
+"""
+import os
+import platform
+from pathlib import Path
+
+# 项目根目录
+ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+"""
+Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 是当前文件的上一级目录
+这种写法主要是为了方便移植项目到不同的平台运行
+"""
+
+# 数据相关路径
+DATA_DIR = ROOT_DIR / "data"
+RAW_DATA_DIR = DATA_DIR / "raw" / "THUCNews"
+PROCESSED_DATA_DIR = DATA_DIR / "processed"
+RESOURCES_DIR = DATA_DIR / "resources"
+STOPWORDS_DIR = RESOURCES_DIR / "stopwords"
+EMBEDDINGS_DIR = RESOURCES_DIR / "embeddings"
+
+# 确保必要的目录存在
+for directory in [PROCESSED_DATA_DIR, RESOURCES_DIR, STOPWORDS_DIR, EMBEDDINGS_DIR]:
+ directory.mkdir(parents=True, exist_ok=True)
+
+# 保存模型的路径
+SAVED_MODELS_DIR = ROOT_DIR / "saved_models"
+TOKENIZERS_DIR = SAVED_MODELS_DIR / "tokenizers"
+CLASSIFIERS_DIR = SAVED_MODELS_DIR / "classifiers"
+
+# 确保模型保存目录存在
+for directory in [SAVED_MODELS_DIR, TOKENIZERS_DIR, CLASSIFIERS_DIR]:
+ directory.mkdir(parents=True, exist_ok=True)
+
+# 系统资源配置
+CPU_COUNT = os.cpu_count()
+USE_GPU = True
+MULTI_GPU = False # 目前只使用单个GPU
+
+# 基于13900K性能设置并行处理参数
+DATA_LOADING_WORKERS = min(16, CPU_COUNT) # 数据加载线程数
+PREPROCESSING_WORKERS = min(24, CPU_COUNT) # 预处理线程数,13900K有强大的多线程能力
+
+# 基于64GB内存设置内存相关参数
+MAX_MEMORY_GB = 48 # 保留部分内存给系统和其他应用
+MAX_TEXT_PER_BATCH = 10000 # 每批处理的最大文本数量
+
+# 日志配置
+LOG_DIR = ROOT_DIR / "logs"
+LOG_DIR.mkdir(exist_ok=True)
+LOG_LEVEL = "INFO"
+LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+
+# 类别标签映射(与THUCNews数据集一致)
+CATEGORIES = [
+ "体育", "娱乐", "家居", "彩票", "房产", "教育",
+ "时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"
+]
+CATEGORY_TO_ID = {category: idx for idx, category in enumerate(CATEGORIES)}
+ID_TO_CATEGORY = {idx: category for idx, category in enumerate(CATEGORIES)}
+
+# 文件编码
+ENCODING = "utf-8"
+
+# 系统信息
+SYSTEM_INFO = {
+ "platform": platform.platform(),
+ "python_version": platform.python_version(),
+ "processor": platform.processor(),
+}
\ No newline at end of file
diff --git a/create_project_structure.sh b/create_project_structure.sh
new file mode 100755
index 0000000..5715ac6
--- /dev/null
+++ b/create_project_structure.sh
@@ -0,0 +1,113 @@
+#!/bin/bash
+# 项目结构创建脚本
+# 用于创建文本分类系统的完整目录结构
+# 作者:AI助手
+# 日期:2023
+
+echo "开始创建项目结构..."
+
+# 创建配置目录和文件
+mkdir -p config
+touch config/__init__.py
+touch config/model_config.py
+touch config/system_config.py
+
+# 创建数据层目录和文件
+mkdir -p data/raw data/processed data/resources/stopwords data/resources/embeddings
+touch data/__init__.py
+touch data/dataloader.py
+touch data/data_manager.py
+touch data/dataset.py
+
+# 创建处理层目录和文件
+mkdir -p preprocessing
+touch preprocessing/__init__.py
+touch preprocessing/text_cleaner.py
+touch preprocessing/tokenization.py
+touch preprocessing/feature_extraction.py
+touch preprocessing/vectorizer.py
+touch preprocessing/data_augmentation.py
+
+# 创建模型层目录和文件
+mkdir -p models/layers
+touch models/__init__.py
+touch models/base_model.py
+touch models/cnn_model.py
+touch models/rnn_model.py
+touch models/transformer_model.py
+touch models/ensemble_model.py
+touch models/model_factory.py
+touch models/layers/__init__.py
+
+# 创建训练模块
+mkdir -p training
+touch training/__init__.py
+touch training/trainer.py
+touch training/optimizer.py
+touch training/callbacks.py
+touch training/scheduler.py
+
+# 创建评估模块
+mkdir -p evaluation
+touch evaluation/__init__.py
+touch evaluation/evaluator.py
+touch evaluation/metrics.py
+touch evaluation/visualization.py
+
+# 创建推理模块
+mkdir -p inference
+touch inference/__init__.py
+touch inference/predictor.py
+touch inference/batch_processor.py
+
+# 创建接口层
+mkdir -p interface/web/templates
+touch interface/__init__.py
+touch interface/cli.py
+touch interface/api.py
+touch interface/web/__init__.py
+touch interface/web/app.py
+touch interface/web/routes.py
+
+# 创建工具模块
+mkdir -p utils
+touch utils/__init__.py
+touch utils/logger.py
+touch utils/file_utils.py
+touch utils/time_utils.py
+touch utils/text_utils.py
+
+# 创建模型保存目录
+mkdir -p saved_models/tokenizers saved_models/classifiers
+
+# 创建测试目录
+mkdir -p tests
+touch tests/__init__.py
+touch tests/test_preprocessing.py
+touch tests/test_models.py
+touch tests/test_evaluation.py
+
+# 创建文档目录
+mkdir -p docs
+touch docs/architecture.md
+touch docs/api_reference.md
+touch docs/usage_guide.md
+
+# 创建脚本目录
+mkdir -p scripts
+touch scripts/train.py
+touch scripts/evaluate.py
+touch scripts/predict.py
+
+# 创建主要文件
+touch main.py
+touch requirements.txt
+touch setup.py
+touch README.md
+
+echo "项目结构创建完成!"
+echo "------------------------------"
+echo "目录结构概览:"
+find . -type d | sort
+echo "------------------------------"
+echo "文件总数: $(find . -type f | wc -l)"
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/data/data_manager.py b/data/data_manager.py
new file mode 100644
index 0000000..5fd95f6
--- /dev/null
+++ b/data/data_manager.py
@@ -0,0 +1,583 @@
+"""
+数据管理模块:负责数据的存储、读取和转换
+"""
+import os
+import pickle
+import json
+import time
+from typing import List, Dict, Tuple, Optional, Any, Union
+import numpy as np
+import pandas as pd
+from collections import Counter
+import matplotlib.pyplot as plt
+from sklearn.model_selection import train_test_split
+
+from config.system_config import (
+ PROCESSED_DATA_DIR, ENCODING, CATEGORY_TO_ID, ID_TO_CATEGORY
+)
+from config.model_config import (
+ VALIDATION_SPLIT, TEST_SPLIT, RANDOM_SEED
+)
+from utils.logger import get_logger
+from utils.file_utils import (
+ save_pickle, load_pickle, save_json, load_json, ensure_dir
+)
+from data.dataloader import DataLoader
+
+logger = get_logger("DataManager")
+
+
+class DataManager:
+ """数据管理类,负责数据的存储、读取和转换"""
+
+ def __init__(self, processed_dir: Optional[str] = None):
+ """
+ 初始化数据管理器
+
+ Args:
+ processed_dir: 处理后数据的存储目录,默认使用配置文件中的路径
+ """
+ self.processed_dir = processed_dir or PROCESSED_DATA_DIR
+ ensure_dir(self.processed_dir)
+
+ # 数据分割后的存储
+ self.train_texts = []
+ self.train_labels = []
+ self.val_texts = []
+ self.val_labels = []
+ self.test_texts = []
+ self.test_labels = []
+
+ # 数据统计信息
+ self.stats = {}
+
+ # 标签编码映射
+ self.label_to_id = CATEGORY_TO_ID
+ self.id_to_label = ID_TO_CATEGORY
+
+ logger.info(f"数据管理器初始化完成,处理后数据将存储在 {self.processed_dir}")
+
+ def load_and_split_data(self, data_loader: DataLoader,
+ categories: Optional[List[str]] = None,
+ val_split: float = VALIDATION_SPLIT,
+ test_split: float = TEST_SPLIT,
+ sample_ratio: float = 1.0,
+ balanced: bool = False,
+ n_per_category: int = 1000,
+ save: bool = True) -> Dict[str, Any]:
+ """
+ 加载并分割数据集
+
+ Args:
+ data_loader: 数据加载器实例
+ categories: 要包含的类别列表,默认为所有类别
+ val_split: 验证集比例
+ test_split: 测试集比例
+ sample_ratio: 采样比例,默认为1.0(全部)
+ balanced: 是否平衡各类别的样本数量
+ n_per_category: 平衡模式下每个类别的样本数量
+ save: 是否保存处理后的数据
+
+ Returns:
+ 包含分割后数据集的字典
+ """
+ start_time = time.time()
+
+ # 加载数据
+ if balanced:
+ logger.info(f"加载平衡数据集,每个类别 {n_per_category} 个样本")
+ data = data_loader.load_balanced_data(
+ n_per_category=n_per_category,
+ categories=categories,
+ shuffle=True
+ )
+ else:
+ logger.info(f"加载数据集,采样比例 {sample_ratio}")
+ data = data_loader.load_data(
+ categories=categories,
+ sample_ratio=sample_ratio,
+ shuffle=True,
+ return_generator=False
+ )
+
+ logger.info(f"加载了 {len(data)} 个样本")
+
+ # 分离文本和标签
+ texts = [text for text, _ in data]
+ labels = [label for _, label in data]
+
+ # 进行标签编码
+ encoded_labels = np.array([self.label_to_id[label] for label in labels])
+
+ # 计算数据统计信息
+ self._compute_stats(texts, labels)
+
+ # 划分训练集、验证集和测试集
+ # 先分出测试集
+ if test_split > 0:
+ train_val_texts, self.test_texts, train_val_labels, self.test_labels = train_test_split(
+ texts, encoded_labels,
+ test_size=test_split,
+ random_state=RANDOM_SEED,
+ stratify=encoded_labels if len(set(encoded_labels)) > 1 else None
+ )
+ else:
+ train_val_texts, train_val_labels = texts, encoded_labels
+ self.test_texts, self.test_labels = [], []
+
+ # 再划分训练集和验证集
+ if val_split > 0:
+ self.train_texts, self.val_texts, self.train_labels, self.val_labels = train_test_split(
+ train_val_texts, train_val_labels,
+ test_size=val_split / (1 - test_split),
+ random_state=RANDOM_SEED,
+ stratify=train_val_labels if len(set(train_val_labels)) > 1 else None
+ )
+ else:
+ self.train_texts, self.train_labels = train_val_texts, train_val_labels
+ self.val_texts, self.val_labels = [], []
+
+ # 打印数据集划分结果
+ logger.info(f"数据集划分结果:")
+ logger.info(f" 训练集:{len(self.train_texts)} 个样本")
+ logger.info(f" 验证集:{len(self.val_texts)} 个样本")
+ logger.info(f" 测试集:{len(self.test_texts)} 个样本")
+
+ # 保存处理后的数据
+ if save:
+ self.save_data()
+
+ elapsed = time.time() - start_time
+ logger.info(f"数据加载和分割完成,用时 {elapsed:.2f} 秒")
+
+ return {
+ "train_texts": self.train_texts,
+ "train_labels": self.train_labels,
+ "val_texts": self.val_texts,
+ "val_labels": self.val_labels,
+ "test_texts": self.test_texts,
+ "test_labels": self.test_labels,
+ "stats": self.stats
+ }
+
+ def _compute_stats(self, texts: List[str], labels: List[str]) -> None:
+ """
+ 计算数据统计信息
+
+ Args:
+ texts: 文本列表
+ labels: 标签列表
+ """
+ # 文本数量
+ num_samples = len(texts)
+
+ # 类别分布
+ label_counter = Counter(labels)
+ label_distribution = {label: count / num_samples * 100 for label, count in label_counter.items()}
+
+ # 文本长度统计
+ text_lengths = [len(text) for text in texts]
+ avg_length = sum(text_lengths) / len(text_lengths)
+ max_length = max(text_lengths)
+ min_length = min(text_lengths)
+
+ # 前5个最长和最短的文本的长度
+ sorted_lengths = sorted(text_lengths)
+ shortest_lengths = sorted_lengths[:5]
+ longest_lengths = sorted_lengths[-5:]
+
+ # 95%的文本长度分位数
+ percentile_95 = np.percentile(text_lengths, 95)
+
+ # 存储统计信息
+ self.stats = {
+ "num_samples": num_samples,
+ "num_categories": len(label_counter),
+ "label_counter": label_counter,
+ "label_distribution": label_distribution,
+ "text_length": {
+ "average": avg_length,
+ "max": max_length,
+ "min": min_length,
+ "percentile_95": percentile_95,
+ "shortest_5": shortest_lengths,
+ "longest_5": longest_lengths
+ }
+ }
+
+ def save_data(self, save_dir: Optional[str] = None) -> None:
+ """
+ 保存处理后的数据
+
+ Args:
+ save_dir: 保存目录,默认使用初始化时设置的目录
+ """
+ save_dir = save_dir or self.processed_dir
+ ensure_dir(save_dir)
+
+ # 保存训练集
+ save_pickle(
+ {"texts": self.train_texts, "labels": self.train_labels},
+ os.path.join(save_dir, "train_data.pkl")
+ )
+
+ # 保存验证集
+ if len(self.val_texts) > 0:
+ save_pickle(
+ {"texts": self.val_texts, "labels": self.val_labels},
+ os.path.join(save_dir, "val_data.pkl")
+ )
+
+ # 保存测试集
+ if len(self.test_texts) > 0:
+ save_pickle(
+ {"texts": self.test_texts, "labels": self.test_labels},
+ os.path.join(save_dir, "test_data.pkl")
+ )
+
+ # 保存标签编码映射
+ save_json(
+ {"label_to_id": self.label_to_id, "id_to_label": self.id_to_label},
+ os.path.join(save_dir, "label_mapping.json")
+ )
+
+ # 保存数据统计信息
+ # 将Counter对象转换为普通字典以便JSON序列化
+ stats_for_json = self.stats.copy()
+ if "label_counter" in stats_for_json:
+ stats_for_json["label_counter"] = dict(stats_for_json["label_counter"])
+
+ save_json(
+ stats_for_json,
+ os.path.join(save_dir, "data_stats.json")
+ )
+
+ logger.info(f"已将处理后的数据保存到 {save_dir}")
+
+ def load_data(self, load_dir: Optional[str] = None) -> Dict[str, Any]:
+ """
+ 加载处理后的数据
+
+ Args:
+ load_dir: 加载目录,默认使用初始化时设置的目录
+
+ Returns:
+ 包含加载的数据集的字典
+ """
+ load_dir = load_dir or self.processed_dir
+
+ # 加载训练集
+ train_data_path = os.path.join(load_dir, "train_data.pkl")
+ if os.path.exists(train_data_path):
+ train_data = load_pickle(train_data_path)
+ self.train_texts = train_data["texts"]
+ self.train_labels = train_data["labels"]
+ logger.info(f"已加载训练集,包含 {len(self.train_texts)} 个样本")
+ else:
+ logger.warning(f"训练集文件不存在: {train_data_path}")
+ self.train_texts, self.train_labels = [], []
+
+ # 加载验证集
+ val_data_path = os.path.join(load_dir, "val_data.pkl")
+ if os.path.exists(val_data_path):
+ val_data = load_pickle(val_data_path)
+ self.val_texts = val_data["texts"]
+ self.val_labels = val_data["labels"]
+ logger.info(f"已加载验证集,包含 {len(self.val_texts)} 个样本")
+ else:
+ logger.warning(f"验证集文件不存在: {val_data_path}")
+ self.val_texts, self.val_labels = [], []
+
+ # 加载测试集
+ test_data_path = os.path.join(load_dir, "test_data.pkl")
+ if os.path.exists(test_data_path):
+ test_data = load_pickle(test_data_path)
+ self.test_texts = test_data["texts"]
+ self.test_labels = test_data["labels"]
+ logger.info(f"已加载测试集,包含 {len(self.test_texts)} 个样本")
+ else:
+ logger.warning(f"测试集文件不存在: {test_data_path}")
+ self.test_texts, self.test_labels = [], []
+
+ # 加载标签编码映射
+ mapping_path = os.path.join(load_dir, "label_mapping.json")
+ if os.path.exists(mapping_path):
+ mapping = load_json(mapping_path)
+ self.label_to_id = mapping["label_to_id"]
+ self.id_to_label = mapping["id_to_label"]
+ # 将字符串键转换为整数(JSON序列化会将所有键转为字符串)
+ self.id_to_label = {int(k): v for k, v in self.id_to_label.items()}
+ logger.info(f"已加载标签编码映射,共 {len(self.label_to_id)} 个类别")
+
+ # 加载数据统计信息
+ stats_path = os.path.join(load_dir, "data_stats.json")
+ if os.path.exists(stats_path):
+ self.stats = load_json(stats_path)
+ logger.info("已加载数据统计信息")
+
+ return {
+ "train_texts": self.train_texts,
+ "train_labels": self.train_labels,
+ "val_texts": self.val_texts,
+ "val_labels": self.val_labels,
+ "test_texts": self.test_texts,
+ "test_labels": self.test_labels,
+ "stats": self.stats
+ }
+
+ def get_label_distribution(self, dataset: str = "train") -> Dict[str, float]:
+ """
+ 获取指定数据集的标签分布
+
+ Args:
+ dataset: 数据集名称,可选值:'train', 'val', 'test'
+
+ Returns:
+ 标签分布字典,键为类别名称,值为比例
+ """
+ if dataset == "train":
+ labels = self.train_labels
+ elif dataset == "val":
+ labels = self.val_labels
+ elif dataset == "test":
+ labels = self.test_labels
+ else:
+ raise ValueError(f"不支持的数据集名称: {dataset}")
+
+ # 计算标签分布
+ label_counter = Counter(labels)
+ num_samples = len(labels)
+
+ # 将数字标签转换为类别名称
+ distribution = {}
+ for label_id, count in label_counter.items():
+ label_name = self.id_to_label.get(label_id, str(label_id))
+ distribution[label_name] = count / num_samples * 100
+
+ return distribution
+
+ def visualize_label_distribution(self, dataset: str = "train",
+ save_path: Optional[str] = None) -> None:
+ """
+ 可视化标签分布
+
+ Args:
+ dataset: 数据集名称,可选值:'train', 'val', 'test', 'all'
+ save_path: 图表保存路径,默认为None(显示而不保存)
+ """
+ plt.figure(figsize=(12, 8))
+
+ if dataset == "all":
+ # 显示所有数据集的标签分布
+ train_dist = self.get_label_distribution("train")
+ val_dist = self.get_label_distribution("val") if len(self.val_labels) > 0 else {}
+ test_dist = self.get_label_distribution("test") if len(self.test_labels) > 0 else {}
+
+ # 准备数据
+ categories = list(train_dist.keys())
+ train_values = [train_dist.get(cat, 0) for cat in categories]
+ val_values = [val_dist.get(cat, 0) for cat in categories]
+ test_values = [test_dist.get(cat, 0) for cat in categories]
+
+ # 绘制条形图
+ x = np.arange(len(categories))
+ width = 0.25
+
+ plt.bar(x - width, train_values, width, label="Training")
+ if val_values:
+ plt.bar(x, val_values, width, label="Validation")
+ if test_values:
+ plt.bar(x + width, test_values, width, label="Testing")
+
+ plt.xlabel("Categories")
+ plt.ylabel("Percentage (%)")
+ plt.title("Label Distribution Across Datasets")
+ plt.xticks(x, categories, rotation=45, ha="right")
+ plt.legend()
+ plt.tight_layout()
+ else:
+ # 显示单个数据集的标签分布
+ distribution = self.get_label_distribution(dataset)
+
+ # 按值排序
+ sorted_items = sorted(distribution.items(), key=lambda x: x[1], reverse=True)
+ categories = [item[0] for item in sorted_items]
+ values = [item[1] for item in sorted_items]
+
+ # 绘制条形图
+ plt.bar(categories, values, color='skyblue')
+ plt.xlabel("Categories")
+ plt.ylabel("Percentage (%)")
+ plt.title(f"Label Distribution in {dataset.capitalize()} Dataset")
+ plt.xticks(rotation=45, ha="right")
+ plt.tight_layout()
+
+ # 保存或显示图表
+ if save_path:
+ plt.savefig(save_path)
+ logger.info(f"标签分布图已保存到 {save_path}")
+ else:
+ plt.show()
+
+ def visualize_text_length_distribution(self, dataset: str = "train",
+ bins: int = 50,
+ save_path: Optional[str] = None) -> None:
+ """
+ 可视化文本长度分布
+
+ Args:
+ dataset: 数据集名称,可选值:'train', 'val', 'test'
+ bins: 直方图的箱数
+ save_path: 图表保存路径,默认为None(显示而不保存)
+ """
+ if dataset == "train":
+ texts = self.train_texts
+ elif dataset == "val":
+ texts = self.val_texts
+ elif dataset == "test":
+ texts = self.test_texts
+ else:
+ raise ValueError(f"不支持的数据集名称: {dataset}")
+
+ # 计算文本长度
+ text_lengths = [len(text) for text in texts]
+
+ # 绘制直方图
+ plt.figure(figsize=(10, 6))
+ plt.hist(text_lengths, bins=bins, color='skyblue', alpha=0.7)
+
+ # 计算并绘制一些统计量
+ avg_length = sum(text_lengths) / len(text_lengths)
+ median_length = np.median(text_lengths)
+ percentile_95 = np.percentile(text_lengths, 95)
+
+ plt.axvline(avg_length, color='red', linestyle='dashed', linewidth=1, label=f'Mean: {avg_length:.1f}')
+ plt.axvline(median_length, color='green', linestyle='dashed', linewidth=1, label=f'Median: {median_length:.1f}')
+ plt.axvline(percentile_95, color='purple', linestyle='dashed', linewidth=1,
+ label=f'95th Percentile: {percentile_95:.1f}')
+
+ plt.xlabel('Text Length (characters)')
+ plt.ylabel('Frequency')
+ plt.title(f'Text Length Distribution in {dataset.capitalize()} Dataset')
+ plt.legend()
+ plt.tight_layout()
+
+ # 保存或显示图表
+ if save_path:
+ plt.savefig(save_path)
+ logger.info(f"文本长度分布图已保存到 {save_path}")
+ else:
+ plt.show()
+
+ def get_data_summary(self) -> Dict[str, Any]:
+ """
+ 获取数据集的摘要信息
+
+ Returns:
+ 包含数据摘要的字典
+ """
+ # 获取数据集的基本信息
+ summary = {
+ "train_size": len(self.train_texts),
+ "val_size": len(self.val_texts),
+ "test_size": len(self.test_texts),
+ "num_categories": len(self.label_to_id),
+ "categories": list(self.label_to_id.keys()),
+ }
+
+ # 添加训练集的标签分布
+ if len(self.train_texts) > 0:
+ summary["train_label_distribution"] = self.get_label_distribution("train")
+
+ # 添加验证集的标签分布
+ if len(self.val_texts) > 0:
+ summary["val_label_distribution"] = self.get_label_distribution("val")
+
+ # 添加测试集的标签分布
+ if len(self.test_texts) > 0:
+ summary["test_label_distribution"] = self.get_label_distribution("test")
+
+ # 添加更多统计信息(如果有)
+ if self.stats:
+ # 只添加一些关键的统计信息
+ if "text_length" in self.stats:
+ summary["text_length_stats"] = self.stats["text_length"]
+
+ return summary
+
+ def export_to_pandas(self, dataset: str = "train") -> pd.DataFrame:
+ """
+ 将数据导出为Pandas DataFrame
+
+ Args:
+ dataset: 数据集名称,可选值:'train', 'val', 'test'
+
+ Returns:
+ Pandas DataFrame
+ """
+ if dataset == "train":
+ texts = self.train_texts
+ labels_ids = self.train_labels
+ elif dataset == "val":
+ texts = self.val_texts
+ labels_ids = self.val_labels
+ elif dataset == "test":
+ texts = self.test_texts
+ labels_ids = self.test_labels
+ else:
+ raise ValueError(f"不支持的数据集名称: {dataset}")
+
+ # 将数字标签转换为类别名称
+ labels = [self.id_to_label.get(label_id, str(label_id)) for label_id in labels_ids]
+
+ # 创建DataFrame
+ df = pd.DataFrame({
+ "text": texts,
+ "label_id": labels_ids,
+ "label": labels
+ })
+
+ return df
+
+ def get_label_name(self, label_id: int) -> str:
+ """
+ 获取标签ID对应的类别名称
+
+ Args:
+ label_id: 标签ID
+
+ Returns:
+ 类别名称
+ """
+ return self.id_to_label.get(label_id, str(label_id))
+
+ def get_label_id(self, label_name: str) -> int:
+ """
+ 获取类别名称对应的标签ID
+
+ Args:
+ label_name: 类别名称
+
+ Returns:
+ 标签ID
+ """
+ return self.label_to_id.get(label_name, -1)
+
+ def get_data(self, dataset: str = "train") -> Tuple[List[str], np.ndarray]:
+ """
+ 获取指定数据集的文本和标签
+
+ Args:
+ dataset: 数据集名称,可选值:'train', 'val', 'test'
+
+ Returns:
+ (文本列表, 标签数组)的元组
+ """
+ if dataset == "train":
+ return self.train_texts, self.train_labels
+ elif dataset == "val":
+ return self.val_texts, self.val_labels
+ elif dataset == "test":
+ return self.test_texts, self.test_labels
+ else:
+ raise ValueError(f"不支持的数据集名称: {dataset}")
\ No newline at end of file
diff --git a/data/dataloader.py b/data/dataloader.py
new file mode 100644
index 0000000..4a8aa07
--- /dev/null
+++ b/data/dataloader.py
@@ -0,0 +1,296 @@
+"""
+数据加载模块:负责从文件系统加载原始文本数据
+"""
+import os
+import glob
+import time
+from pathlib import Path
+from typing import List, Dict, Tuple, Optional, Any
+from concurrent.futures import ThreadPoolExecutor, as_completed
+import random
+import numpy as np
+
+from config.system_config import (
+ RAW_DATA_DIR, DATA_LOADING_WORKERS, CATEGORIES,
+ CATEGORY_TO_ID, ENCODING, MAX_MEMORY_GB, MAX_TEXT_PER_BATCH
+)
+from config.model_config import RANDOM_SEED
+from utils.logger import get_logger
+from utils.file_utils import read_text_file, read_files_parallel, list_files
+
+# 设置随机种子以保证可重复性
+random.seed(RANDOM_SEED)
+np.random.seed(RANDOM_SEED)
+
+logger = get_logger("DataLoader")
+
+
+class DataLoader:
+ """负责加载THUCNews数据集的类"""
+
+ def __init__(self, data_dir: Optional[str] = None,
+ categories: Optional[List[str]] = None,
+ encoding: str = ENCODING,
+ max_workers: int = DATA_LOADING_WORKERS,
+ max_text_per_batch: int = MAX_TEXT_PER_BATCH):
+ """
+ 初始化数据加载器
+
+ Args:
+ data_dir: 数据目录,默认使用配置文件中的路径
+ categories: 要加载的类别列表,默认加载所有类别
+ encoding: 文件编码
+ max_workers: 最大工作线程数
+ max_text_per_batch: 每批处理的最大文本数量
+ """
+ self.data_dir = Path(data_dir) if data_dir else RAW_DATA_DIR
+ self.categories = categories if categories else CATEGORIES
+ self.encoding = encoding
+ self.max_workers = max_workers
+ self.max_text_per_batch = max_text_per_batch
+
+ # 验证数据目录是否存在
+ if not self.data_dir.exists():
+ raise FileNotFoundError(f"数据目录不存在: {self.data_dir}")
+
+ # 验证类别是否存在
+ for category in self.categories:
+ category_dir = self.data_dir / category
+ if not category_dir.exists():
+ logger.warning(f"类别目录不存在: {category_dir}")
+
+ # 存储类别目录的映射
+ self.category_dirs = {
+ category: self.data_dir / category
+ for category in self.categories
+ if (self.data_dir / category).exists()
+ }
+
+ # 记录类别文件数量
+ self.category_file_counts = {}
+
+ # 统计并记录每个类别的文件数量
+ self._count_files()
+ logger.info(f"初始化完成,共找到 {sum(self.category_file_counts.values())} 个文本文件")
+
+ def _count_files(self) -> None:
+ """统计每个类别的文件数量"""
+ for category, category_dir in self.category_dirs.items():
+ files = list(category_dir.glob("*.txt"))
+ self.category_file_counts[category] = len(files)
+ logger.info(f"类别 [{category}] 包含 {len(files)} 个文本文件")
+
+ def get_file_paths(self, category: Optional[str] = None,
+ sample_ratio: float = 1.0,
+ shuffle: bool = True) -> List[Tuple[str, str]]:
+ """
+ 获取指定类别的文件路径列表
+
+ Args:
+ category: 类别名称,如果为None则获取所有类别
+ sample_ratio: 采样比例,默认为1.0(全部)
+ shuffle: 是否打乱文件顺序
+
+ Returns:
+ 包含(文件路径, 类别)元组的列表
+ """
+ file_paths = []
+
+ # 确定要处理的类别
+ categories_to_process = [category] if category else self.categories
+
+ # 获取每个类别的文件路径
+ for cat in categories_to_process:
+ if cat in self.category_dirs:
+ category_dir = self.category_dirs[cat]
+ cat_files = list(category_dir.glob("*.txt"))
+
+ # 采样
+ if sample_ratio < 1.0:
+ sample_size = int(len(cat_files) * sample_ratio)
+ if shuffle:
+ cat_files = random.sample(cat_files, sample_size)
+ else:
+ cat_files = cat_files[:sample_size]
+
+ # 添加文件路径和对应的类别
+ file_paths.extend([(str(file), cat) for file in cat_files])
+
+ # 打乱全局顺序(如果需要)
+ if shuffle:
+ random.shuffle(file_paths)
+
+ return file_paths
+
+ def load_texts(self, file_paths: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
+ """
+ 加载指定路径的文本内容
+
+ Args:
+ file_paths: 包含(文件路径, 类别)元组的列表
+
+ Returns:
+ 包含(文本内容, 类别)元组的列表
+ """
+ start_time = time.time()
+ texts_with_labels = []
+
+ # 提取文件路径列表
+ paths = [path for path, _ in file_paths]
+ labels = [label for _, label in file_paths]
+
+ # 并行加载文本内容
+ contents = read_files_parallel(paths, max_workers=self.max_workers, encoding=self.encoding)
+
+ # 将内容与标签配对
+ for content, label in zip(contents, labels):
+ if content: # 确保内容不为空
+ texts_with_labels.append((content, label))
+
+ elapsed = time.time() - start_time
+ logger.info(f"加载了 {len(texts_with_labels)} 个文本,用时 {elapsed:.2f} 秒")
+
+ return texts_with_labels
+
+ def load_data(self, categories: Optional[List[str]] = None,
+ sample_ratio: float = 1.0,
+ shuffle: bool = True,
+ return_generator: bool = False) -> Any:
+ """
+ 加载指定类别的所有数据
+
+ Args:
+ categories: 要加载的类别列表,默认为所有类别
+ sample_ratio: 采样比例,默认为1.0(全部)
+ shuffle: 是否打乱数据顺序
+ return_generator: 是否返回生成器(批量加载)
+
+ Returns:
+ 如果return_generator为False,返回包含(文本内容, 类别)元组的列表
+ 如果return_generator为True,返回一个生成器,每次产生一批数据
+ """
+ # 确定要处理的类别
+ cats_to_process = categories if categories else self.categories
+
+ # 验证类别是否存在
+ for cat in cats_to_process:
+ if cat not in self.category_dirs:
+ logger.warning(f"类别 {cat} 不存在,将被忽略")
+
+ # 筛选存在的类别
+ cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs]
+
+ # 获取所有文件路径
+ all_file_paths = []
+ for cat in cats_to_process:
+ cat_files = self.get_file_paths(cat, sample_ratio=sample_ratio, shuffle=shuffle)
+ all_file_paths.extend(cat_files)
+
+ # 打乱全局顺序(如果需要)
+ if shuffle:
+ random.shuffle(all_file_paths)
+
+ # 如果需要返回生成器,分批次加载数据
+ if return_generator:
+ def data_generator():
+ for i in range(0, len(all_file_paths), self.max_text_per_batch):
+ batch_paths = all_file_paths[i:i + self.max_text_per_batch]
+ batch_data = self.load_texts(batch_paths)
+ yield batch_data
+
+ return data_generator()
+
+ # 否则,一次性加载所有数据
+ return self.load_texts(all_file_paths)
+
+ def load_balanced_data(self, n_per_category: int = 1000,
+ categories: Optional[List[str]] = None,
+ shuffle: bool = True) -> List[Tuple[str, str]]:
+ """
+ 加载平衡的数据集(每个类别的样本数量相同)
+
+ Args:
+ n_per_category: 每个类别加载的样本数量
+ categories: 要加载的类别列表,默认为所有类别
+ shuffle: 是否打乱数据顺序
+
+ Returns:
+ 包含(文本内容, 类别)元组的列表
+ """
+ # 确定要处理的类别
+ cats_to_process = categories if categories else self.categories
+ cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs]
+
+ balanced_data = []
+
+ for cat in cats_to_process:
+ # 获取该类别的文件路径
+ cat_files = self.get_file_paths(cat, shuffle=shuffle)
+
+ # 限制数量
+ cat_files = cat_files[:n_per_category]
+
+ # 加载文本
+ cat_data = self.load_texts(cat_files)
+ balanced_data.extend(cat_data)
+
+ # 打乱全局顺序(如果需要)
+ if shuffle:
+ random.shuffle(balanced_data)
+
+ return balanced_data
+
+ def get_category_distribution(self) -> Dict[str, int]:
+ """
+ 获取数据集的类别分布
+
+ Returns:
+ 包含各类别样本数量的字典
+ """
+ return self.category_file_counts
+
+ def get_data_stats(self) -> Dict[str, Any]:
+ """
+ 获取数据集的统计信息
+
+ Returns:
+ 包含统计信息的字典
+ """
+ # 计算总样本数
+ total_samples = sum(self.category_file_counts.values())
+
+ # 计算各类别占比
+ category_percentages = {
+ cat: count / total_samples * 100
+ for cat, count in self.category_file_counts.items()
+ }
+
+ # 采样几个文件计算平均文本长度
+ sample_files = []
+ for cat in self.categories:
+ if cat in self.category_dirs:
+ cat_files = list((self.data_dir / cat).glob("*.txt"))
+ if cat_files:
+ # 每个类别最多采样10个文件
+ sample_files.extend(random.sample(cat_files, min(10, len(cat_files))))
+
+ # 加载采样的文件内容
+ sample_contents = []
+ for file_path in sample_files:
+ content = read_text_file(str(file_path), encoding=self.encoding)
+ if content:
+ sample_contents.append(content)
+
+ # 计算平均文本长度(字符数)
+ avg_char_length = sum(len(content) for content in sample_contents) / len(
+ sample_contents) if sample_contents else 0
+
+ # 返回统计信息
+ return {
+ "total_samples": total_samples,
+ "category_counts": self.category_file_counts,
+ "category_percentages": category_percentages,
+ "average_text_length": avg_char_length,
+ "categories": self.categories,
+ "num_categories": len(self.categories),
+ }
diff --git a/data/dataset.py b/data/dataset.py
new file mode 100644
index 0000000..e69de29
diff --git a/data/raw/.DS_Store b/data/raw/.DS_Store
new file mode 100644
index 0000000..4774f66
Binary files /dev/null and b/data/raw/.DS_Store differ
diff --git a/docs/api_reference.md b/docs/api_reference.md
new file mode 100644
index 0000000..e69de29
diff --git a/docs/architecture.md b/docs/architecture.md
new file mode 100644
index 0000000..e69de29
diff --git a/docs/usage_guide.md b/docs/usage_guide.md
new file mode 100644
index 0000000..e69de29
diff --git a/evaluation/__init__.py b/evaluation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/evaluation/evaluator.py b/evaluation/evaluator.py
new file mode 100644
index 0000000..3814dfb
--- /dev/null
+++ b/evaluation/evaluator.py
@@ -0,0 +1,491 @@
+"""
+评估器模块:实现模型评估流程
+"""
+import numpy as np
+import tensorflow as tf
+import time
+import os
+from typing import List, Dict, Tuple, Optional, Any, Union, Callable
+import pandas as pd
+import matplotlib.pyplot as plt
+import json
+
+from config.system_config import SAVED_MODELS_DIR
+from models.base_model import TextClassificationModel
+from evaluation.metrics import ClassificationMetrics
+from utils.logger import get_logger
+from utils.file_utils import ensure_dir, save_json
+
+logger = get_logger("Evaluator")
+
+
+class ModelEvaluator:
+ """模型评估器,负责评估模型性能"""
+
+ def __init__(self, model: TextClassificationModel,
+ class_names: Optional[List[str]] = None,
+ output_dir: Optional[str] = None,
+ batch_size: Optional[int] = None):
+ """
+ 初始化模型评估器
+
+ Args:
+ model: 要评估的模型
+ class_names: 类别名称列表
+ output_dir: 输出目录,用于保存评估结果
+ batch_size: 批大小,如果为None则使用模型默认值
+ """
+ self.model = model
+ self.class_names = class_names
+ self.batch_size = batch_size or model.batch_size
+
+ # 设置输出目录
+ if output_dir is None:
+ self.output_dir = os.path.join(SAVED_MODELS_DIR, 'evaluation', model.model_name)
+ else:
+ self.output_dir = output_dir
+
+ ensure_dir(self.output_dir)
+
+ # 创建评估指标计算器
+ self.metrics = ClassificationMetrics(class_names)
+
+ # 评估结果
+ self.evaluation_results = None
+
+ logger.info(f"初始化模型评估器,模型: {model.model_name}")
+
+ def evaluate(self, x_test: Union[np.ndarray, tf.data.Dataset],
+ y_test: Optional[np.ndarray] = None,
+ batch_size: Optional[int] = None,
+ verbose: int = 1) -> Dict[str, float]:
+ """
+ 评估模型
+
+ Args:
+ x_test: 测试数据特征
+ y_test: 测试数据标签
+ batch_size: 批大小
+ verbose: 详细程度
+
+ Returns:
+ 评估结果
+ """
+ batch_size = batch_size or self.batch_size
+
+ logger.info(f"开始评估模型: {self.model.model_name}")
+ start_time = time.time()
+
+ # 使用模型评估
+ model_metrics = self.model.evaluate(x_test, y_test, verbose=verbose)
+
+ # 获取预测结果
+ y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=0)
+ y_pred = np.argmax(y_prob, axis=1)
+
+ # 处理y_test,确保y_test是一维数组
+ if isinstance(x_test, tf.data.Dataset):
+ # 如果是TensorFlow Dataset,需要从中提取y_test
+ y_test_extracted = np.concatenate([y for _, y in x_test], axis=0)
+ y_test = y_test_extracted
+
+ if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
+ y_test = np.argmax(y_test, axis=1)
+
+ # 计算所有指标
+ all_metrics = self.metrics.calculate_all_metrics(y_test, y_pred, y_prob)
+
+ # 合并模型内置指标和自定义指标
+ metrics_names = self.model.model.metrics_names
+ model_metrics_dict = {name: float(value) for name, value in zip(metrics_names, model_metrics)}
+ all_metrics.update(model_metrics_dict)
+
+ # 记录评估时间
+ evaluation_time = time.time() - start_time
+ all_metrics['evaluation_time'] = evaluation_time
+
+ # 保存评估结果
+ self.evaluation_results = {
+ 'metrics': all_metrics,
+ 'confusion_matrix': self.metrics.confusion_matrix(y_test, y_pred).tolist(),
+ 'classification_report': self.metrics.classification_report(y_test, y_pred, output_dict=True)
+ }
+
+ logger.info(f"模型评估完成,用时: {evaluation_time:.2f} 秒")
+ logger.info(f"主要评估指标: accuracy={all_metrics.get('accuracy', 'N/A'):.4f}, "
+ f"f1_macro={all_metrics.get('f1_macro', 'N/A'):.4f}")
+
+ return all_metrics
+
+ def save_evaluation_results(self, save_plots: bool = True) -> str:
+ """
+ 保存评估结果
+
+ Args:
+ save_plots: 是否保存可视化图表
+
+ Returns:
+ 结果保存路径
+ """
+ if self.evaluation_results is None:
+ raise ValueError("请先调用evaluate方法进行评估")
+
+ # 保存评估结果为JSON
+ results_path = os.path.join(self.output_dir, 'evaluation_results.json')
+ with open(results_path, 'w', encoding='utf-8') as f:
+ json.dump(self.evaluation_results, f, ensure_ascii=False, indent=4)
+
+ # 保存评估指标为CSV
+ metrics_df = pd.DataFrame(
+ self.evaluation_results['metrics'].items(),
+ columns=['Metric', 'Value']
+ ).set_index('Metric')
+
+ metrics_path = os.path.join(self.output_dir, 'metrics.csv')
+ metrics_df.to_csv(metrics_path)
+
+ # 保存可视化图表
+ if save_plots:
+ self._save_plots()
+
+ logger.info(f"评估结果已保存到: {self.output_dir}")
+
+ return self.output_dir
+
+ def _save_plots(self) -> None:
+ """保存评估结果可视化图表"""
+ if self.evaluation_results is None:
+ raise ValueError("请先调用evaluate方法进行评估")
+
+ # 创建可视化目录
+ plots_dir = os.path.join(self.output_dir, 'plots')
+ ensure_dir(plots_dir)
+
+ # 混淆矩阵图
+ cm_path = os.path.join(plots_dir, 'confusion_matrix.png')
+ cm = np.array(self.evaluation_results['confusion_matrix'])
+
+ # 将混淆矩阵转换为NumPy数组
+ if isinstance(cm, list):
+ cm = np.array(cm)
+
+ # 绘制混淆矩阵
+ self.metrics.plot_confusion_matrix(
+ np.arange(cm.shape[0]), # 假设标签
+ np.arange(cm.shape[1]), # 假设预测
+ normalize='true',
+ save_path=cm_path
+ )
+
+ # 保存评估指标条形图
+ metrics_path = os.path.join(plots_dir, 'metrics_bar.png')
+ metrics = self.evaluation_results['metrics']
+
+ # 选择要展示的主要指标
+ main_metrics = {
+ 'accuracy': metrics.get('accuracy', 0),
+ 'precision_macro': metrics.get('precision_macro', 0),
+ 'recall_macro': metrics.get('recall_macro', 0),
+ 'f1_macro': metrics.get('f1_macro', 0)
+ }
+
+ # 绘制条形图
+ plt.figure(figsize=(10, 6))
+ plt.bar(main_metrics.keys(), main_metrics.values())
+ plt.title('Main Evaluation Metrics')
+ plt.ylabel('Score')
+ plt.ylim(0, 1)
+ plt.xticks(rotation=45, ha='right')
+ plt.tight_layout()
+ plt.savefig(metrics_path)
+ plt.close()
+
+ # 如果有类别级别的指标,绘制每个类别的指标
+ if 'classification_report' in self.evaluation_results:
+ report = self.evaluation_results['classification_report']
+
+ # 提取每个类别的精确率、召回率和F1值
+ class_metrics = {}
+ for key, value in report.items():
+ if key not in ['accuracy', 'macro avg', 'weighted avg', 'samples avg']:
+ if isinstance(value, dict):
+ class_metrics[key] = value
+
+ if class_metrics:
+ # 绘制每个类别的F1分数
+ class_f1_path = os.path.join(plots_dir, 'class_f1_scores.png')
+
+ classes = list(class_metrics.keys())
+ f1_scores = [metrics['f1-score'] for metrics in class_metrics.values()]
+
+ plt.figure(figsize=(12, 6))
+ bars = plt.bar(classes, f1_scores)
+
+ # 在柱状图上方显示数值
+ for bar in bars:
+ height = bar.get_height()
+ plt.text(bar.get_x() + bar.get_width() / 2., height + 0.01,
+ f'{height:.2f}',
+ ha='center', va='bottom', rotation=0)
+
+ plt.title('F1 Score by Class')
+ plt.ylabel('F1 Score')
+ plt.ylim(0, 1.1)
+ plt.xticks(rotation=45, ha='right')
+ plt.tight_layout()
+ plt.savefig(class_f1_path)
+ plt.close()
+
+ # 绘制每个类别的精确率和召回率
+ class_prec_rec_path = os.path.join(plots_dir, 'class_precision_recall.png')
+
+ precisions = [metrics['precision'] for metrics in class_metrics.values()]
+ recalls = [metrics['recall'] for metrics in class_metrics.values()]
+
+ plt.figure(figsize=(12, 6))
+ x = np.arange(len(classes))
+ width = 0.35
+
+ plt.bar(x - width / 2, precisions, width, label='Precision')
+ plt.bar(x + width / 2, recalls, width, label='Recall')
+
+ plt.ylabel('Score')
+ plt.title('Precision and Recall by Class')
+ plt.xticks(x, classes, rotation=45, ha='right')
+ plt.legend()
+ plt.ylim(0, 1.1)
+ plt.tight_layout()
+ plt.savefig(class_prec_rec_path)
+ plt.close()
+
+ logger.info(f"评估可视化图表已保存到: {plots_dir}")
+
+ def compare_models(self, other_evaluators: List['ModelEvaluator'],
+ metrics: Optional[List[str]] = None,
+ save_path: Optional[str] = None) -> pd.DataFrame:
+ """
+ 比较多个模型的评估结果
+
+ Args:
+ other_evaluators: 其他模型评估器列表
+ metrics: 要比较的指标列表,默认为['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
+ save_path: 比较结果的保存路径
+
+ Returns:
+ 比较结果DataFrame
+ """
+ if self.evaluation_results is None:
+ raise ValueError("请先调用evaluate方法进行评估")
+
+ # 默认比较指标
+ if metrics is None:
+ metrics = ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
+
+ # 收集所有模型的评估指标
+ models_metrics = {}
+
+ # 当前模型
+ models_metrics[self.model.model_name] = {
+ metric: self.evaluation_results['metrics'].get(metric, 'N/A')
+ for metric in metrics
+ }
+
+ # 其他模型
+ for evaluator in other_evaluators:
+ if evaluator.evaluation_results is None:
+ logger.warning(f"模型 {evaluator.model.model_name} 尚未评估,跳过")
+ continue
+
+ models_metrics[evaluator.model.model_name] = {
+ metric: evaluator.evaluation_results['metrics'].get(metric, 'N/A')
+ for metric in metrics
+ }
+
+ # 创建比较DataFrame
+ comparison_df = pd.DataFrame(models_metrics).T
+
+ # 保存比较结果
+ if save_path:
+ comparison_df.to_csv(save_path)
+ logger.info(f"模型比较结果已保存到: {save_path}")
+
+ # 绘制比较条形图
+ plt.figure(figsize=(12, 6))
+ comparison_df.plot(kind='bar', figsize=(12, 6))
+ plt.title('Model Comparison')
+ plt.ylabel('Score')
+ plt.ylim(0, 1)
+ plt.legend(loc='lower right')
+ plt.tight_layout()
+
+ # 如果save_path是CSV文件,将其替换为PNG文件
+ if save_path.endswith('.csv'):
+ plot_path = save_path.replace('.csv', '.png')
+ else:
+ plot_path = save_path + '.png'
+
+ plt.savefig(plot_path)
+ plt.close()
+
+ logger.info(f"模型比较图表已保存到: {plot_path}")
+
+ return comparison_df
+
+ def evaluate_class_performance(self, x_test: Union[np.ndarray, tf.data.Dataset],
+ y_test: Optional[np.ndarray] = None,
+ batch_size: Optional[int] = None,
+ verbose: int = 0) -> pd.DataFrame:
+ """
+ 评估模型在各个类别上的性能
+
+ Args:
+ x_test: 测试数据特征
+ y_test: 测试数据标签
+ batch_size: 批大小
+ verbose: 详细程度
+
+ Returns:
+ 各类别性能指标DataFrame
+ """
+ batch_size = batch_size or self.batch_size
+
+ # 获取预测结果
+ y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=verbose)
+ y_pred = np.argmax(y_prob, axis=1)
+
+ # 处理y_test,确保y_test是一维数组
+ if isinstance(x_test, tf.data.Dataset):
+ # 如果是TensorFlow Dataset,需要从中提取y_test
+ y_test_extracted = np.concatenate([y for _, y in x_test], axis=0)
+ y_test = y_test_extracted
+
+ if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
+ y_test = np.argmax(y_test, axis=1)
+
+ # 获取分类报告
+ report = self.metrics.classification_report(y_test, y_pred, output_dict=True)
+
+ # 提取各类别指标
+ class_metrics = {}
+ for key, value in report.items():
+ if key not in ['accuracy', 'macro avg', 'weighted avg', 'samples avg']:
+ if isinstance(value, dict):
+ class_metrics[key] = value
+
+ # 转换为DataFrame
+ class_performance_df = pd.DataFrame(class_metrics).T
+
+ # 添加支持度(样本数量)
+ class_counts = np.bincount(y_test)
+ for idx, count in enumerate(class_counts):
+ if str(idx) in class_performance_df.index:
+ class_performance_df.loc[str(idx), 'support'] = count
+
+ # 添加类别名称
+ if self.class_names:
+ class_performance_df['class_name'] = [
+ self.class_names[int(idx)] if int(idx) < len(self.class_names) else idx
+ for idx in class_performance_df.index
+ ]
+
+ # 保存类别性能指标
+ performance_path = os.path.join(self.output_dir, 'class_performance.csv')
+ class_performance_df.to_csv(performance_path)
+ logger.info(f"各类别性能指标已保存到: {performance_path}")
+
+ return class_performance_df
+
+ def plot_error_analysis(self, x_test: Union[np.ndarray, tf.data.Dataset],
+ y_test: Optional[np.ndarray] = None,
+ batch_size: Optional[int] = None,
+ num_samples: int = 10,
+ save_path: Optional[str] = None) -> None:
+ """
+ 绘制误分类样本分析
+
+ Args:
+ x_test: 测试数据特征
+ y_test: 测试数据标签
+ batch_size: 批大小
+ num_samples: 要展示的误分类样本数量
+ save_path: 保存路径
+ """
+ # 仅适用于文本数据的分析,需要原始文本
+ logger.info("误分类样本分析需要原始文本数据,此方法可能需要根据实际数据类型进行修改")
+
+ # 在实际应用中,这里应该根据实际数据类型进行修改
+ # 例如,对于序列化的文本,可能需要反序列化,或者使用词汇表将索引转换回文本
+
+ # 此处仅展示一个基本框架
+ batch_size = batch_size or self.batch_size
+
+ # 获取预测结果
+ y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=0)
+ y_pred = np.argmax(y_prob, axis=1)
+
+ # 处理y_test,确保y_test是一维数组
+ if isinstance(x_test, tf.data.Dataset):
+ # 如果是TensorFlow Dataset,需要从中提取y_test和x_test
+ dataset_iterator = iter(x_test)
+ x_test_extracted = []
+ y_test_extracted = []
+
+ try:
+ while True:
+ x, y = next(dataset_iterator)
+ x_test_extracted.append(x)
+ y_test_extracted.append(y)
+ except StopIteration:
+ pass
+
+ x_test = np.concatenate(x_test_extracted, axis=0)
+ y_test = np.concatenate(y_test_extracted, axis=0)
+
+ if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
+ y_test = np.argmax(y_test, axis=1)
+
+ # 找出误分类样本
+ misclassified_indices = np.where(y_pred != y_test)[0]
+
+ # 如果没有误分类样本,返回
+ if len(misclassified_indices) == 0:
+ logger.info("没有误分类样本")
+ return
+
+ # 随机选择一些误分类样本
+ if len(misclassified_indices) > num_samples:
+ misclassified_indices = np.random.choice(misclassified_indices, num_samples, replace=False)
+
+ # 保存误分类样本分析结果
+ misclassified_data = []
+
+ for idx in misclassified_indices:
+ true_label = y_test[idx]
+ pred_label = y_pred[idx]
+
+ true_class = self.class_names[true_label] if self.class_names else str(true_label)
+ pred_class = self.class_names[pred_label] if self.class_names else str(pred_label)
+
+ # 对于序列化的文本,此处需要进行反序列化
+ # 这里仅作示例,实际应用中需要根据具体数据类型修改
+ sample_text = f"Sample {idx}"
+
+ misclassified_data.append({
+ 'sample_id': idx,
+ 'true_label': true_label,
+ 'predicted_label': pred_label,
+ 'true_class': true_class,
+ 'predicted_class': pred_class,
+ 'confidence': float(y_prob[idx, pred_label]),
+ 'sample_text': sample_text
+ })
+
+ # 创建DataFrame
+ misclassified_df = pd.DataFrame(misclassified_data)
+
+ # 保存结果
+ if save_path:
+ misclassified_df.to_csv(save_path)
+ logger.info(f"误分类样本分析已保存到: {save_path}")
+
+ return misclassified_df
diff --git a/evaluation/metrics.py b/evaluation/metrics.py
new file mode 100644
index 0000000..40d544c
--- /dev/null
+++ b/evaluation/metrics.py
@@ -0,0 +1,356 @@
+"""
+评估指标模块:实现各种评估指标
+"""
+import numpy as np
+import tensorflow as tf
+from sklearn.metrics import (
+ accuracy_score, precision_score, recall_score, f1_score,
+ confusion_matrix, classification_report, roc_auc_score,
+ precision_recall_curve, average_precision_score
+)
+import matplotlib.pyplot as plt
+from typing import List, Dict, Tuple, Optional, Any, Union, Callable
+import pandas as pd
+
+from utils.logger import get_logger
+
+logger = get_logger("Metrics")
+
+
+class ClassificationMetrics:
+ """分类评估指标类,计算各种分类评估指标"""
+
+ def __init__(self, class_names: Optional[List[str]] = None):
+ """
+ 初始化分类评估指标类
+
+ Args:
+ class_names: 类别名称列表
+ """
+ self.class_names = class_names
+
+ def accuracy(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
+ """
+ 计算准确率
+
+ Args:
+ y_true: 真实标签
+ y_pred: 预测标签
+
+ Returns:
+ 准确率
+ """
+ return accuracy_score(y_true, y_pred)
+
+ def precision(self, y_true: np.ndarray, y_pred: np.ndarray,
+ average: str = 'macro') -> Union[float, np.ndarray]:
+ """
+ 计算精确率
+
+ Args:
+ y_true: 真实标签
+ y_pred: 预测标签
+ average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
+
+ Returns:
+ 精确率
+ """
+ return precision_score(y_true, y_pred, average=average, zero_division=0)
+
+ def recall(self, y_true: np.ndarray, y_pred: np.ndarray,
+ average: str = 'macro') -> Union[float, np.ndarray]:
+ """
+ 计算召回率
+
+ Args:
+ y_true: 真实标签
+ y_pred: 预测标签
+ average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
+
+ Returns:
+ 召回率
+ """
+ return recall_score(y_true, y_pred, average=average, zero_division=0)
+
+ def f1(self, y_true: np.ndarray, y_pred: np.ndarray,
+ average: str = 'macro') -> Union[float, np.ndarray]:
+ """
+ 计算F1分数
+
+ Args:
+ y_true: 真实标签
+ y_pred: 预测标签
+ average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
+
+ Returns:
+ F1分数
+ """
+ return f1_score(y_true, y_pred, average=average, zero_division=0)
+
+ def confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray,
+ normalize: Optional[str] = None) -> np.ndarray:
+ """
+ 计算混淆矩阵
+
+ Args:
+ y_true: 真实标签
+ y_pred: 预测标签
+ normalize: 归一化方式,可选值: 'true', 'pred', 'all', None
+
+ Returns:
+ 混淆矩阵
+ """
+ cm = confusion_matrix(y_true, y_pred)
+
+ if normalize is not None:
+ if normalize == 'true':
+ cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
+ elif normalize == 'pred':
+ cm = cm.astype('float') / cm.sum(axis=0, keepdims=True)
+ elif normalize == 'all':
+ cm = cm.astype('float') / cm.sum()
+
+ return cm
+
+ def plot_confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray,
+ normalize: Optional[str] = None,
+ figsize: Tuple[int, int] = (10, 8),
+ save_path: Optional[str] = None) -> None:
+ """
+ 绘制混淆矩阵
+
+ Args:
+ y_true: 真实标签
+ y_pred: 预测标签
+ normalize: 归一化方式,可选值: 'true', 'pred', 'all', None
+ figsize: 图像大小
+ save_path: 保存路径,如果为None则显示图像
+ """
+ # 计算混淆矩阵
+ cm = self.confusion_matrix(y_true, y_pred, normalize)
+
+ # 确定类别名称
+ if self.class_names is None:
+ class_names = [str(i) for i in range(cm.shape[0])]
+ else:
+ class_names = self.class_names
+
+ # 创建图像
+ plt.figure(figsize=figsize)
+
+ # 使用热图显示混淆矩阵
+ im = plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
+ plt.colorbar(im)
+
+ # 设置坐标轴标签
+ plt.xticks(np.arange(cm.shape[1]), class_names, rotation=45, ha='right')
+ plt.yticks(np.arange(cm.shape[0]), class_names)
+
+ # 设置标题
+ if normalize is not None:
+ plt.title(f"Normalized ({normalize}) Confusion Matrix")
+ else:
+ plt.title("Confusion Matrix")
+
+ plt.ylabel('True label')
+ plt.xlabel('Predicted label')
+
+ # 在每个单元格中显示数值
+ thresh = cm.max() / 2.0
+ for i in range(cm.shape[0]):
+ for j in range(cm.shape[1]):
+ if normalize is not None:
+ plt.text(j, i, f"{cm[i, j]:.2f}",
+ ha="center", va="center",
+ color="white" if cm[i, j] > thresh else "black")
+ else:
+ plt.text(j, i, f"{cm[i, j]}",
+ ha="center", va="center",
+ color="white" if cm[i, j] > thresh else "black")
+
+ plt.tight_layout()
+
+ # 保存或显示图像
+ if save_path:
+ plt.savefig(save_path)
+ logger.info(f"混淆矩阵图已保存到: {save_path}")
+ else:
+ plt.show()
+
+ def classification_report(self, y_true: np.ndarray, y_pred: np.ndarray,
+ output_dict: bool = False) -> Union[str, Dict]:
+ """
+ 生成分类报告
+
+ Args:
+ y_true: 真实标签
+ y_pred: 预测标签
+ output_dict: 是否以字典形式返回
+
+ Returns:
+ 分类报告
+ """
+ target_names = self.class_names if self.class_names else None
+ return classification_report(y_true, y_pred,
+ target_names=target_names,
+ output_dict=output_dict,
+ zero_division=0)
+
+ def auc_roc(self, y_true: np.ndarray, y_prob: np.ndarray,
+ multi_class: str = 'ovr') -> Union[float, np.ndarray]:
+ """
+ 计算AUC-ROC
+
+ Args:
+ y_true: 真实标签
+ y_prob: 预测概率
+ multi_class: 多分类处理方式,可选值: 'ovr', 'ovo'
+
+ Returns:
+ AUC-ROC
+ """
+ try:
+ # 如果y_true是one-hot编码,转换为类别索引
+ if len(y_true.shape) > 1 and y_true.shape[1] > 1:
+ y_true = np.argmax(y_true, axis=1)
+
+ # 多分类
+ if y_prob.shape[1] > 2:
+ return roc_auc_score(y_true, y_prob, multi_class=multi_class, average='macro')
+ # 二分类
+ else:
+ return roc_auc_score(y_true, y_prob[:, 1])
+ except Exception as e:
+ logger.error(f"计算AUC-ROC时出错: {e}")
+ return 0.0
+
+ def average_precision(self, y_true: np.ndarray, y_prob: np.ndarray,
+ average: str = 'macro') -> Union[float, np.ndarray]:
+ """
+ 计算平均精确率
+
+ Args:
+ y_true: 真实标签
+ y_prob: 预测概率
+ average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
+
+ Returns:
+ 平均精确率
+ """
+ try:
+ # 如果y_true是one-hot编码,转换为类别索引
+ if len(y_true.shape) > 1 and y_true.shape[1] > 1:
+ y_true = np.argmax(y_true, axis=1)
+
+ # 多分类:使用sklearn的方法
+ return average_precision_score(
+ tf.keras.utils.to_categorical(y_true, num_classes=y_prob.shape[1]),
+ y_prob,
+ average=average
+ )
+ except Exception as e:
+ logger.error(f"计算平均精确率时出错: {e}")
+ return 0.0
+
+ def plot_precision_recall_curve(self, y_true: np.ndarray, y_prob: np.ndarray,
+ class_id: Optional[int] = None,
+ figsize: Tuple[int, int] = (10, 8),
+ save_path: Optional[str] = None) -> None:
+ """
+ 绘制精确率-召回率曲线
+
+ Args:
+ y_true: 真实标签
+ y_prob: 预测概率
+ class_id: 要绘制的类别ID,如果为None则绘制所有类别
+ figsize: 图像大小
+ save_path: 保存路径,如果为None则显示图像
+ """
+ # 如果y_true是one-hot编码,转换为类别索引
+ if len(y_true.shape) > 1 and y_true.shape[1] > 1:
+ y_true = np.argmax(y_true, axis=1)
+
+ # 创建图像
+ plt.figure(figsize=figsize)
+
+ # 确定要绘制的类别
+ if class_id is not None and class_id < y_prob.shape[1]:
+ # 绘制指定类别的PR曲线
+ y_true_bin = (y_true == class_id).astype(int)
+ precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, class_id])
+ avg_prec = average_precision_score(y_true_bin, y_prob[:, class_id])
+
+ plt.step(recall, precision, where='post',
+ label=f'Class {class_id} (AP = {avg_prec:.3f})')
+ else:
+ # 绘制所有类别的PR曲线
+ for i in range(y_prob.shape[1]):
+ y_true_bin = (y_true == i).astype(int)
+ precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, i])
+ avg_prec = average_precision_score(y_true_bin, y_prob[:, i])
+
+ class_name = self.class_names[i] if self.class_names else f"Class {i}"
+ plt.step(recall, precision, where='post',
+ label=f'{class_name} (AP = {avg_prec:.3f})')
+
+ plt.xlabel('Recall')
+ plt.ylabel('Precision')
+ plt.title('Precision-Recall Curve')
+ plt.legend(loc='lower left')
+ plt.grid(True)
+
+ # 保存或显示图像
+ if save_path:
+ plt.savefig(save_path)
+ logger.info(f"精确率-召回率曲线图已保存到: {save_path}")
+ else:
+ plt.show()
+
+ def calculate_all_metrics(self, y_true: np.ndarray, y_pred: np.ndarray,
+ y_prob: Optional[np.ndarray] = None) -> Dict[str, float]:
+ """
+ 计算所有评估指标
+
+ Args:
+ y_true: 真实标签
+ y_pred: 预测标签
+ y_prob: 预测概率
+
+ Returns:
+ 包含所有评估指标的字典
+ """
+ metrics = {}
+
+ # 基础指标
+ metrics['accuracy'] = self.accuracy(y_true, y_pred)
+ metrics['precision_macro'] = self.precision(y_true, y_pred, average='macro')
+ metrics['recall_macro'] = self.recall(y_true, y_pred, average='macro')
+ metrics['f1_macro'] = self.f1(y_true, y_pred, average='macro')
+
+ # 如果提供了预测概率,计算AUC-ROC和平均精确率
+ if y_prob is not None:
+ try:
+ metrics['auc_roc'] = self.auc_roc(y_true, y_prob)
+ metrics['average_precision'] = self.average_precision(y_true, y_prob)
+ except Exception as e:
+ logger.error(f"计算概率指标时出错: {e}")
+
+ # 类别级别的指标
+ for avg in ['micro', 'weighted']:
+ metrics[f'precision_{avg}'] = self.precision(y_true, y_pred, average=avg)
+ metrics[f'recall_{avg}'] = self.recall(y_true, y_pred, average=avg)
+ metrics[f'f1_{avg}'] = self.f1(y_true, y_pred, average=avg)
+
+ return metrics
+
+ def metrics_to_dataframe(self, metrics: Dict[str, float]) -> pd.DataFrame:
+ """
+ 将评估指标转换为DataFrame
+
+ Args:
+ metrics: 评估指标字典
+
+ Returns:
+ 评估指标DataFrame
+ """
+ return pd.DataFrame(metrics.items(), columns=['Metric', 'Value']).set_index('Metric')
diff --git a/evaluation/visualization.py b/evaluation/visualization.py
new file mode 100644
index 0000000..5006052
--- /dev/null
+++ b/evaluation/visualization.py
@@ -0,0 +1,370 @@
+"""
+可视化模块:实现评估结果的可视化
+"""
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import seaborn as sns
+from typing import List, Dict, Tuple, Optional, Any, Union
+import os
+import itertools
+from sklearn.metrics import roc_curve, precision_recall_curve, auc
+from sklearn.manifold import TSNE
+from sklearn.decomposition import PCA
+
+from utils.logger import get_logger
+from utils.file_utils import ensure_dir
+
+logger = get_logger("Visualization")
+
+
+class EvaluationVisualizer:
+ """评估结果可视化类"""
+
+ def __init__(self, output_dir: Optional[str] = None,
+ class_names: Optional[List[str]] = None,
+ figsize: Tuple[int, int] = (10, 8)):
+ """
+ 初始化评估结果可视化类
+
+ Args:
+ output_dir: 输出目录,用于保存可视化结果
+ class_names: 类别名称列表
+ figsize: 图像默认大小
+ """
+ self.output_dir = output_dir
+ if output_dir:
+ ensure_dir(output_dir)
+
+ self.class_names = class_names
+ self.figsize = figsize
+
+ def plot_confusion_matrix(self, cm: np.ndarray,
+ normalize: Optional[str] = None,
+ title: str = 'Confusion Matrix',
+ cmap: str = 'Blues',
+ save_path: Optional[str] = None) -> None:
+ """
+ 绘制混淆矩阵
+
+ Args:
+ cm: 混淆矩阵
+ normalize: 归一化方式,可选值: 'true', 'pred', 'all', None
+ title: 图像标题
+ cmap: 颜色映射
+ save_path: 保存路径,如果为None则使用output_dir/confusion_matrix.png
+ """
+ if normalize is not None:
+ if normalize == 'true':
+ cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
+ title = 'Normalized (by true) ' + title
+ elif normalize == 'pred':
+ cm = cm.astype('float') / cm.sum(axis=0, keepdims=True)
+ title = 'Normalized (by pred) ' + title
+ elif normalize == 'all':
+ cm = cm.astype('float') / cm.sum()
+ title = 'Normalized (by all) ' + title
+
+ plt.figure(figsize=self.figsize)
+
+ # 使用seaborn绘制热图
+ sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd',
+ cmap=cmap, square=True, cbar=True)
+
+ # 设置坐标轴标签
+ if self.class_names:
+ tick_marks = np.arange(len(self.class_names))
+ plt.xticks(tick_marks + 0.5, self.class_names, rotation=45, ha='right')
+ plt.yticks(tick_marks + 0.5, self.class_names, rotation=0)
+
+ plt.title(title)
+ plt.ylabel('True label')
+ plt.xlabel('Predicted label')
+ plt.tight_layout()
+
+ # 保存图像
+ if save_path is None and self.output_dir:
+ save_path = os.path.join(self.output_dir, 'confusion_matrix.png')
+
+ if save_path:
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
+ logger.info(f"混淆矩阵图已保存到: {save_path}")
+
+ plt.close()
+
+ def plot_metrics_comparison(self, metrics_dict: Dict[str, Dict[str, float]],
+ selected_metrics: Optional[List[str]] = None,
+ title: str = 'Metrics Comparison',
+ save_path: Optional[str] = None) -> None:
+ """
+ 绘制多个模型的评估指标比较
+
+ Args:
+ metrics_dict: 模型评估指标字典,格式为{model_name: {metric_name: value}}
+ selected_metrics: 要比较的指标列表,如果为None则使用所有指标
+ title: 图像标题
+ save_path: 保存路径,如果为None则使用output_dir/metrics_comparison.png
+ """
+ # 创建DataFrame
+ df = pd.DataFrame(metrics_dict).T
+
+ # 筛选指标
+ if selected_metrics:
+ df = df[selected_metrics]
+
+ # 绘制条形图
+ plt.figure(figsize=self.figsize)
+ df.plot(kind='bar', figsize=self.figsize)
+
+ plt.title(title)
+ plt.ylabel('Score')
+ plt.ylim(0, 1)
+ plt.legend(loc='best')
+ plt.grid(axis='y', linestyle='--', alpha=0.7)
+ plt.tight_layout()
+
+ # 保存图像
+ if save_path is None and self.output_dir:
+ save_path = os.path.join(self.output_dir, 'metrics_comparison.png')
+
+ if save_path:
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
+ logger.info(f"评估指标比较图已保存到: {save_path}")
+
+ plt.close()
+
+ def plot_roc_curves(self, y_true: np.ndarray, y_prob: np.ndarray,
+ title: str = 'ROC Curves',
+ save_path: Optional[str] = None) -> None:
+ """
+ 绘制ROC曲线
+
+ Args:
+ y_true: 真实标签
+ y_prob: 预测概率
+ title: 图像标题
+ save_path: 保存路径,如果为None则使用output_dir/roc_curves.png
+ """
+ plt.figure(figsize=self.figsize)
+
+ # 确保y_true是一维数组
+ if len(y_true.shape) > 1 and y_true.shape[1] > 1:
+ y_true = np.argmax(y_true, axis=1)
+
+ # 获取类别数
+ num_classes = y_prob.shape[1]
+
+ # 绘制每个类别的ROC曲线
+ for i in range(num_classes):
+ # 二分类转换:当前类别为正类,其他为负类
+ y_true_bin = (y_true == i).astype(int)
+
+ # 计算ROC曲线
+ fpr, tpr, _ = roc_curve(y_true_bin, y_prob[:, i])
+ roc_auc = auc(fpr, tpr)
+
+ # 确定类别名称
+ if self.class_names and i < len(self.class_names):
+ class_name = self.class_names[i]
+ else:
+ class_name = f'Class {i}'
+
+ # 绘制ROC曲线
+ plt.plot(fpr, tpr, lw=2, label=f'{class_name} (AUC = {roc_auc:.3f})')
+
+ # 绘制随机猜测的基准线
+ plt.plot([0, 1], [0, 1], 'k--', lw=2)
+
+ plt.xlim([0.0, 1.0])
+ plt.ylim([0.0, 1.05])
+ plt.xlabel('False Positive Rate')
+ plt.ylabel('True Positive Rate')
+ plt.title(title)
+ plt.legend(loc='lower right')
+ plt.grid(True, alpha=0.3)
+ plt.tight_layout()
+
+ # 保存图像
+ if save_path is None and self.output_dir:
+ save_path = os.path.join(self.output_dir, 'roc_curves.png')
+
+ if save_path:
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
+ logger.info(f"ROC曲线图已保存到: {save_path}")
+
+ plt.close()
+
+ def plot_precision_recall_curves(self, y_true: np.ndarray, y_prob: np.ndarray,
+ title: str = 'Precision-Recall Curves',
+ save_path: Optional[str] = None) -> None:
+ """
+ 绘制精确率-召回率曲线
+
+ Args:
+ y_true: 真实标签
+ y_prob: 预测概率
+ title: 图像标题
+ save_path: 保存路径,如果为None则使用output_dir/precision_recall_curves.png
+ """
+ plt.figure(figsize=self.figsize)
+
+ # 确保y_true是一维数组
+ if len(y_true.shape) > 1 and y_true.shape[1] > 1:
+ y_true = np.argmax(y_true, axis=1)
+
+ # 获取类别数
+ num_classes = y_prob.shape[1]
+
+ # 绘制每个类别的PR曲线
+ for i in range(num_classes):
+ # 二分类转换:当前类别为正类,其他为负类
+ y_true_bin = (y_true == i).astype(int)
+
+ # 计算PR曲线
+ precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, i])
+ pr_auc = auc(recall, precision)
+
+ # 确定类别名称
+ if self.class_names and i < len(self.class_names):
+ class_name = self.class_names[i]
+ else:
+ class_name = f'Class {i}'
+
+ # 绘制PR曲线
+ plt.plot(recall, precision, lw=2, label=f'{class_name} (AUC = {pr_auc:.3f})')
+
+ plt.xlim([0.0, 1.0])
+ plt.ylim([0.0, 1.05])
+ plt.xlabel('Recall')
+ plt.ylabel('Precision')
+ plt.title(title)
+ plt.legend(loc='best')
+ plt.grid(True, alpha=0.3)
+ plt.tight_layout()
+
+ # 保存图像
+ if save_path is None and self.output_dir:
+ save_path = os.path.join(self.output_dir, 'precision_recall_curves.png')
+
+ if save_path:
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
+ logger.info(f"精确率-召回率曲线图已保存到: {save_path}")
+
+ plt.close()
+
+ def plot_feature_importance(self, feature_names: List[str],
+ importance: np.ndarray,
+ title: str = 'Feature Importance',
+ top_n: int = 20,
+ save_path: Optional[str] = None) -> None:
+ """
+ 绘制特征重要性
+
+ Args:
+ feature_names: 特征名称列表
+ importance: 特征重要性数组
+ title: 图像标题
+ top_n: 显示前N个重要的特征
+ save_path: 保存路径,如果为None则使用output_dir/feature_importance.png
+ """
+ # 创建DataFrame
+ df = pd.DataFrame({'Feature': feature_names, 'Importance': importance})
+
+ # 按重要性排序
+ df = df.sort_values('Importance', ascending=False).head(top_n)
+
+ # 绘制条形图
+ plt.figure(figsize=self.figsize)
+ sns.barplot(x='Importance', y='Feature', data=df)
+
+ plt.title(title)
+ plt.xlabel('Importance')
+ plt.ylabel('Feature')
+ plt.tight_layout()
+
+ # 保存图像
+ if save_path is None and self.output_dir:
+ save_path = os.path.join(self.output_dir, 'feature_importance.png')
+
+ if save_path:
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
+ logger.info(f"特征重要性图已保存到: {save_path}")
+
+ plt.close()
+
+ def plot_embedding_visualization(self, embeddings: np.ndarray,
+ labels: np.ndarray,
+ method: str = 'tsne',
+ title: str = 'Embedding Visualization',
+ save_path: Optional[str] = None) -> None:
+ """
+ 绘制嵌入向量可视化
+
+ Args:
+ embeddings: 嵌入向量,形状为(样本数, 嵌入维度)
+ labels: 类别标签,形状为(样本数,)
+ method: 降维方法,'tsne'或'pca'
+ title: 图像标题
+ save_path: 保存路径,如果为None则使用output_dir/embedding_visualization.png
+ """
+ # 降维
+ if method.lower() == 'tsne':
+ reducer = TSNE(n_components=2, random_state=42)
+ elif method.lower() == 'pca':
+ reducer = PCA(n_components=2, random_state=42)
+ else:
+ raise ValueError(f"不支持的降维方法: {method}")
+
+ # 如果嵌入向量太多,采样一部分
+ max_samples = 5000
+ if len(embeddings) > max_samples:
+ indices = np.random.choice(len(embeddings), max_samples, replace=False)
+ embeddings_sample = embeddings[indices]
+ labels_sample = labels[indices]
+ else:
+ embeddings_sample = embeddings
+ labels_sample = labels
+
+ # 执行降维
+ embeddings_2d = reducer.fit_transform(embeddings_sample)
+
+ # 绘制散点图
+ plt.figure(figsize=self.figsize)
+
+ # 确保标签是一维数组
+ if len(labels_sample.shape) > 1 and labels_sample.shape[1] > 1:
+ labels_sample = np.argmax(labels_sample, axis=1)
+
+ # 获取唯一类别
+ unique_labels = np.unique(labels_sample)
+
+ # 为每个类别分配不同的颜色
+ colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
+
+ for i, label in enumerate(unique_labels):
+ mask = labels_sample == label
+
+ # 确定类别名称
+ if self.class_names and label < len(self.class_names):
+ class_name = self.class_names[int(label)]
+ else:
+ class_name = f'Class {int(label)}'
+
+ plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
+ c=[colors[i]], label=class_name, alpha=0.7)
+
+ plt.title(title)
+ plt.legend(loc='best')
+ plt.grid(True, alpha=0.3)
+ plt.tight_layout()
+
+ # 保存图像
+ if save_path is None and self.output_dir:
+ save_path = os.path.join(self.output_dir, f'embedding_visualization_{method}.png')
+
+ if save_path:
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
+ logger.info(f"嵌入向量可视化图已保存到: {save_path}")
+
+ plt.close()
+
diff --git a/inference/__init__.py b/inference/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/inference/batch_processor.py b/inference/batch_processor.py
new file mode 100644
index 0000000..d020f7f
--- /dev/null
+++ b/inference/batch_processor.py
@@ -0,0 +1,362 @@
+"""
+批处理模块:实现批量处理大规模文本数据
+"""
+import os
+import time
+import pandas as pd
+import numpy as np
+from typing import List, Dict, Tuple, Optional, Any, Union, Callable, Iterator
+import concurrent.futures
+from tqdm import tqdm
+import glob
+import json
+
+from config.system_config import ENCODING, DATA_LOADING_WORKERS, MAX_TEXT_PER_BATCH
+from utils.logger import get_logger
+from utils.file_utils import read_text_file, ensure_dir
+from inference.predictor import Predictor
+
+logger = get_logger("BatchProcessor")
+
+
+class BatchProcessor:
+ """批处理器,负责批量处理大规模文本数据"""
+
+ def __init__(self, predictor: Predictor,
+ batch_size: int = 64,
+ max_workers: int = DATA_LOADING_WORKERS,
+ max_batch_queue: int = 10):
+ """
+ 初始化批处理器
+
+ Args:
+ predictor: 预测器实例
+ batch_size: 批大小
+ max_workers: 最大工作线程数
+ max_batch_queue: 最大批次队列长度
+ """
+ self.predictor = predictor
+ self.batch_size = batch_size
+ self.max_workers = max_workers
+ self.max_batch_queue = max_batch_queue
+
+ logger.info(f"初始化批处理器,批大小: {batch_size}, 最大工作线程数: {max_workers}")
+
+ def _extract_text_from_file(self, file_path: str) -> str:
+ """
+ 从文件中提取文本
+
+ Args:
+ file_path: 文件路径
+
+ Returns:
+ 文本内容
+ """
+ return read_text_file(file_path, encoding=ENCODING)
+
+ def _batch_generator(self, texts: List[str], batch_size: int) -> Iterator[List[str]]:
+ """
+ 生成文本批次
+
+ Args:
+ texts: 文本列表
+ batch_size: 批大小
+
+ Returns:
+ 文本批次生成器
+ """
+ for i in range(0, len(texts), batch_size):
+ yield texts[i:i + batch_size]
+
+ def process_files(self, file_paths: List[str], output_path: Optional[str] = None,
+ return_top_k: int = 1, format: str = 'csv') -> pd.DataFrame:
+ """
+ 批量处理文件
+
+ Args:
+ file_paths: 文件路径列表
+ output_path: 输出文件路径,如果为None则不保存
+ return_top_k: 返回概率最高的前k个类别
+ format: 输出格式,'csv'或'json'
+
+ Returns:
+ 预测结果DataFrame
+ """
+ logger.info(f"开始批量处理 {len(file_paths)} 个文件")
+ start_time = time.time()
+
+ # 使用线程池并行读取文件
+ texts = []
+ file_names = []
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
+ future_to_file = {executor.submit(self._extract_text_from_file, file_path): file_path for file_path in
+ file_paths}
+
+ for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(file_paths), desc="读取文件"):
+ file_path = future_to_file[future]
+ try:
+ text = future.result()
+ if text:
+ texts.append(text)
+ file_names.append(os.path.basename(file_path))
+ except Exception as e:
+ logger.error(f"处理文件 {file_path} 时出错: {e}")
+
+ # 批量预测
+ all_predictions = []
+
+ for batch in tqdm(self._batch_generator(texts, self.batch_size),
+ total=(len(texts) + self.batch_size - 1) // self.batch_size, desc="预测"):
+ predictions = self.predictor.predict_batch(batch, return_top_k=return_top_k, return_probabilities=True)
+ all_predictions.extend(predictions)
+
+ # 整合结果
+ if return_top_k > 1:
+ # 多个类别的情况
+ results = []
+ for i, preds in enumerate(all_predictions):
+ for j, pred in enumerate(preds):
+ results.append({
+ 'file_name': file_names[i],
+ 'rank': j + 1,
+ 'predicted_class': pred['class'],
+ 'probability': pred['probability']
+ })
+ df = pd.DataFrame(results)
+ else:
+ # 单个类别的情况
+ df = pd.DataFrame({
+ 'file_name': file_names,
+ 'predicted_class': [pred['class'] for pred in all_predictions],
+ 'probability': [pred['probability'] for pred in all_predictions]
+ })
+
+ # 保存结果
+ if output_path:
+ if format.lower() == 'csv':
+ df.to_csv(output_path, index=False, encoding='utf-8')
+ elif format.lower() == 'json':
+ # 转换为嵌套的JSON格式
+ if return_top_k > 1:
+ # 分组后转换为嵌套格式
+ result = {}
+ for file_name in df['file_name'].unique():
+ sub_df = df[df['file_name'] == file_name]
+ predictions = []
+ for _, row in sub_df.iterrows():
+ predictions.append({
+ 'class': row['predicted_class'],
+ 'probability': row['probability']
+ })
+ result[file_name] = {
+ 'predictions': predictions
+ }
+ else:
+ # 直接构建JSON
+ result = {}
+ for _, row in df.iterrows():
+ result[row['file_name']] = {
+ 'predicted_class': row['predicted_class'],
+ 'probability': row['probability']
+ }
+
+ # 保存为JSON
+ with open(output_path, 'w', encoding='utf-8') as f:
+ json.dump(result, f, ensure_ascii=False, indent=2)
+ else:
+ raise ValueError(f"不支持的输出格式: {format}")
+
+ logger.info(f"预测结果已保存到: {output_path}")
+
+ processing_time = time.time() - start_time
+ logger.info(f"批量处理完成,共处理 {len(texts)} 个文件,用时: {processing_time:.2f} 秒")
+
+ return df
+
+ def process_directory(self, directory: str, pattern: str = "*.txt",
+ output_path: Optional[str] = None,
+ return_top_k: int = 1, format: str = 'csv',
+ recursive: bool = True) -> pd.DataFrame:
+ """
+ 批量处理目录中的文件
+
+ Args:
+ directory: 目录路径
+ pattern: 文件匹配模式
+ output_path: 输出文件路径,如果为None则不保存
+ return_top_k: 返回概率最高的前k个类别
+ format: 输出格式,'csv'或'json'
+ recursive: 是否递归处理子目录
+
+ Returns:
+ 预测结果DataFrame
+ """
+ # 获取符合模式的文件路径
+ if recursive:
+ file_paths = glob.glob(os.path.join(directory, "**", pattern), recursive=True)
+ else:
+ file_paths = glob.glob(os.path.join(directory, pattern))
+
+ if not file_paths:
+ logger.warning(f"在目录 {directory} 中未找到符合模式 {pattern} 的文件")
+ return pd.DataFrame()
+
+ logger.info(f"在目录 {directory} 中找到 {len(file_paths)} 个符合模式 {pattern} 的文件")
+
+ # 调用process_files处理文件
+ return self.process_files(file_paths, output_path, return_top_k, format)
+
+ def process_dataframe(self, df: pd.DataFrame, text_column: str,
+ id_column: Optional[str] = None,
+ output_path: Optional[str] = None,
+ return_top_k: int = 1, format: str = 'csv') -> pd.DataFrame:
+ """
+ 批量处理DataFrame中的文本
+
+ Args:
+ df: 输入DataFrame
+ text_column: 文本列名
+ id_column: ID列名,如果为None则使用索引
+ output_path: 输出文件路径,如果为None则不保存
+ return_top_k: 返回概率最高的前k个类别
+ format: 输出格式,'csv'或'json'
+
+ Returns:
+ 预测结果DataFrame
+ """
+ # 获取文本和ID
+ texts = df[text_column].tolist()
+
+ if id_column:
+ ids = df[id_column].tolist()
+ else:
+ ids = df.index.tolist()
+
+ # 批量预测
+ result_df = self.predictor.predict_to_dataframe(texts, ids, return_top_k)
+
+ # 保存结果
+ if output_path:
+ if format.lower() == 'csv':
+ result_df.to_csv(output_path, index=False, encoding='utf-8')
+ elif format.lower() == 'json':
+ # 转换为嵌套的JSON格式
+ self.predictor.save_predictions(texts, output_path, ids, return_top_k, 'json')
+ else:
+ raise ValueError(f"不支持的输出格式: {format}")
+
+ logger.info(f"预测结果已保存到: {output_path}")
+
+ return result_df
+
+ def process_large_file(self, file_path: str, output_path: Optional[str] = None,
+ return_top_k: int = 1, format: str = 'csv',
+ chunk_size: int = MAX_TEXT_PER_BATCH,
+ delimiter: str = '\n\n') -> None:
+ """
+ 处理大型文本文件,文件会被分块读取和处理
+
+ Args:
+ file_path: 文件路径
+ output_path: 输出文件路径,如果为None则不保存
+ return_top_k: 返回概率最高的前k个类别
+ format: 输出格式,'csv'或'json'
+ chunk_size: 每个块的大小(文本数量)
+ delimiter: 文本分隔符
+ """
+ logger.info(f"开始处理大型文件: {file_path}")
+ start_time = time.time()
+
+ # 读取文件内容
+ with open(file_path, 'r', encoding=ENCODING) as f:
+ content = f.read()
+
+ # 分割文本
+ texts = content.split(delimiter)
+ texts = [text.strip() for text in texts if text.strip()]
+
+ logger.info(f"文件共包含 {len(texts)} 条文本")
+
+ # 创建输出文件
+ if output_path:
+ if format.lower() == 'csv':
+ # 创建CSV文件头
+ if return_top_k > 1:
+ header = "text_id,text,rank,predicted_class,probability\n"
+ else:
+ header = "text_id,text,predicted_class,probability\n"
+
+ with open(output_path, 'w', encoding=ENCODING) as f:
+ f.write(header)
+ elif format.lower() == 'json':
+ # 创建JSON文件
+ with open(output_path, 'w', encoding=ENCODING) as f:
+ f.write('{\n')
+
+ # 分块处理
+ total_chunks = (len(texts) + chunk_size - 1) // chunk_size
+
+ for i in range(0, len(texts), chunk_size):
+ chunk = texts[i:i + chunk_size]
+ chunk_ids = list(range(i, i + len(chunk)))
+
+ logger.info(f"处理第 {i // chunk_size + 1}/{total_chunks} 块,包含 {len(chunk)} 条文本")
+
+ # 批量预测
+ result_df = self.predictor.predict_to_dataframe(chunk, chunk_ids, return_top_k)
+
+ # 追加到输出文件
+ if output_path:
+ if format.lower() == 'csv':
+ result_df.to_csv(output_path, index=False, encoding=ENCODING, mode='a', header=False)
+ elif format.lower() == 'json':
+ # 转换为JSON并追加
+ if return_top_k > 1:
+ # 分组后转换为嵌套格式
+ for id_val in result_df['id'].unique():
+ sub_df = result_df[result_df['id'] == id_val]
+ predictions = []
+ for _, row in sub_df.iterrows():
+ predictions.append({
+ 'class': row['predicted_class'],
+ 'probability': float(row['probability'])
+ })
+
+ json_str = f' "{id_val}": {{\n'
+ json_str += f' "text": {json.dumps(sub_df.iloc[0]["text"], ensure_ascii=False)},\n'
+ json_str += f' "predictions": {json.dumps(predictions, ensure_ascii=False)}\n'
+ json_str += ' },'
+
+ with open(output_path, 'a', encoding=ENCODING) as f:
+ f.write(json_str + '\n')
+ else:
+ # 直接构建JSON
+ for _, row in result_df.iterrows():
+ json_str = f' "{row["id"]}": {{\n'
+ json_str += f' "text": {json.dumps(row["text"], ensure_ascii=False)},\n'
+ json_str += f' "predicted_class": "{row["predicted_class"]}",\n'
+ json_str += f' "probability": {float(row["probability"])}\n'
+ json_str += ' },'
+
+ with open(output_path, 'a', encoding=ENCODING) as f:
+ f.write(json_str + '\n')
+
+ # 完成JSON文件
+ if output_path and format.lower() == 'json':
+ with open(output_path, 'a', encoding=ENCODING) as f:
+ f.write('}\n')
+
+ # 修复JSON文件中的最后一个逗号
+ with open(output_path, 'r', encoding=ENCODING) as f:
+ content = f.read()
+
+ content = content.rstrip('\n}')
+ content = content.rstrip(',')
+ content += '\n}\n'
+
+ with open(output_path, 'w', encoding=ENCODING) as f:
+ f.write(content)
+
+ processing_time = time.time() - start_time
+ logger.info(f"处理大型文件完成,共处理 {len(texts)} 条文本,用时: {processing_time:.2f} 秒")
diff --git a/inference/predictor.py b/inference/predictor.py
new file mode 100644
index 0000000..9e37a08
--- /dev/null
+++ b/inference/predictor.py
@@ -0,0 +1,316 @@
+"""
+预测器模块:实现模型预测功能,支持单条和批量文本预测
+"""
+import os
+import time
+import numpy as np
+import tensorflow as tf
+from typing import List, Dict, Tuple, Optional, Any, Union
+import pandas as pd
+import json
+
+from config.system_config import CATEGORY_TO_ID, ID_TO_CATEGORY
+from models.base_model import TextClassificationModel
+from preprocessing.tokenization import ChineseTokenizer
+from preprocessing.vectorizer import SequenceVectorizer
+from utils.logger import get_logger
+
+logger = get_logger("Predictor")
+
+
+class Predictor:
+ """预测器,负责加载模型和进行预测"""
+
+ def __init__(self, model: TextClassificationModel,
+ tokenizer: Optional[ChineseTokenizer] = None,
+ vectorizer: Optional[SequenceVectorizer] = None,
+ class_names: Optional[List[str]] = None,
+ max_sequence_length: int = 500,
+ batch_size: Optional[int] = None):
+ """
+ 初始化预测器
+
+ Args:
+ model: 已训练的模型实例
+ tokenizer: 分词器实例,如果为None则创建一个新的分词器
+ vectorizer: 文本向量化器实例,如果为None则表示模型直接接收序列
+ class_names: 类别名称列表,如果为None则使用ID_TO_CATEGORY
+ max_sequence_length: 最大序列长度
+ batch_size: 批大小,如果为None则使用模型默认值
+ """
+ self.model = model
+ self.tokenizer = tokenizer or ChineseTokenizer()
+ self.vectorizer = vectorizer
+ self.class_names = class_names
+
+ if class_names is None and hasattr(model, 'num_classes'):
+ # 如果模型具有类别数量信息,从ID_TO_CATEGORY获取类别名称
+ self.class_names = [ID_TO_CATEGORY.get(i, str(i)) for i in range(model.num_classes)]
+
+ self.max_sequence_length = max_sequence_length
+ self.batch_size = batch_size or (model.batch_size if hasattr(model, 'batch_size') else 32)
+
+ logger.info(f"初始化预测器,批大小: {self.batch_size}")
+
+ def preprocess_text(self, text: str) -> Any:
+ """
+ 预处理单条文本
+
+ Args:
+ text: 原始文本
+
+ Returns:
+ 预处理后的文本表示
+ """
+ # 分词
+ tokenized_text = self.tokenizer.tokenize(text, return_string=True)
+
+ # 如果有向量化器,应用向量化
+ if self.vectorizer is not None:
+ return self.vectorizer.transform([tokenized_text])[0]
+
+ return tokenized_text
+
+ def preprocess_texts(self, texts: List[str]) -> Any:
+ """
+ 批量预处理文本
+
+ Args:
+ texts: 原始文本列表
+
+ Returns:
+ 预处理后的批量文本表示
+ """
+ # 分词
+ tokenized_texts = [self.tokenizer.tokenize(text, return_string=True) for text in texts]
+
+ # 如果有向量化器,应用向量化
+ if self.vectorizer is not None:
+ return self.vectorizer.transform(tokenized_texts)
+
+ return tokenized_texts
+
+ def predict(self, text: str, return_top_k: int = 1,
+ return_probabilities: bool = False) -> Union[str, Dict, List]:
+ """
+ 预测单条文本的类别
+
+ Args:
+ text: 原始文本
+ return_top_k: 返回概率最高的前k个类别
+ return_probabilities: 是否返回概率值
+
+ Returns:
+ 预测结果,格式取决于参数设置
+ """
+ # 预处理文本
+ processed_text = self.preprocess_text(text)
+
+ # 添加批次维度
+ if isinstance(processed_text, str):
+ input_data = np.array([processed_text])
+ else:
+ input_data = np.expand_dims(processed_text, axis=0)
+
+ # 预测
+ start_time = time.time()
+ predictions = self.model.predict(input_data)
+ prediction_time = time.time() - start_time
+
+ # 获取前k个预测结果
+ if return_top_k > 1:
+ top_indices = np.argsort(predictions[0])[::-1][:return_top_k]
+ top_probs = predictions[0][top_indices]
+
+ if self.class_names:
+ top_classes = [self.class_names[idx] for idx in top_indices]
+ else:
+ top_classes = [str(idx) for idx in top_indices]
+
+ if return_probabilities:
+ return [{'class': cls, 'probability': float(prob)}
+ for cls, prob in zip(top_classes, top_probs)]
+ else:
+ return top_classes
+ else:
+ # 获取最高概率的类别
+ pred_idx = np.argmax(predictions[0])
+ pred_prob = float(predictions[0][pred_idx])
+
+ if self.class_names:
+ pred_class = self.class_names[pred_idx]
+ else:
+ pred_class = str(pred_idx)
+
+ if return_probabilities:
+ return {'class': pred_class, 'probability': pred_prob, 'time': prediction_time}
+ else:
+ return pred_class
+
+ def predict_batch(self, texts: List[str], return_top_k: int = 1,
+ return_probabilities: bool = False) -> List:
+ """
+ 批量预测文本类别
+
+ Args:
+ texts: 原始文本列表
+ return_top_k: 返回概率最高的前k个类别
+ return_probabilities: 是否返回概率值
+
+ Returns:
+ 预测结果列表
+ """
+ # 空列表检查
+ if not texts:
+ return []
+
+ # 预处理文本
+ processed_texts = self.preprocess_texts(texts)
+
+ # 预测
+ start_time = time.time()
+ predictions = self.model.predict(processed_texts, batch_size=self.batch_size)
+ prediction_time = time.time() - start_time
+
+ # 处理预测结果
+ results = []
+
+ for i, pred in enumerate(predictions):
+ if return_top_k > 1:
+ top_indices = np.argsort(pred)[::-1][:return_top_k]
+ top_probs = pred[top_indices]
+
+ if self.class_names:
+ top_classes = [self.class_names[idx] for idx in top_indices]
+ else:
+ top_classes = [str(idx) for idx in top_indices]
+
+ if return_probabilities:
+ results.append([{'class': cls, 'probability': float(prob)}
+ for cls, prob in zip(top_classes, top_probs)])
+ else:
+ results.append(top_classes)
+ else:
+ # 获取最高概率的类别
+ pred_idx = np.argmax(pred)
+ pred_prob = float(pred[pred_idx])
+
+ if self.class_names:
+ pred_class = self.class_names[pred_idx]
+ else:
+ pred_class = str(pred_idx)
+
+ if return_probabilities:
+ results.append({'class': pred_class, 'probability': pred_prob})
+ else:
+ results.append(pred_class)
+
+ logger.info(f"批量预测 {len(texts)} 条文本完成,用时: {prediction_time:.2f} 秒")
+
+ return results
+
+ def predict_to_dataframe(self, texts: List[str],
+ text_ids: Optional[List[Union[str, int]]] = None,
+ return_top_k: int = 1) -> pd.DataFrame:
+ """
+ 批量预测并返回DataFrame
+
+ Args:
+ texts: 原始文本列表
+ text_ids: 文本ID列表,如果为None则使用索引
+ return_top_k: 返回概率最高的前k个类别
+
+ Returns:
+ 预测结果DataFrame
+ """
+ # 预测
+ predictions = self.predict_batch(texts, return_top_k=return_top_k, return_probabilities=True)
+
+ # 创建DataFrame
+ if text_ids is None:
+ text_ids = list(range(len(texts)))
+
+ if return_top_k > 1:
+ # 多个类别的情况
+ results = []
+ for i, preds in enumerate(predictions):
+ for j, pred in enumerate(preds):
+ results.append({
+ 'id': text_ids[i],
+ 'text': texts[i],
+ 'rank': j + 1,
+ 'predicted_class': pred['class'],
+ 'probability': pred['probability']
+ })
+ df = pd.DataFrame(results)
+ else:
+ # 单个类别的情况
+ df = pd.DataFrame({
+ 'id': text_ids,
+ 'text': texts,
+ 'predicted_class': [pred['class'] for pred in predictions],
+ 'probability': [pred['probability'] for pred in predictions]
+ })
+
+ return df
+
+ def save_predictions(self, texts: List[str],
+ output_path: str,
+ text_ids: Optional[List[Union[str, int]]] = None,
+ return_top_k: int = 1,
+ format: str = 'csv') -> str:
+ """
+ 批量预测并保存结果
+
+ Args:
+ texts: 原始文本列表
+ output_path: 输出文件路径
+ text_ids: 文本ID列表,如果为None则使用索引
+ return_top_k: 返回概率最高的前k个类别
+ format: 输出格式,'csv'或'json'
+
+ Returns:
+ 输出文件路径
+ """
+ # 获取预测结果DataFrame
+ df = self.predict_to_dataframe(texts, text_ids, return_top_k)
+
+ # 保存结果
+ if format.lower() == 'csv':
+ df.to_csv(output_path, index=False, encoding='utf-8')
+ elif format.lower() == 'json':
+ # 转换为嵌套的JSON格式
+ if return_top_k > 1:
+ # 分组后转换为嵌套格式
+ result = {}
+ for id_val in df['id'].unique():
+ sub_df = df[df['id'] == id_val]
+ predictions = []
+ for _, row in sub_df.iterrows():
+ predictions.append({
+ 'class': row['predicted_class'],
+ 'probability': row['probability']
+ })
+ result[str(id_val)] = {
+ 'text': sub_df.iloc[0]['text'],
+ 'predictions': predictions
+ }
+ else:
+ # 直接构建JSON
+ result = {}
+ for _, row in df.iterrows():
+ result[str(row['id'])] = {
+ 'text': row['text'],
+ 'predicted_class': row['predicted_class'],
+ 'probability': row['probability']
+ }
+
+ # 保存为JSON
+ with open(output_path, 'w', encoding='utf-8') as f:
+ json.dump(result, f, ensure_ascii=False, indent=2)
+ else:
+ raise ValueError(f"不支持的输出格式: {format}")
+
+ logger.info(f"预测结果已保存到: {output_path}")
+
+ return output_path
diff --git a/interface/__init__.py b/interface/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/interface/api.py b/interface/api.py
new file mode 100644
index 0000000..844588a
--- /dev/null
+++ b/interface/api.py
@@ -0,0 +1,391 @@
+"""
+API接口模块:提供REST API接口
+"""
+import os
+import sys
+import json
+import time
+from typing import List, Dict, Tuple, Optional, Any, Union
+from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query, Depends, Request
+from fastapi.responses import JSONResponse
+from fastapi.middleware.cors import CORSMiddleware
+from pydantic import BaseModel
+import uvicorn
+import asyncio
+import pandas as pd
+import io
+
+# 将项目根目录添加到sys.path
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from config.system_config import CLASSIFIERS_DIR, CATEGORIES
+from models.model_factory import ModelFactory
+from preprocessing.tokenization import ChineseTokenizer
+from preprocessing.vectorizer import SequenceVectorizer
+from inference.predictor import Predictor
+from inference.batch_processor import BatchProcessor
+from utils.logger import get_logger
+
+logger = get_logger("API")
+
+
+# 数据模型
+class TextItem(BaseModel):
+ text: str
+ id: Optional[str] = None
+
+
+class BatchPredictRequest(BaseModel):
+ texts: List[TextItem]
+ top_k: Optional[int] = 1
+
+
+class BatchFileRequest(BaseModel):
+ file_paths: List[str]
+ top_k: Optional[int] = 1
+
+
+class ModelInfo(BaseModel):
+ id: str
+ name: str
+ type: str
+ num_classes: int
+ created_time: str
+ file_size: str
+
+
+# 应用实例
+app = FastAPI(
+ title="中文文本分类系统API",
+ description="提供中文文本分类功能的REST API",
+ version="1.0.0"
+)
+
+# 允许跨域请求
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+# 全局对象
+predictor = None
+
+
+def get_predictor() -> Predictor:
+ """
+ 获取或创建全局Predictor实例
+
+ Returns:
+ Predictor实例
+ """
+ global predictor
+ if predictor is None:
+ # 获取可用模型列表
+ models_info = ModelFactory.get_available_models()
+
+ if not models_info:
+ raise HTTPException(status_code=500, detail="未找到可用的模型")
+
+ # 使用最新的模型
+ model_path = models_info[0]['path']
+ logger.info(f"API加载模型: {model_path}")
+
+ # 加载模型
+ model = ModelFactory.load_model(model_path)
+
+ # 创建分词器
+ tokenizer = ChineseTokenizer()
+
+ # 创建预测器
+ predictor = Predictor(
+ model=model,
+ tokenizer=tokenizer,
+ class_names=CATEGORIES,
+ batch_size=64
+ )
+
+ return predictor
+
+
+@app.get("/")
+async def root():
+ """API根路径"""
+ return {"message": "欢迎使用中文文本分类系统API"}
+
+
+@app.post("/predict/text")
+async def predict_text(text: str = Form(...), top_k: int = Form(1)):
+ """
+ 预测单条文本
+
+ Args:
+ text: 要预测的文本
+ top_k: 返回概率最高的前k个类别
+
+ Returns:
+ 预测结果
+ """
+ logger.info(f"接收到文本预测请求,文本长度: {len(text)}")
+
+ try:
+ # 获取预测器
+ predictor = get_predictor()
+
+ # 预测
+ start_time = time.time()
+ result = predictor.predict(
+ text=text,
+ return_top_k=top_k,
+ return_probabilities=True
+ )
+ prediction_time = time.time() - start_time
+
+ # 构建响应
+ response = {
+ "success": True,
+ "predictions": result if top_k > 1 else [result],
+ "time": prediction_time
+ }
+
+ return response
+
+ except Exception as e:
+ logger.error(f"预测文本时出错: {e}")
+ raise HTTPException(status_code=500, detail=f"预测文本时出错: {str(e)}")
+
+
+@app.post("/predict/batch")
+async def predict_batch(request: BatchPredictRequest):
+ """
+ 批量预测文本
+
+ Args:
+ request: 包含文本列表和参数的请求
+
+ Returns:
+ 批量预测结果
+ """
+ texts = [item.text for item in request.texts]
+ ids = [item.id or str(i) for i, item in enumerate(request.texts)]
+
+ logger.info(f"接收到批量预测请求,共 {len(texts)} 条文本")
+
+ try:
+ # 获取预测器
+ predictor = get_predictor()
+
+ # 预测
+ start_time = time.time()
+ results = predictor.predict_batch(
+ texts=texts,
+ return_top_k=request.top_k,
+ return_probabilities=True
+ )
+ prediction_time = time.time() - start_time
+
+ # 构建响应
+ response = {
+ "success": True,
+ "total": len(texts),
+ "time": prediction_time,
+ "results": {}
+ }
+
+ # 将结果关联到ID
+ for i, (id_val, result) in enumerate(zip(ids, results)):
+ response["results"][id_val] = result
+
+ return response
+
+ except Exception as e:
+ logger.error(f"批量预测文本时出错: {e}")
+ raise HTTPException(status_code=500, detail=f"批量预测文本时出错: {str(e)}")
+
+
+@app.post("/predict/file")
+async def predict_file(file: UploadFile = File(...), top_k: int = Form(1)):
+ """
+ 预测文件内容
+
+ Args:
+ file: 上传的文件
+ top_k: 返回概率最高的前k个类别
+
+ Returns:
+ 预测结果
+ """
+ logger.info(f"接收到文件预测请求,文件名: {file.filename}")
+
+ try:
+ # 读取文件内容
+ content = await file.read()
+
+ # 根据文件类型处理内容
+ if file.filename.endswith(('.txt', '.md')):
+ # 文本文件
+ text = content.decode('utf-8')
+
+ # 获取预测器
+ predictor = get_predictor()
+
+ # 预测
+ result = predictor.predict(
+ text=text,
+ return_top_k=top_k,
+ return_probabilities=True
+ )
+
+ # 构建响应
+ response = {
+ "success": True,
+ "filename": file.filename,
+ "predictions": result if top_k > 1 else [result]
+ }
+
+ return response
+
+ elif file.filename.endswith(('.csv', '.xls', '.xlsx')):
+ # 表格文件
+ if file.filename.endswith('.csv'):
+ df = pd.read_csv(io.BytesIO(content))
+ else:
+ df = pd.read_excel(io.BytesIO(content))
+
+ # 查找可能的文本列
+ text_columns = [col for col in df.columns if df[col].dtype == 'object']
+
+ if not text_columns:
+ raise HTTPException(status_code=400, detail="文件中没有找到可能的文本列")
+
+ # 使用第一个文本列
+ text_column = text_columns[0]
+ texts = df[text_column].fillna('').tolist()
+
+ # 获取预测器
+ predictor = get_predictor()
+
+ # 批量预测
+ results = predictor.predict_batch(
+ texts=texts,
+ return_top_k=top_k,
+ return_probabilities=True
+ )
+
+ # 构建响应
+ response = {
+ "success": True,
+ "filename": file.filename,
+ "text_column": text_column,
+ "total": len(texts),
+ "results": results
+ }
+
+ return response
+
+ else:
+ raise HTTPException(status_code=400, detail=f"不支持的文件类型: {file.filename}")
+
+ except Exception as e:
+ logger.error(f"预测文件时出错: {e}")
+ raise HTTPException(status_code=500, detail=f"预测文件时出错: {str(e)}")
+
+
+@app.get("/models")
+async def list_models():
+ """
+ 列出可用的模型
+
+ Returns:
+ 可用模型列表
+ """
+ try:
+ # 获取可用模型列表
+ models_info = ModelFactory.get_available_models()
+
+ # 转换为响应格式
+ models = []
+ for info in models_info:
+ models.append(ModelInfo(
+ id=os.path.basename(info['path']),
+ name=info['name'],
+ type=info['type'],
+ num_classes=info['num_classes'],
+ created_time=info['created_time'],
+ file_size=info['file_size']
+ ))
+
+ return {"models": models}
+
+ except Exception as e:
+ logger.error(f"获取模型列表时出错: {e}")
+ raise HTTPException(status_code=500, detail=f"获取模型列表时出错: {str(e)}")
+
+
+@app.get("/categories")
+async def list_categories():
+ """
+ 列出支持的类别
+
+ Returns:
+ 支持的类别列表
+ """
+ try:
+ return {"categories": CATEGORIES}
+ except Exception as e:
+ logger.error(f"获取类别列表时出错: {e}")
+ raise HTTPException(status_code=500, detail=f"获取类别列表时出错: {str(e)}")
+
+
+@app.middleware("http")
+async def log_requests(request: Request, call_next):
+ """
+ 记录请求日志
+
+ Args:
+ request: 请求对象
+ call_next: 下一个处理函数
+
+ Returns:
+ 响应对象
+ """
+ start_time = time.time()
+
+ # 记录请求信息
+ logger.info(f"请求: {request.method} {request.url}")
+
+ # 处理请求
+ response = await call_next(request)
+
+ # 记录响应信息
+ process_time = time.time() - start_time
+ logger.info(f"响应: {response.status_code} ({process_time:.2f}s)")
+
+ return response
+
+
+def run_server(host: str = "0.0.0.0", port: int = 8000):
+ """
+ 运行API服务器
+
+ Args:
+ host: 主机地址
+ port: 端口号
+ """
+ uvicorn.run(app, host=host, port=port)
+
+
+if __name__ == "__main__":
+ # 解析命令行参数
+ import argparse
+
+ parser = argparse.ArgumentParser(description="中文文本分类系统API服务器")
+ parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址")
+ parser.add_argument("--port", type=int, default=8000, help="服务器端口号")
+
+ args = parser.parse_args()
+
+ # 运行服务器
+ run_server(host=args.host, port=args.port)
diff --git a/interface/cli.py b/interface/cli.py
new file mode 100644
index 0000000..4eaf76b
--- /dev/null
+++ b/interface/cli.py
@@ -0,0 +1,378 @@
+"""
+命令行界面模块:提供命令行交互功能
+"""
+import argparse
+import os
+import sys
+import pandas as pd
+from typing import List, Dict, Tuple, Optional, Any, Union
+import json
+
+# 将项目根目录添加到sys.path
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from config.system_config import CLASSIFIERS_DIR, CATEGORIES
+from models.model_factory import ModelFactory
+from models.base_model import TextClassificationModel
+from preprocessing.tokenization import ChineseTokenizer
+from preprocessing.vectorizer import SequenceVectorizer
+from inference.predictor import Predictor
+from inference.batch_processor import BatchProcessor
+from utils.logger import get_logger
+from utils.file_utils import ensure_dir, read_text_file
+
+logger = get_logger("CLI")
+
+
+def load_model_and_components(model_path: Optional[str] = None,
+ tokenizer_path: Optional[str] = None,
+ vectorizer_path: Optional[str] = None,
+ class_names: Optional[List[str]] = None) -> Tuple[
+ TextClassificationModel, ChineseTokenizer, Optional[SequenceVectorizer]]:
+ """
+ 加载模型和相关组件
+
+ Args:
+ model_path: 模型路径,如果为None则使用最新的模型
+ tokenizer_path: 分词器路径,如果为None则创建一个新的分词器
+ vectorizer_path: 向量化器路径,如果为None则不使用向量化器
+ class_names: 类别名称列表,如果为None则使用CATEGORIES
+
+ Returns:
+ (模型, 分词器, 向量化器)的元组
+ """
+ # 加载模型
+ if model_path is None:
+ # 获取可用模型列表
+ models_info = ModelFactory.get_available_models()
+
+ if not models_info:
+ raise ValueError("未找到可用的模型,请指定模型路径")
+
+ # 使用最新的模型
+ model_path = models_info[0]['path']
+ logger.info(f"使用最新的模型: {model_path}")
+
+ # 加载模型
+ model = ModelFactory.load_model(model_path)
+
+ # 加载或创建分词器
+ if tokenizer_path:
+ tokenizer = ChineseTokenizer() # 实际上应该从文件加载,这里简化处理
+ logger.info(f"已加载分词器: {tokenizer_path}")
+ else:
+ tokenizer = ChineseTokenizer()
+ logger.info("已创建新的分词器")
+
+ # 加载向量化器
+ vectorizer = None
+ if vectorizer_path:
+ vectorizer = SequenceVectorizer() # 实际上应该从文件加载,这里简化处理
+ vectorizer.load(vectorizer_path)
+ logger.info(f"已加载向量化器: {vectorizer_path}")
+
+ return model, tokenizer, vectorizer
+
+
+def predict_text(args):
+ """处理单条文本预测命令"""
+ # 加载模型和组件
+ model, tokenizer, vectorizer = load_model_and_components(
+ args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
+ )
+
+ # 创建预测器
+ predictor = Predictor(
+ model=model,
+ tokenizer=tokenizer,
+ vectorizer=vectorizer,
+ class_names=args.class_names or CATEGORIES,
+ batch_size=args.batch_size
+ )
+
+ # 获取文本
+ text = args.text
+
+ # 如果提供的是文件路径而非文本内容
+ if args.file and os.path.exists(text):
+ text = read_text_file(text)
+
+ # 预测
+ result = predictor.predict(
+ text=text,
+ return_top_k=args.top_k,
+ return_probabilities=True
+ )
+
+ # 输出结果
+ if args.top_k > 1:
+ print("\n预测结果:")
+ for i, pred in enumerate(result):
+ print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
+ else:
+ print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})")
+
+ # 保存结果
+ if args.output:
+ if args.output.endswith('.json'):
+ with open(args.output, 'w', encoding='utf-8') as f:
+ json.dump(result, f, ensure_ascii=False, indent=2)
+ else:
+ with open(args.output, 'w', encoding='utf-8') as f:
+ if args.top_k > 1:
+ f.write("rank,class,probability\n")
+ for i, pred in enumerate(result):
+ f.write(f"{i + 1},{pred['class']},{pred['probability']}\n")
+ else:
+ f.write(f"class,probability\n")
+ f.write(f"{result['class']},{result['probability']}\n")
+
+ print(f"结果已保存到: {args.output}")
+
+
+def predict_batch(args):
+ """处理批量文本预测命令"""
+ # 加载模型和组件
+ model, tokenizer, vectorizer = load_model_and_components(
+ args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
+ )
+
+ # 创建预测器
+ predictor = Predictor(
+ model=model,
+ tokenizer=tokenizer,
+ vectorizer=vectorizer,
+ class_names=args.class_names or CATEGORIES,
+ batch_size=args.batch_size
+ )
+
+ # 创建批处理器
+ batch_processor = BatchProcessor(
+ predictor=predictor,
+ batch_size=args.batch_size,
+ max_workers=args.workers
+ )
+
+ # 确保输出目录存在
+ if args.output:
+ ensure_dir(os.path.dirname(args.output))
+
+ # 根据输入类型选择处理方法
+ if args.input_type == 'file' and os.path.isfile(args.input):
+ # 单个文件
+ if args.large_file:
+ # 大型文件,分块处理
+ batch_processor.process_large_file(
+ file_path=args.input,
+ output_path=args.output,
+ return_top_k=args.top_k,
+ format=args.format
+ )
+ else:
+ # CSV或Excel文件
+ if args.input.endswith('.csv'):
+ df = pd.read_csv(args.input, encoding='utf-8')
+ elif args.input.endswith(('.xls', '.xlsx')):
+ df = pd.read_excel(args.input)
+ else:
+ print(f"不支持的文件格式: {args.input}")
+ return
+
+ # 检查文本列是否存在
+ if args.text_column not in df.columns:
+ print(f"文本列 '{args.text_column}' 不在输入文件中,可用列: {', '.join(df.columns)}")
+ return
+
+ # 处理DataFrame
+ result_df = batch_processor.process_dataframe(
+ df=df,
+ text_column=args.text_column,
+ id_column=args.id_column,
+ output_path=args.output,
+ return_top_k=args.top_k,
+ format=args.format
+ )
+
+ # 输出结果统计
+ print(f"\n已处理 {len(result_df)} 条文本")
+ print("类别分布:")
+ if args.top_k == 1:
+ class_counts = result_df['predicted_class'].value_counts()
+ for cls, count in class_counts.items():
+ print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)")
+
+ elif args.input_type == 'dir' and os.path.isdir(args.input):
+ # 目录
+ result_df = batch_processor.process_directory(
+ directory=args.input,
+ pattern=args.pattern,
+ output_path=args.output,
+ return_top_k=args.top_k,
+ format=args.format,
+ recursive=args.recursive
+ )
+
+ # 输出结果统计
+ if not result_df.empty:
+ print(f"\n已处理 {len(result_df)} 个文件")
+ print("类别分布:")
+ if args.top_k == 1:
+ class_counts = result_df['predicted_class'].value_counts()
+ for cls, count in class_counts.items():
+ print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)")
+
+ else:
+ print(f"无效的输入: {args.input}")
+
+
+def list_models(args):
+ """列出可用的模型"""
+ models_info = ModelFactory.get_available_models()
+
+ if not models_info:
+ print("未找到可用的模型")
+ return
+
+ print(f"找到 {len(models_info)} 个可用模型:")
+ for i, info in enumerate(models_info):
+ print(f"\n{i + 1}. {info['name']} ({info['type']})")
+ print(f" 路径: {info['path']}")
+ print(f" 创建时间: {info['created_time']}")
+ print(f" 类别数: {info['num_classes']}")
+ print(f" 文件大小: {info['file_size']}")
+
+
+def interactive_mode(args):
+ """交互模式"""
+ print("启动交互模式...")
+
+ # 加载模型和组件
+ model, tokenizer, vectorizer = load_model_and_components(
+ args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
+ )
+
+ # 创建预测器
+ predictor = Predictor(
+ model=model,
+ tokenizer=tokenizer,
+ vectorizer=vectorizer,
+ class_names=args.class_names or CATEGORIES,
+ batch_size=args.batch_size
+ )
+
+ print("\n模型已加载,可以开始交互式文本分类")
+ print("输入 'quit' 或 'exit' 退出交互模式\n")
+
+ while True:
+ try:
+ # 获取用户输入
+ text = input("请输入要分类的文本: ")
+
+ # 检查是否退出
+ if text.lower() in ['quit', 'exit', 'q']:
+ print("退出交互模式")
+ break
+
+ # 空输入
+ if not text.strip():
+ continue
+
+ # 预测
+ result = predictor.predict(
+ text=text,
+ return_top_k=args.top_k,
+ return_probabilities=True
+ )
+
+ # 输出结果
+ if args.top_k > 1:
+ print("\n预测结果:")
+ for i, pred in enumerate(result):
+ print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
+ else:
+ print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})")
+
+ print() # 空行
+
+ except KeyboardInterrupt:
+ print("\n退出交互模式")
+ break
+ except Exception as e:
+ print(f"处理过程中出错: {e}")
+
+
+def main():
+ """主函数,解析命令行参数并调用相应的功能"""
+ parser = argparse.ArgumentParser(description="中文文本分类系统命令行工具")
+
+ # 创建子命令
+ subparsers = parser.add_subparsers(dest="command", help="子命令")
+
+ # 预测单条文本命令
+ predict_parser = subparsers.add_parser("predict", help="预测单条文本")
+ predict_parser.add_argument("text", help="要预测的文本或文本文件路径")
+ predict_parser.add_argument("--file", action="store_true", help="将text参数视为文件路径")
+ predict_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
+ predict_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
+ predict_parser.add_argument("--vectorizer_path", help="向量化器路径")
+ predict_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
+ predict_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别")
+ predict_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
+ predict_parser.add_argument("--output", help="保存预测结果的文件路径")
+ predict_parser.set_defaults(func=predict_text)
+
+ # 批量预测命令
+ batch_parser = subparsers.add_parser("batch", help="批量预测文本")
+ batch_parser.add_argument("input", help="输入文件或目录路径")
+ batch_parser.add_argument("--input_type", choices=["file", "dir"], default="file", help="输入类型")
+ batch_parser.add_argument("--text_column", default="text", help="CSV/Excel文件中的文本列名")
+ batch_parser.add_argument("--id_column", help="CSV/Excel文件中的ID列名")
+ batch_parser.add_argument("--pattern", default="*.txt", help="文件匹配模式")
+ batch_parser.add_argument("--recursive", action="store_true", help="递归处理子目录")
+ batch_parser.add_argument("--large_file", action="store_true", help="处理大型文本文件")
+ batch_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
+ batch_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
+ batch_parser.add_argument("--vectorizer_path", help="向量化器路径")
+ batch_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
+ batch_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别")
+ batch_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
+ batch_parser.add_argument("--workers", type=int, default=4, help="工作线程数")
+ batch_parser.add_argument("--output", required=True, help="输出文件路径")
+ batch_parser.add_argument("--format", choices=["csv", "json"], default="csv", help="输出格式")
+ batch_parser.set_defaults(func=predict_batch)
+
+ # 列出可用模型命令
+ list_parser = subparsers.add_parser("list", help="列出可用的模型")
+ list_parser.set_defaults(func=list_models)
+
+ # 交互模式命令
+ interactive_parser = subparsers.add_parser("interactive", help="启动交互式分类模式")
+ interactive_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
+ interactive_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
+ interactive_parser.add_argument("--vectorizer_path", help="向量化器路径")
+ interactive_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
+ interactive_parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别")
+ interactive_parser.add_argument("--batch_size", type=int, default=1, help="批大小")
+ interactive_parser.set_defaults(func=interactive_mode)
+
+ # 解析参数
+ args = parser.parse_args()
+
+ # 如果没有指定命令,显示帮助
+ if not hasattr(args, 'func'):
+ parser.print_help()
+ return
+
+ # 执行命令
+ try:
+ args.func(args)
+ except Exception as e:
+ logger.error(f"执行命令时出错: {e}")
+ print(f"执行命令时出错: {e}")
+ return 1
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/interface/web/__init__.py b/interface/web/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/interface/web/app.py b/interface/web/app.py
new file mode 100644
index 0000000..de3fa52
--- /dev/null
+++ b/interface/web/app.py
@@ -0,0 +1,260 @@
+return render_template('predict_text.html', text=text, predictions=predictions)
+
+except Exception as e:
+logger.error(f"预测文本时出错: {e}")
+flash(f'预测失败: {str(e)}')
+return render_template('predict_text.html', text=text)
+
+return render_template('predict_text.html')
+
+
+@app.route('/predict_file', methods=['GET', 'POST'])
+def predict_file():
+ """文件预测页面"""
+ if request.method == 'POST':
+ # 检查是否有文件上传
+ if 'file' not in request.files:
+ flash('未选择文件')
+ return render_template('predict_file.html')
+
+ file = request.files['file']
+
+ # 如果用户没有选择文件
+ if file.filename == '':
+ flash('未选择文件')
+ return render_template('predict_file.html')
+
+ if file and allowed_file(file.filename):
+ # 获取参数
+ top_k = int(request.form.get('top_k', 3))
+ text_column = request.form.get('text_column', 'text')
+
+ try:
+ # 安全地保存文件
+ filename = secure_filename(file.filename)
+ file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
+ file.save(file_path)
+
+ # 处理文件类型
+ if filename.endswith('.txt'):
+ # 文本文件
+ with open(file_path, 'r', encoding='utf-8') as f:
+ text = f.read()
+
+ # 获取预测器
+ predictor = get_predictor()
+
+ # 预测
+ result = predictor.predict(
+ text=text,
+ return_top_k=top_k,
+ return_probabilities=True
+ )
+
+ # 准备结果
+ predictions = result if top_k > 1 else [result]
+
+ return render_template(
+ 'predict_file.html',
+ filename=filename,
+ file_type='text',
+ predictions=predictions
+ )
+
+ elif filename.endswith(('.csv', '.xls', '.xlsx')):
+ # 表格文件
+ if filename.endswith('.csv'):
+ df = pd.read_csv(file_path)
+ else:
+ df = pd.read_excel(file_path)
+
+ # 检查文本列是否存在
+ if text_column not in df.columns:
+ flash(f"文件中没有找到列: {text_column}")
+ return render_template('predict_file.html')
+
+ # 提取文本
+ texts = df[text_column].fillna('').tolist()
+
+ # 获取预测器
+ predictor = get_predictor()
+
+ # 批量预测
+ results = predictor.predict_batch(
+ texts=texts[:100], # 仅处理前100条记录
+ return_top_k=top_k,
+ return_probabilities=True
+ )
+
+ # 准备结果
+ batch_results = []
+ for i, result in enumerate(results):
+ batch_results.append({
+ 'id': i,
+ 'text': texts[i][:100] + '...' if len(texts[i]) > 100 else texts[i],
+ 'predictions': result if top_k > 1 else [result]
+ })
+
+ return render_template(
+ 'predict_file.html',
+ filename=filename,
+ file_type='table',
+ total_records=len(texts),
+ processed_records=min(100, len(texts)),
+ batch_results=batch_results
+ )
+
+ else:
+ flash(f"不支持的文件类型: {filename}")
+ return render_template('predict_file.html')
+
+ except Exception as e:
+ logger.error(f"处理文件时出错: {e}")
+ flash(f'处理失败: {str(e)}')
+ return render_template('predict_file.html')
+ else:
+ flash(f'不支持的文件类型,允许的类型: {", ".join(ALLOWED_EXTENSIONS)}')
+ return render_template('predict_file.html')
+
+ return render_template('predict_file.html')
+
+
+@app.route('/batch_predict', methods=['GET', 'POST'])
+def batch_predict():
+ """批量预测页面"""
+ if request.method == 'POST':
+ texts = request.form.get('texts', '')
+ top_k = int(request.form.get('top_k', 3))
+
+ if not texts:
+ flash('请输入文本内容')
+ return render_template('batch_predict.html')
+
+ # 分割文本
+ text_list = [text.strip() for text in texts.split('\n') if text.strip()]
+
+ if not text_list:
+ flash('请输入有效的文本内容')
+ return render_template('batch_predict.html', texts=texts)
+
+ try:
+ # 获取预测器
+ predictor = get_predictor()
+
+ # 批量预测
+ results = predictor.predict_batch(
+ texts=text_list,
+ return_top_k=top_k,
+ return_probabilities=True
+ )
+
+ # 准备结果
+ batch_results = []
+ for i, result in enumerate(results):
+ batch_results.append({
+ 'id': i + 1,
+ 'text': text_list[i],
+ 'predictions': result if top_k > 1 else [result]
+ })
+
+ return render_template(
+ 'batch_predict.html',
+ texts=texts,
+ batch_results=batch_results
+ )
+
+ except Exception as e:
+ logger.error(f"批量预测时出错: {e}")
+ flash(f'预测失败: {str(e)}')
+ return render_template('batch_predict.html', texts=texts)
+
+ return render_template('batch_predict.html')
+
+
+@app.route('/models')
+def list_models():
+ """模型列表页面"""
+ try:
+ # 获取可用模型列表
+ models_info = ModelFactory.get_available_models()
+
+ return render_template('models.html', models=models_info)
+
+ except Exception as e:
+ logger.error(f"获取模型列表时出错: {e}")
+ flash(f'获取模型列表失败: {str(e)}')
+ return render_template('models.html', models=[])
+
+
+@app.route('/about')
+def about():
+ """关于页面"""
+ return render_template('about.html')
+
+
+@app.errorhandler(404)
+def page_not_found(e):
+ """404错误处理"""
+ return render_template('404.html'), 404
+
+
+@app.errorhandler(500)
+def internal_server_error(e):
+ """500错误处理"""
+ return render_template('500.html'), 500
+
+
+# 过滤器
+@app.template_filter('format_time')
+def format_time_filter(timestamp):
+ """格式化时间戳"""
+ if isinstance(timestamp, str):
+ try:
+ dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
+ return dt.strftime("%Y年%m月%d日 %H:%M")
+ except:
+ return timestamp
+ return timestamp
+
+
+@app.template_filter('truncate_text')
+def truncate_text_filter(text, length=100):
+ """截断文本"""
+ if len(text) <= length:
+ return text
+ return text[:length] + '...'
+
+
+@app.template_filter('format_percent')
+def format_percent_filter(value):
+ """格式化百分比"""
+ if isinstance(value, (int, float)):
+ return f"{value * 100:.2f}%"
+ return value
+
+
+def run_server(host: str = "0.0.0.0", port: int = 5000, debug: bool = False):
+ """
+ 运行Web服务器
+
+ Args:
+ host: 主机地址
+ port: 端口号
+ debug: 是否开启调试模式
+ """
+ app.run(host=host, port=port, debug=debug)
+
+
+if __name__ == "__main__":
+ # 解析命令行参数
+ import argparse
+
+ parser = argparse.ArgumentParser(description="中文文本分类系统Web应用")
+ parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址")
+ parser.add_argument("--port", type=int, default=5000, help="服务器端口号")
+ parser.add_argument("--debug", action="store_true", help="是否开启调试模式")
+
+ args = parser.parse_args()
+
+ # 运行服务器
+ run_server(host=args.host, port=args.port, debug=args.debug)
diff --git a/interface/web/routes.py b/interface/web/routes.py
new file mode 100644
index 0000000..be05188
--- /dev/null
+++ b/interface/web/routes.py
@@ -0,0 +1,166 @@
+"""
+Web路由模块:定义Web应用的路由
+"""
+from flask import Blueprint, render_template, request, jsonify, redirect, url_for, session, flash
+from werkzeug.utils import secure_filename
+import os
+import pandas as pd
+import io
+import time
+from typing import List, Dict, Tuple, Optional
+
+from interface.web.app import get_predictor, allowed_file, app
+
+# 创建蓝图
+bp = Blueprint('routes', __name__)
+
+
+@bp.route('/')
+def index():
+ """首页"""
+ return render_template('index.html')
+
+
+@bp.route('/predict_text', methods=['GET', 'POST'])
+def predict_text():
+ """文本预测页面"""
+ if request.method == 'POST':
+ text = request.form.get('text', '')
+ top_k = int(request.form.get('top_k', 3))
+
+ if not text:
+ flash('请输入文本内容')
+ return render_template('predict_text.html')
+
+ try:
+ # 获取预测器
+ predictor = get_predictor()
+
+ # 预测
+ result = predictor.predict(
+ text=text,
+ return_top_k=top_k,
+ return_probabilities=True
+ )
+
+ # 准备结果
+ predictions = result if top_k > 1 else [result]
+
+ return render_template('predict_text.html', text=text, predictions=predictions)
+
+ except Exception as e:
+ flash(f'预测失败: {str(e)}')
+ return render_template('predict_text.html', text=text)
+
+ return render_template('predict_text.html')
+
+
+@bp.route('/predict_file', methods=['GET', 'POST'])
+def predict_file():
+ """文件预测页面"""
+ if request.method == 'POST':
+ # 检查是否有文件上传
+ if 'file' not in request.files:
+ flash('未选择文件')
+ return render_template('predict_file.html')
+
+ file = request.files['file']
+
+ # 如果用户没有选择文件
+ if file.filename == '':
+ flash('未选择文件')
+ return render_template('predict_file.html')
+
+ if file and allowed_file(file.filename):
+ # 获取参数
+ top_k = int(request.form.get('top_k', 3))
+ text_column = request.form.get('text_column', 'text')
+
+ try:
+ # 安全地保存文件
+ filename = secure_filename(file.filename)
+ file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
+ file.save(file_path)
+
+ # 处理文件
+ # 注意:此处与app.py中的predict_file重复,实际项目中应该将逻辑抽取到单独的函数中
+ # 这里简化处理,仅返回渲染模板
+
+ return render_template('predict_file.html', filename=filename)
+
+ except Exception as e:
+ flash(f'处理失败: {str(e)}')
+ return render_template('predict_file.html')
+ else:
+ flash(f'不支持的文件类型')
+ return render_template('predict_file.html')
+
+ return render_template('predict_file.html')
+
+
+@bp.route('/batch_predict', methods=['GET', 'POST'])
+def batch_predict():
+ """批量预测页面"""
+ if request.method == 'POST':
+ texts = request.form.get('texts', '')
+ top_k = int(request.form.get('top_k', 3))
+
+ if not texts:
+ flash('请输入文本内容')
+ return render_template('batch_predict.html')
+
+ # 分割文本
+ text_list = [text.strip() for text in texts.split('\n') if text.strip()]
+
+ if not text_list:
+ flash('请输入有效的文本内容')
+ return render_template('batch_predict.html', texts=texts)
+
+ try:
+ # 获取预测器
+ predictor = get_predictor()
+
+ # 批量预测
+ results = predictor.predict_batch(
+ texts=text_list,
+ return_top_k=top_k,
+ return_probabilities=True
+ )
+
+ # 准备结果
+ batch_results = []
+ for i, result in enumerate(results):
+ batch_results.append({
+ 'id': i + 1,
+ 'text': text_list[i],
+ 'predictions': result if top_k > 1 else [result]
+ })
+
+ return render_template(
+ 'batch_predict.html',
+ texts=texts,
+ batch_results=batch_results
+ )
+
+ except Exception as e:
+ flash(f'预测失败: {str(e)}')
+ return render_template('batch_predict.html', texts=texts)
+
+ return render_template('batch_predict.html')
+
+
+@bp.route('/models')
+def list_models():
+ """模型列表页面"""
+ return render_template('models.html')
+
+
+@bp.route('/about')
+def about():
+ """关于页面"""
+ return render_template('about.html')
+
+
+# 注册蓝图
+def init_app(app):
+ app.register_blueprint(bp)
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..05170bf
--- /dev/null
+++ b/main.py
@@ -0,0 +1,142 @@
+"""
+主入口文件:整合系统的所有功能,提供命令行接口
+"""
+import os
+import sys
+import argparse
+import logging
+from typing import List, Optional
+
+# 设置日志
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger = logging.getLogger("main")
+
+# 命令行参数
+parser = argparse.ArgumentParser(description="中文文本分类系统")
+subparsers = parser.add_subparsers(dest="command", help="命令")
+
+# 训练命令
+train_parser = subparsers.add_parser("train", help="训练模型")
+train_parser.add_argument("--data_dir", help="数据目录")
+train_parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型")
+train_parser.add_argument("--epochs", type=int, default=10, help="训练轮数")
+train_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
+train_parser.add_argument("--save_dir", help="模型保存目录")
+
+# 评估命令
+evaluate_parser = subparsers.add_parser("evaluate", help="评估模型")
+evaluate_parser.add_argument("--model_path", required=True, help="模型路径")
+evaluate_parser.add_argument("--data_dir", help="数据目录")
+evaluate_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
+evaluate_parser.add_argument("--output_dir", help="评估结果输出目录")
+
+# 预测命令
+predict_parser = subparsers.add_parser("predict", help="使用模型预测")
+predict_parser.add_argument("--model_path", help="模型路径")
+predict_parser.add_argument("--text", help="要预测的文本")
+predict_parser.add_argument("--file", help="要预测的文件")
+predict_parser.add_argument("--output", help="输出文件")
+
+# Web服务命令
+web_parser = subparsers.add_parser("web", help="启动Web服务")
+web_parser.add_argument("--host", default="0.0.0.0", help="服务器主机")
+web_parser.add_argument("--port", type=int, default=5000, help="服务器端口")
+web_parser.add_argument("--debug", action="store_true", help="是否开启调试模式")
+
+# API服务命令
+api_parser = subparsers.add_parser("api", help="启动API服务")
+api_parser.add_argument("--host", default="0.0.0.0", help="服务器主机")
+api_parser.add_argument("--port", type=int, default=8000, help="服务器端口")
+
+# CLI命令
+cli_parser = subparsers.add_parser("cli", help="启动命令行接口")
+cli_parser.add_argument("--model_path", help="模型路径")
+cli_parser.add_argument("--interactive", action="store_true", help="是否开启交互模式")
+
+
+def main():
+ """主函数"""
+ args = parser.parse_args()
+
+ # 如果没有指定命令,显示帮助信息
+ if not args.command:
+ parser.print_help()
+ return 0
+
+ # 根据命令调用相应的功能
+ if args.command == "train":
+ # 导入训练模块
+ from scripts.train import train_model
+
+ # 调用训练功能
+ train_model(
+ data_dir=args.data_dir,
+ model_type=args.model_type,
+ epochs=args.epochs,
+ batch_size=args.batch_size,
+ save_dir=args.save_dir
+ )
+
+ elif args.command == "evaluate":
+ # 导入评估模块
+ from scripts.evaluate import evaluate_model
+
+ # 调用评估功能
+ evaluate_model(
+ model_path=args.model_path,
+ data_dir=args.data_dir,
+ batch_size=args.batch_size,
+ output_dir=args.output_dir
+ )
+
+ elif args.command == "predict":
+ # 导入预测模块
+ from scripts.predict import predict_text, predict_file
+
+ # 根据输入类型调用相应的预测功能
+ if args.text:
+ predict_text(args.text, args.model_path, args.output)
+ elif args.file:
+ predict_file(args.file, args.model_path, args.output)
+ else:
+ logger.error("请提供要预测的文本或文件")
+ return 1
+
+ elif args.command == "web":
+ # 导入Web服务模块
+ from interface.web.app import run_server
+
+ # 启动Web服务
+ run_server(host=args.host, port=args.port, debug=args.debug)
+
+ elif args.command == "api":
+ # 导入API服务模块
+ from interface.api import run_server
+
+ # 启动API服务
+ run_server(host=args.host, port=args.port)
+
+ elif args.command == "cli":
+ # 导入CLI模块
+ from interface.cli import main as cli_main
+
+ # 将命令行参数转换为CLI模块可接受的格式
+ sys.argv = ["interface/cli.py"]
+
+ if args.interactive:
+ sys.argv.append("interactive")
+ if args.model_path:
+ sys.argv.extend(["--model_path", args.model_path])
+ elif args.model_path:
+ sys.argv.extend(["list", "--model_path", args.model_path])
+ else:
+ sys.argv.append("list")
+
+ # 调用CLI主函数
+ return cli_main()
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/models/base_model.py b/models/base_model.py
new file mode 100644
index 0000000..090ddc5
--- /dev/null
+++ b/models/base_model.py
@@ -0,0 +1,419 @@
+"""
+模型基类:定义所有文本分类模型的通用接口
+"""
+import os
+import time
+import json
+import numpy as np
+import tensorflow as tf
+from tensorflow.keras.models import Model, load_model
+from typing import List, Dict, Tuple, Optional, Any, Union, Callable
+from abc import ABC, abstractmethod
+
+from config.system_config import SAVED_MODELS_DIR, CLASSIFIERS_DIR
+from config.model_config import (
+ BATCH_SIZE, LEARNING_RATE, EARLY_STOPPING_PATIENCE,
+ REDUCE_LR_PATIENCE, REDUCE_LR_FACTOR
+)
+from utils.logger import get_logger
+from utils.file_utils import ensure_dir, save_json
+
+logger = get_logger("BaseModel")
+
+
+class TextClassificationModel(ABC):
+ """文本分类模型基类,定义所有模型的通用接口"""
+
+ def __init__(self, num_classes: int, model_name: str = "text_classifier",
+ batch_size: int = BATCH_SIZE,
+ learning_rate: float = LEARNING_RATE):
+ """
+ 初始化文本分类模型
+
+ Args:
+ num_classes: 类别数量
+ model_name: 模型名称
+ batch_size: 批大小
+ learning_rate: 学习率
+ """
+ self.num_classes = num_classes
+ self.model_name = model_name
+ self.batch_size = batch_size
+ self.learning_rate = learning_rate
+
+ # 模型实例
+ self.model = None
+
+ # 训练历史
+ self.history = None
+
+ # 训练配置
+ self.config = {
+ "model_name": model_name,
+ "num_classes": num_classes,
+ "batch_size": batch_size,
+ "learning_rate": learning_rate
+ }
+
+ # 验证集合最佳性能
+ self.best_val_loss = float('inf')
+ self.best_val_accuracy = 0.0
+
+ logger.info(f"初始化 {model_name} 模型,类别数: {num_classes}")
+
+ @abstractmethod
+ def build(self) -> None:
+ """构建模型架构,这是一个抽象方法,子类必须实现"""
+ pass
+
+ def compile(self, optimizer: Optional[tf.keras.optimizers.Optimizer] = None,
+ loss: Optional[Union[str, tf.keras.losses.Loss]] = None,
+ metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None) -> None:
+ """
+ 编译模型
+
+ Args:
+ optimizer: 优化器,默认为Adam
+ loss: 损失函数,默认为sparse_categorical_crossentropy
+ metrics: 评估指标,默认为accuracy
+ """
+ if self.model is None:
+ raise ValueError("模型尚未构建,请先调用build方法")
+
+ # 默认优化器
+ if optimizer is None:
+ optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
+
+ # 默认损失函数
+ if loss is None:
+ loss = 'sparse_categorical_crossentropy'
+
+ # 默认评估指标
+ if metrics is None:
+ metrics = ['accuracy']
+
+ # 编译模型
+ self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
+ logger.info(f"模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}, 评估指标: {metrics}")
+
+ def summary(self) -> None:
+ """打印模型概要"""
+ if self.model is None:
+ raise ValueError("模型尚未构建,请先调用build方法")
+
+ self.model.summary()
+
+ def fit(self, x_train: Union[np.ndarray, tf.data.Dataset],
+ y_train: Optional[np.ndarray] = None,
+ validation_data: Optional[Union[Tuple[np.ndarray, np.ndarray], tf.data.Dataset]] = None,
+ epochs: int = 10,
+ callbacks: Optional[List[tf.keras.callbacks.Callback]] = None,
+ class_weights: Optional[Dict[int, float]] = None,
+ verbose: int = 1) -> tf.keras.callbacks.History:
+ """
+ 训练模型
+
+ Args:
+ x_train: 训练数据特征
+ y_train: 训练数据标签
+ validation_data: 验证数据
+ epochs: 训练轮数
+ callbacks: 回调函数列表
+ class_weights: 类别权重
+ verbose: 详细程度
+
+ Returns:
+ 训练历史
+ """
+ if self.model is None:
+ raise ValueError("模型尚未构建,请先调用build方法")
+
+ # 记录开始时间
+ start_time = time.time()
+
+ # 添加默认回调函数
+ if callbacks is None:
+ callbacks = self._get_default_callbacks()
+
+ # 训练模型
+ if isinstance(x_train, tf.data.Dataset):
+ # 如果输入是TensorFlow Dataset
+ history = self.model.fit(
+ x_train,
+ epochs=epochs,
+ validation_data=validation_data,
+ callbacks=callbacks,
+ class_weight=class_weights,
+ verbose=verbose
+ )
+ else:
+ # 如果输入是NumPy数组
+ history = self.model.fit(
+ x_train, y_train,
+ batch_size=self.batch_size,
+ epochs=epochs,
+ validation_data=validation_data,
+ callbacks=callbacks,
+ class_weight=class_weights,
+ verbose=verbose
+ )
+
+ # 计算训练时间
+ train_time = time.time() - start_time
+
+ # 保存训练历史
+ self.history = history.history
+ self.history['train_time'] = train_time
+
+ logger.info(f"模型训练完成,耗时: {train_time:.2f} 秒")
+
+ return history
+
+ def evaluate(self, x_test: Union[np.ndarray, tf.data.Dataset],
+ y_test: Optional[np.ndarray] = None,
+ verbose: int = 1) -> Dict[str, float]:
+ """
+ 评估模型
+
+ Args:
+ x_test: 测试数据特征
+ y_test: 测试数据标签
+ verbose: 详细程度
+
+ Returns:
+ 评估结果字典
+ """
+ if self.model is None:
+ raise ValueError("模型尚未构建,请先调用build方法")
+
+ # 评估模型
+ if isinstance(x_test, tf.data.Dataset):
+ # 如果输入是TensorFlow Dataset
+ results = self.model.evaluate(x_test, verbose=verbose)
+ else:
+ # 如果输入是NumPy数组
+ results = self.model.evaluate(x_test, y_test, batch_size=self.batch_size, verbose=verbose)
+
+ # 构建评估结果字典
+ metrics_names = self.model.metrics_names
+ evaluation_results = {name: float(value) for name, value in zip(metrics_names, results)}
+
+ logger.info(f"模型评估结果: {evaluation_results}")
+
+ return evaluation_results
+
+ def predict(self, x: Union[np.ndarray, tf.data.Dataset, List],
+ batch_size: Optional[int] = None,
+ verbose: int = 0) -> np.ndarray:
+ """
+ 使用模型进行预测
+
+ Args:
+ x: 预测数据
+ batch_size: 批大小
+ verbose: 详细程度
+
+ Returns:
+ 预测结果
+ """
+ if self.model is None:
+ raise ValueError("模型尚未构建,请先调用build方法")
+
+ # 使用模型进行预测
+ if batch_size is None:
+ batch_size = self.batch_size
+
+ return self.model.predict(x, batch_size=batch_size, verbose=verbose)
+
+ def predict_classes(self, x: Union[np.ndarray, tf.data.Dataset, List],
+ batch_size: Optional[int] = None,
+ verbose: int = 0) -> np.ndarray:
+ """
+ 使用模型预测类别
+
+ Args:
+ x: 预测数据
+ batch_size: 批大小
+ verbose: 详细程度
+
+ Returns:
+ 预测的类别索引
+ """
+ # 获取模型预测概率
+ predictions = self.predict(x, batch_size, verbose)
+
+ # 获取最大概率的类别索引
+ return np.argmax(predictions, axis=1)
+
+ def save(self, filepath: Optional[str] = None,
+ save_format: str = 'tf',
+ include_optimizer: bool = True) -> str:
+ """
+ 保存模型
+
+ Args:
+ filepath: 保存路径,如果为None则使用默认路径
+ save_format: 保存格式,'tf'或'h5'
+ include_optimizer: 是否包含优化器状态
+
+ Returns:
+ 保存路径
+ """
+ if self.model is None:
+ raise ValueError("模型尚未构建,请先调用build方法")
+
+ # 如果未指定保存路径,使用默认路径
+ if filepath is None:
+ ensure_dir(CLASSIFIERS_DIR)
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
+ filepath = os.path.join(CLASSIFIERS_DIR, f"{self.model_name}_{timestamp}")
+
+ # 保存模型
+ self.model.save(filepath, save_format=save_format, include_optimizer=include_optimizer)
+
+ # 保存模型配置
+ config_path = f"{filepath}_config.json"
+ with open(config_path, 'w', encoding='utf-8') as f:
+ json.dump(self.config, f, ensure_ascii=False, indent=4)
+
+ logger.info(f"模型已保存到: {filepath}")
+
+ return filepath
+
+ @classmethod
+ def load(cls, filepath: str, custom_objects: Optional[Dict[str, Any]] = None) -> 'TextClassificationModel':
+ """
+ 加载模型
+
+ Args:
+ filepath: 模型文件路径
+ custom_objects: 自定义对象字典
+
+ Returns:
+ 加载的模型实例
+ """
+ # 加载模型配置
+ config_path = f"{filepath}_config.json"
+
+ try:
+ with open(config_path, 'r', encoding='utf-8') as f:
+ config = json.load(f)
+ except FileNotFoundError:
+ logger.warning(f"未找到模型配置文件: {config_path},将使用默认配置")
+ config = {}
+
+ # 创建模型实例
+ model_name = config.get('model_name', 'loaded_model')
+ num_classes = config.get('num_classes', 1)
+ batch_size = config.get('batch_size', BATCH_SIZE)
+ learning_rate = config.get('learning_rate', LEARNING_RATE)
+
+ instance = cls(num_classes, model_name, batch_size, learning_rate)
+
+ # 加载Keras模型
+ instance.model = load_model(filepath, custom_objects=custom_objects)
+
+ # 加载配置
+ instance.config = config
+
+ logger.info(f"从 {filepath} 加载模型成功")
+
+ return instance
+
+ def _get_default_callbacks(self) -> List[tf.keras.callbacks.Callback]:
+ """获取默认的回调函数列表"""
+ # 早停
+ early_stopping = tf.keras.callbacks.EarlyStopping(
+ monitor='val_loss',
+ patience=EARLY_STOPPING_PATIENCE,
+ restore_best_weights=True,
+ verbose=1
+ )
+
+ # 学习率衰减
+ reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
+ monitor='val_loss',
+ factor=REDUCE_LR_FACTOR,
+ patience=REDUCE_LR_PATIENCE,
+ min_lr=1e-6,
+ verbose=1
+ )
+
+ # 模型检查点
+ checkpoint_path = os.path.join(SAVED_MODELS_DIR, 'checkpoints', self.model_name)
+ ensure_dir(os.path.dirname(checkpoint_path))
+
+ model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
+ filepath=checkpoint_path,
+ save_best_only=True,
+ monitor='val_loss',
+ verbose=1
+ )
+
+ # TensorBoard日志
+ log_dir = os.path.join(SAVED_MODELS_DIR, 'logs', f"{self.model_name}_{time.strftime('%Y%m%d_%H%M%S')}")
+ ensure_dir(log_dir)
+
+ tensorboard = tf.keras.callbacks.TensorBoard(
+ log_dir=log_dir,
+ histogram_freq=1
+ )
+
+ return [early_stopping, reduce_lr, model_checkpoint, tensorboard]
+
+ def get_config(self) -> Dict[str, Any]:
+ """获取模型配置"""
+ return self.config.copy()
+
+ def get_model(self) -> Model:
+ """获取Keras模型实例"""
+ return self.model
+
+ def get_training_history(self) -> Optional[Dict[str, List[float]]]:
+ """获取训练历史"""
+ return self.history
+
+ def plot_training_history(self, save_path: Optional[str] = None,
+ metrics: Optional[List[str]] = None) -> None:
+ """
+ 绘制训练历史
+
+ Args:
+ save_path: 保存路径,如果为None则显示图像
+ metrics: 要绘制的指标列表,默认为['loss', 'accuracy']
+ """
+ if self.history is None:
+ raise ValueError("模型尚未训练,没有训练历史")
+
+ import matplotlib.pyplot as plt
+
+ if metrics is None:
+ metrics = ['loss', 'accuracy']
+
+ # 创建图形
+ plt.figure(figsize=(12, 5))
+
+ # 绘制指标
+ for i, metric in enumerate(metrics):
+ plt.subplot(1, len(metrics), i + 1)
+
+ if metric in self.history:
+ plt.plot(self.history[metric], label=f'train_{metric}')
+
+ val_metric = f'val_{metric}'
+ if val_metric in self.history:
+ plt.plot(self.history[val_metric], label=f'val_{metric}')
+
+ plt.title(f'Model {metric}')
+ plt.xlabel('Epoch')
+ plt.ylabel(metric)
+ plt.legend()
+
+ plt.tight_layout()
+
+ # 保存或显示图像
+ if save_path:
+ plt.savefig(save_path)
+ logger.info(f"训练历史图已保存到: {save_path}")
+ else:
+ plt.show()
\ No newline at end of file
diff --git a/models/cnn_model.py b/models/cnn_model.py
new file mode 100644
index 0000000..056f823
--- /dev/null
+++ b/models/cnn_model.py
@@ -0,0 +1,180 @@
+"""
+CNN模型:实现基于卷积神经网络的文本分类模型
+"""
+import tensorflow as tf
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import (
+ Input, Embedding, Conv1D, MaxPooling1D, GlobalMaxPooling1D,
+ Dense, Dropout, Concatenate, BatchNormalization, Activation
+)
+from typing import List, Dict, Tuple, Optional, Any, Union
+
+from config.model_config import (
+ MAX_SEQUENCE_LENGTH, CNN_CONFIG
+)
+from models.base_model import TextClassificationModel
+from utils.logger import get_logger
+
+logger = get_logger("CNNModel")
+
+
+class CNNTextClassifier(TextClassificationModel):
+ """卷积神经网络文本分类模型"""
+
+ def __init__(self, num_classes: int, vocab_size: int,
+ embedding_dim: int = CNN_CONFIG["embedding_dim"],
+ max_sequence_length: int = MAX_SEQUENCE_LENGTH,
+ num_filters: int = CNN_CONFIG["num_filters"],
+ filter_sizes: List[int] = CNN_CONFIG["filter_sizes"],
+ dropout_rate: float = CNN_CONFIG["dropout_rate"],
+ l2_reg_lambda: float = CNN_CONFIG["l2_reg_lambda"],
+ embedding_matrix: Optional[np.ndarray] = None,
+ trainable_embedding: bool = True,
+ model_name: str = "cnn_text_classifier",
+ batch_size: int = 64,
+ learning_rate: float = 0.001):
+ """
+ 初始化CNN文本分类模型
+
+ Args:
+ num_classes: 类别数量
+ vocab_size: 词汇表大小
+ embedding_dim: 词嵌入维度
+ max_sequence_length: 最大序列长度
+ num_filters: 卷积核数量
+ filter_sizes: 卷积核大小列表
+ dropout_rate: Dropout比例
+ l2_reg_lambda: L2正则化系数
+ embedding_matrix: 预训练词嵌入矩阵,如果为None则使用随机初始化
+ trainable_embedding: 词嵌入是否可训练
+ model_name: 模型名称
+ batch_size: 批大小
+ learning_rate: 学习率
+ """
+ super().__init__(num_classes, model_name, batch_size, learning_rate)
+
+ self.vocab_size = vocab_size
+ self.embedding_dim = embedding_dim
+ self.max_sequence_length = max_sequence_length
+ self.num_filters = num_filters
+ self.filter_sizes = filter_sizes
+ self.dropout_rate = dropout_rate
+ self.l2_reg_lambda = l2_reg_lambda
+ self.embedding_matrix = embedding_matrix
+ self.trainable_embedding = trainable_embedding
+
+ # 更新配置
+ self.config.update({
+ "vocab_size": vocab_size,
+ "embedding_dim": embedding_dim,
+ "max_sequence_length": max_sequence_length,
+ "num_filters": num_filters,
+ "filter_sizes": filter_sizes,
+ "dropout_rate": dropout_rate,
+ "l2_reg_lambda": l2_reg_lambda,
+ "trainable_embedding": trainable_embedding,
+ "model_type": "CNN"
+ })
+
+ logger.info(f"初始化CNN文本分类模型,词汇表大小: {vocab_size}, 嵌入维度: {embedding_dim}")
+
+ def build(self) -> None:
+ """构建CNN模型架构"""
+ # Input layer
+ sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
+
+ # Embedding layer
+ if self.embedding_matrix is not None:
+ embedding_layer = Embedding(
+ input_dim=self.vocab_size,
+ output_dim=self.embedding_dim,
+ weights=[self.embedding_matrix],
+ input_length=self.max_sequence_length,
+ trainable=self.trainable_embedding,
+ name='embedding'
+ )
+ else:
+ embedding_layer = Embedding(
+ input_dim=self.vocab_size,
+ output_dim=self.embedding_dim,
+ input_length=self.max_sequence_length,
+ trainable=True,
+ name='embedding'
+ )
+
+ embedded_sequences = embedding_layer(sequence_input)
+
+ # Convolutional layers with different filter sizes
+ conv_blocks = []
+ for filter_size in self.filter_sizes:
+ conv = Conv1D(
+ filters=self.num_filters,
+ kernel_size=filter_size,
+ padding='valid',
+ activation='relu',
+ kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg_lambda),
+ name=f'conv_{filter_size}'
+ )(embedded_sequences)
+
+ # Max pooling
+ pooled = GlobalMaxPooling1D(name=f'max_pooling_{filter_size}')(conv)
+ conv_blocks.append(pooled)
+
+ # Concatenate pooled features if we have multiple filter sizes
+ if len(self.filter_sizes) > 1:
+ concatenated = Concatenate(name='concatenate')(conv_blocks)
+ else:
+ concatenated = conv_blocks[0]
+
+ # Dropout for regularization
+ x = Dropout(self.dropout_rate, name='dropout_1')(concatenated)
+
+ # Dense layer
+ x = Dense(128, name='dense_1')(x)
+ x = BatchNormalization(name='batch_norm_1')(x)
+ x = Activation('relu', name='activation_1')(x)
+ x = Dropout(self.dropout_rate, name='dropout_2')(x)
+
+ # Output layer
+ if self.num_classes == 2:
+ # Binary classification
+ predictions = Dense(1, activation='sigmoid', name='predictions')(x)
+ else:
+ # Multi-class classification
+ predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
+
+ # Build the model
+ self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
+
+ logger.info(f"CNN模型构建完成,过滤器大小: {self.filter_sizes}, 每种大小的过滤器数量: {self.num_filters}")
+
+ def compile(self, optimizer=None, loss=None, metrics=None) -> None:
+ """
+ 编译CNN模型
+
+ Args:
+ optimizer: 优化器,默认为Adam
+ loss: 损失函数,默认根据类别数量选择
+ metrics: 评估指标,默认为accuracy
+ """
+ if self.model is None:
+ raise ValueError("模型尚未构建,请先调用build方法")
+
+ # 默认优化器
+ if optimizer is None:
+ optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
+
+ # 默认损失函数
+ if loss is None:
+ if self.num_classes == 2:
+ loss = 'binary_crossentropy'
+ else:
+ loss = 'sparse_categorical_crossentropy'
+
+ # 默认评估指标
+ if metrics is None:
+ metrics = ['accuracy']
+
+ # 编译模型
+ self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
+ logger.info(f"CNN模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}")
diff --git a/models/ensemble_model.py b/models/ensemble_model.py
new file mode 100644
index 0000000..2850895
--- /dev/null
+++ b/models/ensemble_model.py
@@ -0,0 +1,216 @@
+"""
+集成模型:实现多个模型的集成
+"""
+import numpy as np
+import tensorflow as tf
+from typing import List, Dict, Tuple, Optional, Any, Union
+import os
+
+from config.system_config import CLASSIFIERS_DIR
+from models.base_model import TextClassificationModel
+from utils.logger import get_logger
+
+logger = get_logger("EnsembleModel")
+
+
+class EnsembleModel:
+ """模型集成类,集成多个模型的预测结果"""
+
+ def __init__(self, models: List[TextClassificationModel],
+ weights: Optional[List[float]] = None,
+ voting: str = 'soft',
+ name: str = "ensemble_model"):
+ """
+ 初始化集成模型
+
+ Args:
+ models: 模型列表
+ weights: 各模型的权重,默认为均等权重
+ voting: 投票方式,'hard'表示多数投票,'soft'表示概率平均
+ name: 集成模型名称
+ """
+ self.models = models
+ self.num_models = len(models)
+
+ # 验证模型数量
+ if self.num_models == 0:
+ raise ValueError("模型列表不能为空")
+
+ # 设置权重
+ if weights is None:
+ self.weights = np.ones(self.num_models) / self.num_models
+ else:
+ if len(weights) != self.num_models:
+ raise ValueError("权重数量必须与模型数量相同")
+
+ # 归一化权重
+ self.weights = np.array(weights) / np.sum(weights)
+
+ # 验证投票方式
+ self.voting = voting.lower()
+ if self.voting not in ['hard', 'soft']:
+ raise ValueError("无效的投票方式,支持的方式: 'hard', 'soft'")
+
+ # 从第一个模型获取类别数
+ self.num_classes = models[0].num_classes
+
+ # 验证所有模型的类别数是否相同
+ for i, model in enumerate(models[1:], 1):
+ if model.num_classes != self.num_classes:
+ raise ValueError(
+ f"模型 {i} 的类别数 ({model.num_classes}) 与第一个模型的类别数 ({self.num_classes}) 不同")
+
+ self.name = name
+
+ logger.info(f"初始化集成模型,包含 {self.num_models} 个模型,投票方式: {self.voting}")
+
+ def predict(self, x: Union[np.ndarray, tf.data.Dataset, List],
+ batch_size: Optional[int] = None,
+ verbose: int = 0) -> np.ndarray:
+ """
+ 使用集成模型进行预测
+
+ Args:
+ x: 预测数据
+ batch_size: 批大小
+ verbose: 详细程度
+
+ Returns:
+ 预测概率
+ """
+ # 获取每个模型的预测结果
+ all_predictions = []
+
+ for i, model in enumerate(self.models):
+ logger.info(f"获取模型 {i + 1}/{self.num_models} 的预测结果")
+ predictions = model.predict(x, batch_size, verbose)
+
+ # 如果是二分类且输出形状是(n,1),转换为(n,2)
+ if self.num_classes == 2 and predictions.shape[1:] == (1,):
+ predictions = np.hstack([1 - predictions, predictions])
+
+ all_predictions.append(predictions)
+
+ # 根据投票方式进行集成
+ if self.voting == 'hard':
+ # 硬投票:每个模型预测的类别,取众数
+ individual_classes = [np.argmax(pred, axis=1) for pred in all_predictions]
+
+ # 获取带权重的预测类别频率
+ ensemble_result = np.zeros((len(x), self.num_classes))
+
+ for i, classes in enumerate(individual_classes):
+ for j, cls in enumerate(classes):
+ ensemble_result[j, cls] += self.weights[i]
+
+ return ensemble_result
+ else: # soft voting
+ # 软投票:对每个模型的预测概率进行加权平均
+ weighted_predictions = [pred * weight for pred, weight in zip(all_predictions, self.weights)]
+ ensemble_result = np.sum(weighted_predictions, axis=0)
+
+ return ensemble_result
+
+ def predict_classes(self, x: Union[np.ndarray, tf.data.Dataset, List],
+ batch_size: Optional[int] = None,
+ verbose: int = 0) -> np.ndarray:
+ """
+ 使用集成模型预测类别
+
+ Args:
+ x: 预测数据
+ batch_size: 批大小
+ verbose: 详细程度
+
+ Returns:
+ 预测的类别索引
+ """
+ # 获取预测概率
+ predictions = self.predict(x, batch_size, verbose)
+
+ # 获取最大概率的类别索引
+ return np.argmax(predictions, axis=1)
+
+ def save(self, directory: Optional[str] = None) -> str:
+ """
+ 保存集成模型
+
+ Args:
+ directory: 保存目录,默认为CLASSIFIERS_DIR
+
+ Returns:
+ 保存路径
+ """
+ if directory is None:
+ import time
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
+ directory = os.path.join(CLASSIFIERS_DIR, f"{self.name}_{timestamp}")
+
+ os.makedirs(directory, exist_ok=True)
+
+ # 保存模型列表
+ model_paths = []
+ for i, model in enumerate(self.models):
+ model_path = os.path.join(directory, f"model_{i}")
+ model.save(model_path)
+ model_paths.append(model_path)
+
+ # 保存集成配置
+ config = {
+ "name": self.name,
+ "num_models": self.num_models,
+ "model_paths": model_paths,
+ "weights": self.weights.tolist(),
+ "voting": self.voting,
+ "num_classes": self.num_classes
+ }
+
+ import json
+ config_path = os.path.join(directory, "ensemble_config.json")
+ with open(config_path, 'w', encoding='utf-8') as f:
+ json.dump(config, f, ensure_ascii=False, indent=4)
+
+ logger.info(f"集成模型已保存到目录: {directory}")
+
+ return directory
+
+ @classmethod
+ def load(cls, directory: str, custom_objects: Optional[Dict[str, Any]] = None) -> 'EnsembleModel':
+ """
+ 加载集成模型
+
+ Args:
+ directory: 模型目录
+ custom_objects: 自定义对象字典
+
+ Returns:
+ 加载的集成模型实例
+ """
+ # 加载配置
+ config_path = os.path.join(directory, "ensemble_config.json")
+
+ import json
+ with open(config_path, 'r', encoding='utf-8') as f:
+ config = json.load(f)
+
+ # 加载子模型
+ from models.model_factory import ModelFactory
+
+ models = []
+ model_paths = config["model_paths"]
+
+ for model_path in model_paths:
+ model = ModelFactory.load_model(model_path, custom_objects)
+ models.append(model)
+
+ # 创建集成模型
+ ensemble = cls(
+ models=models,
+ weights=config["weights"],
+ voting=config["voting"],
+ name=config["name"]
+ )
+
+ logger.info(f"从目录 {directory} 加载集成模型成功")
+
+ return ensemble
diff --git a/models/layers/__init__.py b/models/layers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/models/model_factory.py b/models/model_factory.py
new file mode 100644
index 0000000..7a2b67e
--- /dev/null
+++ b/models/model_factory.py
@@ -0,0 +1,169 @@
+"""
+模型工厂:统一创建和管理不同类型的模型
+"""
+from typing import List, Dict, Tuple, Optional, Any, Union
+import os
+import glob
+import time
+import numpy as np
+
+from config.system_config import CLASSIFIERS_DIR
+from config.model_config import (
+ BATCH_SIZE, LEARNING_RATE
+)
+from models.base_model import TextClassificationModel
+from models.cnn_model import CNNTextClassifier
+from models.rnn_model import RNNTextClassifier
+from models.transformer_model import TransformerTextClassifier
+from utils.logger import get_logger
+
+logger = get_logger("ModelFactory")
+
+
+class ModelFactory:
+ """模型工厂,用于创建和管理不同类型的模型"""
+
+ @staticmethod
+ def create_model(model_type: str, num_classes: int, vocab_size: int,
+ embedding_matrix: Optional[np.ndarray] = None,
+ model_config: Optional[Dict[str, Any]] = None,
+ **kwargs) -> TextClassificationModel:
+ """
+ 创建指定类型的模型
+
+ Args:
+ model_type: 模型类型,可选值: 'cnn', 'rnn', 'transformer'
+ num_classes: 类别数量
+ vocab_size: 词汇表大小
+ embedding_matrix: 预训练词嵌入矩阵
+ model_config: 模型配置字典
+ **kwargs: 其他参数
+
+ Returns:
+ 创建的模型实例
+ """
+ model_type = model_type.lower()
+
+ # 合并配置
+ config = model_config or {}
+ config.update(kwargs)
+
+ # 创建模型
+ if model_type == 'cnn':
+ model = CNNTextClassifier(
+ num_classes=num_classes,
+ vocab_size=vocab_size,
+ embedding_matrix=embedding_matrix,
+ **config
+ )
+ elif model_type == 'rnn':
+ model = RNNTextClassifier(
+ num_classes=num_classes,
+ vocab_size=vocab_size,
+ embedding_matrix=embedding_matrix,
+ **config
+ )
+ elif model_type == 'transformer':
+ model = TransformerTextClassifier(
+ num_classes=num_classes,
+ vocab_size=vocab_size,
+ embedding_matrix=embedding_matrix,
+ **config
+ )
+ else:
+ raise ValueError(f"不支持的模型类型: {model_type}")
+
+ logger.info(f"已创建 {model_type.upper()} 模型")
+
+ return model
+
+ @staticmethod
+ def load_model(model_path: str, custom_objects: Optional[Dict[str, Any]] = None) -> TextClassificationModel:
+ """
+ 加载保存的模型
+
+ Args:
+ model_path: 模型路径
+ custom_objects: 自定义对象字典
+
+ Returns:
+ 加载的模型实例
+ """
+ # 添加Transformer相关的自定义对象
+ if custom_objects is None:
+ custom_objects = {}
+
+ if 'TransformerBlock' not in custom_objects:
+ from models.transformer_model import TransformerBlock
+ custom_objects['TransformerBlock'] = TransformerBlock
+
+ # 根据配置确定模型类型
+ model_config_path = f"{model_path}_config.json"
+
+ import json
+ with open(model_config_path, 'r', encoding='utf-8') as f:
+ config = json.load(f)
+
+ model_type = config.get('model_type', '').lower()
+
+ # 根据模型类型选择加载方法
+ if model_type == 'cnn':
+ model = CNNTextClassifier.load(model_path, custom_objects)
+ elif model_type == 'rnn':
+ model = RNNTextClassifier.load(model_path, custom_objects)
+ elif model_type == 'transformer':
+ model = TransformerTextClassifier.load(model_path, custom_objects)
+ else:
+ # 如果无法确定模型类型,使用基类加载
+ logger.warning(f"无法确定模型类型,使用基类加载: {model_path}")
+ model = TextClassificationModel.load(model_path, custom_objects)
+
+ logger.info(f"已加载模型: {model_path}")
+
+ return model
+
+ @staticmethod
+ def get_available_models() -> List[Dict[str, Any]]:
+ """
+ 获取可用的已保存模型列表
+
+ Returns:
+ 模型信息列表,每个元素是包含模型信息的字典
+ """
+ model_files = glob.glob(os.path.join(CLASSIFIERS_DIR, "*"))
+ model_files = [f for f in model_files if not f.endswith("_config.json")]
+
+ models_info = []
+
+ for model_file in model_files:
+ config_file = f"{model_file}_config.json"
+
+ if os.path.exists(config_file):
+ try:
+ import json
+ with open(config_file, 'r', encoding='utf-8') as f:
+ config = json.load(f)
+
+ # 获取模型文件的创建时间
+ created_time = os.path.getctime(model_file)
+ created_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_time))
+
+ # 获取模型文件大小
+ file_size = os.path.getsize(model_file) / (1024 * 1024) # MB
+
+ models_info.append({
+ "path": model_file,
+ "name": config.get("model_name", os.path.basename(model_file)),
+ "type": config.get("model_type", "unknown"),
+ "num_classes": config.get("num_classes", 0),
+ "created_time": created_time_str,
+ "file_size": f"{file_size:.2f} MB",
+ "config": config
+ })
+ except Exception as e:
+ logger.error(f"读取模型配置失败: {config_file}, 错误: {e}")
+
+ # 按创建时间降序排序
+ models_info.sort(key=lambda x: x.get("created_time", ""), reverse=True)
+
+ return models_info
diff --git a/models/rnn_model.py b/models/rnn_model.py
new file mode 100644
index 0000000..8895dbf
--- /dev/null
+++ b/models/rnn_model.py
@@ -0,0 +1,220 @@
+"""
+RNN模型:实现基于循环神经网络的文本分类模型
+"""
+import tensorflow as tf
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import (
+ Input, Embedding, LSTM, GRU, Bidirectional, Dense, Dropout,
+ BatchNormalization, Activation, GlobalMaxPooling1D, GlobalAveragePooling1D
+)
+from typing import List, Dict, Tuple, Optional, Any, Union
+import numpy as np
+
+from config.model_config import (
+ MAX_SEQUENCE_LENGTH, RNN_CONFIG
+)
+from models.base_model import TextClassificationModel
+from utils.logger import get_logger
+
+logger = get_logger("RNNModel")
+
+
+class RNNTextClassifier(TextClassificationModel):
+ """循环神经网络文本分类模型"""
+
+ def __init__(self, num_classes: int, vocab_size: int,
+ embedding_dim: int = RNN_CONFIG["embedding_dim"],
+ max_sequence_length: int = MAX_SEQUENCE_LENGTH,
+ hidden_size: int = RNN_CONFIG["hidden_size"],
+ num_layers: int = RNN_CONFIG["num_layers"],
+ bidirectional: bool = RNN_CONFIG["bidirectional"],
+ rnn_type: str = "lstm", # 'lstm' or 'gru'
+ dropout_rate: float = RNN_CONFIG["dropout_rate"],
+ embedding_matrix: Optional[np.ndarray] = None,
+ trainable_embedding: bool = True,
+ pool_type: str = "max", # 'max', 'avg', or 'both'
+ model_name: str = "rnn_text_classifier",
+ batch_size: int = 64,
+ learning_rate: float = 0.001):
+ """
+ 初始化RNN文本分类模型
+
+ Args:
+ num_classes: 类别数量
+ vocab_size: 词汇表大小
+ embedding_dim: 词嵌入维度
+ max_sequence_length: 最大序列长度
+ hidden_size: 隐藏层大小
+ num_layers: RNN层数
+ bidirectional: 是否使用双向RNN
+ rnn_type: RNN类型,'lstm'或'gru'
+ dropout_rate: Dropout比例
+ embedding_matrix: 预训练词嵌入矩阵,如果为None则使用随机初始化
+ trainable_embedding: 词嵌入是否可训练
+ pool_type: 池化类型,'max'、'avg'或'both'
+ model_name: 模型名称
+ batch_size: 批大小
+ learning_rate: 学习率
+ """
+ super().__init__(num_classes, model_name, batch_size, learning_rate)
+
+ self.vocab_size = vocab_size
+ self.embedding_dim = embedding_dim
+ self.max_sequence_length = max_sequence_length
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.bidirectional = bidirectional
+ self.rnn_type = rnn_type.lower()
+ self.dropout_rate = dropout_rate
+ self.embedding_matrix = embedding_matrix
+ self.trainable_embedding = trainable_embedding
+ self.pool_type = pool_type
+
+ # 验证RNN类型
+ if self.rnn_type not in ["lstm", "gru"]:
+ raise ValueError("无效的RNN类型,支持的类型: 'lstm', 'gru'")
+
+ # 验证池化类型
+ if self.pool_type not in ["max", "avg", "both"]:
+ raise ValueError("无效的池化类型,支持的类型: 'max', 'avg', 'both'")
+
+ # 更新配置
+ self.config.update({
+ "vocab_size": vocab_size,
+ "embedding_dim": embedding_dim,
+ "max_sequence_length": max_sequence_length,
+ "hidden_size": hidden_size,
+ "num_layers": num_layers,
+ "bidirectional": bidirectional,
+ "rnn_type": rnn_type,
+ "dropout_rate": dropout_rate,
+ "trainable_embedding": trainable_embedding,
+ "pool_type": pool_type,
+ "model_type": "RNN"
+ })
+
+ logger.info(f"初始化RNN文本分类模型,类型: {rnn_type.upper()}, 隐藏层大小: {hidden_size}, 层数: {num_layers}")
+
+ def build(self) -> None:
+ """构建RNN模型架构"""
+ # Input layer
+ sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
+
+ # Embedding layer
+ if self.embedding_matrix is not None:
+ embedding_layer = Embedding(
+ input_dim=self.vocab_size,
+ output_dim=self.embedding_dim,
+ weights=[self.embedding_matrix],
+ input_length=self.max_sequence_length,
+ trainable=self.trainable_embedding,
+ name='embedding'
+ )
+ else:
+ embedding_layer = Embedding(
+ input_dim=self.vocab_size,
+ output_dim=self.embedding_dim,
+ input_length=self.max_sequence_length,
+ trainable=True,
+ name='embedding'
+ )
+
+ embedded_sequences = embedding_layer(sequence_input)
+
+ # 选择RNN层类型
+ if self.rnn_type == "lstm":
+ rnn_layer = LSTM
+ else: # gru
+ rnn_layer = GRU
+
+ # 构建多层RNN
+ x = embedded_sequences
+ for i in range(self.num_layers):
+ return_sequences = i < self.num_layers - 1 or self.pool_type != "last"
+
+ if self.bidirectional:
+ x = Bidirectional(
+ rnn_layer(
+ self.hidden_size,
+ return_sequences=return_sequences,
+ dropout=self.dropout_rate if i < self.num_layers - 1 else 0,
+ name=f'{self.rnn_type}_{i + 1}'
+ )
+ )(x)
+ else:
+ x = rnn_layer(
+ self.hidden_size,
+ return_sequences=return_sequences,
+ dropout=self.dropout_rate if i < self.num_layers - 1 else 0,
+ name=f'{self.rnn_type}_{i + 1}'
+ )(x)
+
+ # 根据池化类型选择池化方法
+ if self.pool_type == "max":
+ # 使用全局最大池化
+ pooled = GlobalMaxPooling1D(name='global_max_pooling')(x)
+ elif self.pool_type == "avg":
+ # 使用全局平均池化
+ pooled = GlobalAveragePooling1D(name='global_avg_pooling')(x)
+ elif self.pool_type == "both":
+ # 同时使用最大池化和平均池化,然后拼接
+ max_pooled = GlobalMaxPooling1D(name='global_max_pooling')(x)
+ avg_pooled = GlobalAveragePooling1D(name='global_avg_pooling')(x)
+ pooled = tf.keras.layers.Concatenate(name='concatenate')([max_pooled, avg_pooled])
+ else: # "last",使用最后一个时间步的输出
+ # 最后一层RNN已经返回了最后一个时间步的状态,不需要额外池化
+ pooled = x
+
+ # Dropout for regularization
+ x = Dropout(self.dropout_rate, name='dropout_1')(pooled)
+
+ # Dense layer
+ x = Dense(128, name='dense_1')(x)
+ x = BatchNormalization(name='batch_norm_1')(x)
+ x = Activation('relu', name='activation_1')(x)
+ x = Dropout(self.dropout_rate, name='dropout_2')(x)
+
+ # Output layer
+ if self.num_classes == 2:
+ # Binary classification
+ predictions = Dense(1, activation='sigmoid', name='predictions')(x)
+ else:
+ # Multi-class classification
+ predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
+
+ # Build the model
+ self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
+
+ logger.info(
+ f"RNN模型构建完成,类型: {self.rnn_type.upper()}, 双向: {self.bidirectional}, 池化类型: {self.pool_type}")
+
+ def compile(self, optimizer=None, loss=None, metrics=None) -> None:
+ """
+ 编译RNN模型
+
+ Args:
+ optimizer: 优化器,默认为Adam
+ loss: 损失函数,默认根据类别数量选择
+ metrics: 评估指标,默认为accuracy
+ """
+ if self.model is None:
+ raise ValueError("模型尚未构建,请先调用build方法")
+
+ # 默认优化器
+ if optimizer is None:
+ optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
+
+ # 默认损失函数
+ if loss is None:
+ if self.num_classes == 2:
+ loss = 'binary_crossentropy'
+ else:
+ loss = 'sparse_categorical_crossentropy'
+
+ # 默认评估指标
+ if metrics is None:
+ metrics = ['accuracy']
+
+ # 编译模型
+ self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
+ logger.info(f"RNN模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}")
diff --git a/models/transformer_model.py b/models/transformer_model.py
new file mode 100644
index 0000000..786c0c1
--- /dev/null
+++ b/models/transformer_model.py
@@ -0,0 +1,270 @@
+"""
+Transformer模型:实现基于Transformer的文本分类模型
+"""
+import tensorflow as tf
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import (
+ Input, Embedding, Dense, Dropout, LayerNormalization,
+ GlobalAveragePooling1D, MultiHeadAttention, Add
+)
+from typing import List, Dict, Tuple, Optional, Any, Union
+import numpy as np
+
+from config.model_config import (
+ MAX_SEQUENCE_LENGTH, TRANSFORMER_CONFIG
+)
+from models.base_model import TextClassificationModel
+from utils.logger import get_logger
+
+logger = get_logger("TransformerModel")
+
+
+class TransformerBlock(tf.keras.layers.Layer):
+ """Transformer块,包含多头注意力和前馈网络"""
+
+ def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout_rate: float = 0.1):
+ """
+ 初始化Transformer块
+
+ Args:
+ embed_dim: 嵌入维度
+ num_heads: 注意力头数
+ ff_dim: 前馈网络维度
+ dropout_rate: Dropout比例
+ """
+ super(TransformerBlock, self).__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.ff_dim = ff_dim
+ self.dropout_rate = dropout_rate
+
+ self.attention = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
+ self.ffn = tf.keras.Sequential([
+ Dense(ff_dim, activation="relu"),
+ Dense(embed_dim),
+ ])
+ self.layernorm1 = LayerNormalization(epsilon=1e-6)
+ self.layernorm2 = LayerNormalization(epsilon=1e-6)
+ self.dropout1 = Dropout(dropout_rate)
+ self.dropout2 = Dropout(dropout_rate)
+
+ def call(self, inputs, training=False):
+ """
+ 前向传播
+
+ Args:
+ inputs: 输入张量
+ training: 是否处于训练模式
+
+ Returns:
+ 输出张量
+ """
+ # 多头自注意力
+ attention_output = self.attention(inputs, inputs)
+ attention_output = self.dropout1(attention_output, training=training)
+ out1 = self.layernorm1(inputs + attention_output)
+
+ # 前馈网络
+ ffn_output = self.ffn(out1)
+ ffn_output = self.dropout2(ffn_output, training=training)
+ out2 = self.layernorm2(out1 + ffn_output)
+
+ return out2
+
+ def get_config(self):
+ """获取配置"""
+ config = super(TransformerBlock, self).get_config()
+ config.update({
+ "embed_dim": self.embed_dim,
+ "num_heads": self.num_heads,
+ "ff_dim": self.ff_dim,
+ "dropout_rate": self.dropout_rate
+ })
+ return config
+
+
+class TransformerTextClassifier(TextClassificationModel):
+ """Transformer文本分类模型"""
+
+ def __init__(self, num_classes: int, vocab_size: int,
+ embedding_dim: int = TRANSFORMER_CONFIG["embedding_dim"],
+ max_sequence_length: int = MAX_SEQUENCE_LENGTH,
+ num_heads: int = TRANSFORMER_CONFIG["num_heads"],
+ ff_dim: int = TRANSFORMER_CONFIG["ff_dim"],
+ num_layers: int = TRANSFORMER_CONFIG["num_layers"],
+ dropout_rate: float = TRANSFORMER_CONFIG["dropout_rate"],
+ embedding_matrix: Optional[np.ndarray] = None,
+ trainable_embedding: bool = True,
+ use_positional_encoding: bool = True,
+ model_name: str = "transformer_text_classifier",
+ batch_size: int = 64,
+ learning_rate: float = 0.001):
+ """
+ 初始化Transformer文本分类模型
+
+ Args:
+ num_classes: 类别数量
+ vocab_size: 词汇表大小
+ embedding_dim: 词嵌入维度
+ max_sequence_length: 最大序列长度
+ num_heads: 注意力头数
+ ff_dim: 前馈网络维度
+ num_layers: Transformer层数
+ dropout_rate: Dropout比例
+ embedding_matrix: 预训练词嵌入矩阵,如果为None则使用随机初始化
+ trainable_embedding: 词嵌入是否可训练
+ use_positional_encoding: 是否使用位置编码
+ model_name: 模型名称
+ batch_size: 批大小
+ learning_rate: 学习率
+ """
+ super().__init__(num_classes, model_name, batch_size, learning_rate)
+
+ self.vocab_size = vocab_size
+ self.embedding_dim = embedding_dim
+ self.max_sequence_length = max_sequence_length
+ self.num_heads = num_heads
+ self.ff_dim = ff_dim
+ self.num_layers = num_layers
+ self.dropout_rate = dropout_rate
+ self.embedding_matrix = embedding_matrix
+ self.trainable_embedding = trainable_embedding
+ self.use_positional_encoding = use_positional_encoding
+
+ # 更新配置
+ self.config.update({
+ "vocab_size": vocab_size,
+ "embedding_dim": embedding_dim,
+ "max_sequence_length": max_sequence_length,
+ "num_heads": num_heads,
+ "ff_dim": ff_dim,
+ "num_layers": num_layers,
+ "dropout_rate": dropout_rate,
+ "trainable_embedding": trainable_embedding,
+ "use_positional_encoding": use_positional_encoding,
+ "model_type": "Transformer"
+ })
+
+ logger.info(f"初始化Transformer文本分类模型,头数: {num_heads}, 层数: {num_layers}")
+
+ def _positional_encoding(self, max_length: int, d_model: int) -> tf.Tensor:
+ """
+ 生成位置编码
+
+ Args:
+ max_length: 最大序列长度
+ d_model: 模型维度
+
+ Returns:
+ 位置编码张量
+ """
+ positions = np.arange(max_length)[:, np.newaxis]
+ depths = np.arange(d_model)[np.newaxis, :] // 2 * 2
+ angle_rates = 1 / np.power(10000, depths / d_model)
+ angle_rads = positions * angle_rates
+
+ # sin用于偶数索引,cos用于奇数索引
+ sines = np.sin(angle_rads[:, 0::2])
+ cosines = np.cos(angle_rads[:, 1::2])
+
+ pos_encoding = np.zeros((max_length, d_model))
+ pos_encoding[:, 0::2] = sines
+ pos_encoding[:, 1::2] = cosines
+
+ return tf.cast(pos_encoding[tf.newaxis, ...], dtype=tf.float32)
+
+ def build(self) -> None:
+ """构建Transformer模型架构"""
+ # Input layer
+ sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
+
+ # Embedding layer
+ if self.embedding_matrix is not None:
+ embedding_layer = Embedding(
+ input_dim=self.vocab_size,
+ output_dim=self.embedding_dim,
+ weights=[self.embedding_matrix],
+ input_length=self.max_sequence_length,
+ trainable=self.trainable_embedding,
+ name='embedding'
+ )
+ else:
+ embedding_layer = Embedding(
+ input_dim=self.vocab_size,
+ output_dim=self.embedding_dim,
+ input_length=self.max_sequence_length,
+ trainable=True,
+ name='embedding'
+ )
+
+ embedded_sequences = embedding_layer(sequence_input)
+
+ # 添加位置编码
+ if self.use_positional_encoding:
+ pos_encoding = self._positional_encoding(self.max_sequence_length, self.embedding_dim)
+ embedded_sequences = embedded_sequences + pos_encoding
+
+ # Transformer层
+ x = embedded_sequences
+ for i in range(self.num_layers):
+ x = TransformerBlock(
+ embed_dim=self.embedding_dim,
+ num_heads=self.num_heads,
+ ff_dim=self.ff_dim,
+ dropout_rate=self.dropout_rate,
+ name=f'transformer_block_{i + 1}'
+ )(x)
+
+ # 全局池化
+ x = GlobalAveragePooling1D(name='global_avg_pooling')(x)
+
+ # Dropout for regularization
+ x = Dropout(self.dropout_rate, name='dropout_1')(x)
+
+ # Dense layer
+ x = Dense(128, activation='relu', name='dense_1')(x)
+ x = Dropout(self.dropout_rate, name='dropout_2')(x)
+
+ # Output layer
+ if self.num_classes == 2:
+ # Binary classification
+ predictions = Dense(1, activation='sigmoid', name='predictions')(x)
+ else:
+ # Multi-class classification
+ predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
+
+ # Build the model
+ self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
+
+ logger.info(f"Transformer模型构建完成,头数: {self.num_heads}, 层数: {self.num_layers}")
+
+ def compile(self, optimizer=None, loss=None, metrics=None) -> None:
+ """
+ 编译Transformer模型
+
+ Args:
+ optimizer: 优化器,默认为Adam
+ loss: 损失函数,默认根据类别数量选择
+ metrics: 评估指标,默认为accuracy
+ """
+ if self.model is None:
+ raise ValueError("模型尚未构建,请先调用build方法")
+
+ # 默认优化器
+ if optimizer is None:
+ optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
+
+ # 默认损失函数
+ if loss is None:
+ if self.num_classes == 2:
+ loss = 'binary_crossentropy'
+ else:
+ loss = 'sparse_categorical_crossentropy'
+
+ # 默认评估指标
+ if metrics is None:
+ metrics = ['accuracy']
+
+ # 编译模型
+ self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
+ logger.info(f"Transformer模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}")
\ No newline at end of file
diff --git a/preprocessing/__init__.py b/preprocessing/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/preprocessing/data_augmentation.py b/preprocessing/data_augmentation.py
new file mode 100644
index 0000000..5743bea
--- /dev/null
+++ b/preprocessing/data_augmentation.py
@@ -0,0 +1,414 @@
+"""
+数据增强模块:实现文本数据增强技术
+"""
+import random
+import re
+import jieba
+import synonyms
+import numpy as np
+from typing import List, Dict, Tuple, Optional, Any, Union, Callable
+import copy
+
+from config.model_config import RANDOM_SEED
+from utils.logger import get_logger
+from preprocessing.tokenization import ChineseTokenizer
+
+# 设置随机种子以保证可重复性
+random.seed(RANDOM_SEED)
+np.random.seed(RANDOM_SEED)
+
+logger = get_logger("DataAugmentation")
+
+
+class TextAugmenter:
+ """文本增强基类,定义通用接口"""
+
+ def __init__(self):
+ """初始化文本增强器"""
+ pass
+
+ def augment(self, text: str) -> str:
+ """
+ 对文本进行增强
+
+ Args:
+ text: 原始文本
+
+ Returns:
+ 增强后的文本
+ """
+ raise NotImplementedError("子类必须实现此方法")
+
+ def batch_augment(self, texts: List[str]) -> List[str]:
+ """
+ 批量对文本进行增强
+
+ Args:
+ texts: 原始文本列表
+
+ Returns:
+ 增强后的文本列表
+ """
+ return [self.augment(text) for text in texts]
+
+ def augment_with_label(self, text: str, label: Any) -> Tuple[str, Any]:
+ """
+ 对文本进行增强,同时保留标签
+
+ Args:
+ text: 原始文本
+ label: 标签
+
+ Returns:
+ (增强后的文本, 标签)的元组
+ """
+ return self.augment(text), label
+
+ def batch_augment_with_label(self, texts: List[str], labels: List[Any]) -> List[Tuple[str, Any]]:
+ """
+ 批量对文本进行增强,同时保留标签
+
+ Args:
+ texts: 原始文本列表
+ labels: 标签列表
+
+ Returns:
+ (增强后的文本, 标签)的元组列表
+ """
+ return [self.augment_with_label(text, label) for text, label in zip(texts, labels)]
+
+
+class SynonymReplacement(TextAugmenter):
+ """同义词替换增强器"""
+
+ def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
+ replace_ratio: float = 0.1,
+ min_similarity: float = 0.7):
+ """
+ 初始化同义词替换增强器
+
+ Args:
+ tokenizer: 分词器,如果为None则创建一个新的分词器
+ replace_ratio: 替换比例,表示要替换的词占总词数的比例
+ min_similarity: 最小相似度,只有相似度大于该值的同义词才会被用于替换
+ """
+ super().__init__()
+ self.tokenizer = tokenizer or ChineseTokenizer()
+ self.replace_ratio = replace_ratio
+ self.min_similarity = min_similarity
+
+ def _get_synonym(self, word: str) -> Optional[str]:
+ """
+ 获取词的同义词
+
+ Args:
+ word: 原始词
+
+ Returns:
+ 同义词,如果没有合适的同义词则返回None
+ """
+ # 使用synonyms包获取同义词
+ try:
+ synonyms_list = synonyms.nearby(word)
+
+ # synonyms.nearby返回一个元组,第一个元素是相似词列表,第二个元素是相似度列表
+ words = synonyms_list[0]
+ similarities = synonyms_list[1]
+
+ # 过滤掉相似度低于阈值的词和原词本身
+ valid_synonyms = [(w, s) for w, s in zip(words, similarities)
+ if s >= self.min_similarity and w != word]
+
+ if valid_synonyms:
+ # 按相似度排序,选择最相似的词
+ valid_synonyms.sort(key=lambda x: x[1], reverse=True)
+ return valid_synonyms[0][0]
+
+ return None
+ except:
+ return None
+
+ def augment(self, text: str) -> str:
+ """
+ 对文本进行同义词替换增强
+
+ Args:
+ text: 原始文本
+
+ Returns:
+ 增强后的文本
+ """
+ if not text:
+ return text
+
+ # 分词
+ words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
+
+ if not words:
+ return text
+
+ # 计算要替换的词数量
+ n_replace = max(1, int(len(words) * self.replace_ratio))
+
+ # 随机选择要替换的词索引
+ replace_indices = random.sample(range(len(words)), min(n_replace, len(words)))
+
+ # 替换为同义词
+ for idx in replace_indices:
+ synonym = self._get_synonym(words[idx])
+ if synonym:
+ words[idx] = synonym
+
+ # 拼接为文本
+ augmented_text = ''.join(words)
+
+ return augmented_text
+
+
+class RandomDeletion(TextAugmenter):
+ """随机删除增强器"""
+
+ def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
+ delete_ratio: float = 0.1):
+ """
+ 初始化随机删除增强器
+
+ Args:
+ tokenizer: 分词器,如果为None则创建一个新的分词器
+ delete_ratio: 删除比例,表示要删除的词占总词数的比例
+ """
+ super().__init__()
+ self.tokenizer = tokenizer or ChineseTokenizer()
+ self.delete_ratio = delete_ratio
+
+ def augment(self, text: str) -> str:
+ """
+ 对文本进行随机删除增强
+
+ Args:
+ text: 原始文本
+
+ Returns:
+ 增强后的文本
+ """
+ if not text:
+ return text
+
+ # 分词
+ words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
+
+ if len(words) <= 1:
+ return text
+
+ # 计算要删除的词数量
+ n_delete = max(1, int(len(words) * self.delete_ratio))
+
+ # 随机选择要删除的词索引
+ delete_indices = random.sample(range(len(words)), min(n_delete, len(words) - 1))
+
+ # 删除选中的词
+ augmented_words = [words[i] for i in range(len(words)) if i not in delete_indices]
+
+ # 拼接为文本
+ augmented_text = ''.join(augmented_words)
+
+ return augmented_text
+
+
+class RandomSwap(TextAugmenter):
+ """随机交换增强器"""
+
+ def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
+ n_swaps: int = 1):
+ """
+ 初始化随机交换增强器
+
+ Args:
+ tokenizer: 分词器,如果为None则创建一个新的分词器
+ n_swaps: 交换次数
+ """
+ super().__init__()
+ self.tokenizer = tokenizer or ChineseTokenizer()
+ self.n_swaps = n_swaps
+
+ def augment(self, text: str) -> str:
+ """
+ 对文本进行随机交换增强
+
+ Args:
+ text: 原始文本
+
+ Returns:
+ 增强后的文本
+ """
+ if not text:
+ return text
+
+ # 分词
+ words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
+
+ if len(words) <= 1:
+ return text
+
+ # 进行n_swaps次随机交换
+ augmented_words = words.copy()
+ for _ in range(min(self.n_swaps, len(words) // 2)):
+ # 随机选择两个不同的索引
+ idx1, idx2 = random.sample(range(len(augmented_words)), 2)
+
+ # 交换两个词
+ augmented_words[idx1], augmented_words[idx2] = augmented_words[idx2], augmented_words[idx1]
+
+ # 拼接为文本
+ augmented_text = ''.join(augmented_words)
+
+ return augmented_text
+
+
+class CompositeAugmenter(TextAugmenter):
+ """组合增强器,组合多个增强器"""
+
+ def __init__(self, augmenters: List[TextAugmenter],
+ probs: Optional[List[float]] = None):
+ """
+ 初始化组合增强器
+
+ Args:
+ augmenters: 增强器列表
+ probs: 各增强器被选择的概率列表,如果为None则均匀选择
+ """
+ super().__init__()
+ self.augmenters = augmenters
+
+ # 如果没有提供概率,则均匀分配
+ if probs is None:
+ self.probs = [1.0 / len(augmenters)] * len(augmenters)
+ else:
+ # 确保概率和为1
+ total = sum(probs)
+ self.probs = [p / total for p in probs]
+
+ assert len(self.augmenters) == len(self.probs), "增强器数量与概率数量不匹配"
+
+ def augment(self, text: str) -> str:
+ """
+ 对文本进行组合增强
+
+ Args:
+ text: 原始文本
+
+ Returns:
+ 增强后的文本
+ """
+ if not text:
+ return text
+
+ # 根据概率随机选择一个增强器
+ augmenter = random.choices(self.augmenters, weights=self.probs, k=1)[0]
+
+ # 使用选中的增强器进行增强
+ return augmenter.augment(text)
+
+
+class BackTranslation(TextAugmenter):
+ """回译增强器"""
+
+ def __init__(self, translator=None, source_lang: str = 'zh',
+ target_langs: List[str] = None):
+ """
+ 初始化回译增强器
+
+ Args:
+ translator: 翻译器,需要实现translate方法
+ source_lang: 源语言代码
+ target_langs: 目标语言代码列表,如果为None则使用默认语言
+ """
+ super().__init__()
+
+ # 如果没有提供翻译器,尝试使用第三方翻译库
+ if translator is None:
+ try:
+ # 尝试导入多种翻译库
+ # 首先尝试使用googletrans (需要单独安装: pip install googletrans==4.0.0-rc1)
+ try:
+ from googletrans import Translator
+ self.translator = Translator()
+ self.translate_func = self._google_translate
+ except ImportError:
+ # 如果googletrans不可用,尝试使用py-translate
+ try:
+ import translate
+ self.translator = translate
+ self.translate_func = self._py_translate
+ except ImportError:
+ logger.warning("未安装翻译库,回译功能将不可用。请安装googletrans或py-translate")
+ self.translator = None
+ self.translate_func = self._dummy_translate
+ except Exception as e:
+ logger.error(f"初始化翻译器失败: {e}")
+ self.translator = None
+ self.translate_func = self._dummy_translate
+ else:
+ self.translator = translator
+ self.translate_func = self._custom_translate
+
+ self.source_lang = source_lang
+ self.target_langs = target_langs or ['en', 'fr', 'de', 'es', 'ja']
+
+ def _google_translate(self, text: str, source_lang: str, target_lang: str) -> str:
+ """使用googletrans进行翻译"""
+ try:
+ result = self.translator.translate(text, src=source_lang, dest=target_lang)
+ return result.text
+ except Exception as e:
+ logger.error(f"翻译失败: {e}")
+ return text
+
+ def _py_translate(self, text: str, source_lang: str, target_lang: str) -> str:
+ """使用py-translate进行翻译"""
+ try:
+ return self.translator.translate(text, source_lang, target_lang)
+ except Exception as e:
+ logger.error(f"翻译失败: {e}")
+ return text
+
+ def _custom_translate(self, text: str, source_lang: str, target_lang: str) -> str:
+ """使用自定义翻译器进行翻译"""
+ try:
+ return self.translator.translate(text, source_lang, target_lang)
+ except Exception as e:
+ logger.error(f"翻译失败: {e}")
+ return text
+
+ def _dummy_translate(self, text: str, source_lang: str, target_lang: str) -> str:
+ """虚拟翻译功能,仅返回原文本"""
+ logger.warning("翻译功能不可用,使用原文本")
+ return text
+
+ def augment(self, text: str) -> str:
+ """
+ 对文本进行回译增强
+
+ Args:
+ text: 原始文本
+
+ Returns:
+ 增强后的文本
+ """
+ if not text or self.translator is None:
+ return text
+
+ # 随机选择一个目标语言
+ target_lang = random.choice(self.target_langs)
+
+ try:
+ # 将源语言翻译为目标语言
+ translated = self.translate_func(text, self.source_lang, target_lang)
+
+ # 将目标语言翻译回源语言
+ back_translated = self.translate_func(translated, target_lang, self.source_lang)
+
+ return back_translated
+ except Exception as e:
+ logger.error(f"回译失败: {e}")
+ return text
\ No newline at end of file
diff --git a/preprocessing/feature_extraction.py b/preprocessing/feature_extraction.py
new file mode 100644
index 0000000..fc90ef3
--- /dev/null
+++ b/preprocessing/feature_extraction.py
@@ -0,0 +1,430 @@
+"""
+特征提取模块:实现文本特征提取,包括语法特征、语义特征等
+"""
+import re
+import numpy as np
+from typing import List, Dict, Tuple, Optional, Any, Union, Set
+from collections import Counter
+import jieba.posseg as pseg
+
+from config.system_config import CATEGORIES
+from utils.logger import get_logger
+from preprocessing.tokenization import ChineseTokenizer
+
+logger = get_logger("FeatureExtraction")
+
+
+class FeatureExtractor:
+ """特征提取基类,定义通用接口"""
+
+ def __init__(self):
+ """初始化特征提取器"""
+ pass
+
+ def extract(self, text: str) -> Dict[str, Any]:
+ """
+ 从文本中提取特征
+
+ Args:
+ text: 文本
+
+ Returns:
+ 特征字典
+ """
+ raise NotImplementedError("子类必须实现此方法")
+
+ def batch_extract(self, texts: List[str]) -> List[Dict[str, Any]]:
+ """
+ 批量提取特征
+
+ Args:
+ texts: 文本列表
+
+ Returns:
+ 特征字典列表
+ """
+ return [self.extract(text) for text in texts]
+
+ def extract_as_vector(self, text: str) -> np.ndarray:
+ """
+ 从文本中提取特征,并转换为向量表示
+
+ Args:
+ text: 文本
+
+ Returns:
+ 特征向量
+ """
+ raise NotImplementedError("子类必须实现此方法")
+
+ def batch_extract_as_vector(self, texts: List[str]) -> np.ndarray:
+ """
+ 批量提取特征,并转换为向量表示
+
+ Args:
+ texts: 文本列表
+
+ Returns:
+ 特征向量数组
+ """
+ return np.array([self.extract_as_vector(text) for text in texts])
+
+
+class StatisticalFeatureExtractor(FeatureExtractor):
+ """统计特征提取器,提取文本的统计特征"""
+
+ def __init__(self, tokenizer: Optional[ChineseTokenizer] = None):
+ """
+ 初始化统计特征提取器
+
+ Args:
+ tokenizer: 分词器,如果为None则创建一个新的分词器
+ """
+ super().__init__()
+ self.tokenizer = tokenizer or ChineseTokenizer()
+
+ def extract(self, text: str) -> Dict[str, Any]:
+ """
+ 从文本中提取统计特征
+
+ Args:
+ text: 文本
+
+ Returns:
+ 特征字典,包含各种统计特征
+ """
+ if not text:
+ return {
+ "char_count": 0,
+ "word_count": 0,
+ "sentence_count": 0,
+ "avg_word_length": 0,
+ "avg_sentence_length": 0,
+ "contains_number": False,
+ "contains_english": False,
+ "punctuation_ratio": 0,
+ "top_words": []
+ }
+
+ # 字符数
+ char_count = len(text)
+
+ # 分词
+ words = self.tokenizer.tokenize(text, return_string=False)
+ word_count = len(words)
+
+ # 句子数(按标点符号分割)
+ sentences = re.split(r'[。!?!?]+', text)
+ sentences = [s for s in sentences if s.strip()]
+ sentence_count = len(sentences)
+
+ # 平均词长
+ avg_word_length = sum(len(word) for word in words) / word_count if word_count > 0 else 0
+
+ # 平均句长(以字符为单位)
+ avg_sentence_length = char_count / sentence_count if sentence_count > 0 else 0
+
+ # 是否包含数字
+ contains_number = bool(re.search(r'\d', text))
+
+ # 是否包含英文
+ contains_english = bool(re.search(r'[a-zA-Z]', text))
+
+ # 标点符号比例
+ punctuation_pattern = re.compile(r'[^\w\s]')
+ punctuations = punctuation_pattern.findall(text)
+ punctuation_ratio = len(punctuations) / char_count if char_count > 0 else 0
+
+ # 高频词
+ word_counter = Counter(words)
+ top_words = word_counter.most_common(5)
+
+ return {
+ "char_count": char_count,
+ "word_count": word_count,
+ "sentence_count": sentence_count,
+ "avg_word_length": avg_word_length,
+ "avg_sentence_length": avg_sentence_length,
+ "contains_number": contains_number,
+ "contains_english": contains_english,
+ "punctuation_ratio": punctuation_ratio,
+ "top_words": top_words
+ }
+
+ def extract_as_vector(self, text: str) -> np.ndarray:
+ """
+ 从文本中提取统计特征,并转换为向量表示
+
+ Args:
+ text: 文本
+
+ Returns:
+ 特征向量,包含各种统计特征
+ """
+ features = self.extract(text)
+
+ # 提取数值特征
+ vector = [
+ features['char_count'],
+ features['word_count'],
+ features['sentence_count'],
+ features['avg_word_length'],
+ features['avg_sentence_length'],
+ int(features['contains_number']),
+ int(features['contains_english']),
+ features['punctuation_ratio']
+ ]
+
+ return np.array(vector, dtype=np.float32)
+
+
+class POSFeatureExtractor(FeatureExtractor):
+ """词性特征提取器,提取文本的词性特征"""
+
+ def __init__(self):
+ """初始化词性特征提取器"""
+ super().__init__()
+
+ # 常见中文词性及其解释
+ self.pos_tags = {
+ 'n': '名词', 'f': '方位名词', 's': '处所名词', 't': '时间名词',
+ 'nr': '人名', 'ns': '地名', 'nt': '机构团体', 'nw': '作品名',
+ 'nz': '其他专名', 'v': '动词', 'vd': '副动词', 'vn': '名动词',
+ 'a': '形容词', 'ad': '副形词', 'an': '名形词', 'd': '副词',
+ 'm': '数词', 'q': '量词', 'r': '代词', 'p': '介词',
+ 'c': '连词', 'u': '助词', 'xc': '其他虚词', 'w': '标点符号'
+ }
+
+ def extract(self, text: str) -> Dict[str, Any]:
+ """
+ 从文本中提取词性特征
+
+ Args:
+ text: 文本
+
+ Returns:
+ 特征字典,包含各种词性特征
+ """
+ if not text:
+ return {
+ "pos_counts": {},
+ "pos_ratios": {}
+ }
+
+ # 使用jieba进行词性标注
+ pos_list = pseg.cut(text)
+
+ # 统计各词性的数量
+ pos_counts = {}
+ total_count = 0
+
+ for word, pos in pos_list:
+ if pos in pos_counts:
+ pos_counts[pos] += 1
+ else:
+ pos_counts[pos] = 1
+ total_count += 1
+
+ # 计算各词性的比例
+ pos_ratios = {pos: count / total_count for pos, count in pos_counts.items()} if total_count > 0 else {}
+
+ return {
+ "pos_counts": pos_counts,
+ "pos_ratios": pos_ratios
+ }
+
+ def extract_as_vector(self, text: str) -> np.ndarray:
+ """
+ 从文本中提取词性特征,并转换为向量表示
+
+ Args:
+ text: 文本
+
+ Returns:
+ 特征向量,包含各词性的比例
+ """
+ features = self.extract(text)
+ pos_ratios = features['pos_ratios']
+
+ # 按照 self.pos_tags 的顺序构建向量
+ vector = []
+ for pos in self.pos_tags.keys():
+ vector.append(pos_ratios.get(pos, 0.0))
+
+ return np.array(vector, dtype=np.float32)
+
+
+class KeywordFeatureExtractor(FeatureExtractor):
+ """关键词特征提取器,基于预定义关键词提取特征"""
+
+ def __init__(self, category_keywords: Optional[Dict[str, List[str]]] = None):
+ """
+ 初始化关键词特征提取器
+
+ Args:
+ category_keywords: 类别关键词字典,键为类别名称,值为关键词列表
+ """
+ super().__init__()
+ self.category_keywords = category_keywords or self._get_default_keywords()
+ self.tokenizer = ChineseTokenizer()
+
+ def _get_default_keywords(self) -> Dict[str, List[str]]:
+ """
+ 获取默认的类别关键词
+
+ Returns:
+ 类别关键词字典
+ """
+ # 为每个类别定义一些示例关键词
+ default_keywords = {
+ "体育": ["比赛", "运动", "球员", "冠军", "球队", "足球", "篮球"],
+ "财经": ["股票", "基金", "投资", "市场", "经济", "金融", "股市"],
+ "房产": ["房价", "楼市", "地产", "购房", "房贷", "物业", "小区"],
+ "家居": ["装修", "家具", "设计", "卧室", "客厅", "厨房", "风格"],
+ "教育": ["学校", "学生", "考试", "教育", "大学", "课程", "老师"],
+ "科技": ["互联网", "科技", "创新", "数字", "智能", "研发", "技术"],
+ "时尚": ["时尚", "潮流", "服装", "搭配", "品牌", "美容", "穿着"],
+ "时政": ["政府", "政策", "国家", "发展", "会议", "主席", "总理"],
+ "游戏": ["游戏", "玩家", "电竞", "网游", "手游", "角色", "任务"],
+ "娱乐": ["明星", "电影", "节目", "综艺", "电视", "演员", "导演"],
+ "其他": ["其他", "一般", "常见", "普通", "正常", "通常", "传统"]
+ }
+
+ # 确保 CATEGORIES 中的每个类别都有关键词
+ for category in CATEGORIES:
+ if category not in default_keywords:
+ default_keywords[category] = [category]
+
+ return default_keywords
+
+ def extract(self, text: str) -> Dict[str, Any]:
+ """
+ 从文本中提取关键词特征
+
+ Args:
+ text: 文本
+
+ Returns:
+ 特征字典,包含各类别的关键词匹配情况
+ """
+ if not text:
+ return {
+ "keyword_matches": {cat: 0 for cat in self.category_keywords},
+ "keyword_match_ratios": {cat: 0.0 for cat in self.category_keywords}
+ }
+
+ # 对文本分词
+ words = set(self.tokenizer.tokenize(text, return_string=False))
+
+ # 统计各类别的关键词匹配数量
+ keyword_matches = {}
+ for category, keywords in self.category_keywords.items():
+ # 计算文本中包含的该类别关键词数量
+ matches = sum(1 for kw in keywords if kw in words)
+ keyword_matches[category] = matches
+
+ # 计算匹配比例(归一化)
+ total_matches = sum(keyword_matches.values())
+ keyword_match_ratios = {
+ cat: matches / total_matches if total_matches > 0 else 0.0
+ for cat, matches in keyword_matches.items()
+ }
+
+ return {
+ "keyword_matches": keyword_matches,
+ "keyword_match_ratios": keyword_match_ratios
+ }
+
+ def extract_as_vector(self, text: str) -> np.ndarray:
+ """
+ 从文本中提取关键词特征,并转换为向量表示
+
+ Args:
+ text: 文本
+
+ Returns:
+ 特征向量,包含各类别的关键词匹配比例
+ """
+ features = self.extract(text)
+ match_ratios = features['keyword_match_ratios']
+
+ # 按照 CATEGORIES 的顺序构建向量
+ vector = [match_ratios.get(cat, 0.0) for cat in CATEGORIES]
+
+ return np.array(vector, dtype=np.float32)
+
+ def update_keywords(self, category: str, keywords: List[str]) -> None:
+ """
+ 更新指定类别的关键词
+
+ Args:
+ category: 类别名称
+ keywords: 关键词列表
+ """
+ self.category_keywords[category] = keywords
+ logger.info(f"已更新类别 {category} 的关键词,共 {len(keywords)} 个")
+
+ def add_keywords(self, category: str, keywords: List[str]) -> None:
+ """
+ 向指定类别添加关键词
+
+ Args:
+ category: 类别名称
+ keywords: 要添加的关键词列表
+ """
+ if category in self.category_keywords:
+ existing_keywords = set(self.category_keywords[category])
+ for keyword in keywords:
+ existing_keywords.add(keyword)
+ self.category_keywords[category] = list(existing_keywords)
+ else:
+ self.category_keywords[category] = keywords
+
+ logger.info(f"已向类别 {category} 添加关键词,当前共 {len(self.category_keywords[category])} 个")
+
+ class CombinedFeatureExtractor(FeatureExtractor):
+ """组合特征提取器,组合多个特征提取器的结果"""
+
+ def __init__(self, extractors: List[FeatureExtractor]):
+ """
+ 初始化组合特征提取器
+
+ Args:
+ extractors: 特征提取器列表
+ """
+ super().__init__()
+ self.extractors = extractors
+
+ def extract(self, text: str) -> Dict[str, Any]:
+ """
+ 从文本中提取组合特征
+
+ Args:
+ text: 文本
+
+ Returns:
+ 特征字典,包含所有特征提取器的结果
+ """
+ combined_features = {}
+ for i, extractor in enumerate(self.extractors):
+ extractor_name = type(extractor).__name__
+ features = extractor.extract(text)
+ combined_features[extractor_name] = features
+
+ return combined_features
+
+ def extract_as_vector(self, text: str) -> np.ndarray:
+ """
+ 从文本中提取组合特征,并转换为向量表示
+
+ Args:
+ text: 文本
+
+ Returns:
+ 特征向量,包含所有特征提取器的向量拼接
+ """
+ # 获取所有特征提取器的向量
+ feature_vectors = [extractor.extract_as_vector(text) for extractor in self.extractors]
+
+ # 拼接向量
+ return np.concatenate(feature_vectors)
\ No newline at end of file
diff --git a/preprocessing/text_cleaner.py b/preprocessing/text_cleaner.py
new file mode 100644
index 0000000..3f19507
--- /dev/null
+++ b/preprocessing/text_cleaner.py
@@ -0,0 +1,229 @@
+"""
+文本清洗模块:实现文本清洗,去除无用字符、HTML标签等
+"""
+import re
+import unicodedata
+import html
+from typing import List, Dict, Tuple, Optional, Any, Callable, Set, Union
+import string
+
+from utils.logger import get_logger
+
+logger = get_logger("TextCleaner")
+
+
+class TextCleaner:
+ """文本清洗类,提供各种文本清洗方法"""
+
+ def __init__(self, remove_html: bool = True,
+ remove_urls: bool = True,
+ remove_emails: bool = True,
+ remove_numbers: bool = False,
+ remove_punctuation: bool = False,
+ lowercase: bool = False,
+ normalize_unicode: bool = True,
+ remove_excessive_spaces: bool = True,
+ remove_short_texts: bool = False,
+ min_text_length: int = 10,
+ custom_patterns: Optional[List[str]] = None):
+ """
+ 初始化文本清洗器
+
+ Args:
+ remove_html: 是否移除HTML标签
+ remove_urls: 是否移除URL
+ remove_emails: 是否移除电子邮件地址
+ remove_numbers: 是否移除数字
+ remove_punctuation: 是否移除标点符号
+ lowercase: 是否转为小写(对中文无效)
+ normalize_unicode: 是否规范化Unicode字符
+ remove_excessive_spaces: 是否移除多余空格
+ remove_short_texts: 是否过滤掉短文本
+ min_text_length: 最小文本长度(当remove_short_texts=True时有效)
+ custom_patterns: 自定义的正则表达式模式列表,用于额外的文本清洗
+ """
+ self.remove_html = remove_html
+ self.remove_urls = remove_urls
+ self.remove_emails = remove_emails
+ self.remove_numbers = remove_numbers
+ self.remove_punctuation = remove_punctuation
+ self.lowercase = lowercase
+ self.normalize_unicode = normalize_unicode
+ self.remove_excessive_spaces = remove_excessive_spaces
+ self.remove_short_texts = remove_short_texts
+ self.min_text_length = min_text_length
+ self.custom_patterns = custom_patterns or []
+
+ # 编译正则表达式
+ self.html_pattern = re.compile(r'<.*?>')
+ self.url_pattern = re.compile(r'https?://\S+|www\.\S+')
+ self.email_pattern = re.compile(r'\S+@\S+\.\S+')
+ self.number_pattern = re.compile(r'\d+')
+ self.space_pattern = re.compile(r'\s+')
+
+ # 编译自定义模式
+ self.compiled_custom_patterns = [re.compile(pattern) for pattern in self.custom_patterns]
+
+ # 中文标点符号
+ self.chinese_punctuation = ",。!?;:""''【】《》()、…—~·"
+
+ logger.info("文本清洗器初始化完成")
+
+ def clean_text(self, text: str) -> str:
+ """
+ 清洗文本,应用所有已配置的清洗方法
+
+ Args:
+ text: 原始文本
+
+ Returns:
+ 清洗后的文本
+ """
+ if not text:
+ return ""
+
+ # HTML解码
+ if self.remove_html:
+ text = html.unescape(text)
+ text = self.html_pattern.sub(' ', text)
+
+ # 移除URL
+ if self.remove_urls:
+ text = self.url_pattern.sub(' ', text)
+
+ # 移除电子邮件
+ if self.remove_emails:
+ text = self.email_pattern.sub(' ', text)
+
+ # Unicode规范化
+ if self.normalize_unicode:
+ text = unicodedata.normalize('NFKC', text)
+
+ # 移除数字
+ if self.remove_numbers:
+ text = self.number_pattern.sub(' ', text)
+
+ # 移除标点符号
+ if self.remove_punctuation:
+ # 处理英文标点
+ for punct in string.punctuation:
+ text = text.replace(punct, ' ')
+ # 处理中文标点
+ for punct in self.chinese_punctuation:
+ text = text.replace(punct, ' ')
+
+ # 应用自定义清洗模式
+ for pattern in self.compiled_custom_patterns:
+ text = pattern.sub(' ', text)
+
+ # 转为小写
+ if self.lowercase:
+ text = text.lower()
+
+ # 移除多余空格
+ if self.remove_excessive_spaces:
+ text = self.space_pattern.sub(' ', text)
+ text = text.strip()
+
+ # 过滤掉短文本
+ if self.remove_short_texts and len(text) < self.min_text_length:
+ return ""
+
+ return text
+
+ def clean_texts(self, texts: List[str]) -> List[str]:
+ """
+ 批量清洗文本
+
+ Args:
+ texts: 原始文本列表
+
+ Returns:
+ 清洗后的文本列表
+ """
+ return [self.clean_text(text) for text in texts]
+
+ def remove_redundant_texts(self, texts: List[str]) -> List[str]:
+ """
+ 移除冗余文本(空文本和长度小于阈值的文本)
+
+ Args:
+ texts: 原始文本列表
+
+ Returns:
+ 移除冗余后的文本列表
+ """
+ return [text for text in texts if text and len(text) >= self.min_text_length]
+
+ @staticmethod
+ def remove_specific_characters(text: str, chars_to_remove: Union[str, Set[str]]) -> str:
+ """
+ 移除特定字符
+
+ Args:
+ text: 原始文本
+ chars_to_remove: 要移除的字符(字符串或字符集合)
+
+ Returns:
+ 移除特定字符后的文本
+ """
+ if isinstance(chars_to_remove, str):
+ for char in chars_to_remove:
+ text = text.replace(char, '')
+ else:
+ for char in chars_to_remove:
+ text = text.replace(char, '')
+ return text
+
+ @staticmethod
+ def replace_characters(text: str, char_map: Dict[str, str]) -> str:
+ """
+ 替换特定字符
+
+ Args:
+ text: 原始文本
+ char_map: 字符映射字典,键为要替换的字符,值为替换后的字符
+
+ Returns:
+ 替换特定字符后的文本
+ """
+ for old_char, new_char in char_map.items():
+ text = text.replace(old_char, new_char)
+ return text
+
+ @staticmethod
+ def remove_empty_lines(text: str) -> str:
+ """
+ 移除空行
+
+ Args:
+ text: 原始文本
+
+ Returns:
+ 移除空行后的文本
+ """
+ lines = text.splitlines()
+ non_empty_lines = [line for line in lines if line.strip()]
+ return '\n'.join(non_empty_lines)
+
+ @staticmethod
+ def truncate_text(text: str, max_length: int, truncate_from_end: bool = True) -> str:
+ """
+ 截断文本
+
+ Args:
+ text: 原始文本
+ max_length: 最大长度
+ truncate_from_end: 是否从末尾截断,如果为False则从开头截断
+
+ Returns:
+ 截断后的文本
+ """
+ if len(text) <= max_length:
+ return text
+
+ if truncate_from_end:
+ return text[:max_length]
+ else:
+ return text[len(text) - max_length:]
+
diff --git a/preprocessing/tokenization.py b/preprocessing/tokenization.py
new file mode 100644
index 0000000..34c5a47
--- /dev/null
+++ b/preprocessing/tokenization.py
@@ -0,0 +1,248 @@
+"""
+中文分词模块:负责中文文本分词处理
+"""
+import os
+import jieba
+import re
+from typing import List, Dict, Tuple, Optional, Any, Set, Union
+import pandas as pd
+from collections import Counter
+
+from config.system_config import STOPWORDS_DIR, ENCODING
+from utils.logger import get_logger
+from utils.file_utils import read_text_file, write_text_file, ensure_dir
+
+logger = get_logger("Tokenization")
+
+
+class ChineseTokenizer:
+ """中文分词器,基于jieba实现"""
+
+ def __init__(self, user_dict_path: Optional[str] = None,
+ use_hmm: bool = True,
+ remove_stopwords: bool = True,
+ stopwords_path: Optional[str] = None,
+ add_custom_words: Optional[List[str]] = None):
+ """
+ 初始化中文分词器
+
+ Args:
+ user_dict_path: 用户自定义词典路径
+ use_hmm: 是否使用HMM模型进行分词
+ remove_stopwords: 是否移除停用词
+ stopwords_path: 停用词表路径,如果为None,则使用默认停用词表
+ add_custom_words: 要添加的自定义词语列表
+ """
+ self.use_hmm = use_hmm
+ self.remove_stopwords = remove_stopwords
+
+ # 加载用户自定义词典
+ if user_dict_path and os.path.exists(user_dict_path):
+ jieba.load_userdict(user_dict_path)
+ logger.info(f"已加载用户自定义词典:{user_dict_path}")
+
+ # 加载停用词
+ self.stopwords = set()
+ if remove_stopwords:
+ self._load_stopwords(stopwords_path)
+
+ # 添加自定义词语
+ if add_custom_words:
+ for word in add_custom_words:
+ jieba.add_word(word)
+ logger.info(f"已添加 {len(add_custom_words)} 个自定义词语")
+
+ def _load_stopwords(self, stopwords_path: Optional[str] = None) -> None:
+ """
+ 加载停用词
+
+ Args:
+ stopwords_path: 停用词表路径,如果为None,则使用默认停用词表
+ """
+ # 如果没有指定停用词表路径,则使用默认停用词表
+ if not stopwords_path:
+ stopwords_path = os.path.join(STOPWORDS_DIR, "chinese_stopwords.txt")
+
+ # 如果没有找到默认停用词表,则创建一个空的停用词表
+ if not os.path.exists(stopwords_path):
+ ensure_dir(os.path.dirname(stopwords_path))
+ # 常见中文停用词
+ default_stopwords = [
+ "的", "了", "和", "是", "就", "都", "而", "及", "与", "这", "那", "你",
+ "我", "他", "她", "它", "们", "或", "上", "下", "之", "地", "得", "着",
+ "说", "对", "在", "于", "由", "因", "为", "所", "以", "能", "可", "会"
+ ]
+ write_text_file("\n".join(default_stopwords), stopwords_path)
+ logger.info(f"未找到停用词表,已创建默认停用词表:{stopwords_path}")
+
+ # 加载停用词表
+ try:
+ with open(stopwords_path, "r", encoding=ENCODING) as f:
+ for line in f:
+ word = line.strip()
+ if word:
+ self.stopwords.add(word)
+ logger.info(f"已加载 {len(self.stopwords)} 个停用词")
+ except Exception as e:
+ logger.error(f"加载停用词表失败:{e}")
+
+ def add_stopwords(self, words: Union[str, List[str]]) -> None:
+ """
+ 添加停用词
+
+ Args:
+ words: 要添加的停用词(字符串或列表)
+ """
+ if isinstance(words, str):
+ self.stopwords.add(words.strip())
+ else:
+ for word in words:
+ self.stopwords.add(word.strip())
+
+ def remove_stopwords_from_list(self, words: List[str]) -> List[str]:
+ """
+ 从词语列表中移除停用词
+
+ Args:
+ words: 词语列表
+
+ Returns:
+ 移除停用词后的词语列表
+ """
+ if not self.remove_stopwords:
+ return words
+
+ return [word for word in words if word not in self.stopwords]
+
+ def tokenize(self, text: str, return_string: bool = False,
+ cut_all: bool = False) -> Union[List[str], str]:
+ """
+ 对文本进行分词
+
+ Args:
+ text: 要分词的文本
+ return_string: 是否返回字符串(以空格分隔的词语)
+ cut_all: 是否使用全模式(默认使用精确模式)
+
+ Returns:
+ 分词结果(词语列表或字符串)
+ """
+ if not text:
+ return "" if return_string else []
+
+ # 使用jieba进行分词
+ if cut_all:
+ words = jieba.lcut(text, cut_all=True)
+ else:
+ words = jieba.lcut(text, HMM=self.use_hmm)
+
+ # 移除停用词
+ if self.remove_stopwords:
+ words = self.remove_stopwords_from_list(words)
+
+ # 返回结果
+ if return_string:
+ return " ".join(words)
+ else:
+ return words
+
+ def batch_tokenize(self, texts: List[str], return_string: bool = False,
+ cut_all: bool = False) -> List[Union[List[str], str]]:
+ """
+ 批量分词
+
+ Args:
+ texts: 要分词的文本列表
+ return_string: 是否返回字符串(以空格分隔的词语)
+ cut_all: 是否使用全模式(默认使用精确模式)
+
+ Returns:
+ 分词结果列表
+ """
+ return [self.tokenize(text, return_string, cut_all) for text in texts]
+
+ def analyze_tokens(self, texts: List[str], top_n: int = 20) -> Dict[str, Any]:
+ """
+ 分析文本中的词频分布
+
+ Args:
+ texts: 要分析的文本列表
+ top_n: 返回前多少个高频词
+
+ Returns:
+ 包含词频分析结果的字典
+ """
+ all_tokens = []
+ for text in texts:
+ tokens = self.tokenize(text, return_string=False)
+ all_tokens.extend(tokens)
+
+ # 统计词频
+ token_counter = Counter(all_tokens)
+
+ # 获取最常见的词
+ most_common = token_counter.most_common(top_n)
+
+ # 计算唯一词数量
+ unique_tokens = len(token_counter)
+
+ return {
+ "total_tokens": len(all_tokens),
+ "unique_tokens": unique_tokens,
+ "most_common": most_common,
+ "token_counter": token_counter
+ }
+
+ def get_top_keywords(self, texts: List[str], top_n: int = 20,
+ min_freq: int = 3, min_length: int = 2) -> List[Tuple[str, int]]:
+ """
+ 获取文本中的关键词
+
+ Args:
+ texts: 要分析的文本列表
+ top_n: 返回前多少个关键词
+ min_freq: 最小词频
+ min_length: 最小词长度(字符数)
+
+ Returns:
+ 包含(关键词, 词频)的元组列表
+ """
+ tokens_analysis = self.analyze_tokens(texts)
+ token_counter = tokens_analysis["token_counter"]
+
+ # 过滤满足条件的词
+ filtered_keywords = [(word, count) for word, count in token_counter.items()
+ if count >= min_freq and len(word) >= min_length]
+
+ # 按词频排序
+ sorted_keywords = sorted(filtered_keywords, key=lambda x: x[1], reverse=True)
+
+ return sorted_keywords[:top_n]
+
+ def get_vocabulary(self, texts: List[str], min_freq: int = 1) -> List[str]:
+ """
+ 获取词汇表
+
+ Args:
+ texts: 文本列表
+ min_freq: 最小词频
+
+ Returns:
+ 词汇表(词语列表)
+ """
+ tokens_analysis = self.analyze_tokens(texts)
+ token_counter = tokens_analysis["token_counter"]
+
+ # 过滤满足最小词频的词
+ vocabulary = [word for word, count in token_counter.items() if count >= min_freq]
+
+ return vocabulary
+
+ def get_stopwords(self) -> Set[str]:
+ """
+ 获取停用词集合
+
+ Returns:
+ 停用词集合
+ """
+ return self.stopwords.copy()
\ No newline at end of file
diff --git a/preprocessing/vectorizer.py b/preprocessing/vectorizer.py
new file mode 100644
index 0000000..dde9301
--- /dev/null
+++ b/preprocessing/vectorizer.py
@@ -0,0 +1,774 @@
+"""
+文本向量化模块:实现文本向量化,包括词袋模型、TF-IDF和词嵌入等多种文本表示方法
+"""
+import numpy as np
+import tensorflow as tf
+from tensorflow.keras.preprocessing.text import Tokenizer
+from tensorflow.keras.preprocessing.sequence import pad_sequences
+from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
+import pickle
+import os
+from typing import List, Dict, Tuple, Optional, Any, Union, Callable
+import gensim
+from gensim.models import Word2Vec, KeyedVectors
+
+from config.system_config import PROCESSED_DATA_DIR, EMBEDDINGS_DIR
+from config.model_config import (
+ MAX_SEQUENCE_LENGTH, MAX_NUM_WORDS, MIN_WORD_FREQUENCY
+)
+from utils.logger import get_logger
+from utils.file_utils import save_pickle, load_pickle, ensure_dir
+from preprocessing.tokenization import ChineseTokenizer
+
+logger = get_logger("Vectorizer")
+
+
+class TextVectorizer:
+ """文本向量化基类,定义通用接口"""
+
+ def __init__(self, max_features: int = MAX_NUM_WORDS):
+ """
+ 初始化文本向量化器
+
+ Args:
+ max_features: 最大特征数(词汇表大小)
+ """
+ self.max_features = max_features
+ self.vectorizer = None
+ self.is_fitted = False
+
+ def fit(self, texts: List[str]) -> None:
+ """
+ 在文本上训练向量化器
+
+ Args:
+ texts: 文本列表
+ """
+ raise NotImplementedError("子类必须实现此方法")
+
+ def transform(self, texts: List[str]) -> np.ndarray:
+ """
+ 将文本转换为向量表示
+
+ Args:
+ texts: 文本列表
+
+ Returns:
+ 向量表示
+ """
+ raise NotImplementedError("子类必须实现此方法")
+
+ def fit_transform(self, texts: List[str]) -> np.ndarray:
+ """
+ 在文本上训练向量化器,并将文本转换为向量表示
+
+ Args:
+ texts: 文本列表
+
+ Returns:
+ 向量表示
+ """
+ self.fit(texts)
+ return self.transform(texts)
+
+ def save(self, path: str) -> None:
+ """
+ 保存向量化器
+
+ Args:
+ path: 保存路径
+ """
+ ensure_dir(os.path.dirname(path))
+ save_pickle(self.vectorizer, path)
+ logger.info(f"向量化器已保存到:{path}")
+
+ def load(self, path: str) -> None:
+ """
+ 加载向量化器
+
+ Args:
+ path: 加载路径
+ """
+ self.vectorizer = load_pickle(path)
+ self.is_fitted = True
+ logger.info(f"向量化器已从 {path} 加载")
+
+ def get_vocabulary(self) -> List[str]:
+ """
+ 获取词汇表
+
+ Returns:
+ 词汇表
+ """
+ raise NotImplementedError("子类必须实现此方法")
+
+ def get_vocabulary_size(self) -> int:
+ """
+ 获取词汇表大小
+
+ Returns:
+ 词汇表大小
+ """
+ raise NotImplementedError("子类必须实现此方法")
+
+
+class BagOfWordsVectorizer(TextVectorizer):
+ """词袋模型向量化器"""
+
+ def __init__(self, max_features: int = MAX_NUM_WORDS,
+ min_df: int = MIN_WORD_FREQUENCY,
+ tokenizer: Optional[Callable[[str], List[str]]] = None,
+ binary: bool = False):
+ """
+ 初始化词袋模型向量化器
+
+ Args:
+ max_features: 最大特征数(词汇表大小)
+ min_df: 最小文档频率
+ tokenizer: 分词器函数,接收文本,返回词语列表
+ binary: 是否使用二进制计数(只关注词语是否出现,不关注频率)
+ """
+ super().__init__(max_features)
+ self.min_df = min_df
+ self.binary = binary
+
+ # 创建sklearn的CountVectorizer
+ self.vectorizer = CountVectorizer(
+ max_features=max_features,
+ min_df=min_df,
+ tokenizer=tokenizer,
+ binary=binary
+ )
+
+ def fit(self, texts: List[str]) -> None:
+ """
+ 在文本上训练词袋模型
+
+ Args:
+ texts: 文本列表
+ """
+ self.vectorizer.fit(texts)
+ self.is_fitted = True
+ logger.info(f"词袋模型已训练,词汇表大小:{len(self.vectorizer.vocabulary_)}")
+
+ def transform(self, texts: List[str]) -> np.ndarray:
+ """
+ 将文本转换为词袋向量表示
+
+ Args:
+ texts: 文本列表
+
+ Returns:
+ 词袋向量表示(稀疏矩阵)
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ return self.vectorizer.transform(texts)
+
+ def get_vocabulary(self) -> List[str]:
+ """
+ 获取词汇表
+
+ Returns:
+ 词汇表(按索引排序)
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ # CountVectorizer的词汇表是一个字典,键为词,值为索引
+ vocab_dict = self.vectorizer.vocabulary_
+ vocab_list = [""] * len(vocab_dict)
+ for word, idx in vocab_dict.items():
+ vocab_list[idx] = word
+
+ return vocab_list
+
+ def get_vocabulary_size(self) -> int:
+ """
+ 获取词汇表大小
+
+ Returns:
+ 词汇表大小
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ return len(self.vectorizer.vocabulary_)
+
+
+class TfidfVectorizer(TextVectorizer):
+ """TF-IDF向量化器"""
+
+ def __init__(self, max_features: int = MAX_NUM_WORDS,
+ min_df: int = MIN_WORD_FREQUENCY,
+ tokenizer: Optional[Callable[[str], List[str]]] = None,
+ norm: str = 'l2',
+ use_idf: bool = True,
+ smooth_idf: bool = True,
+ sublinear_tf: bool = False):
+ """
+ 初始化TF-IDF向量化器
+
+ Args:
+ max_features: 最大特征数(词汇表大小)
+ min_df: 最小文档频率
+ tokenizer: 分词器函数,接收文本,返回词语列表
+ norm: 规范化方法,默认为L2范数
+ use_idf: 是否使用IDF(逆文档频率)
+ smooth_idf: 是否平滑IDF权重
+ sublinear_tf: 是否应用sublinear scaling(对TF取对数)
+ """
+ super().__init__(max_features)
+ self.min_df = min_df
+
+ # 创建sklearn的TfidfVectorizer
+ self.vectorizer = sklearn.feature_extraction.text.TfidfVectorizer(
+ max_features=max_features,
+ min_df=min_df,
+ tokenizer=tokenizer,
+ norm=norm,
+ use_idf=use_idf,
+ smooth_idf=smooth_idf,
+ sublinear_tf=sublinear_tf
+ )
+
+ def fit(self, texts: List[str]) -> None:
+ """
+ 在文本上训练TF-IDF模型
+
+ Args:
+ texts: 文本列表
+ """
+ self.vectorizer.fit(texts)
+ self.is_fitted = True
+ logger.info(f"TF-IDF模型已训练,词汇表大小:{len(self.vectorizer.vocabulary_)}")
+
+ def transform(self, texts: List[str]) -> np.ndarray:
+ """
+ 将文本转换为TF-IDF向量表示
+
+ Args:
+ texts: 文本列表
+
+ Returns:
+ TF-IDF向量表示(稀疏矩阵)
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ return self.vectorizer.transform(texts)
+
+ def get_vocabulary(self) -> List[str]:
+ """
+ 获取词汇表
+
+ Returns:
+ 词汇表(按索引排序)
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ # TfidfVectorizer的词汇表是一个字典,键为词,值为索引
+ vocab_dict = self.vectorizer.vocabulary_
+ vocab_list = [""] * len(vocab_dict)
+ for word, idx in vocab_dict.items():
+ vocab_list[idx] = word
+
+ return vocab_list
+
+ def get_vocabulary_size(self) -> int:
+ """
+ 获取词汇表大小
+
+ Returns:
+ 词汇表大小
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ return len(self.vectorizer.vocabulary_)
+
+ def get_feature_names(self) -> List[str]:
+ """
+ 获取特征名称(词汇表)
+
+ Returns:
+ 特征名称列表
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ return self.vectorizer.get_feature_names_out()
+
+ def get_idf(self) -> np.ndarray:
+ """
+ 获取IDF权重
+
+ Returns:
+ IDF权重数组
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ return self.vectorizer.idf_
+
+
+class SequenceVectorizer(TextVectorizer):
+ """序列向量化器,使用Keras的Tokenizer"""
+
+ def __init__(self, max_features: int = MAX_NUM_WORDS,
+ max_sequence_length: int = MAX_SEQUENCE_LENGTH,
+ oov_token: str = "",
+ padding: str = "post",
+ truncating: str = "post"):
+ """
+ 初始化序列向量化器
+
+ Args:
+ max_features: 最大特征数(词汇表大小)
+ max_sequence_length: 序列最大长度
+ oov_token: 未登录词标记
+ padding: 填充方式,'pre'或'post'
+ truncating: 截断方式,'pre'或'post'
+ """
+ super().__init__(max_features)
+ self.max_sequence_length = max_sequence_length
+ self.oov_token = oov_token
+ self.padding = padding
+ self.truncating = truncating
+
+ # 创建Keras的Tokenizer
+ self.vectorizer = Tokenizer(num_words=max_features, oov_token=oov_token)
+
+ def fit(self, texts: List[str]) -> None:
+ """
+ 在文本上训练序列向量化器
+
+ Args:
+ texts: 文本列表
+ """
+ self.vectorizer.fit_on_texts(texts)
+ self.is_fitted = True
+ logger.info(f"序列向量化器已训练,词汇表大小:{len(self.vectorizer.word_index)}")
+
+ def transform(self, texts: List[str]) -> np.ndarray:
+ """
+ 将文本转换为整数序列,并进行填充
+
+ Args:
+ texts: 文本列表
+
+ Returns:
+ 整数序列表示
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ sequences = self.vectorizer.texts_to_sequences(texts)
+ padded_sequences = pad_sequences(
+ sequences,
+ maxlen=self.max_sequence_length,
+ padding=self.padding,
+ truncating=self.truncating
+ )
+
+ return padded_sequences
+
+ def get_vocabulary(self) -> List[str]:
+ """
+ 获取词汇表
+
+ Returns:
+ 词汇表(按索引排序)
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ # Tokenizer的词汇表是一个字典,键为词,值为索引(从1开始)
+ word_index = self.vectorizer.word_index
+ index_word = {index: word for word, index in word_index.items()}
+
+ # 注意索引0保留给padding,索引1保留给OOV(如果有设置)
+ vocab = [""]
+ if self.oov_token:
+ vocab.append(self.oov_token)
+
+ max_index = min(self.max_features, len(word_index) + 1) if self.max_features else len(word_index) + 1
+ for i in range(1, max_index):
+ if i in index_word:
+ vocab.append(index_word[i])
+
+ return vocab
+
+ def get_vocabulary_size(self) -> int:
+ """
+ 获取词汇表大小
+
+ Returns:
+ 词汇表大小
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ # +1是因为索引0保留给padding
+ return min(self.max_features, len(self.vectorizer.word_index) + 1) if self.max_features else len(
+ self.vectorizer.word_index) + 1
+
+ def texts_to_sequences(self, texts: List[str]) -> List[List[int]]:
+ """
+ 将文本转换为整数序列(不填充)
+
+ Args:
+ texts: 文本列表
+
+ Returns:
+ 整数序列列表
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ return self.vectorizer.texts_to_sequences(texts)
+
+ def sequences_to_padded(self, sequences: List[List[int]]) -> np.ndarray:
+ """
+ 将整数序列填充到指定长度
+
+ Args:
+ sequences: 整数序列列表
+
+ Returns:
+ 填充后的整数序列
+ """
+ return pad_sequences(
+ sequences,
+ maxlen=self.max_sequence_length,
+ padding=self.padding,
+ truncating=self.truncating
+ )
+
+ def save(self, path: str) -> None:
+ """
+ 保存序列向量化器
+
+ Args:
+ path: 保存路径
+ """
+ ensure_dir(os.path.dirname(path))
+
+ # 保存配置和状态
+ tokenizer_state = {
+ 'tokenizer': self.vectorizer,
+ 'max_features': self.max_features,
+ 'max_sequence_length': self.max_sequence_length,
+ 'oov_token': self.oov_token,
+ 'padding': self.padding,
+ 'truncating': self.truncating,
+ 'is_fitted': self.is_fitted
+ }
+
+ save_pickle(tokenizer_state, path)
+ logger.info(f"序列向量化器已保存到:{path}")
+
+ def load(self, path: str) -> None:
+ """
+ 加载序列向量化器
+
+ Args:
+ path: 加载路径
+ """
+ tokenizer_state = load_pickle(path)
+
+ self.vectorizer = tokenizer_state['tokenizer']
+ self.max_features = tokenizer_state['max_features']
+ self.max_sequence_length = tokenizer_state['max_sequence_length']
+ self.oov_token = tokenizer_state['oov_token']
+ self.padding = tokenizer_state['padding']
+ self.truncating = tokenizer_state['truncating']
+ self.is_fitted = tokenizer_state['is_fitted']
+
+ logger.info(f"序列向量化器已从 {path} 加载,词汇表大小:{len(self.vectorizer.word_index)}")
+
+ class Word2VecVectorizer(TextVectorizer):
+ """Word2Vec词嵌入向量化器"""
+
+ def __init__(self, vector_size: int = 100,
+ window: int = 5,
+ min_count: int = MIN_WORD_FREQUENCY,
+ workers: int = 4,
+ sg: int = 1, # 1表示Skip-gram模型,0表示CBOW模型
+ max_sequence_length: int = MAX_SEQUENCE_LENGTH,
+ padding: str = "post",
+ truncating: str = "post",
+ pretrained_path: Optional[str] = None):
+ """
+ 初始化Word2Vec词嵌入向量化器
+
+ Args:
+ vector_size: 词向量维度
+ window: 上下文窗口大小
+ min_count: 最小词频
+ workers: 并行训练的线程数
+ sg: 训练算法,1表示Skip-gram,0表示CBOW
+ max_sequence_length: 序列最大长度
+ padding: 填充方式,'pre'或'post'
+ truncating: 截断方式,'pre'或'post'
+ pretrained_path: 预训练词向量路径,如果不为None,则加载预训练词向量
+ """
+ super().__init__(max_features=None) # Word2Vec没有max_features限制
+ self.vector_size = vector_size
+ self.window = window
+ self.min_count = min_count
+ self.workers = workers
+ self.sg = sg
+ self.max_sequence_length = max_sequence_length
+ self.padding = padding
+ self.truncating = truncating
+ self.pretrained_path = pretrained_path
+
+ # Word2Vec模型
+ self.model = None
+
+ # 词汇表
+ self.word_index = {}
+ self.index_word = {}
+
+ # 如果有预训练词向量,加载它
+ if pretrained_path and os.path.exists(pretrained_path):
+ self._load_pretrained(pretrained_path)
+
+ def _load_pretrained(self, path: str) -> None:
+ """
+ 加载预训练词向量
+
+ Args:
+ path: 预训练词向量路径
+ """
+ try:
+ # 尝试加载Word2Vec模型
+ self.model = Word2Vec.load(path)
+ logger.info(f"已加载预训练Word2Vec模型:{path}")
+ except:
+ try:
+ # 尝试加载词向量(Word2Vec、GloVe或FastText格式)
+ self.model = KeyedVectors.load_word2vec_format(path, binary=path.endswith('.bin'))
+ logger.info(f"已加载预训练词向量:{path}")
+ except Exception as e:
+ logger.error(f"加载预训练词向量失败:{e}")
+ return
+
+ # 如果加载成功,构建词汇表
+ self._build_vocab_from_model()
+ self.is_fitted = True
+
+ def _build_vocab_from_model(self) -> None:
+ """从模型构建词汇表"""
+ # 获取词汇表
+ vocabulary = list(self.model.wv.index_to_key)
+
+ # 构建词汇表索引
+ self.word_index = {word: idx + 1 for idx, word in enumerate(vocabulary)} # 索引0保留给padding
+ self.index_word = {idx + 1: word for idx, word in enumerate(vocabulary)}
+ self.index_word[0] = ""
+
+ def fit(self, tokenized_texts: List[List[str]]) -> None:
+ """
+ 在分词后的文本上训练Word2Vec模型
+
+ Args:
+ tokenized_texts: 分词后的文本列表(每个文本是一个词语列表)
+ """
+ # 如果已经有预训练模型,跳过训练
+ if self.is_fitted and self.model is not None:
+ logger.info("已有预训练模型,跳过训练")
+ return
+
+ # 训练Word2Vec模型
+ self.model = Word2Vec(
+ sentences=tokenized_texts,
+ vector_size=self.vector_size,
+ window=self.window,
+ min_count=self.min_count,
+ workers=self.workers,
+ sg=self.sg
+ )
+
+ # 构建词汇表
+ self._build_vocab_from_model()
+ self.is_fitted = True
+
+ logger.info(f"Word2Vec模型已训练,词汇表大小:{len(self.word_index)}")
+
+ def transform(self, tokenized_texts: List[List[str]]) -> np.ndarray:
+ """
+ 将分词后的文本转换为词向量序列
+
+ Args:
+ tokenized_texts: 分词后的文本列表(每个文本是一个词语列表)
+
+ Returns:
+ 词向量序列,形状为(样本数, 最大序列长度, 词向量维度)
+ """
+ if not self.is_fitted or self.model is None:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ # 初始化结果数组
+ result = np.zeros((len(tokenized_texts), self.max_sequence_length, self.vector_size))
+
+ # 处理每个文本
+ for i, text in enumerate(tokenized_texts):
+ seq_len = min(len(text), self.max_sequence_length)
+
+ # 根据截断方式处理
+ if self.truncating == 'pre' and len(text) > self.max_sequence_length:
+ text = text[-self.max_sequence_length:]
+ elif self.truncating == 'post' and len(text) > self.max_sequence_length:
+ text = text[:self.max_sequence_length]
+
+ # 获取每个词的词向量
+ for j, word in enumerate(text[:seq_len]):
+ if word in self.model.wv:
+ # 根据填充方式确定位置
+ pos = j if self.padding == 'post' else self.max_sequence_length - seq_len + j
+ result[i, pos] = self.model.wv[word]
+
+ return result
+
+ def transform_to_indices(self, tokenized_texts: List[List[str]]) -> np.ndarray:
+ """
+ 将分词后的文本转换为词索引序列,并填充
+
+ Args:
+ tokenized_texts: 分词后的文本列表(每个文本是一个词语列表)
+
+ Returns:
+ 词索引序列
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ # 将词转换为索引
+ sequences = []
+ for text in tokenized_texts:
+ seq = [self.word_index.get(word, 0) for word in text] # 未登录词用0(padding)
+ sequences.append(seq)
+
+ # 填充序列
+ padded_sequences = pad_sequences(
+ sequences,
+ maxlen=self.max_sequence_length,
+ padding=self.padding,
+ truncating=self.truncating
+ )
+
+ return padded_sequences
+
+ def get_embedding_matrix(self) -> np.ndarray:
+ """
+ 获取嵌入矩阵,用于Embedding层的权重初始化
+
+ Returns:
+ 嵌入矩阵,形状为(词汇表大小, 词向量维度)
+ """
+ if not self.is_fitted or self.model is None:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ vocab_size = len(self.word_index) + 1 # +1是因为索引0保留给padding
+ embedding_matrix = np.zeros((vocab_size, self.vector_size))
+
+ # 填充嵌入矩阵
+ for word, idx in self.word_index.items():
+ if word in self.model.wv:
+ embedding_matrix[idx] = self.model.wv[word]
+
+ return embedding_matrix
+
+ def get_vocabulary(self) -> List[str]:
+ """
+ 获取词汇表
+
+ Returns:
+ 词汇表(按索引排序)
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ vocab = [""] # 索引0保留给padding
+ for idx in range(1, len(self.index_word) + 1):
+ if idx in self.index_word:
+ vocab.append(self.index_word[idx])
+
+ return vocab
+
+ def get_vocabulary_size(self) -> int:
+ """
+ 获取词汇表大小
+
+ Returns:
+ 词汇表大小
+ """
+ if not self.is_fitted:
+ raise ValueError("向量化器尚未训练,请先调用fit方法")
+
+ return len(self.word_index) + 1 # +1是因为索引0保留给padding
+
+ def save(self, path: str) -> None:
+ """
+ 保存Word2Vec向量化器
+
+ Args:
+ path: 保存路径
+ """
+ ensure_dir(os.path.dirname(path))
+
+ # 保存模型和配置
+ model_path = os.path.join(os.path.dirname(path), "word2vec_model")
+ if self.model:
+ self.model.save(model_path)
+
+ # 保存配置和状态
+ state = {
+ 'word_index': self.word_index,
+ 'index_word': self.index_word,
+ 'vector_size': self.vector_size,
+ 'window': self.window,
+ 'min_count': self.min_count,
+ 'workers': self.workers,
+ 'sg': self.sg,
+ 'max_sequence_length': self.max_sequence_length,
+ 'padding': self.padding,
+ 'truncating': self.truncating,
+ 'is_fitted': self.is_fitted,
+ 'model_path': model_path if self.model else None
+ }
+
+ save_pickle(state, path)
+ logger.info(f"Word2Vec向量化器已保存到:{path}")
+
+ def load(self, path: str) -> None:
+ """
+ 加载Word2Vec向量化器
+
+ Args:
+ path: 加载路径
+ """
+ state = load_pickle(path)
+
+ self.word_index = state['word_index']
+ self.index_word = state['index_word']
+ self.vector_size = state['vector_size']
+ self.window = state['window']
+ self.min_count = state['min_count']
+ self.workers = state['workers']
+ self.sg = state['sg']
+ self.max_sequence_length = state['max_sequence_length']
+ self.padding = state['padding']
+ self.truncating = state['truncating']
+ self.is_fitted = state['is_fitted']
+
+ # 加载模型
+ model_path = state.get('model_path')
+ if model_path and os.path.exists(model_path):
+ self.model = Word2Vec.load(model_path)
+
+ logger.info(f"Word2Vec向量化器已从 {path} 加载,词汇表大小:{len(self.word_index)}")
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..e589592
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,29 @@
+# 核心依赖
+tensorflow>=2.5.0
+numpy>=1.19.5
+pandas>=1.3.0
+scikit-learn>=0.24.2
+matplotlib>=3.4.2
+jieba>=0.42.1
+tqdm>=4.61.1
+gensim>=4.0.1
+
+# 数据处理
+nltk>=3.6.2
+symspellpy>=6.7.0
+h5py>=3.1.0
+openpyxl>=3.0.7
+xlrd>=2.0.1
+
+# Web和API依赖
+flask>=2.0.1
+fastapi>=0.68.0
+uvicorn>=0.15.0
+pydantic>=1.8.2
+werkzeug>=2.0.1
+jinja2>=3.0.1
+python-multipart>=0.0.5
+
+# 其他工具
+requests>=2.26.0
+synonyms>=3.15.0
diff --git a/scripts/evaluate.py b/scripts/evaluate.py
new file mode 100644
index 0000000..87b80ab
--- /dev/null
+++ b/scripts/evaluate.py
@@ -0,0 +1,166 @@
+"""
+评估脚本:评估文本分类模型性能
+"""
+import os
+import sys
+import time
+import argparse
+import logging
+from typing import List, Dict, Tuple, Optional, Any, Union
+import numpy as np
+import matplotlib.pyplot as plt
+
+# 将项目根目录添加到系统路径
+project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.append(project_root)
+
+from config.system_config import (
+ RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR
+)
+from config.model_config import (
+ BATCH_SIZE, MAX_SEQUENCE_LENGTH
+)
+from data.dataloader import DataLoader
+from data.data_manager import DataManager
+from preprocessing.tokenization import ChineseTokenizer
+from preprocessing.vectorizer import SequenceVectorizer
+from models.model_factory import ModelFactory
+from evaluation.evaluator import ModelEvaluator
+from utils.logger import get_logger
+
+logger = get_logger("Evaluation")
+
+
+def evaluate_model(model_path: str,
+ data_dir: Optional[str] = None,
+ batch_size: int = BATCH_SIZE,
+ output_dir: Optional[str] = None) -> Dict[str, float]:
+ """
+ 评估文本分类模型
+
+ Args:
+ model_path: 模型路径
+ data_dir: 数据目录,如果为None则使用默认目录
+ batch_size: 批大小
+ output_dir: 评估结果输出目录,如果为None则使用默认目录
+
+ Returns:
+ 评估指标
+ """
+ logger.info(f"开始评估模型: {model_path}")
+ start_time = time.time()
+
+ # 设置数据目录
+ data_dir = data_dir or RAW_DATA_DIR
+
+ # 设置输出目录
+ if output_dir:
+ output_dir = os.path.abspath(output_dir)
+ os.makedirs(output_dir, exist_ok=True)
+
+ # 1. 加载模型
+ logger.info("加载模型...")
+ model = ModelFactory.load_model(model_path)
+
+ # 2. 加载数据
+ logger.info("加载数据...")
+ data_loader = DataLoader(data_dir=data_dir)
+ data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR)
+
+ # 加载测试集
+ data_manager.load_data()
+ test_texts, test_labels = data_manager.get_data(dataset="test")
+
+ # 3. 准备数据
+ # 创建分词器
+ tokenizer = ChineseTokenizer()
+
+ # 对测试文本进行分词
+ logger.info("对文本进行分词...")
+ tokenized_test_texts = [tokenizer.tokenize(text, return_string=True) for text in test_texts]
+
+ # 创建序列向量化器
+ logger.info("加载向量化器...")
+ # 查找向量化器文件
+ vectorizer_path = None
+ for model_type in ["cnn", "rnn", "transformer"]:
+ path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
+ if os.path.exists(path):
+ vectorizer_path = path
+ break
+
+ if not vectorizer_path:
+ # 如果找不到向量化器,创建一个新的
+ logger.warning("未找到向量化器,创建一个新的")
+ vectorizer = SequenceVectorizer(
+ max_features=MAX_NUM_WORDS,
+ max_sequence_length=MAX_SEQUENCE_LENGTH
+ )
+ else:
+ # 加载向量化器
+ vectorizer = SequenceVectorizer()
+ vectorizer.load(vectorizer_path)
+
+ # 转换测试文本
+ X_test = vectorizer.transform(tokenized_test_texts)
+
+ # 4. 创建评估器
+ logger.info("创建评估器...")
+ evaluator = ModelEvaluator(
+ model=model,
+ class_names=CATEGORIES,
+ output_dir=output_dir
+ )
+
+ # 5. 评估模型
+ logger.info("评估模型...")
+ metrics = evaluator.evaluate(X_test, test_labels, batch_size)
+
+ # 6. 保存评估结果
+ logger.info("保存评估结果...")
+ evaluator.save_evaluation_results(save_plots=True)
+
+ # 7. 可视化混淆矩阵
+ logger.info("可视化混淆矩阵...")
+ cm = evaluator.evaluation_results['confusion_matrix']
+ evaluator.metrics.plot_confusion_matrix(
+ y_true=test_labels,
+ y_pred=np.argmax(model.predict(X_test), axis=1),
+ normalize='true',
+ save_path=os.path.join(output_dir or os.path.dirname(model_path), "confusion_matrix.png")
+ )
+
+ # 8. 类别性能分析
+ logger.info("分析各类别性能...")
+ class_performance = evaluator.evaluate_class_performance(X_test, test_labels)
+
+ # 9. 计算评估时间
+ eval_time = time.time() - start_time
+ logger.info(f"模型评估完成,耗时: {eval_time:.2f} 秒")
+
+ # 10. 输出主要指标
+ logger.info("主要评估指标:")
+ for metric_name, metric_value in metrics.items():
+ if metric_name in ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']:
+ logger.info(f" {metric_name}: {metric_value:.4f}")
+
+ return metrics
+
+
+if __name__ == "__main__":
+ # 解析命令行参数
+ parser = argparse.ArgumentParser(description="评估文本分类模型")
+ parser.add_argument("--model_path", required=True, help="模型路径")
+ parser.add_argument("--data_dir", help="数据目录")
+ parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小")
+ parser.add_argument("--output_dir", help="评估结果输出目录")
+
+ args = parser.parse_args()
+
+ # 评估模型
+ evaluate_model(
+ model_path=args.model_path,
+ data_dir=args.data_dir,
+ batch_size=args.batch_size,
+ output_dir=args.output_dir
+ )
diff --git a/scripts/predict.py b/scripts/predict.py
new file mode 100644
index 0000000..14407f1
--- /dev/null
+++ b/scripts/predict.py
@@ -0,0 +1,242 @@
+"""
+预测脚本:使用模型进行预测
+"""
+import os
+import sys
+import time
+import argparse
+import logging
+from typing import List, Dict, Tuple, Optional, Any, Union
+import numpy as np
+import json
+
+# 将项目根目录添加到系统路径
+project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.append(project_root)
+
+from config.system_config import (
+ CATEGORIES, CLASSIFIERS_DIR
+)
+from models.model_factory import ModelFactory
+from preprocessing.tokenization import ChineseTokenizer
+from preprocessing.vectorizer import SequenceVectorizer
+from inference.predictor import Predictor
+from inference.batch_processor import BatchProcessor
+from utils.logger import get_logger
+from utils.file_utils import read_text_file
+
+logger = get_logger("Prediction")
+
+
+def predict_text(text: str, model_path: Optional[str] = None,
+ output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]:
+ """
+ 预测单条文本
+
+ Args:
+ text: 要预测的文本
+ model_path: 模型路径,如果为None则使用最新的模型
+ output_path: 输出文件路径,如果为None则不保存
+ top_k: 返回概率最高的前k个类别
+
+ Returns:
+ 预测结果
+ """
+ logger.info("开始预测文本")
+
+ # 1. 加载模型
+ if model_path is None:
+ # 获取可用模型列表
+ models_info = ModelFactory.get_available_models()
+
+ if not models_info:
+ raise ValueError("未找到可用的模型")
+
+ # 使用最新的模型
+ model_path = models_info[0]['path']
+
+ logger.info(f"加载模型: {model_path}")
+ model = ModelFactory.load_model(model_path)
+
+ # 2. 创建分词器和预测器
+ tokenizer = ChineseTokenizer()
+
+ # 查找向量化器文件
+ vectorizer = None
+ for model_type in ["cnn", "rnn", "transformer"]:
+ vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
+ if os.path.exists(vectorizer_path):
+ # 加载向量化器
+ vectorizer = SequenceVectorizer()
+ vectorizer.load(vectorizer_path)
+ logger.info(f"加载向量化器: {vectorizer_path}")
+ break
+
+ # 创建预测器
+ predictor = Predictor(
+ model=model,
+ tokenizer=tokenizer,
+ vectorizer=vectorizer,
+ class_names=CATEGORIES
+ )
+
+ # 3. 预测
+ result = predictor.predict(
+ text=text,
+ return_top_k=top_k,
+ return_probabilities=True
+ )
+
+ # 4. 输出结果
+ if top_k > 1:
+ logger.info("预测结果:")
+ for i, pred in enumerate(result):
+ logger.info(f" {i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
+ else:
+ logger.info(f"预测结果: {result['class']} (概率: {result['probability']:.4f})")
+
+ # 5. 保存结果
+ if output_path:
+ if output_path.endswith('.json'):
+ with open(output_path, 'w', encoding='utf-8') as f:
+ json.dump(result, f, ensure_ascii=False, indent=2)
+ else:
+ with open(output_path, 'w', encoding='utf-8') as f:
+ if top_k > 1:
+ f.write("rank,class,probability\n")
+ for i, pred in enumerate(result):
+ f.write(f"{i + 1},{pred['class']},{pred['probability']}\n")
+ else:
+ f.write(f"class,probability\n")
+ f.write(f"{result['class']},{result['probability']}\n")
+
+ logger.info(f"结果已保存到: {output_path}")
+
+ return result
+
+
+def predict_file(file_path: str, model_path: Optional[str] = None,
+ output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]:
+ """
+ 预测文件内容
+
+ Args:
+ file_path: 文件路径
+ model_path: 模型路径,如果为None则使用最新的模型
+ output_path: 输出文件路径,如果为None则不保存
+ top_k: 返回概率最高的前k个类别
+
+ Returns:
+ 预测结果
+ """
+ logger.info(f"开始预测文件: {file_path}")
+
+ # 检查文件是否存在
+ if not os.path.exists(file_path):
+ raise FileNotFoundError(f"文件不存在: {file_path}")
+
+ # 读取文件内容
+ if file_path.endswith('.txt'):
+ # 文本文件
+ text = read_text_file(file_path)
+ return predict_text(text, model_path, output_path, top_k)
+
+ elif file_path.endswith(('.csv', '.xls', '.xlsx')):
+ # 表格文件
+ import pandas as pd
+
+ if file_path.endswith('.csv'):
+ df = pd.read_csv(file_path)
+ else:
+ df = pd.read_excel(file_path)
+
+ # 查找可能的文本列
+ text_columns = [col for col in df.columns if df[col].dtype == 'object']
+
+ if not text_columns:
+ raise ValueError("文件中没有找到可能的文本列")
+
+ # 使用第一个文本列
+ text_column = text_columns[0]
+ logger.info(f"使用文本列: {text_column}")
+
+ # 1. 加载模型
+ if model_path is None:
+ # 获取可用模型列表
+ models_info = ModelFactory.get_available_models()
+
+ if not models_info:
+ raise ValueError("未找到可用的模型")
+
+ # 使用最新的模型
+ model_path = models_info[0]['path']
+
+ logger.info(f"加载模型: {model_path}")
+ model = ModelFactory.load_model(model_path)
+
+ # 2. 创建分词器和预测器
+ tokenizer = ChineseTokenizer()
+
+ # 查找向量化器文件
+ vectorizer = None
+ for model_type in ["cnn", "rnn", "transformer"]:
+ vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
+ if os.path.exists(vectorizer_path):
+ # 加载向量化器
+ vectorizer = SequenceVectorizer()
+ vectorizer.load(vectorizer_path)
+ logger.info(f"加载向量化器: {vectorizer_path}")
+ break
+
+ # 创建预测器
+ predictor = Predictor(
+ model=model,
+ tokenizer=tokenizer,
+ vectorizer=vectorizer,
+ class_names=CATEGORIES
+ )
+
+ # 3. 创建批处理器
+ batch_processor = BatchProcessor(
+ predictor=predictor,
+ batch_size=64
+ )
+
+ # 4. 批量预测
+ result_df = batch_processor.process_dataframe(
+ df=df,
+ text_column=text_column,
+ output_path=output_path,
+ return_top_k=top_k,
+ format=output_path.split('.')[-1] if output_path else 'csv'
+ )
+
+ logger.info(f"已处理 {len(result_df)} 行数据")
+
+ # 返回结果
+ return result_df.to_dict(orient='records')
+
+ else:
+ raise ValueError(f"不支持的文件类型: {file_path}")
+
+
+if __name__ == "__main__":
+ # 解析命令行参数
+ parser = argparse.ArgumentParser(description="使用模型预测")
+ parser.add_argument("--model_path", help="模型路径")
+ parser.add_argument("--text", help="要预测的文本")
+ parser.add_argument("--file", help="要预测的文件")
+ parser.add_argument("--output", help="输出文件")
+ parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别")
+
+ args = parser.parse_args()
+
+ # 检查输入
+ if not args.text and not args.file:
+ parser.error("请提供要预测的文本或文件")
+
+ # 预测
+ if args.text:
+ predict_text(args.text, args.model_path, args.output, args.top_k)
+ else:
+ predict_file(args.file, args.model_path, args.output, args.top_k)
diff --git a/scripts/train.py b/scripts/train.py
new file mode 100644
index 0000000..c473a80
--- /dev/null
+++ b/scripts/train.py
@@ -0,0 +1,203 @@
+"""
+训练脚本:训练文本分类模型
+"""
+import os
+import sys
+import time
+import argparse
+import logging
+from typing import List, Dict, Tuple, Optional, Any, Union
+import numpy as np
+import tensorflow as tf
+import matplotlib.pyplot as plt
+
+# 将项目根目录添加到系统路径
+project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.append(project_root)
+
+from config.system_config import (
+ RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR
+)
+from config.model_config import (
+ BATCH_SIZE, NUM_EPOCHS, MAX_SEQUENCE_LENGTH, MAX_NUM_WORDS
+)
+from data.dataloader import DataLoader
+from data.data_manager import DataManager
+from preprocessing.tokenization import ChineseTokenizer
+from preprocessing.vectorizer import SequenceVectorizer
+from models.model_factory import ModelFactory
+from training.trainer import Trainer
+from utils.logger import get_logger
+
+logger = get_logger("Training")
+
+
+def train_model(data_dir: Optional[str] = None,
+ model_type: str = "cnn",
+ epochs: int = NUM_EPOCHS,
+ batch_size: int = BATCH_SIZE,
+ save_dir: Optional[str] = None,
+ validation_split: float = 0.1,
+ use_pretrained_embedding: bool = False,
+ embedding_path: Optional[str] = None) -> str:
+ """
+ 训练文本分类模型
+
+ Args:
+ data_dir: 数据目录,如果为None则使用默认目录
+ model_type: 模型类型,'cnn', 'rnn', 或 'transformer'
+ epochs: 训练轮数
+ batch_size: 批大小
+ save_dir: 模型保存目录,如果为None则使用默认目录
+ validation_split: 验证集比例
+ use_pretrained_embedding: 是否使用预训练词向量
+ embedding_path: 预训练词向量路径
+
+ Returns:
+ 保存的模型路径
+ """
+ logger.info(f"开始训练 {model_type.upper()} 模型")
+ start_time = time.time()
+
+ # 设置数据目录
+ data_dir = data_dir or RAW_DATA_DIR
+
+ # 设置保存目录
+ if save_dir:
+ save_dir = os.path.abspath(save_dir)
+ os.makedirs(save_dir, exist_ok=True)
+ else:
+ save_dir = CLASSIFIERS_DIR
+ os.makedirs(save_dir, exist_ok=True)
+
+ # 1. 加载数据
+ logger.info("加载数据...")
+ data_loader = DataLoader(data_dir=data_dir)
+ data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR)
+
+ # 加载和分割数据
+ data = data_manager.load_and_split_data(
+ data_loader=data_loader,
+ val_split=validation_split,
+ sample_ratio=1.0,
+ save=True
+ )
+
+ # 获取训练集和验证集
+ train_texts, train_labels = data_manager.get_data(dataset="train")
+ val_texts, val_labels = data_manager.get_data(dataset="val")
+
+ # 2. 准备数据
+ # 创建分词器
+ tokenizer = ChineseTokenizer()
+
+ # 对训练文本进行分词
+ logger.info("对文本进行分词...")
+ tokenized_train_texts = [tokenizer.tokenize(text, return_string=True) for text in train_texts]
+ tokenized_val_texts = [tokenizer.tokenize(text, return_string=True) for text in val_texts]
+
+ # 创建序列向量化器
+ logger.info("创建序列向量化器...")
+ vectorizer = SequenceVectorizer(
+ max_features=MAX_NUM_WORDS,
+ max_sequence_length=MAX_SEQUENCE_LENGTH
+ )
+
+ # 训练向量化器并转换文本
+ vectorizer.fit(tokenized_train_texts)
+ X_train = vectorizer.transform(tokenized_train_texts)
+ X_val = vectorizer.transform(tokenized_val_texts)
+
+ # 保存向量化器
+ vectorizer_path = os.path.join(save_dir, f"vectorizer_{model_type}.pkl")
+ vectorizer.save(vectorizer_path)
+ logger.info(f"向量化器已保存到: {vectorizer_path}")
+
+ # 获取一些基本参数
+ num_classes = len(CATEGORIES)
+ vocab_size = vectorizer.get_vocabulary_size()
+
+ # 3. 创建模型
+ logger.info(f"创建 {model_type.upper()} 模型...")
+
+ # 加载预训练词向量(如果指定)
+ embedding_matrix = None
+ if use_pretrained_embedding and embedding_path:
+ # 这里简化处理,实际应用中应该加载和处理预训练词向量
+ logger.info("加载预训练词向量...")
+ embedding_matrix = np.random.random((vocab_size, 200))
+
+ # 创建模型
+ model = ModelFactory.create_model(
+ model_type=model_type,
+ num_classes=num_classes,
+ vocab_size=vocab_size,
+ embedding_matrix=embedding_matrix,
+ batch_size=batch_size
+ )
+
+ # 构建模型
+ model.build()
+ model.compile()
+ model.summary()
+
+ # 4. 训练模型
+ logger.info("开始训练模型...")
+ trainer = Trainer(
+ model=model,
+ epochs=epochs,
+ batch_size=batch_size,
+ early_stopping=True,
+ tensorboard=True
+ )
+
+ # 训练
+ history = trainer.train(
+ x_train=X_train,
+ y_train=train_labels,
+ x_val=X_val,
+ y_val=val_labels
+ )
+
+ # 5. 保存模型
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
+ model_path = os.path.join(save_dir, f"{model_type}_model_{timestamp}")
+ model.save(model_path)
+ logger.info(f"模型已保存到: {model_path}")
+
+ # 6. 绘制训练历史
+ logger.info("绘制训练历史...")
+ model.plot_training_history(save_path=os.path.join(save_dir, f"training_history_{model_type}_{timestamp}.png"))
+
+ # 7. 计算训练时间
+ train_time = time.time() - start_time
+ logger.info(f"模型训练完成,耗时: {train_time:.2f} 秒")
+
+ return model_path
+
+
+if __name__ == "__main__":
+ # 解析命令行参数
+ parser = argparse.ArgumentParser(description="训练文本分类模型")
+ parser.add_argument("--data_dir", help="数据目录")
+ parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型")
+ parser.add_argument("--epochs", type=int, default=NUM_EPOCHS, help="训练轮数")
+ parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小")
+ parser.add_argument("--save_dir", help="模型保存目录")
+ parser.add_argument("--validation_split", type=float, default=0.1, help="验证集比例")
+ parser.add_argument("--use_pretrained_embedding", action="store_true", help="是否使用预训练词向量")
+ parser.add_argument("--embedding_path", help="预训练词向量路径")
+
+ args = parser.parse_args()
+
+ # 训练模型
+ train_model(
+ data_dir=args.data_dir,
+ model_type=args.model_type,
+ epochs=args.epochs,
+ batch_size=args.batch_size,
+ save_dir=args.save_dir,
+ validation_split=args.validation_split,
+ use_pretrained_embedding=args.use_pretrained_embedding,
+ embedding_path=args.embedding_path
+ )
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..4ecce2d
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,40 @@
+from setuptools import setup, find_packages
+
+with open("README.md", "r", encoding="utf-8") as fh:
+ long_description = fh.read()
+
+setup(
+ name="chinese-text-classification",
+ version="1.0.0",
+ author="Your Name",
+ author_email="your.email@example.com",
+ description="基于Python的中文文本分类系统",
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ url="https://github.com/yourusername/chinese-text-classification",
+ packages=find_packages(),
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: MIT License",
+ "Operating System :: OS Independent",
+ ],
+ python_requires=">=3.7",
+ install_requires=[
+ "tensorflow>=2.5.0",
+ "numpy>=1.19.5",
+ "pandas>=1.3.0",
+ "scikit-learn>=0.24.2",
+ "matplotlib>=3.4.2",
+ "jieba>=0.42.1",
+ "tqdm>=4.61.1",
+ "gensim>=4.0.1",
+ "flask>=2.0.1",
+ "fastapi>=0.68.0",
+ "uvicorn>=0.15.0"
+ ],
+ entry_points={
+ "console_scripts": [
+ "text-classifier=main:main",
+ ],
+ },
+)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_models.py b/tests/test_models.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py
new file mode 100644
index 0000000..e69de29
diff --git a/training/__init__.py b/training/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/training/callbacks.py b/training/callbacks.py
new file mode 100644
index 0000000..31e5813
--- /dev/null
+++ b/training/callbacks.py
@@ -0,0 +1,430 @@
+"""
+回调函数模块:提供用于模型训练的自定义回调函数
+"""
+import os
+import time
+import numpy as np
+import tensorflow as tf
+from typing import List, Dict, Tuple, Optional, Any, Union
+import matplotlib.pyplot as plt
+from io import BytesIO
+
+from utils.logger import get_logger
+
+logger = get_logger("Callbacks")
+
+
+class MetricsHistory(tf.keras.callbacks.Callback):
+ """跟踪训练过程中的指标历史"""
+
+ def __init__(self, validation_data: Optional[Tuple] = None,
+ metrics: Optional[List[str]] = None,
+ save_path: Optional[str] = None):
+ """
+ 初始化MetricsHistory回调
+
+ Args:
+ validation_data: 验证数据,格式为(x_val, y_val)
+ metrics: 要跟踪的指标列表
+ save_path: 指标历史的保存路径
+ """
+ super().__init__()
+ self.validation_data = validation_data
+ self.metrics = metrics or ['loss', 'accuracy']
+ self.save_path = save_path
+
+ # 历史指标
+ self.history = {metric: [] for metric in self.metrics}
+ if validation_data is not None:
+ for metric in self.metrics:
+ self.history[f'val_{metric}'] = []
+
+ def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
+ """
+ 每个epoch结束时调用
+
+ Args:
+ epoch: 当前epoch索引
+ logs: 训练日志
+ """
+ logs = logs or {}
+
+ # 记录训练指标
+ for metric in self.metrics:
+ if metric in logs:
+ self.history[metric].append(logs[metric])
+
+ # 记录验证指标
+ if self.validation_data is not None:
+ for metric in self.metrics:
+ val_metric = f'val_{metric}'
+ if val_metric in logs:
+ self.history[val_metric].append(logs[val_metric])
+
+ def plot_metrics(self, save_path: Optional[str] = None) -> None:
+ """
+ 绘制指标历史
+
+ Args:
+ save_path: 图像保存路径,如果为None则使用初始化时设置的路径
+ """
+ plt.figure(figsize=(12, 5))
+
+ for i, metric in enumerate(self.metrics):
+ plt.subplot(1, len(self.metrics), i + 1)
+
+ if metric in self.history:
+ plt.plot(self.history[metric], label=f'train_{metric}')
+
+ val_metric = f'val_{metric}'
+ if val_metric in self.history:
+ plt.plot(self.history[val_metric], label=f'val_{metric}')
+
+ plt.title(f'Model {metric}')
+ plt.xlabel('Epoch')
+ plt.ylabel(metric)
+ plt.legend()
+
+ plt.tight_layout()
+
+ save_path = save_path or self.save_path
+ if save_path:
+ plt.savefig(save_path)
+ logger.info(f"指标历史图已保存到: {save_path}")
+ else:
+ plt.show()
+
+
+class ConfusionMatrixCallback(tf.keras.callbacks.Callback):
+ """计算并显示验证集上的混淆矩阵"""
+
+ def __init__(self, validation_data: Tuple[np.ndarray, np.ndarray],
+ class_names: Optional[List[str]] = None,
+ log_dir: Optional[str] = None,
+ freq: int = 1,
+ fig_size: Tuple[int, int] = (10, 8)):
+ """
+ 初始化ConfusionMatrixCallback
+
+ Args:
+ validation_data: 验证数据,格式为(x_val, y_val)
+ class_names: 类别名称列表
+ log_dir: TensorBoard日志目录
+ freq: 计算混淆矩阵的频率(每多少个epoch计算一次)
+ fig_size: 图像大小
+ """
+ super().__init__()
+ self.x_val, self.y_val = validation_data
+ self.class_names = class_names
+ self.log_dir = log_dir
+ self.freq = freq
+ self.fig_size = fig_size
+
+ # 如果提供了TensorBoard日志目录,创建一个文件写入器
+ if log_dir:
+ self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'confusion_matrix'))
+ else:
+ self.file_writer = None
+
+ def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
+ """
+ 每个epoch结束时调用
+
+ Args:
+ epoch: 当前epoch索引
+ logs: 训练日志
+ """
+ # 每freq个epoch计算一次混淆矩阵
+ if (epoch + 1) % self.freq == 0 or epoch == 0:
+ # 获取预测结果
+ y_pred = np.argmax(self.model.predict(self.x_val), axis=1)
+
+ # 确保y_val是一维数组
+ y_true = self.y_val
+ if len(y_true.shape) > 1 and y_true.shape[1] > 1:
+ y_true = np.argmax(y_true, axis=1)
+
+ # 计算混淆矩阵
+ cm = tf.math.confusion_matrix(y_true, y_pred).numpy()
+
+ # 归一化混淆矩阵
+ cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
+
+ # 绘制混淆矩阵
+ fig = self._plot_confusion_matrix(cm_norm, epoch + 1)
+
+ # 如果有TensorBoard日志,将图像添加到TensorBoard
+ if self.file_writer:
+ with self.file_writer.as_default():
+ # 将matplotlib图像转换为TensorBoard图像
+ buf = BytesIO()
+ fig.savefig(buf, format='png')
+ buf.seek(0)
+
+ # 将PNG编码为字符串,并创建图像
+ image = tf.image.decode_png(buf.getvalue(), channels=4)
+ image = tf.expand_dims(image, 0)
+
+ # 添加到TensorBoard
+ tf.summary.image(f'Confusion Matrix (Epoch {epoch + 1})', image, step=epoch)
+
+ plt.close(fig)
+
+ def _plot_confusion_matrix(self, cm: np.ndarray, epoch: int) -> plt.Figure:
+ """
+ 绘制混淆矩阵
+
+ Args:
+ cm: 混淆矩阵
+ epoch: 当前epoch
+
+ Returns:
+ matplotlib图像
+ """
+ fig, ax = plt.subplots(figsize=self.fig_size)
+
+ # 使用热图显示混淆矩阵
+ im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
+ ax.figure.colorbar(im, ax=ax)
+
+ # 设置坐标轴标签
+ if self.class_names:
+ ax.set(
+ xticks=np.arange(cm.shape[1]),
+ yticks=np.arange(cm.shape[0]),
+ xticklabels=self.class_names,
+ yticklabels=self.class_names
+ )
+
+ # 旋转x轴标签
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
+
+ # 在每个单元格中显示数值
+ thresh = cm.max() / 2.0
+ for i in range(cm.shape[0]):
+ for j in range(cm.shape[1]):
+ ax.text(j, i, format(cm[i, j], '.2f'),
+ ha="center", va="center",
+ color="white" if cm[i, j] > thresh else "black")
+
+ ax.set_title(f"Normalized Confusion Matrix (Epoch {epoch})")
+ ax.set_ylabel('True label')
+ ax.set_xlabel('Predicted label')
+
+ fig.tight_layout()
+ return fig
+
+
+class TimingCallback(tf.keras.callbacks.Callback):
+ """测量训练时间的回调函数"""
+
+ def __init__(self):
+ """初始化TimingCallback"""
+ super().__init__()
+ self.epoch_times = []
+ self.batch_times = []
+ self.epoch_start_time = None
+ self.batch_start_time = None
+ self.training_start_time = None
+
+ def on_train_begin(self, logs: Dict[str, float] = None) -> None:
+ """
+ 训练开始时调用
+
+ Args:
+ logs: 训练日志
+ """
+ self.training_start_time = time.time()
+
+ def on_train_end(self, logs: Dict[str, float] = None) -> None:
+ """
+ 训练结束时调用
+
+ Args:
+ logs: 训练日志
+ """
+ training_time = time.time() - self.training_start_time
+ logger.info(f"总训练时间: {training_time:.2f} 秒")
+
+ if self.epoch_times:
+ avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
+ logger.info(f"平均每个epoch时间: {avg_epoch_time:.2f} 秒")
+
+ if self.batch_times:
+ avg_batch_time = sum(self.batch_times) / len(self.batch_times)
+ logger.info(f"平均每个batch时间: {avg_batch_time:.4f} 秒")
+
+ def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None:
+ """
+ 每个epoch开始时调用
+
+ Args:
+ epoch: 当前epoch索引
+ logs: 训练日志
+ """
+ self.epoch_start_time = time.time()
+
+ def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
+ """
+ 每个epoch结束时调用
+
+ Args:
+ epoch: 当前epoch索引
+ logs: 训练日志
+ """
+ epoch_time = time.time() - self.epoch_start_time
+ self.epoch_times.append(epoch_time)
+
+ # 将epoch时间添加到日志中
+ if logs is not None:
+ logs['epoch_time'] = epoch_time
+
+ def on_batch_begin(self, batch: int, logs: Dict[str, float] = None) -> None:
+ """
+ 每个batch开始时调用
+
+ Args:
+ batch: 当前batch索引
+ logs: 训练日志
+ """
+ self.batch_start_time = time.time()
+
+ def on_batch_end(self, batch: int, logs: Dict[str, float] = None) -> None:
+ """
+ 每个batch结束时调用
+
+ Args:
+ batch: 当前batch索引
+ logs: 训练日志
+ """
+ batch_time = time.time() - self.batch_start_time
+ self.batch_times.append(batch_time)
+
+
+class LearningRateSchedulerCallback(tf.keras.callbacks.Callback):
+ """学习率调度器回调函数"""
+
+ def __init__(self, scheduler_func: Callable[[int, float], float],
+ verbose: int = 0,
+ log_dir: Optional[str] = None):
+ """
+ 初始化LearningRateSchedulerCallback
+
+ Args:
+ scheduler_func: 学习率调度函数,接收(epoch, lr)参数,返回新的学习率
+ verbose: 详细程度
+ log_dir: TensorBoard日志目录
+ """
+ super().__init__()
+ self.scheduler_func = scheduler_func
+ self.verbose = verbose
+
+ # 如果提供了TensorBoard日志目录,创建一个文件写入器
+ if log_dir:
+ self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'learning_rate'))
+ else:
+ self.file_writer = None
+
+ # 学习率历史
+ self.lr_history = []
+
+ def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None:
+ """
+ 每个epoch开始时调用
+
+ Args:
+ epoch: 当前epoch索引
+ logs: 训练日志
+ """
+ if not hasattr(self.model.optimizer, 'lr'):
+ raise ValueError('Optimizer must have a "lr" attribute.')
+
+ # 获取当前学习率
+ current_lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
+
+ # 计算新的学习率
+ new_lr = self.scheduler_func(epoch, current_lr)
+
+ # 设置新的学习率
+ tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
+
+ # 记录学习率
+ self.lr_history.append(new_lr)
+
+ # 记录到TensorBoard
+ if self.file_writer:
+ with self.file_writer.as_default():
+ tf.summary.scalar('learning_rate', new_lr, step=epoch)
+
+ if self.verbose > 0:
+ logger.info(f"Epoch {epoch + 1}: 学习率设置为 {new_lr:.6f}")
+
+ def get_lr_history(self) -> List[float]:
+ """
+ 获取学习率历史
+
+ Returns:
+ 学习率历史列表
+ """
+ return self.lr_history
+
+
+class EarlyStoppingCallback(tf.keras.callbacks.EarlyStopping):
+ """增强版早停回调函数,支持最小变化率"""
+
+ def __init__(self, monitor: str = 'val_loss',
+ min_delta: float = 0,
+ min_delta_ratio: float = 0,
+ patience: int = 0,
+ verbose: int = 0,
+ mode: str = 'auto',
+ baseline: Optional[float] = None,
+ restore_best_weights: bool = False):
+ """
+ 初始化EarlyStoppingCallback
+
+ Args:
+ monitor: 监控的指标
+ min_delta: 视为改进的最小绝对变化
+ min_delta_ratio: 视为改进的最小相对变化率
+ patience: 没有改进的轮数
+ verbose: 详细程度
+ mode: 'auto', 'min' 或 'max'
+ baseline: 基准值
+ restore_best_weights: 是否恢复最佳权重
+ """
+ super().__init__(
+ monitor=monitor,
+ min_delta=min_delta,
+ patience=patience,
+ verbose=verbose,
+ mode=mode,
+ baseline=baseline,
+ restore_best_weights=restore_best_weights
+ )
+ self.min_delta_ratio = min_delta_ratio
+
+ def _is_improvement(self, current: float, reference: float) -> bool:
+ """
+ 判断是否有所改进
+
+ Args:
+ current: 当前值
+ reference: 参考值
+
+ Returns:
+ 是否有所改进
+ """
+ # 先检查绝对变化
+ if super()._is_improvement(current, reference):
+ return True
+
+ # 再检查相对变化率
+ if self.monitor_op == np.less:
+ # 对于 'min' 模式,值越小越好
+ relative_delta = (reference - current) / reference if reference != 0 else 0
+ return relative_delta > self.min_delta_ratio
+ else:
+ # 对于 'max' 模式,值越大越好
+ relative_delta = (current - reference) / reference if reference != 0 else 0
+ return relative_delta > self.min_delta_ratio
diff --git a/training/optimizer.py b/training/optimizer.py
new file mode 100644
index 0000000..e69de29
diff --git a/training/scheduler.py b/training/scheduler.py
new file mode 100644
index 0000000..e6153e0
--- /dev/null
+++ b/training/scheduler.py
@@ -0,0 +1,284 @@
+"""
+学习率调度器模块:提供各种学习率调度策略
+"""
+import numpy as np
+import math
+from typing import Callable, Optional, Union, Dict
+import tensorflow as tf
+
+from utils.logger import get_logger
+
+logger = get_logger("Scheduler")
+
+
+def step_decay(epoch: int, initial_lr: float,
+ drop_rate: float = 0.5,
+ epochs_drop: int = 10) -> float:
+ """
+ 阶梯式学习率衰减
+
+ Args:
+ epoch: 当前epoch索引
+ initial_lr: 初始学习率
+ drop_rate: 衰减率
+ epochs_drop: 每多少个epoch衰减一次
+
+ Returns:
+ 新的学习率
+ """
+ return initial_lr * math.pow(drop_rate, math.floor((1 + epoch) / epochs_drop))
+
+
+def exponential_decay(epoch: int, initial_lr: float,
+ decay_rate: float = 0.9,
+ staircase: bool = False) -> float:
+ """
+ 指数衰减学习率
+
+ Args:
+ epoch: 当前epoch索引
+ initial_lr: 初始学习率
+ decay_rate: 衰减率
+ staircase: 是否阶梯式衰减
+
+ Returns:
+ 新的学习率
+ """
+ if staircase:
+ return initial_lr * math.pow(decay_rate, math.floor(epoch))
+ else:
+ return initial_lr * math.pow(decay_rate, epoch)
+
+
+def cosine_decay(epoch: int, initial_lr: float,
+ total_epochs: int = 100,
+ min_lr: float = 0) -> float:
+ """
+ 余弦退火学习率
+
+ Args:
+ epoch: 当前epoch索引
+ initial_lr: 初始学习率
+ total_epochs: 总epoch数
+ min_lr: 最小学习率
+
+ Returns:
+ 新的学习率
+ """
+ return min_lr + 0.5 * (initial_lr - min_lr) * (1 + math.cos(math.pi * epoch / total_epochs))
+
+
+def cosine_decay_with_warmup(epoch: int, initial_lr: float,
+ total_epochs: int = 100,
+ warmup_epochs: int = 5,
+ min_lr: float = 0,
+ warmup_init_lr: float = 0) -> float:
+ """
+ 带预热的余弦退火学习率
+
+ Args:
+ epoch: 当前epoch索引
+ initial_lr: 初始学习率
+ total_epochs: 总epoch数
+ warmup_epochs: 预热epoch数
+ min_lr: 最小学习率
+ warmup_init_lr: 预热初始学习率
+
+ Returns:
+ 新的学习率
+ """
+ if epoch < warmup_epochs:
+ # 线性预热
+ return warmup_init_lr + (initial_lr - warmup_init_lr) * epoch / warmup_epochs
+ else:
+ # 余弦退火
+ return min_lr + 0.5 * (initial_lr - min_lr) * (
+ 1 + math.cos(math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs))
+ )
+
+
+def cyclical_learning_rate(epoch: int, initial_lr: float,
+ max_lr: float = 0.1,
+ step_size: int = 8,
+ gamma: float = 1.0) -> float:
+ """
+ 循环学习率
+
+ Args:
+ epoch: 当前epoch索引
+ initial_lr: 初始学习率
+ max_lr: 最大学习率
+ step_size: 半周期大小(epoch数)
+ gamma: 循环衰减率
+
+ Returns:
+ 新的学习率
+ """
+ # 计算循环数
+ cycle = math.floor(1 + epoch / (2 * step_size))
+
+ # 计算x值,范围在[0, 2]
+ x = abs(epoch / step_size - 2 * cycle + 1)
+
+ # 应用循环衰减
+ lr_range = (max_lr - initial_lr) * math.pow(gamma, cycle - 1)
+
+ # 计算学习率
+ return initial_lr + lr_range * max(0, 1 - x)
+
+
+def create_custom_scheduler(scheduler_type: str, **kwargs) -> Callable[[int, float], float]:
+ """
+ 创建自定义学习率调度器
+
+ Args:
+ scheduler_type: 调度器类型,可选值: 'step', 'exp', 'cosine', 'cosine_warmup', 'cyclical'
+ **kwargs: 调度器参数
+
+ Returns:
+ 学习率调度函数
+ """
+ scheduler_type = scheduler_type.lower()
+
+ if scheduler_type == 'step':
+ drop_rate = kwargs.get('drop_rate', 0.5)
+ epochs_drop = kwargs.get('epochs_drop', 10)
+
+ def scheduler(epoch, lr):
+ return step_decay(epoch, lr, drop_rate, epochs_drop)
+
+ return scheduler
+
+ elif scheduler_type == 'exp':
+ decay_rate = kwargs.get('decay_rate', 0.9)
+ staircase = kwargs.get('staircase', False)
+
+ def scheduler(epoch, lr):
+ if epoch == 0:
+ # 第一个epoch使用初始学习率
+ return lr
+ return exponential_decay(epoch, lr, decay_rate, staircase)
+
+ return scheduler
+
+ elif scheduler_type == 'cosine':
+ total_epochs = kwargs.get('total_epochs', 100)
+ min_lr = kwargs.get('min_lr', 0)
+
+ def scheduler(epoch, lr):
+ if epoch == 0:
+ return lr
+ return cosine_decay(epoch, lr, total_epochs, min_lr)
+
+ return scheduler
+
+ elif scheduler_type == 'cosine_warmup':
+ total_epochs = kwargs.get('total_epochs', 100)
+ warmup_epochs = kwargs.get('warmup_epochs', 5)
+ min_lr = kwargs.get('min_lr', 0)
+ warmup_init_lr = kwargs.get('warmup_init_lr', 0)
+
+ def scheduler(epoch, lr):
+ if epoch == 0:
+ return warmup_init_lr
+ return cosine_decay_with_warmup(epoch, lr, total_epochs, warmup_epochs, min_lr, warmup_init_lr)
+
+ return scheduler
+
+ elif scheduler_type == 'cyclical':
+ max_lr = kwargs.get('max_lr', 0.1)
+ step_size = kwargs.get('step_size', 8)
+ gamma = kwargs.get('gamma', 1.0)
+
+ def scheduler(epoch, lr):
+ if epoch == 0:
+ return lr
+ return cyclical_learning_rate(epoch, lr, max_lr, step_size, gamma)
+
+ return scheduler
+
+ else:
+ raise ValueError(f"不支持的调度器类型: {scheduler_type}")
+
+
+class WarmupCosineDecayScheduler(tf.keras.callbacks.Callback):
+ """预热余弦退火学习率调度器"""
+
+ def __init__(self, learning_rate_base: float,
+ total_steps: int,
+ warmup_learning_rate: float = 0.0,
+ warmup_steps: int = 0,
+ hold_base_rate_steps: int = 0,
+ verbose: int = 0):
+ """
+ 初始化预热余弦退火学习率调度器
+
+ Args:
+ learning_rate_base: 基础学习率
+ total_steps: 总步数
+ warmup_learning_rate: 预热学习率
+ warmup_steps: 预热步数
+ hold_base_rate_steps: 保持基础学习率的步数
+ verbose: 详细程度
+ """
+ super().__init__()
+
+ self.learning_rate_base = learning_rate_base
+ self.total_steps = total_steps
+ self.warmup_learning_rate = warmup_learning_rate
+ self.warmup_steps = warmup_steps
+ self.hold_base_rate_steps = hold_base_rate_steps
+ self.verbose = verbose
+
+ # 学习率历史
+ self.learning_rates = []
+
+ def on_train_begin(self, logs: Optional[Dict] = None) -> None:
+ """
+ 训练开始时调用
+
+ Args:
+ logs: 训练日志
+ """
+ self.current_step = 0
+ logger.info(f"预热余弦退火学习率调度器初始化: 基础学习率={self.learning_rate_base}, "
+ f"预热步数={self.warmup_steps}, 总步数={self.total_steps}")
+
+ def on_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """
+ 每个batch结束时调用
+
+ Args:
+ batch: 当前batch索引
+ logs: 训练日志
+ """
+ self.current_step += 1
+ lr = self._get_learning_rate()
+
+ tf.keras.backend.set_value(self.model.optimizer.lr, lr)
+ self.learning_rates.append(lr)
+
+ def _get_learning_rate(self) -> float:
+ """
+ 计算当前学习率
+
+ Returns:
+ 当前学习率
+ """
+ if self.current_step < self.warmup_steps:
+ # 预热阶段:线性增加学习率
+ lr = self.warmup_learning_rate + self.current_step * (
+ (self.learning_rate_base - self.warmup_learning_rate) / self.warmup_steps
+ )
+ elif self.current_step < self.warmup_steps + self.hold_base_rate_steps:
+ # 保持基础学习率阶段
+ lr = self.learning_rate_base
+ else:
+ # 余弦退火阶段
+ cosine_steps = self.total_steps - self.warmup_steps - self.hold_base_rate_steps
+ cosine_current_step = self.current_step - self.warmup_steps - self.hold_base_rate_steps
+ lr = 0.5 * self.learning_rate_base * (
+ 1 + math.cos(math.pi * cosine_current_step / cosine_steps)
+ )
+
+ return lr
diff --git a/training/trainer.py b/training/trainer.py
new file mode 100644
index 0000000..b955e33
--- /dev/null
+++ b/training/trainer.py
@@ -0,0 +1,281 @@
+"""
+训练器模块:实现模型训练流程,包括训练循环、验证等
+"""
+import os
+import time
+from typing import List, Dict, Tuple, Optional, Any, Union, Callable
+import numpy as np
+import tensorflow as tf
+import matplotlib.pyplot as plt
+from datetime import datetime
+
+from config.system_config import SAVED_MODELS_DIR
+from config.model_config import (
+ NUM_EPOCHS, BATCH_SIZE, EARLY_STOPPING_PATIENCE,
+ VALIDATION_SPLIT, RANDOM_SEED
+)
+from models.base_model import TextClassificationModel
+from utils.logger import get_logger, TrainingLogger
+from utils.file_utils import ensure_dir
+
+logger = get_logger("Trainer")
+
+
+class Trainer:
+ """模型训练器,负责训练和验证模型"""
+
+ def __init__(self, model: TextClassificationModel,
+ epochs: int = NUM_EPOCHS,
+ batch_size: Optional[int] = None,
+ validation_split: float = VALIDATION_SPLIT,
+ early_stopping: bool = True,
+ early_stopping_patience: int = EARLY_STOPPING_PATIENCE,
+ save_best_only: bool = True,
+ tensorboard: bool = True,
+ checkpoint: bool = True,
+ custom_callbacks: Optional[List[tf.keras.callbacks.Callback]] = None):
+ """
+ 初始化训练器
+
+ Args:
+ model: 要训练的模型
+ epochs: 训练轮数
+ batch_size: 批大小,如果为None则使用模型默认值
+ validation_split: 验证集比例
+ early_stopping: 是否使用早停
+ early_stopping_patience: 早停耐心值
+ save_best_only: 是否只保存最佳模型
+ tensorboard: 是否使用TensorBoard
+ checkpoint: 是否保存检查点
+ custom_callbacks: 自定义回调函数列表
+ """
+ self.model = model
+ self.epochs = epochs
+ self.batch_size = batch_size or model.batch_size
+ self.validation_split = validation_split
+ self.early_stopping = early_stopping
+ self.early_stopping_patience = early_stopping_patience
+ self.save_best_only = save_best_only
+ self.tensorboard = tensorboard
+ self.checkpoint = checkpoint
+ self.custom_callbacks = custom_callbacks or []
+
+ # 训练历史
+ self.history = None
+
+ # 训练日志记录器
+ self.training_logger = TrainingLogger(model.model_name)
+
+ logger.info(f"初始化训练器,模型: {model.model_name}, 轮数: {epochs}, 批大小: {self.batch_size}")
+
+ def _create_callbacks(self) -> List[tf.keras.callbacks.Callback]:
+ """
+ 创建回调函数列表
+
+ Returns:
+ 回调函数列表
+ """
+ callbacks = []
+
+ # 早停
+ if self.early_stopping:
+ early_stopping = tf.keras.callbacks.EarlyStopping(
+ monitor='val_loss',
+ patience=self.early_stopping_patience,
+ restore_best_weights=True,
+ verbose=1
+ )
+ callbacks.append(early_stopping)
+
+ # 学习率衰减
+ reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
+ monitor='val_loss',
+ factor=0.5,
+ patience=self.early_stopping_patience // 2,
+ min_lr=1e-6,
+ verbose=1
+ )
+ callbacks.append(reduce_lr)
+
+ # 模型检查点
+ if self.checkpoint:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ checkpoint_dir = os.path.join(SAVED_MODELS_DIR, 'checkpoints')
+ ensure_dir(checkpoint_dir)
+
+ checkpoint_path = os.path.join(
+ checkpoint_dir,
+ f"{self.model.model_name}_{timestamp}.h5"
+ )
+
+ model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
+ filepath=checkpoint_path,
+ save_best_only=self.save_best_only,
+ monitor='val_loss',
+ verbose=1
+ )
+ callbacks.append(model_checkpoint)
+
+ # TensorBoard
+ if self.tensorboard:
+ log_dir = os.path.join(
+ SAVED_MODELS_DIR,
+ 'logs',
+ f"{self.model.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
+ )
+ ensure_dir(log_dir)
+
+ tensorboard_callback = tf.keras.callbacks.TensorBoard(
+ log_dir=log_dir,
+ histogram_freq=1,
+ update_freq='epoch'
+ )
+ callbacks.append(tensorboard_callback)
+
+ # 添加自定义回调函数
+ callbacks.extend(self.custom_callbacks)
+
+ return callbacks
+
+ def _log_training_progress(self, epoch: int, logs: Dict[str, float]) -> None:
+ """
+ 记录训练进度
+
+ Args:
+ epoch: 当前轮数
+ logs: 日志信息
+ """
+ self.training_logger.log_epoch(epoch, logs)
+
+ def train(self, x_train: Union[np.ndarray, tf.data.Dataset],
+ y_train: Optional[np.ndarray] = None,
+ x_val: Optional[Union[np.ndarray, tf.data.Dataset]] = None,
+ y_val: Optional[np.ndarray] = None,
+ class_weights: Optional[Dict[int, float]] = None) -> Dict[str, List[float]]:
+ """
+ 训练模型
+
+ Args:
+ x_train: 训练数据特征
+ y_train: 训练数据标签
+ x_val: 验证数据特征
+ y_val: 验证数据标签
+ class_weights: 类别权重
+
+ Returns:
+ 训练历史
+ """
+ logger.info(f"开始训练模型: {self.model.model_name}")
+
+ # 创建回调函数
+ callbacks = self._create_callbacks()
+
+ # 添加训练进度记录回调
+ progress_callback = tf.keras.callbacks.LambdaCallback(
+ on_epoch_end=lambda epoch, logs: self._log_training_progress(epoch, logs)
+ )
+ callbacks.append(progress_callback)
+
+ # 记录开始时间
+ start_time = time.time()
+
+ # 记录训练开始信息
+ model_config = self.model.get_config()
+ train_config = {
+ "epochs": self.epochs,
+ "batch_size": self.batch_size,
+ "validation_split": self.validation_split,
+ "early_stopping": self.early_stopping,
+ "early_stopping_patience": self.early_stopping_patience
+ }
+ self.training_logger.log_training_start({**model_config, **train_config})
+
+ # 准备验证数据
+ validation_data = None
+ if x_val is not None and y_val is not None:
+ validation_data = (x_val, y_val)
+
+ # 训练模型
+ history = self.model.fit(
+ x_train, y_train,
+ validation_data=validation_data,
+ epochs=self.epochs,
+ callbacks=callbacks,
+ class_weights=class_weights,
+ verbose=1
+ )
+
+ # 计算训练时间
+ train_time = time.time() - start_time
+
+ # 保存训练历史
+ self.history = history.history
+
+ # 找出最佳性能
+ best_val_loss = min(history.history['val_loss']) if 'val_loss' in history.history else None
+ best_val_acc = max(history.history['val_accuracy']) if 'val_accuracy' in history.history else None
+
+ best_metrics = {}
+ if best_val_loss is not None:
+ best_metrics['val_loss'] = best_val_loss
+ if best_val_acc is not None:
+ best_metrics['val_accuracy'] = best_val_acc
+
+ # 记录训练结束信息
+ self.training_logger.log_training_end(train_time, best_metrics)
+
+ logger.info(f"模型训练完成,用时: {train_time:.2f} 秒")
+
+ return history.history
+
+ def plot_training_history(self, metrics: Optional[List[str]] = None,
+ save_path: Optional[str] = None) -> None:
+ """
+ 绘制训练历史
+
+ Args:
+ metrics: 要绘制的指标列表,默认为['loss', 'accuracy']
+ save_path: 保存路径,如果为None则显示图像
+ """
+ if self.history is None:
+ raise ValueError("模型尚未训练,没有训练历史")
+
+ if metrics is None:
+ metrics = ['loss', 'accuracy']
+
+ plt.figure(figsize=(12, 5))
+
+ for i, metric in enumerate(metrics):
+ plt.subplot(1, len(metrics), i + 1)
+
+ if metric in self.history:
+ plt.plot(self.history[metric], label=f'train_{metric}')
+
+ val_metric = f'val_{metric}'
+ if val_metric in self.history:
+ plt.plot(self.history[val_metric], label=f'val_{metric}')
+
+ plt.title(f'Model {metric}')
+ plt.xlabel('Epoch')
+ plt.ylabel(metric)
+ plt.legend()
+
+ plt.tight_layout()
+
+ if save_path:
+ plt.savefig(save_path)
+ logger.info(f"训练历史图已保存到: {save_path}")
+ else:
+ plt.show()
+
+ def save_trained_model(self, filepath: Optional[str] = None) -> str:
+ """
+ 保存训练好的模型
+
+ Args:
+ filepath: 保存路径,如果为None则使用默认路径
+
+ Returns:
+ 保存路径
+ """
+ return self.model.save(filepath)
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/utils/file_utils.py b/utils/file_utils.py
new file mode 100644
index 0000000..8f986b7
--- /dev/null
+++ b/utils/file_utils.py
@@ -0,0 +1,306 @@
+"""
+文件处理工具模块
+"""
+import os
+import shutil
+import json
+import pickle
+import csv
+from pathlib import Path
+import time
+import hashlib
+from concurrent.futures import ThreadPoolExecutor, as_completed
+import zipfile
+import tarfile
+
+from config.system_config import ENCODING, DATA_LOADING_WORKERS
+from utils.logger import get_logger
+
+logger = get_logger("file_utils")
+
+
+def read_text_file(file_path, encoding=ENCODING):
+ """
+ 读取文本文件内容
+
+ Args:
+ file_path: 文件路径
+ encoding: 文件编码
+
+ Returns:
+ 文件内容
+ """
+ try:
+ with open(file_path, 'r', encoding=encoding) as file:
+ return file.read()
+ except Exception as e:
+ logger.error(f"读取文件 {file_path} 时出错: {str(e)}")
+ return None
+
+
+def write_text_file(content, file_path, encoding=ENCODING):
+ """
+ 写入文本文件
+
+ Args:
+ content: 文件内容
+ file_path: 文件路径
+ encoding: 文件编码
+
+ Returns:
+ 成功标志
+ """
+ try:
+ # 确保目录存在
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+
+ with open(file_path, 'w', encoding=encoding) as file:
+ file.write(content)
+ return True
+ except Exception as e:
+ logger.error(f"写入文件 {file_path} 时出错: {str(e)}")
+ return False
+
+
+def save_json(data, file_path, encoding=ENCODING):
+ """
+ 保存JSON数据到文件
+
+ Args:
+ data: 要保存的数据
+ file_path: 文件路径
+ encoding: 文件编码
+
+ Returns:
+ 成功标志
+ """
+ try:
+ # 确保目录存在
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+
+ with open(file_path, 'w', encoding=encoding) as file:
+ json.dump(data, file, ensure_ascii=False, indent=2)
+ return True
+ except Exception as e:
+ logger.error(f"保存JSON文件 {file_path} 时出错: {str(e)}")
+ return False
+
+
+def load_json(file_path, encoding=ENCODING):
+ """
+ 从文件加载JSON数据
+
+ Args:
+ file_path: 文件路径
+ encoding: 文件编码
+
+ Returns:
+ 加载的数据
+ """
+ try:
+ with open(file_path, 'r', encoding=encoding) as file:
+ return json.load(file)
+ except Exception as e:
+ logger.error(f"加载JSON文件 {file_path} 时出错: {str(e)}")
+ return None
+
+
+def save_pickle(data, file_path):
+ """
+ 使用pickle保存数据
+
+ Args:
+ data: 要保存的数据
+ file_path: 文件路径
+
+ Returns:
+ 成功标志
+ """
+ try:
+ # 确保目录存在
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+
+ with open(file_path, 'wb') as file:
+ pickle.dump(data, file)
+ return True
+ except Exception as e:
+ logger.error(f"保存pickle文件 {file_path} 时出错: {str(e)}")
+ return False
+
+
+def load_pickle(file_path):
+ """
+ 从文件加载pickle数据
+
+ Args:
+ file_path: 文件路径
+
+ Returns:
+ 加载的数据
+ """
+ try:
+ with open(file_path, 'rb') as file:
+ return pickle.load(file)
+ except Exception as e:
+ logger.error(f"加载pickle文件 {file_path} 时出错: {str(e)}")
+ return None
+
+
+def read_files_parallel(file_paths, max_workers=DATA_LOADING_WORKERS, encoding=ENCODING):
+ """
+ 并行读取多个文本文件
+
+ Args:
+ file_paths: 文件路径列表
+ max_workers: 最大工作线程数
+ encoding: 文件编码
+
+ Returns:
+ 文件内容列表
+ """
+ start_time = time.time()
+ results = []
+
+ # 定义单个读取函数
+ def read_single_file(file_path):
+ return read_text_file(file_path, encoding)
+
+ # 使用线程池并行读取
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ future_to_file = {executor.submit(read_single_file, file_path): file_path
+ for file_path in file_paths}
+
+ # 收集结果
+ for future in as_completed(future_to_file):
+ file_path = future_to_file[future]
+ try:
+ content = future.result()
+ if content is not None:
+ results.append(content)
+ except Exception as e:
+ logger.error(f"处理文件 {file_path} 时出错: {str(e)}")
+
+ elapsed = time.time() - start_time
+ logger.info(f"并行读取 {len(file_paths)} 个文件,成功 {len(results)} 个,用时 {elapsed:.2f} 秒")
+
+ return results
+
+
+def get_file_md5(file_path):
+ """
+ 计算文件的MD5哈希值
+
+ Args:
+ file_path: 文件路径
+
+ Returns:
+ MD5哈希值
+ """
+ hash_md5 = hashlib.md5()
+
+ try:
+ with open(file_path, "rb") as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ hash_md5.update(chunk)
+ return hash_md5.hexdigest()
+ except Exception as e:
+ logger.error(f"计算文件 {file_path} 的MD5值时出错: {str(e)}")
+ return None
+
+
+def extract_archive(archive_path, extract_to=None):
+ """
+ 解压缩文件
+
+ Args:
+ archive_path: 压缩文件路径
+ extract_to: 解压目标路径,默认为同目录
+
+ Returns:
+ 成功标志
+ """
+ if extract_to is None:
+ extract_to = os.path.dirname(archive_path)
+
+ try:
+ if archive_path.endswith('.zip'):
+ with zipfile.ZipFile(archive_path, 'r') as zip_ref:
+ zip_ref.extractall(extract_to)
+ elif archive_path.endswith(('.tar.gz', '.tgz')):
+ with tarfile.open(archive_path, 'r:gz') as tar_ref:
+ tar_ref.extractall(extract_to)
+ elif archive_path.endswith('.tar'):
+ with tarfile.open(archive_path, 'r') as tar_ref:
+ tar_ref.extractall(extract_to)
+ else:
+ logger.error(f"不支持的压缩格式: {archive_path}")
+ return False
+
+ logger.info(f"成功解压 {archive_path} 到 {extract_to}")
+ return True
+ except Exception as e:
+ logger.error(f"解压 {archive_path} 时出错: {str(e)}")
+ return False
+
+
+def list_files(directory, pattern=None, recursive=True):
+ """
+ 列出目录中的文件
+
+ Args:
+ directory: 目录路径
+ pattern: 文件名模式(支持通配符)
+ recursive: 是否递归搜索子目录
+
+ Returns:
+ 文件路径列表
+ """
+ if not os.path.exists(directory):
+ logger.error(f"目录不存在: {directory}")
+ return []
+
+ directory = Path(directory)
+
+ if pattern:
+ if recursive:
+ return [str(p) for p in directory.glob(f"**/{pattern}")]
+ else:
+ return [str(p) for p in directory.glob(pattern)]
+ else:
+ if recursive:
+ files = []
+ for p in directory.rglob("*"):
+ if p.is_file():
+ files.append(str(p))
+ return files
+ else:
+ return [str(p) for p in directory.iterdir() if p.is_file()]
+
+
+def ensure_dir(directory):
+ """
+ 确保目录存在,不存在则创建
+
+ Args:
+ directory: 目录路径
+ """
+ os.makedirs(directory, exist_ok=True)
+
+
+def remove_dir(directory):
+ """
+ 删除目录及其内容
+
+ Args:
+ directory: 目录路径
+
+ Returns:
+ 成功标志
+ """
+ try:
+ if os.path.exists(directory):
+ shutil.rmtree(directory)
+ return True
+ except Exception as e:
+ logger.error(f"删除目录 {directory} 时出错: {str(e)}")
+ return False
\ No newline at end of file
diff --git a/utils/logger.py b/utils/logger.py
new file mode 100644
index 0000000..9518626
--- /dev/null
+++ b/utils/logger.py
@@ -0,0 +1,115 @@
+"""
+日志工具模块
+"""
+import logging
+import sys
+from pathlib import Path
+from datetime import datetime
+import os
+
+from config.system_config import LOG_DIR, LOG_LEVEL, LOG_FORMAT
+
+
+def get_logger(name, level=None, log_file=None):
+ """
+ 获取logger实例
+
+ Args:
+ name: logger名称
+ level: 日志级别,默认为系统配置
+ log_file: 日志文件路径,默认为None(仅控制台输出)
+
+ Returns:
+ logger实例
+ """
+ level = level or LOG_LEVEL
+
+ # 创建logger
+ logger = logging.getLogger(name)
+ logger.setLevel(getattr(logging, level))
+
+ # 避免重复添加handler
+ if logger.handlers:
+ return logger
+
+ # 创建格式化器
+ formatter = logging.Formatter(LOG_FORMAT)
+
+ # 创建控制台处理器
+ console_handler = logging.StreamHandler(sys.stdout)
+ console_handler.setFormatter(formatter)
+ logger.addHandler(console_handler)
+
+ # 如果指定了日志文件,创建文件处理器
+ if log_file:
+ log_path = Path(LOG_DIR) / log_file
+ file_handler = logging.FileHandler(log_path, encoding='utf-8')
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+
+ return logger
+
+
+def get_time_logger(name):
+ """
+ 获取带时间戳的logger实例
+
+ Args:
+ name: logger名称
+
+ Returns:
+ logger实例
+ """
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ log_file = f"{name}_{timestamp}.log"
+ return get_logger(name, log_file=log_file)
+
+
+class TrainingLogger:
+ """训练过程日志记录器"""
+
+ def __init__(self, model_name):
+ """
+ 初始化训练日志记录器
+
+ Args:
+ model_name: 模型名称
+ """
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ self.log_file = f"training_{model_name}_{timestamp}.log"
+ self.logger = get_logger(f"training_{model_name}", log_file=self.log_file)
+
+ # 创建CSV日志
+ self.csv_path = Path(LOG_DIR) / f"metrics_{model_name}_{timestamp}.csv"
+ with open(self.csv_path, 'w', encoding='utf-8') as f:
+ f.write("epoch,loss,accuracy,val_loss,val_accuracy\n")
+
+ def log_epoch(self, epoch, metrics):
+ """记录每个epoch的指标"""
+ # 日志记录
+ metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
+ self.logger.info(f"Epoch {epoch}: {metrics_str}")
+
+ # CSV记录
+ csv_line = f"{epoch},{metrics.get('loss', '')},{metrics.get('accuracy', '')}," \
+ f"{metrics.get('val_loss', '')},{metrics.get('val_accuracy', '')}\n"
+ with open(self.csv_path, 'a', encoding='utf-8') as f:
+ f.write(csv_line)
+
+ def log_training_start(self, config):
+ """记录训练开始信息"""
+ self.logger.info("=" * 50)
+ self.logger.info("训练开始")
+ self.logger.info("模型配置:")
+ for key, value in config.items():
+ self.logger.info(f" {key}: {value}")
+ self.logger.info("=" * 50)
+
+ def log_training_end(self, duration, best_metrics):
+ """记录训练结束信息"""
+ self.logger.info("=" * 50)
+ self.logger.info(f"训练结束,总用时: {duration:.2f}秒")
+ self.logger.info("最佳性能:")
+ for key, value in best_metrics.items():
+ self.logger.info(f" {key}: {value:.4f}")
+ self.logger.info("=" * 50)
\ No newline at end of file
diff --git a/utils/text_utils.py b/utils/text_utils.py
new file mode 100644
index 0000000..e69de29
diff --git a/utils/time_utils.py b/utils/time_utils.py
new file mode 100644
index 0000000..e69de29