初始提交,排除大型数据集

This commit is contained in:
superlishunqin 2025-03-08 01:34:36 +08:00
commit ba6d4c40ea
68 changed files with 10410 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
data/raw/THUCNews/

8
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View File

@ -0,0 +1,45 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<Languages>
<language minSize="272" name="Python" />
</Languages>
</inspection_tool>
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="13">
<item index="0" class="java.lang.String" itemvalue="mysql-connector-python" />
<item index="1" class="java.lang.String" itemvalue="Flask" />
<item index="2" class="java.lang.String" itemvalue="pandas" />
<item index="3" class="java.lang.String" itemvalue="boto3" />
<item index="4" class="java.lang.String" itemvalue="botocore" />
<item index="5" class="java.lang.String" itemvalue="flask-mail" />
<item index="6" class="java.lang.String" itemvalue="flask-cors" />
<item index="7" class="java.lang.String" itemvalue="python-dotenv" />
<item index="8" class="java.lang.String" itemvalue="Flask-Bcrypt" />
<item index="9" class="java.lang.String" itemvalue="pytz" />
<item index="10" class="java.lang.String" itemvalue="Flask-Session" />
<item index="11" class="java.lang.String" itemvalue="DBUtils" />
<item index="12" class="java.lang.String" itemvalue="PyMySQL" />
</list>
</value>
</option>
</inspection_tool>
<inspection_tool class="PyPep8Inspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<option name="ignoredErrors">
<list>
<option value="E265" />
<option value="E231" />
<option value="E262" />
<option value="E225" />
<option value="E402" />
<option value="E271" />
<option value="E302" />
</list>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

4
.idea/misc.xml generated Normal file
View File

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/text_classification.iml" filepath="$PROJECT_DIR$/.idea/text_classification.iml" />
</modules>
</component>
</project>

19
.idea/text_classification.iml generated Normal file
View File

@ -0,0 +1,19 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.11" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
<component name="TemplatesService">
<option name="TEMPLATE_FOLDERS">
<list>
<option value="$MODULE_DIR$/interface/web/templates" />
</list>
</option>
</component>
</module>

6
.idea/vcs.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

281
README.md Normal file
View File

@ -0,0 +1,281 @@
# 基于Python的中文文本分类系统
本项目是一个全面的中文文本分类系统它基于TensorFlow框架使用深度学习方法对中文文本进行自动分类。系统采用分层架构设计由数据层、处理层、模型层和接口层四个主要层次构成。
## 功能特点
- 支持多种深度学习模型CNN、RNN/LSTM、Transformer
- 完善的文本预处理:中文分词、停用词过滤、特征工程
- 丰富的评估指标准确率、精确率、召回率、F1分数等
- 多种接口方式命令行、Web界面、REST API
- 批量处理能力:支持批量文本、文件上传
## 系统架构
1. **数据层**:负责原始文本数据的存储和管理
2. **处理层**:实现文本预处理和特征工程
3. **模型层**:包含核心的分类模型,负责模型训练和预测
4. **接口层**:提供用户交互界面,支持数据输入输出和结果展示
## 安装方法
### 环境需求
- Python 3.7+
- TensorFlow 2.5+
- 其他依赖见 `requirements.txt`
### 安装步骤
1. 克隆仓库:
```bash
git clone https://git.sq0715.com/qin/Chinese_Text_Classification_System.git
cd chinese-text-classification
```
2. 创建虚拟环境(可选):
```bash
python -m venv venv
source venv/bin/activate # Linux/Mac
venv\Scripts\activate # Windows
```
3. 安装依赖:
```bash
pip install -r requirements.txt
```
4. 安装项目:
```bash
pip install -e .
```
## 使用方法
### 训练模型
使用以下命令训练模型:
```bash
python main.py train --data_dir path/to/data --model_type cnn --epochs 10 --batch_size 64
```
参数说明:
- `--data_dir`:数据目录,默认使用配置中的目录
- `--model_type`:模型类型,可选 'cnn', 'rnn', 'transformer'
- `--epochs`:训练轮数
- `--batch_size`:批大小
- `--save_dir`:模型保存目录
### 评估模型
使用以下命令评估模型:
```bash
python main.py evaluate --model_path path/to/model --data_dir path/to/data
```
参数说明:
- `--model_path`:模型路径
- `--data_dir`:数据目录,默认使用配置中的目录
- `--output_dir`:评估结果输出目录
### 预测文本
使用以下命令预测单条文本:
```bash
python main.py predict --text "这是一条测试文本" --model_path path/to/model
```
使用以下命令预测文件内容:
```bash
python main.py predict --file path/to/file.txt --model_path path/to/model
```
### 启动Web服务
使用以下命令启动Web服务
```bash
python main.py web --host 0.0.0.0 --port 5000
```
然后在浏览器中访问 `http://localhost:5000`
### 启动API服务
使用以下命令启动API服务
```bash
python main.py api --host 0.0.0.0 --port 8000
```
然后可以通过 `http://localhost:8000/docs` 查看API文档
### 使用命令行接口
使用以下命令启动交互式命令行接口:
```bash
python main.py cli --interactive
```
## 目录结构
```
text_classification_system/
├── config/ # 配置文件目录
├── data/ # 数据层
├── preprocessing/ # 处理层
├── models/ # 模型层
├── training/ # 训练相关
├── evaluation/ # 评估相关
├── inference/ # 推理相关
├── interface/ # 接口层
├── utils/ # 工具类
├── saved_models/ # 保存的模型
├── tests/ # 测试代码
├── docs/ # 文档
├── scripts/ # 脚本文件
├── main.py # 主入口文件
├── requirements.txt # 依赖列表
├── setup.py # 安装脚本
└── README.md # 项目说明
```
## 数据集
本项目使用的数据集是清华大学开源的THUCNews数据集包含14个类别的新闻文本。您可以在以下位置下载数据集
[THUCNews数据集下载链接](http://thuctc.thunlp.org/)
下载后,将数据解压到 `data/raw/THUCNews` 目录下。
## 许可证
本项目采用 MIT 许可证,详情请参见 LICENSE 文件。
## 贡献指南
欢迎贡献代码、报告问题或提出改进建议。请先fork本仓库然后提交pull request。
## 联系方式
如有问题或建议,请通过 issue 或以下方式联系我们:
- 邮箱your.email@example.com
```
# 如何使用这个文本分类系统
## 1. 安装项目
首先,确保您的环境满足以下要求:
- Python 3.7+
- TensorFlow 2.5+
安装步骤:
```bash
# 克隆代码(假设您已有代码)
cd text_classification
# 创建虚拟环境(可选)
python -m venv venv
source venv/bin/activate # Linux/Mac
# 或者在Windows上
# venv\Scripts\activate
# 安装依赖
pip install -r requirements.txt
# 安装项目(可选)
pip install -e .
```
## 2. 数据准备
确保THUCNews数据集位于正确路径
```
data/raw/THUCNews/
```
数据集应包含多个类别文件夹,每个文件夹中包含对应类别的文本文件。
## 3. 训练模型
使用以下命令训练CNN模型
```bash
python main.py train --model_type cnn --epochs 10 --batch_size 64
```
您也可以尝试其他模型类型:
- RNN/LSTM模型: `--model_type rnn`
- Transformer模型: `--model_type transformer`
训练完成后,模型将保存在`saved_models/classifiers/`目录下。
## 4. 评估模型
使用以下命令评估模型性能:
```bash
python main.py evaluate --model_path path/to/your/model
```
评估结果将包括准确率、精确率、召回率、F1分数等指标以及混淆矩阵可视化。
## 5. 使用模型预测
有多种方式使用训练好的模型:
### 命令行预测
```bash
# 预测单条文本
python main.py predict --text "2021年羽毛球冠军是林丹"
# 预测文件内容
python main.py predict --file path/to/your/file.txt
```
### 启动Web界面
```bash
python main.py web --port 5000
```
然后在浏览器中访问`http://localhost:5000`
### 启动API服务
```bash
python main.py api --port 8000
```
API文档可在`http://localhost:8000/docs`查看
### 交互式命令行
```bash
python main.py cli --interactive
```
## 6. 系统扩展
如需扩展系统功能:
1. 添加新模型:在`models/`目录下添加新的模型类
2. 调整预处理:修改`preprocessing/`目录下的相关模块
3. 添加新接口:在`interface/`目录下进行扩展
## 7. 常见问题解决
- 内存不足:减小`batch_size`或使用数据生成器
- 训练速度慢调整模型复杂度考虑使用GPU加速
- 分类准确率低:尝试不同模型结构,增加数据预处理步骤,调整超参数
这个完整的中文文本分类系统现在已经准备就绪,您可以根据上述说明开始训练和使用!

0
config/__init__.py Normal file
View File

72
config/model_config.py Normal file
View File

@ -0,0 +1,72 @@
"""
模型配置文件
"""
# 文本预处理参数
MAX_SEQUENCE_LENGTH = 500 # 文本序列最大长度
MAX_NUM_WORDS = 50000 # 词汇表最大大小
MAX_CHAR_LENGTH = 2000 # 字符级最大长度
MIN_WORD_FREQUENCY = 5 # 最小词频
# 模型架构参数
CNN_CONFIG = {
"embedding_dim": 200,
"num_filters": 256,
"filter_sizes": [3, 4, 5],
"dropout_rate": 0.5,
"l2_reg_lambda": 0.0,
}
RNN_CONFIG = {
"embedding_dim": 200,
"hidden_size": 256,
"num_layers": 2,
"bidirectional": True,
"dropout_rate": 0.5,
}
TRANSFORMER_CONFIG = {
"embedding_dim": 200,
"num_heads": 8,
"ff_dim": 512,
"num_layers": 4,
"dropout_rate": 0.1,
}
# 针对RTX 4090的优化设置
BATCH_SIZE = 128 # RTX 4090有24GB显存可以支持较大的batch
EVAL_BATCH_SIZE = 256 # 评估时可以用更大的batch
# 训练参数
LEARNING_RATE = 1e-3
NUM_EPOCHS = 20
EARLY_STOPPING_PATIENCE = 3
REDUCE_LR_PATIENCE = 2
REDUCE_LR_FACTOR = 0.5
VALIDATION_SPLIT = 0.1
TEST_SPLIT = 0.1
# 词嵌入参数
USE_PRETRAINED_EMBEDDING = True
EMBEDDING_TYPE = "word2vec" # 可选: word2vec, glove, fasttext
# 随机种子,保证实验可重复性
RANDOM_SEED = 42
# 模型保存参数
SAVE_BEST_ONLY = True
MODEL_CHECKPOINT_PATH = "best_model.h5"
# 特征工程参数
USE_CHAR_LEVEL = False # 是否使用字符级特征
USE_WORD_LEVEL = True # 是否使用词级特征
USE_TFIDF = False # 是否使用TF-IDF特征
USE_POS_TAGS = False # 是否使用词性标注特征
# 数据增强参数
USE_DATA_AUGMENTATION = False
AUGMENTATION_FACTOR = 0.2 # 增强20%的数据
# 推理参数
PREDICTION_THRESHOLD = 0.5
TOP_K_PREDICTIONS = 3

71
config/system_config.py Normal file
View File

@ -0,0 +1,71 @@
"""
系统全局配置文件
"""
import os
import platform
from pathlib import Path
# 项目根目录
ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
"""
Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 是当前文件的上一级目录
这种写法主要是为了方便移植项目到不同的平台运行
"""
# 数据相关路径
DATA_DIR = ROOT_DIR / "data"
RAW_DATA_DIR = DATA_DIR / "raw" / "THUCNews"
PROCESSED_DATA_DIR = DATA_DIR / "processed"
RESOURCES_DIR = DATA_DIR / "resources"
STOPWORDS_DIR = RESOURCES_DIR / "stopwords"
EMBEDDINGS_DIR = RESOURCES_DIR / "embeddings"
# 确保必要的目录存在
for directory in [PROCESSED_DATA_DIR, RESOURCES_DIR, STOPWORDS_DIR, EMBEDDINGS_DIR]:
directory.mkdir(parents=True, exist_ok=True)
# 保存模型的路径
SAVED_MODELS_DIR = ROOT_DIR / "saved_models"
TOKENIZERS_DIR = SAVED_MODELS_DIR / "tokenizers"
CLASSIFIERS_DIR = SAVED_MODELS_DIR / "classifiers"
# 确保模型保存目录存在
for directory in [SAVED_MODELS_DIR, TOKENIZERS_DIR, CLASSIFIERS_DIR]:
directory.mkdir(parents=True, exist_ok=True)
# 系统资源配置
CPU_COUNT = os.cpu_count()
USE_GPU = True
MULTI_GPU = False # 目前只使用单个GPU
# 基于13900K性能设置并行处理参数
DATA_LOADING_WORKERS = min(16, CPU_COUNT) # 数据加载线程数
PREPROCESSING_WORKERS = min(24, CPU_COUNT) # 预处理线程数13900K有强大的多线程能力
# 基于64GB内存设置内存相关参数
MAX_MEMORY_GB = 48 # 保留部分内存给系统和其他应用
MAX_TEXT_PER_BATCH = 10000 # 每批处理的最大文本数量
# 日志配置
LOG_DIR = ROOT_DIR / "logs"
LOG_DIR.mkdir(exist_ok=True)
LOG_LEVEL = "INFO"
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# 类别标签映射与THUCNews数据集一致
CATEGORIES = [
"体育", "娱乐", "家居", "彩票", "房产", "教育",
"时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"
]
CATEGORY_TO_ID = {category: idx for idx, category in enumerate(CATEGORIES)}
ID_TO_CATEGORY = {idx: category for idx, category in enumerate(CATEGORIES)}
# 文件编码
ENCODING = "utf-8"
# 系统信息
SYSTEM_INFO = {
"platform": platform.platform(),
"python_version": platform.python_version(),
"processor": platform.processor(),
}

113
create_project_structure.sh Executable file
View File

@ -0,0 +1,113 @@
#!/bin/bash
# 项目结构创建脚本
# 用于创建文本分类系统的完整目录结构
# 作者AI助手
# 日期2023
echo "开始创建项目结构..."
# 创建配置目录和文件
mkdir -p config
touch config/__init__.py
touch config/model_config.py
touch config/system_config.py
# 创建数据层目录和文件
mkdir -p data/raw data/processed data/resources/stopwords data/resources/embeddings
touch data/__init__.py
touch data/dataloader.py
touch data/data_manager.py
touch data/dataset.py
# 创建处理层目录和文件
mkdir -p preprocessing
touch preprocessing/__init__.py
touch preprocessing/text_cleaner.py
touch preprocessing/tokenization.py
touch preprocessing/feature_extraction.py
touch preprocessing/vectorizer.py
touch preprocessing/data_augmentation.py
# 创建模型层目录和文件
mkdir -p models/layers
touch models/__init__.py
touch models/base_model.py
touch models/cnn_model.py
touch models/rnn_model.py
touch models/transformer_model.py
touch models/ensemble_model.py
touch models/model_factory.py
touch models/layers/__init__.py
# 创建训练模块
mkdir -p training
touch training/__init__.py
touch training/trainer.py
touch training/optimizer.py
touch training/callbacks.py
touch training/scheduler.py
# 创建评估模块
mkdir -p evaluation
touch evaluation/__init__.py
touch evaluation/evaluator.py
touch evaluation/metrics.py
touch evaluation/visualization.py
# 创建推理模块
mkdir -p inference
touch inference/__init__.py
touch inference/predictor.py
touch inference/batch_processor.py
# 创建接口层
mkdir -p interface/web/templates
touch interface/__init__.py
touch interface/cli.py
touch interface/api.py
touch interface/web/__init__.py
touch interface/web/app.py
touch interface/web/routes.py
# 创建工具模块
mkdir -p utils
touch utils/__init__.py
touch utils/logger.py
touch utils/file_utils.py
touch utils/time_utils.py
touch utils/text_utils.py
# 创建模型保存目录
mkdir -p saved_models/tokenizers saved_models/classifiers
# 创建测试目录
mkdir -p tests
touch tests/__init__.py
touch tests/test_preprocessing.py
touch tests/test_models.py
touch tests/test_evaluation.py
# 创建文档目录
mkdir -p docs
touch docs/architecture.md
touch docs/api_reference.md
touch docs/usage_guide.md
# 创建脚本目录
mkdir -p scripts
touch scripts/train.py
touch scripts/evaluate.py
touch scripts/predict.py
# 创建主要文件
touch main.py
touch requirements.txt
touch setup.py
touch README.md
echo "项目结构创建完成!"
echo "------------------------------"
echo "目录结构概览:"
find . -type d | sort
echo "------------------------------"
echo "文件总数: $(find . -type f | wc -l)"

0
data/__init__.py Normal file
View File

583
data/data_manager.py Normal file
View File

@ -0,0 +1,583 @@
"""
数据管理模块负责数据的存储读取和转换
"""
import os
import pickle
import json
import time
from typing import List, Dict, Tuple, Optional, Any, Union
import numpy as np
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from config.system_config import (
PROCESSED_DATA_DIR, ENCODING, CATEGORY_TO_ID, ID_TO_CATEGORY
)
from config.model_config import (
VALIDATION_SPLIT, TEST_SPLIT, RANDOM_SEED
)
from utils.logger import get_logger
from utils.file_utils import (
save_pickle, load_pickle, save_json, load_json, ensure_dir
)
from data.dataloader import DataLoader
logger = get_logger("DataManager")
class DataManager:
"""数据管理类,负责数据的存储、读取和转换"""
def __init__(self, processed_dir: Optional[str] = None):
"""
初始化数据管理器
Args:
processed_dir: 处理后数据的存储目录默认使用配置文件中的路径
"""
self.processed_dir = processed_dir or PROCESSED_DATA_DIR
ensure_dir(self.processed_dir)
# 数据分割后的存储
self.train_texts = []
self.train_labels = []
self.val_texts = []
self.val_labels = []
self.test_texts = []
self.test_labels = []
# 数据统计信息
self.stats = {}
# 标签编码映射
self.label_to_id = CATEGORY_TO_ID
self.id_to_label = ID_TO_CATEGORY
logger.info(f"数据管理器初始化完成,处理后数据将存储在 {self.processed_dir}")
def load_and_split_data(self, data_loader: DataLoader,
categories: Optional[List[str]] = None,
val_split: float = VALIDATION_SPLIT,
test_split: float = TEST_SPLIT,
sample_ratio: float = 1.0,
balanced: bool = False,
n_per_category: int = 1000,
save: bool = True) -> Dict[str, Any]:
"""
加载并分割数据集
Args:
data_loader: 数据加载器实例
categories: 要包含的类别列表默认为所有类别
val_split: 验证集比例
test_split: 测试集比例
sample_ratio: 采样比例默认为1.0全部
balanced: 是否平衡各类别的样本数量
n_per_category: 平衡模式下每个类别的样本数量
save: 是否保存处理后的数据
Returns:
包含分割后数据集的字典
"""
start_time = time.time()
# 加载数据
if balanced:
logger.info(f"加载平衡数据集,每个类别 {n_per_category} 个样本")
data = data_loader.load_balanced_data(
n_per_category=n_per_category,
categories=categories,
shuffle=True
)
else:
logger.info(f"加载数据集,采样比例 {sample_ratio}")
data = data_loader.load_data(
categories=categories,
sample_ratio=sample_ratio,
shuffle=True,
return_generator=False
)
logger.info(f"加载了 {len(data)} 个样本")
# 分离文本和标签
texts = [text for text, _ in data]
labels = [label for _, label in data]
# 进行标签编码
encoded_labels = np.array([self.label_to_id[label] for label in labels])
# 计算数据统计信息
self._compute_stats(texts, labels)
# 划分训练集、验证集和测试集
# 先分出测试集
if test_split > 0:
train_val_texts, self.test_texts, train_val_labels, self.test_labels = train_test_split(
texts, encoded_labels,
test_size=test_split,
random_state=RANDOM_SEED,
stratify=encoded_labels if len(set(encoded_labels)) > 1 else None
)
else:
train_val_texts, train_val_labels = texts, encoded_labels
self.test_texts, self.test_labels = [], []
# 再划分训练集和验证集
if val_split > 0:
self.train_texts, self.val_texts, self.train_labels, self.val_labels = train_test_split(
train_val_texts, train_val_labels,
test_size=val_split / (1 - test_split),
random_state=RANDOM_SEED,
stratify=train_val_labels if len(set(train_val_labels)) > 1 else None
)
else:
self.train_texts, self.train_labels = train_val_texts, train_val_labels
self.val_texts, self.val_labels = [], []
# 打印数据集划分结果
logger.info(f"数据集划分结果:")
logger.info(f" 训练集:{len(self.train_texts)} 个样本")
logger.info(f" 验证集:{len(self.val_texts)} 个样本")
logger.info(f" 测试集:{len(self.test_texts)} 个样本")
# 保存处理后的数据
if save:
self.save_data()
elapsed = time.time() - start_time
logger.info(f"数据加载和分割完成,用时 {elapsed:.2f}")
return {
"train_texts": self.train_texts,
"train_labels": self.train_labels,
"val_texts": self.val_texts,
"val_labels": self.val_labels,
"test_texts": self.test_texts,
"test_labels": self.test_labels,
"stats": self.stats
}
def _compute_stats(self, texts: List[str], labels: List[str]) -> None:
"""
计算数据统计信息
Args:
texts: 文本列表
labels: 标签列表
"""
# 文本数量
num_samples = len(texts)
# 类别分布
label_counter = Counter(labels)
label_distribution = {label: count / num_samples * 100 for label, count in label_counter.items()}
# 文本长度统计
text_lengths = [len(text) for text in texts]
avg_length = sum(text_lengths) / len(text_lengths)
max_length = max(text_lengths)
min_length = min(text_lengths)
# 前5个最长和最短的文本的长度
sorted_lengths = sorted(text_lengths)
shortest_lengths = sorted_lengths[:5]
longest_lengths = sorted_lengths[-5:]
# 95%的文本长度分位数
percentile_95 = np.percentile(text_lengths, 95)
# 存储统计信息
self.stats = {
"num_samples": num_samples,
"num_categories": len(label_counter),
"label_counter": label_counter,
"label_distribution": label_distribution,
"text_length": {
"average": avg_length,
"max": max_length,
"min": min_length,
"percentile_95": percentile_95,
"shortest_5": shortest_lengths,
"longest_5": longest_lengths
}
}
def save_data(self, save_dir: Optional[str] = None) -> None:
"""
保存处理后的数据
Args:
save_dir: 保存目录默认使用初始化时设置的目录
"""
save_dir = save_dir or self.processed_dir
ensure_dir(save_dir)
# 保存训练集
save_pickle(
{"texts": self.train_texts, "labels": self.train_labels},
os.path.join(save_dir, "train_data.pkl")
)
# 保存验证集
if len(self.val_texts) > 0:
save_pickle(
{"texts": self.val_texts, "labels": self.val_labels},
os.path.join(save_dir, "val_data.pkl")
)
# 保存测试集
if len(self.test_texts) > 0:
save_pickle(
{"texts": self.test_texts, "labels": self.test_labels},
os.path.join(save_dir, "test_data.pkl")
)
# 保存标签编码映射
save_json(
{"label_to_id": self.label_to_id, "id_to_label": self.id_to_label},
os.path.join(save_dir, "label_mapping.json")
)
# 保存数据统计信息
# 将Counter对象转换为普通字典以便JSON序列化
stats_for_json = self.stats.copy()
if "label_counter" in stats_for_json:
stats_for_json["label_counter"] = dict(stats_for_json["label_counter"])
save_json(
stats_for_json,
os.path.join(save_dir, "data_stats.json")
)
logger.info(f"已将处理后的数据保存到 {save_dir}")
def load_data(self, load_dir: Optional[str] = None) -> Dict[str, Any]:
"""
加载处理后的数据
Args:
load_dir: 加载目录默认使用初始化时设置的目录
Returns:
包含加载的数据集的字典
"""
load_dir = load_dir or self.processed_dir
# 加载训练集
train_data_path = os.path.join(load_dir, "train_data.pkl")
if os.path.exists(train_data_path):
train_data = load_pickle(train_data_path)
self.train_texts = train_data["texts"]
self.train_labels = train_data["labels"]
logger.info(f"已加载训练集,包含 {len(self.train_texts)} 个样本")
else:
logger.warning(f"训练集文件不存在: {train_data_path}")
self.train_texts, self.train_labels = [], []
# 加载验证集
val_data_path = os.path.join(load_dir, "val_data.pkl")
if os.path.exists(val_data_path):
val_data = load_pickle(val_data_path)
self.val_texts = val_data["texts"]
self.val_labels = val_data["labels"]
logger.info(f"已加载验证集,包含 {len(self.val_texts)} 个样本")
else:
logger.warning(f"验证集文件不存在: {val_data_path}")
self.val_texts, self.val_labels = [], []
# 加载测试集
test_data_path = os.path.join(load_dir, "test_data.pkl")
if os.path.exists(test_data_path):
test_data = load_pickle(test_data_path)
self.test_texts = test_data["texts"]
self.test_labels = test_data["labels"]
logger.info(f"已加载测试集,包含 {len(self.test_texts)} 个样本")
else:
logger.warning(f"测试集文件不存在: {test_data_path}")
self.test_texts, self.test_labels = [], []
# 加载标签编码映射
mapping_path = os.path.join(load_dir, "label_mapping.json")
if os.path.exists(mapping_path):
mapping = load_json(mapping_path)
self.label_to_id = mapping["label_to_id"]
self.id_to_label = mapping["id_to_label"]
# 将字符串键转换为整数JSON序列化会将所有键转为字符串
self.id_to_label = {int(k): v for k, v in self.id_to_label.items()}
logger.info(f"已加载标签编码映射,共 {len(self.label_to_id)} 个类别")
# 加载数据统计信息
stats_path = os.path.join(load_dir, "data_stats.json")
if os.path.exists(stats_path):
self.stats = load_json(stats_path)
logger.info("已加载数据统计信息")
return {
"train_texts": self.train_texts,
"train_labels": self.train_labels,
"val_texts": self.val_texts,
"val_labels": self.val_labels,
"test_texts": self.test_texts,
"test_labels": self.test_labels,
"stats": self.stats
}
def get_label_distribution(self, dataset: str = "train") -> Dict[str, float]:
"""
获取指定数据集的标签分布
Args:
dataset: 数据集名称可选值'train', 'val', 'test'
Returns:
标签分布字典键为类别名称值为比例
"""
if dataset == "train":
labels = self.train_labels
elif dataset == "val":
labels = self.val_labels
elif dataset == "test":
labels = self.test_labels
else:
raise ValueError(f"不支持的数据集名称: {dataset}")
# 计算标签分布
label_counter = Counter(labels)
num_samples = len(labels)
# 将数字标签转换为类别名称
distribution = {}
for label_id, count in label_counter.items():
label_name = self.id_to_label.get(label_id, str(label_id))
distribution[label_name] = count / num_samples * 100
return distribution
def visualize_label_distribution(self, dataset: str = "train",
save_path: Optional[str] = None) -> None:
"""
可视化标签分布
Args:
dataset: 数据集名称可选值'train', 'val', 'test', 'all'
save_path: 图表保存路径默认为None显示而不保存
"""
plt.figure(figsize=(12, 8))
if dataset == "all":
# 显示所有数据集的标签分布
train_dist = self.get_label_distribution("train")
val_dist = self.get_label_distribution("val") if len(self.val_labels) > 0 else {}
test_dist = self.get_label_distribution("test") if len(self.test_labels) > 0 else {}
# 准备数据
categories = list(train_dist.keys())
train_values = [train_dist.get(cat, 0) for cat in categories]
val_values = [val_dist.get(cat, 0) for cat in categories]
test_values = [test_dist.get(cat, 0) for cat in categories]
# 绘制条形图
x = np.arange(len(categories))
width = 0.25
plt.bar(x - width, train_values, width, label="Training")
if val_values:
plt.bar(x, val_values, width, label="Validation")
if test_values:
plt.bar(x + width, test_values, width, label="Testing")
plt.xlabel("Categories")
plt.ylabel("Percentage (%)")
plt.title("Label Distribution Across Datasets")
plt.xticks(x, categories, rotation=45, ha="right")
plt.legend()
plt.tight_layout()
else:
# 显示单个数据集的标签分布
distribution = self.get_label_distribution(dataset)
# 按值排序
sorted_items = sorted(distribution.items(), key=lambda x: x[1], reverse=True)
categories = [item[0] for item in sorted_items]
values = [item[1] for item in sorted_items]
# 绘制条形图
plt.bar(categories, values, color='skyblue')
plt.xlabel("Categories")
plt.ylabel("Percentage (%)")
plt.title(f"Label Distribution in {dataset.capitalize()} Dataset")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
# 保存或显示图表
if save_path:
plt.savefig(save_path)
logger.info(f"标签分布图已保存到 {save_path}")
else:
plt.show()
def visualize_text_length_distribution(self, dataset: str = "train",
bins: int = 50,
save_path: Optional[str] = None) -> None:
"""
可视化文本长度分布
Args:
dataset: 数据集名称可选值'train', 'val', 'test'
bins: 直方图的箱数
save_path: 图表保存路径默认为None显示而不保存
"""
if dataset == "train":
texts = self.train_texts
elif dataset == "val":
texts = self.val_texts
elif dataset == "test":
texts = self.test_texts
else:
raise ValueError(f"不支持的数据集名称: {dataset}")
# 计算文本长度
text_lengths = [len(text) for text in texts]
# 绘制直方图
plt.figure(figsize=(10, 6))
plt.hist(text_lengths, bins=bins, color='skyblue', alpha=0.7)
# 计算并绘制一些统计量
avg_length = sum(text_lengths) / len(text_lengths)
median_length = np.median(text_lengths)
percentile_95 = np.percentile(text_lengths, 95)
plt.axvline(avg_length, color='red', linestyle='dashed', linewidth=1, label=f'Mean: {avg_length:.1f}')
plt.axvline(median_length, color='green', linestyle='dashed', linewidth=1, label=f'Median: {median_length:.1f}')
plt.axvline(percentile_95, color='purple', linestyle='dashed', linewidth=1,
label=f'95th Percentile: {percentile_95:.1f}')
plt.xlabel('Text Length (characters)')
plt.ylabel('Frequency')
plt.title(f'Text Length Distribution in {dataset.capitalize()} Dataset')
plt.legend()
plt.tight_layout()
# 保存或显示图表
if save_path:
plt.savefig(save_path)
logger.info(f"文本长度分布图已保存到 {save_path}")
else:
plt.show()
def get_data_summary(self) -> Dict[str, Any]:
"""
获取数据集的摘要信息
Returns:
包含数据摘要的字典
"""
# 获取数据集的基本信息
summary = {
"train_size": len(self.train_texts),
"val_size": len(self.val_texts),
"test_size": len(self.test_texts),
"num_categories": len(self.label_to_id),
"categories": list(self.label_to_id.keys()),
}
# 添加训练集的标签分布
if len(self.train_texts) > 0:
summary["train_label_distribution"] = self.get_label_distribution("train")
# 添加验证集的标签分布
if len(self.val_texts) > 0:
summary["val_label_distribution"] = self.get_label_distribution("val")
# 添加测试集的标签分布
if len(self.test_texts) > 0:
summary["test_label_distribution"] = self.get_label_distribution("test")
# 添加更多统计信息(如果有)
if self.stats:
# 只添加一些关键的统计信息
if "text_length" in self.stats:
summary["text_length_stats"] = self.stats["text_length"]
return summary
def export_to_pandas(self, dataset: str = "train") -> pd.DataFrame:
"""
将数据导出为Pandas DataFrame
Args:
dataset: 数据集名称可选值'train', 'val', 'test'
Returns:
Pandas DataFrame
"""
if dataset == "train":
texts = self.train_texts
labels_ids = self.train_labels
elif dataset == "val":
texts = self.val_texts
labels_ids = self.val_labels
elif dataset == "test":
texts = self.test_texts
labels_ids = self.test_labels
else:
raise ValueError(f"不支持的数据集名称: {dataset}")
# 将数字标签转换为类别名称
labels = [self.id_to_label.get(label_id, str(label_id)) for label_id in labels_ids]
# 创建DataFrame
df = pd.DataFrame({
"text": texts,
"label_id": labels_ids,
"label": labels
})
return df
def get_label_name(self, label_id: int) -> str:
"""
获取标签ID对应的类别名称
Args:
label_id: 标签ID
Returns:
类别名称
"""
return self.id_to_label.get(label_id, str(label_id))
def get_label_id(self, label_name: str) -> int:
"""
获取类别名称对应的标签ID
Args:
label_name: 类别名称
Returns:
标签ID
"""
return self.label_to_id.get(label_name, -1)
def get_data(self, dataset: str = "train") -> Tuple[List[str], np.ndarray]:
"""
获取指定数据集的文本和标签
Args:
dataset: 数据集名称可选值'train', 'val', 'test'
Returns:
(文本列表, 标签数组)的元组
"""
if dataset == "train":
return self.train_texts, self.train_labels
elif dataset == "val":
return self.val_texts, self.val_labels
elif dataset == "test":
return self.test_texts, self.test_labels
else:
raise ValueError(f"不支持的数据集名称: {dataset}")

296
data/dataloader.py Normal file
View File

@ -0,0 +1,296 @@
"""
数据加载模块负责从文件系统加载原始文本数据
"""
import os
import glob
import time
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
import random
import numpy as np
from config.system_config import (
RAW_DATA_DIR, DATA_LOADING_WORKERS, CATEGORIES,
CATEGORY_TO_ID, ENCODING, MAX_MEMORY_GB, MAX_TEXT_PER_BATCH
)
from config.model_config import RANDOM_SEED
from utils.logger import get_logger
from utils.file_utils import read_text_file, read_files_parallel, list_files
# 设置随机种子以保证可重复性
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
logger = get_logger("DataLoader")
class DataLoader:
"""负责加载THUCNews数据集的类"""
def __init__(self, data_dir: Optional[str] = None,
categories: Optional[List[str]] = None,
encoding: str = ENCODING,
max_workers: int = DATA_LOADING_WORKERS,
max_text_per_batch: int = MAX_TEXT_PER_BATCH):
"""
初始化数据加载器
Args:
data_dir: 数据目录默认使用配置文件中的路径
categories: 要加载的类别列表默认加载所有类别
encoding: 文件编码
max_workers: 最大工作线程数
max_text_per_batch: 每批处理的最大文本数量
"""
self.data_dir = Path(data_dir) if data_dir else RAW_DATA_DIR
self.categories = categories if categories else CATEGORIES
self.encoding = encoding
self.max_workers = max_workers
self.max_text_per_batch = max_text_per_batch
# 验证数据目录是否存在
if not self.data_dir.exists():
raise FileNotFoundError(f"数据目录不存在: {self.data_dir}")
# 验证类别是否存在
for category in self.categories:
category_dir = self.data_dir / category
if not category_dir.exists():
logger.warning(f"类别目录不存在: {category_dir}")
# 存储类别目录的映射
self.category_dirs = {
category: self.data_dir / category
for category in self.categories
if (self.data_dir / category).exists()
}
# 记录类别文件数量
self.category_file_counts = {}
# 统计并记录每个类别的文件数量
self._count_files()
logger.info(f"初始化完成,共找到 {sum(self.category_file_counts.values())} 个文本文件")
def _count_files(self) -> None:
"""统计每个类别的文件数量"""
for category, category_dir in self.category_dirs.items():
files = list(category_dir.glob("*.txt"))
self.category_file_counts[category] = len(files)
logger.info(f"类别 [{category}] 包含 {len(files)} 个文本文件")
def get_file_paths(self, category: Optional[str] = None,
sample_ratio: float = 1.0,
shuffle: bool = True) -> List[Tuple[str, str]]:
"""
获取指定类别的文件路径列表
Args:
category: 类别名称如果为None则获取所有类别
sample_ratio: 采样比例默认为1.0全部
shuffle: 是否打乱文件顺序
Returns:
包含(文件路径, 类别)元组的列表
"""
file_paths = []
# 确定要处理的类别
categories_to_process = [category] if category else self.categories
# 获取每个类别的文件路径
for cat in categories_to_process:
if cat in self.category_dirs:
category_dir = self.category_dirs[cat]
cat_files = list(category_dir.glob("*.txt"))
# 采样
if sample_ratio < 1.0:
sample_size = int(len(cat_files) * sample_ratio)
if shuffle:
cat_files = random.sample(cat_files, sample_size)
else:
cat_files = cat_files[:sample_size]
# 添加文件路径和对应的类别
file_paths.extend([(str(file), cat) for file in cat_files])
# 打乱全局顺序(如果需要)
if shuffle:
random.shuffle(file_paths)
return file_paths
def load_texts(self, file_paths: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""
加载指定路径的文本内容
Args:
file_paths: 包含(文件路径, 类别)元组的列表
Returns:
包含(文本内容, 类别)元组的列表
"""
start_time = time.time()
texts_with_labels = []
# 提取文件路径列表
paths = [path for path, _ in file_paths]
labels = [label for _, label in file_paths]
# 并行加载文本内容
contents = read_files_parallel(paths, max_workers=self.max_workers, encoding=self.encoding)
# 将内容与标签配对
for content, label in zip(contents, labels):
if content: # 确保内容不为空
texts_with_labels.append((content, label))
elapsed = time.time() - start_time
logger.info(f"加载了 {len(texts_with_labels)} 个文本,用时 {elapsed:.2f}")
return texts_with_labels
def load_data(self, categories: Optional[List[str]] = None,
sample_ratio: float = 1.0,
shuffle: bool = True,
return_generator: bool = False) -> Any:
"""
加载指定类别的所有数据
Args:
categories: 要加载的类别列表默认为所有类别
sample_ratio: 采样比例默认为1.0全部
shuffle: 是否打乱数据顺序
return_generator: 是否返回生成器批量加载
Returns:
如果return_generator为False返回包含(文本内容, 类别)元组的列表
如果return_generator为True返回一个生成器每次产生一批数据
"""
# 确定要处理的类别
cats_to_process = categories if categories else self.categories
# 验证类别是否存在
for cat in cats_to_process:
if cat not in self.category_dirs:
logger.warning(f"类别 {cat} 不存在,将被忽略")
# 筛选存在的类别
cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs]
# 获取所有文件路径
all_file_paths = []
for cat in cats_to_process:
cat_files = self.get_file_paths(cat, sample_ratio=sample_ratio, shuffle=shuffle)
all_file_paths.extend(cat_files)
# 打乱全局顺序(如果需要)
if shuffle:
random.shuffle(all_file_paths)
# 如果需要返回生成器,分批次加载数据
if return_generator:
def data_generator():
for i in range(0, len(all_file_paths), self.max_text_per_batch):
batch_paths = all_file_paths[i:i + self.max_text_per_batch]
batch_data = self.load_texts(batch_paths)
yield batch_data
return data_generator()
# 否则,一次性加载所有数据
return self.load_texts(all_file_paths)
def load_balanced_data(self, n_per_category: int = 1000,
categories: Optional[List[str]] = None,
shuffle: bool = True) -> List[Tuple[str, str]]:
"""
加载平衡的数据集每个类别的样本数量相同
Args:
n_per_category: 每个类别加载的样本数量
categories: 要加载的类别列表默认为所有类别
shuffle: 是否打乱数据顺序
Returns:
包含(文本内容, 类别)元组的列表
"""
# 确定要处理的类别
cats_to_process = categories if categories else self.categories
cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs]
balanced_data = []
for cat in cats_to_process:
# 获取该类别的文件路径
cat_files = self.get_file_paths(cat, shuffle=shuffle)
# 限制数量
cat_files = cat_files[:n_per_category]
# 加载文本
cat_data = self.load_texts(cat_files)
balanced_data.extend(cat_data)
# 打乱全局顺序(如果需要)
if shuffle:
random.shuffle(balanced_data)
return balanced_data
def get_category_distribution(self) -> Dict[str, int]:
"""
获取数据集的类别分布
Returns:
包含各类别样本数量的字典
"""
return self.category_file_counts
def get_data_stats(self) -> Dict[str, Any]:
"""
获取数据集的统计信息
Returns:
包含统计信息的字典
"""
# 计算总样本数
total_samples = sum(self.category_file_counts.values())
# 计算各类别占比
category_percentages = {
cat: count / total_samples * 100
for cat, count in self.category_file_counts.items()
}
# 采样几个文件计算平均文本长度
sample_files = []
for cat in self.categories:
if cat in self.category_dirs:
cat_files = list((self.data_dir / cat).glob("*.txt"))
if cat_files:
# 每个类别最多采样10个文件
sample_files.extend(random.sample(cat_files, min(10, len(cat_files))))
# 加载采样的文件内容
sample_contents = []
for file_path in sample_files:
content = read_text_file(str(file_path), encoding=self.encoding)
if content:
sample_contents.append(content)
# 计算平均文本长度(字符数)
avg_char_length = sum(len(content) for content in sample_contents) / len(
sample_contents) if sample_contents else 0
# 返回统计信息
return {
"total_samples": total_samples,
"category_counts": self.category_file_counts,
"category_percentages": category_percentages,
"average_text_length": avg_char_length,
"categories": self.categories,
"num_categories": len(self.categories),
}

0
data/dataset.py Normal file
View File

BIN
data/raw/.DS_Store vendored Normal file

Binary file not shown.

0
docs/api_reference.md Normal file
View File

0
docs/architecture.md Normal file
View File

0
docs/usage_guide.md Normal file
View File

0
evaluation/__init__.py Normal file
View File

491
evaluation/evaluator.py Normal file
View File

@ -0,0 +1,491 @@
"""
评估器模块实现模型评估流程
"""
import numpy as np
import tensorflow as tf
import time
import os
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
import pandas as pd
import matplotlib.pyplot as plt
import json
from config.system_config import SAVED_MODELS_DIR
from models.base_model import TextClassificationModel
from evaluation.metrics import ClassificationMetrics
from utils.logger import get_logger
from utils.file_utils import ensure_dir, save_json
logger = get_logger("Evaluator")
class ModelEvaluator:
"""模型评估器,负责评估模型性能"""
def __init__(self, model: TextClassificationModel,
class_names: Optional[List[str]] = None,
output_dir: Optional[str] = None,
batch_size: Optional[int] = None):
"""
初始化模型评估器
Args:
model: 要评估的模型
class_names: 类别名称列表
output_dir: 输出目录用于保存评估结果
batch_size: 批大小如果为None则使用模型默认值
"""
self.model = model
self.class_names = class_names
self.batch_size = batch_size or model.batch_size
# 设置输出目录
if output_dir is None:
self.output_dir = os.path.join(SAVED_MODELS_DIR, 'evaluation', model.model_name)
else:
self.output_dir = output_dir
ensure_dir(self.output_dir)
# 创建评估指标计算器
self.metrics = ClassificationMetrics(class_names)
# 评估结果
self.evaluation_results = None
logger.info(f"初始化模型评估器,模型: {model.model_name}")
def evaluate(self, x_test: Union[np.ndarray, tf.data.Dataset],
y_test: Optional[np.ndarray] = None,
batch_size: Optional[int] = None,
verbose: int = 1) -> Dict[str, float]:
"""
评估模型
Args:
x_test: 测试数据特征
y_test: 测试数据标签
batch_size: 批大小
verbose: 详细程度
Returns:
评估结果
"""
batch_size = batch_size or self.batch_size
logger.info(f"开始评估模型: {self.model.model_name}")
start_time = time.time()
# 使用模型评估
model_metrics = self.model.evaluate(x_test, y_test, verbose=verbose)
# 获取预测结果
y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=0)
y_pred = np.argmax(y_prob, axis=1)
# 处理y_test确保y_test是一维数组
if isinstance(x_test, tf.data.Dataset):
# 如果是TensorFlow Dataset需要从中提取y_test
y_test_extracted = np.concatenate([y for _, y in x_test], axis=0)
y_test = y_test_extracted
if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
y_test = np.argmax(y_test, axis=1)
# 计算所有指标
all_metrics = self.metrics.calculate_all_metrics(y_test, y_pred, y_prob)
# 合并模型内置指标和自定义指标
metrics_names = self.model.model.metrics_names
model_metrics_dict = {name: float(value) for name, value in zip(metrics_names, model_metrics)}
all_metrics.update(model_metrics_dict)
# 记录评估时间
evaluation_time = time.time() - start_time
all_metrics['evaluation_time'] = evaluation_time
# 保存评估结果
self.evaluation_results = {
'metrics': all_metrics,
'confusion_matrix': self.metrics.confusion_matrix(y_test, y_pred).tolist(),
'classification_report': self.metrics.classification_report(y_test, y_pred, output_dict=True)
}
logger.info(f"模型评估完成,用时: {evaluation_time:.2f}")
logger.info(f"主要评估指标: accuracy={all_metrics.get('accuracy', 'N/A'):.4f}, "
f"f1_macro={all_metrics.get('f1_macro', 'N/A'):.4f}")
return all_metrics
def save_evaluation_results(self, save_plots: bool = True) -> str:
"""
保存评估结果
Args:
save_plots: 是否保存可视化图表
Returns:
结果保存路径
"""
if self.evaluation_results is None:
raise ValueError("请先调用evaluate方法进行评估")
# 保存评估结果为JSON
results_path = os.path.join(self.output_dir, 'evaluation_results.json')
with open(results_path, 'w', encoding='utf-8') as f:
json.dump(self.evaluation_results, f, ensure_ascii=False, indent=4)
# 保存评估指标为CSV
metrics_df = pd.DataFrame(
self.evaluation_results['metrics'].items(),
columns=['Metric', 'Value']
).set_index('Metric')
metrics_path = os.path.join(self.output_dir, 'metrics.csv')
metrics_df.to_csv(metrics_path)
# 保存可视化图表
if save_plots:
self._save_plots()
logger.info(f"评估结果已保存到: {self.output_dir}")
return self.output_dir
def _save_plots(self) -> None:
"""保存评估结果可视化图表"""
if self.evaluation_results is None:
raise ValueError("请先调用evaluate方法进行评估")
# 创建可视化目录
plots_dir = os.path.join(self.output_dir, 'plots')
ensure_dir(plots_dir)
# 混淆矩阵图
cm_path = os.path.join(plots_dir, 'confusion_matrix.png')
cm = np.array(self.evaluation_results['confusion_matrix'])
# 将混淆矩阵转换为NumPy数组
if isinstance(cm, list):
cm = np.array(cm)
# 绘制混淆矩阵
self.metrics.plot_confusion_matrix(
np.arange(cm.shape[0]), # 假设标签
np.arange(cm.shape[1]), # 假设预测
normalize='true',
save_path=cm_path
)
# 保存评估指标条形图
metrics_path = os.path.join(plots_dir, 'metrics_bar.png')
metrics = self.evaluation_results['metrics']
# 选择要展示的主要指标
main_metrics = {
'accuracy': metrics.get('accuracy', 0),
'precision_macro': metrics.get('precision_macro', 0),
'recall_macro': metrics.get('recall_macro', 0),
'f1_macro': metrics.get('f1_macro', 0)
}
# 绘制条形图
plt.figure(figsize=(10, 6))
plt.bar(main_metrics.keys(), main_metrics.values())
plt.title('Main Evaluation Metrics')
plt.ylabel('Score')
plt.ylim(0, 1)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig(metrics_path)
plt.close()
# 如果有类别级别的指标,绘制每个类别的指标
if 'classification_report' in self.evaluation_results:
report = self.evaluation_results['classification_report']
# 提取每个类别的精确率、召回率和F1值
class_metrics = {}
for key, value in report.items():
if key not in ['accuracy', 'macro avg', 'weighted avg', 'samples avg']:
if isinstance(value, dict):
class_metrics[key] = value
if class_metrics:
# 绘制每个类别的F1分数
class_f1_path = os.path.join(plots_dir, 'class_f1_scores.png')
classes = list(class_metrics.keys())
f1_scores = [metrics['f1-score'] for metrics in class_metrics.values()]
plt.figure(figsize=(12, 6))
bars = plt.bar(classes, f1_scores)
# 在柱状图上方显示数值
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width() / 2., height + 0.01,
f'{height:.2f}',
ha='center', va='bottom', rotation=0)
plt.title('F1 Score by Class')
plt.ylabel('F1 Score')
plt.ylim(0, 1.1)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig(class_f1_path)
plt.close()
# 绘制每个类别的精确率和召回率
class_prec_rec_path = os.path.join(plots_dir, 'class_precision_recall.png')
precisions = [metrics['precision'] for metrics in class_metrics.values()]
recalls = [metrics['recall'] for metrics in class_metrics.values()]
plt.figure(figsize=(12, 6))
x = np.arange(len(classes))
width = 0.35
plt.bar(x - width / 2, precisions, width, label='Precision')
plt.bar(x + width / 2, recalls, width, label='Recall')
plt.ylabel('Score')
plt.title('Precision and Recall by Class')
plt.xticks(x, classes, rotation=45, ha='right')
plt.legend()
plt.ylim(0, 1.1)
plt.tight_layout()
plt.savefig(class_prec_rec_path)
plt.close()
logger.info(f"评估可视化图表已保存到: {plots_dir}")
def compare_models(self, other_evaluators: List['ModelEvaluator'],
metrics: Optional[List[str]] = None,
save_path: Optional[str] = None) -> pd.DataFrame:
"""
比较多个模型的评估结果
Args:
other_evaluators: 其他模型评估器列表
metrics: 要比较的指标列表默认为['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
save_path: 比较结果的保存路径
Returns:
比较结果DataFrame
"""
if self.evaluation_results is None:
raise ValueError("请先调用evaluate方法进行评估")
# 默认比较指标
if metrics is None:
metrics = ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
# 收集所有模型的评估指标
models_metrics = {}
# 当前模型
models_metrics[self.model.model_name] = {
metric: self.evaluation_results['metrics'].get(metric, 'N/A')
for metric in metrics
}
# 其他模型
for evaluator in other_evaluators:
if evaluator.evaluation_results is None:
logger.warning(f"模型 {evaluator.model.model_name} 尚未评估,跳过")
continue
models_metrics[evaluator.model.model_name] = {
metric: evaluator.evaluation_results['metrics'].get(metric, 'N/A')
for metric in metrics
}
# 创建比较DataFrame
comparison_df = pd.DataFrame(models_metrics).T
# 保存比较结果
if save_path:
comparison_df.to_csv(save_path)
logger.info(f"模型比较结果已保存到: {save_path}")
# 绘制比较条形图
plt.figure(figsize=(12, 6))
comparison_df.plot(kind='bar', figsize=(12, 6))
plt.title('Model Comparison')
plt.ylabel('Score')
plt.ylim(0, 1)
plt.legend(loc='lower right')
plt.tight_layout()
# 如果save_path是CSV文件将其替换为PNG文件
if save_path.endswith('.csv'):
plot_path = save_path.replace('.csv', '.png')
else:
plot_path = save_path + '.png'
plt.savefig(plot_path)
plt.close()
logger.info(f"模型比较图表已保存到: {plot_path}")
return comparison_df
def evaluate_class_performance(self, x_test: Union[np.ndarray, tf.data.Dataset],
y_test: Optional[np.ndarray] = None,
batch_size: Optional[int] = None,
verbose: int = 0) -> pd.DataFrame:
"""
评估模型在各个类别上的性能
Args:
x_test: 测试数据特征
y_test: 测试数据标签
batch_size: 批大小
verbose: 详细程度
Returns:
各类别性能指标DataFrame
"""
batch_size = batch_size or self.batch_size
# 获取预测结果
y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=verbose)
y_pred = np.argmax(y_prob, axis=1)
# 处理y_test确保y_test是一维数组
if isinstance(x_test, tf.data.Dataset):
# 如果是TensorFlow Dataset需要从中提取y_test
y_test_extracted = np.concatenate([y for _, y in x_test], axis=0)
y_test = y_test_extracted
if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
y_test = np.argmax(y_test, axis=1)
# 获取分类报告
report = self.metrics.classification_report(y_test, y_pred, output_dict=True)
# 提取各类别指标
class_metrics = {}
for key, value in report.items():
if key not in ['accuracy', 'macro avg', 'weighted avg', 'samples avg']:
if isinstance(value, dict):
class_metrics[key] = value
# 转换为DataFrame
class_performance_df = pd.DataFrame(class_metrics).T
# 添加支持度(样本数量)
class_counts = np.bincount(y_test)
for idx, count in enumerate(class_counts):
if str(idx) in class_performance_df.index:
class_performance_df.loc[str(idx), 'support'] = count
# 添加类别名称
if self.class_names:
class_performance_df['class_name'] = [
self.class_names[int(idx)] if int(idx) < len(self.class_names) else idx
for idx in class_performance_df.index
]
# 保存类别性能指标
performance_path = os.path.join(self.output_dir, 'class_performance.csv')
class_performance_df.to_csv(performance_path)
logger.info(f"各类别性能指标已保存到: {performance_path}")
return class_performance_df
def plot_error_analysis(self, x_test: Union[np.ndarray, tf.data.Dataset],
y_test: Optional[np.ndarray] = None,
batch_size: Optional[int] = None,
num_samples: int = 10,
save_path: Optional[str] = None) -> None:
"""
绘制误分类样本分析
Args:
x_test: 测试数据特征
y_test: 测试数据标签
batch_size: 批大小
num_samples: 要展示的误分类样本数量
save_path: 保存路径
"""
# 仅适用于文本数据的分析,需要原始文本
logger.info("误分类样本分析需要原始文本数据,此方法可能需要根据实际数据类型进行修改")
# 在实际应用中,这里应该根据实际数据类型进行修改
# 例如,对于序列化的文本,可能需要反序列化,或者使用词汇表将索引转换回文本
# 此处仅展示一个基本框架
batch_size = batch_size or self.batch_size
# 获取预测结果
y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=0)
y_pred = np.argmax(y_prob, axis=1)
# 处理y_test确保y_test是一维数组
if isinstance(x_test, tf.data.Dataset):
# 如果是TensorFlow Dataset需要从中提取y_test和x_test
dataset_iterator = iter(x_test)
x_test_extracted = []
y_test_extracted = []
try:
while True:
x, y = next(dataset_iterator)
x_test_extracted.append(x)
y_test_extracted.append(y)
except StopIteration:
pass
x_test = np.concatenate(x_test_extracted, axis=0)
y_test = np.concatenate(y_test_extracted, axis=0)
if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
y_test = np.argmax(y_test, axis=1)
# 找出误分类样本
misclassified_indices = np.where(y_pred != y_test)[0]
# 如果没有误分类样本,返回
if len(misclassified_indices) == 0:
logger.info("没有误分类样本")
return
# 随机选择一些误分类样本
if len(misclassified_indices) > num_samples:
misclassified_indices = np.random.choice(misclassified_indices, num_samples, replace=False)
# 保存误分类样本分析结果
misclassified_data = []
for idx in misclassified_indices:
true_label = y_test[idx]
pred_label = y_pred[idx]
true_class = self.class_names[true_label] if self.class_names else str(true_label)
pred_class = self.class_names[pred_label] if self.class_names else str(pred_label)
# 对于序列化的文本,此处需要进行反序列化
# 这里仅作示例,实际应用中需要根据具体数据类型修改
sample_text = f"Sample {idx}"
misclassified_data.append({
'sample_id': idx,
'true_label': true_label,
'predicted_label': pred_label,
'true_class': true_class,
'predicted_class': pred_class,
'confidence': float(y_prob[idx, pred_label]),
'sample_text': sample_text
})
# 创建DataFrame
misclassified_df = pd.DataFrame(misclassified_data)
# 保存结果
if save_path:
misclassified_df.to_csv(save_path)
logger.info(f"误分类样本分析已保存到: {save_path}")
return misclassified_df

356
evaluation/metrics.py Normal file
View File

@ -0,0 +1,356 @@
"""
评估指标模块实现各种评估指标
"""
import numpy as np
import tensorflow as tf
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
confusion_matrix, classification_report, roc_auc_score,
precision_recall_curve, average_precision_score
)
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
import pandas as pd
from utils.logger import get_logger
logger = get_logger("Metrics")
class ClassificationMetrics:
"""分类评估指标类,计算各种分类评估指标"""
def __init__(self, class_names: Optional[List[str]] = None):
"""
初始化分类评估指标类
Args:
class_names: 类别名称列表
"""
self.class_names = class_names
def accuracy(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""
计算准确率
Args:
y_true: 真实标签
y_pred: 预测标签
Returns:
准确率
"""
return accuracy_score(y_true, y_pred)
def precision(self, y_true: np.ndarray, y_pred: np.ndarray,
average: str = 'macro') -> Union[float, np.ndarray]:
"""
计算精确率
Args:
y_true: 真实标签
y_pred: 预测标签
average: 平均方式可选值: 'micro', 'macro', 'weighted', 'samples', None
Returns:
精确率
"""
return precision_score(y_true, y_pred, average=average, zero_division=0)
def recall(self, y_true: np.ndarray, y_pred: np.ndarray,
average: str = 'macro') -> Union[float, np.ndarray]:
"""
计算召回率
Args:
y_true: 真实标签
y_pred: 预测标签
average: 平均方式可选值: 'micro', 'macro', 'weighted', 'samples', None
Returns:
召回率
"""
return recall_score(y_true, y_pred, average=average, zero_division=0)
def f1(self, y_true: np.ndarray, y_pred: np.ndarray,
average: str = 'macro') -> Union[float, np.ndarray]:
"""
计算F1分数
Args:
y_true: 真实标签
y_pred: 预测标签
average: 平均方式可选值: 'micro', 'macro', 'weighted', 'samples', None
Returns:
F1分数
"""
return f1_score(y_true, y_pred, average=average, zero_division=0)
def confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray,
normalize: Optional[str] = None) -> np.ndarray:
"""
计算混淆矩阵
Args:
y_true: 真实标签
y_pred: 预测标签
normalize: 归一化方式可选值: 'true', 'pred', 'all', None
Returns:
混淆矩阵
"""
cm = confusion_matrix(y_true, y_pred)
if normalize is not None:
if normalize == 'true':
cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
elif normalize == 'pred':
cm = cm.astype('float') / cm.sum(axis=0, keepdims=True)
elif normalize == 'all':
cm = cm.astype('float') / cm.sum()
return cm
def plot_confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray,
normalize: Optional[str] = None,
figsize: Tuple[int, int] = (10, 8),
save_path: Optional[str] = None) -> None:
"""
绘制混淆矩阵
Args:
y_true: 真实标签
y_pred: 预测标签
normalize: 归一化方式可选值: 'true', 'pred', 'all', None
figsize: 图像大小
save_path: 保存路径如果为None则显示图像
"""
# 计算混淆矩阵
cm = self.confusion_matrix(y_true, y_pred, normalize)
# 确定类别名称
if self.class_names is None:
class_names = [str(i) for i in range(cm.shape[0])]
else:
class_names = self.class_names
# 创建图像
plt.figure(figsize=figsize)
# 使用热图显示混淆矩阵
im = plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.colorbar(im)
# 设置坐标轴标签
plt.xticks(np.arange(cm.shape[1]), class_names, rotation=45, ha='right')
plt.yticks(np.arange(cm.shape[0]), class_names)
# 设置标题
if normalize is not None:
plt.title(f"Normalized ({normalize}) Confusion Matrix")
else:
plt.title("Confusion Matrix")
plt.ylabel('True label')
plt.xlabel('Predicted label')
# 在每个单元格中显示数值
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
if normalize is not None:
plt.text(j, i, f"{cm[i, j]:.2f}",
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
else:
plt.text(j, i, f"{cm[i, j]}",
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
# 保存或显示图像
if save_path:
plt.savefig(save_path)
logger.info(f"混淆矩阵图已保存到: {save_path}")
else:
plt.show()
def classification_report(self, y_true: np.ndarray, y_pred: np.ndarray,
output_dict: bool = False) -> Union[str, Dict]:
"""
生成分类报告
Args:
y_true: 真实标签
y_pred: 预测标签
output_dict: 是否以字典形式返回
Returns:
分类报告
"""
target_names = self.class_names if self.class_names else None
return classification_report(y_true, y_pred,
target_names=target_names,
output_dict=output_dict,
zero_division=0)
def auc_roc(self, y_true: np.ndarray, y_prob: np.ndarray,
multi_class: str = 'ovr') -> Union[float, np.ndarray]:
"""
计算AUC-ROC
Args:
y_true: 真实标签
y_prob: 预测概率
multi_class: 多分类处理方式可选值: 'ovr', 'ovo'
Returns:
AUC-ROC
"""
try:
# 如果y_true是one-hot编码转换为类别索引
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 多分类
if y_prob.shape[1] > 2:
return roc_auc_score(y_true, y_prob, multi_class=multi_class, average='macro')
# 二分类
else:
return roc_auc_score(y_true, y_prob[:, 1])
except Exception as e:
logger.error(f"计算AUC-ROC时出错: {e}")
return 0.0
def average_precision(self, y_true: np.ndarray, y_prob: np.ndarray,
average: str = 'macro') -> Union[float, np.ndarray]:
"""
计算平均精确率
Args:
y_true: 真实标签
y_prob: 预测概率
average: 平均方式可选值: 'micro', 'macro', 'weighted', 'samples', None
Returns:
平均精确率
"""
try:
# 如果y_true是one-hot编码转换为类别索引
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 多分类使用sklearn的方法
return average_precision_score(
tf.keras.utils.to_categorical(y_true, num_classes=y_prob.shape[1]),
y_prob,
average=average
)
except Exception as e:
logger.error(f"计算平均精确率时出错: {e}")
return 0.0
def plot_precision_recall_curve(self, y_true: np.ndarray, y_prob: np.ndarray,
class_id: Optional[int] = None,
figsize: Tuple[int, int] = (10, 8),
save_path: Optional[str] = None) -> None:
"""
绘制精确率-召回率曲线
Args:
y_true: 真实标签
y_prob: 预测概率
class_id: 要绘制的类别ID如果为None则绘制所有类别
figsize: 图像大小
save_path: 保存路径如果为None则显示图像
"""
# 如果y_true是one-hot编码转换为类别索引
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 创建图像
plt.figure(figsize=figsize)
# 确定要绘制的类别
if class_id is not None and class_id < y_prob.shape[1]:
# 绘制指定类别的PR曲线
y_true_bin = (y_true == class_id).astype(int)
precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, class_id])
avg_prec = average_precision_score(y_true_bin, y_prob[:, class_id])
plt.step(recall, precision, where='post',
label=f'Class {class_id} (AP = {avg_prec:.3f})')
else:
# 绘制所有类别的PR曲线
for i in range(y_prob.shape[1]):
y_true_bin = (y_true == i).astype(int)
precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, i])
avg_prec = average_precision_score(y_true_bin, y_prob[:, i])
class_name = self.class_names[i] if self.class_names else f"Class {i}"
plt.step(recall, precision, where='post',
label=f'{class_name} (AP = {avg_prec:.3f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc='lower left')
plt.grid(True)
# 保存或显示图像
if save_path:
plt.savefig(save_path)
logger.info(f"精确率-召回率曲线图已保存到: {save_path}")
else:
plt.show()
def calculate_all_metrics(self, y_true: np.ndarray, y_pred: np.ndarray,
y_prob: Optional[np.ndarray] = None) -> Dict[str, float]:
"""
计算所有评估指标
Args:
y_true: 真实标签
y_pred: 预测标签
y_prob: 预测概率
Returns:
包含所有评估指标的字典
"""
metrics = {}
# 基础指标
metrics['accuracy'] = self.accuracy(y_true, y_pred)
metrics['precision_macro'] = self.precision(y_true, y_pred, average='macro')
metrics['recall_macro'] = self.recall(y_true, y_pred, average='macro')
metrics['f1_macro'] = self.f1(y_true, y_pred, average='macro')
# 如果提供了预测概率计算AUC-ROC和平均精确率
if y_prob is not None:
try:
metrics['auc_roc'] = self.auc_roc(y_true, y_prob)
metrics['average_precision'] = self.average_precision(y_true, y_prob)
except Exception as e:
logger.error(f"计算概率指标时出错: {e}")
# 类别级别的指标
for avg in ['micro', 'weighted']:
metrics[f'precision_{avg}'] = self.precision(y_true, y_pred, average=avg)
metrics[f'recall_{avg}'] = self.recall(y_true, y_pred, average=avg)
metrics[f'f1_{avg}'] = self.f1(y_true, y_pred, average=avg)
return metrics
def metrics_to_dataframe(self, metrics: Dict[str, float]) -> pd.DataFrame:
"""
将评估指标转换为DataFrame
Args:
metrics: 评估指标字典
Returns:
评估指标DataFrame
"""
return pd.DataFrame(metrics.items(), columns=['Metric', 'Value']).set_index('Metric')

370
evaluation/visualization.py Normal file
View File

@ -0,0 +1,370 @@
"""
可视化模块实现评估结果的可视化
"""
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from typing import List, Dict, Tuple, Optional, Any, Union
import os
import itertools
from sklearn.metrics import roc_curve, precision_recall_curve, auc
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from utils.logger import get_logger
from utils.file_utils import ensure_dir
logger = get_logger("Visualization")
class EvaluationVisualizer:
"""评估结果可视化类"""
def __init__(self, output_dir: Optional[str] = None,
class_names: Optional[List[str]] = None,
figsize: Tuple[int, int] = (10, 8)):
"""
初始化评估结果可视化类
Args:
output_dir: 输出目录用于保存可视化结果
class_names: 类别名称列表
figsize: 图像默认大小
"""
self.output_dir = output_dir
if output_dir:
ensure_dir(output_dir)
self.class_names = class_names
self.figsize = figsize
def plot_confusion_matrix(self, cm: np.ndarray,
normalize: Optional[str] = None,
title: str = 'Confusion Matrix',
cmap: str = 'Blues',
save_path: Optional[str] = None) -> None:
"""
绘制混淆矩阵
Args:
cm: 混淆矩阵
normalize: 归一化方式可选值: 'true', 'pred', 'all', None
title: 图像标题
cmap: 颜色映射
save_path: 保存路径如果为None则使用output_dir/confusion_matrix.png
"""
if normalize is not None:
if normalize == 'true':
cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
title = 'Normalized (by true) ' + title
elif normalize == 'pred':
cm = cm.astype('float') / cm.sum(axis=0, keepdims=True)
title = 'Normalized (by pred) ' + title
elif normalize == 'all':
cm = cm.astype('float') / cm.sum()
title = 'Normalized (by all) ' + title
plt.figure(figsize=self.figsize)
# 使用seaborn绘制热图
sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd',
cmap=cmap, square=True, cbar=True)
# 设置坐标轴标签
if self.class_names:
tick_marks = np.arange(len(self.class_names))
plt.xticks(tick_marks + 0.5, self.class_names, rotation=45, ha='right')
plt.yticks(tick_marks + 0.5, self.class_names, rotation=0)
plt.title(title)
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
# 保存图像
if save_path is None and self.output_dir:
save_path = os.path.join(self.output_dir, 'confusion_matrix.png')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"混淆矩阵图已保存到: {save_path}")
plt.close()
def plot_metrics_comparison(self, metrics_dict: Dict[str, Dict[str, float]],
selected_metrics: Optional[List[str]] = None,
title: str = 'Metrics Comparison',
save_path: Optional[str] = None) -> None:
"""
绘制多个模型的评估指标比较
Args:
metrics_dict: 模型评估指标字典格式为{model_name: {metric_name: value}}
selected_metrics: 要比较的指标列表如果为None则使用所有指标
title: 图像标题
save_path: 保存路径如果为None则使用output_dir/metrics_comparison.png
"""
# 创建DataFrame
df = pd.DataFrame(metrics_dict).T
# 筛选指标
if selected_metrics:
df = df[selected_metrics]
# 绘制条形图
plt.figure(figsize=self.figsize)
df.plot(kind='bar', figsize=self.figsize)
plt.title(title)
plt.ylabel('Score')
plt.ylim(0, 1)
plt.legend(loc='best')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
# 保存图像
if save_path is None and self.output_dir:
save_path = os.path.join(self.output_dir, 'metrics_comparison.png')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"评估指标比较图已保存到: {save_path}")
plt.close()
def plot_roc_curves(self, y_true: np.ndarray, y_prob: np.ndarray,
title: str = 'ROC Curves',
save_path: Optional[str] = None) -> None:
"""
绘制ROC曲线
Args:
y_true: 真实标签
y_prob: 预测概率
title: 图像标题
save_path: 保存路径如果为None则使用output_dir/roc_curves.png
"""
plt.figure(figsize=self.figsize)
# 确保y_true是一维数组
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 获取类别数
num_classes = y_prob.shape[1]
# 绘制每个类别的ROC曲线
for i in range(num_classes):
# 二分类转换:当前类别为正类,其他为负类
y_true_bin = (y_true == i).astype(int)
# 计算ROC曲线
fpr, tpr, _ = roc_curve(y_true_bin, y_prob[:, i])
roc_auc = auc(fpr, tpr)
# 确定类别名称
if self.class_names and i < len(self.class_names):
class_name = self.class_names[i]
else:
class_name = f'Class {i}'
# 绘制ROC曲线
plt.plot(fpr, tpr, lw=2, label=f'{class_name} (AUC = {roc_auc:.3f})')
# 绘制随机猜测的基准线
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(title)
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.tight_layout()
# 保存图像
if save_path is None and self.output_dir:
save_path = os.path.join(self.output_dir, 'roc_curves.png')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"ROC曲线图已保存到: {save_path}")
plt.close()
def plot_precision_recall_curves(self, y_true: np.ndarray, y_prob: np.ndarray,
title: str = 'Precision-Recall Curves',
save_path: Optional[str] = None) -> None:
"""
绘制精确率-召回率曲线
Args:
y_true: 真实标签
y_prob: 预测概率
title: 图像标题
save_path: 保存路径如果为None则使用output_dir/precision_recall_curves.png
"""
plt.figure(figsize=self.figsize)
# 确保y_true是一维数组
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 获取类别数
num_classes = y_prob.shape[1]
# 绘制每个类别的PR曲线
for i in range(num_classes):
# 二分类转换:当前类别为正类,其他为负类
y_true_bin = (y_true == i).astype(int)
# 计算PR曲线
precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, i])
pr_auc = auc(recall, precision)
# 确定类别名称
if self.class_names and i < len(self.class_names):
class_name = self.class_names[i]
else:
class_name = f'Class {i}'
# 绘制PR曲线
plt.plot(recall, precision, lw=2, label=f'{class_name} (AUC = {pr_auc:.3f})')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title(title)
plt.legend(loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()
# 保存图像
if save_path is None and self.output_dir:
save_path = os.path.join(self.output_dir, 'precision_recall_curves.png')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"精确率-召回率曲线图已保存到: {save_path}")
plt.close()
def plot_feature_importance(self, feature_names: List[str],
importance: np.ndarray,
title: str = 'Feature Importance',
top_n: int = 20,
save_path: Optional[str] = None) -> None:
"""
绘制特征重要性
Args:
feature_names: 特征名称列表
importance: 特征重要性数组
title: 图像标题
top_n: 显示前N个重要的特征
save_path: 保存路径如果为None则使用output_dir/feature_importance.png
"""
# 创建DataFrame
df = pd.DataFrame({'Feature': feature_names, 'Importance': importance})
# 按重要性排序
df = df.sort_values('Importance', ascending=False).head(top_n)
# 绘制条形图
plt.figure(figsize=self.figsize)
sns.barplot(x='Importance', y='Feature', data=df)
plt.title(title)
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.tight_layout()
# 保存图像
if save_path is None and self.output_dir:
save_path = os.path.join(self.output_dir, 'feature_importance.png')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"特征重要性图已保存到: {save_path}")
plt.close()
def plot_embedding_visualization(self, embeddings: np.ndarray,
labels: np.ndarray,
method: str = 'tsne',
title: str = 'Embedding Visualization',
save_path: Optional[str] = None) -> None:
"""
绘制嵌入向量可视化
Args:
embeddings: 嵌入向量形状为(样本数, 嵌入维度)
labels: 类别标签形状为(样本数,)
method: 降维方法'tsne''pca'
title: 图像标题
save_path: 保存路径如果为None则使用output_dir/embedding_visualization.png
"""
# 降维
if method.lower() == 'tsne':
reducer = TSNE(n_components=2, random_state=42)
elif method.lower() == 'pca':
reducer = PCA(n_components=2, random_state=42)
else:
raise ValueError(f"不支持的降维方法: {method}")
# 如果嵌入向量太多,采样一部分
max_samples = 5000
if len(embeddings) > max_samples:
indices = np.random.choice(len(embeddings), max_samples, replace=False)
embeddings_sample = embeddings[indices]
labels_sample = labels[indices]
else:
embeddings_sample = embeddings
labels_sample = labels
# 执行降维
embeddings_2d = reducer.fit_transform(embeddings_sample)
# 绘制散点图
plt.figure(figsize=self.figsize)
# 确保标签是一维数组
if len(labels_sample.shape) > 1 and labels_sample.shape[1] > 1:
labels_sample = np.argmax(labels_sample, axis=1)
# 获取唯一类别
unique_labels = np.unique(labels_sample)
# 为每个类别分配不同的颜色
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
for i, label in enumerate(unique_labels):
mask = labels_sample == label
# 确定类别名称
if self.class_names and label < len(self.class_names):
class_name = self.class_names[int(label)]
else:
class_name = f'Class {int(label)}'
plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
c=[colors[i]], label=class_name, alpha=0.7)
plt.title(title)
plt.legend(loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()
# 保存图像
if save_path is None and self.output_dir:
save_path = os.path.join(self.output_dir, f'embedding_visualization_{method}.png')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"嵌入向量可视化图已保存到: {save_path}")
plt.close()

0
inference/__init__.py Normal file
View File

View File

@ -0,0 +1,362 @@
"""
批处理模块实现批量处理大规模文本数据
"""
import os
import time
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Optional, Any, Union, Callable, Iterator
import concurrent.futures
from tqdm import tqdm
import glob
import json
from config.system_config import ENCODING, DATA_LOADING_WORKERS, MAX_TEXT_PER_BATCH
from utils.logger import get_logger
from utils.file_utils import read_text_file, ensure_dir
from inference.predictor import Predictor
logger = get_logger("BatchProcessor")
class BatchProcessor:
"""批处理器,负责批量处理大规模文本数据"""
def __init__(self, predictor: Predictor,
batch_size: int = 64,
max_workers: int = DATA_LOADING_WORKERS,
max_batch_queue: int = 10):
"""
初始化批处理器
Args:
predictor: 预测器实例
batch_size: 批大小
max_workers: 最大工作线程数
max_batch_queue: 最大批次队列长度
"""
self.predictor = predictor
self.batch_size = batch_size
self.max_workers = max_workers
self.max_batch_queue = max_batch_queue
logger.info(f"初始化批处理器,批大小: {batch_size}, 最大工作线程数: {max_workers}")
def _extract_text_from_file(self, file_path: str) -> str:
"""
从文件中提取文本
Args:
file_path: 文件路径
Returns:
文本内容
"""
return read_text_file(file_path, encoding=ENCODING)
def _batch_generator(self, texts: List[str], batch_size: int) -> Iterator[List[str]]:
"""
生成文本批次
Args:
texts: 文本列表
batch_size: 批大小
Returns:
文本批次生成器
"""
for i in range(0, len(texts), batch_size):
yield texts[i:i + batch_size]
def process_files(self, file_paths: List[str], output_path: Optional[str] = None,
return_top_k: int = 1, format: str = 'csv') -> pd.DataFrame:
"""
批量处理文件
Args:
file_paths: 文件路径列表
output_path: 输出文件路径如果为None则不保存
return_top_k: 返回概率最高的前k个类别
format: 输出格式'csv''json'
Returns:
预测结果DataFrame
"""
logger.info(f"开始批量处理 {len(file_paths)} 个文件")
start_time = time.time()
# 使用线程池并行读取文件
texts = []
file_names = []
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_file = {executor.submit(self._extract_text_from_file, file_path): file_path for file_path in
file_paths}
for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(file_paths), desc="读取文件"):
file_path = future_to_file[future]
try:
text = future.result()
if text:
texts.append(text)
file_names.append(os.path.basename(file_path))
except Exception as e:
logger.error(f"处理文件 {file_path} 时出错: {e}")
# 批量预测
all_predictions = []
for batch in tqdm(self._batch_generator(texts, self.batch_size),
total=(len(texts) + self.batch_size - 1) // self.batch_size, desc="预测"):
predictions = self.predictor.predict_batch(batch, return_top_k=return_top_k, return_probabilities=True)
all_predictions.extend(predictions)
# 整合结果
if return_top_k > 1:
# 多个类别的情况
results = []
for i, preds in enumerate(all_predictions):
for j, pred in enumerate(preds):
results.append({
'file_name': file_names[i],
'rank': j + 1,
'predicted_class': pred['class'],
'probability': pred['probability']
})
df = pd.DataFrame(results)
else:
# 单个类别的情况
df = pd.DataFrame({
'file_name': file_names,
'predicted_class': [pred['class'] for pred in all_predictions],
'probability': [pred['probability'] for pred in all_predictions]
})
# 保存结果
if output_path:
if format.lower() == 'csv':
df.to_csv(output_path, index=False, encoding='utf-8')
elif format.lower() == 'json':
# 转换为嵌套的JSON格式
if return_top_k > 1:
# 分组后转换为嵌套格式
result = {}
for file_name in df['file_name'].unique():
sub_df = df[df['file_name'] == file_name]
predictions = []
for _, row in sub_df.iterrows():
predictions.append({
'class': row['predicted_class'],
'probability': row['probability']
})
result[file_name] = {
'predictions': predictions
}
else:
# 直接构建JSON
result = {}
for _, row in df.iterrows():
result[row['file_name']] = {
'predicted_class': row['predicted_class'],
'probability': row['probability']
}
# 保存为JSON
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
else:
raise ValueError(f"不支持的输出格式: {format}")
logger.info(f"预测结果已保存到: {output_path}")
processing_time = time.time() - start_time
logger.info(f"批量处理完成,共处理 {len(texts)} 个文件,用时: {processing_time:.2f}")
return df
def process_directory(self, directory: str, pattern: str = "*.txt",
output_path: Optional[str] = None,
return_top_k: int = 1, format: str = 'csv',
recursive: bool = True) -> pd.DataFrame:
"""
批量处理目录中的文件
Args:
directory: 目录路径
pattern: 文件匹配模式
output_path: 输出文件路径如果为None则不保存
return_top_k: 返回概率最高的前k个类别
format: 输出格式'csv''json'
recursive: 是否递归处理子目录
Returns:
预测结果DataFrame
"""
# 获取符合模式的文件路径
if recursive:
file_paths = glob.glob(os.path.join(directory, "**", pattern), recursive=True)
else:
file_paths = glob.glob(os.path.join(directory, pattern))
if not file_paths:
logger.warning(f"在目录 {directory} 中未找到符合模式 {pattern} 的文件")
return pd.DataFrame()
logger.info(f"在目录 {directory} 中找到 {len(file_paths)} 个符合模式 {pattern} 的文件")
# 调用process_files处理文件
return self.process_files(file_paths, output_path, return_top_k, format)
def process_dataframe(self, df: pd.DataFrame, text_column: str,
id_column: Optional[str] = None,
output_path: Optional[str] = None,
return_top_k: int = 1, format: str = 'csv') -> pd.DataFrame:
"""
批量处理DataFrame中的文本
Args:
df: 输入DataFrame
text_column: 文本列名
id_column: ID列名如果为None则使用索引
output_path: 输出文件路径如果为None则不保存
return_top_k: 返回概率最高的前k个类别
format: 输出格式'csv''json'
Returns:
预测结果DataFrame
"""
# 获取文本和ID
texts = df[text_column].tolist()
if id_column:
ids = df[id_column].tolist()
else:
ids = df.index.tolist()
# 批量预测
result_df = self.predictor.predict_to_dataframe(texts, ids, return_top_k)
# 保存结果
if output_path:
if format.lower() == 'csv':
result_df.to_csv(output_path, index=False, encoding='utf-8')
elif format.lower() == 'json':
# 转换为嵌套的JSON格式
self.predictor.save_predictions(texts, output_path, ids, return_top_k, 'json')
else:
raise ValueError(f"不支持的输出格式: {format}")
logger.info(f"预测结果已保存到: {output_path}")
return result_df
def process_large_file(self, file_path: str, output_path: Optional[str] = None,
return_top_k: int = 1, format: str = 'csv',
chunk_size: int = MAX_TEXT_PER_BATCH,
delimiter: str = '\n\n') -> None:
"""
处理大型文本文件文件会被分块读取和处理
Args:
file_path: 文件路径
output_path: 输出文件路径如果为None则不保存
return_top_k: 返回概率最高的前k个类别
format: 输出格式'csv''json'
chunk_size: 每个块的大小文本数量
delimiter: 文本分隔符
"""
logger.info(f"开始处理大型文件: {file_path}")
start_time = time.time()
# 读取文件内容
with open(file_path, 'r', encoding=ENCODING) as f:
content = f.read()
# 分割文本
texts = content.split(delimiter)
texts = [text.strip() for text in texts if text.strip()]
logger.info(f"文件共包含 {len(texts)} 条文本")
# 创建输出文件
if output_path:
if format.lower() == 'csv':
# 创建CSV文件头
if return_top_k > 1:
header = "text_id,text,rank,predicted_class,probability\n"
else:
header = "text_id,text,predicted_class,probability\n"
with open(output_path, 'w', encoding=ENCODING) as f:
f.write(header)
elif format.lower() == 'json':
# 创建JSON文件
with open(output_path, 'w', encoding=ENCODING) as f:
f.write('{\n')
# 分块处理
total_chunks = (len(texts) + chunk_size - 1) // chunk_size
for i in range(0, len(texts), chunk_size):
chunk = texts[i:i + chunk_size]
chunk_ids = list(range(i, i + len(chunk)))
logger.info(f"处理第 {i // chunk_size + 1}/{total_chunks} 块,包含 {len(chunk)} 条文本")
# 批量预测
result_df = self.predictor.predict_to_dataframe(chunk, chunk_ids, return_top_k)
# 追加到输出文件
if output_path:
if format.lower() == 'csv':
result_df.to_csv(output_path, index=False, encoding=ENCODING, mode='a', header=False)
elif format.lower() == 'json':
# 转换为JSON并追加
if return_top_k > 1:
# 分组后转换为嵌套格式
for id_val in result_df['id'].unique():
sub_df = result_df[result_df['id'] == id_val]
predictions = []
for _, row in sub_df.iterrows():
predictions.append({
'class': row['predicted_class'],
'probability': float(row['probability'])
})
json_str = f' "{id_val}": {{\n'
json_str += f' "text": {json.dumps(sub_df.iloc[0]["text"], ensure_ascii=False)},\n'
json_str += f' "predictions": {json.dumps(predictions, ensure_ascii=False)}\n'
json_str += ' },'
with open(output_path, 'a', encoding=ENCODING) as f:
f.write(json_str + '\n')
else:
# 直接构建JSON
for _, row in result_df.iterrows():
json_str = f' "{row["id"]}": {{\n'
json_str += f' "text": {json.dumps(row["text"], ensure_ascii=False)},\n'
json_str += f' "predicted_class": "{row["predicted_class"]}",\n'
json_str += f' "probability": {float(row["probability"])}\n'
json_str += ' },'
with open(output_path, 'a', encoding=ENCODING) as f:
f.write(json_str + '\n')
# 完成JSON文件
if output_path and format.lower() == 'json':
with open(output_path, 'a', encoding=ENCODING) as f:
f.write('}\n')
# 修复JSON文件中的最后一个逗号
with open(output_path, 'r', encoding=ENCODING) as f:
content = f.read()
content = content.rstrip('\n}')
content = content.rstrip(',')
content += '\n}\n'
with open(output_path, 'w', encoding=ENCODING) as f:
f.write(content)
processing_time = time.time() - start_time
logger.info(f"处理大型文件完成,共处理 {len(texts)} 条文本,用时: {processing_time:.2f}")

316
inference/predictor.py Normal file
View File

@ -0,0 +1,316 @@
"""
预测器模块实现模型预测功能支持单条和批量文本预测
"""
import os
import time
import numpy as np
import tensorflow as tf
from typing import List, Dict, Tuple, Optional, Any, Union
import pandas as pd
import json
from config.system_config import CATEGORY_TO_ID, ID_TO_CATEGORY
from models.base_model import TextClassificationModel
from preprocessing.tokenization import ChineseTokenizer
from preprocessing.vectorizer import SequenceVectorizer
from utils.logger import get_logger
logger = get_logger("Predictor")
class Predictor:
"""预测器,负责加载模型和进行预测"""
def __init__(self, model: TextClassificationModel,
tokenizer: Optional[ChineseTokenizer] = None,
vectorizer: Optional[SequenceVectorizer] = None,
class_names: Optional[List[str]] = None,
max_sequence_length: int = 500,
batch_size: Optional[int] = None):
"""
初始化预测器
Args:
model: 已训练的模型实例
tokenizer: 分词器实例如果为None则创建一个新的分词器
vectorizer: 文本向量化器实例如果为None则表示模型直接接收序列
class_names: 类别名称列表如果为None则使用ID_TO_CATEGORY
max_sequence_length: 最大序列长度
batch_size: 批大小如果为None则使用模型默认值
"""
self.model = model
self.tokenizer = tokenizer or ChineseTokenizer()
self.vectorizer = vectorizer
self.class_names = class_names
if class_names is None and hasattr(model, 'num_classes'):
# 如果模型具有类别数量信息从ID_TO_CATEGORY获取类别名称
self.class_names = [ID_TO_CATEGORY.get(i, str(i)) for i in range(model.num_classes)]
self.max_sequence_length = max_sequence_length
self.batch_size = batch_size or (model.batch_size if hasattr(model, 'batch_size') else 32)
logger.info(f"初始化预测器,批大小: {self.batch_size}")
def preprocess_text(self, text: str) -> Any:
"""
预处理单条文本
Args:
text: 原始文本
Returns:
预处理后的文本表示
"""
# 分词
tokenized_text = self.tokenizer.tokenize(text, return_string=True)
# 如果有向量化器,应用向量化
if self.vectorizer is not None:
return self.vectorizer.transform([tokenized_text])[0]
return tokenized_text
def preprocess_texts(self, texts: List[str]) -> Any:
"""
批量预处理文本
Args:
texts: 原始文本列表
Returns:
预处理后的批量文本表示
"""
# 分词
tokenized_texts = [self.tokenizer.tokenize(text, return_string=True) for text in texts]
# 如果有向量化器,应用向量化
if self.vectorizer is not None:
return self.vectorizer.transform(tokenized_texts)
return tokenized_texts
def predict(self, text: str, return_top_k: int = 1,
return_probabilities: bool = False) -> Union[str, Dict, List]:
"""
预测单条文本的类别
Args:
text: 原始文本
return_top_k: 返回概率最高的前k个类别
return_probabilities: 是否返回概率值
Returns:
预测结果格式取决于参数设置
"""
# 预处理文本
processed_text = self.preprocess_text(text)
# 添加批次维度
if isinstance(processed_text, str):
input_data = np.array([processed_text])
else:
input_data = np.expand_dims(processed_text, axis=0)
# 预测
start_time = time.time()
predictions = self.model.predict(input_data)
prediction_time = time.time() - start_time
# 获取前k个预测结果
if return_top_k > 1:
top_indices = np.argsort(predictions[0])[::-1][:return_top_k]
top_probs = predictions[0][top_indices]
if self.class_names:
top_classes = [self.class_names[idx] for idx in top_indices]
else:
top_classes = [str(idx) for idx in top_indices]
if return_probabilities:
return [{'class': cls, 'probability': float(prob)}
for cls, prob in zip(top_classes, top_probs)]
else:
return top_classes
else:
# 获取最高概率的类别
pred_idx = np.argmax(predictions[0])
pred_prob = float(predictions[0][pred_idx])
if self.class_names:
pred_class = self.class_names[pred_idx]
else:
pred_class = str(pred_idx)
if return_probabilities:
return {'class': pred_class, 'probability': pred_prob, 'time': prediction_time}
else:
return pred_class
def predict_batch(self, texts: List[str], return_top_k: int = 1,
return_probabilities: bool = False) -> List:
"""
批量预测文本类别
Args:
texts: 原始文本列表
return_top_k: 返回概率最高的前k个类别
return_probabilities: 是否返回概率值
Returns:
预测结果列表
"""
# 空列表检查
if not texts:
return []
# 预处理文本
processed_texts = self.preprocess_texts(texts)
# 预测
start_time = time.time()
predictions = self.model.predict(processed_texts, batch_size=self.batch_size)
prediction_time = time.time() - start_time
# 处理预测结果
results = []
for i, pred in enumerate(predictions):
if return_top_k > 1:
top_indices = np.argsort(pred)[::-1][:return_top_k]
top_probs = pred[top_indices]
if self.class_names:
top_classes = [self.class_names[idx] for idx in top_indices]
else:
top_classes = [str(idx) for idx in top_indices]
if return_probabilities:
results.append([{'class': cls, 'probability': float(prob)}
for cls, prob in zip(top_classes, top_probs)])
else:
results.append(top_classes)
else:
# 获取最高概率的类别
pred_idx = np.argmax(pred)
pred_prob = float(pred[pred_idx])
if self.class_names:
pred_class = self.class_names[pred_idx]
else:
pred_class = str(pred_idx)
if return_probabilities:
results.append({'class': pred_class, 'probability': pred_prob})
else:
results.append(pred_class)
logger.info(f"批量预测 {len(texts)} 条文本完成,用时: {prediction_time:.2f}")
return results
def predict_to_dataframe(self, texts: List[str],
text_ids: Optional[List[Union[str, int]]] = None,
return_top_k: int = 1) -> pd.DataFrame:
"""
批量预测并返回DataFrame
Args:
texts: 原始文本列表
text_ids: 文本ID列表如果为None则使用索引
return_top_k: 返回概率最高的前k个类别
Returns:
预测结果DataFrame
"""
# 预测
predictions = self.predict_batch(texts, return_top_k=return_top_k, return_probabilities=True)
# 创建DataFrame
if text_ids is None:
text_ids = list(range(len(texts)))
if return_top_k > 1:
# 多个类别的情况
results = []
for i, preds in enumerate(predictions):
for j, pred in enumerate(preds):
results.append({
'id': text_ids[i],
'text': texts[i],
'rank': j + 1,
'predicted_class': pred['class'],
'probability': pred['probability']
})
df = pd.DataFrame(results)
else:
# 单个类别的情况
df = pd.DataFrame({
'id': text_ids,
'text': texts,
'predicted_class': [pred['class'] for pred in predictions],
'probability': [pred['probability'] for pred in predictions]
})
return df
def save_predictions(self, texts: List[str],
output_path: str,
text_ids: Optional[List[Union[str, int]]] = None,
return_top_k: int = 1,
format: str = 'csv') -> str:
"""
批量预测并保存结果
Args:
texts: 原始文本列表
output_path: 输出文件路径
text_ids: 文本ID列表如果为None则使用索引
return_top_k: 返回概率最高的前k个类别
format: 输出格式'csv''json'
Returns:
输出文件路径
"""
# 获取预测结果DataFrame
df = self.predict_to_dataframe(texts, text_ids, return_top_k)
# 保存结果
if format.lower() == 'csv':
df.to_csv(output_path, index=False, encoding='utf-8')
elif format.lower() == 'json':
# 转换为嵌套的JSON格式
if return_top_k > 1:
# 分组后转换为嵌套格式
result = {}
for id_val in df['id'].unique():
sub_df = df[df['id'] == id_val]
predictions = []
for _, row in sub_df.iterrows():
predictions.append({
'class': row['predicted_class'],
'probability': row['probability']
})
result[str(id_val)] = {
'text': sub_df.iloc[0]['text'],
'predictions': predictions
}
else:
# 直接构建JSON
result = {}
for _, row in df.iterrows():
result[str(row['id'])] = {
'text': row['text'],
'predicted_class': row['predicted_class'],
'probability': row['probability']
}
# 保存为JSON
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
else:
raise ValueError(f"不支持的输出格式: {format}")
logger.info(f"预测结果已保存到: {output_path}")
return output_path

0
interface/__init__.py Normal file
View File

391
interface/api.py Normal file
View File

@ -0,0 +1,391 @@
"""
API接口模块提供REST API接口
"""
import os
import sys
import json
import time
from typing import List, Dict, Tuple, Optional, Any, Union
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query, Depends, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
import asyncio
import pandas as pd
import io
# 将项目根目录添加到sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config.system_config import CLASSIFIERS_DIR, CATEGORIES
from models.model_factory import ModelFactory
from preprocessing.tokenization import ChineseTokenizer
from preprocessing.vectorizer import SequenceVectorizer
from inference.predictor import Predictor
from inference.batch_processor import BatchProcessor
from utils.logger import get_logger
logger = get_logger("API")
# 数据模型
class TextItem(BaseModel):
text: str
id: Optional[str] = None
class BatchPredictRequest(BaseModel):
texts: List[TextItem]
top_k: Optional[int] = 1
class BatchFileRequest(BaseModel):
file_paths: List[str]
top_k: Optional[int] = 1
class ModelInfo(BaseModel):
id: str
name: str
type: str
num_classes: int
created_time: str
file_size: str
# 应用实例
app = FastAPI(
title="中文文本分类系统API",
description="提供中文文本分类功能的REST API",
version="1.0.0"
)
# 允许跨域请求
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局对象
predictor = None
def get_predictor() -> Predictor:
"""
获取或创建全局Predictor实例
Returns:
Predictor实例
"""
global predictor
if predictor is None:
# 获取可用模型列表
models_info = ModelFactory.get_available_models()
if not models_info:
raise HTTPException(status_code=500, detail="未找到可用的模型")
# 使用最新的模型
model_path = models_info[0]['path']
logger.info(f"API加载模型: {model_path}")
# 加载模型
model = ModelFactory.load_model(model_path)
# 创建分词器
tokenizer = ChineseTokenizer()
# 创建预测器
predictor = Predictor(
model=model,
tokenizer=tokenizer,
class_names=CATEGORIES,
batch_size=64
)
return predictor
@app.get("/")
async def root():
"""API根路径"""
return {"message": "欢迎使用中文文本分类系统API"}
@app.post("/predict/text")
async def predict_text(text: str = Form(...), top_k: int = Form(1)):
"""
预测单条文本
Args:
text: 要预测的文本
top_k: 返回概率最高的前k个类别
Returns:
预测结果
"""
logger.info(f"接收到文本预测请求,文本长度: {len(text)}")
try:
# 获取预测器
predictor = get_predictor()
# 预测
start_time = time.time()
result = predictor.predict(
text=text,
return_top_k=top_k,
return_probabilities=True
)
prediction_time = time.time() - start_time
# 构建响应
response = {
"success": True,
"predictions": result if top_k > 1 else [result],
"time": prediction_time
}
return response
except Exception as e:
logger.error(f"预测文本时出错: {e}")
raise HTTPException(status_code=500, detail=f"预测文本时出错: {str(e)}")
@app.post("/predict/batch")
async def predict_batch(request: BatchPredictRequest):
"""
批量预测文本
Args:
request: 包含文本列表和参数的请求
Returns:
批量预测结果
"""
texts = [item.text for item in request.texts]
ids = [item.id or str(i) for i, item in enumerate(request.texts)]
logger.info(f"接收到批量预测请求,共 {len(texts)} 条文本")
try:
# 获取预测器
predictor = get_predictor()
# 预测
start_time = time.time()
results = predictor.predict_batch(
texts=texts,
return_top_k=request.top_k,
return_probabilities=True
)
prediction_time = time.time() - start_time
# 构建响应
response = {
"success": True,
"total": len(texts),
"time": prediction_time,
"results": {}
}
# 将结果关联到ID
for i, (id_val, result) in enumerate(zip(ids, results)):
response["results"][id_val] = result
return response
except Exception as e:
logger.error(f"批量预测文本时出错: {e}")
raise HTTPException(status_code=500, detail=f"批量预测文本时出错: {str(e)}")
@app.post("/predict/file")
async def predict_file(file: UploadFile = File(...), top_k: int = Form(1)):
"""
预测文件内容
Args:
file: 上传的文件
top_k: 返回概率最高的前k个类别
Returns:
预测结果
"""
logger.info(f"接收到文件预测请求,文件名: {file.filename}")
try:
# 读取文件内容
content = await file.read()
# 根据文件类型处理内容
if file.filename.endswith(('.txt', '.md')):
# 文本文件
text = content.decode('utf-8')
# 获取预测器
predictor = get_predictor()
# 预测
result = predictor.predict(
text=text,
return_top_k=top_k,
return_probabilities=True
)
# 构建响应
response = {
"success": True,
"filename": file.filename,
"predictions": result if top_k > 1 else [result]
}
return response
elif file.filename.endswith(('.csv', '.xls', '.xlsx')):
# 表格文件
if file.filename.endswith('.csv'):
df = pd.read_csv(io.BytesIO(content))
else:
df = pd.read_excel(io.BytesIO(content))
# 查找可能的文本列
text_columns = [col for col in df.columns if df[col].dtype == 'object']
if not text_columns:
raise HTTPException(status_code=400, detail="文件中没有找到可能的文本列")
# 使用第一个文本列
text_column = text_columns[0]
texts = df[text_column].fillna('').tolist()
# 获取预测器
predictor = get_predictor()
# 批量预测
results = predictor.predict_batch(
texts=texts,
return_top_k=top_k,
return_probabilities=True
)
# 构建响应
response = {
"success": True,
"filename": file.filename,
"text_column": text_column,
"total": len(texts),
"results": results
}
return response
else:
raise HTTPException(status_code=400, detail=f"不支持的文件类型: {file.filename}")
except Exception as e:
logger.error(f"预测文件时出错: {e}")
raise HTTPException(status_code=500, detail=f"预测文件时出错: {str(e)}")
@app.get("/models")
async def list_models():
"""
列出可用的模型
Returns:
可用模型列表
"""
try:
# 获取可用模型列表
models_info = ModelFactory.get_available_models()
# 转换为响应格式
models = []
for info in models_info:
models.append(ModelInfo(
id=os.path.basename(info['path']),
name=info['name'],
type=info['type'],
num_classes=info['num_classes'],
created_time=info['created_time'],
file_size=info['file_size']
))
return {"models": models}
except Exception as e:
logger.error(f"获取模型列表时出错: {e}")
raise HTTPException(status_code=500, detail=f"获取模型列表时出错: {str(e)}")
@app.get("/categories")
async def list_categories():
"""
列出支持的类别
Returns:
支持的类别列表
"""
try:
return {"categories": CATEGORIES}
except Exception as e:
logger.error(f"获取类别列表时出错: {e}")
raise HTTPException(status_code=500, detail=f"获取类别列表时出错: {str(e)}")
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""
记录请求日志
Args:
request: 请求对象
call_next: 下一个处理函数
Returns:
响应对象
"""
start_time = time.time()
# 记录请求信息
logger.info(f"请求: {request.method} {request.url}")
# 处理请求
response = await call_next(request)
# 记录响应信息
process_time = time.time() - start_time
logger.info(f"响应: {response.status_code} ({process_time:.2f}s)")
return response
def run_server(host: str = "0.0.0.0", port: int = 8000):
"""
运行API服务器
Args:
host: 主机地址
port: 端口号
"""
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
# 解析命令行参数
import argparse
parser = argparse.ArgumentParser(description="中文文本分类系统API服务器")
parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址")
parser.add_argument("--port", type=int, default=8000, help="服务器端口号")
args = parser.parse_args()
# 运行服务器
run_server(host=args.host, port=args.port)

378
interface/cli.py Normal file
View File

@ -0,0 +1,378 @@
"""
命令行界面模块提供命令行交互功能
"""
import argparse
import os
import sys
import pandas as pd
from typing import List, Dict, Tuple, Optional, Any, Union
import json
# 将项目根目录添加到sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config.system_config import CLASSIFIERS_DIR, CATEGORIES
from models.model_factory import ModelFactory
from models.base_model import TextClassificationModel
from preprocessing.tokenization import ChineseTokenizer
from preprocessing.vectorizer import SequenceVectorizer
from inference.predictor import Predictor
from inference.batch_processor import BatchProcessor
from utils.logger import get_logger
from utils.file_utils import ensure_dir, read_text_file
logger = get_logger("CLI")
def load_model_and_components(model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
vectorizer_path: Optional[str] = None,
class_names: Optional[List[str]] = None) -> Tuple[
TextClassificationModel, ChineseTokenizer, Optional[SequenceVectorizer]]:
"""
加载模型和相关组件
Args:
model_path: 模型路径如果为None则使用最新的模型
tokenizer_path: 分词器路径如果为None则创建一个新的分词器
vectorizer_path: 向量化器路径如果为None则不使用向量化器
class_names: 类别名称列表如果为None则使用CATEGORIES
Returns:
(模型, 分词器, 向量化器)的元组
"""
# 加载模型
if model_path is None:
# 获取可用模型列表
models_info = ModelFactory.get_available_models()
if not models_info:
raise ValueError("未找到可用的模型,请指定模型路径")
# 使用最新的模型
model_path = models_info[0]['path']
logger.info(f"使用最新的模型: {model_path}")
# 加载模型
model = ModelFactory.load_model(model_path)
# 加载或创建分词器
if tokenizer_path:
tokenizer = ChineseTokenizer() # 实际上应该从文件加载,这里简化处理
logger.info(f"已加载分词器: {tokenizer_path}")
else:
tokenizer = ChineseTokenizer()
logger.info("已创建新的分词器")
# 加载向量化器
vectorizer = None
if vectorizer_path:
vectorizer = SequenceVectorizer() # 实际上应该从文件加载,这里简化处理
vectorizer.load(vectorizer_path)
logger.info(f"已加载向量化器: {vectorizer_path}")
return model, tokenizer, vectorizer
def predict_text(args):
"""处理单条文本预测命令"""
# 加载模型和组件
model, tokenizer, vectorizer = load_model_and_components(
args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
)
# 创建预测器
predictor = Predictor(
model=model,
tokenizer=tokenizer,
vectorizer=vectorizer,
class_names=args.class_names or CATEGORIES,
batch_size=args.batch_size
)
# 获取文本
text = args.text
# 如果提供的是文件路径而非文本内容
if args.file and os.path.exists(text):
text = read_text_file(text)
# 预测
result = predictor.predict(
text=text,
return_top_k=args.top_k,
return_probabilities=True
)
# 输出结果
if args.top_k > 1:
print("\n预测结果:")
for i, pred in enumerate(result):
print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
else:
print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})")
# 保存结果
if args.output:
if args.output.endswith('.json'):
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
else:
with open(args.output, 'w', encoding='utf-8') as f:
if args.top_k > 1:
f.write("rank,class,probability\n")
for i, pred in enumerate(result):
f.write(f"{i + 1},{pred['class']},{pred['probability']}\n")
else:
f.write(f"class,probability\n")
f.write(f"{result['class']},{result['probability']}\n")
print(f"结果已保存到: {args.output}")
def predict_batch(args):
"""处理批量文本预测命令"""
# 加载模型和组件
model, tokenizer, vectorizer = load_model_and_components(
args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
)
# 创建预测器
predictor = Predictor(
model=model,
tokenizer=tokenizer,
vectorizer=vectorizer,
class_names=args.class_names or CATEGORIES,
batch_size=args.batch_size
)
# 创建批处理器
batch_processor = BatchProcessor(
predictor=predictor,
batch_size=args.batch_size,
max_workers=args.workers
)
# 确保输出目录存在
if args.output:
ensure_dir(os.path.dirname(args.output))
# 根据输入类型选择处理方法
if args.input_type == 'file' and os.path.isfile(args.input):
# 单个文件
if args.large_file:
# 大型文件,分块处理
batch_processor.process_large_file(
file_path=args.input,
output_path=args.output,
return_top_k=args.top_k,
format=args.format
)
else:
# CSV或Excel文件
if args.input.endswith('.csv'):
df = pd.read_csv(args.input, encoding='utf-8')
elif args.input.endswith(('.xls', '.xlsx')):
df = pd.read_excel(args.input)
else:
print(f"不支持的文件格式: {args.input}")
return
# 检查文本列是否存在
if args.text_column not in df.columns:
print(f"文本列 '{args.text_column}' 不在输入文件中,可用列: {', '.join(df.columns)}")
return
# 处理DataFrame
result_df = batch_processor.process_dataframe(
df=df,
text_column=args.text_column,
id_column=args.id_column,
output_path=args.output,
return_top_k=args.top_k,
format=args.format
)
# 输出结果统计
print(f"\n已处理 {len(result_df)} 条文本")
print("类别分布:")
if args.top_k == 1:
class_counts = result_df['predicted_class'].value_counts()
for cls, count in class_counts.items():
print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)")
elif args.input_type == 'dir' and os.path.isdir(args.input):
# 目录
result_df = batch_processor.process_directory(
directory=args.input,
pattern=args.pattern,
output_path=args.output,
return_top_k=args.top_k,
format=args.format,
recursive=args.recursive
)
# 输出结果统计
if not result_df.empty:
print(f"\n已处理 {len(result_df)} 个文件")
print("类别分布:")
if args.top_k == 1:
class_counts = result_df['predicted_class'].value_counts()
for cls, count in class_counts.items():
print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)")
else:
print(f"无效的输入: {args.input}")
def list_models(args):
"""列出可用的模型"""
models_info = ModelFactory.get_available_models()
if not models_info:
print("未找到可用的模型")
return
print(f"找到 {len(models_info)} 个可用模型:")
for i, info in enumerate(models_info):
print(f"\n{i + 1}. {info['name']} ({info['type']})")
print(f" 路径: {info['path']}")
print(f" 创建时间: {info['created_time']}")
print(f" 类别数: {info['num_classes']}")
print(f" 文件大小: {info['file_size']}")
def interactive_mode(args):
"""交互模式"""
print("启动交互模式...")
# 加载模型和组件
model, tokenizer, vectorizer = load_model_and_components(
args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
)
# 创建预测器
predictor = Predictor(
model=model,
tokenizer=tokenizer,
vectorizer=vectorizer,
class_names=args.class_names or CATEGORIES,
batch_size=args.batch_size
)
print("\n模型已加载,可以开始交互式文本分类")
print("输入 'quit''exit' 退出交互模式\n")
while True:
try:
# 获取用户输入
text = input("请输入要分类的文本: ")
# 检查是否退出
if text.lower() in ['quit', 'exit', 'q']:
print("退出交互模式")
break
# 空输入
if not text.strip():
continue
# 预测
result = predictor.predict(
text=text,
return_top_k=args.top_k,
return_probabilities=True
)
# 输出结果
if args.top_k > 1:
print("\n预测结果:")
for i, pred in enumerate(result):
print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
else:
print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})")
print() # 空行
except KeyboardInterrupt:
print("\n退出交互模式")
break
except Exception as e:
print(f"处理过程中出错: {e}")
def main():
"""主函数,解析命令行参数并调用相应的功能"""
parser = argparse.ArgumentParser(description="中文文本分类系统命令行工具")
# 创建子命令
subparsers = parser.add_subparsers(dest="command", help="子命令")
# 预测单条文本命令
predict_parser = subparsers.add_parser("predict", help="预测单条文本")
predict_parser.add_argument("text", help="要预测的文本或文本文件路径")
predict_parser.add_argument("--file", action="store_true", help="将text参数视为文件路径")
predict_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
predict_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
predict_parser.add_argument("--vectorizer_path", help="向量化器路径")
predict_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
predict_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别")
predict_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
predict_parser.add_argument("--output", help="保存预测结果的文件路径")
predict_parser.set_defaults(func=predict_text)
# 批量预测命令
batch_parser = subparsers.add_parser("batch", help="批量预测文本")
batch_parser.add_argument("input", help="输入文件或目录路径")
batch_parser.add_argument("--input_type", choices=["file", "dir"], default="file", help="输入类型")
batch_parser.add_argument("--text_column", default="text", help="CSV/Excel文件中的文本列名")
batch_parser.add_argument("--id_column", help="CSV/Excel文件中的ID列名")
batch_parser.add_argument("--pattern", default="*.txt", help="文件匹配模式")
batch_parser.add_argument("--recursive", action="store_true", help="递归处理子目录")
batch_parser.add_argument("--large_file", action="store_true", help="处理大型文本文件")
batch_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
batch_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
batch_parser.add_argument("--vectorizer_path", help="向量化器路径")
batch_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
batch_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别")
batch_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
batch_parser.add_argument("--workers", type=int, default=4, help="工作线程数")
batch_parser.add_argument("--output", required=True, help="输出文件路径")
batch_parser.add_argument("--format", choices=["csv", "json"], default="csv", help="输出格式")
batch_parser.set_defaults(func=predict_batch)
# 列出可用模型命令
list_parser = subparsers.add_parser("list", help="列出可用的模型")
list_parser.set_defaults(func=list_models)
# 交互模式命令
interactive_parser = subparsers.add_parser("interactive", help="启动交互式分类模式")
interactive_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
interactive_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
interactive_parser.add_argument("--vectorizer_path", help="向量化器路径")
interactive_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
interactive_parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别")
interactive_parser.add_argument("--batch_size", type=int, default=1, help="批大小")
interactive_parser.set_defaults(func=interactive_mode)
# 解析参数
args = parser.parse_args()
# 如果没有指定命令,显示帮助
if not hasattr(args, 'func'):
parser.print_help()
return
# 执行命令
try:
args.func(args)
except Exception as e:
logger.error(f"执行命令时出错: {e}")
print(f"执行命令时出错: {e}")
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

260
interface/web/app.py Normal file
View File

@ -0,0 +1,260 @@
return render_template('predict_text.html', text=text, predictions=predictions)
except Exception as e:
logger.error(f"预测文本时出错: {e}")
flash(f'预测失败: {str(e)}')
return render_template('predict_text.html', text=text)
return render_template('predict_text.html')
@app.route('/predict_file', methods=['GET', 'POST'])
def predict_file():
"""文件预测页面"""
if request.method == 'POST':
# 检查是否有文件上传
if 'file' not in request.files:
flash('未选择文件')
return render_template('predict_file.html')
file = request.files['file']
# 如果用户没有选择文件
if file.filename == '':
flash('未选择文件')
return render_template('predict_file.html')
if file and allowed_file(file.filename):
# 获取参数
top_k = int(request.form.get('top_k', 3))
text_column = request.form.get('text_column', 'text')
try:
# 安全地保存文件
filename = secure_filename(file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
# 处理文件类型
if filename.endswith('.txt'):
# 文本文件
with open(file_path, 'r', encoding='utf-8') as f:
text = f.read()
# 获取预测器
predictor = get_predictor()
# 预测
result = predictor.predict(
text=text,
return_top_k=top_k,
return_probabilities=True
)
# 准备结果
predictions = result if top_k > 1 else [result]
return render_template(
'predict_file.html',
filename=filename,
file_type='text',
predictions=predictions
)
elif filename.endswith(('.csv', '.xls', '.xlsx')):
# 表格文件
if filename.endswith('.csv'):
df = pd.read_csv(file_path)
else:
df = pd.read_excel(file_path)
# 检查文本列是否存在
if text_column not in df.columns:
flash(f"文件中没有找到列: {text_column}")
return render_template('predict_file.html')
# 提取文本
texts = df[text_column].fillna('').tolist()
# 获取预测器
predictor = get_predictor()
# 批量预测
results = predictor.predict_batch(
texts=texts[:100], # 仅处理前100条记录
return_top_k=top_k,
return_probabilities=True
)
# 准备结果
batch_results = []
for i, result in enumerate(results):
batch_results.append({
'id': i,
'text': texts[i][:100] + '...' if len(texts[i]) > 100 else texts[i],
'predictions': result if top_k > 1 else [result]
})
return render_template(
'predict_file.html',
filename=filename,
file_type='table',
total_records=len(texts),
processed_records=min(100, len(texts)),
batch_results=batch_results
)
else:
flash(f"不支持的文件类型: {filename}")
return render_template('predict_file.html')
except Exception as e:
logger.error(f"处理文件时出错: {e}")
flash(f'处理失败: {str(e)}')
return render_template('predict_file.html')
else:
flash(f'不支持的文件类型,允许的类型: {", ".join(ALLOWED_EXTENSIONS)}')
return render_template('predict_file.html')
return render_template('predict_file.html')
@app.route('/batch_predict', methods=['GET', 'POST'])
def batch_predict():
"""批量预测页面"""
if request.method == 'POST':
texts = request.form.get('texts', '')
top_k = int(request.form.get('top_k', 3))
if not texts:
flash('请输入文本内容')
return render_template('batch_predict.html')
# 分割文本
text_list = [text.strip() for text in texts.split('\n') if text.strip()]
if not text_list:
flash('请输入有效的文本内容')
return render_template('batch_predict.html', texts=texts)
try:
# 获取预测器
predictor = get_predictor()
# 批量预测
results = predictor.predict_batch(
texts=text_list,
return_top_k=top_k,
return_probabilities=True
)
# 准备结果
batch_results = []
for i, result in enumerate(results):
batch_results.append({
'id': i + 1,
'text': text_list[i],
'predictions': result if top_k > 1 else [result]
})
return render_template(
'batch_predict.html',
texts=texts,
batch_results=batch_results
)
except Exception as e:
logger.error(f"批量预测时出错: {e}")
flash(f'预测失败: {str(e)}')
return render_template('batch_predict.html', texts=texts)
return render_template('batch_predict.html')
@app.route('/models')
def list_models():
"""模型列表页面"""
try:
# 获取可用模型列表
models_info = ModelFactory.get_available_models()
return render_template('models.html', models=models_info)
except Exception as e:
logger.error(f"获取模型列表时出错: {e}")
flash(f'获取模型列表失败: {str(e)}')
return render_template('models.html', models=[])
@app.route('/about')
def about():
"""关于页面"""
return render_template('about.html')
@app.errorhandler(404)
def page_not_found(e):
"""404错误处理"""
return render_template('404.html'), 404
@app.errorhandler(500)
def internal_server_error(e):
"""500错误处理"""
return render_template('500.html'), 500
# 过滤器
@app.template_filter('format_time')
def format_time_filter(timestamp):
"""格式化时间戳"""
if isinstance(timestamp, str):
try:
dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
return dt.strftime("%Y年%m月%d%H:%M")
except:
return timestamp
return timestamp
@app.template_filter('truncate_text')
def truncate_text_filter(text, length=100):
"""截断文本"""
if len(text) <= length:
return text
return text[:length] + '...'
@app.template_filter('format_percent')
def format_percent_filter(value):
"""格式化百分比"""
if isinstance(value, (int, float)):
return f"{value * 100:.2f}%"
return value
def run_server(host: str = "0.0.0.0", port: int = 5000, debug: bool = False):
"""
运行Web服务器
Args:
host: 主机地址
port: 端口号
debug: 是否开启调试模式
"""
app.run(host=host, port=port, debug=debug)
if __name__ == "__main__":
# 解析命令行参数
import argparse
parser = argparse.ArgumentParser(description="中文文本分类系统Web应用")
parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址")
parser.add_argument("--port", type=int, default=5000, help="服务器端口号")
parser.add_argument("--debug", action="store_true", help="是否开启调试模式")
args = parser.parse_args()
# 运行服务器
run_server(host=args.host, port=args.port, debug=args.debug)

166
interface/web/routes.py Normal file
View File

@ -0,0 +1,166 @@
"""
Web路由模块定义Web应用的路由
"""
from flask import Blueprint, render_template, request, jsonify, redirect, url_for, session, flash
from werkzeug.utils import secure_filename
import os
import pandas as pd
import io
import time
from typing import List, Dict, Tuple, Optional
from interface.web.app import get_predictor, allowed_file, app
# 创建蓝图
bp = Blueprint('routes', __name__)
@bp.route('/')
def index():
"""首页"""
return render_template('index.html')
@bp.route('/predict_text', methods=['GET', 'POST'])
def predict_text():
"""文本预测页面"""
if request.method == 'POST':
text = request.form.get('text', '')
top_k = int(request.form.get('top_k', 3))
if not text:
flash('请输入文本内容')
return render_template('predict_text.html')
try:
# 获取预测器
predictor = get_predictor()
# 预测
result = predictor.predict(
text=text,
return_top_k=top_k,
return_probabilities=True
)
# 准备结果
predictions = result if top_k > 1 else [result]
return render_template('predict_text.html', text=text, predictions=predictions)
except Exception as e:
flash(f'预测失败: {str(e)}')
return render_template('predict_text.html', text=text)
return render_template('predict_text.html')
@bp.route('/predict_file', methods=['GET', 'POST'])
def predict_file():
"""文件预测页面"""
if request.method == 'POST':
# 检查是否有文件上传
if 'file' not in request.files:
flash('未选择文件')
return render_template('predict_file.html')
file = request.files['file']
# 如果用户没有选择文件
if file.filename == '':
flash('未选择文件')
return render_template('predict_file.html')
if file and allowed_file(file.filename):
# 获取参数
top_k = int(request.form.get('top_k', 3))
text_column = request.form.get('text_column', 'text')
try:
# 安全地保存文件
filename = secure_filename(file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
# 处理文件
# 注意此处与app.py中的predict_file重复实际项目中应该将逻辑抽取到单独的函数中
# 这里简化处理,仅返回渲染模板
return render_template('predict_file.html', filename=filename)
except Exception as e:
flash(f'处理失败: {str(e)}')
return render_template('predict_file.html')
else:
flash(f'不支持的文件类型')
return render_template('predict_file.html')
return render_template('predict_file.html')
@bp.route('/batch_predict', methods=['GET', 'POST'])
def batch_predict():
"""批量预测页面"""
if request.method == 'POST':
texts = request.form.get('texts', '')
top_k = int(request.form.get('top_k', 3))
if not texts:
flash('请输入文本内容')
return render_template('batch_predict.html')
# 分割文本
text_list = [text.strip() for text in texts.split('\n') if text.strip()]
if not text_list:
flash('请输入有效的文本内容')
return render_template('batch_predict.html', texts=texts)
try:
# 获取预测器
predictor = get_predictor()
# 批量预测
results = predictor.predict_batch(
texts=text_list,
return_top_k=top_k,
return_probabilities=True
)
# 准备结果
batch_results = []
for i, result in enumerate(results):
batch_results.append({
'id': i + 1,
'text': text_list[i],
'predictions': result if top_k > 1 else [result]
})
return render_template(
'batch_predict.html',
texts=texts,
batch_results=batch_results
)
except Exception as e:
flash(f'预测失败: {str(e)}')
return render_template('batch_predict.html', texts=texts)
return render_template('batch_predict.html')
@bp.route('/models')
def list_models():
"""模型列表页面"""
return render_template('models.html')
@bp.route('/about')
def about():
"""关于页面"""
return render_template('about.html')
# 注册蓝图
def init_app(app):
app.register_blueprint(bp)

142
main.py Normal file
View File

@ -0,0 +1,142 @@
"""
主入口文件整合系统的所有功能提供命令行接口
"""
import os
import sys
import argparse
import logging
from typing import List, Optional
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("main")
# 命令行参数
parser = argparse.ArgumentParser(description="中文文本分类系统")
subparsers = parser.add_subparsers(dest="command", help="命令")
# 训练命令
train_parser = subparsers.add_parser("train", help="训练模型")
train_parser.add_argument("--data_dir", help="数据目录")
train_parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型")
train_parser.add_argument("--epochs", type=int, default=10, help="训练轮数")
train_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
train_parser.add_argument("--save_dir", help="模型保存目录")
# 评估命令
evaluate_parser = subparsers.add_parser("evaluate", help="评估模型")
evaluate_parser.add_argument("--model_path", required=True, help="模型路径")
evaluate_parser.add_argument("--data_dir", help="数据目录")
evaluate_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
evaluate_parser.add_argument("--output_dir", help="评估结果输出目录")
# 预测命令
predict_parser = subparsers.add_parser("predict", help="使用模型预测")
predict_parser.add_argument("--model_path", help="模型路径")
predict_parser.add_argument("--text", help="要预测的文本")
predict_parser.add_argument("--file", help="要预测的文件")
predict_parser.add_argument("--output", help="输出文件")
# Web服务命令
web_parser = subparsers.add_parser("web", help="启动Web服务")
web_parser.add_argument("--host", default="0.0.0.0", help="服务器主机")
web_parser.add_argument("--port", type=int, default=5000, help="服务器端口")
web_parser.add_argument("--debug", action="store_true", help="是否开启调试模式")
# API服务命令
api_parser = subparsers.add_parser("api", help="启动API服务")
api_parser.add_argument("--host", default="0.0.0.0", help="服务器主机")
api_parser.add_argument("--port", type=int, default=8000, help="服务器端口")
# CLI命令
cli_parser = subparsers.add_parser("cli", help="启动命令行接口")
cli_parser.add_argument("--model_path", help="模型路径")
cli_parser.add_argument("--interactive", action="store_true", help="是否开启交互模式")
def main():
"""主函数"""
args = parser.parse_args()
# 如果没有指定命令,显示帮助信息
if not args.command:
parser.print_help()
return 0
# 根据命令调用相应的功能
if args.command == "train":
# 导入训练模块
from scripts.train import train_model
# 调用训练功能
train_model(
data_dir=args.data_dir,
model_type=args.model_type,
epochs=args.epochs,
batch_size=args.batch_size,
save_dir=args.save_dir
)
elif args.command == "evaluate":
# 导入评估模块
from scripts.evaluate import evaluate_model
# 调用评估功能
evaluate_model(
model_path=args.model_path,
data_dir=args.data_dir,
batch_size=args.batch_size,
output_dir=args.output_dir
)
elif args.command == "predict":
# 导入预测模块
from scripts.predict import predict_text, predict_file
# 根据输入类型调用相应的预测功能
if args.text:
predict_text(args.text, args.model_path, args.output)
elif args.file:
predict_file(args.file, args.model_path, args.output)
else:
logger.error("请提供要预测的文本或文件")
return 1
elif args.command == "web":
# 导入Web服务模块
from interface.web.app import run_server
# 启动Web服务
run_server(host=args.host, port=args.port, debug=args.debug)
elif args.command == "api":
# 导入API服务模块
from interface.api import run_server
# 启动API服务
run_server(host=args.host, port=args.port)
elif args.command == "cli":
# 导入CLI模块
from interface.cli import main as cli_main
# 将命令行参数转换为CLI模块可接受的格式
sys.argv = ["interface/cli.py"]
if args.interactive:
sys.argv.append("interactive")
if args.model_path:
sys.argv.extend(["--model_path", args.model_path])
elif args.model_path:
sys.argv.extend(["list", "--model_path", args.model_path])
else:
sys.argv.append("list")
# 调用CLI主函数
return cli_main()
return 0
if __name__ == "__main__":
sys.exit(main())

0
models/__init__.py Normal file
View File

419
models/base_model.py Normal file
View File

@ -0,0 +1,419 @@
"""
模型基类定义所有文本分类模型的通用接口
"""
import os
import time
import json
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
from abc import ABC, abstractmethod
from config.system_config import SAVED_MODELS_DIR, CLASSIFIERS_DIR
from config.model_config import (
BATCH_SIZE, LEARNING_RATE, EARLY_STOPPING_PATIENCE,
REDUCE_LR_PATIENCE, REDUCE_LR_FACTOR
)
from utils.logger import get_logger
from utils.file_utils import ensure_dir, save_json
logger = get_logger("BaseModel")
class TextClassificationModel(ABC):
"""文本分类模型基类,定义所有模型的通用接口"""
def __init__(self, num_classes: int, model_name: str = "text_classifier",
batch_size: int = BATCH_SIZE,
learning_rate: float = LEARNING_RATE):
"""
初始化文本分类模型
Args:
num_classes: 类别数量
model_name: 模型名称
batch_size: 批大小
learning_rate: 学习率
"""
self.num_classes = num_classes
self.model_name = model_name
self.batch_size = batch_size
self.learning_rate = learning_rate
# 模型实例
self.model = None
# 训练历史
self.history = None
# 训练配置
self.config = {
"model_name": model_name,
"num_classes": num_classes,
"batch_size": batch_size,
"learning_rate": learning_rate
}
# 验证集合最佳性能
self.best_val_loss = float('inf')
self.best_val_accuracy = 0.0
logger.info(f"初始化 {model_name} 模型,类别数: {num_classes}")
@abstractmethod
def build(self) -> None:
"""构建模型架构,这是一个抽象方法,子类必须实现"""
pass
def compile(self, optimizer: Optional[tf.keras.optimizers.Optimizer] = None,
loss: Optional[Union[str, tf.keras.losses.Loss]] = None,
metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None) -> None:
"""
编译模型
Args:
optimizer: 优化器默认为Adam
loss: 损失函数默认为sparse_categorical_crossentropy
metrics: 评估指标默认为accuracy
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 默认优化器
if optimizer is None:
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
# 默认损失函数
if loss is None:
loss = 'sparse_categorical_crossentropy'
# 默认评估指标
if metrics is None:
metrics = ['accuracy']
# 编译模型
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
logger.info(f"模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}, 评估指标: {metrics}")
def summary(self) -> None:
"""打印模型概要"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
self.model.summary()
def fit(self, x_train: Union[np.ndarray, tf.data.Dataset],
y_train: Optional[np.ndarray] = None,
validation_data: Optional[Union[Tuple[np.ndarray, np.ndarray], tf.data.Dataset]] = None,
epochs: int = 10,
callbacks: Optional[List[tf.keras.callbacks.Callback]] = None,
class_weights: Optional[Dict[int, float]] = None,
verbose: int = 1) -> tf.keras.callbacks.History:
"""
训练模型
Args:
x_train: 训练数据特征
y_train: 训练数据标签
validation_data: 验证数据
epochs: 训练轮数
callbacks: 回调函数列表
class_weights: 类别权重
verbose: 详细程度
Returns:
训练历史
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 记录开始时间
start_time = time.time()
# 添加默认回调函数
if callbacks is None:
callbacks = self._get_default_callbacks()
# 训练模型
if isinstance(x_train, tf.data.Dataset):
# 如果输入是TensorFlow Dataset
history = self.model.fit(
x_train,
epochs=epochs,
validation_data=validation_data,
callbacks=callbacks,
class_weight=class_weights,
verbose=verbose
)
else:
# 如果输入是NumPy数组
history = self.model.fit(
x_train, y_train,
batch_size=self.batch_size,
epochs=epochs,
validation_data=validation_data,
callbacks=callbacks,
class_weight=class_weights,
verbose=verbose
)
# 计算训练时间
train_time = time.time() - start_time
# 保存训练历史
self.history = history.history
self.history['train_time'] = train_time
logger.info(f"模型训练完成,耗时: {train_time:.2f}")
return history
def evaluate(self, x_test: Union[np.ndarray, tf.data.Dataset],
y_test: Optional[np.ndarray] = None,
verbose: int = 1) -> Dict[str, float]:
"""
评估模型
Args:
x_test: 测试数据特征
y_test: 测试数据标签
verbose: 详细程度
Returns:
评估结果字典
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 评估模型
if isinstance(x_test, tf.data.Dataset):
# 如果输入是TensorFlow Dataset
results = self.model.evaluate(x_test, verbose=verbose)
else:
# 如果输入是NumPy数组
results = self.model.evaluate(x_test, y_test, batch_size=self.batch_size, verbose=verbose)
# 构建评估结果字典
metrics_names = self.model.metrics_names
evaluation_results = {name: float(value) for name, value in zip(metrics_names, results)}
logger.info(f"模型评估结果: {evaluation_results}")
return evaluation_results
def predict(self, x: Union[np.ndarray, tf.data.Dataset, List],
batch_size: Optional[int] = None,
verbose: int = 0) -> np.ndarray:
"""
使用模型进行预测
Args:
x: 预测数据
batch_size: 批大小
verbose: 详细程度
Returns:
预测结果
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 使用模型进行预测
if batch_size is None:
batch_size = self.batch_size
return self.model.predict(x, batch_size=batch_size, verbose=verbose)
def predict_classes(self, x: Union[np.ndarray, tf.data.Dataset, List],
batch_size: Optional[int] = None,
verbose: int = 0) -> np.ndarray:
"""
使用模型预测类别
Args:
x: 预测数据
batch_size: 批大小
verbose: 详细程度
Returns:
预测的类别索引
"""
# 获取模型预测概率
predictions = self.predict(x, batch_size, verbose)
# 获取最大概率的类别索引
return np.argmax(predictions, axis=1)
def save(self, filepath: Optional[str] = None,
save_format: str = 'tf',
include_optimizer: bool = True) -> str:
"""
保存模型
Args:
filepath: 保存路径如果为None则使用默认路径
save_format: 保存格式'tf''h5'
include_optimizer: 是否包含优化器状态
Returns:
保存路径
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 如果未指定保存路径,使用默认路径
if filepath is None:
ensure_dir(CLASSIFIERS_DIR)
timestamp = time.strftime("%Y%m%d_%H%M%S")
filepath = os.path.join(CLASSIFIERS_DIR, f"{self.model_name}_{timestamp}")
# 保存模型
self.model.save(filepath, save_format=save_format, include_optimizer=include_optimizer)
# 保存模型配置
config_path = f"{filepath}_config.json"
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(self.config, f, ensure_ascii=False, indent=4)
logger.info(f"模型已保存到: {filepath}")
return filepath
@classmethod
def load(cls, filepath: str, custom_objects: Optional[Dict[str, Any]] = None) -> 'TextClassificationModel':
"""
加载模型
Args:
filepath: 模型文件路径
custom_objects: 自定义对象字典
Returns:
加载的模型实例
"""
# 加载模型配置
config_path = f"{filepath}_config.json"
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
except FileNotFoundError:
logger.warning(f"未找到模型配置文件: {config_path},将使用默认配置")
config = {}
# 创建模型实例
model_name = config.get('model_name', 'loaded_model')
num_classes = config.get('num_classes', 1)
batch_size = config.get('batch_size', BATCH_SIZE)
learning_rate = config.get('learning_rate', LEARNING_RATE)
instance = cls(num_classes, model_name, batch_size, learning_rate)
# 加载Keras模型
instance.model = load_model(filepath, custom_objects=custom_objects)
# 加载配置
instance.config = config
logger.info(f"{filepath} 加载模型成功")
return instance
def _get_default_callbacks(self) -> List[tf.keras.callbacks.Callback]:
"""获取默认的回调函数列表"""
# 早停
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=EARLY_STOPPING_PATIENCE,
restore_best_weights=True,
verbose=1
)
# 学习率衰减
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=REDUCE_LR_FACTOR,
patience=REDUCE_LR_PATIENCE,
min_lr=1e-6,
verbose=1
)
# 模型检查点
checkpoint_path = os.path.join(SAVED_MODELS_DIR, 'checkpoints', self.model_name)
ensure_dir(os.path.dirname(checkpoint_path))
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_best_only=True,
monitor='val_loss',
verbose=1
)
# TensorBoard日志
log_dir = os.path.join(SAVED_MODELS_DIR, 'logs', f"{self.model_name}_{time.strftime('%Y%m%d_%H%M%S')}")
ensure_dir(log_dir)
tensorboard = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1
)
return [early_stopping, reduce_lr, model_checkpoint, tensorboard]
def get_config(self) -> Dict[str, Any]:
"""获取模型配置"""
return self.config.copy()
def get_model(self) -> Model:
"""获取Keras模型实例"""
return self.model
def get_training_history(self) -> Optional[Dict[str, List[float]]]:
"""获取训练历史"""
return self.history
def plot_training_history(self, save_path: Optional[str] = None,
metrics: Optional[List[str]] = None) -> None:
"""
绘制训练历史
Args:
save_path: 保存路径如果为None则显示图像
metrics: 要绘制的指标列表默认为['loss', 'accuracy']
"""
if self.history is None:
raise ValueError("模型尚未训练,没有训练历史")
import matplotlib.pyplot as plt
if metrics is None:
metrics = ['loss', 'accuracy']
# 创建图形
plt.figure(figsize=(12, 5))
# 绘制指标
for i, metric in enumerate(metrics):
plt.subplot(1, len(metrics), i + 1)
if metric in self.history:
plt.plot(self.history[metric], label=f'train_{metric}')
val_metric = f'val_{metric}'
if val_metric in self.history:
plt.plot(self.history[val_metric], label=f'val_{metric}')
plt.title(f'Model {metric}')
plt.xlabel('Epoch')
plt.ylabel(metric)
plt.legend()
plt.tight_layout()
# 保存或显示图像
if save_path:
plt.savefig(save_path)
logger.info(f"训练历史图已保存到: {save_path}")
else:
plt.show()

180
models/cnn_model.py Normal file
View File

@ -0,0 +1,180 @@
"""
CNN模型实现基于卷积神经网络的文本分类模型
"""
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
Input, Embedding, Conv1D, MaxPooling1D, GlobalMaxPooling1D,
Dense, Dropout, Concatenate, BatchNormalization, Activation
)
from typing import List, Dict, Tuple, Optional, Any, Union
from config.model_config import (
MAX_SEQUENCE_LENGTH, CNN_CONFIG
)
from models.base_model import TextClassificationModel
from utils.logger import get_logger
logger = get_logger("CNNModel")
class CNNTextClassifier(TextClassificationModel):
"""卷积神经网络文本分类模型"""
def __init__(self, num_classes: int, vocab_size: int,
embedding_dim: int = CNN_CONFIG["embedding_dim"],
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
num_filters: int = CNN_CONFIG["num_filters"],
filter_sizes: List[int] = CNN_CONFIG["filter_sizes"],
dropout_rate: float = CNN_CONFIG["dropout_rate"],
l2_reg_lambda: float = CNN_CONFIG["l2_reg_lambda"],
embedding_matrix: Optional[np.ndarray] = None,
trainable_embedding: bool = True,
model_name: str = "cnn_text_classifier",
batch_size: int = 64,
learning_rate: float = 0.001):
"""
初始化CNN文本分类模型
Args:
num_classes: 类别数量
vocab_size: 词汇表大小
embedding_dim: 词嵌入维度
max_sequence_length: 最大序列长度
num_filters: 卷积核数量
filter_sizes: 卷积核大小列表
dropout_rate: Dropout比例
l2_reg_lambda: L2正则化系数
embedding_matrix: 预训练词嵌入矩阵如果为None则使用随机初始化
trainable_embedding: 词嵌入是否可训练
model_name: 模型名称
batch_size: 批大小
learning_rate: 学习率
"""
super().__init__(num_classes, model_name, batch_size, learning_rate)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.max_sequence_length = max_sequence_length
self.num_filters = num_filters
self.filter_sizes = filter_sizes
self.dropout_rate = dropout_rate
self.l2_reg_lambda = l2_reg_lambda
self.embedding_matrix = embedding_matrix
self.trainable_embedding = trainable_embedding
# 更新配置
self.config.update({
"vocab_size": vocab_size,
"embedding_dim": embedding_dim,
"max_sequence_length": max_sequence_length,
"num_filters": num_filters,
"filter_sizes": filter_sizes,
"dropout_rate": dropout_rate,
"l2_reg_lambda": l2_reg_lambda,
"trainable_embedding": trainable_embedding,
"model_type": "CNN"
})
logger.info(f"初始化CNN文本分类模型词汇表大小: {vocab_size}, 嵌入维度: {embedding_dim}")
def build(self) -> None:
"""构建CNN模型架构"""
# Input layer
sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
# Embedding layer
if self.embedding_matrix is not None:
embedding_layer = Embedding(
input_dim=self.vocab_size,
output_dim=self.embedding_dim,
weights=[self.embedding_matrix],
input_length=self.max_sequence_length,
trainable=self.trainable_embedding,
name='embedding'
)
else:
embedding_layer = Embedding(
input_dim=self.vocab_size,
output_dim=self.embedding_dim,
input_length=self.max_sequence_length,
trainable=True,
name='embedding'
)
embedded_sequences = embedding_layer(sequence_input)
# Convolutional layers with different filter sizes
conv_blocks = []
for filter_size in self.filter_sizes:
conv = Conv1D(
filters=self.num_filters,
kernel_size=filter_size,
padding='valid',
activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg_lambda),
name=f'conv_{filter_size}'
)(embedded_sequences)
# Max pooling
pooled = GlobalMaxPooling1D(name=f'max_pooling_{filter_size}')(conv)
conv_blocks.append(pooled)
# Concatenate pooled features if we have multiple filter sizes
if len(self.filter_sizes) > 1:
concatenated = Concatenate(name='concatenate')(conv_blocks)
else:
concatenated = conv_blocks[0]
# Dropout for regularization
x = Dropout(self.dropout_rate, name='dropout_1')(concatenated)
# Dense layer
x = Dense(128, name='dense_1')(x)
x = BatchNormalization(name='batch_norm_1')(x)
x = Activation('relu', name='activation_1')(x)
x = Dropout(self.dropout_rate, name='dropout_2')(x)
# Output layer
if self.num_classes == 2:
# Binary classification
predictions = Dense(1, activation='sigmoid', name='predictions')(x)
else:
# Multi-class classification
predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
# Build the model
self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
logger.info(f"CNN模型构建完成过滤器大小: {self.filter_sizes}, 每种大小的过滤器数量: {self.num_filters}")
def compile(self, optimizer=None, loss=None, metrics=None) -> None:
"""
编译CNN模型
Args:
optimizer: 优化器默认为Adam
loss: 损失函数默认根据类别数量选择
metrics: 评估指标默认为accuracy
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 默认优化器
if optimizer is None:
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
# 默认损失函数
if loss is None:
if self.num_classes == 2:
loss = 'binary_crossentropy'
else:
loss = 'sparse_categorical_crossentropy'
# 默认评估指标
if metrics is None:
metrics = ['accuracy']
# 编译模型
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
logger.info(f"CNN模型已编译优化器: {optimizer.__class__.__name__}, 损失函数: {loss}")

216
models/ensemble_model.py Normal file
View File

@ -0,0 +1,216 @@
"""
集成模型实现多个模型的集成
"""
import numpy as np
import tensorflow as tf
from typing import List, Dict, Tuple, Optional, Any, Union
import os
from config.system_config import CLASSIFIERS_DIR
from models.base_model import TextClassificationModel
from utils.logger import get_logger
logger = get_logger("EnsembleModel")
class EnsembleModel:
"""模型集成类,集成多个模型的预测结果"""
def __init__(self, models: List[TextClassificationModel],
weights: Optional[List[float]] = None,
voting: str = 'soft',
name: str = "ensemble_model"):
"""
初始化集成模型
Args:
models: 模型列表
weights: 各模型的权重默认为均等权重
voting: 投票方式'hard'表示多数投票'soft'表示概率平均
name: 集成模型名称
"""
self.models = models
self.num_models = len(models)
# 验证模型数量
if self.num_models == 0:
raise ValueError("模型列表不能为空")
# 设置权重
if weights is None:
self.weights = np.ones(self.num_models) / self.num_models
else:
if len(weights) != self.num_models:
raise ValueError("权重数量必须与模型数量相同")
# 归一化权重
self.weights = np.array(weights) / np.sum(weights)
# 验证投票方式
self.voting = voting.lower()
if self.voting not in ['hard', 'soft']:
raise ValueError("无效的投票方式,支持的方式: 'hard', 'soft'")
# 从第一个模型获取类别数
self.num_classes = models[0].num_classes
# 验证所有模型的类别数是否相同
for i, model in enumerate(models[1:], 1):
if model.num_classes != self.num_classes:
raise ValueError(
f"模型 {i} 的类别数 ({model.num_classes}) 与第一个模型的类别数 ({self.num_classes}) 不同")
self.name = name
logger.info(f"初始化集成模型,包含 {self.num_models} 个模型,投票方式: {self.voting}")
def predict(self, x: Union[np.ndarray, tf.data.Dataset, List],
batch_size: Optional[int] = None,
verbose: int = 0) -> np.ndarray:
"""
使用集成模型进行预测
Args:
x: 预测数据
batch_size: 批大小
verbose: 详细程度
Returns:
预测概率
"""
# 获取每个模型的预测结果
all_predictions = []
for i, model in enumerate(self.models):
logger.info(f"获取模型 {i + 1}/{self.num_models} 的预测结果")
predictions = model.predict(x, batch_size, verbose)
# 如果是二分类且输出形状是(n,1),转换为(n,2)
if self.num_classes == 2 and predictions.shape[1:] == (1,):
predictions = np.hstack([1 - predictions, predictions])
all_predictions.append(predictions)
# 根据投票方式进行集成
if self.voting == 'hard':
# 硬投票:每个模型预测的类别,取众数
individual_classes = [np.argmax(pred, axis=1) for pred in all_predictions]
# 获取带权重的预测类别频率
ensemble_result = np.zeros((len(x), self.num_classes))
for i, classes in enumerate(individual_classes):
for j, cls in enumerate(classes):
ensemble_result[j, cls] += self.weights[i]
return ensemble_result
else: # soft voting
# 软投票:对每个模型的预测概率进行加权平均
weighted_predictions = [pred * weight for pred, weight in zip(all_predictions, self.weights)]
ensemble_result = np.sum(weighted_predictions, axis=0)
return ensemble_result
def predict_classes(self, x: Union[np.ndarray, tf.data.Dataset, List],
batch_size: Optional[int] = None,
verbose: int = 0) -> np.ndarray:
"""
使用集成模型预测类别
Args:
x: 预测数据
batch_size: 批大小
verbose: 详细程度
Returns:
预测的类别索引
"""
# 获取预测概率
predictions = self.predict(x, batch_size, verbose)
# 获取最大概率的类别索引
return np.argmax(predictions, axis=1)
def save(self, directory: Optional[str] = None) -> str:
"""
保存集成模型
Args:
directory: 保存目录默认为CLASSIFIERS_DIR
Returns:
保存路径
"""
if directory is None:
import time
timestamp = time.strftime("%Y%m%d_%H%M%S")
directory = os.path.join(CLASSIFIERS_DIR, f"{self.name}_{timestamp}")
os.makedirs(directory, exist_ok=True)
# 保存模型列表
model_paths = []
for i, model in enumerate(self.models):
model_path = os.path.join(directory, f"model_{i}")
model.save(model_path)
model_paths.append(model_path)
# 保存集成配置
config = {
"name": self.name,
"num_models": self.num_models,
"model_paths": model_paths,
"weights": self.weights.tolist(),
"voting": self.voting,
"num_classes": self.num_classes
}
import json
config_path = os.path.join(directory, "ensemble_config.json")
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(config, f, ensure_ascii=False, indent=4)
logger.info(f"集成模型已保存到目录: {directory}")
return directory
@classmethod
def load(cls, directory: str, custom_objects: Optional[Dict[str, Any]] = None) -> 'EnsembleModel':
"""
加载集成模型
Args:
directory: 模型目录
custom_objects: 自定义对象字典
Returns:
加载的集成模型实例
"""
# 加载配置
config_path = os.path.join(directory, "ensemble_config.json")
import json
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# 加载子模型
from models.model_factory import ModelFactory
models = []
model_paths = config["model_paths"]
for model_path in model_paths:
model = ModelFactory.load_model(model_path, custom_objects)
models.append(model)
# 创建集成模型
ensemble = cls(
models=models,
weights=config["weights"],
voting=config["voting"],
name=config["name"]
)
logger.info(f"从目录 {directory} 加载集成模型成功")
return ensemble

View File

169
models/model_factory.py Normal file
View File

@ -0,0 +1,169 @@
"""
模型工厂统一创建和管理不同类型的模型
"""
from typing import List, Dict, Tuple, Optional, Any, Union
import os
import glob
import time
import numpy as np
from config.system_config import CLASSIFIERS_DIR
from config.model_config import (
BATCH_SIZE, LEARNING_RATE
)
from models.base_model import TextClassificationModel
from models.cnn_model import CNNTextClassifier
from models.rnn_model import RNNTextClassifier
from models.transformer_model import TransformerTextClassifier
from utils.logger import get_logger
logger = get_logger("ModelFactory")
class ModelFactory:
"""模型工厂,用于创建和管理不同类型的模型"""
@staticmethod
def create_model(model_type: str, num_classes: int, vocab_size: int,
embedding_matrix: Optional[np.ndarray] = None,
model_config: Optional[Dict[str, Any]] = None,
**kwargs) -> TextClassificationModel:
"""
创建指定类型的模型
Args:
model_type: 模型类型可选值: 'cnn', 'rnn', 'transformer'
num_classes: 类别数量
vocab_size: 词汇表大小
embedding_matrix: 预训练词嵌入矩阵
model_config: 模型配置字典
**kwargs: 其他参数
Returns:
创建的模型实例
"""
model_type = model_type.lower()
# 合并配置
config = model_config or {}
config.update(kwargs)
# 创建模型
if model_type == 'cnn':
model = CNNTextClassifier(
num_classes=num_classes,
vocab_size=vocab_size,
embedding_matrix=embedding_matrix,
**config
)
elif model_type == 'rnn':
model = RNNTextClassifier(
num_classes=num_classes,
vocab_size=vocab_size,
embedding_matrix=embedding_matrix,
**config
)
elif model_type == 'transformer':
model = TransformerTextClassifier(
num_classes=num_classes,
vocab_size=vocab_size,
embedding_matrix=embedding_matrix,
**config
)
else:
raise ValueError(f"不支持的模型类型: {model_type}")
logger.info(f"已创建 {model_type.upper()} 模型")
return model
@staticmethod
def load_model(model_path: str, custom_objects: Optional[Dict[str, Any]] = None) -> TextClassificationModel:
"""
加载保存的模型
Args:
model_path: 模型路径
custom_objects: 自定义对象字典
Returns:
加载的模型实例
"""
# 添加Transformer相关的自定义对象
if custom_objects is None:
custom_objects = {}
if 'TransformerBlock' not in custom_objects:
from models.transformer_model import TransformerBlock
custom_objects['TransformerBlock'] = TransformerBlock
# 根据配置确定模型类型
model_config_path = f"{model_path}_config.json"
import json
with open(model_config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
model_type = config.get('model_type', '').lower()
# 根据模型类型选择加载方法
if model_type == 'cnn':
model = CNNTextClassifier.load(model_path, custom_objects)
elif model_type == 'rnn':
model = RNNTextClassifier.load(model_path, custom_objects)
elif model_type == 'transformer':
model = TransformerTextClassifier.load(model_path, custom_objects)
else:
# 如果无法确定模型类型,使用基类加载
logger.warning(f"无法确定模型类型,使用基类加载: {model_path}")
model = TextClassificationModel.load(model_path, custom_objects)
logger.info(f"已加载模型: {model_path}")
return model
@staticmethod
def get_available_models() -> List[Dict[str, Any]]:
"""
获取可用的已保存模型列表
Returns:
模型信息列表每个元素是包含模型信息的字典
"""
model_files = glob.glob(os.path.join(CLASSIFIERS_DIR, "*"))
model_files = [f for f in model_files if not f.endswith("_config.json")]
models_info = []
for model_file in model_files:
config_file = f"{model_file}_config.json"
if os.path.exists(config_file):
try:
import json
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
# 获取模型文件的创建时间
created_time = os.path.getctime(model_file)
created_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_time))
# 获取模型文件大小
file_size = os.path.getsize(model_file) / (1024 * 1024) # MB
models_info.append({
"path": model_file,
"name": config.get("model_name", os.path.basename(model_file)),
"type": config.get("model_type", "unknown"),
"num_classes": config.get("num_classes", 0),
"created_time": created_time_str,
"file_size": f"{file_size:.2f} MB",
"config": config
})
except Exception as e:
logger.error(f"读取模型配置失败: {config_file}, 错误: {e}")
# 按创建时间降序排序
models_info.sort(key=lambda x: x.get("created_time", ""), reverse=True)
return models_info

220
models/rnn_model.py Normal file
View File

@ -0,0 +1,220 @@
"""
RNN模型实现基于循环神经网络的文本分类模型
"""
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
Input, Embedding, LSTM, GRU, Bidirectional, Dense, Dropout,
BatchNormalization, Activation, GlobalMaxPooling1D, GlobalAveragePooling1D
)
from typing import List, Dict, Tuple, Optional, Any, Union
import numpy as np
from config.model_config import (
MAX_SEQUENCE_LENGTH, RNN_CONFIG
)
from models.base_model import TextClassificationModel
from utils.logger import get_logger
logger = get_logger("RNNModel")
class RNNTextClassifier(TextClassificationModel):
"""循环神经网络文本分类模型"""
def __init__(self, num_classes: int, vocab_size: int,
embedding_dim: int = RNN_CONFIG["embedding_dim"],
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
hidden_size: int = RNN_CONFIG["hidden_size"],
num_layers: int = RNN_CONFIG["num_layers"],
bidirectional: bool = RNN_CONFIG["bidirectional"],
rnn_type: str = "lstm", # 'lstm' or 'gru'
dropout_rate: float = RNN_CONFIG["dropout_rate"],
embedding_matrix: Optional[np.ndarray] = None,
trainable_embedding: bool = True,
pool_type: str = "max", # 'max', 'avg', or 'both'
model_name: str = "rnn_text_classifier",
batch_size: int = 64,
learning_rate: float = 0.001):
"""
初始化RNN文本分类模型
Args:
num_classes: 类别数量
vocab_size: 词汇表大小
embedding_dim: 词嵌入维度
max_sequence_length: 最大序列长度
hidden_size: 隐藏层大小
num_layers: RNN层数
bidirectional: 是否使用双向RNN
rnn_type: RNN类型'lstm''gru'
dropout_rate: Dropout比例
embedding_matrix: 预训练词嵌入矩阵如果为None则使用随机初始化
trainable_embedding: 词嵌入是否可训练
pool_type: 池化类型'max''avg''both'
model_name: 模型名称
batch_size: 批大小
learning_rate: 学习率
"""
super().__init__(num_classes, model_name, batch_size, learning_rate)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.max_sequence_length = max_sequence_length
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
self.rnn_type = rnn_type.lower()
self.dropout_rate = dropout_rate
self.embedding_matrix = embedding_matrix
self.trainable_embedding = trainable_embedding
self.pool_type = pool_type
# 验证RNN类型
if self.rnn_type not in ["lstm", "gru"]:
raise ValueError("无效的RNN类型支持的类型: 'lstm', 'gru'")
# 验证池化类型
if self.pool_type not in ["max", "avg", "both"]:
raise ValueError("无效的池化类型,支持的类型: 'max', 'avg', 'both'")
# 更新配置
self.config.update({
"vocab_size": vocab_size,
"embedding_dim": embedding_dim,
"max_sequence_length": max_sequence_length,
"hidden_size": hidden_size,
"num_layers": num_layers,
"bidirectional": bidirectional,
"rnn_type": rnn_type,
"dropout_rate": dropout_rate,
"trainable_embedding": trainable_embedding,
"pool_type": pool_type,
"model_type": "RNN"
})
logger.info(f"初始化RNN文本分类模型类型: {rnn_type.upper()}, 隐藏层大小: {hidden_size}, 层数: {num_layers}")
def build(self) -> None:
"""构建RNN模型架构"""
# Input layer
sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
# Embedding layer
if self.embedding_matrix is not None:
embedding_layer = Embedding(
input_dim=self.vocab_size,
output_dim=self.embedding_dim,
weights=[self.embedding_matrix],
input_length=self.max_sequence_length,
trainable=self.trainable_embedding,
name='embedding'
)
else:
embedding_layer = Embedding(
input_dim=self.vocab_size,
output_dim=self.embedding_dim,
input_length=self.max_sequence_length,
trainable=True,
name='embedding'
)
embedded_sequences = embedding_layer(sequence_input)
# 选择RNN层类型
if self.rnn_type == "lstm":
rnn_layer = LSTM
else: # gru
rnn_layer = GRU
# 构建多层RNN
x = embedded_sequences
for i in range(self.num_layers):
return_sequences = i < self.num_layers - 1 or self.pool_type != "last"
if self.bidirectional:
x = Bidirectional(
rnn_layer(
self.hidden_size,
return_sequences=return_sequences,
dropout=self.dropout_rate if i < self.num_layers - 1 else 0,
name=f'{self.rnn_type}_{i + 1}'
)
)(x)
else:
x = rnn_layer(
self.hidden_size,
return_sequences=return_sequences,
dropout=self.dropout_rate if i < self.num_layers - 1 else 0,
name=f'{self.rnn_type}_{i + 1}'
)(x)
# 根据池化类型选择池化方法
if self.pool_type == "max":
# 使用全局最大池化
pooled = GlobalMaxPooling1D(name='global_max_pooling')(x)
elif self.pool_type == "avg":
# 使用全局平均池化
pooled = GlobalAveragePooling1D(name='global_avg_pooling')(x)
elif self.pool_type == "both":
# 同时使用最大池化和平均池化,然后拼接
max_pooled = GlobalMaxPooling1D(name='global_max_pooling')(x)
avg_pooled = GlobalAveragePooling1D(name='global_avg_pooling')(x)
pooled = tf.keras.layers.Concatenate(name='concatenate')([max_pooled, avg_pooled])
else: # "last",使用最后一个时间步的输出
# 最后一层RNN已经返回了最后一个时间步的状态不需要额外池化
pooled = x
# Dropout for regularization
x = Dropout(self.dropout_rate, name='dropout_1')(pooled)
# Dense layer
x = Dense(128, name='dense_1')(x)
x = BatchNormalization(name='batch_norm_1')(x)
x = Activation('relu', name='activation_1')(x)
x = Dropout(self.dropout_rate, name='dropout_2')(x)
# Output layer
if self.num_classes == 2:
# Binary classification
predictions = Dense(1, activation='sigmoid', name='predictions')(x)
else:
# Multi-class classification
predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
# Build the model
self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
logger.info(
f"RNN模型构建完成类型: {self.rnn_type.upper()}, 双向: {self.bidirectional}, 池化类型: {self.pool_type}")
def compile(self, optimizer=None, loss=None, metrics=None) -> None:
"""
编译RNN模型
Args:
optimizer: 优化器默认为Adam
loss: 损失函数默认根据类别数量选择
metrics: 评估指标默认为accuracy
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 默认优化器
if optimizer is None:
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
# 默认损失函数
if loss is None:
if self.num_classes == 2:
loss = 'binary_crossentropy'
else:
loss = 'sparse_categorical_crossentropy'
# 默认评估指标
if metrics is None:
metrics = ['accuracy']
# 编译模型
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
logger.info(f"RNN模型已编译优化器: {optimizer.__class__.__name__}, 损失函数: {loss}")

270
models/transformer_model.py Normal file
View File

@ -0,0 +1,270 @@
"""
Transformer模型实现基于Transformer的文本分类模型
"""
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
Input, Embedding, Dense, Dropout, LayerNormalization,
GlobalAveragePooling1D, MultiHeadAttention, Add
)
from typing import List, Dict, Tuple, Optional, Any, Union
import numpy as np
from config.model_config import (
MAX_SEQUENCE_LENGTH, TRANSFORMER_CONFIG
)
from models.base_model import TextClassificationModel
from utils.logger import get_logger
logger = get_logger("TransformerModel")
class TransformerBlock(tf.keras.layers.Layer):
"""Transformer块包含多头注意力和前馈网络"""
def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout_rate: float = 0.1):
"""
初始化Transformer块
Args:
embed_dim: 嵌入维度
num_heads: 注意力头数
ff_dim: 前馈网络维度
dropout_rate: Dropout比例
"""
super(TransformerBlock, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.ff_dim = ff_dim
self.dropout_rate = dropout_rate
self.attention = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
self.ffn = tf.keras.Sequential([
Dense(ff_dim, activation="relu"),
Dense(embed_dim),
])
self.layernorm1 = LayerNormalization(epsilon=1e-6)
self.layernorm2 = LayerNormalization(epsilon=1e-6)
self.dropout1 = Dropout(dropout_rate)
self.dropout2 = Dropout(dropout_rate)
def call(self, inputs, training=False):
"""
前向传播
Args:
inputs: 输入张量
training: 是否处于训练模式
Returns:
输出张量
"""
# 多头自注意力
attention_output = self.attention(inputs, inputs)
attention_output = self.dropout1(attention_output, training=training)
out1 = self.layernorm1(inputs + attention_output)
# 前馈网络
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layernorm2(out1 + ffn_output)
return out2
def get_config(self):
"""获取配置"""
config = super(TransformerBlock, self).get_config()
config.update({
"embed_dim": self.embed_dim,
"num_heads": self.num_heads,
"ff_dim": self.ff_dim,
"dropout_rate": self.dropout_rate
})
return config
class TransformerTextClassifier(TextClassificationModel):
"""Transformer文本分类模型"""
def __init__(self, num_classes: int, vocab_size: int,
embedding_dim: int = TRANSFORMER_CONFIG["embedding_dim"],
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
num_heads: int = TRANSFORMER_CONFIG["num_heads"],
ff_dim: int = TRANSFORMER_CONFIG["ff_dim"],
num_layers: int = TRANSFORMER_CONFIG["num_layers"],
dropout_rate: float = TRANSFORMER_CONFIG["dropout_rate"],
embedding_matrix: Optional[np.ndarray] = None,
trainable_embedding: bool = True,
use_positional_encoding: bool = True,
model_name: str = "transformer_text_classifier",
batch_size: int = 64,
learning_rate: float = 0.001):
"""
初始化Transformer文本分类模型
Args:
num_classes: 类别数量
vocab_size: 词汇表大小
embedding_dim: 词嵌入维度
max_sequence_length: 最大序列长度
num_heads: 注意力头数
ff_dim: 前馈网络维度
num_layers: Transformer层数
dropout_rate: Dropout比例
embedding_matrix: 预训练词嵌入矩阵如果为None则使用随机初始化
trainable_embedding: 词嵌入是否可训练
use_positional_encoding: 是否使用位置编码
model_name: 模型名称
batch_size: 批大小
learning_rate: 学习率
"""
super().__init__(num_classes, model_name, batch_size, learning_rate)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.max_sequence_length = max_sequence_length
self.num_heads = num_heads
self.ff_dim = ff_dim
self.num_layers = num_layers
self.dropout_rate = dropout_rate
self.embedding_matrix = embedding_matrix
self.trainable_embedding = trainable_embedding
self.use_positional_encoding = use_positional_encoding
# 更新配置
self.config.update({
"vocab_size": vocab_size,
"embedding_dim": embedding_dim,
"max_sequence_length": max_sequence_length,
"num_heads": num_heads,
"ff_dim": ff_dim,
"num_layers": num_layers,
"dropout_rate": dropout_rate,
"trainable_embedding": trainable_embedding,
"use_positional_encoding": use_positional_encoding,
"model_type": "Transformer"
})
logger.info(f"初始化Transformer文本分类模型头数: {num_heads}, 层数: {num_layers}")
def _positional_encoding(self, max_length: int, d_model: int) -> tf.Tensor:
"""
生成位置编码
Args:
max_length: 最大序列长度
d_model: 模型维度
Returns:
位置编码张量
"""
positions = np.arange(max_length)[:, np.newaxis]
depths = np.arange(d_model)[np.newaxis, :] // 2 * 2
angle_rates = 1 / np.power(10000, depths / d_model)
angle_rads = positions * angle_rates
# sin用于偶数索引cos用于奇数索引
sines = np.sin(angle_rads[:, 0::2])
cosines = np.cos(angle_rads[:, 1::2])
pos_encoding = np.zeros((max_length, d_model))
pos_encoding[:, 0::2] = sines
pos_encoding[:, 1::2] = cosines
return tf.cast(pos_encoding[tf.newaxis, ...], dtype=tf.float32)
def build(self) -> None:
"""构建Transformer模型架构"""
# Input layer
sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
# Embedding layer
if self.embedding_matrix is not None:
embedding_layer = Embedding(
input_dim=self.vocab_size,
output_dim=self.embedding_dim,
weights=[self.embedding_matrix],
input_length=self.max_sequence_length,
trainable=self.trainable_embedding,
name='embedding'
)
else:
embedding_layer = Embedding(
input_dim=self.vocab_size,
output_dim=self.embedding_dim,
input_length=self.max_sequence_length,
trainable=True,
name='embedding'
)
embedded_sequences = embedding_layer(sequence_input)
# 添加位置编码
if self.use_positional_encoding:
pos_encoding = self._positional_encoding(self.max_sequence_length, self.embedding_dim)
embedded_sequences = embedded_sequences + pos_encoding
# Transformer层
x = embedded_sequences
for i in range(self.num_layers):
x = TransformerBlock(
embed_dim=self.embedding_dim,
num_heads=self.num_heads,
ff_dim=self.ff_dim,
dropout_rate=self.dropout_rate,
name=f'transformer_block_{i + 1}'
)(x)
# 全局池化
x = GlobalAveragePooling1D(name='global_avg_pooling')(x)
# Dropout for regularization
x = Dropout(self.dropout_rate, name='dropout_1')(x)
# Dense layer
x = Dense(128, activation='relu', name='dense_1')(x)
x = Dropout(self.dropout_rate, name='dropout_2')(x)
# Output layer
if self.num_classes == 2:
# Binary classification
predictions = Dense(1, activation='sigmoid', name='predictions')(x)
else:
# Multi-class classification
predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
# Build the model
self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
logger.info(f"Transformer模型构建完成头数: {self.num_heads}, 层数: {self.num_layers}")
def compile(self, optimizer=None, loss=None, metrics=None) -> None:
"""
编译Transformer模型
Args:
optimizer: 优化器默认为Adam
loss: 损失函数默认根据类别数量选择
metrics: 评估指标默认为accuracy
"""
if self.model is None:
raise ValueError("模型尚未构建请先调用build方法")
# 默认优化器
if optimizer is None:
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
# 默认损失函数
if loss is None:
if self.num_classes == 2:
loss = 'binary_crossentropy'
else:
loss = 'sparse_categorical_crossentropy'
# 默认评估指标
if metrics is None:
metrics = ['accuracy']
# 编译模型
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
logger.info(f"Transformer模型已编译优化器: {optimizer.__class__.__name__}, 损失函数: {loss}")

View File

View File

@ -0,0 +1,414 @@
"""
数据增强模块实现文本数据增强技术
"""
import random
import re
import jieba
import synonyms
import numpy as np
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
import copy
from config.model_config import RANDOM_SEED
from utils.logger import get_logger
from preprocessing.tokenization import ChineseTokenizer
# 设置随机种子以保证可重复性
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
logger = get_logger("DataAugmentation")
class TextAugmenter:
"""文本增强基类,定义通用接口"""
def __init__(self):
"""初始化文本增强器"""
pass
def augment(self, text: str) -> str:
"""
对文本进行增强
Args:
text: 原始文本
Returns:
增强后的文本
"""
raise NotImplementedError("子类必须实现此方法")
def batch_augment(self, texts: List[str]) -> List[str]:
"""
批量对文本进行增强
Args:
texts: 原始文本列表
Returns:
增强后的文本列表
"""
return [self.augment(text) for text in texts]
def augment_with_label(self, text: str, label: Any) -> Tuple[str, Any]:
"""
对文本进行增强同时保留标签
Args:
text: 原始文本
label: 标签
Returns:
(增强后的文本, 标签)的元组
"""
return self.augment(text), label
def batch_augment_with_label(self, texts: List[str], labels: List[Any]) -> List[Tuple[str, Any]]:
"""
批量对文本进行增强同时保留标签
Args:
texts: 原始文本列表
labels: 标签列表
Returns:
(增强后的文本, 标签)的元组列表
"""
return [self.augment_with_label(text, label) for text, label in zip(texts, labels)]
class SynonymReplacement(TextAugmenter):
"""同义词替换增强器"""
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
replace_ratio: float = 0.1,
min_similarity: float = 0.7):
"""
初始化同义词替换增强器
Args:
tokenizer: 分词器如果为None则创建一个新的分词器
replace_ratio: 替换比例表示要替换的词占总词数的比例
min_similarity: 最小相似度只有相似度大于该值的同义词才会被用于替换
"""
super().__init__()
self.tokenizer = tokenizer or ChineseTokenizer()
self.replace_ratio = replace_ratio
self.min_similarity = min_similarity
def _get_synonym(self, word: str) -> Optional[str]:
"""
获取词的同义词
Args:
word: 原始词
Returns:
同义词如果没有合适的同义词则返回None
"""
# 使用synonyms包获取同义词
try:
synonyms_list = synonyms.nearby(word)
# synonyms.nearby返回一个元组第一个元素是相似词列表第二个元素是相似度列表
words = synonyms_list[0]
similarities = synonyms_list[1]
# 过滤掉相似度低于阈值的词和原词本身
valid_synonyms = [(w, s) for w, s in zip(words, similarities)
if s >= self.min_similarity and w != word]
if valid_synonyms:
# 按相似度排序,选择最相似的词
valid_synonyms.sort(key=lambda x: x[1], reverse=True)
return valid_synonyms[0][0]
return None
except:
return None
def augment(self, text: str) -> str:
"""
对文本进行同义词替换增强
Args:
text: 原始文本
Returns:
增强后的文本
"""
if not text:
return text
# 分词
words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
if not words:
return text
# 计算要替换的词数量
n_replace = max(1, int(len(words) * self.replace_ratio))
# 随机选择要替换的词索引
replace_indices = random.sample(range(len(words)), min(n_replace, len(words)))
# 替换为同义词
for idx in replace_indices:
synonym = self._get_synonym(words[idx])
if synonym:
words[idx] = synonym
# 拼接为文本
augmented_text = ''.join(words)
return augmented_text
class RandomDeletion(TextAugmenter):
"""随机删除增强器"""
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
delete_ratio: float = 0.1):
"""
初始化随机删除增强器
Args:
tokenizer: 分词器如果为None则创建一个新的分词器
delete_ratio: 删除比例表示要删除的词占总词数的比例
"""
super().__init__()
self.tokenizer = tokenizer or ChineseTokenizer()
self.delete_ratio = delete_ratio
def augment(self, text: str) -> str:
"""
对文本进行随机删除增强
Args:
text: 原始文本
Returns:
增强后的文本
"""
if not text:
return text
# 分词
words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
if len(words) <= 1:
return text
# 计算要删除的词数量
n_delete = max(1, int(len(words) * self.delete_ratio))
# 随机选择要删除的词索引
delete_indices = random.sample(range(len(words)), min(n_delete, len(words) - 1))
# 删除选中的词
augmented_words = [words[i] for i in range(len(words)) if i not in delete_indices]
# 拼接为文本
augmented_text = ''.join(augmented_words)
return augmented_text
class RandomSwap(TextAugmenter):
"""随机交换增强器"""
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
n_swaps: int = 1):
"""
初始化随机交换增强器
Args:
tokenizer: 分词器如果为None则创建一个新的分词器
n_swaps: 交换次数
"""
super().__init__()
self.tokenizer = tokenizer or ChineseTokenizer()
self.n_swaps = n_swaps
def augment(self, text: str) -> str:
"""
对文本进行随机交换增强
Args:
text: 原始文本
Returns:
增强后的文本
"""
if not text:
return text
# 分词
words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
if len(words) <= 1:
return text
# 进行n_swaps次随机交换
augmented_words = words.copy()
for _ in range(min(self.n_swaps, len(words) // 2)):
# 随机选择两个不同的索引
idx1, idx2 = random.sample(range(len(augmented_words)), 2)
# 交换两个词
augmented_words[idx1], augmented_words[idx2] = augmented_words[idx2], augmented_words[idx1]
# 拼接为文本
augmented_text = ''.join(augmented_words)
return augmented_text
class CompositeAugmenter(TextAugmenter):
"""组合增强器,组合多个增强器"""
def __init__(self, augmenters: List[TextAugmenter],
probs: Optional[List[float]] = None):
"""
初始化组合增强器
Args:
augmenters: 增强器列表
probs: 各增强器被选择的概率列表如果为None则均匀选择
"""
super().__init__()
self.augmenters = augmenters
# 如果没有提供概率,则均匀分配
if probs is None:
self.probs = [1.0 / len(augmenters)] * len(augmenters)
else:
# 确保概率和为1
total = sum(probs)
self.probs = [p / total for p in probs]
assert len(self.augmenters) == len(self.probs), "增强器数量与概率数量不匹配"
def augment(self, text: str) -> str:
"""
对文本进行组合增强
Args:
text: 原始文本
Returns:
增强后的文本
"""
if not text:
return text
# 根据概率随机选择一个增强器
augmenter = random.choices(self.augmenters, weights=self.probs, k=1)[0]
# 使用选中的增强器进行增强
return augmenter.augment(text)
class BackTranslation(TextAugmenter):
"""回译增强器"""
def __init__(self, translator=None, source_lang: str = 'zh',
target_langs: List[str] = None):
"""
初始化回译增强器
Args:
translator: 翻译器需要实现translate方法
source_lang: 源语言代码
target_langs: 目标语言代码列表如果为None则使用默认语言
"""
super().__init__()
# 如果没有提供翻译器,尝试使用第三方翻译库
if translator is None:
try:
# 尝试导入多种翻译库
# 首先尝试使用googletrans (需要单独安装: pip install googletrans==4.0.0-rc1)
try:
from googletrans import Translator
self.translator = Translator()
self.translate_func = self._google_translate
except ImportError:
# 如果googletrans不可用尝试使用py-translate
try:
import translate
self.translator = translate
self.translate_func = self._py_translate
except ImportError:
logger.warning("未安装翻译库回译功能将不可用。请安装googletrans或py-translate")
self.translator = None
self.translate_func = self._dummy_translate
except Exception as e:
logger.error(f"初始化翻译器失败: {e}")
self.translator = None
self.translate_func = self._dummy_translate
else:
self.translator = translator
self.translate_func = self._custom_translate
self.source_lang = source_lang
self.target_langs = target_langs or ['en', 'fr', 'de', 'es', 'ja']
def _google_translate(self, text: str, source_lang: str, target_lang: str) -> str:
"""使用googletrans进行翻译"""
try:
result = self.translator.translate(text, src=source_lang, dest=target_lang)
return result.text
except Exception as e:
logger.error(f"翻译失败: {e}")
return text
def _py_translate(self, text: str, source_lang: str, target_lang: str) -> str:
"""使用py-translate进行翻译"""
try:
return self.translator.translate(text, source_lang, target_lang)
except Exception as e:
logger.error(f"翻译失败: {e}")
return text
def _custom_translate(self, text: str, source_lang: str, target_lang: str) -> str:
"""使用自定义翻译器进行翻译"""
try:
return self.translator.translate(text, source_lang, target_lang)
except Exception as e:
logger.error(f"翻译失败: {e}")
return text
def _dummy_translate(self, text: str, source_lang: str, target_lang: str) -> str:
"""虚拟翻译功能,仅返回原文本"""
logger.warning("翻译功能不可用,使用原文本")
return text
def augment(self, text: str) -> str:
"""
对文本进行回译增强
Args:
text: 原始文本
Returns:
增强后的文本
"""
if not text or self.translator is None:
return text
# 随机选择一个目标语言
target_lang = random.choice(self.target_langs)
try:
# 将源语言翻译为目标语言
translated = self.translate_func(text, self.source_lang, target_lang)
# 将目标语言翻译回源语言
back_translated = self.translate_func(translated, target_lang, self.source_lang)
return back_translated
except Exception as e:
logger.error(f"回译失败: {e}")
return text

View File

@ -0,0 +1,430 @@
"""
特征提取模块实现文本特征提取包括语法特征语义特征等
"""
import re
import numpy as np
from typing import List, Dict, Tuple, Optional, Any, Union, Set
from collections import Counter
import jieba.posseg as pseg
from config.system_config import CATEGORIES
from utils.logger import get_logger
from preprocessing.tokenization import ChineseTokenizer
logger = get_logger("FeatureExtraction")
class FeatureExtractor:
"""特征提取基类,定义通用接口"""
def __init__(self):
"""初始化特征提取器"""
pass
def extract(self, text: str) -> Dict[str, Any]:
"""
从文本中提取特征
Args:
text: 文本
Returns:
特征字典
"""
raise NotImplementedError("子类必须实现此方法")
def batch_extract(self, texts: List[str]) -> List[Dict[str, Any]]:
"""
批量提取特征
Args:
texts: 文本列表
Returns:
特征字典列表
"""
return [self.extract(text) for text in texts]
def extract_as_vector(self, text: str) -> np.ndarray:
"""
从文本中提取特征并转换为向量表示
Args:
text: 文本
Returns:
特征向量
"""
raise NotImplementedError("子类必须实现此方法")
def batch_extract_as_vector(self, texts: List[str]) -> np.ndarray:
"""
批量提取特征并转换为向量表示
Args:
texts: 文本列表
Returns:
特征向量数组
"""
return np.array([self.extract_as_vector(text) for text in texts])
class StatisticalFeatureExtractor(FeatureExtractor):
"""统计特征提取器,提取文本的统计特征"""
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None):
"""
初始化统计特征提取器
Args:
tokenizer: 分词器如果为None则创建一个新的分词器
"""
super().__init__()
self.tokenizer = tokenizer or ChineseTokenizer()
def extract(self, text: str) -> Dict[str, Any]:
"""
从文本中提取统计特征
Args:
text: 文本
Returns:
特征字典包含各种统计特征
"""
if not text:
return {
"char_count": 0,
"word_count": 0,
"sentence_count": 0,
"avg_word_length": 0,
"avg_sentence_length": 0,
"contains_number": False,
"contains_english": False,
"punctuation_ratio": 0,
"top_words": []
}
# 字符数
char_count = len(text)
# 分词
words = self.tokenizer.tokenize(text, return_string=False)
word_count = len(words)
# 句子数(按标点符号分割)
sentences = re.split(r'[。!?!?]+', text)
sentences = [s for s in sentences if s.strip()]
sentence_count = len(sentences)
# 平均词长
avg_word_length = sum(len(word) for word in words) / word_count if word_count > 0 else 0
# 平均句长(以字符为单位)
avg_sentence_length = char_count / sentence_count if sentence_count > 0 else 0
# 是否包含数字
contains_number = bool(re.search(r'\d', text))
# 是否包含英文
contains_english = bool(re.search(r'[a-zA-Z]', text))
# 标点符号比例
punctuation_pattern = re.compile(r'[^\w\s]')
punctuations = punctuation_pattern.findall(text)
punctuation_ratio = len(punctuations) / char_count if char_count > 0 else 0
# 高频词
word_counter = Counter(words)
top_words = word_counter.most_common(5)
return {
"char_count": char_count,
"word_count": word_count,
"sentence_count": sentence_count,
"avg_word_length": avg_word_length,
"avg_sentence_length": avg_sentence_length,
"contains_number": contains_number,
"contains_english": contains_english,
"punctuation_ratio": punctuation_ratio,
"top_words": top_words
}
def extract_as_vector(self, text: str) -> np.ndarray:
"""
从文本中提取统计特征并转换为向量表示
Args:
text: 文本
Returns:
特征向量包含各种统计特征
"""
features = self.extract(text)
# 提取数值特征
vector = [
features['char_count'],
features['word_count'],
features['sentence_count'],
features['avg_word_length'],
features['avg_sentence_length'],
int(features['contains_number']),
int(features['contains_english']),
features['punctuation_ratio']
]
return np.array(vector, dtype=np.float32)
class POSFeatureExtractor(FeatureExtractor):
"""词性特征提取器,提取文本的词性特征"""
def __init__(self):
"""初始化词性特征提取器"""
super().__init__()
# 常见中文词性及其解释
self.pos_tags = {
'n': '名词', 'f': '方位名词', 's': '处所名词', 't': '时间名词',
'nr': '人名', 'ns': '地名', 'nt': '机构团体', 'nw': '作品名',
'nz': '其他专名', 'v': '动词', 'vd': '副动词', 'vn': '名动词',
'a': '形容词', 'ad': '副形词', 'an': '名形词', 'd': '副词',
'm': '数词', 'q': '量词', 'r': '代词', 'p': '介词',
'c': '连词', 'u': '助词', 'xc': '其他虚词', 'w': '标点符号'
}
def extract(self, text: str) -> Dict[str, Any]:
"""
从文本中提取词性特征
Args:
text: 文本
Returns:
特征字典包含各种词性特征
"""
if not text:
return {
"pos_counts": {},
"pos_ratios": {}
}
# 使用jieba进行词性标注
pos_list = pseg.cut(text)
# 统计各词性的数量
pos_counts = {}
total_count = 0
for word, pos in pos_list:
if pos in pos_counts:
pos_counts[pos] += 1
else:
pos_counts[pos] = 1
total_count += 1
# 计算各词性的比例
pos_ratios = {pos: count / total_count for pos, count in pos_counts.items()} if total_count > 0 else {}
return {
"pos_counts": pos_counts,
"pos_ratios": pos_ratios
}
def extract_as_vector(self, text: str) -> np.ndarray:
"""
从文本中提取词性特征并转换为向量表示
Args:
text: 文本
Returns:
特征向量包含各词性的比例
"""
features = self.extract(text)
pos_ratios = features['pos_ratios']
# 按照 self.pos_tags 的顺序构建向量
vector = []
for pos in self.pos_tags.keys():
vector.append(pos_ratios.get(pos, 0.0))
return np.array(vector, dtype=np.float32)
class KeywordFeatureExtractor(FeatureExtractor):
"""关键词特征提取器,基于预定义关键词提取特征"""
def __init__(self, category_keywords: Optional[Dict[str, List[str]]] = None):
"""
初始化关键词特征提取器
Args:
category_keywords: 类别关键词字典键为类别名称值为关键词列表
"""
super().__init__()
self.category_keywords = category_keywords or self._get_default_keywords()
self.tokenizer = ChineseTokenizer()
def _get_default_keywords(self) -> Dict[str, List[str]]:
"""
获取默认的类别关键词
Returns:
类别关键词字典
"""
# 为每个类别定义一些示例关键词
default_keywords = {
"体育": ["比赛", "运动", "球员", "冠军", "球队", "足球", "篮球"],
"财经": ["股票", "基金", "投资", "市场", "经济", "金融", "股市"],
"房产": ["房价", "楼市", "地产", "购房", "房贷", "物业", "小区"],
"家居": ["装修", "家具", "设计", "卧室", "客厅", "厨房", "风格"],
"教育": ["学校", "学生", "考试", "教育", "大学", "课程", "老师"],
"科技": ["互联网", "科技", "创新", "数字", "智能", "研发", "技术"],
"时尚": ["时尚", "潮流", "服装", "搭配", "品牌", "美容", "穿着"],
"时政": ["政府", "政策", "国家", "发展", "会议", "主席", "总理"],
"游戏": ["游戏", "玩家", "电竞", "网游", "手游", "角色", "任务"],
"娱乐": ["明星", "电影", "节目", "综艺", "电视", "演员", "导演"],
"其他": ["其他", "一般", "常见", "普通", "正常", "通常", "传统"]
}
# 确保 CATEGORIES 中的每个类别都有关键词
for category in CATEGORIES:
if category not in default_keywords:
default_keywords[category] = [category]
return default_keywords
def extract(self, text: str) -> Dict[str, Any]:
"""
从文本中提取关键词特征
Args:
text: 文本
Returns:
特征字典包含各类别的关键词匹配情况
"""
if not text:
return {
"keyword_matches": {cat: 0 for cat in self.category_keywords},
"keyword_match_ratios": {cat: 0.0 for cat in self.category_keywords}
}
# 对文本分词
words = set(self.tokenizer.tokenize(text, return_string=False))
# 统计各类别的关键词匹配数量
keyword_matches = {}
for category, keywords in self.category_keywords.items():
# 计算文本中包含的该类别关键词数量
matches = sum(1 for kw in keywords if kw in words)
keyword_matches[category] = matches
# 计算匹配比例(归一化)
total_matches = sum(keyword_matches.values())
keyword_match_ratios = {
cat: matches / total_matches if total_matches > 0 else 0.0
for cat, matches in keyword_matches.items()
}
return {
"keyword_matches": keyword_matches,
"keyword_match_ratios": keyword_match_ratios
}
def extract_as_vector(self, text: str) -> np.ndarray:
"""
从文本中提取关键词特征并转换为向量表示
Args:
text: 文本
Returns:
特征向量包含各类别的关键词匹配比例
"""
features = self.extract(text)
match_ratios = features['keyword_match_ratios']
# 按照 CATEGORIES 的顺序构建向量
vector = [match_ratios.get(cat, 0.0) for cat in CATEGORIES]
return np.array(vector, dtype=np.float32)
def update_keywords(self, category: str, keywords: List[str]) -> None:
"""
更新指定类别的关键词
Args:
category: 类别名称
keywords: 关键词列表
"""
self.category_keywords[category] = keywords
logger.info(f"已更新类别 {category} 的关键词,共 {len(keywords)}")
def add_keywords(self, category: str, keywords: List[str]) -> None:
"""
向指定类别添加关键词
Args:
category: 类别名称
keywords: 要添加的关键词列表
"""
if category in self.category_keywords:
existing_keywords = set(self.category_keywords[category])
for keyword in keywords:
existing_keywords.add(keyword)
self.category_keywords[category] = list(existing_keywords)
else:
self.category_keywords[category] = keywords
logger.info(f"已向类别 {category} 添加关键词,当前共 {len(self.category_keywords[category])}")
class CombinedFeatureExtractor(FeatureExtractor):
"""组合特征提取器,组合多个特征提取器的结果"""
def __init__(self, extractors: List[FeatureExtractor]):
"""
初始化组合特征提取器
Args:
extractors: 特征提取器列表
"""
super().__init__()
self.extractors = extractors
def extract(self, text: str) -> Dict[str, Any]:
"""
从文本中提取组合特征
Args:
text: 文本
Returns:
特征字典包含所有特征提取器的结果
"""
combined_features = {}
for i, extractor in enumerate(self.extractors):
extractor_name = type(extractor).__name__
features = extractor.extract(text)
combined_features[extractor_name] = features
return combined_features
def extract_as_vector(self, text: str) -> np.ndarray:
"""
从文本中提取组合特征并转换为向量表示
Args:
text: 文本
Returns:
特征向量包含所有特征提取器的向量拼接
"""
# 获取所有特征提取器的向量
feature_vectors = [extractor.extract_as_vector(text) for extractor in self.extractors]
# 拼接向量
return np.concatenate(feature_vectors)

View File

@ -0,0 +1,229 @@
"""
文本清洗模块实现文本清洗去除无用字符HTML标签等
"""
import re
import unicodedata
import html
from typing import List, Dict, Tuple, Optional, Any, Callable, Set, Union
import string
from utils.logger import get_logger
logger = get_logger("TextCleaner")
class TextCleaner:
"""文本清洗类,提供各种文本清洗方法"""
def __init__(self, remove_html: bool = True,
remove_urls: bool = True,
remove_emails: bool = True,
remove_numbers: bool = False,
remove_punctuation: bool = False,
lowercase: bool = False,
normalize_unicode: bool = True,
remove_excessive_spaces: bool = True,
remove_short_texts: bool = False,
min_text_length: int = 10,
custom_patterns: Optional[List[str]] = None):
"""
初始化文本清洗器
Args:
remove_html: 是否移除HTML标签
remove_urls: 是否移除URL
remove_emails: 是否移除电子邮件地址
remove_numbers: 是否移除数字
remove_punctuation: 是否移除标点符号
lowercase: 是否转为小写对中文无效
normalize_unicode: 是否规范化Unicode字符
remove_excessive_spaces: 是否移除多余空格
remove_short_texts: 是否过滤掉短文本
min_text_length: 最小文本长度当remove_short_texts=True时有效
custom_patterns: 自定义的正则表达式模式列表用于额外的文本清洗
"""
self.remove_html = remove_html
self.remove_urls = remove_urls
self.remove_emails = remove_emails
self.remove_numbers = remove_numbers
self.remove_punctuation = remove_punctuation
self.lowercase = lowercase
self.normalize_unicode = normalize_unicode
self.remove_excessive_spaces = remove_excessive_spaces
self.remove_short_texts = remove_short_texts
self.min_text_length = min_text_length
self.custom_patterns = custom_patterns or []
# 编译正则表达式
self.html_pattern = re.compile(r'<.*?>')
self.url_pattern = re.compile(r'https?://\S+|www\.\S+')
self.email_pattern = re.compile(r'\S+@\S+\.\S+')
self.number_pattern = re.compile(r'\d+')
self.space_pattern = re.compile(r'\s+')
# 编译自定义模式
self.compiled_custom_patterns = [re.compile(pattern) for pattern in self.custom_patterns]
# 中文标点符号
self.chinese_punctuation = ",。!?;:""''【】《》()、…—~·"
logger.info("文本清洗器初始化完成")
def clean_text(self, text: str) -> str:
"""
清洗文本应用所有已配置的清洗方法
Args:
text: 原始文本
Returns:
清洗后的文本
"""
if not text:
return ""
# HTML解码
if self.remove_html:
text = html.unescape(text)
text = self.html_pattern.sub(' ', text)
# 移除URL
if self.remove_urls:
text = self.url_pattern.sub(' ', text)
# 移除电子邮件
if self.remove_emails:
text = self.email_pattern.sub(' ', text)
# Unicode规范化
if self.normalize_unicode:
text = unicodedata.normalize('NFKC', text)
# 移除数字
if self.remove_numbers:
text = self.number_pattern.sub(' ', text)
# 移除标点符号
if self.remove_punctuation:
# 处理英文标点
for punct in string.punctuation:
text = text.replace(punct, ' ')
# 处理中文标点
for punct in self.chinese_punctuation:
text = text.replace(punct, ' ')
# 应用自定义清洗模式
for pattern in self.compiled_custom_patterns:
text = pattern.sub(' ', text)
# 转为小写
if self.lowercase:
text = text.lower()
# 移除多余空格
if self.remove_excessive_spaces:
text = self.space_pattern.sub(' ', text)
text = text.strip()
# 过滤掉短文本
if self.remove_short_texts and len(text) < self.min_text_length:
return ""
return text
def clean_texts(self, texts: List[str]) -> List[str]:
"""
批量清洗文本
Args:
texts: 原始文本列表
Returns:
清洗后的文本列表
"""
return [self.clean_text(text) for text in texts]
def remove_redundant_texts(self, texts: List[str]) -> List[str]:
"""
移除冗余文本空文本和长度小于阈值的文本
Args:
texts: 原始文本列表
Returns:
移除冗余后的文本列表
"""
return [text for text in texts if text and len(text) >= self.min_text_length]
@staticmethod
def remove_specific_characters(text: str, chars_to_remove: Union[str, Set[str]]) -> str:
"""
移除特定字符
Args:
text: 原始文本
chars_to_remove: 要移除的字符字符串或字符集合
Returns:
移除特定字符后的文本
"""
if isinstance(chars_to_remove, str):
for char in chars_to_remove:
text = text.replace(char, '')
else:
for char in chars_to_remove:
text = text.replace(char, '')
return text
@staticmethod
def replace_characters(text: str, char_map: Dict[str, str]) -> str:
"""
替换特定字符
Args:
text: 原始文本
char_map: 字符映射字典键为要替换的字符值为替换后的字符
Returns:
替换特定字符后的文本
"""
for old_char, new_char in char_map.items():
text = text.replace(old_char, new_char)
return text
@staticmethod
def remove_empty_lines(text: str) -> str:
"""
移除空行
Args:
text: 原始文本
Returns:
移除空行后的文本
"""
lines = text.splitlines()
non_empty_lines = [line for line in lines if line.strip()]
return '\n'.join(non_empty_lines)
@staticmethod
def truncate_text(text: str, max_length: int, truncate_from_end: bool = True) -> str:
"""
截断文本
Args:
text: 原始文本
max_length: 最大长度
truncate_from_end: 是否从末尾截断如果为False则从开头截断
Returns:
截断后的文本
"""
if len(text) <= max_length:
return text
if truncate_from_end:
return text[:max_length]
else:
return text[len(text) - max_length:]

View File

@ -0,0 +1,248 @@
"""
中文分词模块负责中文文本分词处理
"""
import os
import jieba
import re
from typing import List, Dict, Tuple, Optional, Any, Set, Union
import pandas as pd
from collections import Counter
from config.system_config import STOPWORDS_DIR, ENCODING
from utils.logger import get_logger
from utils.file_utils import read_text_file, write_text_file, ensure_dir
logger = get_logger("Tokenization")
class ChineseTokenizer:
"""中文分词器基于jieba实现"""
def __init__(self, user_dict_path: Optional[str] = None,
use_hmm: bool = True,
remove_stopwords: bool = True,
stopwords_path: Optional[str] = None,
add_custom_words: Optional[List[str]] = None):
"""
初始化中文分词器
Args:
user_dict_path: 用户自定义词典路径
use_hmm: 是否使用HMM模型进行分词
remove_stopwords: 是否移除停用词
stopwords_path: 停用词表路径如果为None则使用默认停用词表
add_custom_words: 要添加的自定义词语列表
"""
self.use_hmm = use_hmm
self.remove_stopwords = remove_stopwords
# 加载用户自定义词典
if user_dict_path and os.path.exists(user_dict_path):
jieba.load_userdict(user_dict_path)
logger.info(f"已加载用户自定义词典:{user_dict_path}")
# 加载停用词
self.stopwords = set()
if remove_stopwords:
self._load_stopwords(stopwords_path)
# 添加自定义词语
if add_custom_words:
for word in add_custom_words:
jieba.add_word(word)
logger.info(f"已添加 {len(add_custom_words)} 个自定义词语")
def _load_stopwords(self, stopwords_path: Optional[str] = None) -> None:
"""
加载停用词
Args:
stopwords_path: 停用词表路径如果为None则使用默认停用词表
"""
# 如果没有指定停用词表路径,则使用默认停用词表
if not stopwords_path:
stopwords_path = os.path.join(STOPWORDS_DIR, "chinese_stopwords.txt")
# 如果没有找到默认停用词表,则创建一个空的停用词表
if not os.path.exists(stopwords_path):
ensure_dir(os.path.dirname(stopwords_path))
# 常见中文停用词
default_stopwords = [
"", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "", "", "", "", ""
]
write_text_file("\n".join(default_stopwords), stopwords_path)
logger.info(f"未找到停用词表,已创建默认停用词表:{stopwords_path}")
# 加载停用词表
try:
with open(stopwords_path, "r", encoding=ENCODING) as f:
for line in f:
word = line.strip()
if word:
self.stopwords.add(word)
logger.info(f"已加载 {len(self.stopwords)} 个停用词")
except Exception as e:
logger.error(f"加载停用词表失败:{e}")
def add_stopwords(self, words: Union[str, List[str]]) -> None:
"""
添加停用词
Args:
words: 要添加的停用词字符串或列表
"""
if isinstance(words, str):
self.stopwords.add(words.strip())
else:
for word in words:
self.stopwords.add(word.strip())
def remove_stopwords_from_list(self, words: List[str]) -> List[str]:
"""
从词语列表中移除停用词
Args:
words: 词语列表
Returns:
移除停用词后的词语列表
"""
if not self.remove_stopwords:
return words
return [word for word in words if word not in self.stopwords]
def tokenize(self, text: str, return_string: bool = False,
cut_all: bool = False) -> Union[List[str], str]:
"""
对文本进行分词
Args:
text: 要分词的文本
return_string: 是否返回字符串以空格分隔的词语
cut_all: 是否使用全模式默认使用精确模式
Returns:
分词结果词语列表或字符串
"""
if not text:
return "" if return_string else []
# 使用jieba进行分词
if cut_all:
words = jieba.lcut(text, cut_all=True)
else:
words = jieba.lcut(text, HMM=self.use_hmm)
# 移除停用词
if self.remove_stopwords:
words = self.remove_stopwords_from_list(words)
# 返回结果
if return_string:
return " ".join(words)
else:
return words
def batch_tokenize(self, texts: List[str], return_string: bool = False,
cut_all: bool = False) -> List[Union[List[str], str]]:
"""
批量分词
Args:
texts: 要分词的文本列表
return_string: 是否返回字符串以空格分隔的词语
cut_all: 是否使用全模式默认使用精确模式
Returns:
分词结果列表
"""
return [self.tokenize(text, return_string, cut_all) for text in texts]
def analyze_tokens(self, texts: List[str], top_n: int = 20) -> Dict[str, Any]:
"""
分析文本中的词频分布
Args:
texts: 要分析的文本列表
top_n: 返回前多少个高频词
Returns:
包含词频分析结果的字典
"""
all_tokens = []
for text in texts:
tokens = self.tokenize(text, return_string=False)
all_tokens.extend(tokens)
# 统计词频
token_counter = Counter(all_tokens)
# 获取最常见的词
most_common = token_counter.most_common(top_n)
# 计算唯一词数量
unique_tokens = len(token_counter)
return {
"total_tokens": len(all_tokens),
"unique_tokens": unique_tokens,
"most_common": most_common,
"token_counter": token_counter
}
def get_top_keywords(self, texts: List[str], top_n: int = 20,
min_freq: int = 3, min_length: int = 2) -> List[Tuple[str, int]]:
"""
获取文本中的关键词
Args:
texts: 要分析的文本列表
top_n: 返回前多少个关键词
min_freq: 最小词频
min_length: 最小词长度字符数
Returns:
包含(关键词, 词频)的元组列表
"""
tokens_analysis = self.analyze_tokens(texts)
token_counter = tokens_analysis["token_counter"]
# 过滤满足条件的词
filtered_keywords = [(word, count) for word, count in token_counter.items()
if count >= min_freq and len(word) >= min_length]
# 按词频排序
sorted_keywords = sorted(filtered_keywords, key=lambda x: x[1], reverse=True)
return sorted_keywords[:top_n]
def get_vocabulary(self, texts: List[str], min_freq: int = 1) -> List[str]:
"""
获取词汇表
Args:
texts: 文本列表
min_freq: 最小词频
Returns:
词汇表词语列表
"""
tokens_analysis = self.analyze_tokens(texts)
token_counter = tokens_analysis["token_counter"]
# 过滤满足最小词频的词
vocabulary = [word for word, count in token_counter.items() if count >= min_freq]
return vocabulary
def get_stopwords(self) -> Set[str]:
"""
获取停用词集合
Returns:
停用词集合
"""
return self.stopwords.copy()

774
preprocessing/vectorizer.py Normal file
View File

@ -0,0 +1,774 @@
"""
文本向量化模块实现文本向量化包括词袋模型TF-IDF和词嵌入等多种文本表示方法
"""
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
import pickle
import os
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
import gensim
from gensim.models import Word2Vec, KeyedVectors
from config.system_config import PROCESSED_DATA_DIR, EMBEDDINGS_DIR
from config.model_config import (
MAX_SEQUENCE_LENGTH, MAX_NUM_WORDS, MIN_WORD_FREQUENCY
)
from utils.logger import get_logger
from utils.file_utils import save_pickle, load_pickle, ensure_dir
from preprocessing.tokenization import ChineseTokenizer
logger = get_logger("Vectorizer")
class TextVectorizer:
"""文本向量化基类,定义通用接口"""
def __init__(self, max_features: int = MAX_NUM_WORDS):
"""
初始化文本向量化器
Args:
max_features: 最大特征数词汇表大小
"""
self.max_features = max_features
self.vectorizer = None
self.is_fitted = False
def fit(self, texts: List[str]) -> None:
"""
在文本上训练向量化器
Args:
texts: 文本列表
"""
raise NotImplementedError("子类必须实现此方法")
def transform(self, texts: List[str]) -> np.ndarray:
"""
将文本转换为向量表示
Args:
texts: 文本列表
Returns:
向量表示
"""
raise NotImplementedError("子类必须实现此方法")
def fit_transform(self, texts: List[str]) -> np.ndarray:
"""
在文本上训练向量化器并将文本转换为向量表示
Args:
texts: 文本列表
Returns:
向量表示
"""
self.fit(texts)
return self.transform(texts)
def save(self, path: str) -> None:
"""
保存向量化器
Args:
path: 保存路径
"""
ensure_dir(os.path.dirname(path))
save_pickle(self.vectorizer, path)
logger.info(f"向量化器已保存到:{path}")
def load(self, path: str) -> None:
"""
加载向量化器
Args:
path: 加载路径
"""
self.vectorizer = load_pickle(path)
self.is_fitted = True
logger.info(f"向量化器已从 {path} 加载")
def get_vocabulary(self) -> List[str]:
"""
获取词汇表
Returns:
词汇表
"""
raise NotImplementedError("子类必须实现此方法")
def get_vocabulary_size(self) -> int:
"""
获取词汇表大小
Returns:
词汇表大小
"""
raise NotImplementedError("子类必须实现此方法")
class BagOfWordsVectorizer(TextVectorizer):
"""词袋模型向量化器"""
def __init__(self, max_features: int = MAX_NUM_WORDS,
min_df: int = MIN_WORD_FREQUENCY,
tokenizer: Optional[Callable[[str], List[str]]] = None,
binary: bool = False):
"""
初始化词袋模型向量化器
Args:
max_features: 最大特征数词汇表大小
min_df: 最小文档频率
tokenizer: 分词器函数接收文本返回词语列表
binary: 是否使用二进制计数只关注词语是否出现不关注频率
"""
super().__init__(max_features)
self.min_df = min_df
self.binary = binary
# 创建sklearn的CountVectorizer
self.vectorizer = CountVectorizer(
max_features=max_features,
min_df=min_df,
tokenizer=tokenizer,
binary=binary
)
def fit(self, texts: List[str]) -> None:
"""
在文本上训练词袋模型
Args:
texts: 文本列表
"""
self.vectorizer.fit(texts)
self.is_fitted = True
logger.info(f"词袋模型已训练,词汇表大小:{len(self.vectorizer.vocabulary_)}")
def transform(self, texts: List[str]) -> np.ndarray:
"""
将文本转换为词袋向量表示
Args:
texts: 文本列表
Returns:
词袋向量表示稀疏矩阵
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
return self.vectorizer.transform(texts)
def get_vocabulary(self) -> List[str]:
"""
获取词汇表
Returns:
词汇表按索引排序
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
# CountVectorizer的词汇表是一个字典键为词值为索引
vocab_dict = self.vectorizer.vocabulary_
vocab_list = [""] * len(vocab_dict)
for word, idx in vocab_dict.items():
vocab_list[idx] = word
return vocab_list
def get_vocabulary_size(self) -> int:
"""
获取词汇表大小
Returns:
词汇表大小
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
return len(self.vectorizer.vocabulary_)
class TfidfVectorizer(TextVectorizer):
"""TF-IDF向量化器"""
def __init__(self, max_features: int = MAX_NUM_WORDS,
min_df: int = MIN_WORD_FREQUENCY,
tokenizer: Optional[Callable[[str], List[str]]] = None,
norm: str = 'l2',
use_idf: bool = True,
smooth_idf: bool = True,
sublinear_tf: bool = False):
"""
初始化TF-IDF向量化器
Args:
max_features: 最大特征数词汇表大小
min_df: 最小文档频率
tokenizer: 分词器函数接收文本返回词语列表
norm: 规范化方法默认为L2范数
use_idf: 是否使用IDF逆文档频率
smooth_idf: 是否平滑IDF权重
sublinear_tf: 是否应用sublinear scaling对TF取对数
"""
super().__init__(max_features)
self.min_df = min_df
# 创建sklearn的TfidfVectorizer
self.vectorizer = sklearn.feature_extraction.text.TfidfVectorizer(
max_features=max_features,
min_df=min_df,
tokenizer=tokenizer,
norm=norm,
use_idf=use_idf,
smooth_idf=smooth_idf,
sublinear_tf=sublinear_tf
)
def fit(self, texts: List[str]) -> None:
"""
在文本上训练TF-IDF模型
Args:
texts: 文本列表
"""
self.vectorizer.fit(texts)
self.is_fitted = True
logger.info(f"TF-IDF模型已训练词汇表大小{len(self.vectorizer.vocabulary_)}")
def transform(self, texts: List[str]) -> np.ndarray:
"""
将文本转换为TF-IDF向量表示
Args:
texts: 文本列表
Returns:
TF-IDF向量表示稀疏矩阵
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
return self.vectorizer.transform(texts)
def get_vocabulary(self) -> List[str]:
"""
获取词汇表
Returns:
词汇表按索引排序
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
# TfidfVectorizer的词汇表是一个字典键为词值为索引
vocab_dict = self.vectorizer.vocabulary_
vocab_list = [""] * len(vocab_dict)
for word, idx in vocab_dict.items():
vocab_list[idx] = word
return vocab_list
def get_vocabulary_size(self) -> int:
"""
获取词汇表大小
Returns:
词汇表大小
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
return len(self.vectorizer.vocabulary_)
def get_feature_names(self) -> List[str]:
"""
获取特征名称词汇表
Returns:
特征名称列表
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
return self.vectorizer.get_feature_names_out()
def get_idf(self) -> np.ndarray:
"""
获取IDF权重
Returns:
IDF权重数组
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
return self.vectorizer.idf_
class SequenceVectorizer(TextVectorizer):
"""序列向量化器使用Keras的Tokenizer"""
def __init__(self, max_features: int = MAX_NUM_WORDS,
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
oov_token: str = "<OOV>",
padding: str = "post",
truncating: str = "post"):
"""
初始化序列向量化器
Args:
max_features: 最大特征数词汇表大小
max_sequence_length: 序列最大长度
oov_token: 未登录词标记
padding: 填充方式'pre''post'
truncating: 截断方式'pre''post'
"""
super().__init__(max_features)
self.max_sequence_length = max_sequence_length
self.oov_token = oov_token
self.padding = padding
self.truncating = truncating
# 创建Keras的Tokenizer
self.vectorizer = Tokenizer(num_words=max_features, oov_token=oov_token)
def fit(self, texts: List[str]) -> None:
"""
在文本上训练序列向量化器
Args:
texts: 文本列表
"""
self.vectorizer.fit_on_texts(texts)
self.is_fitted = True
logger.info(f"序列向量化器已训练,词汇表大小:{len(self.vectorizer.word_index)}")
def transform(self, texts: List[str]) -> np.ndarray:
"""
将文本转换为整数序列并进行填充
Args:
texts: 文本列表
Returns:
整数序列表示
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
sequences = self.vectorizer.texts_to_sequences(texts)
padded_sequences = pad_sequences(
sequences,
maxlen=self.max_sequence_length,
padding=self.padding,
truncating=self.truncating
)
return padded_sequences
def get_vocabulary(self) -> List[str]:
"""
获取词汇表
Returns:
词汇表按索引排序
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
# Tokenizer的词汇表是一个字典键为词值为索引从1开始
word_index = self.vectorizer.word_index
index_word = {index: word for word, index in word_index.items()}
# 注意索引0保留给padding索引1保留给OOV如果有设置
vocab = ["<PAD>"]
if self.oov_token:
vocab.append(self.oov_token)
max_index = min(self.max_features, len(word_index) + 1) if self.max_features else len(word_index) + 1
for i in range(1, max_index):
if i in index_word:
vocab.append(index_word[i])
return vocab
def get_vocabulary_size(self) -> int:
"""
获取词汇表大小
Returns:
词汇表大小
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
# +1是因为索引0保留给padding
return min(self.max_features, len(self.vectorizer.word_index) + 1) if self.max_features else len(
self.vectorizer.word_index) + 1
def texts_to_sequences(self, texts: List[str]) -> List[List[int]]:
"""
将文本转换为整数序列不填充
Args:
texts: 文本列表
Returns:
整数序列列表
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
return self.vectorizer.texts_to_sequences(texts)
def sequences_to_padded(self, sequences: List[List[int]]) -> np.ndarray:
"""
将整数序列填充到指定长度
Args:
sequences: 整数序列列表
Returns:
填充后的整数序列
"""
return pad_sequences(
sequences,
maxlen=self.max_sequence_length,
padding=self.padding,
truncating=self.truncating
)
def save(self, path: str) -> None:
"""
保存序列向量化器
Args:
path: 保存路径
"""
ensure_dir(os.path.dirname(path))
# 保存配置和状态
tokenizer_state = {
'tokenizer': self.vectorizer,
'max_features': self.max_features,
'max_sequence_length': self.max_sequence_length,
'oov_token': self.oov_token,
'padding': self.padding,
'truncating': self.truncating,
'is_fitted': self.is_fitted
}
save_pickle(tokenizer_state, path)
logger.info(f"序列向量化器已保存到:{path}")
def load(self, path: str) -> None:
"""
加载序列向量化器
Args:
path: 加载路径
"""
tokenizer_state = load_pickle(path)
self.vectorizer = tokenizer_state['tokenizer']
self.max_features = tokenizer_state['max_features']
self.max_sequence_length = tokenizer_state['max_sequence_length']
self.oov_token = tokenizer_state['oov_token']
self.padding = tokenizer_state['padding']
self.truncating = tokenizer_state['truncating']
self.is_fitted = tokenizer_state['is_fitted']
logger.info(f"序列向量化器已从 {path} 加载,词汇表大小:{len(self.vectorizer.word_index)}")
class Word2VecVectorizer(TextVectorizer):
"""Word2Vec词嵌入向量化器"""
def __init__(self, vector_size: int = 100,
window: int = 5,
min_count: int = MIN_WORD_FREQUENCY,
workers: int = 4,
sg: int = 1, # 1表示Skip-gram模型0表示CBOW模型
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
padding: str = "post",
truncating: str = "post",
pretrained_path: Optional[str] = None):
"""
初始化Word2Vec词嵌入向量化器
Args:
vector_size: 词向量维度
window: 上下文窗口大小
min_count: 最小词频
workers: 并行训练的线程数
sg: 训练算法1表示Skip-gram0表示CBOW
max_sequence_length: 序列最大长度
padding: 填充方式'pre''post'
truncating: 截断方式'pre''post'
pretrained_path: 预训练词向量路径如果不为None则加载预训练词向量
"""
super().__init__(max_features=None) # Word2Vec没有max_features限制
self.vector_size = vector_size
self.window = window
self.min_count = min_count
self.workers = workers
self.sg = sg
self.max_sequence_length = max_sequence_length
self.padding = padding
self.truncating = truncating
self.pretrained_path = pretrained_path
# Word2Vec模型
self.model = None
# 词汇表
self.word_index = {}
self.index_word = {}
# 如果有预训练词向量,加载它
if pretrained_path and os.path.exists(pretrained_path):
self._load_pretrained(pretrained_path)
def _load_pretrained(self, path: str) -> None:
"""
加载预训练词向量
Args:
path: 预训练词向量路径
"""
try:
# 尝试加载Word2Vec模型
self.model = Word2Vec.load(path)
logger.info(f"已加载预训练Word2Vec模型{path}")
except:
try:
# 尝试加载词向量Word2Vec、GloVe或FastText格式
self.model = KeyedVectors.load_word2vec_format(path, binary=path.endswith('.bin'))
logger.info(f"已加载预训练词向量:{path}")
except Exception as e:
logger.error(f"加载预训练词向量失败:{e}")
return
# 如果加载成功,构建词汇表
self._build_vocab_from_model()
self.is_fitted = True
def _build_vocab_from_model(self) -> None:
"""从模型构建词汇表"""
# 获取词汇表
vocabulary = list(self.model.wv.index_to_key)
# 构建词汇表索引
self.word_index = {word: idx + 1 for idx, word in enumerate(vocabulary)} # 索引0保留给padding
self.index_word = {idx + 1: word for idx, word in enumerate(vocabulary)}
self.index_word[0] = "<PAD>"
def fit(self, tokenized_texts: List[List[str]]) -> None:
"""
在分词后的文本上训练Word2Vec模型
Args:
tokenized_texts: 分词后的文本列表每个文本是一个词语列表
"""
# 如果已经有预训练模型,跳过训练
if self.is_fitted and self.model is not None:
logger.info("已有预训练模型,跳过训练")
return
# 训练Word2Vec模型
self.model = Word2Vec(
sentences=tokenized_texts,
vector_size=self.vector_size,
window=self.window,
min_count=self.min_count,
workers=self.workers,
sg=self.sg
)
# 构建词汇表
self._build_vocab_from_model()
self.is_fitted = True
logger.info(f"Word2Vec模型已训练词汇表大小{len(self.word_index)}")
def transform(self, tokenized_texts: List[List[str]]) -> np.ndarray:
"""
将分词后的文本转换为词向量序列
Args:
tokenized_texts: 分词后的文本列表每个文本是一个词语列表
Returns:
词向量序列形状为(样本数, 最大序列长度, 词向量维度)
"""
if not self.is_fitted or self.model is None:
raise ValueError("向量化器尚未训练请先调用fit方法")
# 初始化结果数组
result = np.zeros((len(tokenized_texts), self.max_sequence_length, self.vector_size))
# 处理每个文本
for i, text in enumerate(tokenized_texts):
seq_len = min(len(text), self.max_sequence_length)
# 根据截断方式处理
if self.truncating == 'pre' and len(text) > self.max_sequence_length:
text = text[-self.max_sequence_length:]
elif self.truncating == 'post' and len(text) > self.max_sequence_length:
text = text[:self.max_sequence_length]
# 获取每个词的词向量
for j, word in enumerate(text[:seq_len]):
if word in self.model.wv:
# 根据填充方式确定位置
pos = j if self.padding == 'post' else self.max_sequence_length - seq_len + j
result[i, pos] = self.model.wv[word]
return result
def transform_to_indices(self, tokenized_texts: List[List[str]]) -> np.ndarray:
"""
将分词后的文本转换为词索引序列并填充
Args:
tokenized_texts: 分词后的文本列表每个文本是一个词语列表
Returns:
词索引序列
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
# 将词转换为索引
sequences = []
for text in tokenized_texts:
seq = [self.word_index.get(word, 0) for word in text] # 未登录词用0padding
sequences.append(seq)
# 填充序列
padded_sequences = pad_sequences(
sequences,
maxlen=self.max_sequence_length,
padding=self.padding,
truncating=self.truncating
)
return padded_sequences
def get_embedding_matrix(self) -> np.ndarray:
"""
获取嵌入矩阵用于Embedding层的权重初始化
Returns:
嵌入矩阵形状为(词汇表大小, 词向量维度)
"""
if not self.is_fitted or self.model is None:
raise ValueError("向量化器尚未训练请先调用fit方法")
vocab_size = len(self.word_index) + 1 # +1是因为索引0保留给padding
embedding_matrix = np.zeros((vocab_size, self.vector_size))
# 填充嵌入矩阵
for word, idx in self.word_index.items():
if word in self.model.wv:
embedding_matrix[idx] = self.model.wv[word]
return embedding_matrix
def get_vocabulary(self) -> List[str]:
"""
获取词汇表
Returns:
词汇表按索引排序
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
vocab = ["<PAD>"] # 索引0保留给padding
for idx in range(1, len(self.index_word) + 1):
if idx in self.index_word:
vocab.append(self.index_word[idx])
return vocab
def get_vocabulary_size(self) -> int:
"""
获取词汇表大小
Returns:
词汇表大小
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练请先调用fit方法")
return len(self.word_index) + 1 # +1是因为索引0保留给padding
def save(self, path: str) -> None:
"""
保存Word2Vec向量化器
Args:
path: 保存路径
"""
ensure_dir(os.path.dirname(path))
# 保存模型和配置
model_path = os.path.join(os.path.dirname(path), "word2vec_model")
if self.model:
self.model.save(model_path)
# 保存配置和状态
state = {
'word_index': self.word_index,
'index_word': self.index_word,
'vector_size': self.vector_size,
'window': self.window,
'min_count': self.min_count,
'workers': self.workers,
'sg': self.sg,
'max_sequence_length': self.max_sequence_length,
'padding': self.padding,
'truncating': self.truncating,
'is_fitted': self.is_fitted,
'model_path': model_path if self.model else None
}
save_pickle(state, path)
logger.info(f"Word2Vec向量化器已保存到{path}")
def load(self, path: str) -> None:
"""
加载Word2Vec向量化器
Args:
path: 加载路径
"""
state = load_pickle(path)
self.word_index = state['word_index']
self.index_word = state['index_word']
self.vector_size = state['vector_size']
self.window = state['window']
self.min_count = state['min_count']
self.workers = state['workers']
self.sg = state['sg']
self.max_sequence_length = state['max_sequence_length']
self.padding = state['padding']
self.truncating = state['truncating']
self.is_fitted = state['is_fitted']
# 加载模型
model_path = state.get('model_path')
if model_path and os.path.exists(model_path):
self.model = Word2Vec.load(model_path)
logger.info(f"Word2Vec向量化器已从 {path} 加载,词汇表大小:{len(self.word_index)}")

29
requirements.txt Normal file
View File

@ -0,0 +1,29 @@
# 核心依赖
tensorflow>=2.5.0
numpy>=1.19.5
pandas>=1.3.0
scikit-learn>=0.24.2
matplotlib>=3.4.2
jieba>=0.42.1
tqdm>=4.61.1
gensim>=4.0.1
# 数据处理
nltk>=3.6.2
symspellpy>=6.7.0
h5py>=3.1.0
openpyxl>=3.0.7
xlrd>=2.0.1
# Web和API依赖
flask>=2.0.1
fastapi>=0.68.0
uvicorn>=0.15.0
pydantic>=1.8.2
werkzeug>=2.0.1
jinja2>=3.0.1
python-multipart>=0.0.5
# 其他工具
requests>=2.26.0
synonyms>=3.15.0

166
scripts/evaluate.py Normal file
View File

@ -0,0 +1,166 @@
"""
评估脚本评估文本分类模型性能
"""
import os
import sys
import time
import argparse
import logging
from typing import List, Dict, Tuple, Optional, Any, Union
import numpy as np
import matplotlib.pyplot as plt
# 将项目根目录添加到系统路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
from config.system_config import (
RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR
)
from config.model_config import (
BATCH_SIZE, MAX_SEQUENCE_LENGTH
)
from data.dataloader import DataLoader
from data.data_manager import DataManager
from preprocessing.tokenization import ChineseTokenizer
from preprocessing.vectorizer import SequenceVectorizer
from models.model_factory import ModelFactory
from evaluation.evaluator import ModelEvaluator
from utils.logger import get_logger
logger = get_logger("Evaluation")
def evaluate_model(model_path: str,
data_dir: Optional[str] = None,
batch_size: int = BATCH_SIZE,
output_dir: Optional[str] = None) -> Dict[str, float]:
"""
评估文本分类模型
Args:
model_path: 模型路径
data_dir: 数据目录如果为None则使用默认目录
batch_size: 批大小
output_dir: 评估结果输出目录如果为None则使用默认目录
Returns:
评估指标
"""
logger.info(f"开始评估模型: {model_path}")
start_time = time.time()
# 设置数据目录
data_dir = data_dir or RAW_DATA_DIR
# 设置输出目录
if output_dir:
output_dir = os.path.abspath(output_dir)
os.makedirs(output_dir, exist_ok=True)
# 1. 加载模型
logger.info("加载模型...")
model = ModelFactory.load_model(model_path)
# 2. 加载数据
logger.info("加载数据...")
data_loader = DataLoader(data_dir=data_dir)
data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR)
# 加载测试集
data_manager.load_data()
test_texts, test_labels = data_manager.get_data(dataset="test")
# 3. 准备数据
# 创建分词器
tokenizer = ChineseTokenizer()
# 对测试文本进行分词
logger.info("对文本进行分词...")
tokenized_test_texts = [tokenizer.tokenize(text, return_string=True) for text in test_texts]
# 创建序列向量化器
logger.info("加载向量化器...")
# 查找向量化器文件
vectorizer_path = None
for model_type in ["cnn", "rnn", "transformer"]:
path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
if os.path.exists(path):
vectorizer_path = path
break
if not vectorizer_path:
# 如果找不到向量化器,创建一个新的
logger.warning("未找到向量化器,创建一个新的")
vectorizer = SequenceVectorizer(
max_features=MAX_NUM_WORDS,
max_sequence_length=MAX_SEQUENCE_LENGTH
)
else:
# 加载向量化器
vectorizer = SequenceVectorizer()
vectorizer.load(vectorizer_path)
# 转换测试文本
X_test = vectorizer.transform(tokenized_test_texts)
# 4. 创建评估器
logger.info("创建评估器...")
evaluator = ModelEvaluator(
model=model,
class_names=CATEGORIES,
output_dir=output_dir
)
# 5. 评估模型
logger.info("评估模型...")
metrics = evaluator.evaluate(X_test, test_labels, batch_size)
# 6. 保存评估结果
logger.info("保存评估结果...")
evaluator.save_evaluation_results(save_plots=True)
# 7. 可视化混淆矩阵
logger.info("可视化混淆矩阵...")
cm = evaluator.evaluation_results['confusion_matrix']
evaluator.metrics.plot_confusion_matrix(
y_true=test_labels,
y_pred=np.argmax(model.predict(X_test), axis=1),
normalize='true',
save_path=os.path.join(output_dir or os.path.dirname(model_path), "confusion_matrix.png")
)
# 8. 类别性能分析
logger.info("分析各类别性能...")
class_performance = evaluator.evaluate_class_performance(X_test, test_labels)
# 9. 计算评估时间
eval_time = time.time() - start_time
logger.info(f"模型评估完成,耗时: {eval_time:.2f}")
# 10. 输出主要指标
logger.info("主要评估指标:")
for metric_name, metric_value in metrics.items():
if metric_name in ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']:
logger.info(f" {metric_name}: {metric_value:.4f}")
return metrics
if __name__ == "__main__":
# 解析命令行参数
parser = argparse.ArgumentParser(description="评估文本分类模型")
parser.add_argument("--model_path", required=True, help="模型路径")
parser.add_argument("--data_dir", help="数据目录")
parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小")
parser.add_argument("--output_dir", help="评估结果输出目录")
args = parser.parse_args()
# 评估模型
evaluate_model(
model_path=args.model_path,
data_dir=args.data_dir,
batch_size=args.batch_size,
output_dir=args.output_dir
)

242
scripts/predict.py Normal file
View File

@ -0,0 +1,242 @@
"""
预测脚本使用模型进行预测
"""
import os
import sys
import time
import argparse
import logging
from typing import List, Dict, Tuple, Optional, Any, Union
import numpy as np
import json
# 将项目根目录添加到系统路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
from config.system_config import (
CATEGORIES, CLASSIFIERS_DIR
)
from models.model_factory import ModelFactory
from preprocessing.tokenization import ChineseTokenizer
from preprocessing.vectorizer import SequenceVectorizer
from inference.predictor import Predictor
from inference.batch_processor import BatchProcessor
from utils.logger import get_logger
from utils.file_utils import read_text_file
logger = get_logger("Prediction")
def predict_text(text: str, model_path: Optional[str] = None,
output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]:
"""
预测单条文本
Args:
text: 要预测的文本
model_path: 模型路径如果为None则使用最新的模型
output_path: 输出文件路径如果为None则不保存
top_k: 返回概率最高的前k个类别
Returns:
预测结果
"""
logger.info("开始预测文本")
# 1. 加载模型
if model_path is None:
# 获取可用模型列表
models_info = ModelFactory.get_available_models()
if not models_info:
raise ValueError("未找到可用的模型")
# 使用最新的模型
model_path = models_info[0]['path']
logger.info(f"加载模型: {model_path}")
model = ModelFactory.load_model(model_path)
# 2. 创建分词器和预测器
tokenizer = ChineseTokenizer()
# 查找向量化器文件
vectorizer = None
for model_type in ["cnn", "rnn", "transformer"]:
vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
if os.path.exists(vectorizer_path):
# 加载向量化器
vectorizer = SequenceVectorizer()
vectorizer.load(vectorizer_path)
logger.info(f"加载向量化器: {vectorizer_path}")
break
# 创建预测器
predictor = Predictor(
model=model,
tokenizer=tokenizer,
vectorizer=vectorizer,
class_names=CATEGORIES
)
# 3. 预测
result = predictor.predict(
text=text,
return_top_k=top_k,
return_probabilities=True
)
# 4. 输出结果
if top_k > 1:
logger.info("预测结果:")
for i, pred in enumerate(result):
logger.info(f" {i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
else:
logger.info(f"预测结果: {result['class']} (概率: {result['probability']:.4f})")
# 5. 保存结果
if output_path:
if output_path.endswith('.json'):
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
else:
with open(output_path, 'w', encoding='utf-8') as f:
if top_k > 1:
f.write("rank,class,probability\n")
for i, pred in enumerate(result):
f.write(f"{i + 1},{pred['class']},{pred['probability']}\n")
else:
f.write(f"class,probability\n")
f.write(f"{result['class']},{result['probability']}\n")
logger.info(f"结果已保存到: {output_path}")
return result
def predict_file(file_path: str, model_path: Optional[str] = None,
output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]:
"""
预测文件内容
Args:
file_path: 文件路径
model_path: 模型路径如果为None则使用最新的模型
output_path: 输出文件路径如果为None则不保存
top_k: 返回概率最高的前k个类别
Returns:
预测结果
"""
logger.info(f"开始预测文件: {file_path}")
# 检查文件是否存在
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
# 读取文件内容
if file_path.endswith('.txt'):
# 文本文件
text = read_text_file(file_path)
return predict_text(text, model_path, output_path, top_k)
elif file_path.endswith(('.csv', '.xls', '.xlsx')):
# 表格文件
import pandas as pd
if file_path.endswith('.csv'):
df = pd.read_csv(file_path)
else:
df = pd.read_excel(file_path)
# 查找可能的文本列
text_columns = [col for col in df.columns if df[col].dtype == 'object']
if not text_columns:
raise ValueError("文件中没有找到可能的文本列")
# 使用第一个文本列
text_column = text_columns[0]
logger.info(f"使用文本列: {text_column}")
# 1. 加载模型
if model_path is None:
# 获取可用模型列表
models_info = ModelFactory.get_available_models()
if not models_info:
raise ValueError("未找到可用的模型")
# 使用最新的模型
model_path = models_info[0]['path']
logger.info(f"加载模型: {model_path}")
model = ModelFactory.load_model(model_path)
# 2. 创建分词器和预测器
tokenizer = ChineseTokenizer()
# 查找向量化器文件
vectorizer = None
for model_type in ["cnn", "rnn", "transformer"]:
vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
if os.path.exists(vectorizer_path):
# 加载向量化器
vectorizer = SequenceVectorizer()
vectorizer.load(vectorizer_path)
logger.info(f"加载向量化器: {vectorizer_path}")
break
# 创建预测器
predictor = Predictor(
model=model,
tokenizer=tokenizer,
vectorizer=vectorizer,
class_names=CATEGORIES
)
# 3. 创建批处理器
batch_processor = BatchProcessor(
predictor=predictor,
batch_size=64
)
# 4. 批量预测
result_df = batch_processor.process_dataframe(
df=df,
text_column=text_column,
output_path=output_path,
return_top_k=top_k,
format=output_path.split('.')[-1] if output_path else 'csv'
)
logger.info(f"已处理 {len(result_df)} 行数据")
# 返回结果
return result_df.to_dict(orient='records')
else:
raise ValueError(f"不支持的文件类型: {file_path}")
if __name__ == "__main__":
# 解析命令行参数
parser = argparse.ArgumentParser(description="使用模型预测")
parser.add_argument("--model_path", help="模型路径")
parser.add_argument("--text", help="要预测的文本")
parser.add_argument("--file", help="要预测的文件")
parser.add_argument("--output", help="输出文件")
parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别")
args = parser.parse_args()
# 检查输入
if not args.text and not args.file:
parser.error("请提供要预测的文本或文件")
# 预测
if args.text:
predict_text(args.text, args.model_path, args.output, args.top_k)
else:
predict_file(args.file, args.model_path, args.output, args.top_k)

203
scripts/train.py Normal file
View File

@ -0,0 +1,203 @@
"""
训练脚本训练文本分类模型
"""
import os
import sys
import time
import argparse
import logging
from typing import List, Dict, Tuple, Optional, Any, Union
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# 将项目根目录添加到系统路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
from config.system_config import (
RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR
)
from config.model_config import (
BATCH_SIZE, NUM_EPOCHS, MAX_SEQUENCE_LENGTH, MAX_NUM_WORDS
)
from data.dataloader import DataLoader
from data.data_manager import DataManager
from preprocessing.tokenization import ChineseTokenizer
from preprocessing.vectorizer import SequenceVectorizer
from models.model_factory import ModelFactory
from training.trainer import Trainer
from utils.logger import get_logger
logger = get_logger("Training")
def train_model(data_dir: Optional[str] = None,
model_type: str = "cnn",
epochs: int = NUM_EPOCHS,
batch_size: int = BATCH_SIZE,
save_dir: Optional[str] = None,
validation_split: float = 0.1,
use_pretrained_embedding: bool = False,
embedding_path: Optional[str] = None) -> str:
"""
训练文本分类模型
Args:
data_dir: 数据目录如果为None则使用默认目录
model_type: 模型类型'cnn', 'rnn', 'transformer'
epochs: 训练轮数
batch_size: 批大小
save_dir: 模型保存目录如果为None则使用默认目录
validation_split: 验证集比例
use_pretrained_embedding: 是否使用预训练词向量
embedding_path: 预训练词向量路径
Returns:
保存的模型路径
"""
logger.info(f"开始训练 {model_type.upper()} 模型")
start_time = time.time()
# 设置数据目录
data_dir = data_dir or RAW_DATA_DIR
# 设置保存目录
if save_dir:
save_dir = os.path.abspath(save_dir)
os.makedirs(save_dir, exist_ok=True)
else:
save_dir = CLASSIFIERS_DIR
os.makedirs(save_dir, exist_ok=True)
# 1. 加载数据
logger.info("加载数据...")
data_loader = DataLoader(data_dir=data_dir)
data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR)
# 加载和分割数据
data = data_manager.load_and_split_data(
data_loader=data_loader,
val_split=validation_split,
sample_ratio=1.0,
save=True
)
# 获取训练集和验证集
train_texts, train_labels = data_manager.get_data(dataset="train")
val_texts, val_labels = data_manager.get_data(dataset="val")
# 2. 准备数据
# 创建分词器
tokenizer = ChineseTokenizer()
# 对训练文本进行分词
logger.info("对文本进行分词...")
tokenized_train_texts = [tokenizer.tokenize(text, return_string=True) for text in train_texts]
tokenized_val_texts = [tokenizer.tokenize(text, return_string=True) for text in val_texts]
# 创建序列向量化器
logger.info("创建序列向量化器...")
vectorizer = SequenceVectorizer(
max_features=MAX_NUM_WORDS,
max_sequence_length=MAX_SEQUENCE_LENGTH
)
# 训练向量化器并转换文本
vectorizer.fit(tokenized_train_texts)
X_train = vectorizer.transform(tokenized_train_texts)
X_val = vectorizer.transform(tokenized_val_texts)
# 保存向量化器
vectorizer_path = os.path.join(save_dir, f"vectorizer_{model_type}.pkl")
vectorizer.save(vectorizer_path)
logger.info(f"向量化器已保存到: {vectorizer_path}")
# 获取一些基本参数
num_classes = len(CATEGORIES)
vocab_size = vectorizer.get_vocabulary_size()
# 3. 创建模型
logger.info(f"创建 {model_type.upper()} 模型...")
# 加载预训练词向量(如果指定)
embedding_matrix = None
if use_pretrained_embedding and embedding_path:
# 这里简化处理,实际应用中应该加载和处理预训练词向量
logger.info("加载预训练词向量...")
embedding_matrix = np.random.random((vocab_size, 200))
# 创建模型
model = ModelFactory.create_model(
model_type=model_type,
num_classes=num_classes,
vocab_size=vocab_size,
embedding_matrix=embedding_matrix,
batch_size=batch_size
)
# 构建模型
model.build()
model.compile()
model.summary()
# 4. 训练模型
logger.info("开始训练模型...")
trainer = Trainer(
model=model,
epochs=epochs,
batch_size=batch_size,
early_stopping=True,
tensorboard=True
)
# 训练
history = trainer.train(
x_train=X_train,
y_train=train_labels,
x_val=X_val,
y_val=val_labels
)
# 5. 保存模型
timestamp = time.strftime("%Y%m%d_%H%M%S")
model_path = os.path.join(save_dir, f"{model_type}_model_{timestamp}")
model.save(model_path)
logger.info(f"模型已保存到: {model_path}")
# 6. 绘制训练历史
logger.info("绘制训练历史...")
model.plot_training_history(save_path=os.path.join(save_dir, f"training_history_{model_type}_{timestamp}.png"))
# 7. 计算训练时间
train_time = time.time() - start_time
logger.info(f"模型训练完成,耗时: {train_time:.2f}")
return model_path
if __name__ == "__main__":
# 解析命令行参数
parser = argparse.ArgumentParser(description="训练文本分类模型")
parser.add_argument("--data_dir", help="数据目录")
parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型")
parser.add_argument("--epochs", type=int, default=NUM_EPOCHS, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小")
parser.add_argument("--save_dir", help="模型保存目录")
parser.add_argument("--validation_split", type=float, default=0.1, help="验证集比例")
parser.add_argument("--use_pretrained_embedding", action="store_true", help="是否使用预训练词向量")
parser.add_argument("--embedding_path", help="预训练词向量路径")
args = parser.parse_args()
# 训练模型
train_model(
data_dir=args.data_dir,
model_type=args.model_type,
epochs=args.epochs,
batch_size=args.batch_size,
save_dir=args.save_dir,
validation_split=args.validation_split,
use_pretrained_embedding=args.use_pretrained_embedding,
embedding_path=args.embedding_path
)

40
setup.py Normal file
View File

@ -0,0 +1,40 @@
from setuptools import setup, find_packages
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
setup(
name="chinese-text-classification",
version="1.0.0",
author="Your Name",
author_email="your.email@example.com",
description="基于Python的中文文本分类系统",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/yourusername/chinese-text-classification",
packages=find_packages(),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires=">=3.7",
install_requires=[
"tensorflow>=2.5.0",
"numpy>=1.19.5",
"pandas>=1.3.0",
"scikit-learn>=0.24.2",
"matplotlib>=3.4.2",
"jieba>=0.42.1",
"tqdm>=4.61.1",
"gensim>=4.0.1",
"flask>=2.0.1",
"fastapi>=0.68.0",
"uvicorn>=0.15.0"
],
entry_points={
"console_scripts": [
"text-classifier=main:main",
],
},
)

0
tests/__init__.py Normal file
View File

0
tests/test_evaluation.py Normal file
View File

0
tests/test_models.py Normal file
View File

View File

0
training/__init__.py Normal file
View File

430
training/callbacks.py Normal file
View File

@ -0,0 +1,430 @@
"""
回调函数模块提供用于模型训练的自定义回调函数
"""
import os
import time
import numpy as np
import tensorflow as tf
from typing import List, Dict, Tuple, Optional, Any, Union
import matplotlib.pyplot as plt
from io import BytesIO
from utils.logger import get_logger
logger = get_logger("Callbacks")
class MetricsHistory(tf.keras.callbacks.Callback):
"""跟踪训练过程中的指标历史"""
def __init__(self, validation_data: Optional[Tuple] = None,
metrics: Optional[List[str]] = None,
save_path: Optional[str] = None):
"""
初始化MetricsHistory回调
Args:
validation_data: 验证数据格式为(x_val, y_val)
metrics: 要跟踪的指标列表
save_path: 指标历史的保存路径
"""
super().__init__()
self.validation_data = validation_data
self.metrics = metrics or ['loss', 'accuracy']
self.save_path = save_path
# 历史指标
self.history = {metric: [] for metric in self.metrics}
if validation_data is not None:
for metric in self.metrics:
self.history[f'val_{metric}'] = []
def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
"""
每个epoch结束时调用
Args:
epoch: 当前epoch索引
logs: 训练日志
"""
logs = logs or {}
# 记录训练指标
for metric in self.metrics:
if metric in logs:
self.history[metric].append(logs[metric])
# 记录验证指标
if self.validation_data is not None:
for metric in self.metrics:
val_metric = f'val_{metric}'
if val_metric in logs:
self.history[val_metric].append(logs[val_metric])
def plot_metrics(self, save_path: Optional[str] = None) -> None:
"""
绘制指标历史
Args:
save_path: 图像保存路径如果为None则使用初始化时设置的路径
"""
plt.figure(figsize=(12, 5))
for i, metric in enumerate(self.metrics):
plt.subplot(1, len(self.metrics), i + 1)
if metric in self.history:
plt.plot(self.history[metric], label=f'train_{metric}')
val_metric = f'val_{metric}'
if val_metric in self.history:
plt.plot(self.history[val_metric], label=f'val_{metric}')
plt.title(f'Model {metric}')
plt.xlabel('Epoch')
plt.ylabel(metric)
plt.legend()
plt.tight_layout()
save_path = save_path or self.save_path
if save_path:
plt.savefig(save_path)
logger.info(f"指标历史图已保存到: {save_path}")
else:
plt.show()
class ConfusionMatrixCallback(tf.keras.callbacks.Callback):
"""计算并显示验证集上的混淆矩阵"""
def __init__(self, validation_data: Tuple[np.ndarray, np.ndarray],
class_names: Optional[List[str]] = None,
log_dir: Optional[str] = None,
freq: int = 1,
fig_size: Tuple[int, int] = (10, 8)):
"""
初始化ConfusionMatrixCallback
Args:
validation_data: 验证数据格式为(x_val, y_val)
class_names: 类别名称列表
log_dir: TensorBoard日志目录
freq: 计算混淆矩阵的频率每多少个epoch计算一次
fig_size: 图像大小
"""
super().__init__()
self.x_val, self.y_val = validation_data
self.class_names = class_names
self.log_dir = log_dir
self.freq = freq
self.fig_size = fig_size
# 如果提供了TensorBoard日志目录创建一个文件写入器
if log_dir:
self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'confusion_matrix'))
else:
self.file_writer = None
def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
"""
每个epoch结束时调用
Args:
epoch: 当前epoch索引
logs: 训练日志
"""
# 每freq个epoch计算一次混淆矩阵
if (epoch + 1) % self.freq == 0 or epoch == 0:
# 获取预测结果
y_pred = np.argmax(self.model.predict(self.x_val), axis=1)
# 确保y_val是一维数组
y_true = self.y_val
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 计算混淆矩阵
cm = tf.math.confusion_matrix(y_true, y_pred).numpy()
# 归一化混淆矩阵
cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
# 绘制混淆矩阵
fig = self._plot_confusion_matrix(cm_norm, epoch + 1)
# 如果有TensorBoard日志将图像添加到TensorBoard
if self.file_writer:
with self.file_writer.as_default():
# 将matplotlib图像转换为TensorBoard图像
buf = BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
# 将PNG编码为字符串并创建图像
image = tf.image.decode_png(buf.getvalue(), channels=4)
image = tf.expand_dims(image, 0)
# 添加到TensorBoard
tf.summary.image(f'Confusion Matrix (Epoch {epoch + 1})', image, step=epoch)
plt.close(fig)
def _plot_confusion_matrix(self, cm: np.ndarray, epoch: int) -> plt.Figure:
"""
绘制混淆矩阵
Args:
cm: 混淆矩阵
epoch: 当前epoch
Returns:
matplotlib图像
"""
fig, ax = plt.subplots(figsize=self.fig_size)
# 使用热图显示混淆矩阵
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
# 设置坐标轴标签
if self.class_names:
ax.set(
xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=self.class_names,
yticklabels=self.class_names
)
# 旋转x轴标签
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# 在每个单元格中显示数值
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], '.2f'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
ax.set_title(f"Normalized Confusion Matrix (Epoch {epoch})")
ax.set_ylabel('True label')
ax.set_xlabel('Predicted label')
fig.tight_layout()
return fig
class TimingCallback(tf.keras.callbacks.Callback):
"""测量训练时间的回调函数"""
def __init__(self):
"""初始化TimingCallback"""
super().__init__()
self.epoch_times = []
self.batch_times = []
self.epoch_start_time = None
self.batch_start_time = None
self.training_start_time = None
def on_train_begin(self, logs: Dict[str, float] = None) -> None:
"""
训练开始时调用
Args:
logs: 训练日志
"""
self.training_start_time = time.time()
def on_train_end(self, logs: Dict[str, float] = None) -> None:
"""
训练结束时调用
Args:
logs: 训练日志
"""
training_time = time.time() - self.training_start_time
logger.info(f"总训练时间: {training_time:.2f}")
if self.epoch_times:
avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
logger.info(f"平均每个epoch时间: {avg_epoch_time:.2f}")
if self.batch_times:
avg_batch_time = sum(self.batch_times) / len(self.batch_times)
logger.info(f"平均每个batch时间: {avg_batch_time:.4f}")
def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None:
"""
每个epoch开始时调用
Args:
epoch: 当前epoch索引
logs: 训练日志
"""
self.epoch_start_time = time.time()
def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
"""
每个epoch结束时调用
Args:
epoch: 当前epoch索引
logs: 训练日志
"""
epoch_time = time.time() - self.epoch_start_time
self.epoch_times.append(epoch_time)
# 将epoch时间添加到日志中
if logs is not None:
logs['epoch_time'] = epoch_time
def on_batch_begin(self, batch: int, logs: Dict[str, float] = None) -> None:
"""
每个batch开始时调用
Args:
batch: 当前batch索引
logs: 训练日志
"""
self.batch_start_time = time.time()
def on_batch_end(self, batch: int, logs: Dict[str, float] = None) -> None:
"""
每个batch结束时调用
Args:
batch: 当前batch索引
logs: 训练日志
"""
batch_time = time.time() - self.batch_start_time
self.batch_times.append(batch_time)
class LearningRateSchedulerCallback(tf.keras.callbacks.Callback):
"""学习率调度器回调函数"""
def __init__(self, scheduler_func: Callable[[int, float], float],
verbose: int = 0,
log_dir: Optional[str] = None):
"""
初始化LearningRateSchedulerCallback
Args:
scheduler_func: 学习率调度函数接收(epoch, lr)参数返回新的学习率
verbose: 详细程度
log_dir: TensorBoard日志目录
"""
super().__init__()
self.scheduler_func = scheduler_func
self.verbose = verbose
# 如果提供了TensorBoard日志目录创建一个文件写入器
if log_dir:
self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'learning_rate'))
else:
self.file_writer = None
# 学习率历史
self.lr_history = []
def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None:
"""
每个epoch开始时调用
Args:
epoch: 当前epoch索引
logs: 训练日志
"""
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
# 获取当前学习率
current_lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
# 计算新的学习率
new_lr = self.scheduler_func(epoch, current_lr)
# 设置新的学习率
tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
# 记录学习率
self.lr_history.append(new_lr)
# 记录到TensorBoard
if self.file_writer:
with self.file_writer.as_default():
tf.summary.scalar('learning_rate', new_lr, step=epoch)
if self.verbose > 0:
logger.info(f"Epoch {epoch + 1}: 学习率设置为 {new_lr:.6f}")
def get_lr_history(self) -> List[float]:
"""
获取学习率历史
Returns:
学习率历史列表
"""
return self.lr_history
class EarlyStoppingCallback(tf.keras.callbacks.EarlyStopping):
"""增强版早停回调函数,支持最小变化率"""
def __init__(self, monitor: str = 'val_loss',
min_delta: float = 0,
min_delta_ratio: float = 0,
patience: int = 0,
verbose: int = 0,
mode: str = 'auto',
baseline: Optional[float] = None,
restore_best_weights: bool = False):
"""
初始化EarlyStoppingCallback
Args:
monitor: 监控的指标
min_delta: 视为改进的最小绝对变化
min_delta_ratio: 视为改进的最小相对变化率
patience: 没有改进的轮数
verbose: 详细程度
mode: 'auto', 'min' 'max'
baseline: 基准值
restore_best_weights: 是否恢复最佳权重
"""
super().__init__(
monitor=monitor,
min_delta=min_delta,
patience=patience,
verbose=verbose,
mode=mode,
baseline=baseline,
restore_best_weights=restore_best_weights
)
self.min_delta_ratio = min_delta_ratio
def _is_improvement(self, current: float, reference: float) -> bool:
"""
判断是否有所改进
Args:
current: 当前值
reference: 参考值
Returns:
是否有所改进
"""
# 先检查绝对变化
if super()._is_improvement(current, reference):
return True
# 再检查相对变化率
if self.monitor_op == np.less:
# 对于 'min' 模式,值越小越好
relative_delta = (reference - current) / reference if reference != 0 else 0
return relative_delta > self.min_delta_ratio
else:
# 对于 'max' 模式,值越大越好
relative_delta = (current - reference) / reference if reference != 0 else 0
return relative_delta > self.min_delta_ratio

0
training/optimizer.py Normal file
View File

284
training/scheduler.py Normal file
View File

@ -0,0 +1,284 @@
"""
学习率调度器模块提供各种学习率调度策略
"""
import numpy as np
import math
from typing import Callable, Optional, Union, Dict
import tensorflow as tf
from utils.logger import get_logger
logger = get_logger("Scheduler")
def step_decay(epoch: int, initial_lr: float,
drop_rate: float = 0.5,
epochs_drop: int = 10) -> float:
"""
阶梯式学习率衰减
Args:
epoch: 当前epoch索引
initial_lr: 初始学习率
drop_rate: 衰减率
epochs_drop: 每多少个epoch衰减一次
Returns:
新的学习率
"""
return initial_lr * math.pow(drop_rate, math.floor((1 + epoch) / epochs_drop))
def exponential_decay(epoch: int, initial_lr: float,
decay_rate: float = 0.9,
staircase: bool = False) -> float:
"""
指数衰减学习率
Args:
epoch: 当前epoch索引
initial_lr: 初始学习率
decay_rate: 衰减率
staircase: 是否阶梯式衰减
Returns:
新的学习率
"""
if staircase:
return initial_lr * math.pow(decay_rate, math.floor(epoch))
else:
return initial_lr * math.pow(decay_rate, epoch)
def cosine_decay(epoch: int, initial_lr: float,
total_epochs: int = 100,
min_lr: float = 0) -> float:
"""
余弦退火学习率
Args:
epoch: 当前epoch索引
initial_lr: 初始学习率
total_epochs: 总epoch数
min_lr: 最小学习率
Returns:
新的学习率
"""
return min_lr + 0.5 * (initial_lr - min_lr) * (1 + math.cos(math.pi * epoch / total_epochs))
def cosine_decay_with_warmup(epoch: int, initial_lr: float,
total_epochs: int = 100,
warmup_epochs: int = 5,
min_lr: float = 0,
warmup_init_lr: float = 0) -> float:
"""
带预热的余弦退火学习率
Args:
epoch: 当前epoch索引
initial_lr: 初始学习率
total_epochs: 总epoch数
warmup_epochs: 预热epoch数
min_lr: 最小学习率
warmup_init_lr: 预热初始学习率
Returns:
新的学习率
"""
if epoch < warmup_epochs:
# 线性预热
return warmup_init_lr + (initial_lr - warmup_init_lr) * epoch / warmup_epochs
else:
# 余弦退火
return min_lr + 0.5 * (initial_lr - min_lr) * (
1 + math.cos(math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs))
)
def cyclical_learning_rate(epoch: int, initial_lr: float,
max_lr: float = 0.1,
step_size: int = 8,
gamma: float = 1.0) -> float:
"""
循环学习率
Args:
epoch: 当前epoch索引
initial_lr: 初始学习率
max_lr: 最大学习率
step_size: 半周期大小epoch数
gamma: 循环衰减率
Returns:
新的学习率
"""
# 计算循环数
cycle = math.floor(1 + epoch / (2 * step_size))
# 计算x值范围在[0, 2]
x = abs(epoch / step_size - 2 * cycle + 1)
# 应用循环衰减
lr_range = (max_lr - initial_lr) * math.pow(gamma, cycle - 1)
# 计算学习率
return initial_lr + lr_range * max(0, 1 - x)
def create_custom_scheduler(scheduler_type: str, **kwargs) -> Callable[[int, float], float]:
"""
创建自定义学习率调度器
Args:
scheduler_type: 调度器类型可选值: 'step', 'exp', 'cosine', 'cosine_warmup', 'cyclical'
**kwargs: 调度器参数
Returns:
学习率调度函数
"""
scheduler_type = scheduler_type.lower()
if scheduler_type == 'step':
drop_rate = kwargs.get('drop_rate', 0.5)
epochs_drop = kwargs.get('epochs_drop', 10)
def scheduler(epoch, lr):
return step_decay(epoch, lr, drop_rate, epochs_drop)
return scheduler
elif scheduler_type == 'exp':
decay_rate = kwargs.get('decay_rate', 0.9)
staircase = kwargs.get('staircase', False)
def scheduler(epoch, lr):
if epoch == 0:
# 第一个epoch使用初始学习率
return lr
return exponential_decay(epoch, lr, decay_rate, staircase)
return scheduler
elif scheduler_type == 'cosine':
total_epochs = kwargs.get('total_epochs', 100)
min_lr = kwargs.get('min_lr', 0)
def scheduler(epoch, lr):
if epoch == 0:
return lr
return cosine_decay(epoch, lr, total_epochs, min_lr)
return scheduler
elif scheduler_type == 'cosine_warmup':
total_epochs = kwargs.get('total_epochs', 100)
warmup_epochs = kwargs.get('warmup_epochs', 5)
min_lr = kwargs.get('min_lr', 0)
warmup_init_lr = kwargs.get('warmup_init_lr', 0)
def scheduler(epoch, lr):
if epoch == 0:
return warmup_init_lr
return cosine_decay_with_warmup(epoch, lr, total_epochs, warmup_epochs, min_lr, warmup_init_lr)
return scheduler
elif scheduler_type == 'cyclical':
max_lr = kwargs.get('max_lr', 0.1)
step_size = kwargs.get('step_size', 8)
gamma = kwargs.get('gamma', 1.0)
def scheduler(epoch, lr):
if epoch == 0:
return lr
return cyclical_learning_rate(epoch, lr, max_lr, step_size, gamma)
return scheduler
else:
raise ValueError(f"不支持的调度器类型: {scheduler_type}")
class WarmupCosineDecayScheduler(tf.keras.callbacks.Callback):
"""预热余弦退火学习率调度器"""
def __init__(self, learning_rate_base: float,
total_steps: int,
warmup_learning_rate: float = 0.0,
warmup_steps: int = 0,
hold_base_rate_steps: int = 0,
verbose: int = 0):
"""
初始化预热余弦退火学习率调度器
Args:
learning_rate_base: 基础学习率
total_steps: 总步数
warmup_learning_rate: 预热学习率
warmup_steps: 预热步数
hold_base_rate_steps: 保持基础学习率的步数
verbose: 详细程度
"""
super().__init__()
self.learning_rate_base = learning_rate_base
self.total_steps = total_steps
self.warmup_learning_rate = warmup_learning_rate
self.warmup_steps = warmup_steps
self.hold_base_rate_steps = hold_base_rate_steps
self.verbose = verbose
# 学习率历史
self.learning_rates = []
def on_train_begin(self, logs: Optional[Dict] = None) -> None:
"""
训练开始时调用
Args:
logs: 训练日志
"""
self.current_step = 0
logger.info(f"预热余弦退火学习率调度器初始化: 基础学习率={self.learning_rate_base}, "
f"预热步数={self.warmup_steps}, 总步数={self.total_steps}")
def on_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""
每个batch结束时调用
Args:
batch: 当前batch索引
logs: 训练日志
"""
self.current_step += 1
lr = self._get_learning_rate()
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
self.learning_rates.append(lr)
def _get_learning_rate(self) -> float:
"""
计算当前学习率
Returns:
当前学习率
"""
if self.current_step < self.warmup_steps:
# 预热阶段:线性增加学习率
lr = self.warmup_learning_rate + self.current_step * (
(self.learning_rate_base - self.warmup_learning_rate) / self.warmup_steps
)
elif self.current_step < self.warmup_steps + self.hold_base_rate_steps:
# 保持基础学习率阶段
lr = self.learning_rate_base
else:
# 余弦退火阶段
cosine_steps = self.total_steps - self.warmup_steps - self.hold_base_rate_steps
cosine_current_step = self.current_step - self.warmup_steps - self.hold_base_rate_steps
lr = 0.5 * self.learning_rate_base * (
1 + math.cos(math.pi * cosine_current_step / cosine_steps)
)
return lr

281
training/trainer.py Normal file
View File

@ -0,0 +1,281 @@
"""
训练器模块实现模型训练流程包括训练循环验证等
"""
import os
import time
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime
from config.system_config import SAVED_MODELS_DIR
from config.model_config import (
NUM_EPOCHS, BATCH_SIZE, EARLY_STOPPING_PATIENCE,
VALIDATION_SPLIT, RANDOM_SEED
)
from models.base_model import TextClassificationModel
from utils.logger import get_logger, TrainingLogger
from utils.file_utils import ensure_dir
logger = get_logger("Trainer")
class Trainer:
"""模型训练器,负责训练和验证模型"""
def __init__(self, model: TextClassificationModel,
epochs: int = NUM_EPOCHS,
batch_size: Optional[int] = None,
validation_split: float = VALIDATION_SPLIT,
early_stopping: bool = True,
early_stopping_patience: int = EARLY_STOPPING_PATIENCE,
save_best_only: bool = True,
tensorboard: bool = True,
checkpoint: bool = True,
custom_callbacks: Optional[List[tf.keras.callbacks.Callback]] = None):
"""
初始化训练器
Args:
model: 要训练的模型
epochs: 训练轮数
batch_size: 批大小如果为None则使用模型默认值
validation_split: 验证集比例
early_stopping: 是否使用早停
early_stopping_patience: 早停耐心值
save_best_only: 是否只保存最佳模型
tensorboard: 是否使用TensorBoard
checkpoint: 是否保存检查点
custom_callbacks: 自定义回调函数列表
"""
self.model = model
self.epochs = epochs
self.batch_size = batch_size or model.batch_size
self.validation_split = validation_split
self.early_stopping = early_stopping
self.early_stopping_patience = early_stopping_patience
self.save_best_only = save_best_only
self.tensorboard = tensorboard
self.checkpoint = checkpoint
self.custom_callbacks = custom_callbacks or []
# 训练历史
self.history = None
# 训练日志记录器
self.training_logger = TrainingLogger(model.model_name)
logger.info(f"初始化训练器,模型: {model.model_name}, 轮数: {epochs}, 批大小: {self.batch_size}")
def _create_callbacks(self) -> List[tf.keras.callbacks.Callback]:
"""
创建回调函数列表
Returns:
回调函数列表
"""
callbacks = []
# 早停
if self.early_stopping:
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=self.early_stopping_patience,
restore_best_weights=True,
verbose=1
)
callbacks.append(early_stopping)
# 学习率衰减
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=self.early_stopping_patience // 2,
min_lr=1e-6,
verbose=1
)
callbacks.append(reduce_lr)
# 模型检查点
if self.checkpoint:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join(SAVED_MODELS_DIR, 'checkpoints')
ensure_dir(checkpoint_dir)
checkpoint_path = os.path.join(
checkpoint_dir,
f"{self.model.model_name}_{timestamp}.h5"
)
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_best_only=self.save_best_only,
monitor='val_loss',
verbose=1
)
callbacks.append(model_checkpoint)
# TensorBoard
if self.tensorboard:
log_dir = os.path.join(
SAVED_MODELS_DIR,
'logs',
f"{self.model.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
)
ensure_dir(log_dir)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1,
update_freq='epoch'
)
callbacks.append(tensorboard_callback)
# 添加自定义回调函数
callbacks.extend(self.custom_callbacks)
return callbacks
def _log_training_progress(self, epoch: int, logs: Dict[str, float]) -> None:
"""
记录训练进度
Args:
epoch: 当前轮数
logs: 日志信息
"""
self.training_logger.log_epoch(epoch, logs)
def train(self, x_train: Union[np.ndarray, tf.data.Dataset],
y_train: Optional[np.ndarray] = None,
x_val: Optional[Union[np.ndarray, tf.data.Dataset]] = None,
y_val: Optional[np.ndarray] = None,
class_weights: Optional[Dict[int, float]] = None) -> Dict[str, List[float]]:
"""
训练模型
Args:
x_train: 训练数据特征
y_train: 训练数据标签
x_val: 验证数据特征
y_val: 验证数据标签
class_weights: 类别权重
Returns:
训练历史
"""
logger.info(f"开始训练模型: {self.model.model_name}")
# 创建回调函数
callbacks = self._create_callbacks()
# 添加训练进度记录回调
progress_callback = tf.keras.callbacks.LambdaCallback(
on_epoch_end=lambda epoch, logs: self._log_training_progress(epoch, logs)
)
callbacks.append(progress_callback)
# 记录开始时间
start_time = time.time()
# 记录训练开始信息
model_config = self.model.get_config()
train_config = {
"epochs": self.epochs,
"batch_size": self.batch_size,
"validation_split": self.validation_split,
"early_stopping": self.early_stopping,
"early_stopping_patience": self.early_stopping_patience
}
self.training_logger.log_training_start({**model_config, **train_config})
# 准备验证数据
validation_data = None
if x_val is not None and y_val is not None:
validation_data = (x_val, y_val)
# 训练模型
history = self.model.fit(
x_train, y_train,
validation_data=validation_data,
epochs=self.epochs,
callbacks=callbacks,
class_weights=class_weights,
verbose=1
)
# 计算训练时间
train_time = time.time() - start_time
# 保存训练历史
self.history = history.history
# 找出最佳性能
best_val_loss = min(history.history['val_loss']) if 'val_loss' in history.history else None
best_val_acc = max(history.history['val_accuracy']) if 'val_accuracy' in history.history else None
best_metrics = {}
if best_val_loss is not None:
best_metrics['val_loss'] = best_val_loss
if best_val_acc is not None:
best_metrics['val_accuracy'] = best_val_acc
# 记录训练结束信息
self.training_logger.log_training_end(train_time, best_metrics)
logger.info(f"模型训练完成,用时: {train_time:.2f}")
return history.history
def plot_training_history(self, metrics: Optional[List[str]] = None,
save_path: Optional[str] = None) -> None:
"""
绘制训练历史
Args:
metrics: 要绘制的指标列表默认为['loss', 'accuracy']
save_path: 保存路径如果为None则显示图像
"""
if self.history is None:
raise ValueError("模型尚未训练,没有训练历史")
if metrics is None:
metrics = ['loss', 'accuracy']
plt.figure(figsize=(12, 5))
for i, metric in enumerate(metrics):
plt.subplot(1, len(metrics), i + 1)
if metric in self.history:
plt.plot(self.history[metric], label=f'train_{metric}')
val_metric = f'val_{metric}'
if val_metric in self.history:
plt.plot(self.history[val_metric], label=f'val_{metric}')
plt.title(f'Model {metric}')
plt.xlabel('Epoch')
plt.ylabel(metric)
plt.legend()
plt.tight_layout()
if save_path:
plt.savefig(save_path)
logger.info(f"训练历史图已保存到: {save_path}")
else:
plt.show()
def save_trained_model(self, filepath: Optional[str] = None) -> str:
"""
保存训练好的模型
Args:
filepath: 保存路径如果为None则使用默认路径
Returns:
保存路径
"""
return self.model.save(filepath)

0
utils/__init__.py Normal file
View File

306
utils/file_utils.py Normal file
View File

@ -0,0 +1,306 @@
"""
文件处理工具模块
"""
import os
import shutil
import json
import pickle
import csv
from pathlib import Path
import time
import hashlib
from concurrent.futures import ThreadPoolExecutor, as_completed
import zipfile
import tarfile
from config.system_config import ENCODING, DATA_LOADING_WORKERS
from utils.logger import get_logger
logger = get_logger("file_utils")
def read_text_file(file_path, encoding=ENCODING):
"""
读取文本文件内容
Args:
file_path: 文件路径
encoding: 文件编码
Returns:
文件内容
"""
try:
with open(file_path, 'r', encoding=encoding) as file:
return file.read()
except Exception as e:
logger.error(f"读取文件 {file_path} 时出错: {str(e)}")
return None
def write_text_file(content, file_path, encoding=ENCODING):
"""
写入文本文件
Args:
content: 文件内容
file_path: 文件路径
encoding: 文件编码
Returns:
成功标志
"""
try:
# 确保目录存在
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w', encoding=encoding) as file:
file.write(content)
return True
except Exception as e:
logger.error(f"写入文件 {file_path} 时出错: {str(e)}")
return False
def save_json(data, file_path, encoding=ENCODING):
"""
保存JSON数据到文件
Args:
data: 要保存的数据
file_path: 文件路径
encoding: 文件编码
Returns:
成功标志
"""
try:
# 确保目录存在
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w', encoding=encoding) as file:
json.dump(data, file, ensure_ascii=False, indent=2)
return True
except Exception as e:
logger.error(f"保存JSON文件 {file_path} 时出错: {str(e)}")
return False
def load_json(file_path, encoding=ENCODING):
"""
从文件加载JSON数据
Args:
file_path: 文件路径
encoding: 文件编码
Returns:
加载的数据
"""
try:
with open(file_path, 'r', encoding=encoding) as file:
return json.load(file)
except Exception as e:
logger.error(f"加载JSON文件 {file_path} 时出错: {str(e)}")
return None
def save_pickle(data, file_path):
"""
使用pickle保存数据
Args:
data: 要保存的数据
file_path: 文件路径
Returns:
成功标志
"""
try:
# 确保目录存在
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'wb') as file:
pickle.dump(data, file)
return True
except Exception as e:
logger.error(f"保存pickle文件 {file_path} 时出错: {str(e)}")
return False
def load_pickle(file_path):
"""
从文件加载pickle数据
Args:
file_path: 文件路径
Returns:
加载的数据
"""
try:
with open(file_path, 'rb') as file:
return pickle.load(file)
except Exception as e:
logger.error(f"加载pickle文件 {file_path} 时出错: {str(e)}")
return None
def read_files_parallel(file_paths, max_workers=DATA_LOADING_WORKERS, encoding=ENCODING):
"""
并行读取多个文本文件
Args:
file_paths: 文件路径列表
max_workers: 最大工作线程数
encoding: 文件编码
Returns:
文件内容列表
"""
start_time = time.time()
results = []
# 定义单个读取函数
def read_single_file(file_path):
return read_text_file(file_path, encoding)
# 使用线程池并行读取
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_file = {executor.submit(read_single_file, file_path): file_path
for file_path in file_paths}
# 收集结果
for future in as_completed(future_to_file):
file_path = future_to_file[future]
try:
content = future.result()
if content is not None:
results.append(content)
except Exception as e:
logger.error(f"处理文件 {file_path} 时出错: {str(e)}")
elapsed = time.time() - start_time
logger.info(f"并行读取 {len(file_paths)} 个文件,成功 {len(results)} 个,用时 {elapsed:.2f}")
return results
def get_file_md5(file_path):
"""
计算文件的MD5哈希值
Args:
file_path: 文件路径
Returns:
MD5哈希值
"""
hash_md5 = hashlib.md5()
try:
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
except Exception as e:
logger.error(f"计算文件 {file_path} 的MD5值时出错: {str(e)}")
return None
def extract_archive(archive_path, extract_to=None):
"""
解压缩文件
Args:
archive_path: 压缩文件路径
extract_to: 解压目标路径默认为同目录
Returns:
成功标志
"""
if extract_to is None:
extract_to = os.path.dirname(archive_path)
try:
if archive_path.endswith('.zip'):
with zipfile.ZipFile(archive_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
elif archive_path.endswith(('.tar.gz', '.tgz')):
with tarfile.open(archive_path, 'r:gz') as tar_ref:
tar_ref.extractall(extract_to)
elif archive_path.endswith('.tar'):
with tarfile.open(archive_path, 'r') as tar_ref:
tar_ref.extractall(extract_to)
else:
logger.error(f"不支持的压缩格式: {archive_path}")
return False
logger.info(f"成功解压 {archive_path}{extract_to}")
return True
except Exception as e:
logger.error(f"解压 {archive_path} 时出错: {str(e)}")
return False
def list_files(directory, pattern=None, recursive=True):
"""
列出目录中的文件
Args:
directory: 目录路径
pattern: 文件名模式支持通配符
recursive: 是否递归搜索子目录
Returns:
文件路径列表
"""
if not os.path.exists(directory):
logger.error(f"目录不存在: {directory}")
return []
directory = Path(directory)
if pattern:
if recursive:
return [str(p) for p in directory.glob(f"**/{pattern}")]
else:
return [str(p) for p in directory.glob(pattern)]
else:
if recursive:
files = []
for p in directory.rglob("*"):
if p.is_file():
files.append(str(p))
return files
else:
return [str(p) for p in directory.iterdir() if p.is_file()]
def ensure_dir(directory):
"""
确保目录存在不存在则创建
Args:
directory: 目录路径
"""
os.makedirs(directory, exist_ok=True)
def remove_dir(directory):
"""
删除目录及其内容
Args:
directory: 目录路径
Returns:
成功标志
"""
try:
if os.path.exists(directory):
shutil.rmtree(directory)
return True
except Exception as e:
logger.error(f"删除目录 {directory} 时出错: {str(e)}")
return False

115
utils/logger.py Normal file
View File

@ -0,0 +1,115 @@
"""
日志工具模块
"""
import logging
import sys
from pathlib import Path
from datetime import datetime
import os
from config.system_config import LOG_DIR, LOG_LEVEL, LOG_FORMAT
def get_logger(name, level=None, log_file=None):
"""
获取logger实例
Args:
name: logger名称
level: 日志级别默认为系统配置
log_file: 日志文件路径默认为None仅控制台输出
Returns:
logger实例
"""
level = level or LOG_LEVEL
# 创建logger
logger = logging.getLogger(name)
logger.setLevel(getattr(logging, level))
# 避免重复添加handler
if logger.handlers:
return logger
# 创建格式化器
formatter = logging.Formatter(LOG_FORMAT)
# 创建控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 如果指定了日志文件,创建文件处理器
if log_file:
log_path = Path(LOG_DIR) / log_file
file_handler = logging.FileHandler(log_path, encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def get_time_logger(name):
"""
获取带时间戳的logger实例
Args:
name: logger名称
Returns:
logger实例
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = f"{name}_{timestamp}.log"
return get_logger(name, log_file=log_file)
class TrainingLogger:
"""训练过程日志记录器"""
def __init__(self, model_name):
"""
初始化训练日志记录器
Args:
model_name: 模型名称
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.log_file = f"training_{model_name}_{timestamp}.log"
self.logger = get_logger(f"training_{model_name}", log_file=self.log_file)
# 创建CSV日志
self.csv_path = Path(LOG_DIR) / f"metrics_{model_name}_{timestamp}.csv"
with open(self.csv_path, 'w', encoding='utf-8') as f:
f.write("epoch,loss,accuracy,val_loss,val_accuracy\n")
def log_epoch(self, epoch, metrics):
"""记录每个epoch的指标"""
# 日志记录
metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
self.logger.info(f"Epoch {epoch}: {metrics_str}")
# CSV记录
csv_line = f"{epoch},{metrics.get('loss', '')},{metrics.get('accuracy', '')}," \
f"{metrics.get('val_loss', '')},{metrics.get('val_accuracy', '')}\n"
with open(self.csv_path, 'a', encoding='utf-8') as f:
f.write(csv_line)
def log_training_start(self, config):
"""记录训练开始信息"""
self.logger.info("=" * 50)
self.logger.info("训练开始")
self.logger.info("模型配置:")
for key, value in config.items():
self.logger.info(f" {key}: {value}")
self.logger.info("=" * 50)
def log_training_end(self, duration, best_metrics):
"""记录训练结束信息"""
self.logger.info("=" * 50)
self.logger.info(f"训练结束,总用时: {duration:.2f}")
self.logger.info("最佳性能:")
for key, value in best_metrics.items():
self.logger.info(f" {key}: {value:.4f}")
self.logger.info("=" * 50)

0
utils/text_utils.py Normal file
View File

0
utils/time_utils.py Normal file
View File