10287 lines
336 KiB
Plaintext
10287 lines
336 KiB
Plaintext
|
||
================================================================================
|
||
文件: export_all_pyfile.py
|
||
================================================================================
|
||
|
||
import os
|
||
|
||
|
||
def find_all_python_files(root_dir='.'):
|
||
"""
|
||
查找指定目录及其所有子目录下的所有Python文件
|
||
|
||
Args:
|
||
root_dir: 根目录路径,默认为当前目录
|
||
|
||
Returns:
|
||
包含所有Python文件路径的列表
|
||
"""
|
||
python_files = []
|
||
|
||
# 遍历根目录及所有子目录
|
||
for dirpath, dirnames, filenames in os.walk(root_dir):
|
||
# 查找所有.py文件
|
||
for filename in filenames:
|
||
if filename.endswith('.py'):
|
||
# 构建完整文件路径
|
||
full_path = os.path.join(dirpath, filename)
|
||
python_files.append(full_path)
|
||
|
||
return python_files
|
||
|
||
|
||
def export_file_contents_to_txt(python_files, output_file='python_contents.txt'):
|
||
"""
|
||
将所有Python文件的内容导出到一个TXT文件
|
||
|
||
Args:
|
||
python_files: Python文件路径列表
|
||
output_file: 输出TXT文件名
|
||
"""
|
||
with open(output_file, 'w', encoding='utf-8') as outfile:
|
||
for file_path in python_files:
|
||
# 获取相对路径以便更好地显示
|
||
rel_path = os.path.relpath(file_path)
|
||
|
||
# 写入文件分隔符
|
||
outfile.write(f"\n{'=' * 80}\n")
|
||
outfile.write(f"文件: {rel_path}\n")
|
||
outfile.write(f"{'=' * 80}\n\n")
|
||
|
||
# 读取Python文件内容并写入输出文件
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as infile:
|
||
content = infile.read()
|
||
outfile.write(content)
|
||
outfile.write("\n") # 文件末尾添加换行
|
||
except Exception as e:
|
||
outfile.write(f"[无法读取文件内容: {str(e)}]\n")
|
||
|
||
print(f"已将{len(python_files)}个Python文件的内容导出到 {output_file}")
|
||
|
||
|
||
def main():
|
||
# 获取当前工作目录
|
||
current_dir = os.getcwd()
|
||
print(f"正在搜索目录: {current_dir}")
|
||
|
||
# 查找所有Python文件
|
||
python_files = find_all_python_files(current_dir)
|
||
|
||
if python_files:
|
||
print(f"找到 {len(python_files)} 个Python文件")
|
||
|
||
# 导出所有文件内容到TXT
|
||
export_file_contents_to_txt(python_files)
|
||
|
||
# 打印处理的文件列表
|
||
for i, file_path in enumerate(python_files[:10]):
|
||
rel_path = os.path.relpath(file_path)
|
||
print(f"{i + 1}. {rel_path}")
|
||
|
||
if len(python_files) > 10:
|
||
print(f"... 还有 {len(python_files) - 10} 个文件")
|
||
else:
|
||
print("未找到任何Python文件")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|
||
================================================================================
|
||
文件: setup.py
|
||
================================================================================
|
||
|
||
from setuptools import setup, find_packages
|
||
|
||
with open("README.md", "r", encoding="utf-8") as fh:
|
||
long_description = fh.read()
|
||
|
||
setup(
|
||
name="chinese-text-classification",
|
||
version="1.0.0",
|
||
author="Your Name",
|
||
author_email="your.email@example.com",
|
||
description="基于Python的中文文本分类系统",
|
||
long_description=long_description,
|
||
long_description_content_type="text/markdown",
|
||
url="https://github.com/yourusername/chinese-text-classification",
|
||
packages=find_packages(),
|
||
classifiers=[
|
||
"Programming Language :: Python :: 3",
|
||
"License :: OSI Approved :: MIT License",
|
||
"Operating System :: OS Independent",
|
||
],
|
||
python_requires=">=3.7",
|
||
install_requires=[
|
||
"tensorflow>=2.5.0",
|
||
"numpy>=1.19.5",
|
||
"pandas>=1.3.0",
|
||
"scikit-learn>=0.24.2",
|
||
"matplotlib>=3.4.2",
|
||
"jieba>=0.42.1",
|
||
"tqdm>=4.61.1",
|
||
"gensim>=4.0.1",
|
||
"flask>=2.0.1",
|
||
"fastapi>=0.68.0",
|
||
"uvicorn>=0.15.0"
|
||
],
|
||
entry_points={
|
||
"console_scripts": [
|
||
"text-classifier=main:main",
|
||
],
|
||
},
|
||
)
|
||
|
||
|
||
================================================================================
|
||
文件: main.py
|
||
================================================================================
|
||
|
||
"""
|
||
主入口文件:整合系统的所有功能,提供命令行接口
|
||
"""
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import logging
|
||
from typing import List, Optional
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger("main")
|
||
|
||
# 命令行参数
|
||
parser = argparse.ArgumentParser(description="中文文本分类系统")
|
||
subparsers = parser.add_subparsers(dest="command", help="命令")
|
||
|
||
# 训练命令
|
||
train_parser = subparsers.add_parser("train", help="训练模型")
|
||
train_parser.add_argument("--data_dir", help="数据目录")
|
||
train_parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型")
|
||
train_parser.add_argument("--epochs", type=int, default=10, help="训练轮数")
|
||
train_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
|
||
train_parser.add_argument("--save_dir", help="模型保存目录")
|
||
|
||
# 评估命令
|
||
evaluate_parser = subparsers.add_parser("evaluate", help="评估模型")
|
||
evaluate_parser.add_argument("--model_path", required=True, help="模型路径")
|
||
evaluate_parser.add_argument("--data_dir", help="数据目录")
|
||
evaluate_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
|
||
evaluate_parser.add_argument("--output_dir", help="评估结果输出目录")
|
||
|
||
# 预测命令
|
||
predict_parser = subparsers.add_parser("predict", help="使用模型预测")
|
||
predict_parser.add_argument("--model_path", help="模型路径")
|
||
predict_parser.add_argument("--text", help="要预测的文本")
|
||
predict_parser.add_argument("--file", help="要预测的文件")
|
||
predict_parser.add_argument("--output", help="输出文件")
|
||
|
||
# Web服务命令
|
||
web_parser = subparsers.add_parser("web", help="启动Web服务")
|
||
web_parser.add_argument("--host", default="0.0.0.0", help="服务器主机")
|
||
web_parser.add_argument("--port", type=int, default=5000, help="服务器端口")
|
||
web_parser.add_argument("--debug", action="store_true", help="是否开启调试模式")
|
||
|
||
# API服务命令
|
||
api_parser = subparsers.add_parser("api", help="启动API服务")
|
||
api_parser.add_argument("--host", default="0.0.0.0", help="服务器主机")
|
||
api_parser.add_argument("--port", type=int, default=8000, help="服务器端口")
|
||
|
||
# CLI命令
|
||
cli_parser = subparsers.add_parser("cli", help="启动命令行接口")
|
||
cli_parser.add_argument("--model_path", help="模型路径")
|
||
cli_parser.add_argument("--interactive", action="store_true", help="是否开启交互模式")
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
args = parser.parse_args()
|
||
|
||
# 如果没有指定命令,显示帮助信息
|
||
if not args.command:
|
||
parser.print_help()
|
||
return 0
|
||
|
||
# 根据命令调用相应的功能
|
||
if args.command == "train":
|
||
# 导入训练模块
|
||
from scripts.train import train_model
|
||
|
||
# 调用训练功能
|
||
train_model(
|
||
data_dir=args.data_dir,
|
||
model_type=args.model_type,
|
||
epochs=args.epochs,
|
||
batch_size=args.batch_size,
|
||
save_dir=args.save_dir
|
||
)
|
||
|
||
elif args.command == "evaluate":
|
||
# 导入评估模块
|
||
from scripts.evaluate import evaluate_model
|
||
|
||
# 调用评估功能
|
||
evaluate_model(
|
||
model_path=args.model_path,
|
||
data_dir=args.data_dir,
|
||
batch_size=args.batch_size,
|
||
output_dir=args.output_dir
|
||
)
|
||
|
||
elif args.command == "predict":
|
||
# 导入预测模块
|
||
from scripts.predict import predict_text, predict_file
|
||
|
||
# 根据输入类型调用相应的预测功能
|
||
if args.text:
|
||
predict_text(args.text, args.model_path, args.output)
|
||
elif args.file:
|
||
predict_file(args.file, args.model_path, args.output)
|
||
else:
|
||
logger.error("请提供要预测的文本或文件")
|
||
return 1
|
||
|
||
elif args.command == "web":
|
||
# 导入Web服务模块
|
||
from interface.web.app import run_server
|
||
|
||
# 启动Web服务
|
||
run_server(host=args.host, port=args.port, debug=args.debug)
|
||
|
||
elif args.command == "api":
|
||
# 导入API服务模块
|
||
from interface.api import run_server
|
||
|
||
# 启动API服务
|
||
run_server(host=args.host, port=args.port)
|
||
|
||
elif args.command == "cli":
|
||
# 导入CLI模块
|
||
from interface.cli import main as cli_main
|
||
|
||
# 将命令行参数转换为CLI模块可接受的格式
|
||
sys.argv = ["interface/cli.py"]
|
||
|
||
if args.interactive:
|
||
sys.argv.append("interactive")
|
||
if args.model_path:
|
||
sys.argv.extend(["--model_path", args.model_path])
|
||
elif args.model_path:
|
||
sys.argv.extend(["list", "--model_path", args.model_path])
|
||
else:
|
||
sys.argv.append("list")
|
||
|
||
# 调用CLI主函数
|
||
return cli_main()
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|
||
|
||
|
||
================================================================================
|
||
文件: interface/__init__.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: interface/api.py
|
||
================================================================================
|
||
|
||
"""
|
||
API接口模块:提供REST API接口
|
||
"""
|
||
import os
|
||
import sys
|
||
import json
|
||
import time
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query, Depends, Request
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
import uvicorn
|
||
import asyncio
|
||
import pandas as pd
|
||
import io
|
||
|
||
# 将项目根目录添加到sys.path
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from config.system_config import CLASSIFIERS_DIR, CATEGORIES
|
||
from models.model_factory import ModelFactory
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
from preprocessing.vectorizer import SequenceVectorizer
|
||
from inference.predictor import Predictor
|
||
from inference.batch_processor import BatchProcessor
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("API")
|
||
|
||
|
||
# 数据模型
|
||
class TextItem(BaseModel):
|
||
text: str
|
||
id: Optional[str] = None
|
||
|
||
|
||
class BatchPredictRequest(BaseModel):
|
||
texts: List[TextItem]
|
||
top_k: Optional[int] = 1
|
||
|
||
|
||
class BatchFileRequest(BaseModel):
|
||
file_paths: List[str]
|
||
top_k: Optional[int] = 1
|
||
|
||
|
||
class ModelInfo(BaseModel):
|
||
id: str
|
||
name: str
|
||
type: str
|
||
num_classes: int
|
||
created_time: str
|
||
file_size: str
|
||
|
||
|
||
# 应用实例
|
||
app = FastAPI(
|
||
title="中文文本分类系统API",
|
||
description="提供中文文本分类功能的REST API",
|
||
version="1.0.0"
|
||
)
|
||
|
||
# 允许跨域请求
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# 全局对象
|
||
predictor = None
|
||
|
||
|
||
def get_predictor() -> Predictor:
|
||
"""
|
||
获取或创建全局Predictor实例
|
||
|
||
Returns:
|
||
Predictor实例
|
||
"""
|
||
global predictor
|
||
if predictor is None:
|
||
# 获取可用模型列表
|
||
models_info = ModelFactory.get_available_models()
|
||
|
||
if not models_info:
|
||
raise HTTPException(status_code=500, detail="未找到可用的模型")
|
||
|
||
# 使用最新的模型
|
||
model_path = models_info[0]['path']
|
||
logger.info(f"API加载模型: {model_path}")
|
||
|
||
# 加载模型
|
||
model = ModelFactory.load_model(model_path)
|
||
|
||
# 创建分词器
|
||
tokenizer = ChineseTokenizer()
|
||
|
||
# 创建预测器
|
||
predictor = Predictor(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
class_names=CATEGORIES,
|
||
batch_size=64
|
||
)
|
||
|
||
return predictor
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""API根路径"""
|
||
return {"message": "欢迎使用中文文本分类系统API"}
|
||
|
||
|
||
@app.post("/predict/text")
|
||
async def predict_text(text: str = Form(...), top_k: int = Form(1)):
|
||
"""
|
||
预测单条文本
|
||
|
||
Args:
|
||
text: 要预测的文本
|
||
top_k: 返回概率最高的前k个类别
|
||
|
||
Returns:
|
||
预测结果
|
||
"""
|
||
logger.info(f"接收到文本预测请求,文本长度: {len(text)}")
|
||
|
||
try:
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 预测
|
||
start_time = time.time()
|
||
result = predictor.predict(
|
||
text=text,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
prediction_time = time.time() - start_time
|
||
|
||
# 构建响应
|
||
response = {
|
||
"success": True,
|
||
"predictions": result if top_k > 1 else [result],
|
||
"time": prediction_time
|
||
}
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error(f"预测文本时出错: {e}")
|
||
raise HTTPException(status_code=500, detail=f"预测文本时出错: {str(e)}")
|
||
|
||
|
||
@app.post("/predict/batch")
|
||
async def predict_batch(request: BatchPredictRequest):
|
||
"""
|
||
批量预测文本
|
||
|
||
Args:
|
||
request: 包含文本列表和参数的请求
|
||
|
||
Returns:
|
||
批量预测结果
|
||
"""
|
||
texts = [item.text for item in request.texts]
|
||
ids = [item.id or str(i) for i, item in enumerate(request.texts)]
|
||
|
||
logger.info(f"接收到批量预测请求,共 {len(texts)} 条文本")
|
||
|
||
try:
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 预测
|
||
start_time = time.time()
|
||
results = predictor.predict_batch(
|
||
texts=texts,
|
||
return_top_k=request.top_k,
|
||
return_probabilities=True
|
||
)
|
||
prediction_time = time.time() - start_time
|
||
|
||
# 构建响应
|
||
response = {
|
||
"success": True,
|
||
"total": len(texts),
|
||
"time": prediction_time,
|
||
"results": {}
|
||
}
|
||
|
||
# 将结果关联到ID
|
||
for i, (id_val, result) in enumerate(zip(ids, results)):
|
||
response["results"][id_val] = result
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量预测文本时出错: {e}")
|
||
raise HTTPException(status_code=500, detail=f"批量预测文本时出错: {str(e)}")
|
||
|
||
|
||
@app.post("/predict/file")
|
||
async def predict_file(file: UploadFile = File(...), top_k: int = Form(1)):
|
||
"""
|
||
预测文件内容
|
||
|
||
Args:
|
||
file: 上传的文件
|
||
top_k: 返回概率最高的前k个类别
|
||
|
||
Returns:
|
||
预测结果
|
||
"""
|
||
logger.info(f"接收到文件预测请求,文件名: {file.filename}")
|
||
|
||
try:
|
||
# 读取文件内容
|
||
content = await file.read()
|
||
|
||
# 根据文件类型处理内容
|
||
if file.filename.endswith(('.txt', '.md')):
|
||
# 文本文件
|
||
text = content.decode('utf-8')
|
||
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 预测
|
||
result = predictor.predict(
|
||
text=text,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 构建响应
|
||
response = {
|
||
"success": True,
|
||
"filename": file.filename,
|
||
"predictions": result if top_k > 1 else [result]
|
||
}
|
||
|
||
return response
|
||
|
||
elif file.filename.endswith(('.csv', '.xls', '.xlsx')):
|
||
# 表格文件
|
||
if file.filename.endswith('.csv'):
|
||
df = pd.read_csv(io.BytesIO(content))
|
||
else:
|
||
df = pd.read_excel(io.BytesIO(content))
|
||
|
||
# 查找可能的文本列
|
||
text_columns = [col for col in df.columns if df[col].dtype == 'object']
|
||
|
||
if not text_columns:
|
||
raise HTTPException(status_code=400, detail="文件中没有找到可能的文本列")
|
||
|
||
# 使用第一个文本列
|
||
text_column = text_columns[0]
|
||
texts = df[text_column].fillna('').tolist()
|
||
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 批量预测
|
||
results = predictor.predict_batch(
|
||
texts=texts,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 构建响应
|
||
response = {
|
||
"success": True,
|
||
"filename": file.filename,
|
||
"text_column": text_column,
|
||
"total": len(texts),
|
||
"results": results
|
||
}
|
||
|
||
return response
|
||
|
||
else:
|
||
raise HTTPException(status_code=400, detail=f"不支持的文件类型: {file.filename}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"预测文件时出错: {e}")
|
||
raise HTTPException(status_code=500, detail=f"预测文件时出错: {str(e)}")
|
||
|
||
|
||
@app.get("/models")
|
||
async def list_models():
|
||
"""
|
||
列出可用的模型
|
||
|
||
Returns:
|
||
可用模型列表
|
||
"""
|
||
try:
|
||
# 获取可用模型列表
|
||
models_info = ModelFactory.get_available_models()
|
||
|
||
# 转换为响应格式
|
||
models = []
|
||
for info in models_info:
|
||
models.append(ModelInfo(
|
||
id=os.path.basename(info['path']),
|
||
name=info['name'],
|
||
type=info['type'],
|
||
num_classes=info['num_classes'],
|
||
created_time=info['created_time'],
|
||
file_size=info['file_size']
|
||
))
|
||
|
||
return {"models": models}
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取模型列表时出错: {e}")
|
||
raise HTTPException(status_code=500, detail=f"获取模型列表时出错: {str(e)}")
|
||
|
||
|
||
@app.get("/categories")
|
||
async def list_categories():
|
||
"""
|
||
列出支持的类别
|
||
|
||
Returns:
|
||
支持的类别列表
|
||
"""
|
||
try:
|
||
return {"categories": CATEGORIES}
|
||
except Exception as e:
|
||
logger.error(f"获取类别列表时出错: {e}")
|
||
raise HTTPException(status_code=500, detail=f"获取类别列表时出错: {str(e)}")
|
||
|
||
|
||
@app.middleware("http")
|
||
async def log_requests(request: Request, call_next):
|
||
"""
|
||
记录请求日志
|
||
|
||
Args:
|
||
request: 请求对象
|
||
call_next: 下一个处理函数
|
||
|
||
Returns:
|
||
响应对象
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 记录请求信息
|
||
logger.info(f"请求: {request.method} {request.url}")
|
||
|
||
# 处理请求
|
||
response = await call_next(request)
|
||
|
||
# 记录响应信息
|
||
process_time = time.time() - start_time
|
||
logger.info(f"响应: {response.status_code} ({process_time:.2f}s)")
|
||
|
||
return response
|
||
|
||
|
||
def run_server(host: str = "0.0.0.0", port: int = 8000):
|
||
"""
|
||
运行API服务器
|
||
|
||
Args:
|
||
host: 主机地址
|
||
port: 端口号
|
||
"""
|
||
uvicorn.run(app, host=host, port=port)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 解析命令行参数
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="中文文本分类系统API服务器")
|
||
parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址")
|
||
parser.add_argument("--port", type=int, default=8000, help="服务器端口号")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 运行服务器
|
||
run_server(host=args.host, port=args.port)
|
||
|
||
|
||
================================================================================
|
||
文件: interface/cli.py
|
||
================================================================================
|
||
|
||
"""
|
||
命令行界面模块:提供命令行交互功能
|
||
"""
|
||
import argparse
|
||
import os
|
||
import sys
|
||
import pandas as pd
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import json
|
||
|
||
# 将项目根目录添加到sys.path
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from config.system_config import CLASSIFIERS_DIR, CATEGORIES
|
||
from models.model_factory import ModelFactory
|
||
from models.base_model import TextClassificationModel
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
from preprocessing.vectorizer import SequenceVectorizer
|
||
from inference.predictor import Predictor
|
||
from inference.batch_processor import BatchProcessor
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import ensure_dir, read_text_file
|
||
|
||
logger = get_logger("CLI")
|
||
|
||
|
||
def load_model_and_components(model_path: Optional[str] = None,
|
||
tokenizer_path: Optional[str] = None,
|
||
vectorizer_path: Optional[str] = None,
|
||
class_names: Optional[List[str]] = None) -> Tuple[
|
||
TextClassificationModel, ChineseTokenizer, Optional[SequenceVectorizer]]:
|
||
"""
|
||
加载模型和相关组件
|
||
|
||
Args:
|
||
model_path: 模型路径,如果为None则使用最新的模型
|
||
tokenizer_path: 分词器路径,如果为None则创建一个新的分词器
|
||
vectorizer_path: 向量化器路径,如果为None则不使用向量化器
|
||
class_names: 类别名称列表,如果为None则使用CATEGORIES
|
||
|
||
Returns:
|
||
(模型, 分词器, 向量化器)的元组
|
||
"""
|
||
# 加载模型
|
||
if model_path is None:
|
||
# 获取可用模型列表
|
||
models_info = ModelFactory.get_available_models()
|
||
|
||
if not models_info:
|
||
raise ValueError("未找到可用的模型,请指定模型路径")
|
||
|
||
# 使用最新的模型
|
||
model_path = models_info[0]['path']
|
||
logger.info(f"使用最新的模型: {model_path}")
|
||
|
||
# 加载模型
|
||
model = ModelFactory.load_model(model_path)
|
||
|
||
# 加载或创建分词器
|
||
if tokenizer_path:
|
||
tokenizer = ChineseTokenizer() # 实际上应该从文件加载,这里简化处理
|
||
logger.info(f"已加载分词器: {tokenizer_path}")
|
||
else:
|
||
tokenizer = ChineseTokenizer()
|
||
logger.info("已创建新的分词器")
|
||
|
||
# 加载向量化器
|
||
vectorizer = None
|
||
if vectorizer_path:
|
||
vectorizer = SequenceVectorizer() # 实际上应该从文件加载,这里简化处理
|
||
vectorizer.load(vectorizer_path)
|
||
logger.info(f"已加载向量化器: {vectorizer_path}")
|
||
|
||
return model, tokenizer, vectorizer
|
||
|
||
|
||
def predict_text(args):
|
||
"""处理单条文本预测命令"""
|
||
# 加载模型和组件
|
||
model, tokenizer, vectorizer = load_model_and_components(
|
||
args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
|
||
)
|
||
|
||
# 创建预测器
|
||
predictor = Predictor(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
vectorizer=vectorizer,
|
||
class_names=args.class_names or CATEGORIES,
|
||
batch_size=args.batch_size
|
||
)
|
||
|
||
# 获取文本
|
||
text = args.text
|
||
|
||
# 如果提供的是文件路径而非文本内容
|
||
if args.file and os.path.exists(text):
|
||
text = read_text_file(text)
|
||
|
||
# 预测
|
||
result = predictor.predict(
|
||
text=text,
|
||
return_top_k=args.top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 输出结果
|
||
if args.top_k > 1:
|
||
print("\n预测结果:")
|
||
for i, pred in enumerate(result):
|
||
print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
|
||
else:
|
||
print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})")
|
||
|
||
# 保存结果
|
||
if args.output:
|
||
if args.output.endswith('.json'):
|
||
with open(args.output, 'w', encoding='utf-8') as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
else:
|
||
with open(args.output, 'w', encoding='utf-8') as f:
|
||
if args.top_k > 1:
|
||
f.write("rank,class,probability\n")
|
||
for i, pred in enumerate(result):
|
||
f.write(f"{i + 1},{pred['class']},{pred['probability']}\n")
|
||
else:
|
||
f.write(f"class,probability\n")
|
||
f.write(f"{result['class']},{result['probability']}\n")
|
||
|
||
print(f"结果已保存到: {args.output}")
|
||
|
||
|
||
def predict_batch(args):
|
||
"""处理批量文本预测命令"""
|
||
# 加载模型和组件
|
||
model, tokenizer, vectorizer = load_model_and_components(
|
||
args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
|
||
)
|
||
|
||
# 创建预测器
|
||
predictor = Predictor(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
vectorizer=vectorizer,
|
||
class_names=args.class_names or CATEGORIES,
|
||
batch_size=args.batch_size
|
||
)
|
||
|
||
# 创建批处理器
|
||
batch_processor = BatchProcessor(
|
||
predictor=predictor,
|
||
batch_size=args.batch_size,
|
||
max_workers=args.workers
|
||
)
|
||
|
||
# 确保输出目录存在
|
||
if args.output:
|
||
ensure_dir(os.path.dirname(args.output))
|
||
|
||
# 根据输入类型选择处理方法
|
||
if args.input_type == 'file' and os.path.isfile(args.input):
|
||
# 单个文件
|
||
if args.large_file:
|
||
# 大型文件,分块处理
|
||
batch_processor.process_large_file(
|
||
file_path=args.input,
|
||
output_path=args.output,
|
||
return_top_k=args.top_k,
|
||
format=args.format
|
||
)
|
||
else:
|
||
# CSV或Excel文件
|
||
if args.input.endswith('.csv'):
|
||
df = pd.read_csv(args.input, encoding='utf-8')
|
||
elif args.input.endswith(('.xls', '.xlsx')):
|
||
df = pd.read_excel(args.input)
|
||
else:
|
||
print(f"不支持的文件格式: {args.input}")
|
||
return
|
||
|
||
# 检查文本列是否存在
|
||
if args.text_column not in df.columns:
|
||
print(f"文本列 '{args.text_column}' 不在输入文件中,可用列: {', '.join(df.columns)}")
|
||
return
|
||
|
||
# 处理DataFrame
|
||
result_df = batch_processor.process_dataframe(
|
||
df=df,
|
||
text_column=args.text_column,
|
||
id_column=args.id_column,
|
||
output_path=args.output,
|
||
return_top_k=args.top_k,
|
||
format=args.format
|
||
)
|
||
|
||
# 输出结果统计
|
||
print(f"\n已处理 {len(result_df)} 条文本")
|
||
print("类别分布:")
|
||
if args.top_k == 1:
|
||
class_counts = result_df['predicted_class'].value_counts()
|
||
for cls, count in class_counts.items():
|
||
print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)")
|
||
|
||
elif args.input_type == 'dir' and os.path.isdir(args.input):
|
||
# 目录
|
||
result_df = batch_processor.process_directory(
|
||
directory=args.input,
|
||
pattern=args.pattern,
|
||
output_path=args.output,
|
||
return_top_k=args.top_k,
|
||
format=args.format,
|
||
recursive=args.recursive
|
||
)
|
||
|
||
# 输出结果统计
|
||
if not result_df.empty:
|
||
print(f"\n已处理 {len(result_df)} 个文件")
|
||
print("类别分布:")
|
||
if args.top_k == 1:
|
||
class_counts = result_df['predicted_class'].value_counts()
|
||
for cls, count in class_counts.items():
|
||
print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)")
|
||
|
||
else:
|
||
print(f"无效的输入: {args.input}")
|
||
|
||
|
||
def list_models(args):
|
||
"""列出可用的模型"""
|
||
models_info = ModelFactory.get_available_models()
|
||
|
||
if not models_info:
|
||
print("未找到可用的模型")
|
||
return
|
||
|
||
print(f"找到 {len(models_info)} 个可用模型:")
|
||
for i, info in enumerate(models_info):
|
||
print(f"\n{i + 1}. {info['name']} ({info['type']})")
|
||
print(f" 路径: {info['path']}")
|
||
print(f" 创建时间: {info['created_time']}")
|
||
print(f" 类别数: {info['num_classes']}")
|
||
print(f" 文件大小: {info['file_size']}")
|
||
|
||
|
||
def interactive_mode(args):
|
||
"""交互模式"""
|
||
print("启动交互模式...")
|
||
|
||
# 加载模型和组件
|
||
model, tokenizer, vectorizer = load_model_and_components(
|
||
args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names
|
||
)
|
||
|
||
# 创建预测器
|
||
predictor = Predictor(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
vectorizer=vectorizer,
|
||
class_names=args.class_names or CATEGORIES,
|
||
batch_size=args.batch_size
|
||
)
|
||
|
||
print("\n模型已加载,可以开始交互式文本分类")
|
||
print("输入 'quit' 或 'exit' 退出交互模式\n")
|
||
|
||
while True:
|
||
try:
|
||
# 获取用户输入
|
||
text = input("请输入要分类的文本: ")
|
||
|
||
# 检查是否退出
|
||
if text.lower() in ['quit', 'exit', 'q']:
|
||
print("退出交互模式")
|
||
break
|
||
|
||
# 空输入
|
||
if not text.strip():
|
||
continue
|
||
|
||
# 预测
|
||
result = predictor.predict(
|
||
text=text,
|
||
return_top_k=args.top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 输出结果
|
||
if args.top_k > 1:
|
||
print("\n预测结果:")
|
||
for i, pred in enumerate(result):
|
||
print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
|
||
else:
|
||
print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})")
|
||
|
||
print() # 空行
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n退出交互模式")
|
||
break
|
||
except Exception as e:
|
||
print(f"处理过程中出错: {e}")
|
||
|
||
|
||
def main():
|
||
"""主函数,解析命令行参数并调用相应的功能"""
|
||
parser = argparse.ArgumentParser(description="中文文本分类系统命令行工具")
|
||
|
||
# 创建子命令
|
||
subparsers = parser.add_subparsers(dest="command", help="子命令")
|
||
|
||
# 预测单条文本命令
|
||
predict_parser = subparsers.add_parser("predict", help="预测单条文本")
|
||
predict_parser.add_argument("text", help="要预测的文本或文本文件路径")
|
||
predict_parser.add_argument("--file", action="store_true", help="将text参数视为文件路径")
|
||
predict_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
|
||
predict_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
|
||
predict_parser.add_argument("--vectorizer_path", help="向量化器路径")
|
||
predict_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
|
||
predict_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别")
|
||
predict_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
|
||
predict_parser.add_argument("--output", help="保存预测结果的文件路径")
|
||
predict_parser.set_defaults(func=predict_text)
|
||
|
||
# 批量预测命令
|
||
batch_parser = subparsers.add_parser("batch", help="批量预测文本")
|
||
batch_parser.add_argument("input", help="输入文件或目录路径")
|
||
batch_parser.add_argument("--input_type", choices=["file", "dir"], default="file", help="输入类型")
|
||
batch_parser.add_argument("--text_column", default="text", help="CSV/Excel文件中的文本列名")
|
||
batch_parser.add_argument("--id_column", help="CSV/Excel文件中的ID列名")
|
||
batch_parser.add_argument("--pattern", default="*.txt", help="文件匹配模式")
|
||
batch_parser.add_argument("--recursive", action="store_true", help="递归处理子目录")
|
||
batch_parser.add_argument("--large_file", action="store_true", help="处理大型文本文件")
|
||
batch_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
|
||
batch_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
|
||
batch_parser.add_argument("--vectorizer_path", help="向量化器路径")
|
||
batch_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
|
||
batch_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别")
|
||
batch_parser.add_argument("--batch_size", type=int, default=64, help="批大小")
|
||
batch_parser.add_argument("--workers", type=int, default=4, help="工作线程数")
|
||
batch_parser.add_argument("--output", required=True, help="输出文件路径")
|
||
batch_parser.add_argument("--format", choices=["csv", "json"], default="csv", help="输出格式")
|
||
batch_parser.set_defaults(func=predict_batch)
|
||
|
||
# 列出可用模型命令
|
||
list_parser = subparsers.add_parser("list", help="列出可用的模型")
|
||
list_parser.set_defaults(func=list_models)
|
||
|
||
# 交互模式命令
|
||
interactive_parser = subparsers.add_parser("interactive", help="启动交互式分类模式")
|
||
interactive_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型")
|
||
interactive_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器")
|
||
interactive_parser.add_argument("--vectorizer_path", help="向量化器路径")
|
||
interactive_parser.add_argument("--class_names", nargs="+", help="类别名称列表")
|
||
interactive_parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别")
|
||
interactive_parser.add_argument("--batch_size", type=int, default=1, help="批大小")
|
||
interactive_parser.set_defaults(func=interactive_mode)
|
||
|
||
# 解析参数
|
||
args = parser.parse_args()
|
||
|
||
# 如果没有指定命令,显示帮助
|
||
if not hasattr(args, 'func'):
|
||
parser.print_help()
|
||
return
|
||
|
||
# 执行命令
|
||
try:
|
||
args.func(args)
|
||
except Exception as e:
|
||
logger.error(f"执行命令时出错: {e}")
|
||
print(f"执行命令时出错: {e}")
|
||
return 1
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|
||
|
||
|
||
================================================================================
|
||
文件: interface/web/__init__.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: interface/web/app.py
|
||
================================================================================
|
||
|
||
return render_template('predict_text.html', text=text, predictions=predictions)
|
||
|
||
except Exception as e:
|
||
logger.error(f"预测文本时出错: {e}")
|
||
flash(f'预测失败: {str(e)}')
|
||
return render_template('predict_text.html', text=text)
|
||
|
||
return render_template('predict_text.html')
|
||
|
||
|
||
@app.route('/predict_file', methods=['GET', 'POST'])
|
||
def predict_file():
|
||
"""文件预测页面"""
|
||
if request.method == 'POST':
|
||
# 检查是否有文件上传
|
||
if 'file' not in request.files:
|
||
flash('未选择文件')
|
||
return render_template('predict_file.html')
|
||
|
||
file = request.files['file']
|
||
|
||
# 如果用户没有选择文件
|
||
if file.filename == '':
|
||
flash('未选择文件')
|
||
return render_template('predict_file.html')
|
||
|
||
if file and allowed_file(file.filename):
|
||
# 获取参数
|
||
top_k = int(request.form.get('top_k', 3))
|
||
text_column = request.form.get('text_column', 'text')
|
||
|
||
try:
|
||
# 安全地保存文件
|
||
filename = secure_filename(file.filename)
|
||
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
||
file.save(file_path)
|
||
|
||
# 处理文件类型
|
||
if filename.endswith('.txt'):
|
||
# 文本文件
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
text = f.read()
|
||
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 预测
|
||
result = predictor.predict(
|
||
text=text,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 准备结果
|
||
predictions = result if top_k > 1 else [result]
|
||
|
||
return render_template(
|
||
'predict_file.html',
|
||
filename=filename,
|
||
file_type='text',
|
||
predictions=predictions
|
||
)
|
||
|
||
elif filename.endswith(('.csv', '.xls', '.xlsx')):
|
||
# 表格文件
|
||
if filename.endswith('.csv'):
|
||
df = pd.read_csv(file_path)
|
||
else:
|
||
df = pd.read_excel(file_path)
|
||
|
||
# 检查文本列是否存在
|
||
if text_column not in df.columns:
|
||
flash(f"文件中没有找到列: {text_column}")
|
||
return render_template('predict_file.html')
|
||
|
||
# 提取文本
|
||
texts = df[text_column].fillna('').tolist()
|
||
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 批量预测
|
||
results = predictor.predict_batch(
|
||
texts=texts[:100], # 仅处理前100条记录
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 准备结果
|
||
batch_results = []
|
||
for i, result in enumerate(results):
|
||
batch_results.append({
|
||
'id': i,
|
||
'text': texts[i][:100] + '...' if len(texts[i]) > 100 else texts[i],
|
||
'predictions': result if top_k > 1 else [result]
|
||
})
|
||
|
||
return render_template(
|
||
'predict_file.html',
|
||
filename=filename,
|
||
file_type='table',
|
||
total_records=len(texts),
|
||
processed_records=min(100, len(texts)),
|
||
batch_results=batch_results
|
||
)
|
||
|
||
else:
|
||
flash(f"不支持的文件类型: {filename}")
|
||
return render_template('predict_file.html')
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理文件时出错: {e}")
|
||
flash(f'处理失败: {str(e)}')
|
||
return render_template('predict_file.html')
|
||
else:
|
||
flash(f'不支持的文件类型,允许的类型: {", ".join(ALLOWED_EXTENSIONS)}')
|
||
return render_template('predict_file.html')
|
||
|
||
return render_template('predict_file.html')
|
||
|
||
|
||
@app.route('/batch_predict', methods=['GET', 'POST'])
|
||
def batch_predict():
|
||
"""批量预测页面"""
|
||
if request.method == 'POST':
|
||
texts = request.form.get('texts', '')
|
||
top_k = int(request.form.get('top_k', 3))
|
||
|
||
if not texts:
|
||
flash('请输入文本内容')
|
||
return render_template('batch_predict.html')
|
||
|
||
# 分割文本
|
||
text_list = [text.strip() for text in texts.split('\n') if text.strip()]
|
||
|
||
if not text_list:
|
||
flash('请输入有效的文本内容')
|
||
return render_template('batch_predict.html', texts=texts)
|
||
|
||
try:
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 批量预测
|
||
results = predictor.predict_batch(
|
||
texts=text_list,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 准备结果
|
||
batch_results = []
|
||
for i, result in enumerate(results):
|
||
batch_results.append({
|
||
'id': i + 1,
|
||
'text': text_list[i],
|
||
'predictions': result if top_k > 1 else [result]
|
||
})
|
||
|
||
return render_template(
|
||
'batch_predict.html',
|
||
texts=texts,
|
||
batch_results=batch_results
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量预测时出错: {e}")
|
||
flash(f'预测失败: {str(e)}')
|
||
return render_template('batch_predict.html', texts=texts)
|
||
|
||
return render_template('batch_predict.html')
|
||
|
||
|
||
@app.route('/models')
|
||
def list_models():
|
||
"""模型列表页面"""
|
||
try:
|
||
# 获取可用模型列表
|
||
models_info = ModelFactory.get_available_models()
|
||
|
||
return render_template('models.html', models=models_info)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取模型列表时出错: {e}")
|
||
flash(f'获取模型列表失败: {str(e)}')
|
||
return render_template('models.html', models=[])
|
||
|
||
|
||
@app.route('/about')
|
||
def about():
|
||
"""关于页面"""
|
||
return render_template('about.html')
|
||
|
||
|
||
@app.errorhandler(404)
|
||
def page_not_found(e):
|
||
"""404错误处理"""
|
||
return render_template('404.html'), 404
|
||
|
||
|
||
@app.errorhandler(500)
|
||
def internal_server_error(e):
|
||
"""500错误处理"""
|
||
return render_template('500.html'), 500
|
||
|
||
|
||
# 过滤器
|
||
@app.template_filter('format_time')
|
||
def format_time_filter(timestamp):
|
||
"""格式化时间戳"""
|
||
if isinstance(timestamp, str):
|
||
try:
|
||
dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
|
||
return dt.strftime("%Y年%m月%d日 %H:%M")
|
||
except:
|
||
return timestamp
|
||
return timestamp
|
||
|
||
|
||
@app.template_filter('truncate_text')
|
||
def truncate_text_filter(text, length=100):
|
||
"""截断文本"""
|
||
if len(text) <= length:
|
||
return text
|
||
return text[:length] + '...'
|
||
|
||
|
||
@app.template_filter('format_percent')
|
||
def format_percent_filter(value):
|
||
"""格式化百分比"""
|
||
if isinstance(value, (int, float)):
|
||
return f"{value * 100:.2f}%"
|
||
return value
|
||
|
||
|
||
def run_server(host: str = "0.0.0.0", port: int = 5000, debug: bool = False):
|
||
"""
|
||
运行Web服务器
|
||
|
||
Args:
|
||
host: 主机地址
|
||
port: 端口号
|
||
debug: 是否开启调试模式
|
||
"""
|
||
app.run(host=host, port=port, debug=debug)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 解析命令行参数
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="中文文本分类系统Web应用")
|
||
parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址")
|
||
parser.add_argument("--port", type=int, default=5000, help="服务器端口号")
|
||
parser.add_argument("--debug", action="store_true", help="是否开启调试模式")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 运行服务器
|
||
run_server(host=args.host, port=args.port, debug=args.debug)
|
||
|
||
|
||
================================================================================
|
||
文件: interface/web/routes.py
|
||
================================================================================
|
||
|
||
"""
|
||
Web路由模块:定义Web应用的路由
|
||
"""
|
||
from flask import Blueprint, render_template, request, jsonify, redirect, url_for, session, flash
|
||
from werkzeug.utils import secure_filename
|
||
import os
|
||
import pandas as pd
|
||
import io
|
||
import time
|
||
from typing import List, Dict, Tuple, Optional
|
||
|
||
from interface.web.app import get_predictor, allowed_file, app
|
||
|
||
# 创建蓝图
|
||
bp = Blueprint('routes', __name__)
|
||
|
||
|
||
@bp.route('/')
|
||
def index():
|
||
"""首页"""
|
||
return render_template('index.html')
|
||
|
||
|
||
@bp.route('/predict_text', methods=['GET', 'POST'])
|
||
def predict_text():
|
||
"""文本预测页面"""
|
||
if request.method == 'POST':
|
||
text = request.form.get('text', '')
|
||
top_k = int(request.form.get('top_k', 3))
|
||
|
||
if not text:
|
||
flash('请输入文本内容')
|
||
return render_template('predict_text.html')
|
||
|
||
try:
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 预测
|
||
result = predictor.predict(
|
||
text=text,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 准备结果
|
||
predictions = result if top_k > 1 else [result]
|
||
|
||
return render_template('predict_text.html', text=text, predictions=predictions)
|
||
|
||
except Exception as e:
|
||
flash(f'预测失败: {str(e)}')
|
||
return render_template('predict_text.html', text=text)
|
||
|
||
return render_template('predict_text.html')
|
||
|
||
|
||
@bp.route('/predict_file', methods=['GET', 'POST'])
|
||
def predict_file():
|
||
"""文件预测页面"""
|
||
if request.method == 'POST':
|
||
# 检查是否有文件上传
|
||
if 'file' not in request.files:
|
||
flash('未选择文件')
|
||
return render_template('predict_file.html')
|
||
|
||
file = request.files['file']
|
||
|
||
# 如果用户没有选择文件
|
||
if file.filename == '':
|
||
flash('未选择文件')
|
||
return render_template('predict_file.html')
|
||
|
||
if file and allowed_file(file.filename):
|
||
# 获取参数
|
||
top_k = int(request.form.get('top_k', 3))
|
||
text_column = request.form.get('text_column', 'text')
|
||
|
||
try:
|
||
# 安全地保存文件
|
||
filename = secure_filename(file.filename)
|
||
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
||
file.save(file_path)
|
||
|
||
# 处理文件
|
||
# 注意:此处与app.py中的predict_file重复,实际项目中应该将逻辑抽取到单独的函数中
|
||
# 这里简化处理,仅返回渲染模板
|
||
|
||
return render_template('predict_file.html', filename=filename)
|
||
|
||
except Exception as e:
|
||
flash(f'处理失败: {str(e)}')
|
||
return render_template('predict_file.html')
|
||
else:
|
||
flash(f'不支持的文件类型')
|
||
return render_template('predict_file.html')
|
||
|
||
return render_template('predict_file.html')
|
||
|
||
|
||
@bp.route('/batch_predict', methods=['GET', 'POST'])
|
||
def batch_predict():
|
||
"""批量预测页面"""
|
||
if request.method == 'POST':
|
||
texts = request.form.get('texts', '')
|
||
top_k = int(request.form.get('top_k', 3))
|
||
|
||
if not texts:
|
||
flash('请输入文本内容')
|
||
return render_template('batch_predict.html')
|
||
|
||
# 分割文本
|
||
text_list = [text.strip() for text in texts.split('\n') if text.strip()]
|
||
|
||
if not text_list:
|
||
flash('请输入有效的文本内容')
|
||
return render_template('batch_predict.html', texts=texts)
|
||
|
||
try:
|
||
# 获取预测器
|
||
predictor = get_predictor()
|
||
|
||
# 批量预测
|
||
results = predictor.predict_batch(
|
||
texts=text_list,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 准备结果
|
||
batch_results = []
|
||
for i, result in enumerate(results):
|
||
batch_results.append({
|
||
'id': i + 1,
|
||
'text': text_list[i],
|
||
'predictions': result if top_k > 1 else [result]
|
||
})
|
||
|
||
return render_template(
|
||
'batch_predict.html',
|
||
texts=texts,
|
||
batch_results=batch_results
|
||
)
|
||
|
||
except Exception as e:
|
||
flash(f'预测失败: {str(e)}')
|
||
return render_template('batch_predict.html', texts=texts)
|
||
|
||
return render_template('batch_predict.html')
|
||
|
||
|
||
@bp.route('/models')
|
||
def list_models():
|
||
"""模型列表页面"""
|
||
return render_template('models.html')
|
||
|
||
|
||
@bp.route('/about')
|
||
def about():
|
||
"""关于页面"""
|
||
return render_template('about.html')
|
||
|
||
|
||
# 注册蓝图
|
||
def init_app(app):
|
||
app.register_blueprint(bp)
|
||
|
||
|
||
================================================================================
|
||
文件: config/model_config.py
|
||
================================================================================
|
||
|
||
"""
|
||
模型配置文件
|
||
"""
|
||
|
||
# 文本预处理参数
|
||
MAX_SEQUENCE_LENGTH = 500 # 文本序列最大长度
|
||
MAX_NUM_WORDS = 50000 # 词汇表最大大小
|
||
MAX_CHAR_LENGTH = 2000 # 字符级最大长度
|
||
MIN_WORD_FREQUENCY = 5 # 最小词频
|
||
|
||
# 模型架构参数
|
||
CNN_CONFIG = {
|
||
"embedding_dim": 200,
|
||
"num_filters": 256,
|
||
"filter_sizes": [3, 4, 5],
|
||
"dropout_rate": 0.5,
|
||
"l2_reg_lambda": 0.0,
|
||
}
|
||
|
||
RNN_CONFIG = {
|
||
"embedding_dim": 200,
|
||
"hidden_size": 256,
|
||
"num_layers": 2,
|
||
"bidirectional": True,
|
||
"dropout_rate": 0.5,
|
||
}
|
||
|
||
TRANSFORMER_CONFIG = {
|
||
"embedding_dim": 200,
|
||
"num_heads": 8,
|
||
"ff_dim": 512,
|
||
"num_layers": 4,
|
||
"dropout_rate": 0.1,
|
||
}
|
||
|
||
# 针对RTX 4090的优化设置
|
||
BATCH_SIZE = 128 # RTX 4090有24GB显存,可以支持较大的batch
|
||
EVAL_BATCH_SIZE = 256 # 评估时可以用更大的batch
|
||
|
||
# 训练参数
|
||
LEARNING_RATE = 1e-3
|
||
NUM_EPOCHS = 20
|
||
EARLY_STOPPING_PATIENCE = 3
|
||
REDUCE_LR_PATIENCE = 2
|
||
REDUCE_LR_FACTOR = 0.5
|
||
VALIDATION_SPLIT = 0.1
|
||
TEST_SPLIT = 0.1
|
||
|
||
# 词嵌入参数
|
||
USE_PRETRAINED_EMBEDDING = True
|
||
EMBEDDING_TYPE = "word2vec" # 可选: word2vec, glove, fasttext
|
||
|
||
# 随机种子,保证实验可重复性
|
||
RANDOM_SEED = 42
|
||
|
||
# 模型保存参数
|
||
SAVE_BEST_ONLY = True
|
||
MODEL_CHECKPOINT_PATH = "best_model.h5"
|
||
|
||
# 特征工程参数
|
||
USE_CHAR_LEVEL = False # 是否使用字符级特征
|
||
USE_WORD_LEVEL = True # 是否使用词级特征
|
||
USE_TFIDF = False # 是否使用TF-IDF特征
|
||
USE_POS_TAGS = False # 是否使用词性标注特征
|
||
|
||
# 数据增强参数
|
||
USE_DATA_AUGMENTATION = False
|
||
AUGMENTATION_FACTOR = 0.2 # 增强20%的数据
|
||
|
||
# 推理参数
|
||
PREDICTION_THRESHOLD = 0.5
|
||
TOP_K_PREDICTIONS = 3
|
||
|
||
================================================================================
|
||
文件: config/__init__.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: config/system_config.py
|
||
================================================================================
|
||
|
||
"""
|
||
系统全局配置文件
|
||
"""
|
||
import os
|
||
import platform
|
||
from pathlib import Path
|
||
|
||
# 项目根目录
|
||
ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
"""
|
||
Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 是当前文件的上一级目录
|
||
这种写法主要是为了方便移植项目到不同的平台运行
|
||
"""
|
||
|
||
# 数据相关路径
|
||
DATA_DIR = ROOT_DIR / "data"
|
||
RAW_DATA_DIR = DATA_DIR / "raw" / "THUCNews"
|
||
PROCESSED_DATA_DIR = DATA_DIR / "processed"
|
||
RESOURCES_DIR = DATA_DIR / "resources"
|
||
STOPWORDS_DIR = RESOURCES_DIR / "stopwords"
|
||
EMBEDDINGS_DIR = RESOURCES_DIR / "embeddings"
|
||
|
||
# 确保必要的目录存在
|
||
for directory in [PROCESSED_DATA_DIR, RESOURCES_DIR, STOPWORDS_DIR, EMBEDDINGS_DIR]:
|
||
directory.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 保存模型的路径
|
||
SAVED_MODELS_DIR = ROOT_DIR / "saved_models"
|
||
TOKENIZERS_DIR = SAVED_MODELS_DIR / "tokenizers"
|
||
CLASSIFIERS_DIR = SAVED_MODELS_DIR / "classifiers"
|
||
|
||
# 确保模型保存目录存在
|
||
for directory in [SAVED_MODELS_DIR, TOKENIZERS_DIR, CLASSIFIERS_DIR]:
|
||
directory.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 系统资源配置
|
||
CPU_COUNT = os.cpu_count()
|
||
USE_GPU = True
|
||
MULTI_GPU = False # 目前只使用单个GPU
|
||
|
||
# 基于13900K性能设置并行处理参数
|
||
DATA_LOADING_WORKERS = min(16, CPU_COUNT) # 数据加载线程数
|
||
PREPROCESSING_WORKERS = min(24, CPU_COUNT) # 预处理线程数,13900K有强大的多线程能力
|
||
|
||
# 基于64GB内存设置内存相关参数
|
||
MAX_MEMORY_GB = 48 # 保留部分内存给系统和其他应用
|
||
MAX_TEXT_PER_BATCH = 10000 # 每批处理的最大文本数量
|
||
|
||
# 日志配置
|
||
LOG_DIR = ROOT_DIR / "logs"
|
||
LOG_DIR.mkdir(exist_ok=True)
|
||
LOG_LEVEL = "INFO"
|
||
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||
|
||
# 类别标签映射(与THUCNews数据集一致)
|
||
CATEGORIES = [
|
||
"体育", "娱乐", "家居", "彩票", "房产", "教育",
|
||
"时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"
|
||
]
|
||
CATEGORY_TO_ID = {category: idx for idx, category in enumerate(CATEGORIES)}
|
||
ID_TO_CATEGORY = {idx: category for idx, category in enumerate(CATEGORIES)}
|
||
|
||
# 文件编码
|
||
ENCODING = "utf-8"
|
||
|
||
# 系统信息
|
||
SYSTEM_INFO = {
|
||
"platform": platform.platform(),
|
||
"python_version": platform.python_version(),
|
||
"processor": platform.processor(),
|
||
}
|
||
|
||
================================================================================
|
||
文件: training/__init__.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: training/callbacks.py
|
||
================================================================================
|
||
|
||
"""
|
||
回调函数模块:提供用于模型训练的自定义回调函数
|
||
"""
|
||
import os
|
||
import time
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import matplotlib.pyplot as plt
|
||
from io import BytesIO
|
||
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("Callbacks")
|
||
|
||
|
||
class MetricsHistory(tf.keras.callbacks.Callback):
|
||
"""跟踪训练过程中的指标历史"""
|
||
|
||
def __init__(self, validation_data: Optional[Tuple] = None,
|
||
metrics: Optional[List[str]] = None,
|
||
save_path: Optional[str] = None):
|
||
"""
|
||
初始化MetricsHistory回调
|
||
|
||
Args:
|
||
validation_data: 验证数据,格式为(x_val, y_val)
|
||
metrics: 要跟踪的指标列表
|
||
save_path: 指标历史的保存路径
|
||
"""
|
||
super().__init__()
|
||
self.validation_data = validation_data
|
||
self.metrics = metrics or ['loss', 'accuracy']
|
||
self.save_path = save_path
|
||
|
||
# 历史指标
|
||
self.history = {metric: [] for metric in self.metrics}
|
||
if validation_data is not None:
|
||
for metric in self.metrics:
|
||
self.history[f'val_{metric}'] = []
|
||
|
||
def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个epoch结束时调用
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
logs: 训练日志
|
||
"""
|
||
logs = logs or {}
|
||
|
||
# 记录训练指标
|
||
for metric in self.metrics:
|
||
if metric in logs:
|
||
self.history[metric].append(logs[metric])
|
||
|
||
# 记录验证指标
|
||
if self.validation_data is not None:
|
||
for metric in self.metrics:
|
||
val_metric = f'val_{metric}'
|
||
if val_metric in logs:
|
||
self.history[val_metric].append(logs[val_metric])
|
||
|
||
def plot_metrics(self, save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制指标历史
|
||
|
||
Args:
|
||
save_path: 图像保存路径,如果为None则使用初始化时设置的路径
|
||
"""
|
||
plt.figure(figsize=(12, 5))
|
||
|
||
for i, metric in enumerate(self.metrics):
|
||
plt.subplot(1, len(self.metrics), i + 1)
|
||
|
||
if metric in self.history:
|
||
plt.plot(self.history[metric], label=f'train_{metric}')
|
||
|
||
val_metric = f'val_{metric}'
|
||
if val_metric in self.history:
|
||
plt.plot(self.history[val_metric], label=f'val_{metric}')
|
||
|
||
plt.title(f'Model {metric}')
|
||
plt.xlabel('Epoch')
|
||
plt.ylabel(metric)
|
||
plt.legend()
|
||
|
||
plt.tight_layout()
|
||
|
||
save_path = save_path or self.save_path
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"指标历史图已保存到: {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
|
||
class ConfusionMatrixCallback(tf.keras.callbacks.Callback):
|
||
"""计算并显示验证集上的混淆矩阵"""
|
||
|
||
def __init__(self, validation_data: Tuple[np.ndarray, np.ndarray],
|
||
class_names: Optional[List[str]] = None,
|
||
log_dir: Optional[str] = None,
|
||
freq: int = 1,
|
||
fig_size: Tuple[int, int] = (10, 8)):
|
||
"""
|
||
初始化ConfusionMatrixCallback
|
||
|
||
Args:
|
||
validation_data: 验证数据,格式为(x_val, y_val)
|
||
class_names: 类别名称列表
|
||
log_dir: TensorBoard日志目录
|
||
freq: 计算混淆矩阵的频率(每多少个epoch计算一次)
|
||
fig_size: 图像大小
|
||
"""
|
||
super().__init__()
|
||
self.x_val, self.y_val = validation_data
|
||
self.class_names = class_names
|
||
self.log_dir = log_dir
|
||
self.freq = freq
|
||
self.fig_size = fig_size
|
||
|
||
# 如果提供了TensorBoard日志目录,创建一个文件写入器
|
||
if log_dir:
|
||
self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'confusion_matrix'))
|
||
else:
|
||
self.file_writer = None
|
||
|
||
def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个epoch结束时调用
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
logs: 训练日志
|
||
"""
|
||
# 每freq个epoch计算一次混淆矩阵
|
||
if (epoch + 1) % self.freq == 0 or epoch == 0:
|
||
# 获取预测结果
|
||
y_pred = np.argmax(self.model.predict(self.x_val), axis=1)
|
||
|
||
# 确保y_val是一维数组
|
||
y_true = self.y_val
|
||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||
y_true = np.argmax(y_true, axis=1)
|
||
|
||
# 计算混淆矩阵
|
||
cm = tf.math.confusion_matrix(y_true, y_pred).numpy()
|
||
|
||
# 归一化混淆矩阵
|
||
cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
|
||
|
||
# 绘制混淆矩阵
|
||
fig = self._plot_confusion_matrix(cm_norm, epoch + 1)
|
||
|
||
# 如果有TensorBoard日志,将图像添加到TensorBoard
|
||
if self.file_writer:
|
||
with self.file_writer.as_default():
|
||
# 将matplotlib图像转换为TensorBoard图像
|
||
buf = BytesIO()
|
||
fig.savefig(buf, format='png')
|
||
buf.seek(0)
|
||
|
||
# 将PNG编码为字符串,并创建图像
|
||
image = tf.image.decode_png(buf.getvalue(), channels=4)
|
||
image = tf.expand_dims(image, 0)
|
||
|
||
# 添加到TensorBoard
|
||
tf.summary.image(f'Confusion Matrix (Epoch {epoch + 1})', image, step=epoch)
|
||
|
||
plt.close(fig)
|
||
|
||
def _plot_confusion_matrix(self, cm: np.ndarray, epoch: int) -> plt.Figure:
|
||
"""
|
||
绘制混淆矩阵
|
||
|
||
Args:
|
||
cm: 混淆矩阵
|
||
epoch: 当前epoch
|
||
|
||
Returns:
|
||
matplotlib图像
|
||
"""
|
||
fig, ax = plt.subplots(figsize=self.fig_size)
|
||
|
||
# 使用热图显示混淆矩阵
|
||
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
||
ax.figure.colorbar(im, ax=ax)
|
||
|
||
# 设置坐标轴标签
|
||
if self.class_names:
|
||
ax.set(
|
||
xticks=np.arange(cm.shape[1]),
|
||
yticks=np.arange(cm.shape[0]),
|
||
xticklabels=self.class_names,
|
||
yticklabels=self.class_names
|
||
)
|
||
|
||
# 旋转x轴标签
|
||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
||
|
||
# 在每个单元格中显示数值
|
||
thresh = cm.max() / 2.0
|
||
for i in range(cm.shape[0]):
|
||
for j in range(cm.shape[1]):
|
||
ax.text(j, i, format(cm[i, j], '.2f'),
|
||
ha="center", va="center",
|
||
color="white" if cm[i, j] > thresh else "black")
|
||
|
||
ax.set_title(f"Normalized Confusion Matrix (Epoch {epoch})")
|
||
ax.set_ylabel('True label')
|
||
ax.set_xlabel('Predicted label')
|
||
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
class TimingCallback(tf.keras.callbacks.Callback):
|
||
"""测量训练时间的回调函数"""
|
||
|
||
def __init__(self):
|
||
"""初始化TimingCallback"""
|
||
super().__init__()
|
||
self.epoch_times = []
|
||
self.batch_times = []
|
||
self.epoch_start_time = None
|
||
self.batch_start_time = None
|
||
self.training_start_time = None
|
||
|
||
def on_train_begin(self, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
训练开始时调用
|
||
|
||
Args:
|
||
logs: 训练日志
|
||
"""
|
||
self.training_start_time = time.time()
|
||
|
||
def on_train_end(self, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
训练结束时调用
|
||
|
||
Args:
|
||
logs: 训练日志
|
||
"""
|
||
training_time = time.time() - self.training_start_time
|
||
logger.info(f"总训练时间: {training_time:.2f} 秒")
|
||
|
||
if self.epoch_times:
|
||
avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
|
||
logger.info(f"平均每个epoch时间: {avg_epoch_time:.2f} 秒")
|
||
|
||
if self.batch_times:
|
||
avg_batch_time = sum(self.batch_times) / len(self.batch_times)
|
||
logger.info(f"平均每个batch时间: {avg_batch_time:.4f} 秒")
|
||
|
||
def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个epoch开始时调用
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
logs: 训练日志
|
||
"""
|
||
self.epoch_start_time = time.time()
|
||
|
||
def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个epoch结束时调用
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
logs: 训练日志
|
||
"""
|
||
epoch_time = time.time() - self.epoch_start_time
|
||
self.epoch_times.append(epoch_time)
|
||
|
||
# 将epoch时间添加到日志中
|
||
if logs is not None:
|
||
logs['epoch_time'] = epoch_time
|
||
|
||
def on_batch_begin(self, batch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个batch开始时调用
|
||
|
||
Args:
|
||
batch: 当前batch索引
|
||
logs: 训练日志
|
||
"""
|
||
self.batch_start_time = time.time()
|
||
|
||
def on_batch_end(self, batch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个batch结束时调用
|
||
|
||
Args:
|
||
batch: 当前batch索引
|
||
logs: 训练日志
|
||
"""
|
||
batch_time = time.time() - self.batch_start_time
|
||
self.batch_times.append(batch_time)
|
||
|
||
|
||
class LearningRateSchedulerCallback(tf.keras.callbacks.Callback):
|
||
"""学习率调度器回调函数"""
|
||
|
||
def __init__(self, scheduler_func: Callable[[int, float], float],
|
||
verbose: int = 0,
|
||
log_dir: Optional[str] = None):
|
||
"""
|
||
初始化LearningRateSchedulerCallback
|
||
|
||
Args:
|
||
scheduler_func: 学习率调度函数,接收(epoch, lr)参数,返回新的学习率
|
||
verbose: 详细程度
|
||
log_dir: TensorBoard日志目录
|
||
"""
|
||
super().__init__()
|
||
self.scheduler_func = scheduler_func
|
||
self.verbose = verbose
|
||
|
||
# 如果提供了TensorBoard日志目录,创建一个文件写入器
|
||
if log_dir:
|
||
self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'learning_rate'))
|
||
else:
|
||
self.file_writer = None
|
||
|
||
# 学习率历史
|
||
self.lr_history = []
|
||
|
||
def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None:
|
||
"""
|
||
每个epoch开始时调用
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
logs: 训练日志
|
||
"""
|
||
if not hasattr(self.model.optimizer, 'lr'):
|
||
raise ValueError('Optimizer must have a "lr" attribute.')
|
||
|
||
# 获取当前学习率
|
||
current_lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
|
||
|
||
# 计算新的学习率
|
||
new_lr = self.scheduler_func(epoch, current_lr)
|
||
|
||
# 设置新的学习率
|
||
tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
|
||
|
||
# 记录学习率
|
||
self.lr_history.append(new_lr)
|
||
|
||
# 记录到TensorBoard
|
||
if self.file_writer:
|
||
with self.file_writer.as_default():
|
||
tf.summary.scalar('learning_rate', new_lr, step=epoch)
|
||
|
||
if self.verbose > 0:
|
||
logger.info(f"Epoch {epoch + 1}: 学习率设置为 {new_lr:.6f}")
|
||
|
||
def get_lr_history(self) -> List[float]:
|
||
"""
|
||
获取学习率历史
|
||
|
||
Returns:
|
||
学习率历史列表
|
||
"""
|
||
return self.lr_history
|
||
|
||
|
||
class EarlyStoppingCallback(tf.keras.callbacks.EarlyStopping):
|
||
"""增强版早停回调函数,支持最小变化率"""
|
||
|
||
def __init__(self, monitor: str = 'val_loss',
|
||
min_delta: float = 0,
|
||
min_delta_ratio: float = 0,
|
||
patience: int = 0,
|
||
verbose: int = 0,
|
||
mode: str = 'auto',
|
||
baseline: Optional[float] = None,
|
||
restore_best_weights: bool = False):
|
||
"""
|
||
初始化EarlyStoppingCallback
|
||
|
||
Args:
|
||
monitor: 监控的指标
|
||
min_delta: 视为改进的最小绝对变化
|
||
min_delta_ratio: 视为改进的最小相对变化率
|
||
patience: 没有改进的轮数
|
||
verbose: 详细程度
|
||
mode: 'auto', 'min' 或 'max'
|
||
baseline: 基准值
|
||
restore_best_weights: 是否恢复最佳权重
|
||
"""
|
||
super().__init__(
|
||
monitor=monitor,
|
||
min_delta=min_delta,
|
||
patience=patience,
|
||
verbose=verbose,
|
||
mode=mode,
|
||
baseline=baseline,
|
||
restore_best_weights=restore_best_weights
|
||
)
|
||
self.min_delta_ratio = min_delta_ratio
|
||
|
||
def _is_improvement(self, current: float, reference: float) -> bool:
|
||
"""
|
||
判断是否有所改进
|
||
|
||
Args:
|
||
current: 当前值
|
||
reference: 参考值
|
||
|
||
Returns:
|
||
是否有所改进
|
||
"""
|
||
# 先检查绝对变化
|
||
if super()._is_improvement(current, reference):
|
||
return True
|
||
|
||
# 再检查相对变化率
|
||
if self.monitor_op == np.less:
|
||
# 对于 'min' 模式,值越小越好
|
||
relative_delta = (reference - current) / reference if reference != 0 else 0
|
||
return relative_delta > self.min_delta_ratio
|
||
else:
|
||
# 对于 'max' 模式,值越大越好
|
||
relative_delta = (current - reference) / reference if reference != 0 else 0
|
||
return relative_delta > self.min_delta_ratio
|
||
|
||
|
||
================================================================================
|
||
文件: training/optimizer.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: training/scheduler.py
|
||
================================================================================
|
||
|
||
"""
|
||
学习率调度器模块:提供各种学习率调度策略
|
||
"""
|
||
import numpy as np
|
||
import math
|
||
from typing import Callable, Optional, Union, Dict
|
||
import tensorflow as tf
|
||
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("Scheduler")
|
||
|
||
|
||
def step_decay(epoch: int, initial_lr: float,
|
||
drop_rate: float = 0.5,
|
||
epochs_drop: int = 10) -> float:
|
||
"""
|
||
阶梯式学习率衰减
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
initial_lr: 初始学习率
|
||
drop_rate: 衰减率
|
||
epochs_drop: 每多少个epoch衰减一次
|
||
|
||
Returns:
|
||
新的学习率
|
||
"""
|
||
return initial_lr * math.pow(drop_rate, math.floor((1 + epoch) / epochs_drop))
|
||
|
||
|
||
def exponential_decay(epoch: int, initial_lr: float,
|
||
decay_rate: float = 0.9,
|
||
staircase: bool = False) -> float:
|
||
"""
|
||
指数衰减学习率
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
initial_lr: 初始学习率
|
||
decay_rate: 衰减率
|
||
staircase: 是否阶梯式衰减
|
||
|
||
Returns:
|
||
新的学习率
|
||
"""
|
||
if staircase:
|
||
return initial_lr * math.pow(decay_rate, math.floor(epoch))
|
||
else:
|
||
return initial_lr * math.pow(decay_rate, epoch)
|
||
|
||
|
||
def cosine_decay(epoch: int, initial_lr: float,
|
||
total_epochs: int = 100,
|
||
min_lr: float = 0) -> float:
|
||
"""
|
||
余弦退火学习率
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
initial_lr: 初始学习率
|
||
total_epochs: 总epoch数
|
||
min_lr: 最小学习率
|
||
|
||
Returns:
|
||
新的学习率
|
||
"""
|
||
return min_lr + 0.5 * (initial_lr - min_lr) * (1 + math.cos(math.pi * epoch / total_epochs))
|
||
|
||
|
||
def cosine_decay_with_warmup(epoch: int, initial_lr: float,
|
||
total_epochs: int = 100,
|
||
warmup_epochs: int = 5,
|
||
min_lr: float = 0,
|
||
warmup_init_lr: float = 0) -> float:
|
||
"""
|
||
带预热的余弦退火学习率
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
initial_lr: 初始学习率
|
||
total_epochs: 总epoch数
|
||
warmup_epochs: 预热epoch数
|
||
min_lr: 最小学习率
|
||
warmup_init_lr: 预热初始学习率
|
||
|
||
Returns:
|
||
新的学习率
|
||
"""
|
||
if epoch < warmup_epochs:
|
||
# 线性预热
|
||
return warmup_init_lr + (initial_lr - warmup_init_lr) * epoch / warmup_epochs
|
||
else:
|
||
# 余弦退火
|
||
return min_lr + 0.5 * (initial_lr - min_lr) * (
|
||
1 + math.cos(math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs))
|
||
)
|
||
|
||
|
||
def cyclical_learning_rate(epoch: int, initial_lr: float,
|
||
max_lr: float = 0.1,
|
||
step_size: int = 8,
|
||
gamma: float = 1.0) -> float:
|
||
"""
|
||
循环学习率
|
||
|
||
Args:
|
||
epoch: 当前epoch索引
|
||
initial_lr: 初始学习率
|
||
max_lr: 最大学习率
|
||
step_size: 半周期大小(epoch数)
|
||
gamma: 循环衰减率
|
||
|
||
Returns:
|
||
新的学习率
|
||
"""
|
||
# 计算循环数
|
||
cycle = math.floor(1 + epoch / (2 * step_size))
|
||
|
||
# 计算x值,范围在[0, 2]
|
||
x = abs(epoch / step_size - 2 * cycle + 1)
|
||
|
||
# 应用循环衰减
|
||
lr_range = (max_lr - initial_lr) * math.pow(gamma, cycle - 1)
|
||
|
||
# 计算学习率
|
||
return initial_lr + lr_range * max(0, 1 - x)
|
||
|
||
|
||
def create_custom_scheduler(scheduler_type: str, **kwargs) -> Callable[[int, float], float]:
|
||
"""
|
||
创建自定义学习率调度器
|
||
|
||
Args:
|
||
scheduler_type: 调度器类型,可选值: 'step', 'exp', 'cosine', 'cosine_warmup', 'cyclical'
|
||
**kwargs: 调度器参数
|
||
|
||
Returns:
|
||
学习率调度函数
|
||
"""
|
||
scheduler_type = scheduler_type.lower()
|
||
|
||
if scheduler_type == 'step':
|
||
drop_rate = kwargs.get('drop_rate', 0.5)
|
||
epochs_drop = kwargs.get('epochs_drop', 10)
|
||
|
||
def scheduler(epoch, lr):
|
||
return step_decay(epoch, lr, drop_rate, epochs_drop)
|
||
|
||
return scheduler
|
||
|
||
elif scheduler_type == 'exp':
|
||
decay_rate = kwargs.get('decay_rate', 0.9)
|
||
staircase = kwargs.get('staircase', False)
|
||
|
||
def scheduler(epoch, lr):
|
||
if epoch == 0:
|
||
# 第一个epoch使用初始学习率
|
||
return lr
|
||
return exponential_decay(epoch, lr, decay_rate, staircase)
|
||
|
||
return scheduler
|
||
|
||
elif scheduler_type == 'cosine':
|
||
total_epochs = kwargs.get('total_epochs', 100)
|
||
min_lr = kwargs.get('min_lr', 0)
|
||
|
||
def scheduler(epoch, lr):
|
||
if epoch == 0:
|
||
return lr
|
||
return cosine_decay(epoch, lr, total_epochs, min_lr)
|
||
|
||
return scheduler
|
||
|
||
elif scheduler_type == 'cosine_warmup':
|
||
total_epochs = kwargs.get('total_epochs', 100)
|
||
warmup_epochs = kwargs.get('warmup_epochs', 5)
|
||
min_lr = kwargs.get('min_lr', 0)
|
||
warmup_init_lr = kwargs.get('warmup_init_lr', 0)
|
||
|
||
def scheduler(epoch, lr):
|
||
if epoch == 0:
|
||
return warmup_init_lr
|
||
return cosine_decay_with_warmup(epoch, lr, total_epochs, warmup_epochs, min_lr, warmup_init_lr)
|
||
|
||
return scheduler
|
||
|
||
elif scheduler_type == 'cyclical':
|
||
max_lr = kwargs.get('max_lr', 0.1)
|
||
step_size = kwargs.get('step_size', 8)
|
||
gamma = kwargs.get('gamma', 1.0)
|
||
|
||
def scheduler(epoch, lr):
|
||
if epoch == 0:
|
||
return lr
|
||
return cyclical_learning_rate(epoch, lr, max_lr, step_size, gamma)
|
||
|
||
return scheduler
|
||
|
||
else:
|
||
raise ValueError(f"不支持的调度器类型: {scheduler_type}")
|
||
|
||
|
||
class WarmupCosineDecayScheduler(tf.keras.callbacks.Callback):
|
||
"""预热余弦退火学习率调度器"""
|
||
|
||
def __init__(self, learning_rate_base: float,
|
||
total_steps: int,
|
||
warmup_learning_rate: float = 0.0,
|
||
warmup_steps: int = 0,
|
||
hold_base_rate_steps: int = 0,
|
||
verbose: int = 0):
|
||
"""
|
||
初始化预热余弦退火学习率调度器
|
||
|
||
Args:
|
||
learning_rate_base: 基础学习率
|
||
total_steps: 总步数
|
||
warmup_learning_rate: 预热学习率
|
||
warmup_steps: 预热步数
|
||
hold_base_rate_steps: 保持基础学习率的步数
|
||
verbose: 详细程度
|
||
"""
|
||
super().__init__()
|
||
|
||
self.learning_rate_base = learning_rate_base
|
||
self.total_steps = total_steps
|
||
self.warmup_learning_rate = warmup_learning_rate
|
||
self.warmup_steps = warmup_steps
|
||
self.hold_base_rate_steps = hold_base_rate_steps
|
||
self.verbose = verbose
|
||
|
||
# 学习率历史
|
||
self.learning_rates = []
|
||
|
||
def on_train_begin(self, logs: Optional[Dict] = None) -> None:
|
||
"""
|
||
训练开始时调用
|
||
|
||
Args:
|
||
logs: 训练日志
|
||
"""
|
||
self.current_step = 0
|
||
logger.info(f"预热余弦退火学习率调度器初始化: 基础学习率={self.learning_rate_base}, "
|
||
f"预热步数={self.warmup_steps}, 总步数={self.total_steps}")
|
||
|
||
def on_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
|
||
"""
|
||
每个batch结束时调用
|
||
|
||
Args:
|
||
batch: 当前batch索引
|
||
logs: 训练日志
|
||
"""
|
||
self.current_step += 1
|
||
lr = self._get_learning_rate()
|
||
|
||
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
|
||
self.learning_rates.append(lr)
|
||
|
||
def _get_learning_rate(self) -> float:
|
||
"""
|
||
计算当前学习率
|
||
|
||
Returns:
|
||
当前学习率
|
||
"""
|
||
if self.current_step < self.warmup_steps:
|
||
# 预热阶段:线性增加学习率
|
||
lr = self.warmup_learning_rate + self.current_step * (
|
||
(self.learning_rate_base - self.warmup_learning_rate) / self.warmup_steps
|
||
)
|
||
elif self.current_step < self.warmup_steps + self.hold_base_rate_steps:
|
||
# 保持基础学习率阶段
|
||
lr = self.learning_rate_base
|
||
else:
|
||
# 余弦退火阶段
|
||
cosine_steps = self.total_steps - self.warmup_steps - self.hold_base_rate_steps
|
||
cosine_current_step = self.current_step - self.warmup_steps - self.hold_base_rate_steps
|
||
lr = 0.5 * self.learning_rate_base * (
|
||
1 + math.cos(math.pi * cosine_current_step / cosine_steps)
|
||
)
|
||
|
||
return lr
|
||
|
||
|
||
================================================================================
|
||
文件: training/trainer.py
|
||
================================================================================
|
||
|
||
"""
|
||
训练器模块:实现模型训练流程,包括训练循环、验证等
|
||
"""
|
||
import os
|
||
import time
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
import matplotlib.pyplot as plt
|
||
from datetime import datetime
|
||
|
||
from config.system_config import SAVED_MODELS_DIR
|
||
from config.model_config import (
|
||
NUM_EPOCHS, BATCH_SIZE, EARLY_STOPPING_PATIENCE,
|
||
VALIDATION_SPLIT, RANDOM_SEED
|
||
)
|
||
from models.base_model import TextClassificationModel
|
||
from utils.logger import get_logger, TrainingLogger
|
||
from utils.file_utils import ensure_dir
|
||
|
||
logger = get_logger("Trainer")
|
||
|
||
|
||
class Trainer:
|
||
"""模型训练器,负责训练和验证模型"""
|
||
|
||
def __init__(self, model: TextClassificationModel,
|
||
epochs: int = NUM_EPOCHS,
|
||
batch_size: Optional[int] = None,
|
||
validation_split: float = VALIDATION_SPLIT,
|
||
early_stopping: bool = True,
|
||
early_stopping_patience: int = EARLY_STOPPING_PATIENCE,
|
||
save_best_only: bool = True,
|
||
tensorboard: bool = True,
|
||
checkpoint: bool = True,
|
||
custom_callbacks: Optional[List[tf.keras.callbacks.Callback]] = None):
|
||
"""
|
||
初始化训练器
|
||
|
||
Args:
|
||
model: 要训练的模型
|
||
epochs: 训练轮数
|
||
batch_size: 批大小,如果为None则使用模型默认值
|
||
validation_split: 验证集比例
|
||
early_stopping: 是否使用早停
|
||
early_stopping_patience: 早停耐心值
|
||
save_best_only: 是否只保存最佳模型
|
||
tensorboard: 是否使用TensorBoard
|
||
checkpoint: 是否保存检查点
|
||
custom_callbacks: 自定义回调函数列表
|
||
"""
|
||
self.model = model
|
||
self.epochs = epochs
|
||
self.batch_size = batch_size or model.batch_size
|
||
self.validation_split = validation_split
|
||
self.early_stopping = early_stopping
|
||
self.early_stopping_patience = early_stopping_patience
|
||
self.save_best_only = save_best_only
|
||
self.tensorboard = tensorboard
|
||
self.checkpoint = checkpoint
|
||
self.custom_callbacks = custom_callbacks or []
|
||
|
||
# 训练历史
|
||
self.history = None
|
||
|
||
# 训练日志记录器
|
||
self.training_logger = TrainingLogger(model.model_name)
|
||
|
||
logger.info(f"初始化训练器,模型: {model.model_name}, 轮数: {epochs}, 批大小: {self.batch_size}")
|
||
|
||
def _create_callbacks(self) -> List[tf.keras.callbacks.Callback]:
|
||
"""
|
||
创建回调函数列表
|
||
|
||
Returns:
|
||
回调函数列表
|
||
"""
|
||
callbacks = []
|
||
|
||
# 早停
|
||
if self.early_stopping:
|
||
early_stopping = tf.keras.callbacks.EarlyStopping(
|
||
monitor='val_loss',
|
||
patience=self.early_stopping_patience,
|
||
restore_best_weights=True,
|
||
verbose=1
|
||
)
|
||
callbacks.append(early_stopping)
|
||
|
||
# 学习率衰减
|
||
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
|
||
monitor='val_loss',
|
||
factor=0.5,
|
||
patience=self.early_stopping_patience // 2,
|
||
min_lr=1e-6,
|
||
verbose=1
|
||
)
|
||
callbacks.append(reduce_lr)
|
||
|
||
# 模型检查点
|
||
if self.checkpoint:
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
checkpoint_dir = os.path.join(SAVED_MODELS_DIR, 'checkpoints')
|
||
ensure_dir(checkpoint_dir)
|
||
|
||
checkpoint_path = os.path.join(
|
||
checkpoint_dir,
|
||
f"{self.model.model_name}_{timestamp}.h5"
|
||
)
|
||
|
||
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||
filepath=checkpoint_path,
|
||
save_best_only=self.save_best_only,
|
||
monitor='val_loss',
|
||
verbose=1
|
||
)
|
||
callbacks.append(model_checkpoint)
|
||
|
||
# TensorBoard
|
||
if self.tensorboard:
|
||
log_dir = os.path.join(
|
||
SAVED_MODELS_DIR,
|
||
'logs',
|
||
f"{self.model.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
)
|
||
ensure_dir(log_dir)
|
||
|
||
tensorboard_callback = tf.keras.callbacks.TensorBoard(
|
||
log_dir=log_dir,
|
||
histogram_freq=1,
|
||
update_freq='epoch'
|
||
)
|
||
callbacks.append(tensorboard_callback)
|
||
|
||
# 添加自定义回调函数
|
||
callbacks.extend(self.custom_callbacks)
|
||
|
||
return callbacks
|
||
|
||
def _log_training_progress(self, epoch: int, logs: Dict[str, float]) -> None:
|
||
"""
|
||
记录训练进度
|
||
|
||
Args:
|
||
epoch: 当前轮数
|
||
logs: 日志信息
|
||
"""
|
||
self.training_logger.log_epoch(epoch, logs)
|
||
|
||
def train(self, x_train: Union[np.ndarray, tf.data.Dataset],
|
||
y_train: Optional[np.ndarray] = None,
|
||
x_val: Optional[Union[np.ndarray, tf.data.Dataset]] = None,
|
||
y_val: Optional[np.ndarray] = None,
|
||
class_weights: Optional[Dict[int, float]] = None) -> Dict[str, List[float]]:
|
||
"""
|
||
训练模型
|
||
|
||
Args:
|
||
x_train: 训练数据特征
|
||
y_train: 训练数据标签
|
||
x_val: 验证数据特征
|
||
y_val: 验证数据标签
|
||
class_weights: 类别权重
|
||
|
||
Returns:
|
||
训练历史
|
||
"""
|
||
logger.info(f"开始训练模型: {self.model.model_name}")
|
||
|
||
# 创建回调函数
|
||
callbacks = self._create_callbacks()
|
||
|
||
# 添加训练进度记录回调
|
||
progress_callback = tf.keras.callbacks.LambdaCallback(
|
||
on_epoch_end=lambda epoch, logs: self._log_training_progress(epoch, logs)
|
||
)
|
||
callbacks.append(progress_callback)
|
||
|
||
# 记录开始时间
|
||
start_time = time.time()
|
||
|
||
# 记录训练开始信息
|
||
model_config = self.model.get_config()
|
||
train_config = {
|
||
"epochs": self.epochs,
|
||
"batch_size": self.batch_size,
|
||
"validation_split": self.validation_split,
|
||
"early_stopping": self.early_stopping,
|
||
"early_stopping_patience": self.early_stopping_patience
|
||
}
|
||
self.training_logger.log_training_start({**model_config, **train_config})
|
||
|
||
# 准备验证数据
|
||
validation_data = None
|
||
if x_val is not None and y_val is not None:
|
||
validation_data = (x_val, y_val)
|
||
|
||
# 训练模型
|
||
history = self.model.fit(
|
||
x_train, y_train,
|
||
validation_data=validation_data,
|
||
epochs=self.epochs,
|
||
callbacks=callbacks,
|
||
class_weights=class_weights,
|
||
verbose=1
|
||
)
|
||
|
||
# 计算训练时间
|
||
train_time = time.time() - start_time
|
||
|
||
# 保存训练历史
|
||
self.history = history.history
|
||
|
||
# 找出最佳性能
|
||
best_val_loss = min(history.history['val_loss']) if 'val_loss' in history.history else None
|
||
best_val_acc = max(history.history['val_accuracy']) if 'val_accuracy' in history.history else None
|
||
|
||
best_metrics = {}
|
||
if best_val_loss is not None:
|
||
best_metrics['val_loss'] = best_val_loss
|
||
if best_val_acc is not None:
|
||
best_metrics['val_accuracy'] = best_val_acc
|
||
|
||
# 记录训练结束信息
|
||
self.training_logger.log_training_end(train_time, best_metrics)
|
||
|
||
logger.info(f"模型训练完成,用时: {train_time:.2f} 秒")
|
||
|
||
return history.history
|
||
|
||
def plot_training_history(self, metrics: Optional[List[str]] = None,
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制训练历史
|
||
|
||
Args:
|
||
metrics: 要绘制的指标列表,默认为['loss', 'accuracy']
|
||
save_path: 保存路径,如果为None则显示图像
|
||
"""
|
||
if self.history is None:
|
||
raise ValueError("模型尚未训练,没有训练历史")
|
||
|
||
if metrics is None:
|
||
metrics = ['loss', 'accuracy']
|
||
|
||
plt.figure(figsize=(12, 5))
|
||
|
||
for i, metric in enumerate(metrics):
|
||
plt.subplot(1, len(metrics), i + 1)
|
||
|
||
if metric in self.history:
|
||
plt.plot(self.history[metric], label=f'train_{metric}')
|
||
|
||
val_metric = f'val_{metric}'
|
||
if val_metric in self.history:
|
||
plt.plot(self.history[val_metric], label=f'val_{metric}')
|
||
|
||
plt.title(f'Model {metric}')
|
||
plt.xlabel('Epoch')
|
||
plt.ylabel(metric)
|
||
plt.legend()
|
||
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"训练历史图已保存到: {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
def save_trained_model(self, filepath: Optional[str] = None) -> str:
|
||
"""
|
||
保存训练好的模型
|
||
|
||
Args:
|
||
filepath: 保存路径,如果为None则使用默认路径
|
||
|
||
Returns:
|
||
保存路径
|
||
"""
|
||
return self.model.save(filepath)
|
||
|
||
|
||
================================================================================
|
||
文件: tests/__init__.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: tests/test_evaluation.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: tests/test_preprocessing.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: tests/test_models.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: utils/text_utils.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: utils/__init__.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: utils/time_utils.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: utils/logger.py
|
||
================================================================================
|
||
|
||
"""
|
||
日志工具模块
|
||
"""
|
||
import logging
|
||
import sys
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
import os
|
||
|
||
from config.system_config import LOG_DIR, LOG_LEVEL, LOG_FORMAT
|
||
|
||
|
||
def get_logger(name, level=None, log_file=None):
|
||
"""
|
||
获取logger实例
|
||
|
||
Args:
|
||
name: logger名称
|
||
level: 日志级别,默认为系统配置
|
||
log_file: 日志文件路径,默认为None(仅控制台输出)
|
||
|
||
Returns:
|
||
logger实例
|
||
"""
|
||
level = level or LOG_LEVEL
|
||
|
||
# 创建logger
|
||
logger = logging.getLogger(name)
|
||
logger.setLevel(getattr(logging, level))
|
||
|
||
# 避免重复添加handler
|
||
if logger.handlers:
|
||
return logger
|
||
|
||
# 创建格式化器
|
||
formatter = logging.Formatter(LOG_FORMAT)
|
||
|
||
# 创建控制台处理器
|
||
console_handler = logging.StreamHandler(sys.stdout)
|
||
console_handler.setFormatter(formatter)
|
||
logger.addHandler(console_handler)
|
||
|
||
# 如果指定了日志文件,创建文件处理器
|
||
if log_file:
|
||
log_path = Path(LOG_DIR) / log_file
|
||
file_handler = logging.FileHandler(log_path, encoding='utf-8')
|
||
file_handler.setFormatter(formatter)
|
||
logger.addHandler(file_handler)
|
||
|
||
return logger
|
||
|
||
|
||
def get_time_logger(name):
|
||
"""
|
||
获取带时间戳的logger实例
|
||
|
||
Args:
|
||
name: logger名称
|
||
|
||
Returns:
|
||
logger实例
|
||
"""
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
log_file = f"{name}_{timestamp}.log"
|
||
return get_logger(name, log_file=log_file)
|
||
|
||
|
||
class TrainingLogger:
|
||
"""训练过程日志记录器"""
|
||
|
||
def __init__(self, model_name):
|
||
"""
|
||
初始化训练日志记录器
|
||
|
||
Args:
|
||
model_name: 模型名称
|
||
"""
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
self.log_file = f"training_{model_name}_{timestamp}.log"
|
||
self.logger = get_logger(f"training_{model_name}", log_file=self.log_file)
|
||
|
||
# 创建CSV日志
|
||
self.csv_path = Path(LOG_DIR) / f"metrics_{model_name}_{timestamp}.csv"
|
||
with open(self.csv_path, 'w', encoding='utf-8') as f:
|
||
f.write("epoch,loss,accuracy,val_loss,val_accuracy\n")
|
||
|
||
def log_epoch(self, epoch, metrics):
|
||
"""记录每个epoch的指标"""
|
||
# 日志记录
|
||
metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
||
self.logger.info(f"Epoch {epoch}: {metrics_str}")
|
||
|
||
# CSV记录
|
||
csv_line = f"{epoch},{metrics.get('loss', '')},{metrics.get('accuracy', '')}," \
|
||
f"{metrics.get('val_loss', '')},{metrics.get('val_accuracy', '')}\n"
|
||
with open(self.csv_path, 'a', encoding='utf-8') as f:
|
||
f.write(csv_line)
|
||
|
||
def log_training_start(self, config):
|
||
"""记录训练开始信息"""
|
||
self.logger.info("=" * 50)
|
||
self.logger.info("训练开始")
|
||
self.logger.info("模型配置:")
|
||
for key, value in config.items():
|
||
self.logger.info(f" {key}: {value}")
|
||
self.logger.info("=" * 50)
|
||
|
||
def log_training_end(self, duration, best_metrics):
|
||
"""记录训练结束信息"""
|
||
self.logger.info("=" * 50)
|
||
self.logger.info(f"训练结束,总用时: {duration:.2f}秒")
|
||
self.logger.info("最佳性能:")
|
||
for key, value in best_metrics.items():
|
||
self.logger.info(f" {key}: {value:.4f}")
|
||
self.logger.info("=" * 50)
|
||
|
||
================================================================================
|
||
文件: utils/file_utils.py
|
||
================================================================================
|
||
|
||
"""
|
||
文件处理工具模块
|
||
"""
|
||
import os
|
||
import shutil
|
||
import json
|
||
import pickle
|
||
import csv
|
||
from pathlib import Path
|
||
import time
|
||
import hashlib
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import zipfile
|
||
import tarfile
|
||
|
||
from config.system_config import ENCODING, DATA_LOADING_WORKERS
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("file_utils")
|
||
|
||
|
||
def read_text_file(file_path, encoding=ENCODING):
|
||
"""
|
||
读取文本文件内容
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
encoding: 文件编码
|
||
|
||
Returns:
|
||
文件内容
|
||
"""
|
||
try:
|
||
with open(file_path, 'r', encoding=encoding) as file:
|
||
return file.read()
|
||
except Exception as e:
|
||
logger.error(f"读取文件 {file_path} 时出错: {str(e)}")
|
||
return None
|
||
|
||
|
||
def write_text_file(content, file_path, encoding=ENCODING):
|
||
"""
|
||
写入文本文件
|
||
|
||
Args:
|
||
content: 文件内容
|
||
file_path: 文件路径
|
||
encoding: 文件编码
|
||
|
||
Returns:
|
||
成功标志
|
||
"""
|
||
try:
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||
|
||
with open(file_path, 'w', encoding=encoding) as file:
|
||
file.write(content)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"写入文件 {file_path} 时出错: {str(e)}")
|
||
return False
|
||
|
||
|
||
def save_json(data, file_path, encoding=ENCODING):
|
||
"""
|
||
保存JSON数据到文件
|
||
|
||
Args:
|
||
data: 要保存的数据
|
||
file_path: 文件路径
|
||
encoding: 文件编码
|
||
|
||
Returns:
|
||
成功标志
|
||
"""
|
||
try:
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||
|
||
with open(file_path, 'w', encoding=encoding) as file:
|
||
json.dump(data, file, ensure_ascii=False, indent=2)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"保存JSON文件 {file_path} 时出错: {str(e)}")
|
||
return False
|
||
|
||
|
||
def load_json(file_path, encoding=ENCODING):
|
||
"""
|
||
从文件加载JSON数据
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
encoding: 文件编码
|
||
|
||
Returns:
|
||
加载的数据
|
||
"""
|
||
try:
|
||
with open(file_path, 'r', encoding=encoding) as file:
|
||
return json.load(file)
|
||
except Exception as e:
|
||
logger.error(f"加载JSON文件 {file_path} 时出错: {str(e)}")
|
||
return None
|
||
|
||
|
||
def save_pickle(data, file_path):
|
||
"""
|
||
使用pickle保存数据
|
||
|
||
Args:
|
||
data: 要保存的数据
|
||
file_path: 文件路径
|
||
|
||
Returns:
|
||
成功标志
|
||
"""
|
||
try:
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||
|
||
with open(file_path, 'wb') as file:
|
||
pickle.dump(data, file)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"保存pickle文件 {file_path} 时出错: {str(e)}")
|
||
return False
|
||
|
||
|
||
def load_pickle(file_path):
|
||
"""
|
||
从文件加载pickle数据
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
|
||
Returns:
|
||
加载的数据
|
||
"""
|
||
try:
|
||
with open(file_path, 'rb') as file:
|
||
return pickle.load(file)
|
||
except Exception as e:
|
||
logger.error(f"加载pickle文件 {file_path} 时出错: {str(e)}")
|
||
return None
|
||
|
||
|
||
def read_files_parallel(file_paths, max_workers=DATA_LOADING_WORKERS, encoding=ENCODING):
|
||
"""
|
||
并行读取多个文本文件
|
||
|
||
Args:
|
||
file_paths: 文件路径列表
|
||
max_workers: 最大工作线程数
|
||
encoding: 文件编码
|
||
|
||
Returns:
|
||
文件内容列表
|
||
"""
|
||
start_time = time.time()
|
||
results = []
|
||
|
||
# 定义单个读取函数
|
||
def read_single_file(file_path):
|
||
return read_text_file(file_path, encoding)
|
||
|
||
# 使用线程池并行读取
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
future_to_file = {executor.submit(read_single_file, file_path): file_path
|
||
for file_path in file_paths}
|
||
|
||
# 收集结果
|
||
for future in as_completed(future_to_file):
|
||
file_path = future_to_file[future]
|
||
try:
|
||
content = future.result()
|
||
if content is not None:
|
||
results.append(content)
|
||
except Exception as e:
|
||
logger.error(f"处理文件 {file_path} 时出错: {str(e)}")
|
||
|
||
elapsed = time.time() - start_time
|
||
logger.info(f"并行读取 {len(file_paths)} 个文件,成功 {len(results)} 个,用时 {elapsed:.2f} 秒")
|
||
|
||
return results
|
||
|
||
|
||
def get_file_md5(file_path):
|
||
"""
|
||
计算文件的MD5哈希值
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
|
||
Returns:
|
||
MD5哈希值
|
||
"""
|
||
hash_md5 = hashlib.md5()
|
||
|
||
try:
|
||
with open(file_path, "rb") as f:
|
||
for chunk in iter(lambda: f.read(4096), b""):
|
||
hash_md5.update(chunk)
|
||
return hash_md5.hexdigest()
|
||
except Exception as e:
|
||
logger.error(f"计算文件 {file_path} 的MD5值时出错: {str(e)}")
|
||
return None
|
||
|
||
|
||
def extract_archive(archive_path, extract_to=None):
|
||
"""
|
||
解压缩文件
|
||
|
||
Args:
|
||
archive_path: 压缩文件路径
|
||
extract_to: 解压目标路径,默认为同目录
|
||
|
||
Returns:
|
||
成功标志
|
||
"""
|
||
if extract_to is None:
|
||
extract_to = os.path.dirname(archive_path)
|
||
|
||
try:
|
||
if archive_path.endswith('.zip'):
|
||
with zipfile.ZipFile(archive_path, 'r') as zip_ref:
|
||
zip_ref.extractall(extract_to)
|
||
elif archive_path.endswith(('.tar.gz', '.tgz')):
|
||
with tarfile.open(archive_path, 'r:gz') as tar_ref:
|
||
tar_ref.extractall(extract_to)
|
||
elif archive_path.endswith('.tar'):
|
||
with tarfile.open(archive_path, 'r') as tar_ref:
|
||
tar_ref.extractall(extract_to)
|
||
else:
|
||
logger.error(f"不支持的压缩格式: {archive_path}")
|
||
return False
|
||
|
||
logger.info(f"成功解压 {archive_path} 到 {extract_to}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"解压 {archive_path} 时出错: {str(e)}")
|
||
return False
|
||
|
||
|
||
def list_files(directory, pattern=None, recursive=True):
|
||
"""
|
||
列出目录中的文件
|
||
|
||
Args:
|
||
directory: 目录路径
|
||
pattern: 文件名模式(支持通配符)
|
||
recursive: 是否递归搜索子目录
|
||
|
||
Returns:
|
||
文件路径列表
|
||
"""
|
||
if not os.path.exists(directory):
|
||
logger.error(f"目录不存在: {directory}")
|
||
return []
|
||
|
||
directory = Path(directory)
|
||
|
||
if pattern:
|
||
if recursive:
|
||
return [str(p) for p in directory.glob(f"**/{pattern}")]
|
||
else:
|
||
return [str(p) for p in directory.glob(pattern)]
|
||
else:
|
||
if recursive:
|
||
files = []
|
||
for p in directory.rglob("*"):
|
||
if p.is_file():
|
||
files.append(str(p))
|
||
return files
|
||
else:
|
||
return [str(p) for p in directory.iterdir() if p.is_file()]
|
||
|
||
|
||
def ensure_dir(directory):
|
||
"""
|
||
确保目录存在,不存在则创建
|
||
|
||
Args:
|
||
directory: 目录路径
|
||
"""
|
||
os.makedirs(directory, exist_ok=True)
|
||
|
||
|
||
def remove_dir(directory):
|
||
"""
|
||
删除目录及其内容
|
||
|
||
Args:
|
||
directory: 目录路径
|
||
|
||
Returns:
|
||
成功标志
|
||
"""
|
||
try:
|
||
if os.path.exists(directory):
|
||
shutil.rmtree(directory)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"删除目录 {directory} 时出错: {str(e)}")
|
||
return False
|
||
|
||
================================================================================
|
||
文件: models/ensemble_model.py
|
||
================================================================================
|
||
|
||
"""
|
||
集成模型:实现多个模型的集成
|
||
"""
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import os
|
||
|
||
from config.system_config import CLASSIFIERS_DIR
|
||
from models.base_model import TextClassificationModel
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("EnsembleModel")
|
||
|
||
|
||
class EnsembleModel:
|
||
"""模型集成类,集成多个模型的预测结果"""
|
||
|
||
def __init__(self, models: List[TextClassificationModel],
|
||
weights: Optional[List[float]] = None,
|
||
voting: str = 'soft',
|
||
name: str = "ensemble_model"):
|
||
"""
|
||
初始化集成模型
|
||
|
||
Args:
|
||
models: 模型列表
|
||
weights: 各模型的权重,默认为均等权重
|
||
voting: 投票方式,'hard'表示多数投票,'soft'表示概率平均
|
||
name: 集成模型名称
|
||
"""
|
||
self.models = models
|
||
self.num_models = len(models)
|
||
|
||
# 验证模型数量
|
||
if self.num_models == 0:
|
||
raise ValueError("模型列表不能为空")
|
||
|
||
# 设置权重
|
||
if weights is None:
|
||
self.weights = np.ones(self.num_models) / self.num_models
|
||
else:
|
||
if len(weights) != self.num_models:
|
||
raise ValueError("权重数量必须与模型数量相同")
|
||
|
||
# 归一化权重
|
||
self.weights = np.array(weights) / np.sum(weights)
|
||
|
||
# 验证投票方式
|
||
self.voting = voting.lower()
|
||
if self.voting not in ['hard', 'soft']:
|
||
raise ValueError("无效的投票方式,支持的方式: 'hard', 'soft'")
|
||
|
||
# 从第一个模型获取类别数
|
||
self.num_classes = models[0].num_classes
|
||
|
||
# 验证所有模型的类别数是否相同
|
||
for i, model in enumerate(models[1:], 1):
|
||
if model.num_classes != self.num_classes:
|
||
raise ValueError(
|
||
f"模型 {i} 的类别数 ({model.num_classes}) 与第一个模型的类别数 ({self.num_classes}) 不同")
|
||
|
||
self.name = name
|
||
|
||
logger.info(f"初始化集成模型,包含 {self.num_models} 个模型,投票方式: {self.voting}")
|
||
|
||
def predict(self, x: Union[np.ndarray, tf.data.Dataset, List],
|
||
batch_size: Optional[int] = None,
|
||
verbose: int = 0) -> np.ndarray:
|
||
"""
|
||
使用集成模型进行预测
|
||
|
||
Args:
|
||
x: 预测数据
|
||
batch_size: 批大小
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
预测概率
|
||
"""
|
||
# 获取每个模型的预测结果
|
||
all_predictions = []
|
||
|
||
for i, model in enumerate(self.models):
|
||
logger.info(f"获取模型 {i + 1}/{self.num_models} 的预测结果")
|
||
predictions = model.predict(x, batch_size, verbose)
|
||
|
||
# 如果是二分类且输出形状是(n,1),转换为(n,2)
|
||
if self.num_classes == 2 and predictions.shape[1:] == (1,):
|
||
predictions = np.hstack([1 - predictions, predictions])
|
||
|
||
all_predictions.append(predictions)
|
||
|
||
# 根据投票方式进行集成
|
||
if self.voting == 'hard':
|
||
# 硬投票:每个模型预测的类别,取众数
|
||
individual_classes = [np.argmax(pred, axis=1) for pred in all_predictions]
|
||
|
||
# 获取带权重的预测类别频率
|
||
ensemble_result = np.zeros((len(x), self.num_classes))
|
||
|
||
for i, classes in enumerate(individual_classes):
|
||
for j, cls in enumerate(classes):
|
||
ensemble_result[j, cls] += self.weights[i]
|
||
|
||
return ensemble_result
|
||
else: # soft voting
|
||
# 软投票:对每个模型的预测概率进行加权平均
|
||
weighted_predictions = [pred * weight for pred, weight in zip(all_predictions, self.weights)]
|
||
ensemble_result = np.sum(weighted_predictions, axis=0)
|
||
|
||
return ensemble_result
|
||
|
||
def predict_classes(self, x: Union[np.ndarray, tf.data.Dataset, List],
|
||
batch_size: Optional[int] = None,
|
||
verbose: int = 0) -> np.ndarray:
|
||
"""
|
||
使用集成模型预测类别
|
||
|
||
Args:
|
||
x: 预测数据
|
||
batch_size: 批大小
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
预测的类别索引
|
||
"""
|
||
# 获取预测概率
|
||
predictions = self.predict(x, batch_size, verbose)
|
||
|
||
# 获取最大概率的类别索引
|
||
return np.argmax(predictions, axis=1)
|
||
|
||
def save(self, directory: Optional[str] = None) -> str:
|
||
"""
|
||
保存集成模型
|
||
|
||
Args:
|
||
directory: 保存目录,默认为CLASSIFIERS_DIR
|
||
|
||
Returns:
|
||
保存路径
|
||
"""
|
||
if directory is None:
|
||
import time
|
||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||
directory = os.path.join(CLASSIFIERS_DIR, f"{self.name}_{timestamp}")
|
||
|
||
os.makedirs(directory, exist_ok=True)
|
||
|
||
# 保存模型列表
|
||
model_paths = []
|
||
for i, model in enumerate(self.models):
|
||
model_path = os.path.join(directory, f"model_{i}")
|
||
model.save(model_path)
|
||
model_paths.append(model_path)
|
||
|
||
# 保存集成配置
|
||
config = {
|
||
"name": self.name,
|
||
"num_models": self.num_models,
|
||
"model_paths": model_paths,
|
||
"weights": self.weights.tolist(),
|
||
"voting": self.voting,
|
||
"num_classes": self.num_classes
|
||
}
|
||
|
||
import json
|
||
config_path = os.path.join(directory, "ensemble_config.json")
|
||
with open(config_path, 'w', encoding='utf-8') as f:
|
||
json.dump(config, f, ensure_ascii=False, indent=4)
|
||
|
||
logger.info(f"集成模型已保存到目录: {directory}")
|
||
|
||
return directory
|
||
|
||
@classmethod
|
||
def load(cls, directory: str, custom_objects: Optional[Dict[str, Any]] = None) -> 'EnsembleModel':
|
||
"""
|
||
加载集成模型
|
||
|
||
Args:
|
||
directory: 模型目录
|
||
custom_objects: 自定义对象字典
|
||
|
||
Returns:
|
||
加载的集成模型实例
|
||
"""
|
||
# 加载配置
|
||
config_path = os.path.join(directory, "ensemble_config.json")
|
||
|
||
import json
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
|
||
# 加载子模型
|
||
from models.model_factory import ModelFactory
|
||
|
||
models = []
|
||
model_paths = config["model_paths"]
|
||
|
||
for model_path in model_paths:
|
||
model = ModelFactory.load_model(model_path, custom_objects)
|
||
models.append(model)
|
||
|
||
# 创建集成模型
|
||
ensemble = cls(
|
||
models=models,
|
||
weights=config["weights"],
|
||
voting=config["voting"],
|
||
name=config["name"]
|
||
)
|
||
|
||
logger.info(f"从目录 {directory} 加载集成模型成功")
|
||
|
||
return ensemble
|
||
|
||
|
||
================================================================================
|
||
文件: models/transformer_model.py
|
||
================================================================================
|
||
|
||
"""
|
||
Transformer模型:实现基于Transformer的文本分类模型
|
||
"""
|
||
import tensorflow as tf
|
||
from tensorflow.keras.models import Model
|
||
from tensorflow.keras.layers import (
|
||
Input, Embedding, Dense, Dropout, LayerNormalization,
|
||
GlobalAveragePooling1D, MultiHeadAttention, Add
|
||
)
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import numpy as np
|
||
|
||
from config.model_config import (
|
||
MAX_SEQUENCE_LENGTH, TRANSFORMER_CONFIG
|
||
)
|
||
from models.base_model import TextClassificationModel
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("TransformerModel")
|
||
|
||
|
||
class TransformerBlock(tf.keras.layers.Layer):
|
||
"""Transformer块,包含多头注意力和前馈网络"""
|
||
|
||
def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout_rate: float = 0.1):
|
||
"""
|
||
初始化Transformer块
|
||
|
||
Args:
|
||
embed_dim: 嵌入维度
|
||
num_heads: 注意力头数
|
||
ff_dim: 前馈网络维度
|
||
dropout_rate: Dropout比例
|
||
"""
|
||
super(TransformerBlock, self).__init__()
|
||
self.embed_dim = embed_dim
|
||
self.num_heads = num_heads
|
||
self.ff_dim = ff_dim
|
||
self.dropout_rate = dropout_rate
|
||
|
||
self.attention = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
|
||
self.ffn = tf.keras.Sequential([
|
||
Dense(ff_dim, activation="relu"),
|
||
Dense(embed_dim),
|
||
])
|
||
self.layernorm1 = LayerNormalization(epsilon=1e-6)
|
||
self.layernorm2 = LayerNormalization(epsilon=1e-6)
|
||
self.dropout1 = Dropout(dropout_rate)
|
||
self.dropout2 = Dropout(dropout_rate)
|
||
|
||
def call(self, inputs, training=False):
|
||
"""
|
||
前向传播
|
||
|
||
Args:
|
||
inputs: 输入张量
|
||
training: 是否处于训练模式
|
||
|
||
Returns:
|
||
输出张量
|
||
"""
|
||
# 多头自注意力
|
||
attention_output = self.attention(inputs, inputs)
|
||
attention_output = self.dropout1(attention_output, training=training)
|
||
out1 = self.layernorm1(inputs + attention_output)
|
||
|
||
# 前馈网络
|
||
ffn_output = self.ffn(out1)
|
||
ffn_output = self.dropout2(ffn_output, training=training)
|
||
out2 = self.layernorm2(out1 + ffn_output)
|
||
|
||
return out2
|
||
|
||
def get_config(self):
|
||
"""获取配置"""
|
||
config = super(TransformerBlock, self).get_config()
|
||
config.update({
|
||
"embed_dim": self.embed_dim,
|
||
"num_heads": self.num_heads,
|
||
"ff_dim": self.ff_dim,
|
||
"dropout_rate": self.dropout_rate
|
||
})
|
||
return config
|
||
|
||
|
||
class TransformerTextClassifier(TextClassificationModel):
|
||
"""Transformer文本分类模型"""
|
||
|
||
def __init__(self, num_classes: int, vocab_size: int,
|
||
embedding_dim: int = TRANSFORMER_CONFIG["embedding_dim"],
|
||
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
|
||
num_heads: int = TRANSFORMER_CONFIG["num_heads"],
|
||
ff_dim: int = TRANSFORMER_CONFIG["ff_dim"],
|
||
num_layers: int = TRANSFORMER_CONFIG["num_layers"],
|
||
dropout_rate: float = TRANSFORMER_CONFIG["dropout_rate"],
|
||
embedding_matrix: Optional[np.ndarray] = None,
|
||
trainable_embedding: bool = True,
|
||
use_positional_encoding: bool = True,
|
||
model_name: str = "transformer_text_classifier",
|
||
batch_size: int = 64,
|
||
learning_rate: float = 0.001):
|
||
"""
|
||
初始化Transformer文本分类模型
|
||
|
||
Args:
|
||
num_classes: 类别数量
|
||
vocab_size: 词汇表大小
|
||
embedding_dim: 词嵌入维度
|
||
max_sequence_length: 最大序列长度
|
||
num_heads: 注意力头数
|
||
ff_dim: 前馈网络维度
|
||
num_layers: Transformer层数
|
||
dropout_rate: Dropout比例
|
||
embedding_matrix: 预训练词嵌入矩阵,如果为None则使用随机初始化
|
||
trainable_embedding: 词嵌入是否可训练
|
||
use_positional_encoding: 是否使用位置编码
|
||
model_name: 模型名称
|
||
batch_size: 批大小
|
||
learning_rate: 学习率
|
||
"""
|
||
super().__init__(num_classes, model_name, batch_size, learning_rate)
|
||
|
||
self.vocab_size = vocab_size
|
||
self.embedding_dim = embedding_dim
|
||
self.max_sequence_length = max_sequence_length
|
||
self.num_heads = num_heads
|
||
self.ff_dim = ff_dim
|
||
self.num_layers = num_layers
|
||
self.dropout_rate = dropout_rate
|
||
self.embedding_matrix = embedding_matrix
|
||
self.trainable_embedding = trainable_embedding
|
||
self.use_positional_encoding = use_positional_encoding
|
||
|
||
# 更新配置
|
||
self.config.update({
|
||
"vocab_size": vocab_size,
|
||
"embedding_dim": embedding_dim,
|
||
"max_sequence_length": max_sequence_length,
|
||
"num_heads": num_heads,
|
||
"ff_dim": ff_dim,
|
||
"num_layers": num_layers,
|
||
"dropout_rate": dropout_rate,
|
||
"trainable_embedding": trainable_embedding,
|
||
"use_positional_encoding": use_positional_encoding,
|
||
"model_type": "Transformer"
|
||
})
|
||
|
||
logger.info(f"初始化Transformer文本分类模型,头数: {num_heads}, 层数: {num_layers}")
|
||
|
||
def _positional_encoding(self, max_length: int, d_model: int) -> tf.Tensor:
|
||
"""
|
||
生成位置编码
|
||
|
||
Args:
|
||
max_length: 最大序列长度
|
||
d_model: 模型维度
|
||
|
||
Returns:
|
||
位置编码张量
|
||
"""
|
||
positions = np.arange(max_length)[:, np.newaxis]
|
||
depths = np.arange(d_model)[np.newaxis, :] // 2 * 2
|
||
angle_rates = 1 / np.power(10000, depths / d_model)
|
||
angle_rads = positions * angle_rates
|
||
|
||
# sin用于偶数索引,cos用于奇数索引
|
||
sines = np.sin(angle_rads[:, 0::2])
|
||
cosines = np.cos(angle_rads[:, 1::2])
|
||
|
||
pos_encoding = np.zeros((max_length, d_model))
|
||
pos_encoding[:, 0::2] = sines
|
||
pos_encoding[:, 1::2] = cosines
|
||
|
||
return tf.cast(pos_encoding[tf.newaxis, ...], dtype=tf.float32)
|
||
|
||
def build(self) -> None:
|
||
"""构建Transformer模型架构"""
|
||
# Input layer
|
||
sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
|
||
|
||
# Embedding layer
|
||
if self.embedding_matrix is not None:
|
||
embedding_layer = Embedding(
|
||
input_dim=self.vocab_size,
|
||
output_dim=self.embedding_dim,
|
||
weights=[self.embedding_matrix],
|
||
input_length=self.max_sequence_length,
|
||
trainable=self.trainable_embedding,
|
||
name='embedding'
|
||
)
|
||
else:
|
||
embedding_layer = Embedding(
|
||
input_dim=self.vocab_size,
|
||
output_dim=self.embedding_dim,
|
||
input_length=self.max_sequence_length,
|
||
trainable=True,
|
||
name='embedding'
|
||
)
|
||
|
||
embedded_sequences = embedding_layer(sequence_input)
|
||
|
||
# 添加位置编码
|
||
if self.use_positional_encoding:
|
||
pos_encoding = self._positional_encoding(self.max_sequence_length, self.embedding_dim)
|
||
embedded_sequences = embedded_sequences + pos_encoding
|
||
|
||
# Transformer层
|
||
x = embedded_sequences
|
||
for i in range(self.num_layers):
|
||
x = TransformerBlock(
|
||
embed_dim=self.embedding_dim,
|
||
num_heads=self.num_heads,
|
||
ff_dim=self.ff_dim,
|
||
dropout_rate=self.dropout_rate,
|
||
name=f'transformer_block_{i + 1}'
|
||
)(x)
|
||
|
||
# 全局池化
|
||
x = GlobalAveragePooling1D(name='global_avg_pooling')(x)
|
||
|
||
# Dropout for regularization
|
||
x = Dropout(self.dropout_rate, name='dropout_1')(x)
|
||
|
||
# Dense layer
|
||
x = Dense(128, activation='relu', name='dense_1')(x)
|
||
x = Dropout(self.dropout_rate, name='dropout_2')(x)
|
||
|
||
# Output layer
|
||
if self.num_classes == 2:
|
||
# Binary classification
|
||
predictions = Dense(1, activation='sigmoid', name='predictions')(x)
|
||
else:
|
||
# Multi-class classification
|
||
predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
|
||
|
||
# Build the model
|
||
self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
|
||
|
||
logger.info(f"Transformer模型构建完成,头数: {self.num_heads}, 层数: {self.num_layers}")
|
||
|
||
def compile(self, optimizer=None, loss=None, metrics=None) -> None:
|
||
"""
|
||
编译Transformer模型
|
||
|
||
Args:
|
||
optimizer: 优化器,默认为Adam
|
||
loss: 损失函数,默认根据类别数量选择
|
||
metrics: 评估指标,默认为accuracy
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 默认优化器
|
||
if optimizer is None:
|
||
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
|
||
|
||
# 默认损失函数
|
||
if loss is None:
|
||
if self.num_classes == 2:
|
||
loss = 'binary_crossentropy'
|
||
else:
|
||
loss = 'sparse_categorical_crossentropy'
|
||
|
||
# 默认评估指标
|
||
if metrics is None:
|
||
metrics = ['accuracy']
|
||
|
||
# 编译模型
|
||
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
|
||
logger.info(f"Transformer模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}")
|
||
|
||
================================================================================
|
||
文件: models/__init__.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: models/base_model.py
|
||
================================================================================
|
||
|
||
"""
|
||
模型基类:定义所有文本分类模型的通用接口
|
||
"""
|
||
import os
|
||
import time
|
||
import json
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from tensorflow.keras.models import Model, load_model
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
|
||
from abc import ABC, abstractmethod
|
||
|
||
from config.system_config import SAVED_MODELS_DIR, CLASSIFIERS_DIR
|
||
from config.model_config import (
|
||
BATCH_SIZE, LEARNING_RATE, EARLY_STOPPING_PATIENCE,
|
||
REDUCE_LR_PATIENCE, REDUCE_LR_FACTOR
|
||
)
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import ensure_dir, save_json
|
||
|
||
logger = get_logger("BaseModel")
|
||
|
||
|
||
class TextClassificationModel(ABC):
|
||
"""文本分类模型基类,定义所有模型的通用接口"""
|
||
|
||
def __init__(self, num_classes: int, model_name: str = "text_classifier",
|
||
batch_size: int = BATCH_SIZE,
|
||
learning_rate: float = LEARNING_RATE):
|
||
"""
|
||
初始化文本分类模型
|
||
|
||
Args:
|
||
num_classes: 类别数量
|
||
model_name: 模型名称
|
||
batch_size: 批大小
|
||
learning_rate: 学习率
|
||
"""
|
||
self.num_classes = num_classes
|
||
self.model_name = model_name
|
||
self.batch_size = batch_size
|
||
self.learning_rate = learning_rate
|
||
|
||
# 模型实例
|
||
self.model = None
|
||
|
||
# 训练历史
|
||
self.history = None
|
||
|
||
# 训练配置
|
||
self.config = {
|
||
"model_name": model_name,
|
||
"num_classes": num_classes,
|
||
"batch_size": batch_size,
|
||
"learning_rate": learning_rate
|
||
}
|
||
|
||
# 验证集合最佳性能
|
||
self.best_val_loss = float('inf')
|
||
self.best_val_accuracy = 0.0
|
||
|
||
logger.info(f"初始化 {model_name} 模型,类别数: {num_classes}")
|
||
|
||
@abstractmethod
|
||
def build(self) -> None:
|
||
"""构建模型架构,这是一个抽象方法,子类必须实现"""
|
||
pass
|
||
|
||
def compile(self, optimizer: Optional[tf.keras.optimizers.Optimizer] = None,
|
||
loss: Optional[Union[str, tf.keras.losses.Loss]] = None,
|
||
metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None) -> None:
|
||
"""
|
||
编译模型
|
||
|
||
Args:
|
||
optimizer: 优化器,默认为Adam
|
||
loss: 损失函数,默认为sparse_categorical_crossentropy
|
||
metrics: 评估指标,默认为accuracy
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 默认优化器
|
||
if optimizer is None:
|
||
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
|
||
|
||
# 默认损失函数
|
||
if loss is None:
|
||
loss = 'sparse_categorical_crossentropy'
|
||
|
||
# 默认评估指标
|
||
if metrics is None:
|
||
metrics = ['accuracy']
|
||
|
||
# 编译模型
|
||
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
|
||
logger.info(f"模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}, 评估指标: {metrics}")
|
||
|
||
def summary(self) -> None:
|
||
"""打印模型概要"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
self.model.summary()
|
||
|
||
def fit(self, x_train: Union[np.ndarray, tf.data.Dataset],
|
||
y_train: Optional[np.ndarray] = None,
|
||
validation_data: Optional[Union[Tuple[np.ndarray, np.ndarray], tf.data.Dataset]] = None,
|
||
epochs: int = 10,
|
||
callbacks: Optional[List[tf.keras.callbacks.Callback]] = None,
|
||
class_weights: Optional[Dict[int, float]] = None,
|
||
verbose: int = 1) -> tf.keras.callbacks.History:
|
||
"""
|
||
训练模型
|
||
|
||
Args:
|
||
x_train: 训练数据特征
|
||
y_train: 训练数据标签
|
||
validation_data: 验证数据
|
||
epochs: 训练轮数
|
||
callbacks: 回调函数列表
|
||
class_weights: 类别权重
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
训练历史
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 记录开始时间
|
||
start_time = time.time()
|
||
|
||
# 添加默认回调函数
|
||
if callbacks is None:
|
||
callbacks = self._get_default_callbacks()
|
||
|
||
# 训练模型
|
||
if isinstance(x_train, tf.data.Dataset):
|
||
# 如果输入是TensorFlow Dataset
|
||
history = self.model.fit(
|
||
x_train,
|
||
epochs=epochs,
|
||
validation_data=validation_data,
|
||
callbacks=callbacks,
|
||
class_weight=class_weights,
|
||
verbose=verbose
|
||
)
|
||
else:
|
||
# 如果输入是NumPy数组
|
||
history = self.model.fit(
|
||
x_train, y_train,
|
||
batch_size=self.batch_size,
|
||
epochs=epochs,
|
||
validation_data=validation_data,
|
||
callbacks=callbacks,
|
||
class_weight=class_weights,
|
||
verbose=verbose
|
||
)
|
||
|
||
# 计算训练时间
|
||
train_time = time.time() - start_time
|
||
|
||
# 保存训练历史
|
||
self.history = history.history
|
||
self.history['train_time'] = train_time
|
||
|
||
logger.info(f"模型训练完成,耗时: {train_time:.2f} 秒")
|
||
|
||
return history
|
||
|
||
def evaluate(self, x_test: Union[np.ndarray, tf.data.Dataset],
|
||
y_test: Optional[np.ndarray] = None,
|
||
verbose: int = 1) -> Dict[str, float]:
|
||
"""
|
||
评估模型
|
||
|
||
Args:
|
||
x_test: 测试数据特征
|
||
y_test: 测试数据标签
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
评估结果字典
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 评估模型
|
||
if isinstance(x_test, tf.data.Dataset):
|
||
# 如果输入是TensorFlow Dataset
|
||
results = self.model.evaluate(x_test, verbose=verbose)
|
||
else:
|
||
# 如果输入是NumPy数组
|
||
results = self.model.evaluate(x_test, y_test, batch_size=self.batch_size, verbose=verbose)
|
||
|
||
# 构建评估结果字典
|
||
metrics_names = self.model.metrics_names
|
||
evaluation_results = {name: float(value) for name, value in zip(metrics_names, results)}
|
||
|
||
logger.info(f"模型评估结果: {evaluation_results}")
|
||
|
||
return evaluation_results
|
||
|
||
def predict(self, x: Union[np.ndarray, tf.data.Dataset, List],
|
||
batch_size: Optional[int] = None,
|
||
verbose: int = 0) -> np.ndarray:
|
||
"""
|
||
使用模型进行预测
|
||
|
||
Args:
|
||
x: 预测数据
|
||
batch_size: 批大小
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
预测结果
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 使用模型进行预测
|
||
if batch_size is None:
|
||
batch_size = self.batch_size
|
||
|
||
return self.model.predict(x, batch_size=batch_size, verbose=verbose)
|
||
|
||
def predict_classes(self, x: Union[np.ndarray, tf.data.Dataset, List],
|
||
batch_size: Optional[int] = None,
|
||
verbose: int = 0) -> np.ndarray:
|
||
"""
|
||
使用模型预测类别
|
||
|
||
Args:
|
||
x: 预测数据
|
||
batch_size: 批大小
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
预测的类别索引
|
||
"""
|
||
# 获取模型预测概率
|
||
predictions = self.predict(x, batch_size, verbose)
|
||
|
||
# 获取最大概率的类别索引
|
||
return np.argmax(predictions, axis=1)
|
||
|
||
def save(self, filepath: Optional[str] = None,
|
||
save_format: str = 'tf',
|
||
include_optimizer: bool = True) -> str:
|
||
"""
|
||
保存模型
|
||
|
||
Args:
|
||
filepath: 保存路径,如果为None则使用默认路径
|
||
save_format: 保存格式,'tf'或'h5'
|
||
include_optimizer: 是否包含优化器状态
|
||
|
||
Returns:
|
||
保存路径
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 如果未指定保存路径,使用默认路径
|
||
if filepath is None:
|
||
ensure_dir(CLASSIFIERS_DIR)
|
||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||
filepath = os.path.join(CLASSIFIERS_DIR, f"{self.model_name}_{timestamp}")
|
||
|
||
# 保存模型
|
||
self.model.save(filepath, save_format=save_format, include_optimizer=include_optimizer)
|
||
|
||
# 保存模型配置
|
||
config_path = f"{filepath}_config.json"
|
||
with open(config_path, 'w', encoding='utf-8') as f:
|
||
json.dump(self.config, f, ensure_ascii=False, indent=4)
|
||
|
||
logger.info(f"模型已保存到: {filepath}")
|
||
|
||
return filepath
|
||
|
||
@classmethod
|
||
def load(cls, filepath: str, custom_objects: Optional[Dict[str, Any]] = None) -> 'TextClassificationModel':
|
||
"""
|
||
加载模型
|
||
|
||
Args:
|
||
filepath: 模型文件路径
|
||
custom_objects: 自定义对象字典
|
||
|
||
Returns:
|
||
加载的模型实例
|
||
"""
|
||
# 加载模型配置
|
||
config_path = f"{filepath}_config.json"
|
||
|
||
try:
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
except FileNotFoundError:
|
||
logger.warning(f"未找到模型配置文件: {config_path},将使用默认配置")
|
||
config = {}
|
||
|
||
# 创建模型实例
|
||
model_name = config.get('model_name', 'loaded_model')
|
||
num_classes = config.get('num_classes', 1)
|
||
batch_size = config.get('batch_size', BATCH_SIZE)
|
||
learning_rate = config.get('learning_rate', LEARNING_RATE)
|
||
|
||
instance = cls(num_classes, model_name, batch_size, learning_rate)
|
||
|
||
# 加载Keras模型
|
||
instance.model = load_model(filepath, custom_objects=custom_objects)
|
||
|
||
# 加载配置
|
||
instance.config = config
|
||
|
||
logger.info(f"从 {filepath} 加载模型成功")
|
||
|
||
return instance
|
||
|
||
def _get_default_callbacks(self) -> List[tf.keras.callbacks.Callback]:
|
||
"""获取默认的回调函数列表"""
|
||
# 早停
|
||
early_stopping = tf.keras.callbacks.EarlyStopping(
|
||
monitor='val_loss',
|
||
patience=EARLY_STOPPING_PATIENCE,
|
||
restore_best_weights=True,
|
||
verbose=1
|
||
)
|
||
|
||
# 学习率衰减
|
||
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
|
||
monitor='val_loss',
|
||
factor=REDUCE_LR_FACTOR,
|
||
patience=REDUCE_LR_PATIENCE,
|
||
min_lr=1e-6,
|
||
verbose=1
|
||
)
|
||
|
||
# 模型检查点
|
||
checkpoint_path = os.path.join(SAVED_MODELS_DIR, 'checkpoints', self.model_name)
|
||
ensure_dir(os.path.dirname(checkpoint_path))
|
||
|
||
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||
filepath=checkpoint_path,
|
||
save_best_only=True,
|
||
monitor='val_loss',
|
||
verbose=1
|
||
)
|
||
|
||
# TensorBoard日志
|
||
log_dir = os.path.join(SAVED_MODELS_DIR, 'logs', f"{self.model_name}_{time.strftime('%Y%m%d_%H%M%S')}")
|
||
ensure_dir(log_dir)
|
||
|
||
tensorboard = tf.keras.callbacks.TensorBoard(
|
||
log_dir=log_dir,
|
||
histogram_freq=1
|
||
)
|
||
|
||
return [early_stopping, reduce_lr, model_checkpoint, tensorboard]
|
||
|
||
def get_config(self) -> Dict[str, Any]:
|
||
"""获取模型配置"""
|
||
return self.config.copy()
|
||
|
||
def get_model(self) -> Model:
|
||
"""获取Keras模型实例"""
|
||
return self.model
|
||
|
||
def get_training_history(self) -> Optional[Dict[str, List[float]]]:
|
||
"""获取训练历史"""
|
||
return self.history
|
||
|
||
def plot_training_history(self, save_path: Optional[str] = None,
|
||
metrics: Optional[List[str]] = None) -> None:
|
||
"""
|
||
绘制训练历史
|
||
|
||
Args:
|
||
save_path: 保存路径,如果为None则显示图像
|
||
metrics: 要绘制的指标列表,默认为['loss', 'accuracy']
|
||
"""
|
||
if self.history is None:
|
||
raise ValueError("模型尚未训练,没有训练历史")
|
||
|
||
import matplotlib.pyplot as plt
|
||
|
||
if metrics is None:
|
||
metrics = ['loss', 'accuracy']
|
||
|
||
# 创建图形
|
||
plt.figure(figsize=(12, 5))
|
||
|
||
# 绘制指标
|
||
for i, metric in enumerate(metrics):
|
||
plt.subplot(1, len(metrics), i + 1)
|
||
|
||
if metric in self.history:
|
||
plt.plot(self.history[metric], label=f'train_{metric}')
|
||
|
||
val_metric = f'val_{metric}'
|
||
if val_metric in self.history:
|
||
plt.plot(self.history[val_metric], label=f'val_{metric}')
|
||
|
||
plt.title(f'Model {metric}')
|
||
plt.xlabel('Epoch')
|
||
plt.ylabel(metric)
|
||
plt.legend()
|
||
|
||
plt.tight_layout()
|
||
|
||
# 保存或显示图像
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"训练历史图已保存到: {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
================================================================================
|
||
文件: models/rnn_model.py
|
||
================================================================================
|
||
|
||
"""
|
||
RNN模型:实现基于循环神经网络的文本分类模型
|
||
"""
|
||
import tensorflow as tf
|
||
from tensorflow.keras.models import Model
|
||
from tensorflow.keras.layers import (
|
||
Input, Embedding, LSTM, GRU, Bidirectional, Dense, Dropout,
|
||
BatchNormalization, Activation, GlobalMaxPooling1D, GlobalAveragePooling1D
|
||
)
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import numpy as np
|
||
|
||
from config.model_config import (
|
||
MAX_SEQUENCE_LENGTH, RNN_CONFIG
|
||
)
|
||
from models.base_model import TextClassificationModel
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("RNNModel")
|
||
|
||
|
||
class RNNTextClassifier(TextClassificationModel):
|
||
"""循环神经网络文本分类模型"""
|
||
|
||
def __init__(self, num_classes: int, vocab_size: int,
|
||
embedding_dim: int = RNN_CONFIG["embedding_dim"],
|
||
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
|
||
hidden_size: int = RNN_CONFIG["hidden_size"],
|
||
num_layers: int = RNN_CONFIG["num_layers"],
|
||
bidirectional: bool = RNN_CONFIG["bidirectional"],
|
||
rnn_type: str = "lstm", # 'lstm' or 'gru'
|
||
dropout_rate: float = RNN_CONFIG["dropout_rate"],
|
||
embedding_matrix: Optional[np.ndarray] = None,
|
||
trainable_embedding: bool = True,
|
||
pool_type: str = "max", # 'max', 'avg', or 'both'
|
||
model_name: str = "rnn_text_classifier",
|
||
batch_size: int = 64,
|
||
learning_rate: float = 0.001):
|
||
"""
|
||
初始化RNN文本分类模型
|
||
|
||
Args:
|
||
num_classes: 类别数量
|
||
vocab_size: 词汇表大小
|
||
embedding_dim: 词嵌入维度
|
||
max_sequence_length: 最大序列长度
|
||
hidden_size: 隐藏层大小
|
||
num_layers: RNN层数
|
||
bidirectional: 是否使用双向RNN
|
||
rnn_type: RNN类型,'lstm'或'gru'
|
||
dropout_rate: Dropout比例
|
||
embedding_matrix: 预训练词嵌入矩阵,如果为None则使用随机初始化
|
||
trainable_embedding: 词嵌入是否可训练
|
||
pool_type: 池化类型,'max'、'avg'或'both'
|
||
model_name: 模型名称
|
||
batch_size: 批大小
|
||
learning_rate: 学习率
|
||
"""
|
||
super().__init__(num_classes, model_name, batch_size, learning_rate)
|
||
|
||
self.vocab_size = vocab_size
|
||
self.embedding_dim = embedding_dim
|
||
self.max_sequence_length = max_sequence_length
|
||
self.hidden_size = hidden_size
|
||
self.num_layers = num_layers
|
||
self.bidirectional = bidirectional
|
||
self.rnn_type = rnn_type.lower()
|
||
self.dropout_rate = dropout_rate
|
||
self.embedding_matrix = embedding_matrix
|
||
self.trainable_embedding = trainable_embedding
|
||
self.pool_type = pool_type
|
||
|
||
# 验证RNN类型
|
||
if self.rnn_type not in ["lstm", "gru"]:
|
||
raise ValueError("无效的RNN类型,支持的类型: 'lstm', 'gru'")
|
||
|
||
# 验证池化类型
|
||
if self.pool_type not in ["max", "avg", "both"]:
|
||
raise ValueError("无效的池化类型,支持的类型: 'max', 'avg', 'both'")
|
||
|
||
# 更新配置
|
||
self.config.update({
|
||
"vocab_size": vocab_size,
|
||
"embedding_dim": embedding_dim,
|
||
"max_sequence_length": max_sequence_length,
|
||
"hidden_size": hidden_size,
|
||
"num_layers": num_layers,
|
||
"bidirectional": bidirectional,
|
||
"rnn_type": rnn_type,
|
||
"dropout_rate": dropout_rate,
|
||
"trainable_embedding": trainable_embedding,
|
||
"pool_type": pool_type,
|
||
"model_type": "RNN"
|
||
})
|
||
|
||
logger.info(f"初始化RNN文本分类模型,类型: {rnn_type.upper()}, 隐藏层大小: {hidden_size}, 层数: {num_layers}")
|
||
|
||
def build(self) -> None:
|
||
"""构建RNN模型架构"""
|
||
# Input layer
|
||
sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
|
||
|
||
# Embedding layer
|
||
if self.embedding_matrix is not None:
|
||
embedding_layer = Embedding(
|
||
input_dim=self.vocab_size,
|
||
output_dim=self.embedding_dim,
|
||
weights=[self.embedding_matrix],
|
||
input_length=self.max_sequence_length,
|
||
trainable=self.trainable_embedding,
|
||
name='embedding'
|
||
)
|
||
else:
|
||
embedding_layer = Embedding(
|
||
input_dim=self.vocab_size,
|
||
output_dim=self.embedding_dim,
|
||
input_length=self.max_sequence_length,
|
||
trainable=True,
|
||
name='embedding'
|
||
)
|
||
|
||
embedded_sequences = embedding_layer(sequence_input)
|
||
|
||
# 选择RNN层类型
|
||
if self.rnn_type == "lstm":
|
||
rnn_layer = LSTM
|
||
else: # gru
|
||
rnn_layer = GRU
|
||
|
||
# 构建多层RNN
|
||
x = embedded_sequences
|
||
for i in range(self.num_layers):
|
||
return_sequences = i < self.num_layers - 1 or self.pool_type != "last"
|
||
|
||
if self.bidirectional:
|
||
x = Bidirectional(
|
||
rnn_layer(
|
||
self.hidden_size,
|
||
return_sequences=return_sequences,
|
||
dropout=self.dropout_rate if i < self.num_layers - 1 else 0,
|
||
name=f'{self.rnn_type}_{i + 1}'
|
||
)
|
||
)(x)
|
||
else:
|
||
x = rnn_layer(
|
||
self.hidden_size,
|
||
return_sequences=return_sequences,
|
||
dropout=self.dropout_rate if i < self.num_layers - 1 else 0,
|
||
name=f'{self.rnn_type}_{i + 1}'
|
||
)(x)
|
||
|
||
# 根据池化类型选择池化方法
|
||
if self.pool_type == "max":
|
||
# 使用全局最大池化
|
||
pooled = GlobalMaxPooling1D(name='global_max_pooling')(x)
|
||
elif self.pool_type == "avg":
|
||
# 使用全局平均池化
|
||
pooled = GlobalAveragePooling1D(name='global_avg_pooling')(x)
|
||
elif self.pool_type == "both":
|
||
# 同时使用最大池化和平均池化,然后拼接
|
||
max_pooled = GlobalMaxPooling1D(name='global_max_pooling')(x)
|
||
avg_pooled = GlobalAveragePooling1D(name='global_avg_pooling')(x)
|
||
pooled = tf.keras.layers.Concatenate(name='concatenate')([max_pooled, avg_pooled])
|
||
else: # "last",使用最后一个时间步的输出
|
||
# 最后一层RNN已经返回了最后一个时间步的状态,不需要额外池化
|
||
pooled = x
|
||
|
||
# Dropout for regularization
|
||
x = Dropout(self.dropout_rate, name='dropout_1')(pooled)
|
||
|
||
# Dense layer
|
||
x = Dense(128, name='dense_1')(x)
|
||
x = BatchNormalization(name='batch_norm_1')(x)
|
||
x = Activation('relu', name='activation_1')(x)
|
||
x = Dropout(self.dropout_rate, name='dropout_2')(x)
|
||
|
||
# Output layer
|
||
if self.num_classes == 2:
|
||
# Binary classification
|
||
predictions = Dense(1, activation='sigmoid', name='predictions')(x)
|
||
else:
|
||
# Multi-class classification
|
||
predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
|
||
|
||
# Build the model
|
||
self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
|
||
|
||
logger.info(
|
||
f"RNN模型构建完成,类型: {self.rnn_type.upper()}, 双向: {self.bidirectional}, 池化类型: {self.pool_type}")
|
||
|
||
def compile(self, optimizer=None, loss=None, metrics=None) -> None:
|
||
"""
|
||
编译RNN模型
|
||
|
||
Args:
|
||
optimizer: 优化器,默认为Adam
|
||
loss: 损失函数,默认根据类别数量选择
|
||
metrics: 评估指标,默认为accuracy
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 默认优化器
|
||
if optimizer is None:
|
||
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
|
||
|
||
# 默认损失函数
|
||
if loss is None:
|
||
if self.num_classes == 2:
|
||
loss = 'binary_crossentropy'
|
||
else:
|
||
loss = 'sparse_categorical_crossentropy'
|
||
|
||
# 默认评估指标
|
||
if metrics is None:
|
||
metrics = ['accuracy']
|
||
|
||
# 编译模型
|
||
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
|
||
logger.info(f"RNN模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}")
|
||
|
||
|
||
================================================================================
|
||
文件: models/model_factory.py
|
||
================================================================================
|
||
|
||
"""
|
||
模型工厂:统一创建和管理不同类型的模型
|
||
"""
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import os
|
||
import glob
|
||
import time
|
||
import numpy as np
|
||
|
||
from config.system_config import CLASSIFIERS_DIR
|
||
from config.model_config import (
|
||
BATCH_SIZE, LEARNING_RATE
|
||
)
|
||
from models.base_model import TextClassificationModel
|
||
from models.cnn_model import CNNTextClassifier
|
||
from models.rnn_model import RNNTextClassifier
|
||
from models.transformer_model import TransformerTextClassifier
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("ModelFactory")
|
||
|
||
|
||
class ModelFactory:
|
||
"""模型工厂,用于创建和管理不同类型的模型"""
|
||
|
||
@staticmethod
|
||
def create_model(model_type: str, num_classes: int, vocab_size: int,
|
||
embedding_matrix: Optional[np.ndarray] = None,
|
||
model_config: Optional[Dict[str, Any]] = None,
|
||
**kwargs) -> TextClassificationModel:
|
||
"""
|
||
创建指定类型的模型
|
||
|
||
Args:
|
||
model_type: 模型类型,可选值: 'cnn', 'rnn', 'transformer'
|
||
num_classes: 类别数量
|
||
vocab_size: 词汇表大小
|
||
embedding_matrix: 预训练词嵌入矩阵
|
||
model_config: 模型配置字典
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
创建的模型实例
|
||
"""
|
||
model_type = model_type.lower()
|
||
|
||
# 合并配置
|
||
config = model_config or {}
|
||
config.update(kwargs)
|
||
|
||
# 创建模型
|
||
if model_type == 'cnn':
|
||
model = CNNTextClassifier(
|
||
num_classes=num_classes,
|
||
vocab_size=vocab_size,
|
||
embedding_matrix=embedding_matrix,
|
||
**config
|
||
)
|
||
elif model_type == 'rnn':
|
||
model = RNNTextClassifier(
|
||
num_classes=num_classes,
|
||
vocab_size=vocab_size,
|
||
embedding_matrix=embedding_matrix,
|
||
**config
|
||
)
|
||
elif model_type == 'transformer':
|
||
model = TransformerTextClassifier(
|
||
num_classes=num_classes,
|
||
vocab_size=vocab_size,
|
||
embedding_matrix=embedding_matrix,
|
||
**config
|
||
)
|
||
else:
|
||
raise ValueError(f"不支持的模型类型: {model_type}")
|
||
|
||
logger.info(f"已创建 {model_type.upper()} 模型")
|
||
|
||
return model
|
||
|
||
@staticmethod
|
||
def load_model(model_path: str, custom_objects: Optional[Dict[str, Any]] = None) -> TextClassificationModel:
|
||
"""
|
||
加载保存的模型
|
||
|
||
Args:
|
||
model_path: 模型路径
|
||
custom_objects: 自定义对象字典
|
||
|
||
Returns:
|
||
加载的模型实例
|
||
"""
|
||
# 添加Transformer相关的自定义对象
|
||
if custom_objects is None:
|
||
custom_objects = {}
|
||
|
||
if 'TransformerBlock' not in custom_objects:
|
||
from models.transformer_model import TransformerBlock
|
||
custom_objects['TransformerBlock'] = TransformerBlock
|
||
|
||
# 根据配置确定模型类型
|
||
model_config_path = f"{model_path}_config.json"
|
||
|
||
import json
|
||
with open(model_config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
|
||
model_type = config.get('model_type', '').lower()
|
||
|
||
# 根据模型类型选择加载方法
|
||
if model_type == 'cnn':
|
||
model = CNNTextClassifier.load(model_path, custom_objects)
|
||
elif model_type == 'rnn':
|
||
model = RNNTextClassifier.load(model_path, custom_objects)
|
||
elif model_type == 'transformer':
|
||
model = TransformerTextClassifier.load(model_path, custom_objects)
|
||
else:
|
||
# 如果无法确定模型类型,使用基类加载
|
||
logger.warning(f"无法确定模型类型,使用基类加载: {model_path}")
|
||
model = TextClassificationModel.load(model_path, custom_objects)
|
||
|
||
logger.info(f"已加载模型: {model_path}")
|
||
|
||
return model
|
||
|
||
@staticmethod
|
||
def get_available_models() -> List[Dict[str, Any]]:
|
||
"""
|
||
获取可用的已保存模型列表
|
||
|
||
Returns:
|
||
模型信息列表,每个元素是包含模型信息的字典
|
||
"""
|
||
model_files = glob.glob(os.path.join(CLASSIFIERS_DIR, "*"))
|
||
model_files = [f for f in model_files if not f.endswith("_config.json")]
|
||
|
||
models_info = []
|
||
|
||
for model_file in model_files:
|
||
config_file = f"{model_file}_config.json"
|
||
|
||
if os.path.exists(config_file):
|
||
try:
|
||
import json
|
||
with open(config_file, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
|
||
# 获取模型文件的创建时间
|
||
created_time = os.path.getctime(model_file)
|
||
created_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_time))
|
||
|
||
# 获取模型文件大小
|
||
file_size = os.path.getsize(model_file) / (1024 * 1024) # MB
|
||
|
||
models_info.append({
|
||
"path": model_file,
|
||
"name": config.get("model_name", os.path.basename(model_file)),
|
||
"type": config.get("model_type", "unknown"),
|
||
"num_classes": config.get("num_classes", 0),
|
||
"created_time": created_time_str,
|
||
"file_size": f"{file_size:.2f} MB",
|
||
"config": config
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"读取模型配置失败: {config_file}, 错误: {e}")
|
||
|
||
# 按创建时间降序排序
|
||
models_info.sort(key=lambda x: x.get("created_time", ""), reverse=True)
|
||
|
||
return models_info
|
||
|
||
|
||
================================================================================
|
||
文件: models/cnn_model.py
|
||
================================================================================
|
||
|
||
"""
|
||
CNN模型:实现基于卷积神经网络的文本分类模型
|
||
"""
|
||
import tensorflow as tf
|
||
from tensorflow.keras.models import Model
|
||
from tensorflow.keras.layers import (
|
||
Input, Embedding, Conv1D, MaxPooling1D, GlobalMaxPooling1D,
|
||
Dense, Dropout, Concatenate, BatchNormalization, Activation
|
||
)
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
|
||
from config.model_config import (
|
||
MAX_SEQUENCE_LENGTH, CNN_CONFIG
|
||
)
|
||
from models.base_model import TextClassificationModel
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("CNNModel")
|
||
|
||
|
||
class CNNTextClassifier(TextClassificationModel):
|
||
"""卷积神经网络文本分类模型"""
|
||
|
||
def __init__(self, num_classes: int, vocab_size: int,
|
||
embedding_dim: int = CNN_CONFIG["embedding_dim"],
|
||
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
|
||
num_filters: int = CNN_CONFIG["num_filters"],
|
||
filter_sizes: List[int] = CNN_CONFIG["filter_sizes"],
|
||
dropout_rate: float = CNN_CONFIG["dropout_rate"],
|
||
l2_reg_lambda: float = CNN_CONFIG["l2_reg_lambda"],
|
||
embedding_matrix: Optional[np.ndarray] = None,
|
||
trainable_embedding: bool = True,
|
||
model_name: str = "cnn_text_classifier",
|
||
batch_size: int = 64,
|
||
learning_rate: float = 0.001):
|
||
"""
|
||
初始化CNN文本分类模型
|
||
|
||
Args:
|
||
num_classes: 类别数量
|
||
vocab_size: 词汇表大小
|
||
embedding_dim: 词嵌入维度
|
||
max_sequence_length: 最大序列长度
|
||
num_filters: 卷积核数量
|
||
filter_sizes: 卷积核大小列表
|
||
dropout_rate: Dropout比例
|
||
l2_reg_lambda: L2正则化系数
|
||
embedding_matrix: 预训练词嵌入矩阵,如果为None则使用随机初始化
|
||
trainable_embedding: 词嵌入是否可训练
|
||
model_name: 模型名称
|
||
batch_size: 批大小
|
||
learning_rate: 学习率
|
||
"""
|
||
super().__init__(num_classes, model_name, batch_size, learning_rate)
|
||
|
||
self.vocab_size = vocab_size
|
||
self.embedding_dim = embedding_dim
|
||
self.max_sequence_length = max_sequence_length
|
||
self.num_filters = num_filters
|
||
self.filter_sizes = filter_sizes
|
||
self.dropout_rate = dropout_rate
|
||
self.l2_reg_lambda = l2_reg_lambda
|
||
self.embedding_matrix = embedding_matrix
|
||
self.trainable_embedding = trainable_embedding
|
||
|
||
# 更新配置
|
||
self.config.update({
|
||
"vocab_size": vocab_size,
|
||
"embedding_dim": embedding_dim,
|
||
"max_sequence_length": max_sequence_length,
|
||
"num_filters": num_filters,
|
||
"filter_sizes": filter_sizes,
|
||
"dropout_rate": dropout_rate,
|
||
"l2_reg_lambda": l2_reg_lambda,
|
||
"trainable_embedding": trainable_embedding,
|
||
"model_type": "CNN"
|
||
})
|
||
|
||
logger.info(f"初始化CNN文本分类模型,词汇表大小: {vocab_size}, 嵌入维度: {embedding_dim}")
|
||
|
||
def build(self) -> None:
|
||
"""构建CNN模型架构"""
|
||
# Input layer
|
||
sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input')
|
||
|
||
# Embedding layer
|
||
if self.embedding_matrix is not None:
|
||
embedding_layer = Embedding(
|
||
input_dim=self.vocab_size,
|
||
output_dim=self.embedding_dim,
|
||
weights=[self.embedding_matrix],
|
||
input_length=self.max_sequence_length,
|
||
trainable=self.trainable_embedding,
|
||
name='embedding'
|
||
)
|
||
else:
|
||
embedding_layer = Embedding(
|
||
input_dim=self.vocab_size,
|
||
output_dim=self.embedding_dim,
|
||
input_length=self.max_sequence_length,
|
||
trainable=True,
|
||
name='embedding'
|
||
)
|
||
|
||
embedded_sequences = embedding_layer(sequence_input)
|
||
|
||
# Convolutional layers with different filter sizes
|
||
conv_blocks = []
|
||
for filter_size in self.filter_sizes:
|
||
conv = Conv1D(
|
||
filters=self.num_filters,
|
||
kernel_size=filter_size,
|
||
padding='valid',
|
||
activation='relu',
|
||
kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg_lambda),
|
||
name=f'conv_{filter_size}'
|
||
)(embedded_sequences)
|
||
|
||
# Max pooling
|
||
pooled = GlobalMaxPooling1D(name=f'max_pooling_{filter_size}')(conv)
|
||
conv_blocks.append(pooled)
|
||
|
||
# Concatenate pooled features if we have multiple filter sizes
|
||
if len(self.filter_sizes) > 1:
|
||
concatenated = Concatenate(name='concatenate')(conv_blocks)
|
||
else:
|
||
concatenated = conv_blocks[0]
|
||
|
||
# Dropout for regularization
|
||
x = Dropout(self.dropout_rate, name='dropout_1')(concatenated)
|
||
|
||
# Dense layer
|
||
x = Dense(128, name='dense_1')(x)
|
||
x = BatchNormalization(name='batch_norm_1')(x)
|
||
x = Activation('relu', name='activation_1')(x)
|
||
x = Dropout(self.dropout_rate, name='dropout_2')(x)
|
||
|
||
# Output layer
|
||
if self.num_classes == 2:
|
||
# Binary classification
|
||
predictions = Dense(1, activation='sigmoid', name='predictions')(x)
|
||
else:
|
||
# Multi-class classification
|
||
predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x)
|
||
|
||
# Build the model
|
||
self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name)
|
||
|
||
logger.info(f"CNN模型构建完成,过滤器大小: {self.filter_sizes}, 每种大小的过滤器数量: {self.num_filters}")
|
||
|
||
def compile(self, optimizer=None, loss=None, metrics=None) -> None:
|
||
"""
|
||
编译CNN模型
|
||
|
||
Args:
|
||
optimizer: 优化器,默认为Adam
|
||
loss: 损失函数,默认根据类别数量选择
|
||
metrics: 评估指标,默认为accuracy
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型尚未构建,请先调用build方法")
|
||
|
||
# 默认优化器
|
||
if optimizer is None:
|
||
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
|
||
|
||
# 默认损失函数
|
||
if loss is None:
|
||
if self.num_classes == 2:
|
||
loss = 'binary_crossentropy'
|
||
else:
|
||
loss = 'sparse_categorical_crossentropy'
|
||
|
||
# 默认评估指标
|
||
if metrics is None:
|
||
metrics = ['accuracy']
|
||
|
||
# 编译模型
|
||
self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
|
||
logger.info(f"CNN模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}")
|
||
|
||
|
||
================================================================================
|
||
文件: models/layers/__init__.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: scripts/predict.py
|
||
================================================================================
|
||
|
||
"""
|
||
预测脚本:使用模型进行预测
|
||
"""
|
||
import os
|
||
import sys
|
||
import time
|
||
import argparse
|
||
import logging
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import numpy as np
|
||
import json
|
||
|
||
# 将项目根目录添加到系统路径
|
||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
sys.path.append(project_root)
|
||
|
||
from config.system_config import (
|
||
CATEGORIES, CLASSIFIERS_DIR
|
||
)
|
||
from models.model_factory import ModelFactory
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
from preprocessing.vectorizer import SequenceVectorizer
|
||
from inference.predictor import Predictor
|
||
from inference.batch_processor import BatchProcessor
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import read_text_file
|
||
|
||
logger = get_logger("Prediction")
|
||
|
||
|
||
def predict_text(text: str, model_path: Optional[str] = None,
|
||
output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]:
|
||
"""
|
||
预测单条文本
|
||
|
||
Args:
|
||
text: 要预测的文本
|
||
model_path: 模型路径,如果为None则使用最新的模型
|
||
output_path: 输出文件路径,如果为None则不保存
|
||
top_k: 返回概率最高的前k个类别
|
||
|
||
Returns:
|
||
预测结果
|
||
"""
|
||
logger.info("开始预测文本")
|
||
|
||
# 1. 加载模型
|
||
if model_path is None:
|
||
# 获取可用模型列表
|
||
models_info = ModelFactory.get_available_models()
|
||
|
||
if not models_info:
|
||
raise ValueError("未找到可用的模型")
|
||
|
||
# 使用最新的模型
|
||
model_path = models_info[0]['path']
|
||
|
||
logger.info(f"加载模型: {model_path}")
|
||
model = ModelFactory.load_model(model_path)
|
||
|
||
# 2. 创建分词器和预测器
|
||
tokenizer = ChineseTokenizer()
|
||
|
||
# 查找向量化器文件
|
||
vectorizer = None
|
||
for model_type in ["cnn", "rnn", "transformer"]:
|
||
vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
|
||
if os.path.exists(vectorizer_path):
|
||
# 加载向量化器
|
||
vectorizer = SequenceVectorizer()
|
||
vectorizer.load(vectorizer_path)
|
||
logger.info(f"加载向量化器: {vectorizer_path}")
|
||
break
|
||
|
||
# 创建预测器
|
||
predictor = Predictor(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
vectorizer=vectorizer,
|
||
class_names=CATEGORIES
|
||
)
|
||
|
||
# 3. 预测
|
||
result = predictor.predict(
|
||
text=text,
|
||
return_top_k=top_k,
|
||
return_probabilities=True
|
||
)
|
||
|
||
# 4. 输出结果
|
||
if top_k > 1:
|
||
logger.info("预测结果:")
|
||
for i, pred in enumerate(result):
|
||
logger.info(f" {i + 1}. {pred['class']} (概率: {pred['probability']:.4f})")
|
||
else:
|
||
logger.info(f"预测结果: {result['class']} (概率: {result['probability']:.4f})")
|
||
|
||
# 5. 保存结果
|
||
if output_path:
|
||
if output_path.endswith('.json'):
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
else:
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
if top_k > 1:
|
||
f.write("rank,class,probability\n")
|
||
for i, pred in enumerate(result):
|
||
f.write(f"{i + 1},{pred['class']},{pred['probability']}\n")
|
||
else:
|
||
f.write(f"class,probability\n")
|
||
f.write(f"{result['class']},{result['probability']}\n")
|
||
|
||
logger.info(f"结果已保存到: {output_path}")
|
||
|
||
return result
|
||
|
||
|
||
def predict_file(file_path: str, model_path: Optional[str] = None,
|
||
output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]:
|
||
"""
|
||
预测文件内容
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
model_path: 模型路径,如果为None则使用最新的模型
|
||
output_path: 输出文件路径,如果为None则不保存
|
||
top_k: 返回概率最高的前k个类别
|
||
|
||
Returns:
|
||
预测结果
|
||
"""
|
||
logger.info(f"开始预测文件: {file_path}")
|
||
|
||
# 检查文件是否存在
|
||
if not os.path.exists(file_path):
|
||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||
|
||
# 读取文件内容
|
||
if file_path.endswith('.txt'):
|
||
# 文本文件
|
||
text = read_text_file(file_path)
|
||
return predict_text(text, model_path, output_path, top_k)
|
||
|
||
elif file_path.endswith(('.csv', '.xls', '.xlsx')):
|
||
# 表格文件
|
||
import pandas as pd
|
||
|
||
if file_path.endswith('.csv'):
|
||
df = pd.read_csv(file_path)
|
||
else:
|
||
df = pd.read_excel(file_path)
|
||
|
||
# 查找可能的文本列
|
||
text_columns = [col for col in df.columns if df[col].dtype == 'object']
|
||
|
||
if not text_columns:
|
||
raise ValueError("文件中没有找到可能的文本列")
|
||
|
||
# 使用第一个文本列
|
||
text_column = text_columns[0]
|
||
logger.info(f"使用文本列: {text_column}")
|
||
|
||
# 1. 加载模型
|
||
if model_path is None:
|
||
# 获取可用模型列表
|
||
models_info = ModelFactory.get_available_models()
|
||
|
||
if not models_info:
|
||
raise ValueError("未找到可用的模型")
|
||
|
||
# 使用最新的模型
|
||
model_path = models_info[0]['path']
|
||
|
||
logger.info(f"加载模型: {model_path}")
|
||
model = ModelFactory.load_model(model_path)
|
||
|
||
# 2. 创建分词器和预测器
|
||
tokenizer = ChineseTokenizer()
|
||
|
||
# 查找向量化器文件
|
||
vectorizer = None
|
||
for model_type in ["cnn", "rnn", "transformer"]:
|
||
vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
|
||
if os.path.exists(vectorizer_path):
|
||
# 加载向量化器
|
||
vectorizer = SequenceVectorizer()
|
||
vectorizer.load(vectorizer_path)
|
||
logger.info(f"加载向量化器: {vectorizer_path}")
|
||
break
|
||
|
||
# 创建预测器
|
||
predictor = Predictor(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
vectorizer=vectorizer,
|
||
class_names=CATEGORIES
|
||
)
|
||
|
||
# 3. 创建批处理器
|
||
batch_processor = BatchProcessor(
|
||
predictor=predictor,
|
||
batch_size=64
|
||
)
|
||
|
||
# 4. 批量预测
|
||
result_df = batch_processor.process_dataframe(
|
||
df=df,
|
||
text_column=text_column,
|
||
output_path=output_path,
|
||
return_top_k=top_k,
|
||
format=output_path.split('.')[-1] if output_path else 'csv'
|
||
)
|
||
|
||
logger.info(f"已处理 {len(result_df)} 行数据")
|
||
|
||
# 返回结果
|
||
return result_df.to_dict(orient='records')
|
||
|
||
else:
|
||
raise ValueError(f"不支持的文件类型: {file_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 解析命令行参数
|
||
parser = argparse.ArgumentParser(description="使用模型预测")
|
||
parser.add_argument("--model_path", help="模型路径")
|
||
parser.add_argument("--text", help="要预测的文本")
|
||
parser.add_argument("--file", help="要预测的文件")
|
||
parser.add_argument("--output", help="输出文件")
|
||
parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 检查输入
|
||
if not args.text and not args.file:
|
||
parser.error("请提供要预测的文本或文件")
|
||
|
||
# 预测
|
||
if args.text:
|
||
predict_text(args.text, args.model_path, args.output, args.top_k)
|
||
else:
|
||
predict_file(args.file, args.model_path, args.output, args.top_k)
|
||
|
||
|
||
================================================================================
|
||
文件: scripts/train.py
|
||
================================================================================
|
||
|
||
"""
|
||
训练脚本:训练文本分类模型
|
||
"""
|
||
import os
|
||
import sys
|
||
import time
|
||
import argparse
|
||
import logging
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
import matplotlib.pyplot as plt
|
||
|
||
# 将项目根目录添加到系统路径
|
||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
sys.path.append(project_root)
|
||
|
||
from config.system_config import (
|
||
RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR
|
||
)
|
||
from config.model_config import (
|
||
BATCH_SIZE, NUM_EPOCHS, MAX_SEQUENCE_LENGTH, MAX_NUM_WORDS
|
||
)
|
||
from data.dataloader import DataLoader
|
||
from data.data_manager import DataManager
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
from preprocessing.vectorizer import SequenceVectorizer
|
||
from models.model_factory import ModelFactory
|
||
from training.trainer import Trainer
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("Training")
|
||
|
||
|
||
def train_model(data_dir: Optional[str] = None,
|
||
model_type: str = "cnn",
|
||
epochs: int = NUM_EPOCHS,
|
||
batch_size: int = BATCH_SIZE,
|
||
save_dir: Optional[str] = None,
|
||
validation_split: float = 0.1,
|
||
use_pretrained_embedding: bool = False,
|
||
embedding_path: Optional[str] = None) -> str:
|
||
"""
|
||
训练文本分类模型
|
||
|
||
Args:
|
||
data_dir: 数据目录,如果为None则使用默认目录
|
||
model_type: 模型类型,'cnn', 'rnn', 或 'transformer'
|
||
epochs: 训练轮数
|
||
batch_size: 批大小
|
||
save_dir: 模型保存目录,如果为None则使用默认目录
|
||
validation_split: 验证集比例
|
||
use_pretrained_embedding: 是否使用预训练词向量
|
||
embedding_path: 预训练词向量路径
|
||
|
||
Returns:
|
||
保存的模型路径
|
||
"""
|
||
logger.info(f"开始训练 {model_type.upper()} 模型")
|
||
start_time = time.time()
|
||
|
||
# 设置数据目录
|
||
data_dir = data_dir or RAW_DATA_DIR
|
||
|
||
# 设置保存目录
|
||
if save_dir:
|
||
save_dir = os.path.abspath(save_dir)
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
else:
|
||
save_dir = CLASSIFIERS_DIR
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
# 1. 加载数据
|
||
logger.info("加载数据...")
|
||
data_loader = DataLoader(data_dir=data_dir)
|
||
data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR)
|
||
|
||
# 加载和分割数据
|
||
data = data_manager.load_and_split_data(
|
||
data_loader=data_loader,
|
||
val_split=validation_split,
|
||
sample_ratio=1.0,
|
||
save=True
|
||
)
|
||
|
||
# 获取训练集和验证集
|
||
train_texts, train_labels = data_manager.get_data(dataset="train")
|
||
val_texts, val_labels = data_manager.get_data(dataset="val")
|
||
|
||
# 2. 准备数据
|
||
# 创建分词器
|
||
tokenizer = ChineseTokenizer()
|
||
|
||
# 对训练文本进行分词
|
||
logger.info("对文本进行分词...")
|
||
tokenized_train_texts = [tokenizer.tokenize(text, return_string=True) for text in train_texts]
|
||
tokenized_val_texts = [tokenizer.tokenize(text, return_string=True) for text in val_texts]
|
||
|
||
# 创建序列向量化器
|
||
logger.info("创建序列向量化器...")
|
||
vectorizer = SequenceVectorizer(
|
||
max_features=MAX_NUM_WORDS,
|
||
max_sequence_length=MAX_SEQUENCE_LENGTH
|
||
)
|
||
|
||
# 训练向量化器并转换文本
|
||
vectorizer.fit(tokenized_train_texts)
|
||
X_train = vectorizer.transform(tokenized_train_texts)
|
||
X_val = vectorizer.transform(tokenized_val_texts)
|
||
|
||
# 保存向量化器
|
||
vectorizer_path = os.path.join(save_dir, f"vectorizer_{model_type}.pkl")
|
||
vectorizer.save(vectorizer_path)
|
||
logger.info(f"向量化器已保存到: {vectorizer_path}")
|
||
|
||
# 获取一些基本参数
|
||
num_classes = len(CATEGORIES)
|
||
vocab_size = vectorizer.get_vocabulary_size()
|
||
|
||
# 3. 创建模型
|
||
logger.info(f"创建 {model_type.upper()} 模型...")
|
||
|
||
# 加载预训练词向量(如果指定)
|
||
embedding_matrix = None
|
||
if use_pretrained_embedding and embedding_path:
|
||
# 这里简化处理,实际应用中应该加载和处理预训练词向量
|
||
logger.info("加载预训练词向量...")
|
||
embedding_matrix = np.random.random((vocab_size, 200))
|
||
|
||
# 创建模型
|
||
model = ModelFactory.create_model(
|
||
model_type=model_type,
|
||
num_classes=num_classes,
|
||
vocab_size=vocab_size,
|
||
embedding_matrix=embedding_matrix,
|
||
batch_size=batch_size
|
||
)
|
||
|
||
# 构建模型
|
||
model.build()
|
||
model.compile()
|
||
model.summary()
|
||
|
||
# 4. 训练模型
|
||
logger.info("开始训练模型...")
|
||
trainer = Trainer(
|
||
model=model,
|
||
epochs=epochs,
|
||
batch_size=batch_size,
|
||
early_stopping=True,
|
||
tensorboard=True
|
||
)
|
||
|
||
# 训练
|
||
history = trainer.train(
|
||
x_train=X_train,
|
||
y_train=train_labels,
|
||
x_val=X_val,
|
||
y_val=val_labels
|
||
)
|
||
|
||
# 5. 保存模型
|
||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||
model_path = os.path.join(save_dir, f"{model_type}_model_{timestamp}")
|
||
model.save(model_path)
|
||
logger.info(f"模型已保存到: {model_path}")
|
||
|
||
# 6. 绘制训练历史
|
||
logger.info("绘制训练历史...")
|
||
model.plot_training_history(save_path=os.path.join(save_dir, f"training_history_{model_type}_{timestamp}.png"))
|
||
|
||
# 7. 计算训练时间
|
||
train_time = time.time() - start_time
|
||
logger.info(f"模型训练完成,耗时: {train_time:.2f} 秒")
|
||
|
||
return model_path
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 解析命令行参数
|
||
parser = argparse.ArgumentParser(description="训练文本分类模型")
|
||
parser.add_argument("--data_dir", help="数据目录")
|
||
parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型")
|
||
parser.add_argument("--epochs", type=int, default=NUM_EPOCHS, help="训练轮数")
|
||
parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小")
|
||
parser.add_argument("--save_dir", help="模型保存目录")
|
||
parser.add_argument("--validation_split", type=float, default=0.1, help="验证集比例")
|
||
parser.add_argument("--use_pretrained_embedding", action="store_true", help="是否使用预训练词向量")
|
||
parser.add_argument("--embedding_path", help="预训练词向量路径")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 训练模型
|
||
train_model(
|
||
data_dir=args.data_dir,
|
||
model_type=args.model_type,
|
||
epochs=args.epochs,
|
||
batch_size=args.batch_size,
|
||
save_dir=args.save_dir,
|
||
validation_split=args.validation_split,
|
||
use_pretrained_embedding=args.use_pretrained_embedding,
|
||
embedding_path=args.embedding_path
|
||
)
|
||
|
||
|
||
================================================================================
|
||
文件: scripts/evaluate.py
|
||
================================================================================
|
||
|
||
"""
|
||
评估脚本:评估文本分类模型性能
|
||
"""
|
||
import os
|
||
import sys
|
||
import time
|
||
import argparse
|
||
import logging
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
|
||
# 将项目根目录添加到系统路径
|
||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
sys.path.append(project_root)
|
||
|
||
from config.system_config import (
|
||
RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR
|
||
)
|
||
from config.model_config import (
|
||
BATCH_SIZE, MAX_SEQUENCE_LENGTH
|
||
)
|
||
from data.dataloader import DataLoader
|
||
from data.data_manager import DataManager
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
from preprocessing.vectorizer import SequenceVectorizer
|
||
from models.model_factory import ModelFactory
|
||
from evaluation.evaluator import ModelEvaluator
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("Evaluation")
|
||
|
||
|
||
def evaluate_model(model_path: str,
|
||
data_dir: Optional[str] = None,
|
||
batch_size: int = BATCH_SIZE,
|
||
output_dir: Optional[str] = None) -> Dict[str, float]:
|
||
"""
|
||
评估文本分类模型
|
||
|
||
Args:
|
||
model_path: 模型路径
|
||
data_dir: 数据目录,如果为None则使用默认目录
|
||
batch_size: 批大小
|
||
output_dir: 评估结果输出目录,如果为None则使用默认目录
|
||
|
||
Returns:
|
||
评估指标
|
||
"""
|
||
logger.info(f"开始评估模型: {model_path}")
|
||
start_time = time.time()
|
||
|
||
# 设置数据目录
|
||
data_dir = data_dir or RAW_DATA_DIR
|
||
|
||
# 设置输出目录
|
||
if output_dir:
|
||
output_dir = os.path.abspath(output_dir)
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 1. 加载模型
|
||
logger.info("加载模型...")
|
||
model = ModelFactory.load_model(model_path)
|
||
|
||
# 2. 加载数据
|
||
logger.info("加载数据...")
|
||
data_loader = DataLoader(data_dir=data_dir)
|
||
data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR)
|
||
|
||
# 加载测试集
|
||
data_manager.load_data()
|
||
test_texts, test_labels = data_manager.get_data(dataset="test")
|
||
|
||
# 3. 准备数据
|
||
# 创建分词器
|
||
tokenizer = ChineseTokenizer()
|
||
|
||
# 对测试文本进行分词
|
||
logger.info("对文本进行分词...")
|
||
tokenized_test_texts = [tokenizer.tokenize(text, return_string=True) for text in test_texts]
|
||
|
||
# 创建序列向量化器
|
||
logger.info("加载向量化器...")
|
||
# 查找向量化器文件
|
||
vectorizer_path = None
|
||
for model_type in ["cnn", "rnn", "transformer"]:
|
||
path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl")
|
||
if os.path.exists(path):
|
||
vectorizer_path = path
|
||
break
|
||
|
||
if not vectorizer_path:
|
||
# 如果找不到向量化器,创建一个新的
|
||
logger.warning("未找到向量化器,创建一个新的")
|
||
vectorizer = SequenceVectorizer(
|
||
max_features=MAX_NUM_WORDS,
|
||
max_sequence_length=MAX_SEQUENCE_LENGTH
|
||
)
|
||
else:
|
||
# 加载向量化器
|
||
vectorizer = SequenceVectorizer()
|
||
vectorizer.load(vectorizer_path)
|
||
|
||
# 转换测试文本
|
||
X_test = vectorizer.transform(tokenized_test_texts)
|
||
|
||
# 4. 创建评估器
|
||
logger.info("创建评估器...")
|
||
evaluator = ModelEvaluator(
|
||
model=model,
|
||
class_names=CATEGORIES,
|
||
output_dir=output_dir
|
||
)
|
||
|
||
# 5. 评估模型
|
||
logger.info("评估模型...")
|
||
metrics = evaluator.evaluate(X_test, test_labels, batch_size)
|
||
|
||
# 6. 保存评估结果
|
||
logger.info("保存评估结果...")
|
||
evaluator.save_evaluation_results(save_plots=True)
|
||
|
||
# 7. 可视化混淆矩阵
|
||
logger.info("可视化混淆矩阵...")
|
||
cm = evaluator.evaluation_results['confusion_matrix']
|
||
evaluator.metrics.plot_confusion_matrix(
|
||
y_true=test_labels,
|
||
y_pred=np.argmax(model.predict(X_test), axis=1),
|
||
normalize='true',
|
||
save_path=os.path.join(output_dir or os.path.dirname(model_path), "confusion_matrix.png")
|
||
)
|
||
|
||
# 8. 类别性能分析
|
||
logger.info("分析各类别性能...")
|
||
class_performance = evaluator.evaluate_class_performance(X_test, test_labels)
|
||
|
||
# 9. 计算评估时间
|
||
eval_time = time.time() - start_time
|
||
logger.info(f"模型评估完成,耗时: {eval_time:.2f} 秒")
|
||
|
||
# 10. 输出主要指标
|
||
logger.info("主要评估指标:")
|
||
for metric_name, metric_value in metrics.items():
|
||
if metric_name in ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']:
|
||
logger.info(f" {metric_name}: {metric_value:.4f}")
|
||
|
||
return metrics
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 解析命令行参数
|
||
parser = argparse.ArgumentParser(description="评估文本分类模型")
|
||
parser.add_argument("--model_path", required=True, help="模型路径")
|
||
parser.add_argument("--data_dir", help="数据目录")
|
||
parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小")
|
||
parser.add_argument("--output_dir", help="评估结果输出目录")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 评估模型
|
||
evaluate_model(
|
||
model_path=args.model_path,
|
||
data_dir=args.data_dir,
|
||
batch_size=args.batch_size,
|
||
output_dir=args.output_dir
|
||
)
|
||
|
||
|
||
================================================================================
|
||
文件: inference/predictor.py
|
||
================================================================================
|
||
|
||
"""
|
||
预测器模块:实现模型预测功能,支持单条和批量文本预测
|
||
"""
|
||
import os
|
||
import time
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import pandas as pd
|
||
import json
|
||
|
||
from config.system_config import CATEGORY_TO_ID, ID_TO_CATEGORY
|
||
from models.base_model import TextClassificationModel
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
from preprocessing.vectorizer import SequenceVectorizer
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("Predictor")
|
||
|
||
|
||
class Predictor:
|
||
"""预测器,负责加载模型和进行预测"""
|
||
|
||
def __init__(self, model: TextClassificationModel,
|
||
tokenizer: Optional[ChineseTokenizer] = None,
|
||
vectorizer: Optional[SequenceVectorizer] = None,
|
||
class_names: Optional[List[str]] = None,
|
||
max_sequence_length: int = 500,
|
||
batch_size: Optional[int] = None):
|
||
"""
|
||
初始化预测器
|
||
|
||
Args:
|
||
model: 已训练的模型实例
|
||
tokenizer: 分词器实例,如果为None则创建一个新的分词器
|
||
vectorizer: 文本向量化器实例,如果为None则表示模型直接接收序列
|
||
class_names: 类别名称列表,如果为None则使用ID_TO_CATEGORY
|
||
max_sequence_length: 最大序列长度
|
||
batch_size: 批大小,如果为None则使用模型默认值
|
||
"""
|
||
self.model = model
|
||
self.tokenizer = tokenizer or ChineseTokenizer()
|
||
self.vectorizer = vectorizer
|
||
self.class_names = class_names
|
||
|
||
if class_names is None and hasattr(model, 'num_classes'):
|
||
# 如果模型具有类别数量信息,从ID_TO_CATEGORY获取类别名称
|
||
self.class_names = [ID_TO_CATEGORY.get(i, str(i)) for i in range(model.num_classes)]
|
||
|
||
self.max_sequence_length = max_sequence_length
|
||
self.batch_size = batch_size or (model.batch_size if hasattr(model, 'batch_size') else 32)
|
||
|
||
logger.info(f"初始化预测器,批大小: {self.batch_size}")
|
||
|
||
def preprocess_text(self, text: str) -> Any:
|
||
"""
|
||
预处理单条文本
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
预处理后的文本表示
|
||
"""
|
||
# 分词
|
||
tokenized_text = self.tokenizer.tokenize(text, return_string=True)
|
||
|
||
# 如果有向量化器,应用向量化
|
||
if self.vectorizer is not None:
|
||
return self.vectorizer.transform([tokenized_text])[0]
|
||
|
||
return tokenized_text
|
||
|
||
def preprocess_texts(self, texts: List[str]) -> Any:
|
||
"""
|
||
批量预处理文本
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
|
||
Returns:
|
||
预处理后的批量文本表示
|
||
"""
|
||
# 分词
|
||
tokenized_texts = [self.tokenizer.tokenize(text, return_string=True) for text in texts]
|
||
|
||
# 如果有向量化器,应用向量化
|
||
if self.vectorizer is not None:
|
||
return self.vectorizer.transform(tokenized_texts)
|
||
|
||
return tokenized_texts
|
||
|
||
def predict(self, text: str, return_top_k: int = 1,
|
||
return_probabilities: bool = False) -> Union[str, Dict, List]:
|
||
"""
|
||
预测单条文本的类别
|
||
|
||
Args:
|
||
text: 原始文本
|
||
return_top_k: 返回概率最高的前k个类别
|
||
return_probabilities: 是否返回概率值
|
||
|
||
Returns:
|
||
预测结果,格式取决于参数设置
|
||
"""
|
||
# 预处理文本
|
||
processed_text = self.preprocess_text(text)
|
||
|
||
# 添加批次维度
|
||
if isinstance(processed_text, str):
|
||
input_data = np.array([processed_text])
|
||
else:
|
||
input_data = np.expand_dims(processed_text, axis=0)
|
||
|
||
# 预测
|
||
start_time = time.time()
|
||
predictions = self.model.predict(input_data)
|
||
prediction_time = time.time() - start_time
|
||
|
||
# 获取前k个预测结果
|
||
if return_top_k > 1:
|
||
top_indices = np.argsort(predictions[0])[::-1][:return_top_k]
|
||
top_probs = predictions[0][top_indices]
|
||
|
||
if self.class_names:
|
||
top_classes = [self.class_names[idx] for idx in top_indices]
|
||
else:
|
||
top_classes = [str(idx) for idx in top_indices]
|
||
|
||
if return_probabilities:
|
||
return [{'class': cls, 'probability': float(prob)}
|
||
for cls, prob in zip(top_classes, top_probs)]
|
||
else:
|
||
return top_classes
|
||
else:
|
||
# 获取最高概率的类别
|
||
pred_idx = np.argmax(predictions[0])
|
||
pred_prob = float(predictions[0][pred_idx])
|
||
|
||
if self.class_names:
|
||
pred_class = self.class_names[pred_idx]
|
||
else:
|
||
pred_class = str(pred_idx)
|
||
|
||
if return_probabilities:
|
||
return {'class': pred_class, 'probability': pred_prob, 'time': prediction_time}
|
||
else:
|
||
return pred_class
|
||
|
||
def predict_batch(self, texts: List[str], return_top_k: int = 1,
|
||
return_probabilities: bool = False) -> List:
|
||
"""
|
||
批量预测文本类别
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
return_top_k: 返回概率最高的前k个类别
|
||
return_probabilities: 是否返回概率值
|
||
|
||
Returns:
|
||
预测结果列表
|
||
"""
|
||
# 空列表检查
|
||
if not texts:
|
||
return []
|
||
|
||
# 预处理文本
|
||
processed_texts = self.preprocess_texts(texts)
|
||
|
||
# 预测
|
||
start_time = time.time()
|
||
predictions = self.model.predict(processed_texts, batch_size=self.batch_size)
|
||
prediction_time = time.time() - start_time
|
||
|
||
# 处理预测结果
|
||
results = []
|
||
|
||
for i, pred in enumerate(predictions):
|
||
if return_top_k > 1:
|
||
top_indices = np.argsort(pred)[::-1][:return_top_k]
|
||
top_probs = pred[top_indices]
|
||
|
||
if self.class_names:
|
||
top_classes = [self.class_names[idx] for idx in top_indices]
|
||
else:
|
||
top_classes = [str(idx) for idx in top_indices]
|
||
|
||
if return_probabilities:
|
||
results.append([{'class': cls, 'probability': float(prob)}
|
||
for cls, prob in zip(top_classes, top_probs)])
|
||
else:
|
||
results.append(top_classes)
|
||
else:
|
||
# 获取最高概率的类别
|
||
pred_idx = np.argmax(pred)
|
||
pred_prob = float(pred[pred_idx])
|
||
|
||
if self.class_names:
|
||
pred_class = self.class_names[pred_idx]
|
||
else:
|
||
pred_class = str(pred_idx)
|
||
|
||
if return_probabilities:
|
||
results.append({'class': pred_class, 'probability': pred_prob})
|
||
else:
|
||
results.append(pred_class)
|
||
|
||
logger.info(f"批量预测 {len(texts)} 条文本完成,用时: {prediction_time:.2f} 秒")
|
||
|
||
return results
|
||
|
||
def predict_to_dataframe(self, texts: List[str],
|
||
text_ids: Optional[List[Union[str, int]]] = None,
|
||
return_top_k: int = 1) -> pd.DataFrame:
|
||
"""
|
||
批量预测并返回DataFrame
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
text_ids: 文本ID列表,如果为None则使用索引
|
||
return_top_k: 返回概率最高的前k个类别
|
||
|
||
Returns:
|
||
预测结果DataFrame
|
||
"""
|
||
# 预测
|
||
predictions = self.predict_batch(texts, return_top_k=return_top_k, return_probabilities=True)
|
||
|
||
# 创建DataFrame
|
||
if text_ids is None:
|
||
text_ids = list(range(len(texts)))
|
||
|
||
if return_top_k > 1:
|
||
# 多个类别的情况
|
||
results = []
|
||
for i, preds in enumerate(predictions):
|
||
for j, pred in enumerate(preds):
|
||
results.append({
|
||
'id': text_ids[i],
|
||
'text': texts[i],
|
||
'rank': j + 1,
|
||
'predicted_class': pred['class'],
|
||
'probability': pred['probability']
|
||
})
|
||
df = pd.DataFrame(results)
|
||
else:
|
||
# 单个类别的情况
|
||
df = pd.DataFrame({
|
||
'id': text_ids,
|
||
'text': texts,
|
||
'predicted_class': [pred['class'] for pred in predictions],
|
||
'probability': [pred['probability'] for pred in predictions]
|
||
})
|
||
|
||
return df
|
||
|
||
def save_predictions(self, texts: List[str],
|
||
output_path: str,
|
||
text_ids: Optional[List[Union[str, int]]] = None,
|
||
return_top_k: int = 1,
|
||
format: str = 'csv') -> str:
|
||
"""
|
||
批量预测并保存结果
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
output_path: 输出文件路径
|
||
text_ids: 文本ID列表,如果为None则使用索引
|
||
return_top_k: 返回概率最高的前k个类别
|
||
format: 输出格式,'csv'或'json'
|
||
|
||
Returns:
|
||
输出文件路径
|
||
"""
|
||
# 获取预测结果DataFrame
|
||
df = self.predict_to_dataframe(texts, text_ids, return_top_k)
|
||
|
||
# 保存结果
|
||
if format.lower() == 'csv':
|
||
df.to_csv(output_path, index=False, encoding='utf-8')
|
||
elif format.lower() == 'json':
|
||
# 转换为嵌套的JSON格式
|
||
if return_top_k > 1:
|
||
# 分组后转换为嵌套格式
|
||
result = {}
|
||
for id_val in df['id'].unique():
|
||
sub_df = df[df['id'] == id_val]
|
||
predictions = []
|
||
for _, row in sub_df.iterrows():
|
||
predictions.append({
|
||
'class': row['predicted_class'],
|
||
'probability': row['probability']
|
||
})
|
||
result[str(id_val)] = {
|
||
'text': sub_df.iloc[0]['text'],
|
||
'predictions': predictions
|
||
}
|
||
else:
|
||
# 直接构建JSON
|
||
result = {}
|
||
for _, row in df.iterrows():
|
||
result[str(row['id'])] = {
|
||
'text': row['text'],
|
||
'predicted_class': row['predicted_class'],
|
||
'probability': row['probability']
|
||
}
|
||
|
||
# 保存为JSON
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
else:
|
||
raise ValueError(f"不支持的输出格式: {format}")
|
||
|
||
logger.info(f"预测结果已保存到: {output_path}")
|
||
|
||
return output_path
|
||
|
||
|
||
================================================================================
|
||
文件: inference/batch_processor.py
|
||
================================================================================
|
||
|
||
"""
|
||
批处理模块:实现批量处理大规模文本数据
|
||
"""
|
||
import os
|
||
import time
|
||
import pandas as pd
|
||
import numpy as np
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Callable, Iterator
|
||
import concurrent.futures
|
||
from tqdm import tqdm
|
||
import glob
|
||
import json
|
||
|
||
from config.system_config import ENCODING, DATA_LOADING_WORKERS, MAX_TEXT_PER_BATCH
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import read_text_file, ensure_dir
|
||
from inference.predictor import Predictor
|
||
|
||
logger = get_logger("BatchProcessor")
|
||
|
||
|
||
class BatchProcessor:
|
||
"""批处理器,负责批量处理大规模文本数据"""
|
||
|
||
def __init__(self, predictor: Predictor,
|
||
batch_size: int = 64,
|
||
max_workers: int = DATA_LOADING_WORKERS,
|
||
max_batch_queue: int = 10):
|
||
"""
|
||
初始化批处理器
|
||
|
||
Args:
|
||
predictor: 预测器实例
|
||
batch_size: 批大小
|
||
max_workers: 最大工作线程数
|
||
max_batch_queue: 最大批次队列长度
|
||
"""
|
||
self.predictor = predictor
|
||
self.batch_size = batch_size
|
||
self.max_workers = max_workers
|
||
self.max_batch_queue = max_batch_queue
|
||
|
||
logger.info(f"初始化批处理器,批大小: {batch_size}, 最大工作线程数: {max_workers}")
|
||
|
||
def _extract_text_from_file(self, file_path: str) -> str:
|
||
"""
|
||
从文件中提取文本
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
|
||
Returns:
|
||
文本内容
|
||
"""
|
||
return read_text_file(file_path, encoding=ENCODING)
|
||
|
||
def _batch_generator(self, texts: List[str], batch_size: int) -> Iterator[List[str]]:
|
||
"""
|
||
生成文本批次
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
batch_size: 批大小
|
||
|
||
Returns:
|
||
文本批次生成器
|
||
"""
|
||
for i in range(0, len(texts), batch_size):
|
||
yield texts[i:i + batch_size]
|
||
|
||
def process_files(self, file_paths: List[str], output_path: Optional[str] = None,
|
||
return_top_k: int = 1, format: str = 'csv') -> pd.DataFrame:
|
||
"""
|
||
批量处理文件
|
||
|
||
Args:
|
||
file_paths: 文件路径列表
|
||
output_path: 输出文件路径,如果为None则不保存
|
||
return_top_k: 返回概率最高的前k个类别
|
||
format: 输出格式,'csv'或'json'
|
||
|
||
Returns:
|
||
预测结果DataFrame
|
||
"""
|
||
logger.info(f"开始批量处理 {len(file_paths)} 个文件")
|
||
start_time = time.time()
|
||
|
||
# 使用线程池并行读取文件
|
||
texts = []
|
||
file_names = []
|
||
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||
future_to_file = {executor.submit(self._extract_text_from_file, file_path): file_path for file_path in
|
||
file_paths}
|
||
|
||
for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(file_paths), desc="读取文件"):
|
||
file_path = future_to_file[future]
|
||
try:
|
||
text = future.result()
|
||
if text:
|
||
texts.append(text)
|
||
file_names.append(os.path.basename(file_path))
|
||
except Exception as e:
|
||
logger.error(f"处理文件 {file_path} 时出错: {e}")
|
||
|
||
# 批量预测
|
||
all_predictions = []
|
||
|
||
for batch in tqdm(self._batch_generator(texts, self.batch_size),
|
||
total=(len(texts) + self.batch_size - 1) // self.batch_size, desc="预测"):
|
||
predictions = self.predictor.predict_batch(batch, return_top_k=return_top_k, return_probabilities=True)
|
||
all_predictions.extend(predictions)
|
||
|
||
# 整合结果
|
||
if return_top_k > 1:
|
||
# 多个类别的情况
|
||
results = []
|
||
for i, preds in enumerate(all_predictions):
|
||
for j, pred in enumerate(preds):
|
||
results.append({
|
||
'file_name': file_names[i],
|
||
'rank': j + 1,
|
||
'predicted_class': pred['class'],
|
||
'probability': pred['probability']
|
||
})
|
||
df = pd.DataFrame(results)
|
||
else:
|
||
# 单个类别的情况
|
||
df = pd.DataFrame({
|
||
'file_name': file_names,
|
||
'predicted_class': [pred['class'] for pred in all_predictions],
|
||
'probability': [pred['probability'] for pred in all_predictions]
|
||
})
|
||
|
||
# 保存结果
|
||
if output_path:
|
||
if format.lower() == 'csv':
|
||
df.to_csv(output_path, index=False, encoding='utf-8')
|
||
elif format.lower() == 'json':
|
||
# 转换为嵌套的JSON格式
|
||
if return_top_k > 1:
|
||
# 分组后转换为嵌套格式
|
||
result = {}
|
||
for file_name in df['file_name'].unique():
|
||
sub_df = df[df['file_name'] == file_name]
|
||
predictions = []
|
||
for _, row in sub_df.iterrows():
|
||
predictions.append({
|
||
'class': row['predicted_class'],
|
||
'probability': row['probability']
|
||
})
|
||
result[file_name] = {
|
||
'predictions': predictions
|
||
}
|
||
else:
|
||
# 直接构建JSON
|
||
result = {}
|
||
for _, row in df.iterrows():
|
||
result[row['file_name']] = {
|
||
'predicted_class': row['predicted_class'],
|
||
'probability': row['probability']
|
||
}
|
||
|
||
# 保存为JSON
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
else:
|
||
raise ValueError(f"不支持的输出格式: {format}")
|
||
|
||
logger.info(f"预测结果已保存到: {output_path}")
|
||
|
||
processing_time = time.time() - start_time
|
||
logger.info(f"批量处理完成,共处理 {len(texts)} 个文件,用时: {processing_time:.2f} 秒")
|
||
|
||
return df
|
||
|
||
def process_directory(self, directory: str, pattern: str = "*.txt",
|
||
output_path: Optional[str] = None,
|
||
return_top_k: int = 1, format: str = 'csv',
|
||
recursive: bool = True) -> pd.DataFrame:
|
||
"""
|
||
批量处理目录中的文件
|
||
|
||
Args:
|
||
directory: 目录路径
|
||
pattern: 文件匹配模式
|
||
output_path: 输出文件路径,如果为None则不保存
|
||
return_top_k: 返回概率最高的前k个类别
|
||
format: 输出格式,'csv'或'json'
|
||
recursive: 是否递归处理子目录
|
||
|
||
Returns:
|
||
预测结果DataFrame
|
||
"""
|
||
# 获取符合模式的文件路径
|
||
if recursive:
|
||
file_paths = glob.glob(os.path.join(directory, "**", pattern), recursive=True)
|
||
else:
|
||
file_paths = glob.glob(os.path.join(directory, pattern))
|
||
|
||
if not file_paths:
|
||
logger.warning(f"在目录 {directory} 中未找到符合模式 {pattern} 的文件")
|
||
return pd.DataFrame()
|
||
|
||
logger.info(f"在目录 {directory} 中找到 {len(file_paths)} 个符合模式 {pattern} 的文件")
|
||
|
||
# 调用process_files处理文件
|
||
return self.process_files(file_paths, output_path, return_top_k, format)
|
||
|
||
def process_dataframe(self, df: pd.DataFrame, text_column: str,
|
||
id_column: Optional[str] = None,
|
||
output_path: Optional[str] = None,
|
||
return_top_k: int = 1, format: str = 'csv') -> pd.DataFrame:
|
||
"""
|
||
批量处理DataFrame中的文本
|
||
|
||
Args:
|
||
df: 输入DataFrame
|
||
text_column: 文本列名
|
||
id_column: ID列名,如果为None则使用索引
|
||
output_path: 输出文件路径,如果为None则不保存
|
||
return_top_k: 返回概率最高的前k个类别
|
||
format: 输出格式,'csv'或'json'
|
||
|
||
Returns:
|
||
预测结果DataFrame
|
||
"""
|
||
# 获取文本和ID
|
||
texts = df[text_column].tolist()
|
||
|
||
if id_column:
|
||
ids = df[id_column].tolist()
|
||
else:
|
||
ids = df.index.tolist()
|
||
|
||
# 批量预测
|
||
result_df = self.predictor.predict_to_dataframe(texts, ids, return_top_k)
|
||
|
||
# 保存结果
|
||
if output_path:
|
||
if format.lower() == 'csv':
|
||
result_df.to_csv(output_path, index=False, encoding='utf-8')
|
||
elif format.lower() == 'json':
|
||
# 转换为嵌套的JSON格式
|
||
self.predictor.save_predictions(texts, output_path, ids, return_top_k, 'json')
|
||
else:
|
||
raise ValueError(f"不支持的输出格式: {format}")
|
||
|
||
logger.info(f"预测结果已保存到: {output_path}")
|
||
|
||
return result_df
|
||
|
||
def process_large_file(self, file_path: str, output_path: Optional[str] = None,
|
||
return_top_k: int = 1, format: str = 'csv',
|
||
chunk_size: int = MAX_TEXT_PER_BATCH,
|
||
delimiter: str = '\n\n') -> None:
|
||
"""
|
||
处理大型文本文件,文件会被分块读取和处理
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
output_path: 输出文件路径,如果为None则不保存
|
||
return_top_k: 返回概率最高的前k个类别
|
||
format: 输出格式,'csv'或'json'
|
||
chunk_size: 每个块的大小(文本数量)
|
||
delimiter: 文本分隔符
|
||
"""
|
||
logger.info(f"开始处理大型文件: {file_path}")
|
||
start_time = time.time()
|
||
|
||
# 读取文件内容
|
||
with open(file_path, 'r', encoding=ENCODING) as f:
|
||
content = f.read()
|
||
|
||
# 分割文本
|
||
texts = content.split(delimiter)
|
||
texts = [text.strip() for text in texts if text.strip()]
|
||
|
||
logger.info(f"文件共包含 {len(texts)} 条文本")
|
||
|
||
# 创建输出文件
|
||
if output_path:
|
||
if format.lower() == 'csv':
|
||
# 创建CSV文件头
|
||
if return_top_k > 1:
|
||
header = "text_id,text,rank,predicted_class,probability\n"
|
||
else:
|
||
header = "text_id,text,predicted_class,probability\n"
|
||
|
||
with open(output_path, 'w', encoding=ENCODING) as f:
|
||
f.write(header)
|
||
elif format.lower() == 'json':
|
||
# 创建JSON文件
|
||
with open(output_path, 'w', encoding=ENCODING) as f:
|
||
f.write('{\n')
|
||
|
||
# 分块处理
|
||
total_chunks = (len(texts) + chunk_size - 1) // chunk_size
|
||
|
||
for i in range(0, len(texts), chunk_size):
|
||
chunk = texts[i:i + chunk_size]
|
||
chunk_ids = list(range(i, i + len(chunk)))
|
||
|
||
logger.info(f"处理第 {i // chunk_size + 1}/{total_chunks} 块,包含 {len(chunk)} 条文本")
|
||
|
||
# 批量预测
|
||
result_df = self.predictor.predict_to_dataframe(chunk, chunk_ids, return_top_k)
|
||
|
||
# 追加到输出文件
|
||
if output_path:
|
||
if format.lower() == 'csv':
|
||
result_df.to_csv(output_path, index=False, encoding=ENCODING, mode='a', header=False)
|
||
elif format.lower() == 'json':
|
||
# 转换为JSON并追加
|
||
if return_top_k > 1:
|
||
# 分组后转换为嵌套格式
|
||
for id_val in result_df['id'].unique():
|
||
sub_df = result_df[result_df['id'] == id_val]
|
||
predictions = []
|
||
for _, row in sub_df.iterrows():
|
||
predictions.append({
|
||
'class': row['predicted_class'],
|
||
'probability': float(row['probability'])
|
||
})
|
||
|
||
json_str = f' "{id_val}": {{\n'
|
||
json_str += f' "text": {json.dumps(sub_df.iloc[0]["text"], ensure_ascii=False)},\n'
|
||
json_str += f' "predictions": {json.dumps(predictions, ensure_ascii=False)}\n'
|
||
json_str += ' },'
|
||
|
||
with open(output_path, 'a', encoding=ENCODING) as f:
|
||
f.write(json_str + '\n')
|
||
else:
|
||
# 直接构建JSON
|
||
for _, row in result_df.iterrows():
|
||
json_str = f' "{row["id"]}": {{\n'
|
||
json_str += f' "text": {json.dumps(row["text"], ensure_ascii=False)},\n'
|
||
json_str += f' "predicted_class": "{row["predicted_class"]}",\n'
|
||
json_str += f' "probability": {float(row["probability"])}\n'
|
||
json_str += ' },'
|
||
|
||
with open(output_path, 'a', encoding=ENCODING) as f:
|
||
f.write(json_str + '\n')
|
||
|
||
# 完成JSON文件
|
||
if output_path and format.lower() == 'json':
|
||
with open(output_path, 'a', encoding=ENCODING) as f:
|
||
f.write('}\n')
|
||
|
||
# 修复JSON文件中的最后一个逗号
|
||
with open(output_path, 'r', encoding=ENCODING) as f:
|
||
content = f.read()
|
||
|
||
content = content.rstrip('\n}')
|
||
content = content.rstrip(',')
|
||
content += '\n}\n'
|
||
|
||
with open(output_path, 'w', encoding=ENCODING) as f:
|
||
f.write(content)
|
||
|
||
processing_time = time.time() - start_time
|
||
logger.info(f"处理大型文件完成,共处理 {len(texts)} 条文本,用时: {processing_time:.2f} 秒")
|
||
|
||
|
||
================================================================================
|
||
文件: inference/__init__.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: evaluation/metrics.py
|
||
================================================================================
|
||
|
||
"""
|
||
评估指标模块:实现各种评估指标
|
||
"""
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from sklearn.metrics import (
|
||
accuracy_score, precision_score, recall_score, f1_score,
|
||
confusion_matrix, classification_report, roc_auc_score,
|
||
precision_recall_curve, average_precision_score
|
||
)
|
||
import matplotlib.pyplot as plt
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
|
||
import pandas as pd
|
||
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("Metrics")
|
||
|
||
|
||
class ClassificationMetrics:
|
||
"""分类评估指标类,计算各种分类评估指标"""
|
||
|
||
def __init__(self, class_names: Optional[List[str]] = None):
|
||
"""
|
||
初始化分类评估指标类
|
||
|
||
Args:
|
||
class_names: 类别名称列表
|
||
"""
|
||
self.class_names = class_names
|
||
|
||
def accuracy(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||
"""
|
||
计算准确率
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
|
||
Returns:
|
||
准确率
|
||
"""
|
||
return accuracy_score(y_true, y_pred)
|
||
|
||
def precision(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
average: str = 'macro') -> Union[float, np.ndarray]:
|
||
"""
|
||
计算精确率
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
|
||
|
||
Returns:
|
||
精确率
|
||
"""
|
||
return precision_score(y_true, y_pred, average=average, zero_division=0)
|
||
|
||
def recall(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
average: str = 'macro') -> Union[float, np.ndarray]:
|
||
"""
|
||
计算召回率
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
|
||
|
||
Returns:
|
||
召回率
|
||
"""
|
||
return recall_score(y_true, y_pred, average=average, zero_division=0)
|
||
|
||
def f1(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
average: str = 'macro') -> Union[float, np.ndarray]:
|
||
"""
|
||
计算F1分数
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
|
||
|
||
Returns:
|
||
F1分数
|
||
"""
|
||
return f1_score(y_true, y_pred, average=average, zero_division=0)
|
||
|
||
def confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
normalize: Optional[str] = None) -> np.ndarray:
|
||
"""
|
||
计算混淆矩阵
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
normalize: 归一化方式,可选值: 'true', 'pred', 'all', None
|
||
|
||
Returns:
|
||
混淆矩阵
|
||
"""
|
||
cm = confusion_matrix(y_true, y_pred)
|
||
|
||
if normalize is not None:
|
||
if normalize == 'true':
|
||
cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
|
||
elif normalize == 'pred':
|
||
cm = cm.astype('float') / cm.sum(axis=0, keepdims=True)
|
||
elif normalize == 'all':
|
||
cm = cm.astype('float') / cm.sum()
|
||
|
||
return cm
|
||
|
||
def plot_confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
normalize: Optional[str] = None,
|
||
figsize: Tuple[int, int] = (10, 8),
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制混淆矩阵
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
normalize: 归一化方式,可选值: 'true', 'pred', 'all', None
|
||
figsize: 图像大小
|
||
save_path: 保存路径,如果为None则显示图像
|
||
"""
|
||
# 计算混淆矩阵
|
||
cm = self.confusion_matrix(y_true, y_pred, normalize)
|
||
|
||
# 确定类别名称
|
||
if self.class_names is None:
|
||
class_names = [str(i) for i in range(cm.shape[0])]
|
||
else:
|
||
class_names = self.class_names
|
||
|
||
# 创建图像
|
||
plt.figure(figsize=figsize)
|
||
|
||
# 使用热图显示混淆矩阵
|
||
im = plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
||
plt.colorbar(im)
|
||
|
||
# 设置坐标轴标签
|
||
plt.xticks(np.arange(cm.shape[1]), class_names, rotation=45, ha='right')
|
||
plt.yticks(np.arange(cm.shape[0]), class_names)
|
||
|
||
# 设置标题
|
||
if normalize is not None:
|
||
plt.title(f"Normalized ({normalize}) Confusion Matrix")
|
||
else:
|
||
plt.title("Confusion Matrix")
|
||
|
||
plt.ylabel('True label')
|
||
plt.xlabel('Predicted label')
|
||
|
||
# 在每个单元格中显示数值
|
||
thresh = cm.max() / 2.0
|
||
for i in range(cm.shape[0]):
|
||
for j in range(cm.shape[1]):
|
||
if normalize is not None:
|
||
plt.text(j, i, f"{cm[i, j]:.2f}",
|
||
ha="center", va="center",
|
||
color="white" if cm[i, j] > thresh else "black")
|
||
else:
|
||
plt.text(j, i, f"{cm[i, j]}",
|
||
ha="center", va="center",
|
||
color="white" if cm[i, j] > thresh else "black")
|
||
|
||
plt.tight_layout()
|
||
|
||
# 保存或显示图像
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"混淆矩阵图已保存到: {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
def classification_report(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
output_dict: bool = False) -> Union[str, Dict]:
|
||
"""
|
||
生成分类报告
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
output_dict: 是否以字典形式返回
|
||
|
||
Returns:
|
||
分类报告
|
||
"""
|
||
target_names = self.class_names if self.class_names else None
|
||
return classification_report(y_true, y_pred,
|
||
target_names=target_names,
|
||
output_dict=output_dict,
|
||
zero_division=0)
|
||
|
||
def auc_roc(self, y_true: np.ndarray, y_prob: np.ndarray,
|
||
multi_class: str = 'ovr') -> Union[float, np.ndarray]:
|
||
"""
|
||
计算AUC-ROC
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_prob: 预测概率
|
||
multi_class: 多分类处理方式,可选值: 'ovr', 'ovo'
|
||
|
||
Returns:
|
||
AUC-ROC
|
||
"""
|
||
try:
|
||
# 如果y_true是one-hot编码,转换为类别索引
|
||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||
y_true = np.argmax(y_true, axis=1)
|
||
|
||
# 多分类
|
||
if y_prob.shape[1] > 2:
|
||
return roc_auc_score(y_true, y_prob, multi_class=multi_class, average='macro')
|
||
# 二分类
|
||
else:
|
||
return roc_auc_score(y_true, y_prob[:, 1])
|
||
except Exception as e:
|
||
logger.error(f"计算AUC-ROC时出错: {e}")
|
||
return 0.0
|
||
|
||
def average_precision(self, y_true: np.ndarray, y_prob: np.ndarray,
|
||
average: str = 'macro') -> Union[float, np.ndarray]:
|
||
"""
|
||
计算平均精确率
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_prob: 预测概率
|
||
average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None
|
||
|
||
Returns:
|
||
平均精确率
|
||
"""
|
||
try:
|
||
# 如果y_true是one-hot编码,转换为类别索引
|
||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||
y_true = np.argmax(y_true, axis=1)
|
||
|
||
# 多分类:使用sklearn的方法
|
||
return average_precision_score(
|
||
tf.keras.utils.to_categorical(y_true, num_classes=y_prob.shape[1]),
|
||
y_prob,
|
||
average=average
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"计算平均精确率时出错: {e}")
|
||
return 0.0
|
||
|
||
def plot_precision_recall_curve(self, y_true: np.ndarray, y_prob: np.ndarray,
|
||
class_id: Optional[int] = None,
|
||
figsize: Tuple[int, int] = (10, 8),
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制精确率-召回率曲线
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_prob: 预测概率
|
||
class_id: 要绘制的类别ID,如果为None则绘制所有类别
|
||
figsize: 图像大小
|
||
save_path: 保存路径,如果为None则显示图像
|
||
"""
|
||
# 如果y_true是one-hot编码,转换为类别索引
|
||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||
y_true = np.argmax(y_true, axis=1)
|
||
|
||
# 创建图像
|
||
plt.figure(figsize=figsize)
|
||
|
||
# 确定要绘制的类别
|
||
if class_id is not None and class_id < y_prob.shape[1]:
|
||
# 绘制指定类别的PR曲线
|
||
y_true_bin = (y_true == class_id).astype(int)
|
||
precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, class_id])
|
||
avg_prec = average_precision_score(y_true_bin, y_prob[:, class_id])
|
||
|
||
plt.step(recall, precision, where='post',
|
||
label=f'Class {class_id} (AP = {avg_prec:.3f})')
|
||
else:
|
||
# 绘制所有类别的PR曲线
|
||
for i in range(y_prob.shape[1]):
|
||
y_true_bin = (y_true == i).astype(int)
|
||
precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, i])
|
||
avg_prec = average_precision_score(y_true_bin, y_prob[:, i])
|
||
|
||
class_name = self.class_names[i] if self.class_names else f"Class {i}"
|
||
plt.step(recall, precision, where='post',
|
||
label=f'{class_name} (AP = {avg_prec:.3f})')
|
||
|
||
plt.xlabel('Recall')
|
||
plt.ylabel('Precision')
|
||
plt.title('Precision-Recall Curve')
|
||
plt.legend(loc='lower left')
|
||
plt.grid(True)
|
||
|
||
# 保存或显示图像
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"精确率-召回率曲线图已保存到: {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
def calculate_all_metrics(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
y_prob: Optional[np.ndarray] = None) -> Dict[str, float]:
|
||
"""
|
||
计算所有评估指标
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
y_prob: 预测概率
|
||
|
||
Returns:
|
||
包含所有评估指标的字典
|
||
"""
|
||
metrics = {}
|
||
|
||
# 基础指标
|
||
metrics['accuracy'] = self.accuracy(y_true, y_pred)
|
||
metrics['precision_macro'] = self.precision(y_true, y_pred, average='macro')
|
||
metrics['recall_macro'] = self.recall(y_true, y_pred, average='macro')
|
||
metrics['f1_macro'] = self.f1(y_true, y_pred, average='macro')
|
||
|
||
# 如果提供了预测概率,计算AUC-ROC和平均精确率
|
||
if y_prob is not None:
|
||
try:
|
||
metrics['auc_roc'] = self.auc_roc(y_true, y_prob)
|
||
metrics['average_precision'] = self.average_precision(y_true, y_prob)
|
||
except Exception as e:
|
||
logger.error(f"计算概率指标时出错: {e}")
|
||
|
||
# 类别级别的指标
|
||
for avg in ['micro', 'weighted']:
|
||
metrics[f'precision_{avg}'] = self.precision(y_true, y_pred, average=avg)
|
||
metrics[f'recall_{avg}'] = self.recall(y_true, y_pred, average=avg)
|
||
metrics[f'f1_{avg}'] = self.f1(y_true, y_pred, average=avg)
|
||
|
||
return metrics
|
||
|
||
def metrics_to_dataframe(self, metrics: Dict[str, float]) -> pd.DataFrame:
|
||
"""
|
||
将评估指标转换为DataFrame
|
||
|
||
Args:
|
||
metrics: 评估指标字典
|
||
|
||
Returns:
|
||
评估指标DataFrame
|
||
"""
|
||
return pd.DataFrame(metrics.items(), columns=['Metric', 'Value']).set_index('Metric')
|
||
|
||
|
||
================================================================================
|
||
文件: evaluation/__init__.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: evaluation/visualization.py
|
||
================================================================================
|
||
|
||
"""
|
||
可视化模块:实现评估结果的可视化
|
||
"""
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import pandas as pd
|
||
import seaborn as sns
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import os
|
||
import itertools
|
||
from sklearn.metrics import roc_curve, precision_recall_curve, auc
|
||
from sklearn.manifold import TSNE
|
||
from sklearn.decomposition import PCA
|
||
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import ensure_dir
|
||
|
||
logger = get_logger("Visualization")
|
||
|
||
|
||
class EvaluationVisualizer:
|
||
"""评估结果可视化类"""
|
||
|
||
def __init__(self, output_dir: Optional[str] = None,
|
||
class_names: Optional[List[str]] = None,
|
||
figsize: Tuple[int, int] = (10, 8)):
|
||
"""
|
||
初始化评估结果可视化类
|
||
|
||
Args:
|
||
output_dir: 输出目录,用于保存可视化结果
|
||
class_names: 类别名称列表
|
||
figsize: 图像默认大小
|
||
"""
|
||
self.output_dir = output_dir
|
||
if output_dir:
|
||
ensure_dir(output_dir)
|
||
|
||
self.class_names = class_names
|
||
self.figsize = figsize
|
||
|
||
def plot_confusion_matrix(self, cm: np.ndarray,
|
||
normalize: Optional[str] = None,
|
||
title: str = 'Confusion Matrix',
|
||
cmap: str = 'Blues',
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制混淆矩阵
|
||
|
||
Args:
|
||
cm: 混淆矩阵
|
||
normalize: 归一化方式,可选值: 'true', 'pred', 'all', None
|
||
title: 图像标题
|
||
cmap: 颜色映射
|
||
save_path: 保存路径,如果为None则使用output_dir/confusion_matrix.png
|
||
"""
|
||
if normalize is not None:
|
||
if normalize == 'true':
|
||
cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
|
||
title = 'Normalized (by true) ' + title
|
||
elif normalize == 'pred':
|
||
cm = cm.astype('float') / cm.sum(axis=0, keepdims=True)
|
||
title = 'Normalized (by pred) ' + title
|
||
elif normalize == 'all':
|
||
cm = cm.astype('float') / cm.sum()
|
||
title = 'Normalized (by all) ' + title
|
||
|
||
plt.figure(figsize=self.figsize)
|
||
|
||
# 使用seaborn绘制热图
|
||
sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd',
|
||
cmap=cmap, square=True, cbar=True)
|
||
|
||
# 设置坐标轴标签
|
||
if self.class_names:
|
||
tick_marks = np.arange(len(self.class_names))
|
||
plt.xticks(tick_marks + 0.5, self.class_names, rotation=45, ha='right')
|
||
plt.yticks(tick_marks + 0.5, self.class_names, rotation=0)
|
||
|
||
plt.title(title)
|
||
plt.ylabel('True label')
|
||
plt.xlabel('Predicted label')
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
if save_path is None and self.output_dir:
|
||
save_path = os.path.join(self.output_dir, 'confusion_matrix.png')
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
logger.info(f"混淆矩阵图已保存到: {save_path}")
|
||
|
||
plt.close()
|
||
|
||
def plot_metrics_comparison(self, metrics_dict: Dict[str, Dict[str, float]],
|
||
selected_metrics: Optional[List[str]] = None,
|
||
title: str = 'Metrics Comparison',
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制多个模型的评估指标比较
|
||
|
||
Args:
|
||
metrics_dict: 模型评估指标字典,格式为{model_name: {metric_name: value}}
|
||
selected_metrics: 要比较的指标列表,如果为None则使用所有指标
|
||
title: 图像标题
|
||
save_path: 保存路径,如果为None则使用output_dir/metrics_comparison.png
|
||
"""
|
||
# 创建DataFrame
|
||
df = pd.DataFrame(metrics_dict).T
|
||
|
||
# 筛选指标
|
||
if selected_metrics:
|
||
df = df[selected_metrics]
|
||
|
||
# 绘制条形图
|
||
plt.figure(figsize=self.figsize)
|
||
df.plot(kind='bar', figsize=self.figsize)
|
||
|
||
plt.title(title)
|
||
plt.ylabel('Score')
|
||
plt.ylim(0, 1)
|
||
plt.legend(loc='best')
|
||
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
if save_path is None and self.output_dir:
|
||
save_path = os.path.join(self.output_dir, 'metrics_comparison.png')
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
logger.info(f"评估指标比较图已保存到: {save_path}")
|
||
|
||
plt.close()
|
||
|
||
def plot_roc_curves(self, y_true: np.ndarray, y_prob: np.ndarray,
|
||
title: str = 'ROC Curves',
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制ROC曲线
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_prob: 预测概率
|
||
title: 图像标题
|
||
save_path: 保存路径,如果为None则使用output_dir/roc_curves.png
|
||
"""
|
||
plt.figure(figsize=self.figsize)
|
||
|
||
# 确保y_true是一维数组
|
||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||
y_true = np.argmax(y_true, axis=1)
|
||
|
||
# 获取类别数
|
||
num_classes = y_prob.shape[1]
|
||
|
||
# 绘制每个类别的ROC曲线
|
||
for i in range(num_classes):
|
||
# 二分类转换:当前类别为正类,其他为负类
|
||
y_true_bin = (y_true == i).astype(int)
|
||
|
||
# 计算ROC曲线
|
||
fpr, tpr, _ = roc_curve(y_true_bin, y_prob[:, i])
|
||
roc_auc = auc(fpr, tpr)
|
||
|
||
# 确定类别名称
|
||
if self.class_names and i < len(self.class_names):
|
||
class_name = self.class_names[i]
|
||
else:
|
||
class_name = f'Class {i}'
|
||
|
||
# 绘制ROC曲线
|
||
plt.plot(fpr, tpr, lw=2, label=f'{class_name} (AUC = {roc_auc:.3f})')
|
||
|
||
# 绘制随机猜测的基准线
|
||
plt.plot([0, 1], [0, 1], 'k--', lw=2)
|
||
|
||
plt.xlim([0.0, 1.0])
|
||
plt.ylim([0.0, 1.05])
|
||
plt.xlabel('False Positive Rate')
|
||
plt.ylabel('True Positive Rate')
|
||
plt.title(title)
|
||
plt.legend(loc='lower right')
|
||
plt.grid(True, alpha=0.3)
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
if save_path is None and self.output_dir:
|
||
save_path = os.path.join(self.output_dir, 'roc_curves.png')
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
logger.info(f"ROC曲线图已保存到: {save_path}")
|
||
|
||
plt.close()
|
||
|
||
def plot_precision_recall_curves(self, y_true: np.ndarray, y_prob: np.ndarray,
|
||
title: str = 'Precision-Recall Curves',
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制精确率-召回率曲线
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_prob: 预测概率
|
||
title: 图像标题
|
||
save_path: 保存路径,如果为None则使用output_dir/precision_recall_curves.png
|
||
"""
|
||
plt.figure(figsize=self.figsize)
|
||
|
||
# 确保y_true是一维数组
|
||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||
y_true = np.argmax(y_true, axis=1)
|
||
|
||
# 获取类别数
|
||
num_classes = y_prob.shape[1]
|
||
|
||
# 绘制每个类别的PR曲线
|
||
for i in range(num_classes):
|
||
# 二分类转换:当前类别为正类,其他为负类
|
||
y_true_bin = (y_true == i).astype(int)
|
||
|
||
# 计算PR曲线
|
||
precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, i])
|
||
pr_auc = auc(recall, precision)
|
||
|
||
# 确定类别名称
|
||
if self.class_names and i < len(self.class_names):
|
||
class_name = self.class_names[i]
|
||
else:
|
||
class_name = f'Class {i}'
|
||
|
||
# 绘制PR曲线
|
||
plt.plot(recall, precision, lw=2, label=f'{class_name} (AUC = {pr_auc:.3f})')
|
||
|
||
plt.xlim([0.0, 1.0])
|
||
plt.ylim([0.0, 1.05])
|
||
plt.xlabel('Recall')
|
||
plt.ylabel('Precision')
|
||
plt.title(title)
|
||
plt.legend(loc='best')
|
||
plt.grid(True, alpha=0.3)
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
if save_path is None and self.output_dir:
|
||
save_path = os.path.join(self.output_dir, 'precision_recall_curves.png')
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
logger.info(f"精确率-召回率曲线图已保存到: {save_path}")
|
||
|
||
plt.close()
|
||
|
||
def plot_feature_importance(self, feature_names: List[str],
|
||
importance: np.ndarray,
|
||
title: str = 'Feature Importance',
|
||
top_n: int = 20,
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制特征重要性
|
||
|
||
Args:
|
||
feature_names: 特征名称列表
|
||
importance: 特征重要性数组
|
||
title: 图像标题
|
||
top_n: 显示前N个重要的特征
|
||
save_path: 保存路径,如果为None则使用output_dir/feature_importance.png
|
||
"""
|
||
# 创建DataFrame
|
||
df = pd.DataFrame({'Feature': feature_names, 'Importance': importance})
|
||
|
||
# 按重要性排序
|
||
df = df.sort_values('Importance', ascending=False).head(top_n)
|
||
|
||
# 绘制条形图
|
||
plt.figure(figsize=self.figsize)
|
||
sns.barplot(x='Importance', y='Feature', data=df)
|
||
|
||
plt.title(title)
|
||
plt.xlabel('Importance')
|
||
plt.ylabel('Feature')
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
if save_path is None and self.output_dir:
|
||
save_path = os.path.join(self.output_dir, 'feature_importance.png')
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
logger.info(f"特征重要性图已保存到: {save_path}")
|
||
|
||
plt.close()
|
||
|
||
def plot_embedding_visualization(self, embeddings: np.ndarray,
|
||
labels: np.ndarray,
|
||
method: str = 'tsne',
|
||
title: str = 'Embedding Visualization',
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制嵌入向量可视化
|
||
|
||
Args:
|
||
embeddings: 嵌入向量,形状为(样本数, 嵌入维度)
|
||
labels: 类别标签,形状为(样本数,)
|
||
method: 降维方法,'tsne'或'pca'
|
||
title: 图像标题
|
||
save_path: 保存路径,如果为None则使用output_dir/embedding_visualization.png
|
||
"""
|
||
# 降维
|
||
if method.lower() == 'tsne':
|
||
reducer = TSNE(n_components=2, random_state=42)
|
||
elif method.lower() == 'pca':
|
||
reducer = PCA(n_components=2, random_state=42)
|
||
else:
|
||
raise ValueError(f"不支持的降维方法: {method}")
|
||
|
||
# 如果嵌入向量太多,采样一部分
|
||
max_samples = 5000
|
||
if len(embeddings) > max_samples:
|
||
indices = np.random.choice(len(embeddings), max_samples, replace=False)
|
||
embeddings_sample = embeddings[indices]
|
||
labels_sample = labels[indices]
|
||
else:
|
||
embeddings_sample = embeddings
|
||
labels_sample = labels
|
||
|
||
# 执行降维
|
||
embeddings_2d = reducer.fit_transform(embeddings_sample)
|
||
|
||
# 绘制散点图
|
||
plt.figure(figsize=self.figsize)
|
||
|
||
# 确保标签是一维数组
|
||
if len(labels_sample.shape) > 1 and labels_sample.shape[1] > 1:
|
||
labels_sample = np.argmax(labels_sample, axis=1)
|
||
|
||
# 获取唯一类别
|
||
unique_labels = np.unique(labels_sample)
|
||
|
||
# 为每个类别分配不同的颜色
|
||
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
|
||
|
||
for i, label in enumerate(unique_labels):
|
||
mask = labels_sample == label
|
||
|
||
# 确定类别名称
|
||
if self.class_names and label < len(self.class_names):
|
||
class_name = self.class_names[int(label)]
|
||
else:
|
||
class_name = f'Class {int(label)}'
|
||
|
||
plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
|
||
c=[colors[i]], label=class_name, alpha=0.7)
|
||
|
||
plt.title(title)
|
||
plt.legend(loc='best')
|
||
plt.grid(True, alpha=0.3)
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
if save_path is None and self.output_dir:
|
||
save_path = os.path.join(self.output_dir, f'embedding_visualization_{method}.png')
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
logger.info(f"嵌入向量可视化图已保存到: {save_path}")
|
||
|
||
plt.close()
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: evaluation/evaluator.py
|
||
================================================================================
|
||
|
||
"""
|
||
评估器模块:实现模型评估流程
|
||
"""
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
import time
|
||
import os
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import json
|
||
|
||
from config.system_config import SAVED_MODELS_DIR
|
||
from models.base_model import TextClassificationModel
|
||
from evaluation.metrics import ClassificationMetrics
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import ensure_dir, save_json
|
||
|
||
logger = get_logger("Evaluator")
|
||
|
||
|
||
class ModelEvaluator:
|
||
"""模型评估器,负责评估模型性能"""
|
||
|
||
def __init__(self, model: TextClassificationModel,
|
||
class_names: Optional[List[str]] = None,
|
||
output_dir: Optional[str] = None,
|
||
batch_size: Optional[int] = None):
|
||
"""
|
||
初始化模型评估器
|
||
|
||
Args:
|
||
model: 要评估的模型
|
||
class_names: 类别名称列表
|
||
output_dir: 输出目录,用于保存评估结果
|
||
batch_size: 批大小,如果为None则使用模型默认值
|
||
"""
|
||
self.model = model
|
||
self.class_names = class_names
|
||
self.batch_size = batch_size or model.batch_size
|
||
|
||
# 设置输出目录
|
||
if output_dir is None:
|
||
self.output_dir = os.path.join(SAVED_MODELS_DIR, 'evaluation', model.model_name)
|
||
else:
|
||
self.output_dir = output_dir
|
||
|
||
ensure_dir(self.output_dir)
|
||
|
||
# 创建评估指标计算器
|
||
self.metrics = ClassificationMetrics(class_names)
|
||
|
||
# 评估结果
|
||
self.evaluation_results = None
|
||
|
||
logger.info(f"初始化模型评估器,模型: {model.model_name}")
|
||
|
||
def evaluate(self, x_test: Union[np.ndarray, tf.data.Dataset],
|
||
y_test: Optional[np.ndarray] = None,
|
||
batch_size: Optional[int] = None,
|
||
verbose: int = 1) -> Dict[str, float]:
|
||
"""
|
||
评估模型
|
||
|
||
Args:
|
||
x_test: 测试数据特征
|
||
y_test: 测试数据标签
|
||
batch_size: 批大小
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
评估结果
|
||
"""
|
||
batch_size = batch_size or self.batch_size
|
||
|
||
logger.info(f"开始评估模型: {self.model.model_name}")
|
||
start_time = time.time()
|
||
|
||
# 使用模型评估
|
||
model_metrics = self.model.evaluate(x_test, y_test, verbose=verbose)
|
||
|
||
# 获取预测结果
|
||
y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=0)
|
||
y_pred = np.argmax(y_prob, axis=1)
|
||
|
||
# 处理y_test,确保y_test是一维数组
|
||
if isinstance(x_test, tf.data.Dataset):
|
||
# 如果是TensorFlow Dataset,需要从中提取y_test
|
||
y_test_extracted = np.concatenate([y for _, y in x_test], axis=0)
|
||
y_test = y_test_extracted
|
||
|
||
if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
|
||
y_test = np.argmax(y_test, axis=1)
|
||
|
||
# 计算所有指标
|
||
all_metrics = self.metrics.calculate_all_metrics(y_test, y_pred, y_prob)
|
||
|
||
# 合并模型内置指标和自定义指标
|
||
metrics_names = self.model.model.metrics_names
|
||
model_metrics_dict = {name: float(value) for name, value in zip(metrics_names, model_metrics)}
|
||
all_metrics.update(model_metrics_dict)
|
||
|
||
# 记录评估时间
|
||
evaluation_time = time.time() - start_time
|
||
all_metrics['evaluation_time'] = evaluation_time
|
||
|
||
# 保存评估结果
|
||
self.evaluation_results = {
|
||
'metrics': all_metrics,
|
||
'confusion_matrix': self.metrics.confusion_matrix(y_test, y_pred).tolist(),
|
||
'classification_report': self.metrics.classification_report(y_test, y_pred, output_dict=True)
|
||
}
|
||
|
||
logger.info(f"模型评估完成,用时: {evaluation_time:.2f} 秒")
|
||
logger.info(f"主要评估指标: accuracy={all_metrics.get('accuracy', 'N/A'):.4f}, "
|
||
f"f1_macro={all_metrics.get('f1_macro', 'N/A'):.4f}")
|
||
|
||
return all_metrics
|
||
|
||
def save_evaluation_results(self, save_plots: bool = True) -> str:
|
||
"""
|
||
保存评估结果
|
||
|
||
Args:
|
||
save_plots: 是否保存可视化图表
|
||
|
||
Returns:
|
||
结果保存路径
|
||
"""
|
||
if self.evaluation_results is None:
|
||
raise ValueError("请先调用evaluate方法进行评估")
|
||
|
||
# 保存评估结果为JSON
|
||
results_path = os.path.join(self.output_dir, 'evaluation_results.json')
|
||
with open(results_path, 'w', encoding='utf-8') as f:
|
||
json.dump(self.evaluation_results, f, ensure_ascii=False, indent=4)
|
||
|
||
# 保存评估指标为CSV
|
||
metrics_df = pd.DataFrame(
|
||
self.evaluation_results['metrics'].items(),
|
||
columns=['Metric', 'Value']
|
||
).set_index('Metric')
|
||
|
||
metrics_path = os.path.join(self.output_dir, 'metrics.csv')
|
||
metrics_df.to_csv(metrics_path)
|
||
|
||
# 保存可视化图表
|
||
if save_plots:
|
||
self._save_plots()
|
||
|
||
logger.info(f"评估结果已保存到: {self.output_dir}")
|
||
|
||
return self.output_dir
|
||
|
||
def _save_plots(self) -> None:
|
||
"""保存评估结果可视化图表"""
|
||
if self.evaluation_results is None:
|
||
raise ValueError("请先调用evaluate方法进行评估")
|
||
|
||
# 创建可视化目录
|
||
plots_dir = os.path.join(self.output_dir, 'plots')
|
||
ensure_dir(plots_dir)
|
||
|
||
# 混淆矩阵图
|
||
cm_path = os.path.join(plots_dir, 'confusion_matrix.png')
|
||
cm = np.array(self.evaluation_results['confusion_matrix'])
|
||
|
||
# 将混淆矩阵转换为NumPy数组
|
||
if isinstance(cm, list):
|
||
cm = np.array(cm)
|
||
|
||
# 绘制混淆矩阵
|
||
self.metrics.plot_confusion_matrix(
|
||
np.arange(cm.shape[0]), # 假设标签
|
||
np.arange(cm.shape[1]), # 假设预测
|
||
normalize='true',
|
||
save_path=cm_path
|
||
)
|
||
|
||
# 保存评估指标条形图
|
||
metrics_path = os.path.join(plots_dir, 'metrics_bar.png')
|
||
metrics = self.evaluation_results['metrics']
|
||
|
||
# 选择要展示的主要指标
|
||
main_metrics = {
|
||
'accuracy': metrics.get('accuracy', 0),
|
||
'precision_macro': metrics.get('precision_macro', 0),
|
||
'recall_macro': metrics.get('recall_macro', 0),
|
||
'f1_macro': metrics.get('f1_macro', 0)
|
||
}
|
||
|
||
# 绘制条形图
|
||
plt.figure(figsize=(10, 6))
|
||
plt.bar(main_metrics.keys(), main_metrics.values())
|
||
plt.title('Main Evaluation Metrics')
|
||
plt.ylabel('Score')
|
||
plt.ylim(0, 1)
|
||
plt.xticks(rotation=45, ha='right')
|
||
plt.tight_layout()
|
||
plt.savefig(metrics_path)
|
||
plt.close()
|
||
|
||
# 如果有类别级别的指标,绘制每个类别的指标
|
||
if 'classification_report' in self.evaluation_results:
|
||
report = self.evaluation_results['classification_report']
|
||
|
||
# 提取每个类别的精确率、召回率和F1值
|
||
class_metrics = {}
|
||
for key, value in report.items():
|
||
if key not in ['accuracy', 'macro avg', 'weighted avg', 'samples avg']:
|
||
if isinstance(value, dict):
|
||
class_metrics[key] = value
|
||
|
||
if class_metrics:
|
||
# 绘制每个类别的F1分数
|
||
class_f1_path = os.path.join(plots_dir, 'class_f1_scores.png')
|
||
|
||
classes = list(class_metrics.keys())
|
||
f1_scores = [metrics['f1-score'] for metrics in class_metrics.values()]
|
||
|
||
plt.figure(figsize=(12, 6))
|
||
bars = plt.bar(classes, f1_scores)
|
||
|
||
# 在柱状图上方显示数值
|
||
for bar in bars:
|
||
height = bar.get_height()
|
||
plt.text(bar.get_x() + bar.get_width() / 2., height + 0.01,
|
||
f'{height:.2f}',
|
||
ha='center', va='bottom', rotation=0)
|
||
|
||
plt.title('F1 Score by Class')
|
||
plt.ylabel('F1 Score')
|
||
plt.ylim(0, 1.1)
|
||
plt.xticks(rotation=45, ha='right')
|
||
plt.tight_layout()
|
||
plt.savefig(class_f1_path)
|
||
plt.close()
|
||
|
||
# 绘制每个类别的精确率和召回率
|
||
class_prec_rec_path = os.path.join(plots_dir, 'class_precision_recall.png')
|
||
|
||
precisions = [metrics['precision'] for metrics in class_metrics.values()]
|
||
recalls = [metrics['recall'] for metrics in class_metrics.values()]
|
||
|
||
plt.figure(figsize=(12, 6))
|
||
x = np.arange(len(classes))
|
||
width = 0.35
|
||
|
||
plt.bar(x - width / 2, precisions, width, label='Precision')
|
||
plt.bar(x + width / 2, recalls, width, label='Recall')
|
||
|
||
plt.ylabel('Score')
|
||
plt.title('Precision and Recall by Class')
|
||
plt.xticks(x, classes, rotation=45, ha='right')
|
||
plt.legend()
|
||
plt.ylim(0, 1.1)
|
||
plt.tight_layout()
|
||
plt.savefig(class_prec_rec_path)
|
||
plt.close()
|
||
|
||
logger.info(f"评估可视化图表已保存到: {plots_dir}")
|
||
|
||
def compare_models(self, other_evaluators: List['ModelEvaluator'],
|
||
metrics: Optional[List[str]] = None,
|
||
save_path: Optional[str] = None) -> pd.DataFrame:
|
||
"""
|
||
比较多个模型的评估结果
|
||
|
||
Args:
|
||
other_evaluators: 其他模型评估器列表
|
||
metrics: 要比较的指标列表,默认为['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
|
||
save_path: 比较结果的保存路径
|
||
|
||
Returns:
|
||
比较结果DataFrame
|
||
"""
|
||
if self.evaluation_results is None:
|
||
raise ValueError("请先调用evaluate方法进行评估")
|
||
|
||
# 默认比较指标
|
||
if metrics is None:
|
||
metrics = ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
|
||
|
||
# 收集所有模型的评估指标
|
||
models_metrics = {}
|
||
|
||
# 当前模型
|
||
models_metrics[self.model.model_name] = {
|
||
metric: self.evaluation_results['metrics'].get(metric, 'N/A')
|
||
for metric in metrics
|
||
}
|
||
|
||
# 其他模型
|
||
for evaluator in other_evaluators:
|
||
if evaluator.evaluation_results is None:
|
||
logger.warning(f"模型 {evaluator.model.model_name} 尚未评估,跳过")
|
||
continue
|
||
|
||
models_metrics[evaluator.model.model_name] = {
|
||
metric: evaluator.evaluation_results['metrics'].get(metric, 'N/A')
|
||
for metric in metrics
|
||
}
|
||
|
||
# 创建比较DataFrame
|
||
comparison_df = pd.DataFrame(models_metrics).T
|
||
|
||
# 保存比较结果
|
||
if save_path:
|
||
comparison_df.to_csv(save_path)
|
||
logger.info(f"模型比较结果已保存到: {save_path}")
|
||
|
||
# 绘制比较条形图
|
||
plt.figure(figsize=(12, 6))
|
||
comparison_df.plot(kind='bar', figsize=(12, 6))
|
||
plt.title('Model Comparison')
|
||
plt.ylabel('Score')
|
||
plt.ylim(0, 1)
|
||
plt.legend(loc='lower right')
|
||
plt.tight_layout()
|
||
|
||
# 如果save_path是CSV文件,将其替换为PNG文件
|
||
if save_path.endswith('.csv'):
|
||
plot_path = save_path.replace('.csv', '.png')
|
||
else:
|
||
plot_path = save_path + '.png'
|
||
|
||
plt.savefig(plot_path)
|
||
plt.close()
|
||
|
||
logger.info(f"模型比较图表已保存到: {plot_path}")
|
||
|
||
return comparison_df
|
||
|
||
def evaluate_class_performance(self, x_test: Union[np.ndarray, tf.data.Dataset],
|
||
y_test: Optional[np.ndarray] = None,
|
||
batch_size: Optional[int] = None,
|
||
verbose: int = 0) -> pd.DataFrame:
|
||
"""
|
||
评估模型在各个类别上的性能
|
||
|
||
Args:
|
||
x_test: 测试数据特征
|
||
y_test: 测试数据标签
|
||
batch_size: 批大小
|
||
verbose: 详细程度
|
||
|
||
Returns:
|
||
各类别性能指标DataFrame
|
||
"""
|
||
batch_size = batch_size or self.batch_size
|
||
|
||
# 获取预测结果
|
||
y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=verbose)
|
||
y_pred = np.argmax(y_prob, axis=1)
|
||
|
||
# 处理y_test,确保y_test是一维数组
|
||
if isinstance(x_test, tf.data.Dataset):
|
||
# 如果是TensorFlow Dataset,需要从中提取y_test
|
||
y_test_extracted = np.concatenate([y for _, y in x_test], axis=0)
|
||
y_test = y_test_extracted
|
||
|
||
if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
|
||
y_test = np.argmax(y_test, axis=1)
|
||
|
||
# 获取分类报告
|
||
report = self.metrics.classification_report(y_test, y_pred, output_dict=True)
|
||
|
||
# 提取各类别指标
|
||
class_metrics = {}
|
||
for key, value in report.items():
|
||
if key not in ['accuracy', 'macro avg', 'weighted avg', 'samples avg']:
|
||
if isinstance(value, dict):
|
||
class_metrics[key] = value
|
||
|
||
# 转换为DataFrame
|
||
class_performance_df = pd.DataFrame(class_metrics).T
|
||
|
||
# 添加支持度(样本数量)
|
||
class_counts = np.bincount(y_test)
|
||
for idx, count in enumerate(class_counts):
|
||
if str(idx) in class_performance_df.index:
|
||
class_performance_df.loc[str(idx), 'support'] = count
|
||
|
||
# 添加类别名称
|
||
if self.class_names:
|
||
class_performance_df['class_name'] = [
|
||
self.class_names[int(idx)] if int(idx) < len(self.class_names) else idx
|
||
for idx in class_performance_df.index
|
||
]
|
||
|
||
# 保存类别性能指标
|
||
performance_path = os.path.join(self.output_dir, 'class_performance.csv')
|
||
class_performance_df.to_csv(performance_path)
|
||
logger.info(f"各类别性能指标已保存到: {performance_path}")
|
||
|
||
return class_performance_df
|
||
|
||
def plot_error_analysis(self, x_test: Union[np.ndarray, tf.data.Dataset],
|
||
y_test: Optional[np.ndarray] = None,
|
||
batch_size: Optional[int] = None,
|
||
num_samples: int = 10,
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
绘制误分类样本分析
|
||
|
||
Args:
|
||
x_test: 测试数据特征
|
||
y_test: 测试数据标签
|
||
batch_size: 批大小
|
||
num_samples: 要展示的误分类样本数量
|
||
save_path: 保存路径
|
||
"""
|
||
# 仅适用于文本数据的分析,需要原始文本
|
||
logger.info("误分类样本分析需要原始文本数据,此方法可能需要根据实际数据类型进行修改")
|
||
|
||
# 在实际应用中,这里应该根据实际数据类型进行修改
|
||
# 例如,对于序列化的文本,可能需要反序列化,或者使用词汇表将索引转换回文本
|
||
|
||
# 此处仅展示一个基本框架
|
||
batch_size = batch_size or self.batch_size
|
||
|
||
# 获取预测结果
|
||
y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=0)
|
||
y_pred = np.argmax(y_prob, axis=1)
|
||
|
||
# 处理y_test,确保y_test是一维数组
|
||
if isinstance(x_test, tf.data.Dataset):
|
||
# 如果是TensorFlow Dataset,需要从中提取y_test和x_test
|
||
dataset_iterator = iter(x_test)
|
||
x_test_extracted = []
|
||
y_test_extracted = []
|
||
|
||
try:
|
||
while True:
|
||
x, y = next(dataset_iterator)
|
||
x_test_extracted.append(x)
|
||
y_test_extracted.append(y)
|
||
except StopIteration:
|
||
pass
|
||
|
||
x_test = np.concatenate(x_test_extracted, axis=0)
|
||
y_test = np.concatenate(y_test_extracted, axis=0)
|
||
|
||
if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1:
|
||
y_test = np.argmax(y_test, axis=1)
|
||
|
||
# 找出误分类样本
|
||
misclassified_indices = np.where(y_pred != y_test)[0]
|
||
|
||
# 如果没有误分类样本,返回
|
||
if len(misclassified_indices) == 0:
|
||
logger.info("没有误分类样本")
|
||
return
|
||
|
||
# 随机选择一些误分类样本
|
||
if len(misclassified_indices) > num_samples:
|
||
misclassified_indices = np.random.choice(misclassified_indices, num_samples, replace=False)
|
||
|
||
# 保存误分类样本分析结果
|
||
misclassified_data = []
|
||
|
||
for idx in misclassified_indices:
|
||
true_label = y_test[idx]
|
||
pred_label = y_pred[idx]
|
||
|
||
true_class = self.class_names[true_label] if self.class_names else str(true_label)
|
||
pred_class = self.class_names[pred_label] if self.class_names else str(pred_label)
|
||
|
||
# 对于序列化的文本,此处需要进行反序列化
|
||
# 这里仅作示例,实际应用中需要根据具体数据类型修改
|
||
sample_text = f"Sample {idx}"
|
||
|
||
misclassified_data.append({
|
||
'sample_id': idx,
|
||
'true_label': true_label,
|
||
'predicted_label': pred_label,
|
||
'true_class': true_class,
|
||
'predicted_class': pred_class,
|
||
'confidence': float(y_prob[idx, pred_label]),
|
||
'sample_text': sample_text
|
||
})
|
||
|
||
# 创建DataFrame
|
||
misclassified_df = pd.DataFrame(misclassified_data)
|
||
|
||
# 保存结果
|
||
if save_path:
|
||
misclassified_df.to_csv(save_path)
|
||
logger.info(f"误分类样本分析已保存到: {save_path}")
|
||
|
||
return misclassified_df
|
||
|
||
|
||
================================================================================
|
||
文件: preprocessing/__init__.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: preprocessing/tokenization.py
|
||
================================================================================
|
||
|
||
"""
|
||
中文分词模块:负责中文文本分词处理
|
||
"""
|
||
import os
|
||
import jieba
|
||
import re
|
||
from typing import List, Dict, Tuple, Optional, Any, Set, Union
|
||
import pandas as pd
|
||
from collections import Counter
|
||
|
||
from config.system_config import STOPWORDS_DIR, ENCODING
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import read_text_file, write_text_file, ensure_dir
|
||
|
||
logger = get_logger("Tokenization")
|
||
|
||
|
||
class ChineseTokenizer:
|
||
"""中文分词器,基于jieba实现"""
|
||
|
||
def __init__(self, user_dict_path: Optional[str] = None,
|
||
use_hmm: bool = True,
|
||
remove_stopwords: bool = True,
|
||
stopwords_path: Optional[str] = None,
|
||
add_custom_words: Optional[List[str]] = None):
|
||
"""
|
||
初始化中文分词器
|
||
|
||
Args:
|
||
user_dict_path: 用户自定义词典路径
|
||
use_hmm: 是否使用HMM模型进行分词
|
||
remove_stopwords: 是否移除停用词
|
||
stopwords_path: 停用词表路径,如果为None,则使用默认停用词表
|
||
add_custom_words: 要添加的自定义词语列表
|
||
"""
|
||
self.use_hmm = use_hmm
|
||
self.remove_stopwords = remove_stopwords
|
||
|
||
# 加载用户自定义词典
|
||
if user_dict_path and os.path.exists(user_dict_path):
|
||
jieba.load_userdict(user_dict_path)
|
||
logger.info(f"已加载用户自定义词典:{user_dict_path}")
|
||
|
||
# 加载停用词
|
||
self.stopwords = set()
|
||
if remove_stopwords:
|
||
self._load_stopwords(stopwords_path)
|
||
|
||
# 添加自定义词语
|
||
if add_custom_words:
|
||
for word in add_custom_words:
|
||
jieba.add_word(word)
|
||
logger.info(f"已添加 {len(add_custom_words)} 个自定义词语")
|
||
|
||
def _load_stopwords(self, stopwords_path: Optional[str] = None) -> None:
|
||
"""
|
||
加载停用词
|
||
|
||
Args:
|
||
stopwords_path: 停用词表路径,如果为None,则使用默认停用词表
|
||
"""
|
||
# 如果没有指定停用词表路径,则使用默认停用词表
|
||
if not stopwords_path:
|
||
stopwords_path = os.path.join(STOPWORDS_DIR, "chinese_stopwords.txt")
|
||
|
||
# 如果没有找到默认停用词表,则创建一个空的停用词表
|
||
if not os.path.exists(stopwords_path):
|
||
ensure_dir(os.path.dirname(stopwords_path))
|
||
# 常见中文停用词
|
||
default_stopwords = [
|
||
"的", "了", "和", "是", "就", "都", "而", "及", "与", "这", "那", "你",
|
||
"我", "他", "她", "它", "们", "或", "上", "下", "之", "地", "得", "着",
|
||
"说", "对", "在", "于", "由", "因", "为", "所", "以", "能", "可", "会"
|
||
]
|
||
write_text_file("\n".join(default_stopwords), stopwords_path)
|
||
logger.info(f"未找到停用词表,已创建默认停用词表:{stopwords_path}")
|
||
|
||
# 加载停用词表
|
||
try:
|
||
with open(stopwords_path, "r", encoding=ENCODING) as f:
|
||
for line in f:
|
||
word = line.strip()
|
||
if word:
|
||
self.stopwords.add(word)
|
||
logger.info(f"已加载 {len(self.stopwords)} 个停用词")
|
||
except Exception as e:
|
||
logger.error(f"加载停用词表失败:{e}")
|
||
|
||
def add_stopwords(self, words: Union[str, List[str]]) -> None:
|
||
"""
|
||
添加停用词
|
||
|
||
Args:
|
||
words: 要添加的停用词(字符串或列表)
|
||
"""
|
||
if isinstance(words, str):
|
||
self.stopwords.add(words.strip())
|
||
else:
|
||
for word in words:
|
||
self.stopwords.add(word.strip())
|
||
|
||
def remove_stopwords_from_list(self, words: List[str]) -> List[str]:
|
||
"""
|
||
从词语列表中移除停用词
|
||
|
||
Args:
|
||
words: 词语列表
|
||
|
||
Returns:
|
||
移除停用词后的词语列表
|
||
"""
|
||
if not self.remove_stopwords:
|
||
return words
|
||
|
||
return [word for word in words if word not in self.stopwords]
|
||
|
||
def tokenize(self, text: str, return_string: bool = False,
|
||
cut_all: bool = False) -> Union[List[str], str]:
|
||
"""
|
||
对文本进行分词
|
||
|
||
Args:
|
||
text: 要分词的文本
|
||
return_string: 是否返回字符串(以空格分隔的词语)
|
||
cut_all: 是否使用全模式(默认使用精确模式)
|
||
|
||
Returns:
|
||
分词结果(词语列表或字符串)
|
||
"""
|
||
if not text:
|
||
return "" if return_string else []
|
||
|
||
# 使用jieba进行分词
|
||
if cut_all:
|
||
words = jieba.lcut(text, cut_all=True)
|
||
else:
|
||
words = jieba.lcut(text, HMM=self.use_hmm)
|
||
|
||
# 移除停用词
|
||
if self.remove_stopwords:
|
||
words = self.remove_stopwords_from_list(words)
|
||
|
||
# 返回结果
|
||
if return_string:
|
||
return " ".join(words)
|
||
else:
|
||
return words
|
||
|
||
def batch_tokenize(self, texts: List[str], return_string: bool = False,
|
||
cut_all: bool = False) -> List[Union[List[str], str]]:
|
||
"""
|
||
批量分词
|
||
|
||
Args:
|
||
texts: 要分词的文本列表
|
||
return_string: 是否返回字符串(以空格分隔的词语)
|
||
cut_all: 是否使用全模式(默认使用精确模式)
|
||
|
||
Returns:
|
||
分词结果列表
|
||
"""
|
||
return [self.tokenize(text, return_string, cut_all) for text in texts]
|
||
|
||
def analyze_tokens(self, texts: List[str], top_n: int = 20) -> Dict[str, Any]:
|
||
"""
|
||
分析文本中的词频分布
|
||
|
||
Args:
|
||
texts: 要分析的文本列表
|
||
top_n: 返回前多少个高频词
|
||
|
||
Returns:
|
||
包含词频分析结果的字典
|
||
"""
|
||
all_tokens = []
|
||
for text in texts:
|
||
tokens = self.tokenize(text, return_string=False)
|
||
all_tokens.extend(tokens)
|
||
|
||
# 统计词频
|
||
token_counter = Counter(all_tokens)
|
||
|
||
# 获取最常见的词
|
||
most_common = token_counter.most_common(top_n)
|
||
|
||
# 计算唯一词数量
|
||
unique_tokens = len(token_counter)
|
||
|
||
return {
|
||
"total_tokens": len(all_tokens),
|
||
"unique_tokens": unique_tokens,
|
||
"most_common": most_common,
|
||
"token_counter": token_counter
|
||
}
|
||
|
||
def get_top_keywords(self, texts: List[str], top_n: int = 20,
|
||
min_freq: int = 3, min_length: int = 2) -> List[Tuple[str, int]]:
|
||
"""
|
||
获取文本中的关键词
|
||
|
||
Args:
|
||
texts: 要分析的文本列表
|
||
top_n: 返回前多少个关键词
|
||
min_freq: 最小词频
|
||
min_length: 最小词长度(字符数)
|
||
|
||
Returns:
|
||
包含(关键词, 词频)的元组列表
|
||
"""
|
||
tokens_analysis = self.analyze_tokens(texts)
|
||
token_counter = tokens_analysis["token_counter"]
|
||
|
||
# 过滤满足条件的词
|
||
filtered_keywords = [(word, count) for word, count in token_counter.items()
|
||
if count >= min_freq and len(word) >= min_length]
|
||
|
||
# 按词频排序
|
||
sorted_keywords = sorted(filtered_keywords, key=lambda x: x[1], reverse=True)
|
||
|
||
return sorted_keywords[:top_n]
|
||
|
||
def get_vocabulary(self, texts: List[str], min_freq: int = 1) -> List[str]:
|
||
"""
|
||
获取词汇表
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
min_freq: 最小词频
|
||
|
||
Returns:
|
||
词汇表(词语列表)
|
||
"""
|
||
tokens_analysis = self.analyze_tokens(texts)
|
||
token_counter = tokens_analysis["token_counter"]
|
||
|
||
# 过滤满足最小词频的词
|
||
vocabulary = [word for word, count in token_counter.items() if count >= min_freq]
|
||
|
||
return vocabulary
|
||
|
||
def get_stopwords(self) -> Set[str]:
|
||
"""
|
||
获取停用词集合
|
||
|
||
Returns:
|
||
停用词集合
|
||
"""
|
||
return self.stopwords.copy()
|
||
|
||
================================================================================
|
||
文件: preprocessing/vectorizer.py
|
||
================================================================================
|
||
|
||
"""
|
||
文本向量化模块:实现文本向量化,包括词袋模型、TF-IDF和词嵌入等多种文本表示方法
|
||
"""
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from tensorflow.keras.preprocessing.text import Tokenizer
|
||
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
||
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
|
||
import pickle
|
||
import os
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
|
||
import gensim
|
||
from gensim.models import Word2Vec, KeyedVectors
|
||
|
||
from config.system_config import PROCESSED_DATA_DIR, EMBEDDINGS_DIR
|
||
from config.model_config import (
|
||
MAX_SEQUENCE_LENGTH, MAX_NUM_WORDS, MIN_WORD_FREQUENCY
|
||
)
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import save_pickle, load_pickle, ensure_dir
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
|
||
logger = get_logger("Vectorizer")
|
||
|
||
|
||
class TextVectorizer:
|
||
"""文本向量化基类,定义通用接口"""
|
||
|
||
def __init__(self, max_features: int = MAX_NUM_WORDS):
|
||
"""
|
||
初始化文本向量化器
|
||
|
||
Args:
|
||
max_features: 最大特征数(词汇表大小)
|
||
"""
|
||
self.max_features = max_features
|
||
self.vectorizer = None
|
||
self.is_fitted = False
|
||
|
||
def fit(self, texts: List[str]) -> None:
|
||
"""
|
||
在文本上训练向量化器
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
"""
|
||
raise NotImplementedError("子类必须实现此方法")
|
||
|
||
def transform(self, texts: List[str]) -> np.ndarray:
|
||
"""
|
||
将文本转换为向量表示
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
|
||
Returns:
|
||
向量表示
|
||
"""
|
||
raise NotImplementedError("子类必须实现此方法")
|
||
|
||
def fit_transform(self, texts: List[str]) -> np.ndarray:
|
||
"""
|
||
在文本上训练向量化器,并将文本转换为向量表示
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
|
||
Returns:
|
||
向量表示
|
||
"""
|
||
self.fit(texts)
|
||
return self.transform(texts)
|
||
|
||
def save(self, path: str) -> None:
|
||
"""
|
||
保存向量化器
|
||
|
||
Args:
|
||
path: 保存路径
|
||
"""
|
||
ensure_dir(os.path.dirname(path))
|
||
save_pickle(self.vectorizer, path)
|
||
logger.info(f"向量化器已保存到:{path}")
|
||
|
||
def load(self, path: str) -> None:
|
||
"""
|
||
加载向量化器
|
||
|
||
Args:
|
||
path: 加载路径
|
||
"""
|
||
self.vectorizer = load_pickle(path)
|
||
self.is_fitted = True
|
||
logger.info(f"向量化器已从 {path} 加载")
|
||
|
||
def get_vocabulary(self) -> List[str]:
|
||
"""
|
||
获取词汇表
|
||
|
||
Returns:
|
||
词汇表
|
||
"""
|
||
raise NotImplementedError("子类必须实现此方法")
|
||
|
||
def get_vocabulary_size(self) -> int:
|
||
"""
|
||
获取词汇表大小
|
||
|
||
Returns:
|
||
词汇表大小
|
||
"""
|
||
raise NotImplementedError("子类必须实现此方法")
|
||
|
||
|
||
class BagOfWordsVectorizer(TextVectorizer):
|
||
"""词袋模型向量化器"""
|
||
|
||
def __init__(self, max_features: int = MAX_NUM_WORDS,
|
||
min_df: int = MIN_WORD_FREQUENCY,
|
||
tokenizer: Optional[Callable[[str], List[str]]] = None,
|
||
binary: bool = False):
|
||
"""
|
||
初始化词袋模型向量化器
|
||
|
||
Args:
|
||
max_features: 最大特征数(词汇表大小)
|
||
min_df: 最小文档频率
|
||
tokenizer: 分词器函数,接收文本,返回词语列表
|
||
binary: 是否使用二进制计数(只关注词语是否出现,不关注频率)
|
||
"""
|
||
super().__init__(max_features)
|
||
self.min_df = min_df
|
||
self.binary = binary
|
||
|
||
# 创建sklearn的CountVectorizer
|
||
self.vectorizer = CountVectorizer(
|
||
max_features=max_features,
|
||
min_df=min_df,
|
||
tokenizer=tokenizer,
|
||
binary=binary
|
||
)
|
||
|
||
def fit(self, texts: List[str]) -> None:
|
||
"""
|
||
在文本上训练词袋模型
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
"""
|
||
self.vectorizer.fit(texts)
|
||
self.is_fitted = True
|
||
logger.info(f"词袋模型已训练,词汇表大小:{len(self.vectorizer.vocabulary_)}")
|
||
|
||
def transform(self, texts: List[str]) -> np.ndarray:
|
||
"""
|
||
将文本转换为词袋向量表示
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
|
||
Returns:
|
||
词袋向量表示(稀疏矩阵)
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
return self.vectorizer.transform(texts)
|
||
|
||
def get_vocabulary(self) -> List[str]:
|
||
"""
|
||
获取词汇表
|
||
|
||
Returns:
|
||
词汇表(按索引排序)
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
# CountVectorizer的词汇表是一个字典,键为词,值为索引
|
||
vocab_dict = self.vectorizer.vocabulary_
|
||
vocab_list = [""] * len(vocab_dict)
|
||
for word, idx in vocab_dict.items():
|
||
vocab_list[idx] = word
|
||
|
||
return vocab_list
|
||
|
||
def get_vocabulary_size(self) -> int:
|
||
"""
|
||
获取词汇表大小
|
||
|
||
Returns:
|
||
词汇表大小
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
return len(self.vectorizer.vocabulary_)
|
||
|
||
|
||
class TfidfVectorizer(TextVectorizer):
|
||
"""TF-IDF向量化器"""
|
||
|
||
def __init__(self, max_features: int = MAX_NUM_WORDS,
|
||
min_df: int = MIN_WORD_FREQUENCY,
|
||
tokenizer: Optional[Callable[[str], List[str]]] = None,
|
||
norm: str = 'l2',
|
||
use_idf: bool = True,
|
||
smooth_idf: bool = True,
|
||
sublinear_tf: bool = False):
|
||
"""
|
||
初始化TF-IDF向量化器
|
||
|
||
Args:
|
||
max_features: 最大特征数(词汇表大小)
|
||
min_df: 最小文档频率
|
||
tokenizer: 分词器函数,接收文本,返回词语列表
|
||
norm: 规范化方法,默认为L2范数
|
||
use_idf: 是否使用IDF(逆文档频率)
|
||
smooth_idf: 是否平滑IDF权重
|
||
sublinear_tf: 是否应用sublinear scaling(对TF取对数)
|
||
"""
|
||
super().__init__(max_features)
|
||
self.min_df = min_df
|
||
|
||
# 创建sklearn的TfidfVectorizer
|
||
self.vectorizer = sklearn.feature_extraction.text.TfidfVectorizer(
|
||
max_features=max_features,
|
||
min_df=min_df,
|
||
tokenizer=tokenizer,
|
||
norm=norm,
|
||
use_idf=use_idf,
|
||
smooth_idf=smooth_idf,
|
||
sublinear_tf=sublinear_tf
|
||
)
|
||
|
||
def fit(self, texts: List[str]) -> None:
|
||
"""
|
||
在文本上训练TF-IDF模型
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
"""
|
||
self.vectorizer.fit(texts)
|
||
self.is_fitted = True
|
||
logger.info(f"TF-IDF模型已训练,词汇表大小:{len(self.vectorizer.vocabulary_)}")
|
||
|
||
def transform(self, texts: List[str]) -> np.ndarray:
|
||
"""
|
||
将文本转换为TF-IDF向量表示
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
|
||
Returns:
|
||
TF-IDF向量表示(稀疏矩阵)
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
return self.vectorizer.transform(texts)
|
||
|
||
def get_vocabulary(self) -> List[str]:
|
||
"""
|
||
获取词汇表
|
||
|
||
Returns:
|
||
词汇表(按索引排序)
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
# TfidfVectorizer的词汇表是一个字典,键为词,值为索引
|
||
vocab_dict = self.vectorizer.vocabulary_
|
||
vocab_list = [""] * len(vocab_dict)
|
||
for word, idx in vocab_dict.items():
|
||
vocab_list[idx] = word
|
||
|
||
return vocab_list
|
||
|
||
def get_vocabulary_size(self) -> int:
|
||
"""
|
||
获取词汇表大小
|
||
|
||
Returns:
|
||
词汇表大小
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
return len(self.vectorizer.vocabulary_)
|
||
|
||
def get_feature_names(self) -> List[str]:
|
||
"""
|
||
获取特征名称(词汇表)
|
||
|
||
Returns:
|
||
特征名称列表
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
return self.vectorizer.get_feature_names_out()
|
||
|
||
def get_idf(self) -> np.ndarray:
|
||
"""
|
||
获取IDF权重
|
||
|
||
Returns:
|
||
IDF权重数组
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
return self.vectorizer.idf_
|
||
|
||
|
||
class SequenceVectorizer(TextVectorizer):
|
||
"""序列向量化器,使用Keras的Tokenizer"""
|
||
|
||
def __init__(self, max_features: int = MAX_NUM_WORDS,
|
||
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
|
||
oov_token: str = "<OOV>",
|
||
padding: str = "post",
|
||
truncating: str = "post"):
|
||
"""
|
||
初始化序列向量化器
|
||
|
||
Args:
|
||
max_features: 最大特征数(词汇表大小)
|
||
max_sequence_length: 序列最大长度
|
||
oov_token: 未登录词标记
|
||
padding: 填充方式,'pre'或'post'
|
||
truncating: 截断方式,'pre'或'post'
|
||
"""
|
||
super().__init__(max_features)
|
||
self.max_sequence_length = max_sequence_length
|
||
self.oov_token = oov_token
|
||
self.padding = padding
|
||
self.truncating = truncating
|
||
|
||
# 创建Keras的Tokenizer
|
||
self.vectorizer = Tokenizer(num_words=max_features, oov_token=oov_token)
|
||
|
||
def fit(self, texts: List[str]) -> None:
|
||
"""
|
||
在文本上训练序列向量化器
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
"""
|
||
self.vectorizer.fit_on_texts(texts)
|
||
self.is_fitted = True
|
||
logger.info(f"序列向量化器已训练,词汇表大小:{len(self.vectorizer.word_index)}")
|
||
|
||
def transform(self, texts: List[str]) -> np.ndarray:
|
||
"""
|
||
将文本转换为整数序列,并进行填充
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
|
||
Returns:
|
||
整数序列表示
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
sequences = self.vectorizer.texts_to_sequences(texts)
|
||
padded_sequences = pad_sequences(
|
||
sequences,
|
||
maxlen=self.max_sequence_length,
|
||
padding=self.padding,
|
||
truncating=self.truncating
|
||
)
|
||
|
||
return padded_sequences
|
||
|
||
def get_vocabulary(self) -> List[str]:
|
||
"""
|
||
获取词汇表
|
||
|
||
Returns:
|
||
词汇表(按索引排序)
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
# Tokenizer的词汇表是一个字典,键为词,值为索引(从1开始)
|
||
word_index = self.vectorizer.word_index
|
||
index_word = {index: word for word, index in word_index.items()}
|
||
|
||
# 注意索引0保留给padding,索引1保留给OOV(如果有设置)
|
||
vocab = ["<PAD>"]
|
||
if self.oov_token:
|
||
vocab.append(self.oov_token)
|
||
|
||
max_index = min(self.max_features, len(word_index) + 1) if self.max_features else len(word_index) + 1
|
||
for i in range(1, max_index):
|
||
if i in index_word:
|
||
vocab.append(index_word[i])
|
||
|
||
return vocab
|
||
|
||
def get_vocabulary_size(self) -> int:
|
||
"""
|
||
获取词汇表大小
|
||
|
||
Returns:
|
||
词汇表大小
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
# +1是因为索引0保留给padding
|
||
return min(self.max_features, len(self.vectorizer.word_index) + 1) if self.max_features else len(
|
||
self.vectorizer.word_index) + 1
|
||
|
||
def texts_to_sequences(self, texts: List[str]) -> List[List[int]]:
|
||
"""
|
||
将文本转换为整数序列(不填充)
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
|
||
Returns:
|
||
整数序列列表
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
return self.vectorizer.texts_to_sequences(texts)
|
||
|
||
def sequences_to_padded(self, sequences: List[List[int]]) -> np.ndarray:
|
||
"""
|
||
将整数序列填充到指定长度
|
||
|
||
Args:
|
||
sequences: 整数序列列表
|
||
|
||
Returns:
|
||
填充后的整数序列
|
||
"""
|
||
return pad_sequences(
|
||
sequences,
|
||
maxlen=self.max_sequence_length,
|
||
padding=self.padding,
|
||
truncating=self.truncating
|
||
)
|
||
|
||
def save(self, path: str) -> None:
|
||
"""
|
||
保存序列向量化器
|
||
|
||
Args:
|
||
path: 保存路径
|
||
"""
|
||
ensure_dir(os.path.dirname(path))
|
||
|
||
# 保存配置和状态
|
||
tokenizer_state = {
|
||
'tokenizer': self.vectorizer,
|
||
'max_features': self.max_features,
|
||
'max_sequence_length': self.max_sequence_length,
|
||
'oov_token': self.oov_token,
|
||
'padding': self.padding,
|
||
'truncating': self.truncating,
|
||
'is_fitted': self.is_fitted
|
||
}
|
||
|
||
save_pickle(tokenizer_state, path)
|
||
logger.info(f"序列向量化器已保存到:{path}")
|
||
|
||
def load(self, path: str) -> None:
|
||
"""
|
||
加载序列向量化器
|
||
|
||
Args:
|
||
path: 加载路径
|
||
"""
|
||
tokenizer_state = load_pickle(path)
|
||
|
||
self.vectorizer = tokenizer_state['tokenizer']
|
||
self.max_features = tokenizer_state['max_features']
|
||
self.max_sequence_length = tokenizer_state['max_sequence_length']
|
||
self.oov_token = tokenizer_state['oov_token']
|
||
self.padding = tokenizer_state['padding']
|
||
self.truncating = tokenizer_state['truncating']
|
||
self.is_fitted = tokenizer_state['is_fitted']
|
||
|
||
logger.info(f"序列向量化器已从 {path} 加载,词汇表大小:{len(self.vectorizer.word_index)}")
|
||
|
||
class Word2VecVectorizer(TextVectorizer):
|
||
"""Word2Vec词嵌入向量化器"""
|
||
|
||
def __init__(self, vector_size: int = 100,
|
||
window: int = 5,
|
||
min_count: int = MIN_WORD_FREQUENCY,
|
||
workers: int = 4,
|
||
sg: int = 1, # 1表示Skip-gram模型,0表示CBOW模型
|
||
max_sequence_length: int = MAX_SEQUENCE_LENGTH,
|
||
padding: str = "post",
|
||
truncating: str = "post",
|
||
pretrained_path: Optional[str] = None):
|
||
"""
|
||
初始化Word2Vec词嵌入向量化器
|
||
|
||
Args:
|
||
vector_size: 词向量维度
|
||
window: 上下文窗口大小
|
||
min_count: 最小词频
|
||
workers: 并行训练的线程数
|
||
sg: 训练算法,1表示Skip-gram,0表示CBOW
|
||
max_sequence_length: 序列最大长度
|
||
padding: 填充方式,'pre'或'post'
|
||
truncating: 截断方式,'pre'或'post'
|
||
pretrained_path: 预训练词向量路径,如果不为None,则加载预训练词向量
|
||
"""
|
||
super().__init__(max_features=None) # Word2Vec没有max_features限制
|
||
self.vector_size = vector_size
|
||
self.window = window
|
||
self.min_count = min_count
|
||
self.workers = workers
|
||
self.sg = sg
|
||
self.max_sequence_length = max_sequence_length
|
||
self.padding = padding
|
||
self.truncating = truncating
|
||
self.pretrained_path = pretrained_path
|
||
|
||
# Word2Vec模型
|
||
self.model = None
|
||
|
||
# 词汇表
|
||
self.word_index = {}
|
||
self.index_word = {}
|
||
|
||
# 如果有预训练词向量,加载它
|
||
if pretrained_path and os.path.exists(pretrained_path):
|
||
self._load_pretrained(pretrained_path)
|
||
|
||
def _load_pretrained(self, path: str) -> None:
|
||
"""
|
||
加载预训练词向量
|
||
|
||
Args:
|
||
path: 预训练词向量路径
|
||
"""
|
||
try:
|
||
# 尝试加载Word2Vec模型
|
||
self.model = Word2Vec.load(path)
|
||
logger.info(f"已加载预训练Word2Vec模型:{path}")
|
||
except:
|
||
try:
|
||
# 尝试加载词向量(Word2Vec、GloVe或FastText格式)
|
||
self.model = KeyedVectors.load_word2vec_format(path, binary=path.endswith('.bin'))
|
||
logger.info(f"已加载预训练词向量:{path}")
|
||
except Exception as e:
|
||
logger.error(f"加载预训练词向量失败:{e}")
|
||
return
|
||
|
||
# 如果加载成功,构建词汇表
|
||
self._build_vocab_from_model()
|
||
self.is_fitted = True
|
||
|
||
def _build_vocab_from_model(self) -> None:
|
||
"""从模型构建词汇表"""
|
||
# 获取词汇表
|
||
vocabulary = list(self.model.wv.index_to_key)
|
||
|
||
# 构建词汇表索引
|
||
self.word_index = {word: idx + 1 for idx, word in enumerate(vocabulary)} # 索引0保留给padding
|
||
self.index_word = {idx + 1: word for idx, word in enumerate(vocabulary)}
|
||
self.index_word[0] = "<PAD>"
|
||
|
||
def fit(self, tokenized_texts: List[List[str]]) -> None:
|
||
"""
|
||
在分词后的文本上训练Word2Vec模型
|
||
|
||
Args:
|
||
tokenized_texts: 分词后的文本列表(每个文本是一个词语列表)
|
||
"""
|
||
# 如果已经有预训练模型,跳过训练
|
||
if self.is_fitted and self.model is not None:
|
||
logger.info("已有预训练模型,跳过训练")
|
||
return
|
||
|
||
# 训练Word2Vec模型
|
||
self.model = Word2Vec(
|
||
sentences=tokenized_texts,
|
||
vector_size=self.vector_size,
|
||
window=self.window,
|
||
min_count=self.min_count,
|
||
workers=self.workers,
|
||
sg=self.sg
|
||
)
|
||
|
||
# 构建词汇表
|
||
self._build_vocab_from_model()
|
||
self.is_fitted = True
|
||
|
||
logger.info(f"Word2Vec模型已训练,词汇表大小:{len(self.word_index)}")
|
||
|
||
def transform(self, tokenized_texts: List[List[str]]) -> np.ndarray:
|
||
"""
|
||
将分词后的文本转换为词向量序列
|
||
|
||
Args:
|
||
tokenized_texts: 分词后的文本列表(每个文本是一个词语列表)
|
||
|
||
Returns:
|
||
词向量序列,形状为(样本数, 最大序列长度, 词向量维度)
|
||
"""
|
||
if not self.is_fitted or self.model is None:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
# 初始化结果数组
|
||
result = np.zeros((len(tokenized_texts), self.max_sequence_length, self.vector_size))
|
||
|
||
# 处理每个文本
|
||
for i, text in enumerate(tokenized_texts):
|
||
seq_len = min(len(text), self.max_sequence_length)
|
||
|
||
# 根据截断方式处理
|
||
if self.truncating == 'pre' and len(text) > self.max_sequence_length:
|
||
text = text[-self.max_sequence_length:]
|
||
elif self.truncating == 'post' and len(text) > self.max_sequence_length:
|
||
text = text[:self.max_sequence_length]
|
||
|
||
# 获取每个词的词向量
|
||
for j, word in enumerate(text[:seq_len]):
|
||
if word in self.model.wv:
|
||
# 根据填充方式确定位置
|
||
pos = j if self.padding == 'post' else self.max_sequence_length - seq_len + j
|
||
result[i, pos] = self.model.wv[word]
|
||
|
||
return result
|
||
|
||
def transform_to_indices(self, tokenized_texts: List[List[str]]) -> np.ndarray:
|
||
"""
|
||
将分词后的文本转换为词索引序列,并填充
|
||
|
||
Args:
|
||
tokenized_texts: 分词后的文本列表(每个文本是一个词语列表)
|
||
|
||
Returns:
|
||
词索引序列
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
# 将词转换为索引
|
||
sequences = []
|
||
for text in tokenized_texts:
|
||
seq = [self.word_index.get(word, 0) for word in text] # 未登录词用0(padding)
|
||
sequences.append(seq)
|
||
|
||
# 填充序列
|
||
padded_sequences = pad_sequences(
|
||
sequences,
|
||
maxlen=self.max_sequence_length,
|
||
padding=self.padding,
|
||
truncating=self.truncating
|
||
)
|
||
|
||
return padded_sequences
|
||
|
||
def get_embedding_matrix(self) -> np.ndarray:
|
||
"""
|
||
获取嵌入矩阵,用于Embedding层的权重初始化
|
||
|
||
Returns:
|
||
嵌入矩阵,形状为(词汇表大小, 词向量维度)
|
||
"""
|
||
if not self.is_fitted or self.model is None:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
vocab_size = len(self.word_index) + 1 # +1是因为索引0保留给padding
|
||
embedding_matrix = np.zeros((vocab_size, self.vector_size))
|
||
|
||
# 填充嵌入矩阵
|
||
for word, idx in self.word_index.items():
|
||
if word in self.model.wv:
|
||
embedding_matrix[idx] = self.model.wv[word]
|
||
|
||
return embedding_matrix
|
||
|
||
def get_vocabulary(self) -> List[str]:
|
||
"""
|
||
获取词汇表
|
||
|
||
Returns:
|
||
词汇表(按索引排序)
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
vocab = ["<PAD>"] # 索引0保留给padding
|
||
for idx in range(1, len(self.index_word) + 1):
|
||
if idx in self.index_word:
|
||
vocab.append(self.index_word[idx])
|
||
|
||
return vocab
|
||
|
||
def get_vocabulary_size(self) -> int:
|
||
"""
|
||
获取词汇表大小
|
||
|
||
Returns:
|
||
词汇表大小
|
||
"""
|
||
if not self.is_fitted:
|
||
raise ValueError("向量化器尚未训练,请先调用fit方法")
|
||
|
||
return len(self.word_index) + 1 # +1是因为索引0保留给padding
|
||
|
||
def save(self, path: str) -> None:
|
||
"""
|
||
保存Word2Vec向量化器
|
||
|
||
Args:
|
||
path: 保存路径
|
||
"""
|
||
ensure_dir(os.path.dirname(path))
|
||
|
||
# 保存模型和配置
|
||
model_path = os.path.join(os.path.dirname(path), "word2vec_model")
|
||
if self.model:
|
||
self.model.save(model_path)
|
||
|
||
# 保存配置和状态
|
||
state = {
|
||
'word_index': self.word_index,
|
||
'index_word': self.index_word,
|
||
'vector_size': self.vector_size,
|
||
'window': self.window,
|
||
'min_count': self.min_count,
|
||
'workers': self.workers,
|
||
'sg': self.sg,
|
||
'max_sequence_length': self.max_sequence_length,
|
||
'padding': self.padding,
|
||
'truncating': self.truncating,
|
||
'is_fitted': self.is_fitted,
|
||
'model_path': model_path if self.model else None
|
||
}
|
||
|
||
save_pickle(state, path)
|
||
logger.info(f"Word2Vec向量化器已保存到:{path}")
|
||
|
||
def load(self, path: str) -> None:
|
||
"""
|
||
加载Word2Vec向量化器
|
||
|
||
Args:
|
||
path: 加载路径
|
||
"""
|
||
state = load_pickle(path)
|
||
|
||
self.word_index = state['word_index']
|
||
self.index_word = state['index_word']
|
||
self.vector_size = state['vector_size']
|
||
self.window = state['window']
|
||
self.min_count = state['min_count']
|
||
self.workers = state['workers']
|
||
self.sg = state['sg']
|
||
self.max_sequence_length = state['max_sequence_length']
|
||
self.padding = state['padding']
|
||
self.truncating = state['truncating']
|
||
self.is_fitted = state['is_fitted']
|
||
|
||
# 加载模型
|
||
model_path = state.get('model_path')
|
||
if model_path and os.path.exists(model_path):
|
||
self.model = Word2Vec.load(model_path)
|
||
|
||
logger.info(f"Word2Vec向量化器已从 {path} 加载,词汇表大小:{len(self.word_index)}")
|
||
|
||
================================================================================
|
||
文件: preprocessing/text_cleaner.py
|
||
================================================================================
|
||
|
||
"""
|
||
文本清洗模块:实现文本清洗,去除无用字符、HTML标签等
|
||
"""
|
||
import re
|
||
import unicodedata
|
||
import html
|
||
from typing import List, Dict, Tuple, Optional, Any, Callable, Set, Union
|
||
import string
|
||
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("TextCleaner")
|
||
|
||
|
||
class TextCleaner:
|
||
"""文本清洗类,提供各种文本清洗方法"""
|
||
|
||
def __init__(self, remove_html: bool = True,
|
||
remove_urls: bool = True,
|
||
remove_emails: bool = True,
|
||
remove_numbers: bool = False,
|
||
remove_punctuation: bool = False,
|
||
lowercase: bool = False,
|
||
normalize_unicode: bool = True,
|
||
remove_excessive_spaces: bool = True,
|
||
remove_short_texts: bool = False,
|
||
min_text_length: int = 10,
|
||
custom_patterns: Optional[List[str]] = None):
|
||
"""
|
||
初始化文本清洗器
|
||
|
||
Args:
|
||
remove_html: 是否移除HTML标签
|
||
remove_urls: 是否移除URL
|
||
remove_emails: 是否移除电子邮件地址
|
||
remove_numbers: 是否移除数字
|
||
remove_punctuation: 是否移除标点符号
|
||
lowercase: 是否转为小写(对中文无效)
|
||
normalize_unicode: 是否规范化Unicode字符
|
||
remove_excessive_spaces: 是否移除多余空格
|
||
remove_short_texts: 是否过滤掉短文本
|
||
min_text_length: 最小文本长度(当remove_short_texts=True时有效)
|
||
custom_patterns: 自定义的正则表达式模式列表,用于额外的文本清洗
|
||
"""
|
||
self.remove_html = remove_html
|
||
self.remove_urls = remove_urls
|
||
self.remove_emails = remove_emails
|
||
self.remove_numbers = remove_numbers
|
||
self.remove_punctuation = remove_punctuation
|
||
self.lowercase = lowercase
|
||
self.normalize_unicode = normalize_unicode
|
||
self.remove_excessive_spaces = remove_excessive_spaces
|
||
self.remove_short_texts = remove_short_texts
|
||
self.min_text_length = min_text_length
|
||
self.custom_patterns = custom_patterns or []
|
||
|
||
# 编译正则表达式
|
||
self.html_pattern = re.compile(r'<.*?>')
|
||
self.url_pattern = re.compile(r'https?://\S+|www\.\S+')
|
||
self.email_pattern = re.compile(r'\S+@\S+\.\S+')
|
||
self.number_pattern = re.compile(r'\d+')
|
||
self.space_pattern = re.compile(r'\s+')
|
||
|
||
# 编译自定义模式
|
||
self.compiled_custom_patterns = [re.compile(pattern) for pattern in self.custom_patterns]
|
||
|
||
# 中文标点符号
|
||
self.chinese_punctuation = ",。!?;:""''【】《》()、…—~·"
|
||
|
||
logger.info("文本清洗器初始化完成")
|
||
|
||
def clean_text(self, text: str) -> str:
|
||
"""
|
||
清洗文本,应用所有已配置的清洗方法
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
清洗后的文本
|
||
"""
|
||
if not text:
|
||
return ""
|
||
|
||
# HTML解码
|
||
if self.remove_html:
|
||
text = html.unescape(text)
|
||
text = self.html_pattern.sub(' ', text)
|
||
|
||
# 移除URL
|
||
if self.remove_urls:
|
||
text = self.url_pattern.sub(' ', text)
|
||
|
||
# 移除电子邮件
|
||
if self.remove_emails:
|
||
text = self.email_pattern.sub(' ', text)
|
||
|
||
# Unicode规范化
|
||
if self.normalize_unicode:
|
||
text = unicodedata.normalize('NFKC', text)
|
||
|
||
# 移除数字
|
||
if self.remove_numbers:
|
||
text = self.number_pattern.sub(' ', text)
|
||
|
||
# 移除标点符号
|
||
if self.remove_punctuation:
|
||
# 处理英文标点
|
||
for punct in string.punctuation:
|
||
text = text.replace(punct, ' ')
|
||
# 处理中文标点
|
||
for punct in self.chinese_punctuation:
|
||
text = text.replace(punct, ' ')
|
||
|
||
# 应用自定义清洗模式
|
||
for pattern in self.compiled_custom_patterns:
|
||
text = pattern.sub(' ', text)
|
||
|
||
# 转为小写
|
||
if self.lowercase:
|
||
text = text.lower()
|
||
|
||
# 移除多余空格
|
||
if self.remove_excessive_spaces:
|
||
text = self.space_pattern.sub(' ', text)
|
||
text = text.strip()
|
||
|
||
# 过滤掉短文本
|
||
if self.remove_short_texts and len(text) < self.min_text_length:
|
||
return ""
|
||
|
||
return text
|
||
|
||
def clean_texts(self, texts: List[str]) -> List[str]:
|
||
"""
|
||
批量清洗文本
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
|
||
Returns:
|
||
清洗后的文本列表
|
||
"""
|
||
return [self.clean_text(text) for text in texts]
|
||
|
||
def remove_redundant_texts(self, texts: List[str]) -> List[str]:
|
||
"""
|
||
移除冗余文本(空文本和长度小于阈值的文本)
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
|
||
Returns:
|
||
移除冗余后的文本列表
|
||
"""
|
||
return [text for text in texts if text and len(text) >= self.min_text_length]
|
||
|
||
@staticmethod
|
||
def remove_specific_characters(text: str, chars_to_remove: Union[str, Set[str]]) -> str:
|
||
"""
|
||
移除特定字符
|
||
|
||
Args:
|
||
text: 原始文本
|
||
chars_to_remove: 要移除的字符(字符串或字符集合)
|
||
|
||
Returns:
|
||
移除特定字符后的文本
|
||
"""
|
||
if isinstance(chars_to_remove, str):
|
||
for char in chars_to_remove:
|
||
text = text.replace(char, '')
|
||
else:
|
||
for char in chars_to_remove:
|
||
text = text.replace(char, '')
|
||
return text
|
||
|
||
@staticmethod
|
||
def replace_characters(text: str, char_map: Dict[str, str]) -> str:
|
||
"""
|
||
替换特定字符
|
||
|
||
Args:
|
||
text: 原始文本
|
||
char_map: 字符映射字典,键为要替换的字符,值为替换后的字符
|
||
|
||
Returns:
|
||
替换特定字符后的文本
|
||
"""
|
||
for old_char, new_char in char_map.items():
|
||
text = text.replace(old_char, new_char)
|
||
return text
|
||
|
||
@staticmethod
|
||
def remove_empty_lines(text: str) -> str:
|
||
"""
|
||
移除空行
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
移除空行后的文本
|
||
"""
|
||
lines = text.splitlines()
|
||
non_empty_lines = [line for line in lines if line.strip()]
|
||
return '\n'.join(non_empty_lines)
|
||
|
||
@staticmethod
|
||
def truncate_text(text: str, max_length: int, truncate_from_end: bool = True) -> str:
|
||
"""
|
||
截断文本
|
||
|
||
Args:
|
||
text: 原始文本
|
||
max_length: 最大长度
|
||
truncate_from_end: 是否从末尾截断,如果为False则从开头截断
|
||
|
||
Returns:
|
||
截断后的文本
|
||
"""
|
||
if len(text) <= max_length:
|
||
return text
|
||
|
||
if truncate_from_end:
|
||
return text[:max_length]
|
||
else:
|
||
return text[len(text) - max_length:]
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: preprocessing/feature_extraction.py
|
||
================================================================================
|
||
|
||
"""
|
||
特征提取模块:实现文本特征提取,包括语法特征、语义特征等
|
||
"""
|
||
import re
|
||
import numpy as np
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Set
|
||
from collections import Counter
|
||
import jieba.posseg as pseg
|
||
|
||
from config.system_config import CATEGORIES
|
||
from utils.logger import get_logger
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
|
||
logger = get_logger("FeatureExtraction")
|
||
|
||
|
||
class FeatureExtractor:
|
||
"""特征提取基类,定义通用接口"""
|
||
|
||
def __init__(self):
|
||
"""初始化特征提取器"""
|
||
pass
|
||
|
||
def extract(self, text: str) -> Dict[str, Any]:
|
||
"""
|
||
从文本中提取特征
|
||
|
||
Args:
|
||
text: 文本
|
||
|
||
Returns:
|
||
特征字典
|
||
"""
|
||
raise NotImplementedError("子类必须实现此方法")
|
||
|
||
def batch_extract(self, texts: List[str]) -> List[Dict[str, Any]]:
|
||
"""
|
||
批量提取特征
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
|
||
Returns:
|
||
特征字典列表
|
||
"""
|
||
return [self.extract(text) for text in texts]
|
||
|
||
def extract_as_vector(self, text: str) -> np.ndarray:
|
||
"""
|
||
从文本中提取特征,并转换为向量表示
|
||
|
||
Args:
|
||
text: 文本
|
||
|
||
Returns:
|
||
特征向量
|
||
"""
|
||
raise NotImplementedError("子类必须实现此方法")
|
||
|
||
def batch_extract_as_vector(self, texts: List[str]) -> np.ndarray:
|
||
"""
|
||
批量提取特征,并转换为向量表示
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
|
||
Returns:
|
||
特征向量数组
|
||
"""
|
||
return np.array([self.extract_as_vector(text) for text in texts])
|
||
|
||
|
||
class StatisticalFeatureExtractor(FeatureExtractor):
|
||
"""统计特征提取器,提取文本的统计特征"""
|
||
|
||
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None):
|
||
"""
|
||
初始化统计特征提取器
|
||
|
||
Args:
|
||
tokenizer: 分词器,如果为None则创建一个新的分词器
|
||
"""
|
||
super().__init__()
|
||
self.tokenizer = tokenizer or ChineseTokenizer()
|
||
|
||
def extract(self, text: str) -> Dict[str, Any]:
|
||
"""
|
||
从文本中提取统计特征
|
||
|
||
Args:
|
||
text: 文本
|
||
|
||
Returns:
|
||
特征字典,包含各种统计特征
|
||
"""
|
||
if not text:
|
||
return {
|
||
"char_count": 0,
|
||
"word_count": 0,
|
||
"sentence_count": 0,
|
||
"avg_word_length": 0,
|
||
"avg_sentence_length": 0,
|
||
"contains_number": False,
|
||
"contains_english": False,
|
||
"punctuation_ratio": 0,
|
||
"top_words": []
|
||
}
|
||
|
||
# 字符数
|
||
char_count = len(text)
|
||
|
||
# 分词
|
||
words = self.tokenizer.tokenize(text, return_string=False)
|
||
word_count = len(words)
|
||
|
||
# 句子数(按标点符号分割)
|
||
sentences = re.split(r'[。!?!?]+', text)
|
||
sentences = [s for s in sentences if s.strip()]
|
||
sentence_count = len(sentences)
|
||
|
||
# 平均词长
|
||
avg_word_length = sum(len(word) for word in words) / word_count if word_count > 0 else 0
|
||
|
||
# 平均句长(以字符为单位)
|
||
avg_sentence_length = char_count / sentence_count if sentence_count > 0 else 0
|
||
|
||
# 是否包含数字
|
||
contains_number = bool(re.search(r'\d', text))
|
||
|
||
# 是否包含英文
|
||
contains_english = bool(re.search(r'[a-zA-Z]', text))
|
||
|
||
# 标点符号比例
|
||
punctuation_pattern = re.compile(r'[^\w\s]')
|
||
punctuations = punctuation_pattern.findall(text)
|
||
punctuation_ratio = len(punctuations) / char_count if char_count > 0 else 0
|
||
|
||
# 高频词
|
||
word_counter = Counter(words)
|
||
top_words = word_counter.most_common(5)
|
||
|
||
return {
|
||
"char_count": char_count,
|
||
"word_count": word_count,
|
||
"sentence_count": sentence_count,
|
||
"avg_word_length": avg_word_length,
|
||
"avg_sentence_length": avg_sentence_length,
|
||
"contains_number": contains_number,
|
||
"contains_english": contains_english,
|
||
"punctuation_ratio": punctuation_ratio,
|
||
"top_words": top_words
|
||
}
|
||
|
||
def extract_as_vector(self, text: str) -> np.ndarray:
|
||
"""
|
||
从文本中提取统计特征,并转换为向量表示
|
||
|
||
Args:
|
||
text: 文本
|
||
|
||
Returns:
|
||
特征向量,包含各种统计特征
|
||
"""
|
||
features = self.extract(text)
|
||
|
||
# 提取数值特征
|
||
vector = [
|
||
features['char_count'],
|
||
features['word_count'],
|
||
features['sentence_count'],
|
||
features['avg_word_length'],
|
||
features['avg_sentence_length'],
|
||
int(features['contains_number']),
|
||
int(features['contains_english']),
|
||
features['punctuation_ratio']
|
||
]
|
||
|
||
return np.array(vector, dtype=np.float32)
|
||
|
||
|
||
class POSFeatureExtractor(FeatureExtractor):
|
||
"""词性特征提取器,提取文本的词性特征"""
|
||
|
||
def __init__(self):
|
||
"""初始化词性特征提取器"""
|
||
super().__init__()
|
||
|
||
# 常见中文词性及其解释
|
||
self.pos_tags = {
|
||
'n': '名词', 'f': '方位名词', 's': '处所名词', 't': '时间名词',
|
||
'nr': '人名', 'ns': '地名', 'nt': '机构团体', 'nw': '作品名',
|
||
'nz': '其他专名', 'v': '动词', 'vd': '副动词', 'vn': '名动词',
|
||
'a': '形容词', 'ad': '副形词', 'an': '名形词', 'd': '副词',
|
||
'm': '数词', 'q': '量词', 'r': '代词', 'p': '介词',
|
||
'c': '连词', 'u': '助词', 'xc': '其他虚词', 'w': '标点符号'
|
||
}
|
||
|
||
def extract(self, text: str) -> Dict[str, Any]:
|
||
"""
|
||
从文本中提取词性特征
|
||
|
||
Args:
|
||
text: 文本
|
||
|
||
Returns:
|
||
特征字典,包含各种词性特征
|
||
"""
|
||
if not text:
|
||
return {
|
||
"pos_counts": {},
|
||
"pos_ratios": {}
|
||
}
|
||
|
||
# 使用jieba进行词性标注
|
||
pos_list = pseg.cut(text)
|
||
|
||
# 统计各词性的数量
|
||
pos_counts = {}
|
||
total_count = 0
|
||
|
||
for word, pos in pos_list:
|
||
if pos in pos_counts:
|
||
pos_counts[pos] += 1
|
||
else:
|
||
pos_counts[pos] = 1
|
||
total_count += 1
|
||
|
||
# 计算各词性的比例
|
||
pos_ratios = {pos: count / total_count for pos, count in pos_counts.items()} if total_count > 0 else {}
|
||
|
||
return {
|
||
"pos_counts": pos_counts,
|
||
"pos_ratios": pos_ratios
|
||
}
|
||
|
||
def extract_as_vector(self, text: str) -> np.ndarray:
|
||
"""
|
||
从文本中提取词性特征,并转换为向量表示
|
||
|
||
Args:
|
||
text: 文本
|
||
|
||
Returns:
|
||
特征向量,包含各词性的比例
|
||
"""
|
||
features = self.extract(text)
|
||
pos_ratios = features['pos_ratios']
|
||
|
||
# 按照 self.pos_tags 的顺序构建向量
|
||
vector = []
|
||
for pos in self.pos_tags.keys():
|
||
vector.append(pos_ratios.get(pos, 0.0))
|
||
|
||
return np.array(vector, dtype=np.float32)
|
||
|
||
|
||
class KeywordFeatureExtractor(FeatureExtractor):
|
||
"""关键词特征提取器,基于预定义关键词提取特征"""
|
||
|
||
def __init__(self, category_keywords: Optional[Dict[str, List[str]]] = None):
|
||
"""
|
||
初始化关键词特征提取器
|
||
|
||
Args:
|
||
category_keywords: 类别关键词字典,键为类别名称,值为关键词列表
|
||
"""
|
||
super().__init__()
|
||
self.category_keywords = category_keywords or self._get_default_keywords()
|
||
self.tokenizer = ChineseTokenizer()
|
||
|
||
def _get_default_keywords(self) -> Dict[str, List[str]]:
|
||
"""
|
||
获取默认的类别关键词
|
||
|
||
Returns:
|
||
类别关键词字典
|
||
"""
|
||
# 为每个类别定义一些示例关键词
|
||
default_keywords = {
|
||
"体育": ["比赛", "运动", "球员", "冠军", "球队", "足球", "篮球"],
|
||
"财经": ["股票", "基金", "投资", "市场", "经济", "金融", "股市"],
|
||
"房产": ["房价", "楼市", "地产", "购房", "房贷", "物业", "小区"],
|
||
"家居": ["装修", "家具", "设计", "卧室", "客厅", "厨房", "风格"],
|
||
"教育": ["学校", "学生", "考试", "教育", "大学", "课程", "老师"],
|
||
"科技": ["互联网", "科技", "创新", "数字", "智能", "研发", "技术"],
|
||
"时尚": ["时尚", "潮流", "服装", "搭配", "品牌", "美容", "穿着"],
|
||
"时政": ["政府", "政策", "国家", "发展", "会议", "主席", "总理"],
|
||
"游戏": ["游戏", "玩家", "电竞", "网游", "手游", "角色", "任务"],
|
||
"娱乐": ["明星", "电影", "节目", "综艺", "电视", "演员", "导演"],
|
||
"其他": ["其他", "一般", "常见", "普通", "正常", "通常", "传统"]
|
||
}
|
||
|
||
# 确保 CATEGORIES 中的每个类别都有关键词
|
||
for category in CATEGORIES:
|
||
if category not in default_keywords:
|
||
default_keywords[category] = [category]
|
||
|
||
return default_keywords
|
||
|
||
def extract(self, text: str) -> Dict[str, Any]:
|
||
"""
|
||
从文本中提取关键词特征
|
||
|
||
Args:
|
||
text: 文本
|
||
|
||
Returns:
|
||
特征字典,包含各类别的关键词匹配情况
|
||
"""
|
||
if not text:
|
||
return {
|
||
"keyword_matches": {cat: 0 for cat in self.category_keywords},
|
||
"keyword_match_ratios": {cat: 0.0 for cat in self.category_keywords}
|
||
}
|
||
|
||
# 对文本分词
|
||
words = set(self.tokenizer.tokenize(text, return_string=False))
|
||
|
||
# 统计各类别的关键词匹配数量
|
||
keyword_matches = {}
|
||
for category, keywords in self.category_keywords.items():
|
||
# 计算文本中包含的该类别关键词数量
|
||
matches = sum(1 for kw in keywords if kw in words)
|
||
keyword_matches[category] = matches
|
||
|
||
# 计算匹配比例(归一化)
|
||
total_matches = sum(keyword_matches.values())
|
||
keyword_match_ratios = {
|
||
cat: matches / total_matches if total_matches > 0 else 0.0
|
||
for cat, matches in keyword_matches.items()
|
||
}
|
||
|
||
return {
|
||
"keyword_matches": keyword_matches,
|
||
"keyword_match_ratios": keyword_match_ratios
|
||
}
|
||
|
||
def extract_as_vector(self, text: str) -> np.ndarray:
|
||
"""
|
||
从文本中提取关键词特征,并转换为向量表示
|
||
|
||
Args:
|
||
text: 文本
|
||
|
||
Returns:
|
||
特征向量,包含各类别的关键词匹配比例
|
||
"""
|
||
features = self.extract(text)
|
||
match_ratios = features['keyword_match_ratios']
|
||
|
||
# 按照 CATEGORIES 的顺序构建向量
|
||
vector = [match_ratios.get(cat, 0.0) for cat in CATEGORIES]
|
||
|
||
return np.array(vector, dtype=np.float32)
|
||
|
||
def update_keywords(self, category: str, keywords: List[str]) -> None:
|
||
"""
|
||
更新指定类别的关键词
|
||
|
||
Args:
|
||
category: 类别名称
|
||
keywords: 关键词列表
|
||
"""
|
||
self.category_keywords[category] = keywords
|
||
logger.info(f"已更新类别 {category} 的关键词,共 {len(keywords)} 个")
|
||
|
||
def add_keywords(self, category: str, keywords: List[str]) -> None:
|
||
"""
|
||
向指定类别添加关键词
|
||
|
||
Args:
|
||
category: 类别名称
|
||
keywords: 要添加的关键词列表
|
||
"""
|
||
if category in self.category_keywords:
|
||
existing_keywords = set(self.category_keywords[category])
|
||
for keyword in keywords:
|
||
existing_keywords.add(keyword)
|
||
self.category_keywords[category] = list(existing_keywords)
|
||
else:
|
||
self.category_keywords[category] = keywords
|
||
|
||
logger.info(f"已向类别 {category} 添加关键词,当前共 {len(self.category_keywords[category])} 个")
|
||
|
||
class CombinedFeatureExtractor(FeatureExtractor):
|
||
"""组合特征提取器,组合多个特征提取器的结果"""
|
||
|
||
def __init__(self, extractors: List[FeatureExtractor]):
|
||
"""
|
||
初始化组合特征提取器
|
||
|
||
Args:
|
||
extractors: 特征提取器列表
|
||
"""
|
||
super().__init__()
|
||
self.extractors = extractors
|
||
|
||
def extract(self, text: str) -> Dict[str, Any]:
|
||
"""
|
||
从文本中提取组合特征
|
||
|
||
Args:
|
||
text: 文本
|
||
|
||
Returns:
|
||
特征字典,包含所有特征提取器的结果
|
||
"""
|
||
combined_features = {}
|
||
for i, extractor in enumerate(self.extractors):
|
||
extractor_name = type(extractor).__name__
|
||
features = extractor.extract(text)
|
||
combined_features[extractor_name] = features
|
||
|
||
return combined_features
|
||
|
||
def extract_as_vector(self, text: str) -> np.ndarray:
|
||
"""
|
||
从文本中提取组合特征,并转换为向量表示
|
||
|
||
Args:
|
||
text: 文本
|
||
|
||
Returns:
|
||
特征向量,包含所有特征提取器的向量拼接
|
||
"""
|
||
# 获取所有特征提取器的向量
|
||
feature_vectors = [extractor.extract_as_vector(text) for extractor in self.extractors]
|
||
|
||
# 拼接向量
|
||
return np.concatenate(feature_vectors)
|
||
|
||
================================================================================
|
||
文件: preprocessing/data_augmentation.py
|
||
================================================================================
|
||
|
||
"""
|
||
数据增强模块:实现文本数据增强技术
|
||
"""
|
||
import random
|
||
import re
|
||
import jieba
|
||
import synonyms
|
||
import numpy as np
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
|
||
import copy
|
||
|
||
from config.model_config import RANDOM_SEED
|
||
from utils.logger import get_logger
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
|
||
# 设置随机种子以保证可重复性
|
||
random.seed(RANDOM_SEED)
|
||
np.random.seed(RANDOM_SEED)
|
||
|
||
logger = get_logger("DataAugmentation")
|
||
|
||
|
||
class TextAugmenter:
|
||
"""文本增强基类,定义通用接口"""
|
||
|
||
def __init__(self):
|
||
"""初始化文本增强器"""
|
||
pass
|
||
|
||
def augment(self, text: str) -> str:
|
||
"""
|
||
对文本进行增强
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
增强后的文本
|
||
"""
|
||
raise NotImplementedError("子类必须实现此方法")
|
||
|
||
def batch_augment(self, texts: List[str]) -> List[str]:
|
||
"""
|
||
批量对文本进行增强
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
|
||
Returns:
|
||
增强后的文本列表
|
||
"""
|
||
return [self.augment(text) for text in texts]
|
||
|
||
def augment_with_label(self, text: str, label: Any) -> Tuple[str, Any]:
|
||
"""
|
||
对文本进行增强,同时保留标签
|
||
|
||
Args:
|
||
text: 原始文本
|
||
label: 标签
|
||
|
||
Returns:
|
||
(增强后的文本, 标签)的元组
|
||
"""
|
||
return self.augment(text), label
|
||
|
||
def batch_augment_with_label(self, texts: List[str], labels: List[Any]) -> List[Tuple[str, Any]]:
|
||
"""
|
||
批量对文本进行增强,同时保留标签
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
labels: 标签列表
|
||
|
||
Returns:
|
||
(增强后的文本, 标签)的元组列表
|
||
"""
|
||
return [self.augment_with_label(text, label) for text, label in zip(texts, labels)]
|
||
|
||
|
||
class SynonymReplacement(TextAugmenter):
|
||
"""同义词替换增强器"""
|
||
|
||
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
|
||
replace_ratio: float = 0.1,
|
||
min_similarity: float = 0.7):
|
||
"""
|
||
初始化同义词替换增强器
|
||
|
||
Args:
|
||
tokenizer: 分词器,如果为None则创建一个新的分词器
|
||
replace_ratio: 替换比例,表示要替换的词占总词数的比例
|
||
min_similarity: 最小相似度,只有相似度大于该值的同义词才会被用于替换
|
||
"""
|
||
super().__init__()
|
||
self.tokenizer = tokenizer or ChineseTokenizer()
|
||
self.replace_ratio = replace_ratio
|
||
self.min_similarity = min_similarity
|
||
|
||
def _get_synonym(self, word: str) -> Optional[str]:
|
||
"""
|
||
获取词的同义词
|
||
|
||
Args:
|
||
word: 原始词
|
||
|
||
Returns:
|
||
同义词,如果没有合适的同义词则返回None
|
||
"""
|
||
# 使用synonyms包获取同义词
|
||
try:
|
||
synonyms_list = synonyms.nearby(word)
|
||
|
||
# synonyms.nearby返回一个元组,第一个元素是相似词列表,第二个元素是相似度列表
|
||
words = synonyms_list[0]
|
||
similarities = synonyms_list[1]
|
||
|
||
# 过滤掉相似度低于阈值的词和原词本身
|
||
valid_synonyms = [(w, s) for w, s in zip(words, similarities)
|
||
if s >= self.min_similarity and w != word]
|
||
|
||
if valid_synonyms:
|
||
# 按相似度排序,选择最相似的词
|
||
valid_synonyms.sort(key=lambda x: x[1], reverse=True)
|
||
return valid_synonyms[0][0]
|
||
|
||
return None
|
||
except:
|
||
return None
|
||
|
||
def augment(self, text: str) -> str:
|
||
"""
|
||
对文本进行同义词替换增强
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
增强后的文本
|
||
"""
|
||
if not text:
|
||
return text
|
||
|
||
# 分词
|
||
words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
|
||
|
||
if not words:
|
||
return text
|
||
|
||
# 计算要替换的词数量
|
||
n_replace = max(1, int(len(words) * self.replace_ratio))
|
||
|
||
# 随机选择要替换的词索引
|
||
replace_indices = random.sample(range(len(words)), min(n_replace, len(words)))
|
||
|
||
# 替换为同义词
|
||
for idx in replace_indices:
|
||
synonym = self._get_synonym(words[idx])
|
||
if synonym:
|
||
words[idx] = synonym
|
||
|
||
# 拼接为文本
|
||
augmented_text = ''.join(words)
|
||
|
||
return augmented_text
|
||
|
||
|
||
class RandomDeletion(TextAugmenter):
|
||
"""随机删除增强器"""
|
||
|
||
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
|
||
delete_ratio: float = 0.1):
|
||
"""
|
||
初始化随机删除增强器
|
||
|
||
Args:
|
||
tokenizer: 分词器,如果为None则创建一个新的分词器
|
||
delete_ratio: 删除比例,表示要删除的词占总词数的比例
|
||
"""
|
||
super().__init__()
|
||
self.tokenizer = tokenizer or ChineseTokenizer()
|
||
self.delete_ratio = delete_ratio
|
||
|
||
def augment(self, text: str) -> str:
|
||
"""
|
||
对文本进行随机删除增强
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
增强后的文本
|
||
"""
|
||
if not text:
|
||
return text
|
||
|
||
# 分词
|
||
words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
|
||
|
||
if len(words) <= 1:
|
||
return text
|
||
|
||
# 计算要删除的词数量
|
||
n_delete = max(1, int(len(words) * self.delete_ratio))
|
||
|
||
# 随机选择要删除的词索引
|
||
delete_indices = random.sample(range(len(words)), min(n_delete, len(words) - 1))
|
||
|
||
# 删除选中的词
|
||
augmented_words = [words[i] for i in range(len(words)) if i not in delete_indices]
|
||
|
||
# 拼接为文本
|
||
augmented_text = ''.join(augmented_words)
|
||
|
||
return augmented_text
|
||
|
||
|
||
class RandomSwap(TextAugmenter):
|
||
"""随机交换增强器"""
|
||
|
||
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
|
||
n_swaps: int = 1):
|
||
"""
|
||
初始化随机交换增强器
|
||
|
||
Args:
|
||
tokenizer: 分词器,如果为None则创建一个新的分词器
|
||
n_swaps: 交换次数
|
||
"""
|
||
super().__init__()
|
||
self.tokenizer = tokenizer or ChineseTokenizer()
|
||
self.n_swaps = n_swaps
|
||
|
||
def augment(self, text: str) -> str:
|
||
"""
|
||
对文本进行随机交换增强
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
增强后的文本
|
||
"""
|
||
if not text:
|
||
return text
|
||
|
||
# 分词
|
||
words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
|
||
|
||
if len(words) <= 1:
|
||
return text
|
||
|
||
# 进行n_swaps次随机交换
|
||
augmented_words = words.copy()
|
||
for _ in range(min(self.n_swaps, len(words) // 2)):
|
||
# 随机选择两个不同的索引
|
||
idx1, idx2 = random.sample(range(len(augmented_words)), 2)
|
||
|
||
# 交换两个词
|
||
augmented_words[idx1], augmented_words[idx2] = augmented_words[idx2], augmented_words[idx1]
|
||
|
||
# 拼接为文本
|
||
augmented_text = ''.join(augmented_words)
|
||
|
||
return augmented_text
|
||
|
||
|
||
class CompositeAugmenter(TextAugmenter):
|
||
"""组合增强器,组合多个增强器"""
|
||
|
||
def __init__(self, augmenters: List[TextAugmenter],
|
||
probs: Optional[List[float]] = None):
|
||
"""
|
||
初始化组合增强器
|
||
|
||
Args:
|
||
augmenters: 增强器列表
|
||
probs: 各增强器被选择的概率列表,如果为None则均匀选择
|
||
"""
|
||
super().__init__()
|
||
self.augmenters = augmenters
|
||
|
||
# 如果没有提供概率,则均匀分配
|
||
if probs is None:
|
||
self.probs = [1.0 / len(augmenters)] * len(augmenters)
|
||
else:
|
||
# 确保概率和为1
|
||
total = sum(probs)
|
||
self.probs = [p / total for p in probs]
|
||
|
||
assert len(self.augmenters) == len(self.probs), "增强器数量与概率数量不匹配"
|
||
|
||
def augment(self, text: str) -> str:
|
||
"""
|
||
对文本进行组合增强
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
增强后的文本
|
||
"""
|
||
if not text:
|
||
return text
|
||
|
||
# 根据概率随机选择一个增强器
|
||
augmenter = random.choices(self.augmenters, weights=self.probs, k=1)[0]
|
||
|
||
# 使用选中的增强器进行增强
|
||
return augmenter.augment(text)
|
||
|
||
|
||
class BackTranslation(TextAugmenter):
|
||
"""回译增强器"""
|
||
|
||
def __init__(self, translator=None, source_lang: str = 'zh',
|
||
target_langs: List[str] = None):
|
||
"""
|
||
初始化回译增强器
|
||
|
||
Args:
|
||
translator: 翻译器,需要实现translate方法
|
||
source_lang: 源语言代码
|
||
target_langs: 目标语言代码列表,如果为None则使用默认语言
|
||
"""
|
||
super().__init__()
|
||
|
||
# 如果没有提供翻译器,尝试使用第三方翻译库
|
||
if translator is None:
|
||
try:
|
||
# 尝试导入多种翻译库
|
||
# 首先尝试使用googletrans (需要单独安装: pip install googletrans==4.0.0-rc1)
|
||
try:
|
||
from googletrans import Translator
|
||
self.translator = Translator()
|
||
self.translate_func = self._google_translate
|
||
except ImportError:
|
||
# 如果googletrans不可用,尝试使用py-translate
|
||
try:
|
||
import translate
|
||
self.translator = translate
|
||
self.translate_func = self._py_translate
|
||
except ImportError:
|
||
logger.warning("未安装翻译库,回译功能将不可用。请安装googletrans或py-translate")
|
||
self.translator = None
|
||
self.translate_func = self._dummy_translate
|
||
except Exception as e:
|
||
logger.error(f"初始化翻译器失败: {e}")
|
||
self.translator = None
|
||
self.translate_func = self._dummy_translate
|
||
else:
|
||
self.translator = translator
|
||
self.translate_func = self._custom_translate
|
||
|
||
self.source_lang = source_lang
|
||
self.target_langs = target_langs or ['en', 'fr', 'de', 'es', 'ja']
|
||
|
||
def _google_translate(self, text: str, source_lang: str, target_lang: str) -> str:
|
||
"""使用googletrans进行翻译"""
|
||
try:
|
||
result = self.translator.translate(text, src=source_lang, dest=target_lang)
|
||
return result.text
|
||
except Exception as e:
|
||
logger.error(f"翻译失败: {e}")
|
||
return text
|
||
|
||
def _py_translate(self, text: str, source_lang: str, target_lang: str) -> str:
|
||
"""使用py-translate进行翻译"""
|
||
try:
|
||
return self.translator.translate(text, source_lang, target_lang)
|
||
except Exception as e:
|
||
logger.error(f"翻译失败: {e}")
|
||
return text
|
||
|
||
def _custom_translate(self, text: str, source_lang: str, target_lang: str) -> str:
|
||
"""使用自定义翻译器进行翻译"""
|
||
try:
|
||
return self.translator.translate(text, source_lang, target_lang)
|
||
except Exception as e:
|
||
logger.error(f"翻译失败: {e}")
|
||
return text
|
||
|
||
def _dummy_translate(self, text: str, source_lang: str, target_lang: str) -> str:
|
||
"""虚拟翻译功能,仅返回原文本"""
|
||
logger.warning("翻译功能不可用,使用原文本")
|
||
return text
|
||
|
||
def augment(self, text: str) -> str:
|
||
"""
|
||
对文本进行回译增强
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
增强后的文本
|
||
"""
|
||
if not text or self.translator is None:
|
||
return text
|
||
|
||
# 随机选择一个目标语言
|
||
target_lang = random.choice(self.target_langs)
|
||
|
||
try:
|
||
# 将源语言翻译为目标语言
|
||
translated = self.translate_func(text, self.source_lang, target_lang)
|
||
|
||
# 将目标语言翻译回源语言
|
||
back_translated = self.translate_func(translated, target_lang, self.source_lang)
|
||
|
||
return back_translated
|
||
except Exception as e:
|
||
logger.error(f"回译失败: {e}")
|
||
return text
|
||
|
||
================================================================================
|
||
文件: data/__init__.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: data/dataset.py
|
||
================================================================================
|
||
|
||
|
||
|
||
================================================================================
|
||
文件: data/dataloader.py
|
||
================================================================================
|
||
|
||
"""
|
||
数据加载模块:负责从文件系统加载原始文本数据
|
||
"""
|
||
import os
|
||
import glob
|
||
import time
|
||
from pathlib import Path
|
||
from typing import List, Dict, Tuple, Optional, Any
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import random
|
||
import numpy as np
|
||
|
||
from config.system_config import (
|
||
RAW_DATA_DIR, DATA_LOADING_WORKERS, CATEGORIES,
|
||
CATEGORY_TO_ID, ENCODING, MAX_MEMORY_GB, MAX_TEXT_PER_BATCH
|
||
)
|
||
from config.model_config import RANDOM_SEED
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import read_text_file, read_files_parallel, list_files
|
||
|
||
# 设置随机种子以保证可重复性
|
||
random.seed(RANDOM_SEED)
|
||
np.random.seed(RANDOM_SEED)
|
||
|
||
logger = get_logger("DataLoader")
|
||
|
||
|
||
class DataLoader:
|
||
"""负责加载THUCNews数据集的类"""
|
||
|
||
def __init__(self, data_dir: Optional[str] = None,
|
||
categories: Optional[List[str]] = None,
|
||
encoding: str = ENCODING,
|
||
max_workers: int = DATA_LOADING_WORKERS,
|
||
max_text_per_batch: int = MAX_TEXT_PER_BATCH):
|
||
"""
|
||
初始化数据加载器
|
||
|
||
Args:
|
||
data_dir: 数据目录,默认使用配置文件中的路径
|
||
categories: 要加载的类别列表,默认加载所有类别
|
||
encoding: 文件编码
|
||
max_workers: 最大工作线程数
|
||
max_text_per_batch: 每批处理的最大文本数量
|
||
"""
|
||
self.data_dir = Path(data_dir) if data_dir else RAW_DATA_DIR
|
||
self.categories = categories if categories else CATEGORIES
|
||
self.encoding = encoding
|
||
self.max_workers = max_workers
|
||
self.max_text_per_batch = max_text_per_batch
|
||
|
||
# 验证数据目录是否存在
|
||
if not self.data_dir.exists():
|
||
raise FileNotFoundError(f"数据目录不存在: {self.data_dir}")
|
||
|
||
# 验证类别是否存在
|
||
for category in self.categories:
|
||
category_dir = self.data_dir / category
|
||
if not category_dir.exists():
|
||
logger.warning(f"类别目录不存在: {category_dir}")
|
||
|
||
# 存储类别目录的映射
|
||
self.category_dirs = {
|
||
category: self.data_dir / category
|
||
for category in self.categories
|
||
if (self.data_dir / category).exists()
|
||
}
|
||
|
||
# 记录类别文件数量
|
||
self.category_file_counts = {}
|
||
|
||
# 统计并记录每个类别的文件数量
|
||
self._count_files()
|
||
logger.info(f"初始化完成,共找到 {sum(self.category_file_counts.values())} 个文本文件")
|
||
|
||
def _count_files(self) -> None:
|
||
"""统计每个类别的文件数量"""
|
||
for category, category_dir in self.category_dirs.items():
|
||
files = list(category_dir.glob("*.txt"))
|
||
self.category_file_counts[category] = len(files)
|
||
logger.info(f"类别 [{category}] 包含 {len(files)} 个文本文件")
|
||
|
||
def get_file_paths(self, category: Optional[str] = None,
|
||
sample_ratio: float = 1.0,
|
||
shuffle: bool = True) -> List[Tuple[str, str]]:
|
||
"""
|
||
获取指定类别的文件路径列表
|
||
|
||
Args:
|
||
category: 类别名称,如果为None则获取所有类别
|
||
sample_ratio: 采样比例,默认为1.0(全部)
|
||
shuffle: 是否打乱文件顺序
|
||
|
||
Returns:
|
||
包含(文件路径, 类别)元组的列表
|
||
"""
|
||
file_paths = []
|
||
|
||
# 确定要处理的类别
|
||
categories_to_process = [category] if category else self.categories
|
||
|
||
# 获取每个类别的文件路径
|
||
for cat in categories_to_process:
|
||
if cat in self.category_dirs:
|
||
category_dir = self.category_dirs[cat]
|
||
cat_files = list(category_dir.glob("*.txt"))
|
||
|
||
# 采样
|
||
if sample_ratio < 1.0:
|
||
sample_size = int(len(cat_files) * sample_ratio)
|
||
if shuffle:
|
||
cat_files = random.sample(cat_files, sample_size)
|
||
else:
|
||
cat_files = cat_files[:sample_size]
|
||
|
||
# 添加文件路径和对应的类别
|
||
file_paths.extend([(str(file), cat) for file in cat_files])
|
||
|
||
# 打乱全局顺序(如果需要)
|
||
if shuffle:
|
||
random.shuffle(file_paths)
|
||
|
||
return file_paths
|
||
|
||
def load_texts(self, file_paths: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||
"""
|
||
加载指定路径的文本内容
|
||
|
||
Args:
|
||
file_paths: 包含(文件路径, 类别)元组的列表
|
||
|
||
Returns:
|
||
包含(文本内容, 类别)元组的列表
|
||
"""
|
||
start_time = time.time()
|
||
texts_with_labels = []
|
||
|
||
# 提取文件路径列表
|
||
paths = [path for path, _ in file_paths]
|
||
labels = [label for _, label in file_paths]
|
||
|
||
# 并行加载文本内容
|
||
contents = read_files_parallel(paths, max_workers=self.max_workers, encoding=self.encoding)
|
||
|
||
# 将内容与标签配对
|
||
for content, label in zip(contents, labels):
|
||
if content: # 确保内容不为空
|
||
texts_with_labels.append((content, label))
|
||
|
||
elapsed = time.time() - start_time
|
||
logger.info(f"加载了 {len(texts_with_labels)} 个文本,用时 {elapsed:.2f} 秒")
|
||
|
||
return texts_with_labels
|
||
|
||
def load_data(self, categories: Optional[List[str]] = None,
|
||
sample_ratio: float = 1.0,
|
||
shuffle: bool = True,
|
||
return_generator: bool = False) -> Any:
|
||
"""
|
||
加载指定类别的所有数据
|
||
|
||
Args:
|
||
categories: 要加载的类别列表,默认为所有类别
|
||
sample_ratio: 采样比例,默认为1.0(全部)
|
||
shuffle: 是否打乱数据顺序
|
||
return_generator: 是否返回生成器(批量加载)
|
||
|
||
Returns:
|
||
如果return_generator为False,返回包含(文本内容, 类别)元组的列表
|
||
如果return_generator为True,返回一个生成器,每次产生一批数据
|
||
"""
|
||
# 确定要处理的类别
|
||
cats_to_process = categories if categories else self.categories
|
||
|
||
# 验证类别是否存在
|
||
for cat in cats_to_process:
|
||
if cat not in self.category_dirs:
|
||
logger.warning(f"类别 {cat} 不存在,将被忽略")
|
||
|
||
# 筛选存在的类别
|
||
cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs]
|
||
|
||
# 获取所有文件路径
|
||
all_file_paths = []
|
||
for cat in cats_to_process:
|
||
cat_files = self.get_file_paths(cat, sample_ratio=sample_ratio, shuffle=shuffle)
|
||
all_file_paths.extend(cat_files)
|
||
|
||
# 打乱全局顺序(如果需要)
|
||
if shuffle:
|
||
random.shuffle(all_file_paths)
|
||
|
||
# 如果需要返回生成器,分批次加载数据
|
||
if return_generator:
|
||
def data_generator():
|
||
for i in range(0, len(all_file_paths), self.max_text_per_batch):
|
||
batch_paths = all_file_paths[i:i + self.max_text_per_batch]
|
||
batch_data = self.load_texts(batch_paths)
|
||
yield batch_data
|
||
|
||
return data_generator()
|
||
|
||
# 否则,一次性加载所有数据
|
||
return self.load_texts(all_file_paths)
|
||
|
||
def load_balanced_data(self, n_per_category: int = 1000,
|
||
categories: Optional[List[str]] = None,
|
||
shuffle: bool = True) -> List[Tuple[str, str]]:
|
||
"""
|
||
加载平衡的数据集(每个类别的样本数量相同)
|
||
|
||
Args:
|
||
n_per_category: 每个类别加载的样本数量
|
||
categories: 要加载的类别列表,默认为所有类别
|
||
shuffle: 是否打乱数据顺序
|
||
|
||
Returns:
|
||
包含(文本内容, 类别)元组的列表
|
||
"""
|
||
# 确定要处理的类别
|
||
cats_to_process = categories if categories else self.categories
|
||
cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs]
|
||
|
||
balanced_data = []
|
||
|
||
for cat in cats_to_process:
|
||
# 获取该类别的文件路径
|
||
cat_files = self.get_file_paths(cat, shuffle=shuffle)
|
||
|
||
# 限制数量
|
||
cat_files = cat_files[:n_per_category]
|
||
|
||
# 加载文本
|
||
cat_data = self.load_texts(cat_files)
|
||
balanced_data.extend(cat_data)
|
||
|
||
# 打乱全局顺序(如果需要)
|
||
if shuffle:
|
||
random.shuffle(balanced_data)
|
||
|
||
return balanced_data
|
||
|
||
def get_category_distribution(self) -> Dict[str, int]:
|
||
"""
|
||
获取数据集的类别分布
|
||
|
||
Returns:
|
||
包含各类别样本数量的字典
|
||
"""
|
||
return self.category_file_counts
|
||
|
||
def get_data_stats(self) -> Dict[str, Any]:
|
||
"""
|
||
获取数据集的统计信息
|
||
|
||
Returns:
|
||
包含统计信息的字典
|
||
"""
|
||
# 计算总样本数
|
||
total_samples = sum(self.category_file_counts.values())
|
||
|
||
# 计算各类别占比
|
||
category_percentages = {
|
||
cat: count / total_samples * 100
|
||
for cat, count in self.category_file_counts.items()
|
||
}
|
||
|
||
# 采样几个文件计算平均文本长度
|
||
sample_files = []
|
||
for cat in self.categories:
|
||
if cat in self.category_dirs:
|
||
cat_files = list((self.data_dir / cat).glob("*.txt"))
|
||
if cat_files:
|
||
# 每个类别最多采样10个文件
|
||
sample_files.extend(random.sample(cat_files, min(10, len(cat_files))))
|
||
|
||
# 加载采样的文件内容
|
||
sample_contents = []
|
||
for file_path in sample_files:
|
||
content = read_text_file(str(file_path), encoding=self.encoding)
|
||
if content:
|
||
sample_contents.append(content)
|
||
|
||
# 计算平均文本长度(字符数)
|
||
avg_char_length = sum(len(content) for content in sample_contents) / len(
|
||
sample_contents) if sample_contents else 0
|
||
|
||
# 返回统计信息
|
||
return {
|
||
"total_samples": total_samples,
|
||
"category_counts": self.category_file_counts,
|
||
"category_percentages": category_percentages,
|
||
"average_text_length": avg_char_length,
|
||
"categories": self.categories,
|
||
"num_categories": len(self.categories),
|
||
}
|
||
|
||
|
||
================================================================================
|
||
文件: data/data_manager.py
|
||
================================================================================
|
||
|
||
"""
|
||
数据管理模块:负责数据的存储、读取和转换
|
||
"""
|
||
import os
|
||
import pickle
|
||
import json
|
||
import time
|
||
from typing import List, Dict, Tuple, Optional, Any, Union
|
||
import numpy as np
|
||
import pandas as pd
|
||
from collections import Counter
|
||
import matplotlib.pyplot as plt
|
||
from sklearn.model_selection import train_test_split
|
||
|
||
from config.system_config import (
|
||
PROCESSED_DATA_DIR, ENCODING, CATEGORY_TO_ID, ID_TO_CATEGORY
|
||
)
|
||
from config.model_config import (
|
||
VALIDATION_SPLIT, TEST_SPLIT, RANDOM_SEED
|
||
)
|
||
from utils.logger import get_logger
|
||
from utils.file_utils import (
|
||
save_pickle, load_pickle, save_json, load_json, ensure_dir
|
||
)
|
||
from data.dataloader import DataLoader
|
||
|
||
logger = get_logger("DataManager")
|
||
|
||
|
||
class DataManager:
|
||
"""数据管理类,负责数据的存储、读取和转换"""
|
||
|
||
def __init__(self, processed_dir: Optional[str] = None):
|
||
"""
|
||
初始化数据管理器
|
||
|
||
Args:
|
||
processed_dir: 处理后数据的存储目录,默认使用配置文件中的路径
|
||
"""
|
||
self.processed_dir = processed_dir or PROCESSED_DATA_DIR
|
||
ensure_dir(self.processed_dir)
|
||
|
||
# 数据分割后的存储
|
||
self.train_texts = []
|
||
self.train_labels = []
|
||
self.val_texts = []
|
||
self.val_labels = []
|
||
self.test_texts = []
|
||
self.test_labels = []
|
||
|
||
# 数据统计信息
|
||
self.stats = {}
|
||
|
||
# 标签编码映射
|
||
self.label_to_id = CATEGORY_TO_ID
|
||
self.id_to_label = ID_TO_CATEGORY
|
||
|
||
logger.info(f"数据管理器初始化完成,处理后数据将存储在 {self.processed_dir}")
|
||
|
||
def load_and_split_data(self, data_loader: DataLoader,
|
||
categories: Optional[List[str]] = None,
|
||
val_split: float = VALIDATION_SPLIT,
|
||
test_split: float = TEST_SPLIT,
|
||
sample_ratio: float = 1.0,
|
||
balanced: bool = False,
|
||
n_per_category: int = 1000,
|
||
save: bool = True) -> Dict[str, Any]:
|
||
"""
|
||
加载并分割数据集
|
||
|
||
Args:
|
||
data_loader: 数据加载器实例
|
||
categories: 要包含的类别列表,默认为所有类别
|
||
val_split: 验证集比例
|
||
test_split: 测试集比例
|
||
sample_ratio: 采样比例,默认为1.0(全部)
|
||
balanced: 是否平衡各类别的样本数量
|
||
n_per_category: 平衡模式下每个类别的样本数量
|
||
save: 是否保存处理后的数据
|
||
|
||
Returns:
|
||
包含分割后数据集的字典
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 加载数据
|
||
if balanced:
|
||
logger.info(f"加载平衡数据集,每个类别 {n_per_category} 个样本")
|
||
data = data_loader.load_balanced_data(
|
||
n_per_category=n_per_category,
|
||
categories=categories,
|
||
shuffle=True
|
||
)
|
||
else:
|
||
logger.info(f"加载数据集,采样比例 {sample_ratio}")
|
||
data = data_loader.load_data(
|
||
categories=categories,
|
||
sample_ratio=sample_ratio,
|
||
shuffle=True,
|
||
return_generator=False
|
||
)
|
||
|
||
logger.info(f"加载了 {len(data)} 个样本")
|
||
|
||
# 分离文本和标签
|
||
texts = [text for text, _ in data]
|
||
labels = [label for _, label in data]
|
||
|
||
# 进行标签编码
|
||
encoded_labels = np.array([self.label_to_id[label] for label in labels])
|
||
|
||
# 计算数据统计信息
|
||
self._compute_stats(texts, labels)
|
||
|
||
# 划分训练集、验证集和测试集
|
||
# 先分出测试集
|
||
if test_split > 0:
|
||
train_val_texts, self.test_texts, train_val_labels, self.test_labels = train_test_split(
|
||
texts, encoded_labels,
|
||
test_size=test_split,
|
||
random_state=RANDOM_SEED,
|
||
stratify=encoded_labels if len(set(encoded_labels)) > 1 else None
|
||
)
|
||
else:
|
||
train_val_texts, train_val_labels = texts, encoded_labels
|
||
self.test_texts, self.test_labels = [], []
|
||
|
||
# 再划分训练集和验证集
|
||
if val_split > 0:
|
||
self.train_texts, self.val_texts, self.train_labels, self.val_labels = train_test_split(
|
||
train_val_texts, train_val_labels,
|
||
test_size=val_split / (1 - test_split),
|
||
random_state=RANDOM_SEED,
|
||
stratify=train_val_labels if len(set(train_val_labels)) > 1 else None
|
||
)
|
||
else:
|
||
self.train_texts, self.train_labels = train_val_texts, train_val_labels
|
||
self.val_texts, self.val_labels = [], []
|
||
|
||
# 打印数据集划分结果
|
||
logger.info(f"数据集划分结果:")
|
||
logger.info(f" 训练集:{len(self.train_texts)} 个样本")
|
||
logger.info(f" 验证集:{len(self.val_texts)} 个样本")
|
||
logger.info(f" 测试集:{len(self.test_texts)} 个样本")
|
||
|
||
# 保存处理后的数据
|
||
if save:
|
||
self.save_data()
|
||
|
||
elapsed = time.time() - start_time
|
||
logger.info(f"数据加载和分割完成,用时 {elapsed:.2f} 秒")
|
||
|
||
return {
|
||
"train_texts": self.train_texts,
|
||
"train_labels": self.train_labels,
|
||
"val_texts": self.val_texts,
|
||
"val_labels": self.val_labels,
|
||
"test_texts": self.test_texts,
|
||
"test_labels": self.test_labels,
|
||
"stats": self.stats
|
||
}
|
||
|
||
def _compute_stats(self, texts: List[str], labels: List[str]) -> None:
|
||
"""
|
||
计算数据统计信息
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
labels: 标签列表
|
||
"""
|
||
# 文本数量
|
||
num_samples = len(texts)
|
||
|
||
# 类别分布
|
||
label_counter = Counter(labels)
|
||
label_distribution = {label: count / num_samples * 100 for label, count in label_counter.items()}
|
||
|
||
# 文本长度统计
|
||
text_lengths = [len(text) for text in texts]
|
||
avg_length = sum(text_lengths) / len(text_lengths)
|
||
max_length = max(text_lengths)
|
||
min_length = min(text_lengths)
|
||
|
||
# 前5个最长和最短的文本的长度
|
||
sorted_lengths = sorted(text_lengths)
|
||
shortest_lengths = sorted_lengths[:5]
|
||
longest_lengths = sorted_lengths[-5:]
|
||
|
||
# 95%的文本长度分位数
|
||
percentile_95 = np.percentile(text_lengths, 95)
|
||
|
||
# 存储统计信息
|
||
self.stats = {
|
||
"num_samples": num_samples,
|
||
"num_categories": len(label_counter),
|
||
"label_counter": label_counter,
|
||
"label_distribution": label_distribution,
|
||
"text_length": {
|
||
"average": avg_length,
|
||
"max": max_length,
|
||
"min": min_length,
|
||
"percentile_95": percentile_95,
|
||
"shortest_5": shortest_lengths,
|
||
"longest_5": longest_lengths
|
||
}
|
||
}
|
||
|
||
def save_data(self, save_dir: Optional[str] = None) -> None:
|
||
"""
|
||
保存处理后的数据
|
||
|
||
Args:
|
||
save_dir: 保存目录,默认使用初始化时设置的目录
|
||
"""
|
||
save_dir = save_dir or self.processed_dir
|
||
ensure_dir(save_dir)
|
||
|
||
# 保存训练集
|
||
save_pickle(
|
||
{"texts": self.train_texts, "labels": self.train_labels},
|
||
os.path.join(save_dir, "train_data.pkl")
|
||
)
|
||
|
||
# 保存验证集
|
||
if len(self.val_texts) > 0:
|
||
save_pickle(
|
||
{"texts": self.val_texts, "labels": self.val_labels},
|
||
os.path.join(save_dir, "val_data.pkl")
|
||
)
|
||
|
||
# 保存测试集
|
||
if len(self.test_texts) > 0:
|
||
save_pickle(
|
||
{"texts": self.test_texts, "labels": self.test_labels},
|
||
os.path.join(save_dir, "test_data.pkl")
|
||
)
|
||
|
||
# 保存标签编码映射
|
||
save_json(
|
||
{"label_to_id": self.label_to_id, "id_to_label": self.id_to_label},
|
||
os.path.join(save_dir, "label_mapping.json")
|
||
)
|
||
|
||
# 保存数据统计信息
|
||
# 将Counter对象转换为普通字典以便JSON序列化
|
||
stats_for_json = self.stats.copy()
|
||
if "label_counter" in stats_for_json:
|
||
stats_for_json["label_counter"] = dict(stats_for_json["label_counter"])
|
||
|
||
save_json(
|
||
stats_for_json,
|
||
os.path.join(save_dir, "data_stats.json")
|
||
)
|
||
|
||
logger.info(f"已将处理后的数据保存到 {save_dir}")
|
||
|
||
def load_data(self, load_dir: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
加载处理后的数据
|
||
|
||
Args:
|
||
load_dir: 加载目录,默认使用初始化时设置的目录
|
||
|
||
Returns:
|
||
包含加载的数据集的字典
|
||
"""
|
||
load_dir = load_dir or self.processed_dir
|
||
|
||
# 加载训练集
|
||
train_data_path = os.path.join(load_dir, "train_data.pkl")
|
||
if os.path.exists(train_data_path):
|
||
train_data = load_pickle(train_data_path)
|
||
self.train_texts = train_data["texts"]
|
||
self.train_labels = train_data["labels"]
|
||
logger.info(f"已加载训练集,包含 {len(self.train_texts)} 个样本")
|
||
else:
|
||
logger.warning(f"训练集文件不存在: {train_data_path}")
|
||
self.train_texts, self.train_labels = [], []
|
||
|
||
# 加载验证集
|
||
val_data_path = os.path.join(load_dir, "val_data.pkl")
|
||
if os.path.exists(val_data_path):
|
||
val_data = load_pickle(val_data_path)
|
||
self.val_texts = val_data["texts"]
|
||
self.val_labels = val_data["labels"]
|
||
logger.info(f"已加载验证集,包含 {len(self.val_texts)} 个样本")
|
||
else:
|
||
logger.warning(f"验证集文件不存在: {val_data_path}")
|
||
self.val_texts, self.val_labels = [], []
|
||
|
||
# 加载测试集
|
||
test_data_path = os.path.join(load_dir, "test_data.pkl")
|
||
if os.path.exists(test_data_path):
|
||
test_data = load_pickle(test_data_path)
|
||
self.test_texts = test_data["texts"]
|
||
self.test_labels = test_data["labels"]
|
||
logger.info(f"已加载测试集,包含 {len(self.test_texts)} 个样本")
|
||
else:
|
||
logger.warning(f"测试集文件不存在: {test_data_path}")
|
||
self.test_texts, self.test_labels = [], []
|
||
|
||
# 加载标签编码映射
|
||
mapping_path = os.path.join(load_dir, "label_mapping.json")
|
||
if os.path.exists(mapping_path):
|
||
mapping = load_json(mapping_path)
|
||
self.label_to_id = mapping["label_to_id"]
|
||
self.id_to_label = mapping["id_to_label"]
|
||
# 将字符串键转换为整数(JSON序列化会将所有键转为字符串)
|
||
self.id_to_label = {int(k): v for k, v in self.id_to_label.items()}
|
||
logger.info(f"已加载标签编码映射,共 {len(self.label_to_id)} 个类别")
|
||
|
||
# 加载数据统计信息
|
||
stats_path = os.path.join(load_dir, "data_stats.json")
|
||
if os.path.exists(stats_path):
|
||
self.stats = load_json(stats_path)
|
||
logger.info("已加载数据统计信息")
|
||
|
||
return {
|
||
"train_texts": self.train_texts,
|
||
"train_labels": self.train_labels,
|
||
"val_texts": self.val_texts,
|
||
"val_labels": self.val_labels,
|
||
"test_texts": self.test_texts,
|
||
"test_labels": self.test_labels,
|
||
"stats": self.stats
|
||
}
|
||
|
||
def get_label_distribution(self, dataset: str = "train") -> Dict[str, float]:
|
||
"""
|
||
获取指定数据集的标签分布
|
||
|
||
Args:
|
||
dataset: 数据集名称,可选值:'train', 'val', 'test'
|
||
|
||
Returns:
|
||
标签分布字典,键为类别名称,值为比例
|
||
"""
|
||
if dataset == "train":
|
||
labels = self.train_labels
|
||
elif dataset == "val":
|
||
labels = self.val_labels
|
||
elif dataset == "test":
|
||
labels = self.test_labels
|
||
else:
|
||
raise ValueError(f"不支持的数据集名称: {dataset}")
|
||
|
||
# 计算标签分布
|
||
label_counter = Counter(labels)
|
||
num_samples = len(labels)
|
||
|
||
# 将数字标签转换为类别名称
|
||
distribution = {}
|
||
for label_id, count in label_counter.items():
|
||
label_name = self.id_to_label.get(label_id, str(label_id))
|
||
distribution[label_name] = count / num_samples * 100
|
||
|
||
return distribution
|
||
|
||
def visualize_label_distribution(self, dataset: str = "train",
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
可视化标签分布
|
||
|
||
Args:
|
||
dataset: 数据集名称,可选值:'train', 'val', 'test', 'all'
|
||
save_path: 图表保存路径,默认为None(显示而不保存)
|
||
"""
|
||
plt.figure(figsize=(12, 8))
|
||
|
||
if dataset == "all":
|
||
# 显示所有数据集的标签分布
|
||
train_dist = self.get_label_distribution("train")
|
||
val_dist = self.get_label_distribution("val") if len(self.val_labels) > 0 else {}
|
||
test_dist = self.get_label_distribution("test") if len(self.test_labels) > 0 else {}
|
||
|
||
# 准备数据
|
||
categories = list(train_dist.keys())
|
||
train_values = [train_dist.get(cat, 0) for cat in categories]
|
||
val_values = [val_dist.get(cat, 0) for cat in categories]
|
||
test_values = [test_dist.get(cat, 0) for cat in categories]
|
||
|
||
# 绘制条形图
|
||
x = np.arange(len(categories))
|
||
width = 0.25
|
||
|
||
plt.bar(x - width, train_values, width, label="Training")
|
||
if val_values:
|
||
plt.bar(x, val_values, width, label="Validation")
|
||
if test_values:
|
||
plt.bar(x + width, test_values, width, label="Testing")
|
||
|
||
plt.xlabel("Categories")
|
||
plt.ylabel("Percentage (%)")
|
||
plt.title("Label Distribution Across Datasets")
|
||
plt.xticks(x, categories, rotation=45, ha="right")
|
||
plt.legend()
|
||
plt.tight_layout()
|
||
else:
|
||
# 显示单个数据集的标签分布
|
||
distribution = self.get_label_distribution(dataset)
|
||
|
||
# 按值排序
|
||
sorted_items = sorted(distribution.items(), key=lambda x: x[1], reverse=True)
|
||
categories = [item[0] for item in sorted_items]
|
||
values = [item[1] for item in sorted_items]
|
||
|
||
# 绘制条形图
|
||
plt.bar(categories, values, color='skyblue')
|
||
plt.xlabel("Categories")
|
||
plt.ylabel("Percentage (%)")
|
||
plt.title(f"Label Distribution in {dataset.capitalize()} Dataset")
|
||
plt.xticks(rotation=45, ha="right")
|
||
plt.tight_layout()
|
||
|
||
# 保存或显示图表
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"标签分布图已保存到 {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
def visualize_text_length_distribution(self, dataset: str = "train",
|
||
bins: int = 50,
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
可视化文本长度分布
|
||
|
||
Args:
|
||
dataset: 数据集名称,可选值:'train', 'val', 'test'
|
||
bins: 直方图的箱数
|
||
save_path: 图表保存路径,默认为None(显示而不保存)
|
||
"""
|
||
if dataset == "train":
|
||
texts = self.train_texts
|
||
elif dataset == "val":
|
||
texts = self.val_texts
|
||
elif dataset == "test":
|
||
texts = self.test_texts
|
||
else:
|
||
raise ValueError(f"不支持的数据集名称: {dataset}")
|
||
|
||
# 计算文本长度
|
||
text_lengths = [len(text) for text in texts]
|
||
|
||
# 绘制直方图
|
||
plt.figure(figsize=(10, 6))
|
||
plt.hist(text_lengths, bins=bins, color='skyblue', alpha=0.7)
|
||
|
||
# 计算并绘制一些统计量
|
||
avg_length = sum(text_lengths) / len(text_lengths)
|
||
median_length = np.median(text_lengths)
|
||
percentile_95 = np.percentile(text_lengths, 95)
|
||
|
||
plt.axvline(avg_length, color='red', linestyle='dashed', linewidth=1, label=f'Mean: {avg_length:.1f}')
|
||
plt.axvline(median_length, color='green', linestyle='dashed', linewidth=1, label=f'Median: {median_length:.1f}')
|
||
plt.axvline(percentile_95, color='purple', linestyle='dashed', linewidth=1,
|
||
label=f'95th Percentile: {percentile_95:.1f}')
|
||
|
||
plt.xlabel('Text Length (characters)')
|
||
plt.ylabel('Frequency')
|
||
plt.title(f'Text Length Distribution in {dataset.capitalize()} Dataset')
|
||
plt.legend()
|
||
plt.tight_layout()
|
||
|
||
# 保存或显示图表
|
||
if save_path:
|
||
plt.savefig(save_path)
|
||
logger.info(f"文本长度分布图已保存到 {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
def get_data_summary(self) -> Dict[str, Any]:
|
||
"""
|
||
获取数据集的摘要信息
|
||
|
||
Returns:
|
||
包含数据摘要的字典
|
||
"""
|
||
# 获取数据集的基本信息
|
||
summary = {
|
||
"train_size": len(self.train_texts),
|
||
"val_size": len(self.val_texts),
|
||
"test_size": len(self.test_texts),
|
||
"num_categories": len(self.label_to_id),
|
||
"categories": list(self.label_to_id.keys()),
|
||
}
|
||
|
||
# 添加训练集的标签分布
|
||
if len(self.train_texts) > 0:
|
||
summary["train_label_distribution"] = self.get_label_distribution("train")
|
||
|
||
# 添加验证集的标签分布
|
||
if len(self.val_texts) > 0:
|
||
summary["val_label_distribution"] = self.get_label_distribution("val")
|
||
|
||
# 添加测试集的标签分布
|
||
if len(self.test_texts) > 0:
|
||
summary["test_label_distribution"] = self.get_label_distribution("test")
|
||
|
||
# 添加更多统计信息(如果有)
|
||
if self.stats:
|
||
# 只添加一些关键的统计信息
|
||
if "text_length" in self.stats:
|
||
summary["text_length_stats"] = self.stats["text_length"]
|
||
|
||
return summary
|
||
|
||
def export_to_pandas(self, dataset: str = "train") -> pd.DataFrame:
|
||
"""
|
||
将数据导出为Pandas DataFrame
|
||
|
||
Args:
|
||
dataset: 数据集名称,可选值:'train', 'val', 'test'
|
||
|
||
Returns:
|
||
Pandas DataFrame
|
||
"""
|
||
if dataset == "train":
|
||
texts = self.train_texts
|
||
labels_ids = self.train_labels
|
||
elif dataset == "val":
|
||
texts = self.val_texts
|
||
labels_ids = self.val_labels
|
||
elif dataset == "test":
|
||
texts = self.test_texts
|
||
labels_ids = self.test_labels
|
||
else:
|
||
raise ValueError(f"不支持的数据集名称: {dataset}")
|
||
|
||
# 将数字标签转换为类别名称
|
||
labels = [self.id_to_label.get(label_id, str(label_id)) for label_id in labels_ids]
|
||
|
||
# 创建DataFrame
|
||
df = pd.DataFrame({
|
||
"text": texts,
|
||
"label_id": labels_ids,
|
||
"label": labels
|
||
})
|
||
|
||
return df
|
||
|
||
def get_label_name(self, label_id: int) -> str:
|
||
"""
|
||
获取标签ID对应的类别名称
|
||
|
||
Args:
|
||
label_id: 标签ID
|
||
|
||
Returns:
|
||
类别名称
|
||
"""
|
||
return self.id_to_label.get(label_id, str(label_id))
|
||
|
||
def get_label_id(self, label_name: str) -> int:
|
||
"""
|
||
获取类别名称对应的标签ID
|
||
|
||
Args:
|
||
label_name: 类别名称
|
||
|
||
Returns:
|
||
标签ID
|
||
"""
|
||
return self.label_to_id.get(label_name, -1)
|
||
|
||
def get_data(self, dataset: str = "train") -> Tuple[List[str], np.ndarray]:
|
||
"""
|
||
获取指定数据集的文本和标签
|
||
|
||
Args:
|
||
dataset: 数据集名称,可选值:'train', 'val', 'test'
|
||
|
||
Returns:
|
||
(文本列表, 标签数组)的元组
|
||
"""
|
||
if dataset == "train":
|
||
return self.train_texts, self.train_labels
|
||
elif dataset == "val":
|
||
return self.val_texts, self.val_labels
|
||
elif dataset == "test":
|
||
return self.test_texts, self.test_labels
|
||
else:
|
||
raise ValueError(f"不支持的数据集名称: {dataset}")
|