diff --git a/export_all_pyfile.py b/export_all_pyfile.py index e69de29..83d71e7 100644 --- a/export_all_pyfile.py +++ b/export_all_pyfile.py @@ -0,0 +1,84 @@ +import os + + +def find_all_python_files(root_dir='.'): + """ + 查找指定目录及其所有子目录下的所有Python文件 + + Args: + root_dir: 根目录路径,默认为当前目录 + + Returns: + 包含所有Python文件路径的列表 + """ + python_files = [] + + # 遍历根目录及所有子目录 + for dirpath, dirnames, filenames in os.walk(root_dir): + # 查找所有.py文件 + for filename in filenames: + if filename.endswith('.py'): + # 构建完整文件路径 + full_path = os.path.join(dirpath, filename) + python_files.append(full_path) + + return python_files + + +def export_file_contents_to_txt(python_files, output_file='python_contents.txt'): + """ + 将所有Python文件的内容导出到一个TXT文件 + + Args: + python_files: Python文件路径列表 + output_file: 输出TXT文件名 + """ + with open(output_file, 'w', encoding='utf-8') as outfile: + for file_path in python_files: + # 获取相对路径以便更好地显示 + rel_path = os.path.relpath(file_path) + + # 写入文件分隔符 + outfile.write(f"\n{'=' * 80}\n") + outfile.write(f"文件: {rel_path}\n") + outfile.write(f"{'=' * 80}\n\n") + + # 读取Python文件内容并写入输出文件 + try: + with open(file_path, 'r', encoding='utf-8') as infile: + content = infile.read() + outfile.write(content) + outfile.write("\n") # 文件末尾添加换行 + except Exception as e: + outfile.write(f"[无法读取文件内容: {str(e)}]\n") + + print(f"已将{len(python_files)}个Python文件的内容导出到 {output_file}") + + +def main(): + # 获取当前工作目录 + current_dir = os.getcwd() + print(f"正在搜索目录: {current_dir}") + + # 查找所有Python文件 + python_files = find_all_python_files(current_dir) + + if python_files: + print(f"找到 {len(python_files)} 个Python文件") + + # 导出所有文件内容到TXT + export_file_contents_to_txt(python_files) + + # 打印处理的文件列表 + for i, file_path in enumerate(python_files[:10]): + rel_path = os.path.relpath(file_path) + print(f"{i + 1}. {rel_path}") + + if len(python_files) > 10: + print(f"... 还有 {len(python_files) - 10} 个文件") + else: + print("未找到任何Python文件") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/python_contents.txt b/python_contents.txt new file mode 100644 index 0000000..6d06dde --- /dev/null +++ b/python_contents.txt @@ -0,0 +1,10286 @@ + +================================================================================ +文件: export_all_pyfile.py +================================================================================ + +import os + + +def find_all_python_files(root_dir='.'): + """ + 查找指定目录及其所有子目录下的所有Python文件 + + Args: + root_dir: 根目录路径,默认为当前目录 + + Returns: + 包含所有Python文件路径的列表 + """ + python_files = [] + + # 遍历根目录及所有子目录 + for dirpath, dirnames, filenames in os.walk(root_dir): + # 查找所有.py文件 + for filename in filenames: + if filename.endswith('.py'): + # 构建完整文件路径 + full_path = os.path.join(dirpath, filename) + python_files.append(full_path) + + return python_files + + +def export_file_contents_to_txt(python_files, output_file='python_contents.txt'): + """ + 将所有Python文件的内容导出到一个TXT文件 + + Args: + python_files: Python文件路径列表 + output_file: 输出TXT文件名 + """ + with open(output_file, 'w', encoding='utf-8') as outfile: + for file_path in python_files: + # 获取相对路径以便更好地显示 + rel_path = os.path.relpath(file_path) + + # 写入文件分隔符 + outfile.write(f"\n{'=' * 80}\n") + outfile.write(f"文件: {rel_path}\n") + outfile.write(f"{'=' * 80}\n\n") + + # 读取Python文件内容并写入输出文件 + try: + with open(file_path, 'r', encoding='utf-8') as infile: + content = infile.read() + outfile.write(content) + outfile.write("\n") # 文件末尾添加换行 + except Exception as e: + outfile.write(f"[无法读取文件内容: {str(e)}]\n") + + print(f"已将{len(python_files)}个Python文件的内容导出到 {output_file}") + + +def main(): + # 获取当前工作目录 + current_dir = os.getcwd() + print(f"正在搜索目录: {current_dir}") + + # 查找所有Python文件 + python_files = find_all_python_files(current_dir) + + if python_files: + print(f"找到 {len(python_files)} 个Python文件") + + # 导出所有文件内容到TXT + export_file_contents_to_txt(python_files) + + # 打印处理的文件列表 + for i, file_path in enumerate(python_files[:10]): + rel_path = os.path.relpath(file_path) + print(f"{i + 1}. {rel_path}") + + if len(python_files) > 10: + print(f"... 还有 {len(python_files) - 10} 个文件") + else: + print("未找到任何Python文件") + + +if __name__ == "__main__": + main() + +================================================================================ +文件: setup.py +================================================================================ + +from setuptools import setup, find_packages + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + +setup( + name="chinese-text-classification", + version="1.0.0", + author="Your Name", + author_email="your.email@example.com", + description="基于Python的中文文本分类系统", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/yourusername/chinese-text-classification", + packages=find_packages(), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires=">=3.7", + install_requires=[ + "tensorflow>=2.5.0", + "numpy>=1.19.5", + "pandas>=1.3.0", + "scikit-learn>=0.24.2", + "matplotlib>=3.4.2", + "jieba>=0.42.1", + "tqdm>=4.61.1", + "gensim>=4.0.1", + "flask>=2.0.1", + "fastapi>=0.68.0", + "uvicorn>=0.15.0" + ], + entry_points={ + "console_scripts": [ + "text-classifier=main:main", + ], + }, +) + + +================================================================================ +文件: main.py +================================================================================ + +""" +主入口文件:整合系统的所有功能,提供命令行接口 +""" +import os +import sys +import argparse +import logging +from typing import List, Optional + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger("main") + +# 命令行参数 +parser = argparse.ArgumentParser(description="中文文本分类系统") +subparsers = parser.add_subparsers(dest="command", help="命令") + +# 训练命令 +train_parser = subparsers.add_parser("train", help="训练模型") +train_parser.add_argument("--data_dir", help="数据目录") +train_parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型") +train_parser.add_argument("--epochs", type=int, default=10, help="训练轮数") +train_parser.add_argument("--batch_size", type=int, default=64, help="批大小") +train_parser.add_argument("--save_dir", help="模型保存目录") + +# 评估命令 +evaluate_parser = subparsers.add_parser("evaluate", help="评估模型") +evaluate_parser.add_argument("--model_path", required=True, help="模型路径") +evaluate_parser.add_argument("--data_dir", help="数据目录") +evaluate_parser.add_argument("--batch_size", type=int, default=64, help="批大小") +evaluate_parser.add_argument("--output_dir", help="评估结果输出目录") + +# 预测命令 +predict_parser = subparsers.add_parser("predict", help="使用模型预测") +predict_parser.add_argument("--model_path", help="模型路径") +predict_parser.add_argument("--text", help="要预测的文本") +predict_parser.add_argument("--file", help="要预测的文件") +predict_parser.add_argument("--output", help="输出文件") + +# Web服务命令 +web_parser = subparsers.add_parser("web", help="启动Web服务") +web_parser.add_argument("--host", default="0.0.0.0", help="服务器主机") +web_parser.add_argument("--port", type=int, default=5000, help="服务器端口") +web_parser.add_argument("--debug", action="store_true", help="是否开启调试模式") + +# API服务命令 +api_parser = subparsers.add_parser("api", help="启动API服务") +api_parser.add_argument("--host", default="0.0.0.0", help="服务器主机") +api_parser.add_argument("--port", type=int, default=8000, help="服务器端口") + +# CLI命令 +cli_parser = subparsers.add_parser("cli", help="启动命令行接口") +cli_parser.add_argument("--model_path", help="模型路径") +cli_parser.add_argument("--interactive", action="store_true", help="是否开启交互模式") + + +def main(): + """主函数""" + args = parser.parse_args() + + # 如果没有指定命令,显示帮助信息 + if not args.command: + parser.print_help() + return 0 + + # 根据命令调用相应的功能 + if args.command == "train": + # 导入训练模块 + from scripts.train import train_model + + # 调用训练功能 + train_model( + data_dir=args.data_dir, + model_type=args.model_type, + epochs=args.epochs, + batch_size=args.batch_size, + save_dir=args.save_dir + ) + + elif args.command == "evaluate": + # 导入评估模块 + from scripts.evaluate import evaluate_model + + # 调用评估功能 + evaluate_model( + model_path=args.model_path, + data_dir=args.data_dir, + batch_size=args.batch_size, + output_dir=args.output_dir + ) + + elif args.command == "predict": + # 导入预测模块 + from scripts.predict import predict_text, predict_file + + # 根据输入类型调用相应的预测功能 + if args.text: + predict_text(args.text, args.model_path, args.output) + elif args.file: + predict_file(args.file, args.model_path, args.output) + else: + logger.error("请提供要预测的文本或文件") + return 1 + + elif args.command == "web": + # 导入Web服务模块 + from interface.web.app import run_server + + # 启动Web服务 + run_server(host=args.host, port=args.port, debug=args.debug) + + elif args.command == "api": + # 导入API服务模块 + from interface.api import run_server + + # 启动API服务 + run_server(host=args.host, port=args.port) + + elif args.command == "cli": + # 导入CLI模块 + from interface.cli import main as cli_main + + # 将命令行参数转换为CLI模块可接受的格式 + sys.argv = ["interface/cli.py"] + + if args.interactive: + sys.argv.append("interactive") + if args.model_path: + sys.argv.extend(["--model_path", args.model_path]) + elif args.model_path: + sys.argv.extend(["list", "--model_path", args.model_path]) + else: + sys.argv.append("list") + + # 调用CLI主函数 + return cli_main() + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) + + +================================================================================ +文件: interface/__init__.py +================================================================================ + + + +================================================================================ +文件: interface/api.py +================================================================================ + +""" +API接口模块:提供REST API接口 +""" +import os +import sys +import json +import time +from typing import List, Dict, Tuple, Optional, Any, Union +from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query, Depends, Request +from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +import uvicorn +import asyncio +import pandas as pd +import io + +# 将项目根目录添加到sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from config.system_config import CLASSIFIERS_DIR, CATEGORIES +from models.model_factory import ModelFactory +from preprocessing.tokenization import ChineseTokenizer +from preprocessing.vectorizer import SequenceVectorizer +from inference.predictor import Predictor +from inference.batch_processor import BatchProcessor +from utils.logger import get_logger + +logger = get_logger("API") + + +# 数据模型 +class TextItem(BaseModel): + text: str + id: Optional[str] = None + + +class BatchPredictRequest(BaseModel): + texts: List[TextItem] + top_k: Optional[int] = 1 + + +class BatchFileRequest(BaseModel): + file_paths: List[str] + top_k: Optional[int] = 1 + + +class ModelInfo(BaseModel): + id: str + name: str + type: str + num_classes: int + created_time: str + file_size: str + + +# 应用实例 +app = FastAPI( + title="中文文本分类系统API", + description="提供中文文本分类功能的REST API", + version="1.0.0" +) + +# 允许跨域请求 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 全局对象 +predictor = None + + +def get_predictor() -> Predictor: + """ + 获取或创建全局Predictor实例 + + Returns: + Predictor实例 + """ + global predictor + if predictor is None: + # 获取可用模型列表 + models_info = ModelFactory.get_available_models() + + if not models_info: + raise HTTPException(status_code=500, detail="未找到可用的模型") + + # 使用最新的模型 + model_path = models_info[0]['path'] + logger.info(f"API加载模型: {model_path}") + + # 加载模型 + model = ModelFactory.load_model(model_path) + + # 创建分词器 + tokenizer = ChineseTokenizer() + + # 创建预测器 + predictor = Predictor( + model=model, + tokenizer=tokenizer, + class_names=CATEGORIES, + batch_size=64 + ) + + return predictor + + +@app.get("/") +async def root(): + """API根路径""" + return {"message": "欢迎使用中文文本分类系统API"} + + +@app.post("/predict/text") +async def predict_text(text: str = Form(...), top_k: int = Form(1)): + """ + 预测单条文本 + + Args: + text: 要预测的文本 + top_k: 返回概率最高的前k个类别 + + Returns: + 预测结果 + """ + logger.info(f"接收到文本预测请求,文本长度: {len(text)}") + + try: + # 获取预测器 + predictor = get_predictor() + + # 预测 + start_time = time.time() + result = predictor.predict( + text=text, + return_top_k=top_k, + return_probabilities=True + ) + prediction_time = time.time() - start_time + + # 构建响应 + response = { + "success": True, + "predictions": result if top_k > 1 else [result], + "time": prediction_time + } + + return response + + except Exception as e: + logger.error(f"预测文本时出错: {e}") + raise HTTPException(status_code=500, detail=f"预测文本时出错: {str(e)}") + + +@app.post("/predict/batch") +async def predict_batch(request: BatchPredictRequest): + """ + 批量预测文本 + + Args: + request: 包含文本列表和参数的请求 + + Returns: + 批量预测结果 + """ + texts = [item.text for item in request.texts] + ids = [item.id or str(i) for i, item in enumerate(request.texts)] + + logger.info(f"接收到批量预测请求,共 {len(texts)} 条文本") + + try: + # 获取预测器 + predictor = get_predictor() + + # 预测 + start_time = time.time() + results = predictor.predict_batch( + texts=texts, + return_top_k=request.top_k, + return_probabilities=True + ) + prediction_time = time.time() - start_time + + # 构建响应 + response = { + "success": True, + "total": len(texts), + "time": prediction_time, + "results": {} + } + + # 将结果关联到ID + for i, (id_val, result) in enumerate(zip(ids, results)): + response["results"][id_val] = result + + return response + + except Exception as e: + logger.error(f"批量预测文本时出错: {e}") + raise HTTPException(status_code=500, detail=f"批量预测文本时出错: {str(e)}") + + +@app.post("/predict/file") +async def predict_file(file: UploadFile = File(...), top_k: int = Form(1)): + """ + 预测文件内容 + + Args: + file: 上传的文件 + top_k: 返回概率最高的前k个类别 + + Returns: + 预测结果 + """ + logger.info(f"接收到文件预测请求,文件名: {file.filename}") + + try: + # 读取文件内容 + content = await file.read() + + # 根据文件类型处理内容 + if file.filename.endswith(('.txt', '.md')): + # 文本文件 + text = content.decode('utf-8') + + # 获取预测器 + predictor = get_predictor() + + # 预测 + result = predictor.predict( + text=text, + return_top_k=top_k, + return_probabilities=True + ) + + # 构建响应 + response = { + "success": True, + "filename": file.filename, + "predictions": result if top_k > 1 else [result] + } + + return response + + elif file.filename.endswith(('.csv', '.xls', '.xlsx')): + # 表格文件 + if file.filename.endswith('.csv'): + df = pd.read_csv(io.BytesIO(content)) + else: + df = pd.read_excel(io.BytesIO(content)) + + # 查找可能的文本列 + text_columns = [col for col in df.columns if df[col].dtype == 'object'] + + if not text_columns: + raise HTTPException(status_code=400, detail="文件中没有找到可能的文本列") + + # 使用第一个文本列 + text_column = text_columns[0] + texts = df[text_column].fillna('').tolist() + + # 获取预测器 + predictor = get_predictor() + + # 批量预测 + results = predictor.predict_batch( + texts=texts, + return_top_k=top_k, + return_probabilities=True + ) + + # 构建响应 + response = { + "success": True, + "filename": file.filename, + "text_column": text_column, + "total": len(texts), + "results": results + } + + return response + + else: + raise HTTPException(status_code=400, detail=f"不支持的文件类型: {file.filename}") + + except Exception as e: + logger.error(f"预测文件时出错: {e}") + raise HTTPException(status_code=500, detail=f"预测文件时出错: {str(e)}") + + +@app.get("/models") +async def list_models(): + """ + 列出可用的模型 + + Returns: + 可用模型列表 + """ + try: + # 获取可用模型列表 + models_info = ModelFactory.get_available_models() + + # 转换为响应格式 + models = [] + for info in models_info: + models.append(ModelInfo( + id=os.path.basename(info['path']), + name=info['name'], + type=info['type'], + num_classes=info['num_classes'], + created_time=info['created_time'], + file_size=info['file_size'] + )) + + return {"models": models} + + except Exception as e: + logger.error(f"获取模型列表时出错: {e}") + raise HTTPException(status_code=500, detail=f"获取模型列表时出错: {str(e)}") + + +@app.get("/categories") +async def list_categories(): + """ + 列出支持的类别 + + Returns: + 支持的类别列表 + """ + try: + return {"categories": CATEGORIES} + except Exception as e: + logger.error(f"获取类别列表时出错: {e}") + raise HTTPException(status_code=500, detail=f"获取类别列表时出错: {str(e)}") + + +@app.middleware("http") +async def log_requests(request: Request, call_next): + """ + 记录请求日志 + + Args: + request: 请求对象 + call_next: 下一个处理函数 + + Returns: + 响应对象 + """ + start_time = time.time() + + # 记录请求信息 + logger.info(f"请求: {request.method} {request.url}") + + # 处理请求 + response = await call_next(request) + + # 记录响应信息 + process_time = time.time() - start_time + logger.info(f"响应: {response.status_code} ({process_time:.2f}s)") + + return response + + +def run_server(host: str = "0.0.0.0", port: int = 8000): + """ + 运行API服务器 + + Args: + host: 主机地址 + port: 端口号 + """ + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + # 解析命令行参数 + import argparse + + parser = argparse.ArgumentParser(description="中文文本分类系统API服务器") + parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址") + parser.add_argument("--port", type=int, default=8000, help="服务器端口号") + + args = parser.parse_args() + + # 运行服务器 + run_server(host=args.host, port=args.port) + + +================================================================================ +文件: interface/cli.py +================================================================================ + +""" +命令行界面模块:提供命令行交互功能 +""" +import argparse +import os +import sys +import pandas as pd +from typing import List, Dict, Tuple, Optional, Any, Union +import json + +# 将项目根目录添加到sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from config.system_config import CLASSIFIERS_DIR, CATEGORIES +from models.model_factory import ModelFactory +from models.base_model import TextClassificationModel +from preprocessing.tokenization import ChineseTokenizer +from preprocessing.vectorizer import SequenceVectorizer +from inference.predictor import Predictor +from inference.batch_processor import BatchProcessor +from utils.logger import get_logger +from utils.file_utils import ensure_dir, read_text_file + +logger = get_logger("CLI") + + +def load_model_and_components(model_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, + vectorizer_path: Optional[str] = None, + class_names: Optional[List[str]] = None) -> Tuple[ + TextClassificationModel, ChineseTokenizer, Optional[SequenceVectorizer]]: + """ + 加载模型和相关组件 + + Args: + model_path: 模型路径,如果为None则使用最新的模型 + tokenizer_path: 分词器路径,如果为None则创建一个新的分词器 + vectorizer_path: 向量化器路径,如果为None则不使用向量化器 + class_names: 类别名称列表,如果为None则使用CATEGORIES + + Returns: + (模型, 分词器, 向量化器)的元组 + """ + # 加载模型 + if model_path is None: + # 获取可用模型列表 + models_info = ModelFactory.get_available_models() + + if not models_info: + raise ValueError("未找到可用的模型,请指定模型路径") + + # 使用最新的模型 + model_path = models_info[0]['path'] + logger.info(f"使用最新的模型: {model_path}") + + # 加载模型 + model = ModelFactory.load_model(model_path) + + # 加载或创建分词器 + if tokenizer_path: + tokenizer = ChineseTokenizer() # 实际上应该从文件加载,这里简化处理 + logger.info(f"已加载分词器: {tokenizer_path}") + else: + tokenizer = ChineseTokenizer() + logger.info("已创建新的分词器") + + # 加载向量化器 + vectorizer = None + if vectorizer_path: + vectorizer = SequenceVectorizer() # 实际上应该从文件加载,这里简化处理 + vectorizer.load(vectorizer_path) + logger.info(f"已加载向量化器: {vectorizer_path}") + + return model, tokenizer, vectorizer + + +def predict_text(args): + """处理单条文本预测命令""" + # 加载模型和组件 + model, tokenizer, vectorizer = load_model_and_components( + args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names + ) + + # 创建预测器 + predictor = Predictor( + model=model, + tokenizer=tokenizer, + vectorizer=vectorizer, + class_names=args.class_names or CATEGORIES, + batch_size=args.batch_size + ) + + # 获取文本 + text = args.text + + # 如果提供的是文件路径而非文本内容 + if args.file and os.path.exists(text): + text = read_text_file(text) + + # 预测 + result = predictor.predict( + text=text, + return_top_k=args.top_k, + return_probabilities=True + ) + + # 输出结果 + if args.top_k > 1: + print("\n预测结果:") + for i, pred in enumerate(result): + print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})") + else: + print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})") + + # 保存结果 + if args.output: + if args.output.endswith('.json'): + with open(args.output, 'w', encoding='utf-8') as f: + json.dump(result, f, ensure_ascii=False, indent=2) + else: + with open(args.output, 'w', encoding='utf-8') as f: + if args.top_k > 1: + f.write("rank,class,probability\n") + for i, pred in enumerate(result): + f.write(f"{i + 1},{pred['class']},{pred['probability']}\n") + else: + f.write(f"class,probability\n") + f.write(f"{result['class']},{result['probability']}\n") + + print(f"结果已保存到: {args.output}") + + +def predict_batch(args): + """处理批量文本预测命令""" + # 加载模型和组件 + model, tokenizer, vectorizer = load_model_and_components( + args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names + ) + + # 创建预测器 + predictor = Predictor( + model=model, + tokenizer=tokenizer, + vectorizer=vectorizer, + class_names=args.class_names or CATEGORIES, + batch_size=args.batch_size + ) + + # 创建批处理器 + batch_processor = BatchProcessor( + predictor=predictor, + batch_size=args.batch_size, + max_workers=args.workers + ) + + # 确保输出目录存在 + if args.output: + ensure_dir(os.path.dirname(args.output)) + + # 根据输入类型选择处理方法 + if args.input_type == 'file' and os.path.isfile(args.input): + # 单个文件 + if args.large_file: + # 大型文件,分块处理 + batch_processor.process_large_file( + file_path=args.input, + output_path=args.output, + return_top_k=args.top_k, + format=args.format + ) + else: + # CSV或Excel文件 + if args.input.endswith('.csv'): + df = pd.read_csv(args.input, encoding='utf-8') + elif args.input.endswith(('.xls', '.xlsx')): + df = pd.read_excel(args.input) + else: + print(f"不支持的文件格式: {args.input}") + return + + # 检查文本列是否存在 + if args.text_column not in df.columns: + print(f"文本列 '{args.text_column}' 不在输入文件中,可用列: {', '.join(df.columns)}") + return + + # 处理DataFrame + result_df = batch_processor.process_dataframe( + df=df, + text_column=args.text_column, + id_column=args.id_column, + output_path=args.output, + return_top_k=args.top_k, + format=args.format + ) + + # 输出结果统计 + print(f"\n已处理 {len(result_df)} 条文本") + print("类别分布:") + if args.top_k == 1: + class_counts = result_df['predicted_class'].value_counts() + for cls, count in class_counts.items(): + print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)") + + elif args.input_type == 'dir' and os.path.isdir(args.input): + # 目录 + result_df = batch_processor.process_directory( + directory=args.input, + pattern=args.pattern, + output_path=args.output, + return_top_k=args.top_k, + format=args.format, + recursive=args.recursive + ) + + # 输出结果统计 + if not result_df.empty: + print(f"\n已处理 {len(result_df)} 个文件") + print("类别分布:") + if args.top_k == 1: + class_counts = result_df['predicted_class'].value_counts() + for cls, count in class_counts.items(): + print(f" {cls}: {count} ({count / len(result_df) * 100:.1f}%)") + + else: + print(f"无效的输入: {args.input}") + + +def list_models(args): + """列出可用的模型""" + models_info = ModelFactory.get_available_models() + + if not models_info: + print("未找到可用的模型") + return + + print(f"找到 {len(models_info)} 个可用模型:") + for i, info in enumerate(models_info): + print(f"\n{i + 1}. {info['name']} ({info['type']})") + print(f" 路径: {info['path']}") + print(f" 创建时间: {info['created_time']}") + print(f" 类别数: {info['num_classes']}") + print(f" 文件大小: {info['file_size']}") + + +def interactive_mode(args): + """交互模式""" + print("启动交互模式...") + + # 加载模型和组件 + model, tokenizer, vectorizer = load_model_and_components( + args.model_path, args.tokenizer_path, args.vectorizer_path, args.class_names + ) + + # 创建预测器 + predictor = Predictor( + model=model, + tokenizer=tokenizer, + vectorizer=vectorizer, + class_names=args.class_names or CATEGORIES, + batch_size=args.batch_size + ) + + print("\n模型已加载,可以开始交互式文本分类") + print("输入 'quit' 或 'exit' 退出交互模式\n") + + while True: + try: + # 获取用户输入 + text = input("请输入要分类的文本: ") + + # 检查是否退出 + if text.lower() in ['quit', 'exit', 'q']: + print("退出交互模式") + break + + # 空输入 + if not text.strip(): + continue + + # 预测 + result = predictor.predict( + text=text, + return_top_k=args.top_k, + return_probabilities=True + ) + + # 输出结果 + if args.top_k > 1: + print("\n预测结果:") + for i, pred in enumerate(result): + print(f"{i + 1}. {pred['class']} (概率: {pred['probability']:.4f})") + else: + print(f"\n预测结果: {result['class']} (概率: {result['probability']:.4f})") + + print() # 空行 + + except KeyboardInterrupt: + print("\n退出交互模式") + break + except Exception as e: + print(f"处理过程中出错: {e}") + + +def main(): + """主函数,解析命令行参数并调用相应的功能""" + parser = argparse.ArgumentParser(description="中文文本分类系统命令行工具") + + # 创建子命令 + subparsers = parser.add_subparsers(dest="command", help="子命令") + + # 预测单条文本命令 + predict_parser = subparsers.add_parser("predict", help="预测单条文本") + predict_parser.add_argument("text", help="要预测的文本或文本文件路径") + predict_parser.add_argument("--file", action="store_true", help="将text参数视为文件路径") + predict_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型") + predict_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器") + predict_parser.add_argument("--vectorizer_path", help="向量化器路径") + predict_parser.add_argument("--class_names", nargs="+", help="类别名称列表") + predict_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别") + predict_parser.add_argument("--batch_size", type=int, default=64, help="批大小") + predict_parser.add_argument("--output", help="保存预测结果的文件路径") + predict_parser.set_defaults(func=predict_text) + + # 批量预测命令 + batch_parser = subparsers.add_parser("batch", help="批量预测文本") + batch_parser.add_argument("input", help="输入文件或目录路径") + batch_parser.add_argument("--input_type", choices=["file", "dir"], default="file", help="输入类型") + batch_parser.add_argument("--text_column", default="text", help="CSV/Excel文件中的文本列名") + batch_parser.add_argument("--id_column", help="CSV/Excel文件中的ID列名") + batch_parser.add_argument("--pattern", default="*.txt", help="文件匹配模式") + batch_parser.add_argument("--recursive", action="store_true", help="递归处理子目录") + batch_parser.add_argument("--large_file", action="store_true", help="处理大型文本文件") + batch_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型") + batch_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器") + batch_parser.add_argument("--vectorizer_path", help="向量化器路径") + batch_parser.add_argument("--class_names", nargs="+", help="类别名称列表") + batch_parser.add_argument("--top_k", type=int, default=1, help="返回概率最高的前k个类别") + batch_parser.add_argument("--batch_size", type=int, default=64, help="批大小") + batch_parser.add_argument("--workers", type=int, default=4, help="工作线程数") + batch_parser.add_argument("--output", required=True, help="输出文件路径") + batch_parser.add_argument("--format", choices=["csv", "json"], default="csv", help="输出格式") + batch_parser.set_defaults(func=predict_batch) + + # 列出可用模型命令 + list_parser = subparsers.add_parser("list", help="列出可用的模型") + list_parser.set_defaults(func=list_models) + + # 交互模式命令 + interactive_parser = subparsers.add_parser("interactive", help="启动交互式分类模式") + interactive_parser.add_argument("--model_path", help="模型路径,默认使用最新的模型") + interactive_parser.add_argument("--tokenizer_path", help="分词器路径,默认创建新的分词器") + interactive_parser.add_argument("--vectorizer_path", help="向量化器路径") + interactive_parser.add_argument("--class_names", nargs="+", help="类别名称列表") + interactive_parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别") + interactive_parser.add_argument("--batch_size", type=int, default=1, help="批大小") + interactive_parser.set_defaults(func=interactive_mode) + + # 解析参数 + args = parser.parse_args() + + # 如果没有指定命令,显示帮助 + if not hasattr(args, 'func'): + parser.print_help() + return + + # 执行命令 + try: + args.func(args) + except Exception as e: + logger.error(f"执行命令时出错: {e}") + print(f"执行命令时出错: {e}") + return 1 + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) + + +================================================================================ +文件: interface/web/__init__.py +================================================================================ + + + +================================================================================ +文件: interface/web/app.py +================================================================================ + +return render_template('predict_text.html', text=text, predictions=predictions) + +except Exception as e: +logger.error(f"预测文本时出错: {e}") +flash(f'预测失败: {str(e)}') +return render_template('predict_text.html', text=text) + +return render_template('predict_text.html') + + +@app.route('/predict_file', methods=['GET', 'POST']) +def predict_file(): + """文件预测页面""" + if request.method == 'POST': + # 检查是否有文件上传 + if 'file' not in request.files: + flash('未选择文件') + return render_template('predict_file.html') + + file = request.files['file'] + + # 如果用户没有选择文件 + if file.filename == '': + flash('未选择文件') + return render_template('predict_file.html') + + if file and allowed_file(file.filename): + # 获取参数 + top_k = int(request.form.get('top_k', 3)) + text_column = request.form.get('text_column', 'text') + + try: + # 安全地保存文件 + filename = secure_filename(file.filename) + file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) + file.save(file_path) + + # 处理文件类型 + if filename.endswith('.txt'): + # 文本文件 + with open(file_path, 'r', encoding='utf-8') as f: + text = f.read() + + # 获取预测器 + predictor = get_predictor() + + # 预测 + result = predictor.predict( + text=text, + return_top_k=top_k, + return_probabilities=True + ) + + # 准备结果 + predictions = result if top_k > 1 else [result] + + return render_template( + 'predict_file.html', + filename=filename, + file_type='text', + predictions=predictions + ) + + elif filename.endswith(('.csv', '.xls', '.xlsx')): + # 表格文件 + if filename.endswith('.csv'): + df = pd.read_csv(file_path) + else: + df = pd.read_excel(file_path) + + # 检查文本列是否存在 + if text_column not in df.columns: + flash(f"文件中没有找到列: {text_column}") + return render_template('predict_file.html') + + # 提取文本 + texts = df[text_column].fillna('').tolist() + + # 获取预测器 + predictor = get_predictor() + + # 批量预测 + results = predictor.predict_batch( + texts=texts[:100], # 仅处理前100条记录 + return_top_k=top_k, + return_probabilities=True + ) + + # 准备结果 + batch_results = [] + for i, result in enumerate(results): + batch_results.append({ + 'id': i, + 'text': texts[i][:100] + '...' if len(texts[i]) > 100 else texts[i], + 'predictions': result if top_k > 1 else [result] + }) + + return render_template( + 'predict_file.html', + filename=filename, + file_type='table', + total_records=len(texts), + processed_records=min(100, len(texts)), + batch_results=batch_results + ) + + else: + flash(f"不支持的文件类型: {filename}") + return render_template('predict_file.html') + + except Exception as e: + logger.error(f"处理文件时出错: {e}") + flash(f'处理失败: {str(e)}') + return render_template('predict_file.html') + else: + flash(f'不支持的文件类型,允许的类型: {", ".join(ALLOWED_EXTENSIONS)}') + return render_template('predict_file.html') + + return render_template('predict_file.html') + + +@app.route('/batch_predict', methods=['GET', 'POST']) +def batch_predict(): + """批量预测页面""" + if request.method == 'POST': + texts = request.form.get('texts', '') + top_k = int(request.form.get('top_k', 3)) + + if not texts: + flash('请输入文本内容') + return render_template('batch_predict.html') + + # 分割文本 + text_list = [text.strip() for text in texts.split('\n') if text.strip()] + + if not text_list: + flash('请输入有效的文本内容') + return render_template('batch_predict.html', texts=texts) + + try: + # 获取预测器 + predictor = get_predictor() + + # 批量预测 + results = predictor.predict_batch( + texts=text_list, + return_top_k=top_k, + return_probabilities=True + ) + + # 准备结果 + batch_results = [] + for i, result in enumerate(results): + batch_results.append({ + 'id': i + 1, + 'text': text_list[i], + 'predictions': result if top_k > 1 else [result] + }) + + return render_template( + 'batch_predict.html', + texts=texts, + batch_results=batch_results + ) + + except Exception as e: + logger.error(f"批量预测时出错: {e}") + flash(f'预测失败: {str(e)}') + return render_template('batch_predict.html', texts=texts) + + return render_template('batch_predict.html') + + +@app.route('/models') +def list_models(): + """模型列表页面""" + try: + # 获取可用模型列表 + models_info = ModelFactory.get_available_models() + + return render_template('models.html', models=models_info) + + except Exception as e: + logger.error(f"获取模型列表时出错: {e}") + flash(f'获取模型列表失败: {str(e)}') + return render_template('models.html', models=[]) + + +@app.route('/about') +def about(): + """关于页面""" + return render_template('about.html') + + +@app.errorhandler(404) +def page_not_found(e): + """404错误处理""" + return render_template('404.html'), 404 + + +@app.errorhandler(500) +def internal_server_error(e): + """500错误处理""" + return render_template('500.html'), 500 + + +# 过滤器 +@app.template_filter('format_time') +def format_time_filter(timestamp): + """格式化时间戳""" + if isinstance(timestamp, str): + try: + dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S") + return dt.strftime("%Y年%m月%d日 %H:%M") + except: + return timestamp + return timestamp + + +@app.template_filter('truncate_text') +def truncate_text_filter(text, length=100): + """截断文本""" + if len(text) <= length: + return text + return text[:length] + '...' + + +@app.template_filter('format_percent') +def format_percent_filter(value): + """格式化百分比""" + if isinstance(value, (int, float)): + return f"{value * 100:.2f}%" + return value + + +def run_server(host: str = "0.0.0.0", port: int = 5000, debug: bool = False): + """ + 运行Web服务器 + + Args: + host: 主机地址 + port: 端口号 + debug: 是否开启调试模式 + """ + app.run(host=host, port=port, debug=debug) + + +if __name__ == "__main__": + # 解析命令行参数 + import argparse + + parser = argparse.ArgumentParser(description="中文文本分类系统Web应用") + parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址") + parser.add_argument("--port", type=int, default=5000, help="服务器端口号") + parser.add_argument("--debug", action="store_true", help="是否开启调试模式") + + args = parser.parse_args() + + # 运行服务器 + run_server(host=args.host, port=args.port, debug=args.debug) + + +================================================================================ +文件: interface/web/routes.py +================================================================================ + +""" +Web路由模块:定义Web应用的路由 +""" +from flask import Blueprint, render_template, request, jsonify, redirect, url_for, session, flash +from werkzeug.utils import secure_filename +import os +import pandas as pd +import io +import time +from typing import List, Dict, Tuple, Optional + +from interface.web.app import get_predictor, allowed_file, app + +# 创建蓝图 +bp = Blueprint('routes', __name__) + + +@bp.route('/') +def index(): + """首页""" + return render_template('index.html') + + +@bp.route('/predict_text', methods=['GET', 'POST']) +def predict_text(): + """文本预测页面""" + if request.method == 'POST': + text = request.form.get('text', '') + top_k = int(request.form.get('top_k', 3)) + + if not text: + flash('请输入文本内容') + return render_template('predict_text.html') + + try: + # 获取预测器 + predictor = get_predictor() + + # 预测 + result = predictor.predict( + text=text, + return_top_k=top_k, + return_probabilities=True + ) + + # 准备结果 + predictions = result if top_k > 1 else [result] + + return render_template('predict_text.html', text=text, predictions=predictions) + + except Exception as e: + flash(f'预测失败: {str(e)}') + return render_template('predict_text.html', text=text) + + return render_template('predict_text.html') + + +@bp.route('/predict_file', methods=['GET', 'POST']) +def predict_file(): + """文件预测页面""" + if request.method == 'POST': + # 检查是否有文件上传 + if 'file' not in request.files: + flash('未选择文件') + return render_template('predict_file.html') + + file = request.files['file'] + + # 如果用户没有选择文件 + if file.filename == '': + flash('未选择文件') + return render_template('predict_file.html') + + if file and allowed_file(file.filename): + # 获取参数 + top_k = int(request.form.get('top_k', 3)) + text_column = request.form.get('text_column', 'text') + + try: + # 安全地保存文件 + filename = secure_filename(file.filename) + file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) + file.save(file_path) + + # 处理文件 + # 注意:此处与app.py中的predict_file重复,实际项目中应该将逻辑抽取到单独的函数中 + # 这里简化处理,仅返回渲染模板 + + return render_template('predict_file.html', filename=filename) + + except Exception as e: + flash(f'处理失败: {str(e)}') + return render_template('predict_file.html') + else: + flash(f'不支持的文件类型') + return render_template('predict_file.html') + + return render_template('predict_file.html') + + +@bp.route('/batch_predict', methods=['GET', 'POST']) +def batch_predict(): + """批量预测页面""" + if request.method == 'POST': + texts = request.form.get('texts', '') + top_k = int(request.form.get('top_k', 3)) + + if not texts: + flash('请输入文本内容') + return render_template('batch_predict.html') + + # 分割文本 + text_list = [text.strip() for text in texts.split('\n') if text.strip()] + + if not text_list: + flash('请输入有效的文本内容') + return render_template('batch_predict.html', texts=texts) + + try: + # 获取预测器 + predictor = get_predictor() + + # 批量预测 + results = predictor.predict_batch( + texts=text_list, + return_top_k=top_k, + return_probabilities=True + ) + + # 准备结果 + batch_results = [] + for i, result in enumerate(results): + batch_results.append({ + 'id': i + 1, + 'text': text_list[i], + 'predictions': result if top_k > 1 else [result] + }) + + return render_template( + 'batch_predict.html', + texts=texts, + batch_results=batch_results + ) + + except Exception as e: + flash(f'预测失败: {str(e)}') + return render_template('batch_predict.html', texts=texts) + + return render_template('batch_predict.html') + + +@bp.route('/models') +def list_models(): + """模型列表页面""" + return render_template('models.html') + + +@bp.route('/about') +def about(): + """关于页面""" + return render_template('about.html') + + +# 注册蓝图 +def init_app(app): + app.register_blueprint(bp) + + +================================================================================ +文件: config/model_config.py +================================================================================ + +""" +模型配置文件 +""" + +# 文本预处理参数 +MAX_SEQUENCE_LENGTH = 500 # 文本序列最大长度 +MAX_NUM_WORDS = 50000 # 词汇表最大大小 +MAX_CHAR_LENGTH = 2000 # 字符级最大长度 +MIN_WORD_FREQUENCY = 5 # 最小词频 + +# 模型架构参数 +CNN_CONFIG = { + "embedding_dim": 200, + "num_filters": 256, + "filter_sizes": [3, 4, 5], + "dropout_rate": 0.5, + "l2_reg_lambda": 0.0, +} + +RNN_CONFIG = { + "embedding_dim": 200, + "hidden_size": 256, + "num_layers": 2, + "bidirectional": True, + "dropout_rate": 0.5, +} + +TRANSFORMER_CONFIG = { + "embedding_dim": 200, + "num_heads": 8, + "ff_dim": 512, + "num_layers": 4, + "dropout_rate": 0.1, +} + +# 针对RTX 4090的优化设置 +BATCH_SIZE = 128 # RTX 4090有24GB显存,可以支持较大的batch +EVAL_BATCH_SIZE = 256 # 评估时可以用更大的batch + +# 训练参数 +LEARNING_RATE = 1e-3 +NUM_EPOCHS = 20 +EARLY_STOPPING_PATIENCE = 3 +REDUCE_LR_PATIENCE = 2 +REDUCE_LR_FACTOR = 0.5 +VALIDATION_SPLIT = 0.1 +TEST_SPLIT = 0.1 + +# 词嵌入参数 +USE_PRETRAINED_EMBEDDING = True +EMBEDDING_TYPE = "word2vec" # 可选: word2vec, glove, fasttext + +# 随机种子,保证实验可重复性 +RANDOM_SEED = 42 + +# 模型保存参数 +SAVE_BEST_ONLY = True +MODEL_CHECKPOINT_PATH = "best_model.h5" + +# 特征工程参数 +USE_CHAR_LEVEL = False # 是否使用字符级特征 +USE_WORD_LEVEL = True # 是否使用词级特征 +USE_TFIDF = False # 是否使用TF-IDF特征 +USE_POS_TAGS = False # 是否使用词性标注特征 + +# 数据增强参数 +USE_DATA_AUGMENTATION = False +AUGMENTATION_FACTOR = 0.2 # 增强20%的数据 + +# 推理参数 +PREDICTION_THRESHOLD = 0.5 +TOP_K_PREDICTIONS = 3 + +================================================================================ +文件: config/__init__.py +================================================================================ + + + +================================================================================ +文件: config/system_config.py +================================================================================ + +""" +系统全局配置文件 +""" +import os +import platform +from pathlib import Path + +# 项目根目录 +ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +""" +Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 是当前文件的上一级目录 +这种写法主要是为了方便移植项目到不同的平台运行 +""" + +# 数据相关路径 +DATA_DIR = ROOT_DIR / "data" +RAW_DATA_DIR = DATA_DIR / "raw" / "THUCNews" +PROCESSED_DATA_DIR = DATA_DIR / "processed" +RESOURCES_DIR = DATA_DIR / "resources" +STOPWORDS_DIR = RESOURCES_DIR / "stopwords" +EMBEDDINGS_DIR = RESOURCES_DIR / "embeddings" + +# 确保必要的目录存在 +for directory in [PROCESSED_DATA_DIR, RESOURCES_DIR, STOPWORDS_DIR, EMBEDDINGS_DIR]: + directory.mkdir(parents=True, exist_ok=True) + +# 保存模型的路径 +SAVED_MODELS_DIR = ROOT_DIR / "saved_models" +TOKENIZERS_DIR = SAVED_MODELS_DIR / "tokenizers" +CLASSIFIERS_DIR = SAVED_MODELS_DIR / "classifiers" + +# 确保模型保存目录存在 +for directory in [SAVED_MODELS_DIR, TOKENIZERS_DIR, CLASSIFIERS_DIR]: + directory.mkdir(parents=True, exist_ok=True) + +# 系统资源配置 +CPU_COUNT = os.cpu_count() +USE_GPU = True +MULTI_GPU = False # 目前只使用单个GPU + +# 基于13900K性能设置并行处理参数 +DATA_LOADING_WORKERS = min(16, CPU_COUNT) # 数据加载线程数 +PREPROCESSING_WORKERS = min(24, CPU_COUNT) # 预处理线程数,13900K有强大的多线程能力 + +# 基于64GB内存设置内存相关参数 +MAX_MEMORY_GB = 48 # 保留部分内存给系统和其他应用 +MAX_TEXT_PER_BATCH = 10000 # 每批处理的最大文本数量 + +# 日志配置 +LOG_DIR = ROOT_DIR / "logs" +LOG_DIR.mkdir(exist_ok=True) +LOG_LEVEL = "INFO" +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + +# 类别标签映射(与THUCNews数据集一致) +CATEGORIES = [ + "体育", "娱乐", "家居", "彩票", "房产", "教育", + "时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经" +] +CATEGORY_TO_ID = {category: idx for idx, category in enumerate(CATEGORIES)} +ID_TO_CATEGORY = {idx: category for idx, category in enumerate(CATEGORIES)} + +# 文件编码 +ENCODING = "utf-8" + +# 系统信息 +SYSTEM_INFO = { + "platform": platform.platform(), + "python_version": platform.python_version(), + "processor": platform.processor(), +} + +================================================================================ +文件: training/__init__.py +================================================================================ + + + +================================================================================ +文件: training/callbacks.py +================================================================================ + +""" +回调函数模块:提供用于模型训练的自定义回调函数 +""" +import os +import time +import numpy as np +import tensorflow as tf +from typing import List, Dict, Tuple, Optional, Any, Union +import matplotlib.pyplot as plt +from io import BytesIO + +from utils.logger import get_logger + +logger = get_logger("Callbacks") + + +class MetricsHistory(tf.keras.callbacks.Callback): + """跟踪训练过程中的指标历史""" + + def __init__(self, validation_data: Optional[Tuple] = None, + metrics: Optional[List[str]] = None, + save_path: Optional[str] = None): + """ + 初始化MetricsHistory回调 + + Args: + validation_data: 验证数据,格式为(x_val, y_val) + metrics: 要跟踪的指标列表 + save_path: 指标历史的保存路径 + """ + super().__init__() + self.validation_data = validation_data + self.metrics = metrics or ['loss', 'accuracy'] + self.save_path = save_path + + # 历史指标 + self.history = {metric: [] for metric in self.metrics} + if validation_data is not None: + for metric in self.metrics: + self.history[f'val_{metric}'] = [] + + def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None: + """ + 每个epoch结束时调用 + + Args: + epoch: 当前epoch索引 + logs: 训练日志 + """ + logs = logs or {} + + # 记录训练指标 + for metric in self.metrics: + if metric in logs: + self.history[metric].append(logs[metric]) + + # 记录验证指标 + if self.validation_data is not None: + for metric in self.metrics: + val_metric = f'val_{metric}' + if val_metric in logs: + self.history[val_metric].append(logs[val_metric]) + + def plot_metrics(self, save_path: Optional[str] = None) -> None: + """ + 绘制指标历史 + + Args: + save_path: 图像保存路径,如果为None则使用初始化时设置的路径 + """ + plt.figure(figsize=(12, 5)) + + for i, metric in enumerate(self.metrics): + plt.subplot(1, len(self.metrics), i + 1) + + if metric in self.history: + plt.plot(self.history[metric], label=f'train_{metric}') + + val_metric = f'val_{metric}' + if val_metric in self.history: + plt.plot(self.history[val_metric], label=f'val_{metric}') + + plt.title(f'Model {metric}') + plt.xlabel('Epoch') + plt.ylabel(metric) + plt.legend() + + plt.tight_layout() + + save_path = save_path or self.save_path + if save_path: + plt.savefig(save_path) + logger.info(f"指标历史图已保存到: {save_path}") + else: + plt.show() + + +class ConfusionMatrixCallback(tf.keras.callbacks.Callback): + """计算并显示验证集上的混淆矩阵""" + + def __init__(self, validation_data: Tuple[np.ndarray, np.ndarray], + class_names: Optional[List[str]] = None, + log_dir: Optional[str] = None, + freq: int = 1, + fig_size: Tuple[int, int] = (10, 8)): + """ + 初始化ConfusionMatrixCallback + + Args: + validation_data: 验证数据,格式为(x_val, y_val) + class_names: 类别名称列表 + log_dir: TensorBoard日志目录 + freq: 计算混淆矩阵的频率(每多少个epoch计算一次) + fig_size: 图像大小 + """ + super().__init__() + self.x_val, self.y_val = validation_data + self.class_names = class_names + self.log_dir = log_dir + self.freq = freq + self.fig_size = fig_size + + # 如果提供了TensorBoard日志目录,创建一个文件写入器 + if log_dir: + self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'confusion_matrix')) + else: + self.file_writer = None + + def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None: + """ + 每个epoch结束时调用 + + Args: + epoch: 当前epoch索引 + logs: 训练日志 + """ + # 每freq个epoch计算一次混淆矩阵 + if (epoch + 1) % self.freq == 0 or epoch == 0: + # 获取预测结果 + y_pred = np.argmax(self.model.predict(self.x_val), axis=1) + + # 确保y_val是一维数组 + y_true = self.y_val + if len(y_true.shape) > 1 and y_true.shape[1] > 1: + y_true = np.argmax(y_true, axis=1) + + # 计算混淆矩阵 + cm = tf.math.confusion_matrix(y_true, y_pred).numpy() + + # 归一化混淆矩阵 + cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True) + + # 绘制混淆矩阵 + fig = self._plot_confusion_matrix(cm_norm, epoch + 1) + + # 如果有TensorBoard日志,将图像添加到TensorBoard + if self.file_writer: + with self.file_writer.as_default(): + # 将matplotlib图像转换为TensorBoard图像 + buf = BytesIO() + fig.savefig(buf, format='png') + buf.seek(0) + + # 将PNG编码为字符串,并创建图像 + image = tf.image.decode_png(buf.getvalue(), channels=4) + image = tf.expand_dims(image, 0) + + # 添加到TensorBoard + tf.summary.image(f'Confusion Matrix (Epoch {epoch + 1})', image, step=epoch) + + plt.close(fig) + + def _plot_confusion_matrix(self, cm: np.ndarray, epoch: int) -> plt.Figure: + """ + 绘制混淆矩阵 + + Args: + cm: 混淆矩阵 + epoch: 当前epoch + + Returns: + matplotlib图像 + """ + fig, ax = plt.subplots(figsize=self.fig_size) + + # 使用热图显示混淆矩阵 + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + + # 设置坐标轴标签 + if self.class_names: + ax.set( + xticks=np.arange(cm.shape[1]), + yticks=np.arange(cm.shape[0]), + xticklabels=self.class_names, + yticklabels=self.class_names + ) + + # 旋转x轴标签 + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") + + # 在每个单元格中显示数值 + thresh = cm.max() / 2.0 + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], '.2f'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + ax.set_title(f"Normalized Confusion Matrix (Epoch {epoch})") + ax.set_ylabel('True label') + ax.set_xlabel('Predicted label') + + fig.tight_layout() + return fig + + +class TimingCallback(tf.keras.callbacks.Callback): + """测量训练时间的回调函数""" + + def __init__(self): + """初始化TimingCallback""" + super().__init__() + self.epoch_times = [] + self.batch_times = [] + self.epoch_start_time = None + self.batch_start_time = None + self.training_start_time = None + + def on_train_begin(self, logs: Dict[str, float] = None) -> None: + """ + 训练开始时调用 + + Args: + logs: 训练日志 + """ + self.training_start_time = time.time() + + def on_train_end(self, logs: Dict[str, float] = None) -> None: + """ + 训练结束时调用 + + Args: + logs: 训练日志 + """ + training_time = time.time() - self.training_start_time + logger.info(f"总训练时间: {training_time:.2f} 秒") + + if self.epoch_times: + avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times) + logger.info(f"平均每个epoch时间: {avg_epoch_time:.2f} 秒") + + if self.batch_times: + avg_batch_time = sum(self.batch_times) / len(self.batch_times) + logger.info(f"平均每个batch时间: {avg_batch_time:.4f} 秒") + + def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None: + """ + 每个epoch开始时调用 + + Args: + epoch: 当前epoch索引 + logs: 训练日志 + """ + self.epoch_start_time = time.time() + + def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None: + """ + 每个epoch结束时调用 + + Args: + epoch: 当前epoch索引 + logs: 训练日志 + """ + epoch_time = time.time() - self.epoch_start_time + self.epoch_times.append(epoch_time) + + # 将epoch时间添加到日志中 + if logs is not None: + logs['epoch_time'] = epoch_time + + def on_batch_begin(self, batch: int, logs: Dict[str, float] = None) -> None: + """ + 每个batch开始时调用 + + Args: + batch: 当前batch索引 + logs: 训练日志 + """ + self.batch_start_time = time.time() + + def on_batch_end(self, batch: int, logs: Dict[str, float] = None) -> None: + """ + 每个batch结束时调用 + + Args: + batch: 当前batch索引 + logs: 训练日志 + """ + batch_time = time.time() - self.batch_start_time + self.batch_times.append(batch_time) + + +class LearningRateSchedulerCallback(tf.keras.callbacks.Callback): + """学习率调度器回调函数""" + + def __init__(self, scheduler_func: Callable[[int, float], float], + verbose: int = 0, + log_dir: Optional[str] = None): + """ + 初始化LearningRateSchedulerCallback + + Args: + scheduler_func: 学习率调度函数,接收(epoch, lr)参数,返回新的学习率 + verbose: 详细程度 + log_dir: TensorBoard日志目录 + """ + super().__init__() + self.scheduler_func = scheduler_func + self.verbose = verbose + + # 如果提供了TensorBoard日志目录,创建一个文件写入器 + if log_dir: + self.file_writer = tf.summary.create_file_writer(os.path.join(log_dir, 'learning_rate')) + else: + self.file_writer = None + + # 学习率历史 + self.lr_history = [] + + def on_epoch_begin(self, epoch: int, logs: Dict[str, float] = None) -> None: + """ + 每个epoch开始时调用 + + Args: + epoch: 当前epoch索引 + logs: 训练日志 + """ + if not hasattr(self.model.optimizer, 'lr'): + raise ValueError('Optimizer must have a "lr" attribute.') + + # 获取当前学习率 + current_lr = float(tf.keras.backend.get_value(self.model.optimizer.lr)) + + # 计算新的学习率 + new_lr = self.scheduler_func(epoch, current_lr) + + # 设置新的学习率 + tf.keras.backend.set_value(self.model.optimizer.lr, new_lr) + + # 记录学习率 + self.lr_history.append(new_lr) + + # 记录到TensorBoard + if self.file_writer: + with self.file_writer.as_default(): + tf.summary.scalar('learning_rate', new_lr, step=epoch) + + if self.verbose > 0: + logger.info(f"Epoch {epoch + 1}: 学习率设置为 {new_lr:.6f}") + + def get_lr_history(self) -> List[float]: + """ + 获取学习率历史 + + Returns: + 学习率历史列表 + """ + return self.lr_history + + +class EarlyStoppingCallback(tf.keras.callbacks.EarlyStopping): + """增强版早停回调函数,支持最小变化率""" + + def __init__(self, monitor: str = 'val_loss', + min_delta: float = 0, + min_delta_ratio: float = 0, + patience: int = 0, + verbose: int = 0, + mode: str = 'auto', + baseline: Optional[float] = None, + restore_best_weights: bool = False): + """ + 初始化EarlyStoppingCallback + + Args: + monitor: 监控的指标 + min_delta: 视为改进的最小绝对变化 + min_delta_ratio: 视为改进的最小相对变化率 + patience: 没有改进的轮数 + verbose: 详细程度 + mode: 'auto', 'min' 或 'max' + baseline: 基准值 + restore_best_weights: 是否恢复最佳权重 + """ + super().__init__( + monitor=monitor, + min_delta=min_delta, + patience=patience, + verbose=verbose, + mode=mode, + baseline=baseline, + restore_best_weights=restore_best_weights + ) + self.min_delta_ratio = min_delta_ratio + + def _is_improvement(self, current: float, reference: float) -> bool: + """ + 判断是否有所改进 + + Args: + current: 当前值 + reference: 参考值 + + Returns: + 是否有所改进 + """ + # 先检查绝对变化 + if super()._is_improvement(current, reference): + return True + + # 再检查相对变化率 + if self.monitor_op == np.less: + # 对于 'min' 模式,值越小越好 + relative_delta = (reference - current) / reference if reference != 0 else 0 + return relative_delta > self.min_delta_ratio + else: + # 对于 'max' 模式,值越大越好 + relative_delta = (current - reference) / reference if reference != 0 else 0 + return relative_delta > self.min_delta_ratio + + +================================================================================ +文件: training/optimizer.py +================================================================================ + + + +================================================================================ +文件: training/scheduler.py +================================================================================ + +""" +学习率调度器模块:提供各种学习率调度策略 +""" +import numpy as np +import math +from typing import Callable, Optional, Union, Dict +import tensorflow as tf + +from utils.logger import get_logger + +logger = get_logger("Scheduler") + + +def step_decay(epoch: int, initial_lr: float, + drop_rate: float = 0.5, + epochs_drop: int = 10) -> float: + """ + 阶梯式学习率衰减 + + Args: + epoch: 当前epoch索引 + initial_lr: 初始学习率 + drop_rate: 衰减率 + epochs_drop: 每多少个epoch衰减一次 + + Returns: + 新的学习率 + """ + return initial_lr * math.pow(drop_rate, math.floor((1 + epoch) / epochs_drop)) + + +def exponential_decay(epoch: int, initial_lr: float, + decay_rate: float = 0.9, + staircase: bool = False) -> float: + """ + 指数衰减学习率 + + Args: + epoch: 当前epoch索引 + initial_lr: 初始学习率 + decay_rate: 衰减率 + staircase: 是否阶梯式衰减 + + Returns: + 新的学习率 + """ + if staircase: + return initial_lr * math.pow(decay_rate, math.floor(epoch)) + else: + return initial_lr * math.pow(decay_rate, epoch) + + +def cosine_decay(epoch: int, initial_lr: float, + total_epochs: int = 100, + min_lr: float = 0) -> float: + """ + 余弦退火学习率 + + Args: + epoch: 当前epoch索引 + initial_lr: 初始学习率 + total_epochs: 总epoch数 + min_lr: 最小学习率 + + Returns: + 新的学习率 + """ + return min_lr + 0.5 * (initial_lr - min_lr) * (1 + math.cos(math.pi * epoch / total_epochs)) + + +def cosine_decay_with_warmup(epoch: int, initial_lr: float, + total_epochs: int = 100, + warmup_epochs: int = 5, + min_lr: float = 0, + warmup_init_lr: float = 0) -> float: + """ + 带预热的余弦退火学习率 + + Args: + epoch: 当前epoch索引 + initial_lr: 初始学习率 + total_epochs: 总epoch数 + warmup_epochs: 预热epoch数 + min_lr: 最小学习率 + warmup_init_lr: 预热初始学习率 + + Returns: + 新的学习率 + """ + if epoch < warmup_epochs: + # 线性预热 + return warmup_init_lr + (initial_lr - warmup_init_lr) * epoch / warmup_epochs + else: + # 余弦退火 + return min_lr + 0.5 * (initial_lr - min_lr) * ( + 1 + math.cos(math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs)) + ) + + +def cyclical_learning_rate(epoch: int, initial_lr: float, + max_lr: float = 0.1, + step_size: int = 8, + gamma: float = 1.0) -> float: + """ + 循环学习率 + + Args: + epoch: 当前epoch索引 + initial_lr: 初始学习率 + max_lr: 最大学习率 + step_size: 半周期大小(epoch数) + gamma: 循环衰减率 + + Returns: + 新的学习率 + """ + # 计算循环数 + cycle = math.floor(1 + epoch / (2 * step_size)) + + # 计算x值,范围在[0, 2] + x = abs(epoch / step_size - 2 * cycle + 1) + + # 应用循环衰减 + lr_range = (max_lr - initial_lr) * math.pow(gamma, cycle - 1) + + # 计算学习率 + return initial_lr + lr_range * max(0, 1 - x) + + +def create_custom_scheduler(scheduler_type: str, **kwargs) -> Callable[[int, float], float]: + """ + 创建自定义学习率调度器 + + Args: + scheduler_type: 调度器类型,可选值: 'step', 'exp', 'cosine', 'cosine_warmup', 'cyclical' + **kwargs: 调度器参数 + + Returns: + 学习率调度函数 + """ + scheduler_type = scheduler_type.lower() + + if scheduler_type == 'step': + drop_rate = kwargs.get('drop_rate', 0.5) + epochs_drop = kwargs.get('epochs_drop', 10) + + def scheduler(epoch, lr): + return step_decay(epoch, lr, drop_rate, epochs_drop) + + return scheduler + + elif scheduler_type == 'exp': + decay_rate = kwargs.get('decay_rate', 0.9) + staircase = kwargs.get('staircase', False) + + def scheduler(epoch, lr): + if epoch == 0: + # 第一个epoch使用初始学习率 + return lr + return exponential_decay(epoch, lr, decay_rate, staircase) + + return scheduler + + elif scheduler_type == 'cosine': + total_epochs = kwargs.get('total_epochs', 100) + min_lr = kwargs.get('min_lr', 0) + + def scheduler(epoch, lr): + if epoch == 0: + return lr + return cosine_decay(epoch, lr, total_epochs, min_lr) + + return scheduler + + elif scheduler_type == 'cosine_warmup': + total_epochs = kwargs.get('total_epochs', 100) + warmup_epochs = kwargs.get('warmup_epochs', 5) + min_lr = kwargs.get('min_lr', 0) + warmup_init_lr = kwargs.get('warmup_init_lr', 0) + + def scheduler(epoch, lr): + if epoch == 0: + return warmup_init_lr + return cosine_decay_with_warmup(epoch, lr, total_epochs, warmup_epochs, min_lr, warmup_init_lr) + + return scheduler + + elif scheduler_type == 'cyclical': + max_lr = kwargs.get('max_lr', 0.1) + step_size = kwargs.get('step_size', 8) + gamma = kwargs.get('gamma', 1.0) + + def scheduler(epoch, lr): + if epoch == 0: + return lr + return cyclical_learning_rate(epoch, lr, max_lr, step_size, gamma) + + return scheduler + + else: + raise ValueError(f"不支持的调度器类型: {scheduler_type}") + + +class WarmupCosineDecayScheduler(tf.keras.callbacks.Callback): + """预热余弦退火学习率调度器""" + + def __init__(self, learning_rate_base: float, + total_steps: int, + warmup_learning_rate: float = 0.0, + warmup_steps: int = 0, + hold_base_rate_steps: int = 0, + verbose: int = 0): + """ + 初始化预热余弦退火学习率调度器 + + Args: + learning_rate_base: 基础学习率 + total_steps: 总步数 + warmup_learning_rate: 预热学习率 + warmup_steps: 预热步数 + hold_base_rate_steps: 保持基础学习率的步数 + verbose: 详细程度 + """ + super().__init__() + + self.learning_rate_base = learning_rate_base + self.total_steps = total_steps + self.warmup_learning_rate = warmup_learning_rate + self.warmup_steps = warmup_steps + self.hold_base_rate_steps = hold_base_rate_steps + self.verbose = verbose + + # 学习率历史 + self.learning_rates = [] + + def on_train_begin(self, logs: Optional[Dict] = None) -> None: + """ + 训练开始时调用 + + Args: + logs: 训练日志 + """ + self.current_step = 0 + logger.info(f"预热余弦退火学习率调度器初始化: 基础学习率={self.learning_rate_base}, " + f"预热步数={self.warmup_steps}, 总步数={self.total_steps}") + + def on_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """ + 每个batch结束时调用 + + Args: + batch: 当前batch索引 + logs: 训练日志 + """ + self.current_step += 1 + lr = self._get_learning_rate() + + tf.keras.backend.set_value(self.model.optimizer.lr, lr) + self.learning_rates.append(lr) + + def _get_learning_rate(self) -> float: + """ + 计算当前学习率 + + Returns: + 当前学习率 + """ + if self.current_step < self.warmup_steps: + # 预热阶段:线性增加学习率 + lr = self.warmup_learning_rate + self.current_step * ( + (self.learning_rate_base - self.warmup_learning_rate) / self.warmup_steps + ) + elif self.current_step < self.warmup_steps + self.hold_base_rate_steps: + # 保持基础学习率阶段 + lr = self.learning_rate_base + else: + # 余弦退火阶段 + cosine_steps = self.total_steps - self.warmup_steps - self.hold_base_rate_steps + cosine_current_step = self.current_step - self.warmup_steps - self.hold_base_rate_steps + lr = 0.5 * self.learning_rate_base * ( + 1 + math.cos(math.pi * cosine_current_step / cosine_steps) + ) + + return lr + + +================================================================================ +文件: training/trainer.py +================================================================================ + +""" +训练器模块:实现模型训练流程,包括训练循环、验证等 +""" +import os +import time +from typing import List, Dict, Tuple, Optional, Any, Union, Callable +import numpy as np +import tensorflow as tf +import matplotlib.pyplot as plt +from datetime import datetime + +from config.system_config import SAVED_MODELS_DIR +from config.model_config import ( + NUM_EPOCHS, BATCH_SIZE, EARLY_STOPPING_PATIENCE, + VALIDATION_SPLIT, RANDOM_SEED +) +from models.base_model import TextClassificationModel +from utils.logger import get_logger, TrainingLogger +from utils.file_utils import ensure_dir + +logger = get_logger("Trainer") + + +class Trainer: + """模型训练器,负责训练和验证模型""" + + def __init__(self, model: TextClassificationModel, + epochs: int = NUM_EPOCHS, + batch_size: Optional[int] = None, + validation_split: float = VALIDATION_SPLIT, + early_stopping: bool = True, + early_stopping_patience: int = EARLY_STOPPING_PATIENCE, + save_best_only: bool = True, + tensorboard: bool = True, + checkpoint: bool = True, + custom_callbacks: Optional[List[tf.keras.callbacks.Callback]] = None): + """ + 初始化训练器 + + Args: + model: 要训练的模型 + epochs: 训练轮数 + batch_size: 批大小,如果为None则使用模型默认值 + validation_split: 验证集比例 + early_stopping: 是否使用早停 + early_stopping_patience: 早停耐心值 + save_best_only: 是否只保存最佳模型 + tensorboard: 是否使用TensorBoard + checkpoint: 是否保存检查点 + custom_callbacks: 自定义回调函数列表 + """ + self.model = model + self.epochs = epochs + self.batch_size = batch_size or model.batch_size + self.validation_split = validation_split + self.early_stopping = early_stopping + self.early_stopping_patience = early_stopping_patience + self.save_best_only = save_best_only + self.tensorboard = tensorboard + self.checkpoint = checkpoint + self.custom_callbacks = custom_callbacks or [] + + # 训练历史 + self.history = None + + # 训练日志记录器 + self.training_logger = TrainingLogger(model.model_name) + + logger.info(f"初始化训练器,模型: {model.model_name}, 轮数: {epochs}, 批大小: {self.batch_size}") + + def _create_callbacks(self) -> List[tf.keras.callbacks.Callback]: + """ + 创建回调函数列表 + + Returns: + 回调函数列表 + """ + callbacks = [] + + # 早停 + if self.early_stopping: + early_stopping = tf.keras.callbacks.EarlyStopping( + monitor='val_loss', + patience=self.early_stopping_patience, + restore_best_weights=True, + verbose=1 + ) + callbacks.append(early_stopping) + + # 学习率衰减 + reduce_lr = tf.keras.callbacks.ReduceLROnPlateau( + monitor='val_loss', + factor=0.5, + patience=self.early_stopping_patience // 2, + min_lr=1e-6, + verbose=1 + ) + callbacks.append(reduce_lr) + + # 模型检查点 + if self.checkpoint: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + checkpoint_dir = os.path.join(SAVED_MODELS_DIR, 'checkpoints') + ensure_dir(checkpoint_dir) + + checkpoint_path = os.path.join( + checkpoint_dir, + f"{self.model.model_name}_{timestamp}.h5" + ) + + model_checkpoint = tf.keras.callbacks.ModelCheckpoint( + filepath=checkpoint_path, + save_best_only=self.save_best_only, + monitor='val_loss', + verbose=1 + ) + callbacks.append(model_checkpoint) + + # TensorBoard + if self.tensorboard: + log_dir = os.path.join( + SAVED_MODELS_DIR, + 'logs', + f"{self.model.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + ensure_dir(log_dir) + + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir=log_dir, + histogram_freq=1, + update_freq='epoch' + ) + callbacks.append(tensorboard_callback) + + # 添加自定义回调函数 + callbacks.extend(self.custom_callbacks) + + return callbacks + + def _log_training_progress(self, epoch: int, logs: Dict[str, float]) -> None: + """ + 记录训练进度 + + Args: + epoch: 当前轮数 + logs: 日志信息 + """ + self.training_logger.log_epoch(epoch, logs) + + def train(self, x_train: Union[np.ndarray, tf.data.Dataset], + y_train: Optional[np.ndarray] = None, + x_val: Optional[Union[np.ndarray, tf.data.Dataset]] = None, + y_val: Optional[np.ndarray] = None, + class_weights: Optional[Dict[int, float]] = None) -> Dict[str, List[float]]: + """ + 训练模型 + + Args: + x_train: 训练数据特征 + y_train: 训练数据标签 + x_val: 验证数据特征 + y_val: 验证数据标签 + class_weights: 类别权重 + + Returns: + 训练历史 + """ + logger.info(f"开始训练模型: {self.model.model_name}") + + # 创建回调函数 + callbacks = self._create_callbacks() + + # 添加训练进度记录回调 + progress_callback = tf.keras.callbacks.LambdaCallback( + on_epoch_end=lambda epoch, logs: self._log_training_progress(epoch, logs) + ) + callbacks.append(progress_callback) + + # 记录开始时间 + start_time = time.time() + + # 记录训练开始信息 + model_config = self.model.get_config() + train_config = { + "epochs": self.epochs, + "batch_size": self.batch_size, + "validation_split": self.validation_split, + "early_stopping": self.early_stopping, + "early_stopping_patience": self.early_stopping_patience + } + self.training_logger.log_training_start({**model_config, **train_config}) + + # 准备验证数据 + validation_data = None + if x_val is not None and y_val is not None: + validation_data = (x_val, y_val) + + # 训练模型 + history = self.model.fit( + x_train, y_train, + validation_data=validation_data, + epochs=self.epochs, + callbacks=callbacks, + class_weights=class_weights, + verbose=1 + ) + + # 计算训练时间 + train_time = time.time() - start_time + + # 保存训练历史 + self.history = history.history + + # 找出最佳性能 + best_val_loss = min(history.history['val_loss']) if 'val_loss' in history.history else None + best_val_acc = max(history.history['val_accuracy']) if 'val_accuracy' in history.history else None + + best_metrics = {} + if best_val_loss is not None: + best_metrics['val_loss'] = best_val_loss + if best_val_acc is not None: + best_metrics['val_accuracy'] = best_val_acc + + # 记录训练结束信息 + self.training_logger.log_training_end(train_time, best_metrics) + + logger.info(f"模型训练完成,用时: {train_time:.2f} 秒") + + return history.history + + def plot_training_history(self, metrics: Optional[List[str]] = None, + save_path: Optional[str] = None) -> None: + """ + 绘制训练历史 + + Args: + metrics: 要绘制的指标列表,默认为['loss', 'accuracy'] + save_path: 保存路径,如果为None则显示图像 + """ + if self.history is None: + raise ValueError("模型尚未训练,没有训练历史") + + if metrics is None: + metrics = ['loss', 'accuracy'] + + plt.figure(figsize=(12, 5)) + + for i, metric in enumerate(metrics): + plt.subplot(1, len(metrics), i + 1) + + if metric in self.history: + plt.plot(self.history[metric], label=f'train_{metric}') + + val_metric = f'val_{metric}' + if val_metric in self.history: + plt.plot(self.history[val_metric], label=f'val_{metric}') + + plt.title(f'Model {metric}') + plt.xlabel('Epoch') + plt.ylabel(metric) + plt.legend() + + plt.tight_layout() + + if save_path: + plt.savefig(save_path) + logger.info(f"训练历史图已保存到: {save_path}") + else: + plt.show() + + def save_trained_model(self, filepath: Optional[str] = None) -> str: + """ + 保存训练好的模型 + + Args: + filepath: 保存路径,如果为None则使用默认路径 + + Returns: + 保存路径 + """ + return self.model.save(filepath) + + +================================================================================ +文件: tests/__init__.py +================================================================================ + + + +================================================================================ +文件: tests/test_evaluation.py +================================================================================ + + + +================================================================================ +文件: tests/test_preprocessing.py +================================================================================ + + + +================================================================================ +文件: tests/test_models.py +================================================================================ + + + +================================================================================ +文件: utils/text_utils.py +================================================================================ + + + +================================================================================ +文件: utils/__init__.py +================================================================================ + + + +================================================================================ +文件: utils/time_utils.py +================================================================================ + + + +================================================================================ +文件: utils/logger.py +================================================================================ + +""" +日志工具模块 +""" +import logging +import sys +from pathlib import Path +from datetime import datetime +import os + +from config.system_config import LOG_DIR, LOG_LEVEL, LOG_FORMAT + + +def get_logger(name, level=None, log_file=None): + """ + 获取logger实例 + + Args: + name: logger名称 + level: 日志级别,默认为系统配置 + log_file: 日志文件路径,默认为None(仅控制台输出) + + Returns: + logger实例 + """ + level = level or LOG_LEVEL + + # 创建logger + logger = logging.getLogger(name) + logger.setLevel(getattr(logging, level)) + + # 避免重复添加handler + if logger.handlers: + return logger + + # 创建格式化器 + formatter = logging.Formatter(LOG_FORMAT) + + # 创建控制台处理器 + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # 如果指定了日志文件,创建文件处理器 + if log_file: + log_path = Path(LOG_DIR) / log_file + file_handler = logging.FileHandler(log_path, encoding='utf-8') + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +def get_time_logger(name): + """ + 获取带时间戳的logger实例 + + Args: + name: logger名称 + + Returns: + logger实例 + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = f"{name}_{timestamp}.log" + return get_logger(name, log_file=log_file) + + +class TrainingLogger: + """训练过程日志记录器""" + + def __init__(self, model_name): + """ + 初始化训练日志记录器 + + Args: + model_name: 模型名称 + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.log_file = f"training_{model_name}_{timestamp}.log" + self.logger = get_logger(f"training_{model_name}", log_file=self.log_file) + + # 创建CSV日志 + self.csv_path = Path(LOG_DIR) / f"metrics_{model_name}_{timestamp}.csv" + with open(self.csv_path, 'w', encoding='utf-8') as f: + f.write("epoch,loss,accuracy,val_loss,val_accuracy\n") + + def log_epoch(self, epoch, metrics): + """记录每个epoch的指标""" + # 日志记录 + metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) + self.logger.info(f"Epoch {epoch}: {metrics_str}") + + # CSV记录 + csv_line = f"{epoch},{metrics.get('loss', '')},{metrics.get('accuracy', '')}," \ + f"{metrics.get('val_loss', '')},{metrics.get('val_accuracy', '')}\n" + with open(self.csv_path, 'a', encoding='utf-8') as f: + f.write(csv_line) + + def log_training_start(self, config): + """记录训练开始信息""" + self.logger.info("=" * 50) + self.logger.info("训练开始") + self.logger.info("模型配置:") + for key, value in config.items(): + self.logger.info(f" {key}: {value}") + self.logger.info("=" * 50) + + def log_training_end(self, duration, best_metrics): + """记录训练结束信息""" + self.logger.info("=" * 50) + self.logger.info(f"训练结束,总用时: {duration:.2f}秒") + self.logger.info("最佳性能:") + for key, value in best_metrics.items(): + self.logger.info(f" {key}: {value:.4f}") + self.logger.info("=" * 50) + +================================================================================ +文件: utils/file_utils.py +================================================================================ + +""" +文件处理工具模块 +""" +import os +import shutil +import json +import pickle +import csv +from pathlib import Path +import time +import hashlib +from concurrent.futures import ThreadPoolExecutor, as_completed +import zipfile +import tarfile + +from config.system_config import ENCODING, DATA_LOADING_WORKERS +from utils.logger import get_logger + +logger = get_logger("file_utils") + + +def read_text_file(file_path, encoding=ENCODING): + """ + 读取文本文件内容 + + Args: + file_path: 文件路径 + encoding: 文件编码 + + Returns: + 文件内容 + """ + try: + with open(file_path, 'r', encoding=encoding) as file: + return file.read() + except Exception as e: + logger.error(f"读取文件 {file_path} 时出错: {str(e)}") + return None + + +def write_text_file(content, file_path, encoding=ENCODING): + """ + 写入文本文件 + + Args: + content: 文件内容 + file_path: 文件路径 + encoding: 文件编码 + + Returns: + 成功标志 + """ + try: + # 确保目录存在 + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + with open(file_path, 'w', encoding=encoding) as file: + file.write(content) + return True + except Exception as e: + logger.error(f"写入文件 {file_path} 时出错: {str(e)}") + return False + + +def save_json(data, file_path, encoding=ENCODING): + """ + 保存JSON数据到文件 + + Args: + data: 要保存的数据 + file_path: 文件路径 + encoding: 文件编码 + + Returns: + 成功标志 + """ + try: + # 确保目录存在 + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + with open(file_path, 'w', encoding=encoding) as file: + json.dump(data, file, ensure_ascii=False, indent=2) + return True + except Exception as e: + logger.error(f"保存JSON文件 {file_path} 时出错: {str(e)}") + return False + + +def load_json(file_path, encoding=ENCODING): + """ + 从文件加载JSON数据 + + Args: + file_path: 文件路径 + encoding: 文件编码 + + Returns: + 加载的数据 + """ + try: + with open(file_path, 'r', encoding=encoding) as file: + return json.load(file) + except Exception as e: + logger.error(f"加载JSON文件 {file_path} 时出错: {str(e)}") + return None + + +def save_pickle(data, file_path): + """ + 使用pickle保存数据 + + Args: + data: 要保存的数据 + file_path: 文件路径 + + Returns: + 成功标志 + """ + try: + # 确保目录存在 + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + with open(file_path, 'wb') as file: + pickle.dump(data, file) + return True + except Exception as e: + logger.error(f"保存pickle文件 {file_path} 时出错: {str(e)}") + return False + + +def load_pickle(file_path): + """ + 从文件加载pickle数据 + + Args: + file_path: 文件路径 + + Returns: + 加载的数据 + """ + try: + with open(file_path, 'rb') as file: + return pickle.load(file) + except Exception as e: + logger.error(f"加载pickle文件 {file_path} 时出错: {str(e)}") + return None + + +def read_files_parallel(file_paths, max_workers=DATA_LOADING_WORKERS, encoding=ENCODING): + """ + 并行读取多个文本文件 + + Args: + file_paths: 文件路径列表 + max_workers: 最大工作线程数 + encoding: 文件编码 + + Returns: + 文件内容列表 + """ + start_time = time.time() + results = [] + + # 定义单个读取函数 + def read_single_file(file_path): + return read_text_file(file_path, encoding) + + # 使用线程池并行读取 + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_file = {executor.submit(read_single_file, file_path): file_path + for file_path in file_paths} + + # 收集结果 + for future in as_completed(future_to_file): + file_path = future_to_file[future] + try: + content = future.result() + if content is not None: + results.append(content) + except Exception as e: + logger.error(f"处理文件 {file_path} 时出错: {str(e)}") + + elapsed = time.time() - start_time + logger.info(f"并行读取 {len(file_paths)} 个文件,成功 {len(results)} 个,用时 {elapsed:.2f} 秒") + + return results + + +def get_file_md5(file_path): + """ + 计算文件的MD5哈希值 + + Args: + file_path: 文件路径 + + Returns: + MD5哈希值 + """ + hash_md5 = hashlib.md5() + + try: + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + except Exception as e: + logger.error(f"计算文件 {file_path} 的MD5值时出错: {str(e)}") + return None + + +def extract_archive(archive_path, extract_to=None): + """ + 解压缩文件 + + Args: + archive_path: 压缩文件路径 + extract_to: 解压目标路径,默认为同目录 + + Returns: + 成功标志 + """ + if extract_to is None: + extract_to = os.path.dirname(archive_path) + + try: + if archive_path.endswith('.zip'): + with zipfile.ZipFile(archive_path, 'r') as zip_ref: + zip_ref.extractall(extract_to) + elif archive_path.endswith(('.tar.gz', '.tgz')): + with tarfile.open(archive_path, 'r:gz') as tar_ref: + tar_ref.extractall(extract_to) + elif archive_path.endswith('.tar'): + with tarfile.open(archive_path, 'r') as tar_ref: + tar_ref.extractall(extract_to) + else: + logger.error(f"不支持的压缩格式: {archive_path}") + return False + + logger.info(f"成功解压 {archive_path} 到 {extract_to}") + return True + except Exception as e: + logger.error(f"解压 {archive_path} 时出错: {str(e)}") + return False + + +def list_files(directory, pattern=None, recursive=True): + """ + 列出目录中的文件 + + Args: + directory: 目录路径 + pattern: 文件名模式(支持通配符) + recursive: 是否递归搜索子目录 + + Returns: + 文件路径列表 + """ + if not os.path.exists(directory): + logger.error(f"目录不存在: {directory}") + return [] + + directory = Path(directory) + + if pattern: + if recursive: + return [str(p) for p in directory.glob(f"**/{pattern}")] + else: + return [str(p) for p in directory.glob(pattern)] + else: + if recursive: + files = [] + for p in directory.rglob("*"): + if p.is_file(): + files.append(str(p)) + return files + else: + return [str(p) for p in directory.iterdir() if p.is_file()] + + +def ensure_dir(directory): + """ + 确保目录存在,不存在则创建 + + Args: + directory: 目录路径 + """ + os.makedirs(directory, exist_ok=True) + + +def remove_dir(directory): + """ + 删除目录及其内容 + + Args: + directory: 目录路径 + + Returns: + 成功标志 + """ + try: + if os.path.exists(directory): + shutil.rmtree(directory) + return True + except Exception as e: + logger.error(f"删除目录 {directory} 时出错: {str(e)}") + return False + +================================================================================ +文件: models/ensemble_model.py +================================================================================ + +""" +集成模型:实现多个模型的集成 +""" +import numpy as np +import tensorflow as tf +from typing import List, Dict, Tuple, Optional, Any, Union +import os + +from config.system_config import CLASSIFIERS_DIR +from models.base_model import TextClassificationModel +from utils.logger import get_logger + +logger = get_logger("EnsembleModel") + + +class EnsembleModel: + """模型集成类,集成多个模型的预测结果""" + + def __init__(self, models: List[TextClassificationModel], + weights: Optional[List[float]] = None, + voting: str = 'soft', + name: str = "ensemble_model"): + """ + 初始化集成模型 + + Args: + models: 模型列表 + weights: 各模型的权重,默认为均等权重 + voting: 投票方式,'hard'表示多数投票,'soft'表示概率平均 + name: 集成模型名称 + """ + self.models = models + self.num_models = len(models) + + # 验证模型数量 + if self.num_models == 0: + raise ValueError("模型列表不能为空") + + # 设置权重 + if weights is None: + self.weights = np.ones(self.num_models) / self.num_models + else: + if len(weights) != self.num_models: + raise ValueError("权重数量必须与模型数量相同") + + # 归一化权重 + self.weights = np.array(weights) / np.sum(weights) + + # 验证投票方式 + self.voting = voting.lower() + if self.voting not in ['hard', 'soft']: + raise ValueError("无效的投票方式,支持的方式: 'hard', 'soft'") + + # 从第一个模型获取类别数 + self.num_classes = models[0].num_classes + + # 验证所有模型的类别数是否相同 + for i, model in enumerate(models[1:], 1): + if model.num_classes != self.num_classes: + raise ValueError( + f"模型 {i} 的类别数 ({model.num_classes}) 与第一个模型的类别数 ({self.num_classes}) 不同") + + self.name = name + + logger.info(f"初始化集成模型,包含 {self.num_models} 个模型,投票方式: {self.voting}") + + def predict(self, x: Union[np.ndarray, tf.data.Dataset, List], + batch_size: Optional[int] = None, + verbose: int = 0) -> np.ndarray: + """ + 使用集成模型进行预测 + + Args: + x: 预测数据 + batch_size: 批大小 + verbose: 详细程度 + + Returns: + 预测概率 + """ + # 获取每个模型的预测结果 + all_predictions = [] + + for i, model in enumerate(self.models): + logger.info(f"获取模型 {i + 1}/{self.num_models} 的预测结果") + predictions = model.predict(x, batch_size, verbose) + + # 如果是二分类且输出形状是(n,1),转换为(n,2) + if self.num_classes == 2 and predictions.shape[1:] == (1,): + predictions = np.hstack([1 - predictions, predictions]) + + all_predictions.append(predictions) + + # 根据投票方式进行集成 + if self.voting == 'hard': + # 硬投票:每个模型预测的类别,取众数 + individual_classes = [np.argmax(pred, axis=1) for pred in all_predictions] + + # 获取带权重的预测类别频率 + ensemble_result = np.zeros((len(x), self.num_classes)) + + for i, classes in enumerate(individual_classes): + for j, cls in enumerate(classes): + ensemble_result[j, cls] += self.weights[i] + + return ensemble_result + else: # soft voting + # 软投票:对每个模型的预测概率进行加权平均 + weighted_predictions = [pred * weight for pred, weight in zip(all_predictions, self.weights)] + ensemble_result = np.sum(weighted_predictions, axis=0) + + return ensemble_result + + def predict_classes(self, x: Union[np.ndarray, tf.data.Dataset, List], + batch_size: Optional[int] = None, + verbose: int = 0) -> np.ndarray: + """ + 使用集成模型预测类别 + + Args: + x: 预测数据 + batch_size: 批大小 + verbose: 详细程度 + + Returns: + 预测的类别索引 + """ + # 获取预测概率 + predictions = self.predict(x, batch_size, verbose) + + # 获取最大概率的类别索引 + return np.argmax(predictions, axis=1) + + def save(self, directory: Optional[str] = None) -> str: + """ + 保存集成模型 + + Args: + directory: 保存目录,默认为CLASSIFIERS_DIR + + Returns: + 保存路径 + """ + if directory is None: + import time + timestamp = time.strftime("%Y%m%d_%H%M%S") + directory = os.path.join(CLASSIFIERS_DIR, f"{self.name}_{timestamp}") + + os.makedirs(directory, exist_ok=True) + + # 保存模型列表 + model_paths = [] + for i, model in enumerate(self.models): + model_path = os.path.join(directory, f"model_{i}") + model.save(model_path) + model_paths.append(model_path) + + # 保存集成配置 + config = { + "name": self.name, + "num_models": self.num_models, + "model_paths": model_paths, + "weights": self.weights.tolist(), + "voting": self.voting, + "num_classes": self.num_classes + } + + import json + config_path = os.path.join(directory, "ensemble_config.json") + with open(config_path, 'w', encoding='utf-8') as f: + json.dump(config, f, ensure_ascii=False, indent=4) + + logger.info(f"集成模型已保存到目录: {directory}") + + return directory + + @classmethod + def load(cls, directory: str, custom_objects: Optional[Dict[str, Any]] = None) -> 'EnsembleModel': + """ + 加载集成模型 + + Args: + directory: 模型目录 + custom_objects: 自定义对象字典 + + Returns: + 加载的集成模型实例 + """ + # 加载配置 + config_path = os.path.join(directory, "ensemble_config.json") + + import json + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + + # 加载子模型 + from models.model_factory import ModelFactory + + models = [] + model_paths = config["model_paths"] + + for model_path in model_paths: + model = ModelFactory.load_model(model_path, custom_objects) + models.append(model) + + # 创建集成模型 + ensemble = cls( + models=models, + weights=config["weights"], + voting=config["voting"], + name=config["name"] + ) + + logger.info(f"从目录 {directory} 加载集成模型成功") + + return ensemble + + +================================================================================ +文件: models/transformer_model.py +================================================================================ + +""" +Transformer模型:实现基于Transformer的文本分类模型 +""" +import tensorflow as tf +from tensorflow.keras.models import Model +from tensorflow.keras.layers import ( + Input, Embedding, Dense, Dropout, LayerNormalization, + GlobalAveragePooling1D, MultiHeadAttention, Add +) +from typing import List, Dict, Tuple, Optional, Any, Union +import numpy as np + +from config.model_config import ( + MAX_SEQUENCE_LENGTH, TRANSFORMER_CONFIG +) +from models.base_model import TextClassificationModel +from utils.logger import get_logger + +logger = get_logger("TransformerModel") + + +class TransformerBlock(tf.keras.layers.Layer): + """Transformer块,包含多头注意力和前馈网络""" + + def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout_rate: float = 0.1): + """ + 初始化Transformer块 + + Args: + embed_dim: 嵌入维度 + num_heads: 注意力头数 + ff_dim: 前馈网络维度 + dropout_rate: Dropout比例 + """ + super(TransformerBlock, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.ff_dim = ff_dim + self.dropout_rate = dropout_rate + + self.attention = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) + self.ffn = tf.keras.Sequential([ + Dense(ff_dim, activation="relu"), + Dense(embed_dim), + ]) + self.layernorm1 = LayerNormalization(epsilon=1e-6) + self.layernorm2 = LayerNormalization(epsilon=1e-6) + self.dropout1 = Dropout(dropout_rate) + self.dropout2 = Dropout(dropout_rate) + + def call(self, inputs, training=False): + """ + 前向传播 + + Args: + inputs: 输入张量 + training: 是否处于训练模式 + + Returns: + 输出张量 + """ + # 多头自注意力 + attention_output = self.attention(inputs, inputs) + attention_output = self.dropout1(attention_output, training=training) + out1 = self.layernorm1(inputs + attention_output) + + # 前馈网络 + ffn_output = self.ffn(out1) + ffn_output = self.dropout2(ffn_output, training=training) + out2 = self.layernorm2(out1 + ffn_output) + + return out2 + + def get_config(self): + """获取配置""" + config = super(TransformerBlock, self).get_config() + config.update({ + "embed_dim": self.embed_dim, + "num_heads": self.num_heads, + "ff_dim": self.ff_dim, + "dropout_rate": self.dropout_rate + }) + return config + + +class TransformerTextClassifier(TextClassificationModel): + """Transformer文本分类模型""" + + def __init__(self, num_classes: int, vocab_size: int, + embedding_dim: int = TRANSFORMER_CONFIG["embedding_dim"], + max_sequence_length: int = MAX_SEQUENCE_LENGTH, + num_heads: int = TRANSFORMER_CONFIG["num_heads"], + ff_dim: int = TRANSFORMER_CONFIG["ff_dim"], + num_layers: int = TRANSFORMER_CONFIG["num_layers"], + dropout_rate: float = TRANSFORMER_CONFIG["dropout_rate"], + embedding_matrix: Optional[np.ndarray] = None, + trainable_embedding: bool = True, + use_positional_encoding: bool = True, + model_name: str = "transformer_text_classifier", + batch_size: int = 64, + learning_rate: float = 0.001): + """ + 初始化Transformer文本分类模型 + + Args: + num_classes: 类别数量 + vocab_size: 词汇表大小 + embedding_dim: 词嵌入维度 + max_sequence_length: 最大序列长度 + num_heads: 注意力头数 + ff_dim: 前馈网络维度 + num_layers: Transformer层数 + dropout_rate: Dropout比例 + embedding_matrix: 预训练词嵌入矩阵,如果为None则使用随机初始化 + trainable_embedding: 词嵌入是否可训练 + use_positional_encoding: 是否使用位置编码 + model_name: 模型名称 + batch_size: 批大小 + learning_rate: 学习率 + """ + super().__init__(num_classes, model_name, batch_size, learning_rate) + + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.max_sequence_length = max_sequence_length + self.num_heads = num_heads + self.ff_dim = ff_dim + self.num_layers = num_layers + self.dropout_rate = dropout_rate + self.embedding_matrix = embedding_matrix + self.trainable_embedding = trainable_embedding + self.use_positional_encoding = use_positional_encoding + + # 更新配置 + self.config.update({ + "vocab_size": vocab_size, + "embedding_dim": embedding_dim, + "max_sequence_length": max_sequence_length, + "num_heads": num_heads, + "ff_dim": ff_dim, + "num_layers": num_layers, + "dropout_rate": dropout_rate, + "trainable_embedding": trainable_embedding, + "use_positional_encoding": use_positional_encoding, + "model_type": "Transformer" + }) + + logger.info(f"初始化Transformer文本分类模型,头数: {num_heads}, 层数: {num_layers}") + + def _positional_encoding(self, max_length: int, d_model: int) -> tf.Tensor: + """ + 生成位置编码 + + Args: + max_length: 最大序列长度 + d_model: 模型维度 + + Returns: + 位置编码张量 + """ + positions = np.arange(max_length)[:, np.newaxis] + depths = np.arange(d_model)[np.newaxis, :] // 2 * 2 + angle_rates = 1 / np.power(10000, depths / d_model) + angle_rads = positions * angle_rates + + # sin用于偶数索引,cos用于奇数索引 + sines = np.sin(angle_rads[:, 0::2]) + cosines = np.cos(angle_rads[:, 1::2]) + + pos_encoding = np.zeros((max_length, d_model)) + pos_encoding[:, 0::2] = sines + pos_encoding[:, 1::2] = cosines + + return tf.cast(pos_encoding[tf.newaxis, ...], dtype=tf.float32) + + def build(self) -> None: + """构建Transformer模型架构""" + # Input layer + sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input') + + # Embedding layer + if self.embedding_matrix is not None: + embedding_layer = Embedding( + input_dim=self.vocab_size, + output_dim=self.embedding_dim, + weights=[self.embedding_matrix], + input_length=self.max_sequence_length, + trainable=self.trainable_embedding, + name='embedding' + ) + else: + embedding_layer = Embedding( + input_dim=self.vocab_size, + output_dim=self.embedding_dim, + input_length=self.max_sequence_length, + trainable=True, + name='embedding' + ) + + embedded_sequences = embedding_layer(sequence_input) + + # 添加位置编码 + if self.use_positional_encoding: + pos_encoding = self._positional_encoding(self.max_sequence_length, self.embedding_dim) + embedded_sequences = embedded_sequences + pos_encoding + + # Transformer层 + x = embedded_sequences + for i in range(self.num_layers): + x = TransformerBlock( + embed_dim=self.embedding_dim, + num_heads=self.num_heads, + ff_dim=self.ff_dim, + dropout_rate=self.dropout_rate, + name=f'transformer_block_{i + 1}' + )(x) + + # 全局池化 + x = GlobalAveragePooling1D(name='global_avg_pooling')(x) + + # Dropout for regularization + x = Dropout(self.dropout_rate, name='dropout_1')(x) + + # Dense layer + x = Dense(128, activation='relu', name='dense_1')(x) + x = Dropout(self.dropout_rate, name='dropout_2')(x) + + # Output layer + if self.num_classes == 2: + # Binary classification + predictions = Dense(1, activation='sigmoid', name='predictions')(x) + else: + # Multi-class classification + predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x) + + # Build the model + self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name) + + logger.info(f"Transformer模型构建完成,头数: {self.num_heads}, 层数: {self.num_layers}") + + def compile(self, optimizer=None, loss=None, metrics=None) -> None: + """ + 编译Transformer模型 + + Args: + optimizer: 优化器,默认为Adam + loss: 损失函数,默认根据类别数量选择 + metrics: 评估指标,默认为accuracy + """ + if self.model is None: + raise ValueError("模型尚未构建,请先调用build方法") + + # 默认优化器 + if optimizer is None: + optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate) + + # 默认损失函数 + if loss is None: + if self.num_classes == 2: + loss = 'binary_crossentropy' + else: + loss = 'sparse_categorical_crossentropy' + + # 默认评估指标 + if metrics is None: + metrics = ['accuracy'] + + # 编译模型 + self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics) + logger.info(f"Transformer模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}") + +================================================================================ +文件: models/__init__.py +================================================================================ + + + +================================================================================ +文件: models/base_model.py +================================================================================ + +""" +模型基类:定义所有文本分类模型的通用接口 +""" +import os +import time +import json +import numpy as np +import tensorflow as tf +from tensorflow.keras.models import Model, load_model +from typing import List, Dict, Tuple, Optional, Any, Union, Callable +from abc import ABC, abstractmethod + +from config.system_config import SAVED_MODELS_DIR, CLASSIFIERS_DIR +from config.model_config import ( + BATCH_SIZE, LEARNING_RATE, EARLY_STOPPING_PATIENCE, + REDUCE_LR_PATIENCE, REDUCE_LR_FACTOR +) +from utils.logger import get_logger +from utils.file_utils import ensure_dir, save_json + +logger = get_logger("BaseModel") + + +class TextClassificationModel(ABC): + """文本分类模型基类,定义所有模型的通用接口""" + + def __init__(self, num_classes: int, model_name: str = "text_classifier", + batch_size: int = BATCH_SIZE, + learning_rate: float = LEARNING_RATE): + """ + 初始化文本分类模型 + + Args: + num_classes: 类别数量 + model_name: 模型名称 + batch_size: 批大小 + learning_rate: 学习率 + """ + self.num_classes = num_classes + self.model_name = model_name + self.batch_size = batch_size + self.learning_rate = learning_rate + + # 模型实例 + self.model = None + + # 训练历史 + self.history = None + + # 训练配置 + self.config = { + "model_name": model_name, + "num_classes": num_classes, + "batch_size": batch_size, + "learning_rate": learning_rate + } + + # 验证集合最佳性能 + self.best_val_loss = float('inf') + self.best_val_accuracy = 0.0 + + logger.info(f"初始化 {model_name} 模型,类别数: {num_classes}") + + @abstractmethod + def build(self) -> None: + """构建模型架构,这是一个抽象方法,子类必须实现""" + pass + + def compile(self, optimizer: Optional[tf.keras.optimizers.Optimizer] = None, + loss: Optional[Union[str, tf.keras.losses.Loss]] = None, + metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None) -> None: + """ + 编译模型 + + Args: + optimizer: 优化器,默认为Adam + loss: 损失函数,默认为sparse_categorical_crossentropy + metrics: 评估指标,默认为accuracy + """ + if self.model is None: + raise ValueError("模型尚未构建,请先调用build方法") + + # 默认优化器 + if optimizer is None: + optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate) + + # 默认损失函数 + if loss is None: + loss = 'sparse_categorical_crossentropy' + + # 默认评估指标 + if metrics is None: + metrics = ['accuracy'] + + # 编译模型 + self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics) + logger.info(f"模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}, 评估指标: {metrics}") + + def summary(self) -> None: + """打印模型概要""" + if self.model is None: + raise ValueError("模型尚未构建,请先调用build方法") + + self.model.summary() + + def fit(self, x_train: Union[np.ndarray, tf.data.Dataset], + y_train: Optional[np.ndarray] = None, + validation_data: Optional[Union[Tuple[np.ndarray, np.ndarray], tf.data.Dataset]] = None, + epochs: int = 10, + callbacks: Optional[List[tf.keras.callbacks.Callback]] = None, + class_weights: Optional[Dict[int, float]] = None, + verbose: int = 1) -> tf.keras.callbacks.History: + """ + 训练模型 + + Args: + x_train: 训练数据特征 + y_train: 训练数据标签 + validation_data: 验证数据 + epochs: 训练轮数 + callbacks: 回调函数列表 + class_weights: 类别权重 + verbose: 详细程度 + + Returns: + 训练历史 + """ + if self.model is None: + raise ValueError("模型尚未构建,请先调用build方法") + + # 记录开始时间 + start_time = time.time() + + # 添加默认回调函数 + if callbacks is None: + callbacks = self._get_default_callbacks() + + # 训练模型 + if isinstance(x_train, tf.data.Dataset): + # 如果输入是TensorFlow Dataset + history = self.model.fit( + x_train, + epochs=epochs, + validation_data=validation_data, + callbacks=callbacks, + class_weight=class_weights, + verbose=verbose + ) + else: + # 如果输入是NumPy数组 + history = self.model.fit( + x_train, y_train, + batch_size=self.batch_size, + epochs=epochs, + validation_data=validation_data, + callbacks=callbacks, + class_weight=class_weights, + verbose=verbose + ) + + # 计算训练时间 + train_time = time.time() - start_time + + # 保存训练历史 + self.history = history.history + self.history['train_time'] = train_time + + logger.info(f"模型训练完成,耗时: {train_time:.2f} 秒") + + return history + + def evaluate(self, x_test: Union[np.ndarray, tf.data.Dataset], + y_test: Optional[np.ndarray] = None, + verbose: int = 1) -> Dict[str, float]: + """ + 评估模型 + + Args: + x_test: 测试数据特征 + y_test: 测试数据标签 + verbose: 详细程度 + + Returns: + 评估结果字典 + """ + if self.model is None: + raise ValueError("模型尚未构建,请先调用build方法") + + # 评估模型 + if isinstance(x_test, tf.data.Dataset): + # 如果输入是TensorFlow Dataset + results = self.model.evaluate(x_test, verbose=verbose) + else: + # 如果输入是NumPy数组 + results = self.model.evaluate(x_test, y_test, batch_size=self.batch_size, verbose=verbose) + + # 构建评估结果字典 + metrics_names = self.model.metrics_names + evaluation_results = {name: float(value) for name, value in zip(metrics_names, results)} + + logger.info(f"模型评估结果: {evaluation_results}") + + return evaluation_results + + def predict(self, x: Union[np.ndarray, tf.data.Dataset, List], + batch_size: Optional[int] = None, + verbose: int = 0) -> np.ndarray: + """ + 使用模型进行预测 + + Args: + x: 预测数据 + batch_size: 批大小 + verbose: 详细程度 + + Returns: + 预测结果 + """ + if self.model is None: + raise ValueError("模型尚未构建,请先调用build方法") + + # 使用模型进行预测 + if batch_size is None: + batch_size = self.batch_size + + return self.model.predict(x, batch_size=batch_size, verbose=verbose) + + def predict_classes(self, x: Union[np.ndarray, tf.data.Dataset, List], + batch_size: Optional[int] = None, + verbose: int = 0) -> np.ndarray: + """ + 使用模型预测类别 + + Args: + x: 预测数据 + batch_size: 批大小 + verbose: 详细程度 + + Returns: + 预测的类别索引 + """ + # 获取模型预测概率 + predictions = self.predict(x, batch_size, verbose) + + # 获取最大概率的类别索引 + return np.argmax(predictions, axis=1) + + def save(self, filepath: Optional[str] = None, + save_format: str = 'tf', + include_optimizer: bool = True) -> str: + """ + 保存模型 + + Args: + filepath: 保存路径,如果为None则使用默认路径 + save_format: 保存格式,'tf'或'h5' + include_optimizer: 是否包含优化器状态 + + Returns: + 保存路径 + """ + if self.model is None: + raise ValueError("模型尚未构建,请先调用build方法") + + # 如果未指定保存路径,使用默认路径 + if filepath is None: + ensure_dir(CLASSIFIERS_DIR) + timestamp = time.strftime("%Y%m%d_%H%M%S") + filepath = os.path.join(CLASSIFIERS_DIR, f"{self.model_name}_{timestamp}") + + # 保存模型 + self.model.save(filepath, save_format=save_format, include_optimizer=include_optimizer) + + # 保存模型配置 + config_path = f"{filepath}_config.json" + with open(config_path, 'w', encoding='utf-8') as f: + json.dump(self.config, f, ensure_ascii=False, indent=4) + + logger.info(f"模型已保存到: {filepath}") + + return filepath + + @classmethod + def load(cls, filepath: str, custom_objects: Optional[Dict[str, Any]] = None) -> 'TextClassificationModel': + """ + 加载模型 + + Args: + filepath: 模型文件路径 + custom_objects: 自定义对象字典 + + Returns: + 加载的模型实例 + """ + # 加载模型配置 + config_path = f"{filepath}_config.json" + + try: + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + except FileNotFoundError: + logger.warning(f"未找到模型配置文件: {config_path},将使用默认配置") + config = {} + + # 创建模型实例 + model_name = config.get('model_name', 'loaded_model') + num_classes = config.get('num_classes', 1) + batch_size = config.get('batch_size', BATCH_SIZE) + learning_rate = config.get('learning_rate', LEARNING_RATE) + + instance = cls(num_classes, model_name, batch_size, learning_rate) + + # 加载Keras模型 + instance.model = load_model(filepath, custom_objects=custom_objects) + + # 加载配置 + instance.config = config + + logger.info(f"从 {filepath} 加载模型成功") + + return instance + + def _get_default_callbacks(self) -> List[tf.keras.callbacks.Callback]: + """获取默认的回调函数列表""" + # 早停 + early_stopping = tf.keras.callbacks.EarlyStopping( + monitor='val_loss', + patience=EARLY_STOPPING_PATIENCE, + restore_best_weights=True, + verbose=1 + ) + + # 学习率衰减 + reduce_lr = tf.keras.callbacks.ReduceLROnPlateau( + monitor='val_loss', + factor=REDUCE_LR_FACTOR, + patience=REDUCE_LR_PATIENCE, + min_lr=1e-6, + verbose=1 + ) + + # 模型检查点 + checkpoint_path = os.path.join(SAVED_MODELS_DIR, 'checkpoints', self.model_name) + ensure_dir(os.path.dirname(checkpoint_path)) + + model_checkpoint = tf.keras.callbacks.ModelCheckpoint( + filepath=checkpoint_path, + save_best_only=True, + monitor='val_loss', + verbose=1 + ) + + # TensorBoard日志 + log_dir = os.path.join(SAVED_MODELS_DIR, 'logs', f"{self.model_name}_{time.strftime('%Y%m%d_%H%M%S')}") + ensure_dir(log_dir) + + tensorboard = tf.keras.callbacks.TensorBoard( + log_dir=log_dir, + histogram_freq=1 + ) + + return [early_stopping, reduce_lr, model_checkpoint, tensorboard] + + def get_config(self) -> Dict[str, Any]: + """获取模型配置""" + return self.config.copy() + + def get_model(self) -> Model: + """获取Keras模型实例""" + return self.model + + def get_training_history(self) -> Optional[Dict[str, List[float]]]: + """获取训练历史""" + return self.history + + def plot_training_history(self, save_path: Optional[str] = None, + metrics: Optional[List[str]] = None) -> None: + """ + 绘制训练历史 + + Args: + save_path: 保存路径,如果为None则显示图像 + metrics: 要绘制的指标列表,默认为['loss', 'accuracy'] + """ + if self.history is None: + raise ValueError("模型尚未训练,没有训练历史") + + import matplotlib.pyplot as plt + + if metrics is None: + metrics = ['loss', 'accuracy'] + + # 创建图形 + plt.figure(figsize=(12, 5)) + + # 绘制指标 + for i, metric in enumerate(metrics): + plt.subplot(1, len(metrics), i + 1) + + if metric in self.history: + plt.plot(self.history[metric], label=f'train_{metric}') + + val_metric = f'val_{metric}' + if val_metric in self.history: + plt.plot(self.history[val_metric], label=f'val_{metric}') + + plt.title(f'Model {metric}') + plt.xlabel('Epoch') + plt.ylabel(metric) + plt.legend() + + plt.tight_layout() + + # 保存或显示图像 + if save_path: + plt.savefig(save_path) + logger.info(f"训练历史图已保存到: {save_path}") + else: + plt.show() + +================================================================================ +文件: models/rnn_model.py +================================================================================ + +""" +RNN模型:实现基于循环神经网络的文本分类模型 +""" +import tensorflow as tf +from tensorflow.keras.models import Model +from tensorflow.keras.layers import ( + Input, Embedding, LSTM, GRU, Bidirectional, Dense, Dropout, + BatchNormalization, Activation, GlobalMaxPooling1D, GlobalAveragePooling1D +) +from typing import List, Dict, Tuple, Optional, Any, Union +import numpy as np + +from config.model_config import ( + MAX_SEQUENCE_LENGTH, RNN_CONFIG +) +from models.base_model import TextClassificationModel +from utils.logger import get_logger + +logger = get_logger("RNNModel") + + +class RNNTextClassifier(TextClassificationModel): + """循环神经网络文本分类模型""" + + def __init__(self, num_classes: int, vocab_size: int, + embedding_dim: int = RNN_CONFIG["embedding_dim"], + max_sequence_length: int = MAX_SEQUENCE_LENGTH, + hidden_size: int = RNN_CONFIG["hidden_size"], + num_layers: int = RNN_CONFIG["num_layers"], + bidirectional: bool = RNN_CONFIG["bidirectional"], + rnn_type: str = "lstm", # 'lstm' or 'gru' + dropout_rate: float = RNN_CONFIG["dropout_rate"], + embedding_matrix: Optional[np.ndarray] = None, + trainable_embedding: bool = True, + pool_type: str = "max", # 'max', 'avg', or 'both' + model_name: str = "rnn_text_classifier", + batch_size: int = 64, + learning_rate: float = 0.001): + """ + 初始化RNN文本分类模型 + + Args: + num_classes: 类别数量 + vocab_size: 词汇表大小 + embedding_dim: 词嵌入维度 + max_sequence_length: 最大序列长度 + hidden_size: 隐藏层大小 + num_layers: RNN层数 + bidirectional: 是否使用双向RNN + rnn_type: RNN类型,'lstm'或'gru' + dropout_rate: Dropout比例 + embedding_matrix: 预训练词嵌入矩阵,如果为None则使用随机初始化 + trainable_embedding: 词嵌入是否可训练 + pool_type: 池化类型,'max'、'avg'或'both' + model_name: 模型名称 + batch_size: 批大小 + learning_rate: 学习率 + """ + super().__init__(num_classes, model_name, batch_size, learning_rate) + + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.max_sequence_length = max_sequence_length + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bidirectional = bidirectional + self.rnn_type = rnn_type.lower() + self.dropout_rate = dropout_rate + self.embedding_matrix = embedding_matrix + self.trainable_embedding = trainable_embedding + self.pool_type = pool_type + + # 验证RNN类型 + if self.rnn_type not in ["lstm", "gru"]: + raise ValueError("无效的RNN类型,支持的类型: 'lstm', 'gru'") + + # 验证池化类型 + if self.pool_type not in ["max", "avg", "both"]: + raise ValueError("无效的池化类型,支持的类型: 'max', 'avg', 'both'") + + # 更新配置 + self.config.update({ + "vocab_size": vocab_size, + "embedding_dim": embedding_dim, + "max_sequence_length": max_sequence_length, + "hidden_size": hidden_size, + "num_layers": num_layers, + "bidirectional": bidirectional, + "rnn_type": rnn_type, + "dropout_rate": dropout_rate, + "trainable_embedding": trainable_embedding, + "pool_type": pool_type, + "model_type": "RNN" + }) + + logger.info(f"初始化RNN文本分类模型,类型: {rnn_type.upper()}, 隐藏层大小: {hidden_size}, 层数: {num_layers}") + + def build(self) -> None: + """构建RNN模型架构""" + # Input layer + sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input') + + # Embedding layer + if self.embedding_matrix is not None: + embedding_layer = Embedding( + input_dim=self.vocab_size, + output_dim=self.embedding_dim, + weights=[self.embedding_matrix], + input_length=self.max_sequence_length, + trainable=self.trainable_embedding, + name='embedding' + ) + else: + embedding_layer = Embedding( + input_dim=self.vocab_size, + output_dim=self.embedding_dim, + input_length=self.max_sequence_length, + trainable=True, + name='embedding' + ) + + embedded_sequences = embedding_layer(sequence_input) + + # 选择RNN层类型 + if self.rnn_type == "lstm": + rnn_layer = LSTM + else: # gru + rnn_layer = GRU + + # 构建多层RNN + x = embedded_sequences + for i in range(self.num_layers): + return_sequences = i < self.num_layers - 1 or self.pool_type != "last" + + if self.bidirectional: + x = Bidirectional( + rnn_layer( + self.hidden_size, + return_sequences=return_sequences, + dropout=self.dropout_rate if i < self.num_layers - 1 else 0, + name=f'{self.rnn_type}_{i + 1}' + ) + )(x) + else: + x = rnn_layer( + self.hidden_size, + return_sequences=return_sequences, + dropout=self.dropout_rate if i < self.num_layers - 1 else 0, + name=f'{self.rnn_type}_{i + 1}' + )(x) + + # 根据池化类型选择池化方法 + if self.pool_type == "max": + # 使用全局最大池化 + pooled = GlobalMaxPooling1D(name='global_max_pooling')(x) + elif self.pool_type == "avg": + # 使用全局平均池化 + pooled = GlobalAveragePooling1D(name='global_avg_pooling')(x) + elif self.pool_type == "both": + # 同时使用最大池化和平均池化,然后拼接 + max_pooled = GlobalMaxPooling1D(name='global_max_pooling')(x) + avg_pooled = GlobalAveragePooling1D(name='global_avg_pooling')(x) + pooled = tf.keras.layers.Concatenate(name='concatenate')([max_pooled, avg_pooled]) + else: # "last",使用最后一个时间步的输出 + # 最后一层RNN已经返回了最后一个时间步的状态,不需要额外池化 + pooled = x + + # Dropout for regularization + x = Dropout(self.dropout_rate, name='dropout_1')(pooled) + + # Dense layer + x = Dense(128, name='dense_1')(x) + x = BatchNormalization(name='batch_norm_1')(x) + x = Activation('relu', name='activation_1')(x) + x = Dropout(self.dropout_rate, name='dropout_2')(x) + + # Output layer + if self.num_classes == 2: + # Binary classification + predictions = Dense(1, activation='sigmoid', name='predictions')(x) + else: + # Multi-class classification + predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x) + + # Build the model + self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name) + + logger.info( + f"RNN模型构建完成,类型: {self.rnn_type.upper()}, 双向: {self.bidirectional}, 池化类型: {self.pool_type}") + + def compile(self, optimizer=None, loss=None, metrics=None) -> None: + """ + 编译RNN模型 + + Args: + optimizer: 优化器,默认为Adam + loss: 损失函数,默认根据类别数量选择 + metrics: 评估指标,默认为accuracy + """ + if self.model is None: + raise ValueError("模型尚未构建,请先调用build方法") + + # 默认优化器 + if optimizer is None: + optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate) + + # 默认损失函数 + if loss is None: + if self.num_classes == 2: + loss = 'binary_crossentropy' + else: + loss = 'sparse_categorical_crossentropy' + + # 默认评估指标 + if metrics is None: + metrics = ['accuracy'] + + # 编译模型 + self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics) + logger.info(f"RNN模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}") + + +================================================================================ +文件: models/model_factory.py +================================================================================ + +""" +模型工厂:统一创建和管理不同类型的模型 +""" +from typing import List, Dict, Tuple, Optional, Any, Union +import os +import glob +import time +import numpy as np + +from config.system_config import CLASSIFIERS_DIR +from config.model_config import ( + BATCH_SIZE, LEARNING_RATE +) +from models.base_model import TextClassificationModel +from models.cnn_model import CNNTextClassifier +from models.rnn_model import RNNTextClassifier +from models.transformer_model import TransformerTextClassifier +from utils.logger import get_logger + +logger = get_logger("ModelFactory") + + +class ModelFactory: + """模型工厂,用于创建和管理不同类型的模型""" + + @staticmethod + def create_model(model_type: str, num_classes: int, vocab_size: int, + embedding_matrix: Optional[np.ndarray] = None, + model_config: Optional[Dict[str, Any]] = None, + **kwargs) -> TextClassificationModel: + """ + 创建指定类型的模型 + + Args: + model_type: 模型类型,可选值: 'cnn', 'rnn', 'transformer' + num_classes: 类别数量 + vocab_size: 词汇表大小 + embedding_matrix: 预训练词嵌入矩阵 + model_config: 模型配置字典 + **kwargs: 其他参数 + + Returns: + 创建的模型实例 + """ + model_type = model_type.lower() + + # 合并配置 + config = model_config or {} + config.update(kwargs) + + # 创建模型 + if model_type == 'cnn': + model = CNNTextClassifier( + num_classes=num_classes, + vocab_size=vocab_size, + embedding_matrix=embedding_matrix, + **config + ) + elif model_type == 'rnn': + model = RNNTextClassifier( + num_classes=num_classes, + vocab_size=vocab_size, + embedding_matrix=embedding_matrix, + **config + ) + elif model_type == 'transformer': + model = TransformerTextClassifier( + num_classes=num_classes, + vocab_size=vocab_size, + embedding_matrix=embedding_matrix, + **config + ) + else: + raise ValueError(f"不支持的模型类型: {model_type}") + + logger.info(f"已创建 {model_type.upper()} 模型") + + return model + + @staticmethod + def load_model(model_path: str, custom_objects: Optional[Dict[str, Any]] = None) -> TextClassificationModel: + """ + 加载保存的模型 + + Args: + model_path: 模型路径 + custom_objects: 自定义对象字典 + + Returns: + 加载的模型实例 + """ + # 添加Transformer相关的自定义对象 + if custom_objects is None: + custom_objects = {} + + if 'TransformerBlock' not in custom_objects: + from models.transformer_model import TransformerBlock + custom_objects['TransformerBlock'] = TransformerBlock + + # 根据配置确定模型类型 + model_config_path = f"{model_path}_config.json" + + import json + with open(model_config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + + model_type = config.get('model_type', '').lower() + + # 根据模型类型选择加载方法 + if model_type == 'cnn': + model = CNNTextClassifier.load(model_path, custom_objects) + elif model_type == 'rnn': + model = RNNTextClassifier.load(model_path, custom_objects) + elif model_type == 'transformer': + model = TransformerTextClassifier.load(model_path, custom_objects) + else: + # 如果无法确定模型类型,使用基类加载 + logger.warning(f"无法确定模型类型,使用基类加载: {model_path}") + model = TextClassificationModel.load(model_path, custom_objects) + + logger.info(f"已加载模型: {model_path}") + + return model + + @staticmethod + def get_available_models() -> List[Dict[str, Any]]: + """ + 获取可用的已保存模型列表 + + Returns: + 模型信息列表,每个元素是包含模型信息的字典 + """ + model_files = glob.glob(os.path.join(CLASSIFIERS_DIR, "*")) + model_files = [f for f in model_files if not f.endswith("_config.json")] + + models_info = [] + + for model_file in model_files: + config_file = f"{model_file}_config.json" + + if os.path.exists(config_file): + try: + import json + with open(config_file, 'r', encoding='utf-8') as f: + config = json.load(f) + + # 获取模型文件的创建时间 + created_time = os.path.getctime(model_file) + created_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_time)) + + # 获取模型文件大小 + file_size = os.path.getsize(model_file) / (1024 * 1024) # MB + + models_info.append({ + "path": model_file, + "name": config.get("model_name", os.path.basename(model_file)), + "type": config.get("model_type", "unknown"), + "num_classes": config.get("num_classes", 0), + "created_time": created_time_str, + "file_size": f"{file_size:.2f} MB", + "config": config + }) + except Exception as e: + logger.error(f"读取模型配置失败: {config_file}, 错误: {e}") + + # 按创建时间降序排序 + models_info.sort(key=lambda x: x.get("created_time", ""), reverse=True) + + return models_info + + +================================================================================ +文件: models/cnn_model.py +================================================================================ + +""" +CNN模型:实现基于卷积神经网络的文本分类模型 +""" +import tensorflow as tf +from tensorflow.keras.models import Model +from tensorflow.keras.layers import ( + Input, Embedding, Conv1D, MaxPooling1D, GlobalMaxPooling1D, + Dense, Dropout, Concatenate, BatchNormalization, Activation +) +from typing import List, Dict, Tuple, Optional, Any, Union + +from config.model_config import ( + MAX_SEQUENCE_LENGTH, CNN_CONFIG +) +from models.base_model import TextClassificationModel +from utils.logger import get_logger + +logger = get_logger("CNNModel") + + +class CNNTextClassifier(TextClassificationModel): + """卷积神经网络文本分类模型""" + + def __init__(self, num_classes: int, vocab_size: int, + embedding_dim: int = CNN_CONFIG["embedding_dim"], + max_sequence_length: int = MAX_SEQUENCE_LENGTH, + num_filters: int = CNN_CONFIG["num_filters"], + filter_sizes: List[int] = CNN_CONFIG["filter_sizes"], + dropout_rate: float = CNN_CONFIG["dropout_rate"], + l2_reg_lambda: float = CNN_CONFIG["l2_reg_lambda"], + embedding_matrix: Optional[np.ndarray] = None, + trainable_embedding: bool = True, + model_name: str = "cnn_text_classifier", + batch_size: int = 64, + learning_rate: float = 0.001): + """ + 初始化CNN文本分类模型 + + Args: + num_classes: 类别数量 + vocab_size: 词汇表大小 + embedding_dim: 词嵌入维度 + max_sequence_length: 最大序列长度 + num_filters: 卷积核数量 + filter_sizes: 卷积核大小列表 + dropout_rate: Dropout比例 + l2_reg_lambda: L2正则化系数 + embedding_matrix: 预训练词嵌入矩阵,如果为None则使用随机初始化 + trainable_embedding: 词嵌入是否可训练 + model_name: 模型名称 + batch_size: 批大小 + learning_rate: 学习率 + """ + super().__init__(num_classes, model_name, batch_size, learning_rate) + + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.max_sequence_length = max_sequence_length + self.num_filters = num_filters + self.filter_sizes = filter_sizes + self.dropout_rate = dropout_rate + self.l2_reg_lambda = l2_reg_lambda + self.embedding_matrix = embedding_matrix + self.trainable_embedding = trainable_embedding + + # 更新配置 + self.config.update({ + "vocab_size": vocab_size, + "embedding_dim": embedding_dim, + "max_sequence_length": max_sequence_length, + "num_filters": num_filters, + "filter_sizes": filter_sizes, + "dropout_rate": dropout_rate, + "l2_reg_lambda": l2_reg_lambda, + "trainable_embedding": trainable_embedding, + "model_type": "CNN" + }) + + logger.info(f"初始化CNN文本分类模型,词汇表大小: {vocab_size}, 嵌入维度: {embedding_dim}") + + def build(self) -> None: + """构建CNN模型架构""" + # Input layer + sequence_input = Input(shape=(self.max_sequence_length,), dtype='int32', name='sequence_input') + + # Embedding layer + if self.embedding_matrix is not None: + embedding_layer = Embedding( + input_dim=self.vocab_size, + output_dim=self.embedding_dim, + weights=[self.embedding_matrix], + input_length=self.max_sequence_length, + trainable=self.trainable_embedding, + name='embedding' + ) + else: + embedding_layer = Embedding( + input_dim=self.vocab_size, + output_dim=self.embedding_dim, + input_length=self.max_sequence_length, + trainable=True, + name='embedding' + ) + + embedded_sequences = embedding_layer(sequence_input) + + # Convolutional layers with different filter sizes + conv_blocks = [] + for filter_size in self.filter_sizes: + conv = Conv1D( + filters=self.num_filters, + kernel_size=filter_size, + padding='valid', + activation='relu', + kernel_regularizer=tf.keras.regularizers.l2(self.l2_reg_lambda), + name=f'conv_{filter_size}' + )(embedded_sequences) + + # Max pooling + pooled = GlobalMaxPooling1D(name=f'max_pooling_{filter_size}')(conv) + conv_blocks.append(pooled) + + # Concatenate pooled features if we have multiple filter sizes + if len(self.filter_sizes) > 1: + concatenated = Concatenate(name='concatenate')(conv_blocks) + else: + concatenated = conv_blocks[0] + + # Dropout for regularization + x = Dropout(self.dropout_rate, name='dropout_1')(concatenated) + + # Dense layer + x = Dense(128, name='dense_1')(x) + x = BatchNormalization(name='batch_norm_1')(x) + x = Activation('relu', name='activation_1')(x) + x = Dropout(self.dropout_rate, name='dropout_2')(x) + + # Output layer + if self.num_classes == 2: + # Binary classification + predictions = Dense(1, activation='sigmoid', name='predictions')(x) + else: + # Multi-class classification + predictions = Dense(self.num_classes, activation='softmax', name='predictions')(x) + + # Build the model + self.model = Model(inputs=sequence_input, outputs=predictions, name=self.model_name) + + logger.info(f"CNN模型构建完成,过滤器大小: {self.filter_sizes}, 每种大小的过滤器数量: {self.num_filters}") + + def compile(self, optimizer=None, loss=None, metrics=None) -> None: + """ + 编译CNN模型 + + Args: + optimizer: 优化器,默认为Adam + loss: 损失函数,默认根据类别数量选择 + metrics: 评估指标,默认为accuracy + """ + if self.model is None: + raise ValueError("模型尚未构建,请先调用build方法") + + # 默认优化器 + if optimizer is None: + optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate) + + # 默认损失函数 + if loss is None: + if self.num_classes == 2: + loss = 'binary_crossentropy' + else: + loss = 'sparse_categorical_crossentropy' + + # 默认评估指标 + if metrics is None: + metrics = ['accuracy'] + + # 编译模型 + self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics) + logger.info(f"CNN模型已编译,优化器: {optimizer.__class__.__name__}, 损失函数: {loss}") + + +================================================================================ +文件: models/layers/__init__.py +================================================================================ + + + +================================================================================ +文件: scripts/predict.py +================================================================================ + +""" +预测脚本:使用模型进行预测 +""" +import os +import sys +import time +import argparse +import logging +from typing import List, Dict, Tuple, Optional, Any, Union +import numpy as np +import json + +# 将项目根目录添加到系统路径 +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(project_root) + +from config.system_config import ( + CATEGORIES, CLASSIFIERS_DIR +) +from models.model_factory import ModelFactory +from preprocessing.tokenization import ChineseTokenizer +from preprocessing.vectorizer import SequenceVectorizer +from inference.predictor import Predictor +from inference.batch_processor import BatchProcessor +from utils.logger import get_logger +from utils.file_utils import read_text_file + +logger = get_logger("Prediction") + + +def predict_text(text: str, model_path: Optional[str] = None, + output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]: + """ + 预测单条文本 + + Args: + text: 要预测的文本 + model_path: 模型路径,如果为None则使用最新的模型 + output_path: 输出文件路径,如果为None则不保存 + top_k: 返回概率最高的前k个类别 + + Returns: + 预测结果 + """ + logger.info("开始预测文本") + + # 1. 加载模型 + if model_path is None: + # 获取可用模型列表 + models_info = ModelFactory.get_available_models() + + if not models_info: + raise ValueError("未找到可用的模型") + + # 使用最新的模型 + model_path = models_info[0]['path'] + + logger.info(f"加载模型: {model_path}") + model = ModelFactory.load_model(model_path) + + # 2. 创建分词器和预测器 + tokenizer = ChineseTokenizer() + + # 查找向量化器文件 + vectorizer = None + for model_type in ["cnn", "rnn", "transformer"]: + vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl") + if os.path.exists(vectorizer_path): + # 加载向量化器 + vectorizer = SequenceVectorizer() + vectorizer.load(vectorizer_path) + logger.info(f"加载向量化器: {vectorizer_path}") + break + + # 创建预测器 + predictor = Predictor( + model=model, + tokenizer=tokenizer, + vectorizer=vectorizer, + class_names=CATEGORIES + ) + + # 3. 预测 + result = predictor.predict( + text=text, + return_top_k=top_k, + return_probabilities=True + ) + + # 4. 输出结果 + if top_k > 1: + logger.info("预测结果:") + for i, pred in enumerate(result): + logger.info(f" {i + 1}. {pred['class']} (概率: {pred['probability']:.4f})") + else: + logger.info(f"预测结果: {result['class']} (概率: {result['probability']:.4f})") + + # 5. 保存结果 + if output_path: + if output_path.endswith('.json'): + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(result, f, ensure_ascii=False, indent=2) + else: + with open(output_path, 'w', encoding='utf-8') as f: + if top_k > 1: + f.write("rank,class,probability\n") + for i, pred in enumerate(result): + f.write(f"{i + 1},{pred['class']},{pred['probability']}\n") + else: + f.write(f"class,probability\n") + f.write(f"{result['class']},{result['probability']}\n") + + logger.info(f"结果已保存到: {output_path}") + + return result + + +def predict_file(file_path: str, model_path: Optional[str] = None, + output_path: Optional[str] = None, top_k: int = 3) -> Dict[str, Any]: + """ + 预测文件内容 + + Args: + file_path: 文件路径 + model_path: 模型路径,如果为None则使用最新的模型 + output_path: 输出文件路径,如果为None则不保存 + top_k: 返回概率最高的前k个类别 + + Returns: + 预测结果 + """ + logger.info(f"开始预测文件: {file_path}") + + # 检查文件是否存在 + if not os.path.exists(file_path): + raise FileNotFoundError(f"文件不存在: {file_path}") + + # 读取文件内容 + if file_path.endswith('.txt'): + # 文本文件 + text = read_text_file(file_path) + return predict_text(text, model_path, output_path, top_k) + + elif file_path.endswith(('.csv', '.xls', '.xlsx')): + # 表格文件 + import pandas as pd + + if file_path.endswith('.csv'): + df = pd.read_csv(file_path) + else: + df = pd.read_excel(file_path) + + # 查找可能的文本列 + text_columns = [col for col in df.columns if df[col].dtype == 'object'] + + if not text_columns: + raise ValueError("文件中没有找到可能的文本列") + + # 使用第一个文本列 + text_column = text_columns[0] + logger.info(f"使用文本列: {text_column}") + + # 1. 加载模型 + if model_path is None: + # 获取可用模型列表 + models_info = ModelFactory.get_available_models() + + if not models_info: + raise ValueError("未找到可用的模型") + + # 使用最新的模型 + model_path = models_info[0]['path'] + + logger.info(f"加载模型: {model_path}") + model = ModelFactory.load_model(model_path) + + # 2. 创建分词器和预测器 + tokenizer = ChineseTokenizer() + + # 查找向量化器文件 + vectorizer = None + for model_type in ["cnn", "rnn", "transformer"]: + vectorizer_path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl") + if os.path.exists(vectorizer_path): + # 加载向量化器 + vectorizer = SequenceVectorizer() + vectorizer.load(vectorizer_path) + logger.info(f"加载向量化器: {vectorizer_path}") + break + + # 创建预测器 + predictor = Predictor( + model=model, + tokenizer=tokenizer, + vectorizer=vectorizer, + class_names=CATEGORIES + ) + + # 3. 创建批处理器 + batch_processor = BatchProcessor( + predictor=predictor, + batch_size=64 + ) + + # 4. 批量预测 + result_df = batch_processor.process_dataframe( + df=df, + text_column=text_column, + output_path=output_path, + return_top_k=top_k, + format=output_path.split('.')[-1] if output_path else 'csv' + ) + + logger.info(f"已处理 {len(result_df)} 行数据") + + # 返回结果 + return result_df.to_dict(orient='records') + + else: + raise ValueError(f"不支持的文件类型: {file_path}") + + +if __name__ == "__main__": + # 解析命令行参数 + parser = argparse.ArgumentParser(description="使用模型预测") + parser.add_argument("--model_path", help="模型路径") + parser.add_argument("--text", help="要预测的文本") + parser.add_argument("--file", help="要预测的文件") + parser.add_argument("--output", help="输出文件") + parser.add_argument("--top_k", type=int, default=3, help="返回概率最高的前k个类别") + + args = parser.parse_args() + + # 检查输入 + if not args.text and not args.file: + parser.error("请提供要预测的文本或文件") + + # 预测 + if args.text: + predict_text(args.text, args.model_path, args.output, args.top_k) + else: + predict_file(args.file, args.model_path, args.output, args.top_k) + + +================================================================================ +文件: scripts/train.py +================================================================================ + +""" +训练脚本:训练文本分类模型 +""" +import os +import sys +import time +import argparse +import logging +from typing import List, Dict, Tuple, Optional, Any, Union +import numpy as np +import tensorflow as tf +import matplotlib.pyplot as plt + +# 将项目根目录添加到系统路径 +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(project_root) + +from config.system_config import ( + RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR +) +from config.model_config import ( + BATCH_SIZE, NUM_EPOCHS, MAX_SEQUENCE_LENGTH, MAX_NUM_WORDS +) +from data.dataloader import DataLoader +from data.data_manager import DataManager +from preprocessing.tokenization import ChineseTokenizer +from preprocessing.vectorizer import SequenceVectorizer +from models.model_factory import ModelFactory +from training.trainer import Trainer +from utils.logger import get_logger + +logger = get_logger("Training") + + +def train_model(data_dir: Optional[str] = None, + model_type: str = "cnn", + epochs: int = NUM_EPOCHS, + batch_size: int = BATCH_SIZE, + save_dir: Optional[str] = None, + validation_split: float = 0.1, + use_pretrained_embedding: bool = False, + embedding_path: Optional[str] = None) -> str: + """ + 训练文本分类模型 + + Args: + data_dir: 数据目录,如果为None则使用默认目录 + model_type: 模型类型,'cnn', 'rnn', 或 'transformer' + epochs: 训练轮数 + batch_size: 批大小 + save_dir: 模型保存目录,如果为None则使用默认目录 + validation_split: 验证集比例 + use_pretrained_embedding: 是否使用预训练词向量 + embedding_path: 预训练词向量路径 + + Returns: + 保存的模型路径 + """ + logger.info(f"开始训练 {model_type.upper()} 模型") + start_time = time.time() + + # 设置数据目录 + data_dir = data_dir or RAW_DATA_DIR + + # 设置保存目录 + if save_dir: + save_dir = os.path.abspath(save_dir) + os.makedirs(save_dir, exist_ok=True) + else: + save_dir = CLASSIFIERS_DIR + os.makedirs(save_dir, exist_ok=True) + + # 1. 加载数据 + logger.info("加载数据...") + data_loader = DataLoader(data_dir=data_dir) + data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR) + + # 加载和分割数据 + data = data_manager.load_and_split_data( + data_loader=data_loader, + val_split=validation_split, + sample_ratio=1.0, + save=True + ) + + # 获取训练集和验证集 + train_texts, train_labels = data_manager.get_data(dataset="train") + val_texts, val_labels = data_manager.get_data(dataset="val") + + # 2. 准备数据 + # 创建分词器 + tokenizer = ChineseTokenizer() + + # 对训练文本进行分词 + logger.info("对文本进行分词...") + tokenized_train_texts = [tokenizer.tokenize(text, return_string=True) for text in train_texts] + tokenized_val_texts = [tokenizer.tokenize(text, return_string=True) for text in val_texts] + + # 创建序列向量化器 + logger.info("创建序列向量化器...") + vectorizer = SequenceVectorizer( + max_features=MAX_NUM_WORDS, + max_sequence_length=MAX_SEQUENCE_LENGTH + ) + + # 训练向量化器并转换文本 + vectorizer.fit(tokenized_train_texts) + X_train = vectorizer.transform(tokenized_train_texts) + X_val = vectorizer.transform(tokenized_val_texts) + + # 保存向量化器 + vectorizer_path = os.path.join(save_dir, f"vectorizer_{model_type}.pkl") + vectorizer.save(vectorizer_path) + logger.info(f"向量化器已保存到: {vectorizer_path}") + + # 获取一些基本参数 + num_classes = len(CATEGORIES) + vocab_size = vectorizer.get_vocabulary_size() + + # 3. 创建模型 + logger.info(f"创建 {model_type.upper()} 模型...") + + # 加载预训练词向量(如果指定) + embedding_matrix = None + if use_pretrained_embedding and embedding_path: + # 这里简化处理,实际应用中应该加载和处理预训练词向量 + logger.info("加载预训练词向量...") + embedding_matrix = np.random.random((vocab_size, 200)) + + # 创建模型 + model = ModelFactory.create_model( + model_type=model_type, + num_classes=num_classes, + vocab_size=vocab_size, + embedding_matrix=embedding_matrix, + batch_size=batch_size + ) + + # 构建模型 + model.build() + model.compile() + model.summary() + + # 4. 训练模型 + logger.info("开始训练模型...") + trainer = Trainer( + model=model, + epochs=epochs, + batch_size=batch_size, + early_stopping=True, + tensorboard=True + ) + + # 训练 + history = trainer.train( + x_train=X_train, + y_train=train_labels, + x_val=X_val, + y_val=val_labels + ) + + # 5. 保存模型 + timestamp = time.strftime("%Y%m%d_%H%M%S") + model_path = os.path.join(save_dir, f"{model_type}_model_{timestamp}") + model.save(model_path) + logger.info(f"模型已保存到: {model_path}") + + # 6. 绘制训练历史 + logger.info("绘制训练历史...") + model.plot_training_history(save_path=os.path.join(save_dir, f"training_history_{model_type}_{timestamp}.png")) + + # 7. 计算训练时间 + train_time = time.time() - start_time + logger.info(f"模型训练完成,耗时: {train_time:.2f} 秒") + + return model_path + + +if __name__ == "__main__": + # 解析命令行参数 + parser = argparse.ArgumentParser(description="训练文本分类模型") + parser.add_argument("--data_dir", help="数据目录") + parser.add_argument("--model_type", choices=["cnn", "rnn", "transformer"], default="cnn", help="模型类型") + parser.add_argument("--epochs", type=int, default=NUM_EPOCHS, help="训练轮数") + parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小") + parser.add_argument("--save_dir", help="模型保存目录") + parser.add_argument("--validation_split", type=float, default=0.1, help="验证集比例") + parser.add_argument("--use_pretrained_embedding", action="store_true", help="是否使用预训练词向量") + parser.add_argument("--embedding_path", help="预训练词向量路径") + + args = parser.parse_args() + + # 训练模型 + train_model( + data_dir=args.data_dir, + model_type=args.model_type, + epochs=args.epochs, + batch_size=args.batch_size, + save_dir=args.save_dir, + validation_split=args.validation_split, + use_pretrained_embedding=args.use_pretrained_embedding, + embedding_path=args.embedding_path + ) + + +================================================================================ +文件: scripts/evaluate.py +================================================================================ + +""" +评估脚本:评估文本分类模型性能 +""" +import os +import sys +import time +import argparse +import logging +from typing import List, Dict, Tuple, Optional, Any, Union +import numpy as np +import matplotlib.pyplot as plt + +# 将项目根目录添加到系统路径 +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(project_root) + +from config.system_config import ( + RAW_DATA_DIR, CATEGORIES, CLASSIFIERS_DIR, PROCESSED_DATA_DIR +) +from config.model_config import ( + BATCH_SIZE, MAX_SEQUENCE_LENGTH +) +from data.dataloader import DataLoader +from data.data_manager import DataManager +from preprocessing.tokenization import ChineseTokenizer +from preprocessing.vectorizer import SequenceVectorizer +from models.model_factory import ModelFactory +from evaluation.evaluator import ModelEvaluator +from utils.logger import get_logger + +logger = get_logger("Evaluation") + + +def evaluate_model(model_path: str, + data_dir: Optional[str] = None, + batch_size: int = BATCH_SIZE, + output_dir: Optional[str] = None) -> Dict[str, float]: + """ + 评估文本分类模型 + + Args: + model_path: 模型路径 + data_dir: 数据目录,如果为None则使用默认目录 + batch_size: 批大小 + output_dir: 评估结果输出目录,如果为None则使用默认目录 + + Returns: + 评估指标 + """ + logger.info(f"开始评估模型: {model_path}") + start_time = time.time() + + # 设置数据目录 + data_dir = data_dir or RAW_DATA_DIR + + # 设置输出目录 + if output_dir: + output_dir = os.path.abspath(output_dir) + os.makedirs(output_dir, exist_ok=True) + + # 1. 加载模型 + logger.info("加载模型...") + model = ModelFactory.load_model(model_path) + + # 2. 加载数据 + logger.info("加载数据...") + data_loader = DataLoader(data_dir=data_dir) + data_manager = DataManager(processed_dir=PROCESSED_DATA_DIR) + + # 加载测试集 + data_manager.load_data() + test_texts, test_labels = data_manager.get_data(dataset="test") + + # 3. 准备数据 + # 创建分词器 + tokenizer = ChineseTokenizer() + + # 对测试文本进行分词 + logger.info("对文本进行分词...") + tokenized_test_texts = [tokenizer.tokenize(text, return_string=True) for text in test_texts] + + # 创建序列向量化器 + logger.info("加载向量化器...") + # 查找向量化器文件 + vectorizer_path = None + for model_type in ["cnn", "rnn", "transformer"]: + path = os.path.join(os.path.dirname(model_path), f"vectorizer_{model_type}.pkl") + if os.path.exists(path): + vectorizer_path = path + break + + if not vectorizer_path: + # 如果找不到向量化器,创建一个新的 + logger.warning("未找到向量化器,创建一个新的") + vectorizer = SequenceVectorizer( + max_features=MAX_NUM_WORDS, + max_sequence_length=MAX_SEQUENCE_LENGTH + ) + else: + # 加载向量化器 + vectorizer = SequenceVectorizer() + vectorizer.load(vectorizer_path) + + # 转换测试文本 + X_test = vectorizer.transform(tokenized_test_texts) + + # 4. 创建评估器 + logger.info("创建评估器...") + evaluator = ModelEvaluator( + model=model, + class_names=CATEGORIES, + output_dir=output_dir + ) + + # 5. 评估模型 + logger.info("评估模型...") + metrics = evaluator.evaluate(X_test, test_labels, batch_size) + + # 6. 保存评估结果 + logger.info("保存评估结果...") + evaluator.save_evaluation_results(save_plots=True) + + # 7. 可视化混淆矩阵 + logger.info("可视化混淆矩阵...") + cm = evaluator.evaluation_results['confusion_matrix'] + evaluator.metrics.plot_confusion_matrix( + y_true=test_labels, + y_pred=np.argmax(model.predict(X_test), axis=1), + normalize='true', + save_path=os.path.join(output_dir or os.path.dirname(model_path), "confusion_matrix.png") + ) + + # 8. 类别性能分析 + logger.info("分析各类别性能...") + class_performance = evaluator.evaluate_class_performance(X_test, test_labels) + + # 9. 计算评估时间 + eval_time = time.time() - start_time + logger.info(f"模型评估完成,耗时: {eval_time:.2f} 秒") + + # 10. 输出主要指标 + logger.info("主要评估指标:") + for metric_name, metric_value in metrics.items(): + if metric_name in ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']: + logger.info(f" {metric_name}: {metric_value:.4f}") + + return metrics + + +if __name__ == "__main__": + # 解析命令行参数 + parser = argparse.ArgumentParser(description="评估文本分类模型") + parser.add_argument("--model_path", required=True, help="模型路径") + parser.add_argument("--data_dir", help="数据目录") + parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="批大小") + parser.add_argument("--output_dir", help="评估结果输出目录") + + args = parser.parse_args() + + # 评估模型 + evaluate_model( + model_path=args.model_path, + data_dir=args.data_dir, + batch_size=args.batch_size, + output_dir=args.output_dir + ) + + +================================================================================ +文件: inference/predictor.py +================================================================================ + +""" +预测器模块:实现模型预测功能,支持单条和批量文本预测 +""" +import os +import time +import numpy as np +import tensorflow as tf +from typing import List, Dict, Tuple, Optional, Any, Union +import pandas as pd +import json + +from config.system_config import CATEGORY_TO_ID, ID_TO_CATEGORY +from models.base_model import TextClassificationModel +from preprocessing.tokenization import ChineseTokenizer +from preprocessing.vectorizer import SequenceVectorizer +from utils.logger import get_logger + +logger = get_logger("Predictor") + + +class Predictor: + """预测器,负责加载模型和进行预测""" + + def __init__(self, model: TextClassificationModel, + tokenizer: Optional[ChineseTokenizer] = None, + vectorizer: Optional[SequenceVectorizer] = None, + class_names: Optional[List[str]] = None, + max_sequence_length: int = 500, + batch_size: Optional[int] = None): + """ + 初始化预测器 + + Args: + model: 已训练的模型实例 + tokenizer: 分词器实例,如果为None则创建一个新的分词器 + vectorizer: 文本向量化器实例,如果为None则表示模型直接接收序列 + class_names: 类别名称列表,如果为None则使用ID_TO_CATEGORY + max_sequence_length: 最大序列长度 + batch_size: 批大小,如果为None则使用模型默认值 + """ + self.model = model + self.tokenizer = tokenizer or ChineseTokenizer() + self.vectorizer = vectorizer + self.class_names = class_names + + if class_names is None and hasattr(model, 'num_classes'): + # 如果模型具有类别数量信息,从ID_TO_CATEGORY获取类别名称 + self.class_names = [ID_TO_CATEGORY.get(i, str(i)) for i in range(model.num_classes)] + + self.max_sequence_length = max_sequence_length + self.batch_size = batch_size or (model.batch_size if hasattr(model, 'batch_size') else 32) + + logger.info(f"初始化预测器,批大小: {self.batch_size}") + + def preprocess_text(self, text: str) -> Any: + """ + 预处理单条文本 + + Args: + text: 原始文本 + + Returns: + 预处理后的文本表示 + """ + # 分词 + tokenized_text = self.tokenizer.tokenize(text, return_string=True) + + # 如果有向量化器,应用向量化 + if self.vectorizer is not None: + return self.vectorizer.transform([tokenized_text])[0] + + return tokenized_text + + def preprocess_texts(self, texts: List[str]) -> Any: + """ + 批量预处理文本 + + Args: + texts: 原始文本列表 + + Returns: + 预处理后的批量文本表示 + """ + # 分词 + tokenized_texts = [self.tokenizer.tokenize(text, return_string=True) for text in texts] + + # 如果有向量化器,应用向量化 + if self.vectorizer is not None: + return self.vectorizer.transform(tokenized_texts) + + return tokenized_texts + + def predict(self, text: str, return_top_k: int = 1, + return_probabilities: bool = False) -> Union[str, Dict, List]: + """ + 预测单条文本的类别 + + Args: + text: 原始文本 + return_top_k: 返回概率最高的前k个类别 + return_probabilities: 是否返回概率值 + + Returns: + 预测结果,格式取决于参数设置 + """ + # 预处理文本 + processed_text = self.preprocess_text(text) + + # 添加批次维度 + if isinstance(processed_text, str): + input_data = np.array([processed_text]) + else: + input_data = np.expand_dims(processed_text, axis=0) + + # 预测 + start_time = time.time() + predictions = self.model.predict(input_data) + prediction_time = time.time() - start_time + + # 获取前k个预测结果 + if return_top_k > 1: + top_indices = np.argsort(predictions[0])[::-1][:return_top_k] + top_probs = predictions[0][top_indices] + + if self.class_names: + top_classes = [self.class_names[idx] for idx in top_indices] + else: + top_classes = [str(idx) for idx in top_indices] + + if return_probabilities: + return [{'class': cls, 'probability': float(prob)} + for cls, prob in zip(top_classes, top_probs)] + else: + return top_classes + else: + # 获取最高概率的类别 + pred_idx = np.argmax(predictions[0]) + pred_prob = float(predictions[0][pred_idx]) + + if self.class_names: + pred_class = self.class_names[pred_idx] + else: + pred_class = str(pred_idx) + + if return_probabilities: + return {'class': pred_class, 'probability': pred_prob, 'time': prediction_time} + else: + return pred_class + + def predict_batch(self, texts: List[str], return_top_k: int = 1, + return_probabilities: bool = False) -> List: + """ + 批量预测文本类别 + + Args: + texts: 原始文本列表 + return_top_k: 返回概率最高的前k个类别 + return_probabilities: 是否返回概率值 + + Returns: + 预测结果列表 + """ + # 空列表检查 + if not texts: + return [] + + # 预处理文本 + processed_texts = self.preprocess_texts(texts) + + # 预测 + start_time = time.time() + predictions = self.model.predict(processed_texts, batch_size=self.batch_size) + prediction_time = time.time() - start_time + + # 处理预测结果 + results = [] + + for i, pred in enumerate(predictions): + if return_top_k > 1: + top_indices = np.argsort(pred)[::-1][:return_top_k] + top_probs = pred[top_indices] + + if self.class_names: + top_classes = [self.class_names[idx] for idx in top_indices] + else: + top_classes = [str(idx) for idx in top_indices] + + if return_probabilities: + results.append([{'class': cls, 'probability': float(prob)} + for cls, prob in zip(top_classes, top_probs)]) + else: + results.append(top_classes) + else: + # 获取最高概率的类别 + pred_idx = np.argmax(pred) + pred_prob = float(pred[pred_idx]) + + if self.class_names: + pred_class = self.class_names[pred_idx] + else: + pred_class = str(pred_idx) + + if return_probabilities: + results.append({'class': pred_class, 'probability': pred_prob}) + else: + results.append(pred_class) + + logger.info(f"批量预测 {len(texts)} 条文本完成,用时: {prediction_time:.2f} 秒") + + return results + + def predict_to_dataframe(self, texts: List[str], + text_ids: Optional[List[Union[str, int]]] = None, + return_top_k: int = 1) -> pd.DataFrame: + """ + 批量预测并返回DataFrame + + Args: + texts: 原始文本列表 + text_ids: 文本ID列表,如果为None则使用索引 + return_top_k: 返回概率最高的前k个类别 + + Returns: + 预测结果DataFrame + """ + # 预测 + predictions = self.predict_batch(texts, return_top_k=return_top_k, return_probabilities=True) + + # 创建DataFrame + if text_ids is None: + text_ids = list(range(len(texts))) + + if return_top_k > 1: + # 多个类别的情况 + results = [] + for i, preds in enumerate(predictions): + for j, pred in enumerate(preds): + results.append({ + 'id': text_ids[i], + 'text': texts[i], + 'rank': j + 1, + 'predicted_class': pred['class'], + 'probability': pred['probability'] + }) + df = pd.DataFrame(results) + else: + # 单个类别的情况 + df = pd.DataFrame({ + 'id': text_ids, + 'text': texts, + 'predicted_class': [pred['class'] for pred in predictions], + 'probability': [pred['probability'] for pred in predictions] + }) + + return df + + def save_predictions(self, texts: List[str], + output_path: str, + text_ids: Optional[List[Union[str, int]]] = None, + return_top_k: int = 1, + format: str = 'csv') -> str: + """ + 批量预测并保存结果 + + Args: + texts: 原始文本列表 + output_path: 输出文件路径 + text_ids: 文本ID列表,如果为None则使用索引 + return_top_k: 返回概率最高的前k个类别 + format: 输出格式,'csv'或'json' + + Returns: + 输出文件路径 + """ + # 获取预测结果DataFrame + df = self.predict_to_dataframe(texts, text_ids, return_top_k) + + # 保存结果 + if format.lower() == 'csv': + df.to_csv(output_path, index=False, encoding='utf-8') + elif format.lower() == 'json': + # 转换为嵌套的JSON格式 + if return_top_k > 1: + # 分组后转换为嵌套格式 + result = {} + for id_val in df['id'].unique(): + sub_df = df[df['id'] == id_val] + predictions = [] + for _, row in sub_df.iterrows(): + predictions.append({ + 'class': row['predicted_class'], + 'probability': row['probability'] + }) + result[str(id_val)] = { + 'text': sub_df.iloc[0]['text'], + 'predictions': predictions + } + else: + # 直接构建JSON + result = {} + for _, row in df.iterrows(): + result[str(row['id'])] = { + 'text': row['text'], + 'predicted_class': row['predicted_class'], + 'probability': row['probability'] + } + + # 保存为JSON + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(result, f, ensure_ascii=False, indent=2) + else: + raise ValueError(f"不支持的输出格式: {format}") + + logger.info(f"预测结果已保存到: {output_path}") + + return output_path + + +================================================================================ +文件: inference/batch_processor.py +================================================================================ + +""" +批处理模块:实现批量处理大规模文本数据 +""" +import os +import time +import pandas as pd +import numpy as np +from typing import List, Dict, Tuple, Optional, Any, Union, Callable, Iterator +import concurrent.futures +from tqdm import tqdm +import glob +import json + +from config.system_config import ENCODING, DATA_LOADING_WORKERS, MAX_TEXT_PER_BATCH +from utils.logger import get_logger +from utils.file_utils import read_text_file, ensure_dir +from inference.predictor import Predictor + +logger = get_logger("BatchProcessor") + + +class BatchProcessor: + """批处理器,负责批量处理大规模文本数据""" + + def __init__(self, predictor: Predictor, + batch_size: int = 64, + max_workers: int = DATA_LOADING_WORKERS, + max_batch_queue: int = 10): + """ + 初始化批处理器 + + Args: + predictor: 预测器实例 + batch_size: 批大小 + max_workers: 最大工作线程数 + max_batch_queue: 最大批次队列长度 + """ + self.predictor = predictor + self.batch_size = batch_size + self.max_workers = max_workers + self.max_batch_queue = max_batch_queue + + logger.info(f"初始化批处理器,批大小: {batch_size}, 最大工作线程数: {max_workers}") + + def _extract_text_from_file(self, file_path: str) -> str: + """ + 从文件中提取文本 + + Args: + file_path: 文件路径 + + Returns: + 文本内容 + """ + return read_text_file(file_path, encoding=ENCODING) + + def _batch_generator(self, texts: List[str], batch_size: int) -> Iterator[List[str]]: + """ + 生成文本批次 + + Args: + texts: 文本列表 + batch_size: 批大小 + + Returns: + 文本批次生成器 + """ + for i in range(0, len(texts), batch_size): + yield texts[i:i + batch_size] + + def process_files(self, file_paths: List[str], output_path: Optional[str] = None, + return_top_k: int = 1, format: str = 'csv') -> pd.DataFrame: + """ + 批量处理文件 + + Args: + file_paths: 文件路径列表 + output_path: 输出文件路径,如果为None则不保存 + return_top_k: 返回概率最高的前k个类别 + format: 输出格式,'csv'或'json' + + Returns: + 预测结果DataFrame + """ + logger.info(f"开始批量处理 {len(file_paths)} 个文件") + start_time = time.time() + + # 使用线程池并行读取文件 + texts = [] + file_names = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: + future_to_file = {executor.submit(self._extract_text_from_file, file_path): file_path for file_path in + file_paths} + + for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(file_paths), desc="读取文件"): + file_path = future_to_file[future] + try: + text = future.result() + if text: + texts.append(text) + file_names.append(os.path.basename(file_path)) + except Exception as e: + logger.error(f"处理文件 {file_path} 时出错: {e}") + + # 批量预测 + all_predictions = [] + + for batch in tqdm(self._batch_generator(texts, self.batch_size), + total=(len(texts) + self.batch_size - 1) // self.batch_size, desc="预测"): + predictions = self.predictor.predict_batch(batch, return_top_k=return_top_k, return_probabilities=True) + all_predictions.extend(predictions) + + # 整合结果 + if return_top_k > 1: + # 多个类别的情况 + results = [] + for i, preds in enumerate(all_predictions): + for j, pred in enumerate(preds): + results.append({ + 'file_name': file_names[i], + 'rank': j + 1, + 'predicted_class': pred['class'], + 'probability': pred['probability'] + }) + df = pd.DataFrame(results) + else: + # 单个类别的情况 + df = pd.DataFrame({ + 'file_name': file_names, + 'predicted_class': [pred['class'] for pred in all_predictions], + 'probability': [pred['probability'] for pred in all_predictions] + }) + + # 保存结果 + if output_path: + if format.lower() == 'csv': + df.to_csv(output_path, index=False, encoding='utf-8') + elif format.lower() == 'json': + # 转换为嵌套的JSON格式 + if return_top_k > 1: + # 分组后转换为嵌套格式 + result = {} + for file_name in df['file_name'].unique(): + sub_df = df[df['file_name'] == file_name] + predictions = [] + for _, row in sub_df.iterrows(): + predictions.append({ + 'class': row['predicted_class'], + 'probability': row['probability'] + }) + result[file_name] = { + 'predictions': predictions + } + else: + # 直接构建JSON + result = {} + for _, row in df.iterrows(): + result[row['file_name']] = { + 'predicted_class': row['predicted_class'], + 'probability': row['probability'] + } + + # 保存为JSON + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(result, f, ensure_ascii=False, indent=2) + else: + raise ValueError(f"不支持的输出格式: {format}") + + logger.info(f"预测结果已保存到: {output_path}") + + processing_time = time.time() - start_time + logger.info(f"批量处理完成,共处理 {len(texts)} 个文件,用时: {processing_time:.2f} 秒") + + return df + + def process_directory(self, directory: str, pattern: str = "*.txt", + output_path: Optional[str] = None, + return_top_k: int = 1, format: str = 'csv', + recursive: bool = True) -> pd.DataFrame: + """ + 批量处理目录中的文件 + + Args: + directory: 目录路径 + pattern: 文件匹配模式 + output_path: 输出文件路径,如果为None则不保存 + return_top_k: 返回概率最高的前k个类别 + format: 输出格式,'csv'或'json' + recursive: 是否递归处理子目录 + + Returns: + 预测结果DataFrame + """ + # 获取符合模式的文件路径 + if recursive: + file_paths = glob.glob(os.path.join(directory, "**", pattern), recursive=True) + else: + file_paths = glob.glob(os.path.join(directory, pattern)) + + if not file_paths: + logger.warning(f"在目录 {directory} 中未找到符合模式 {pattern} 的文件") + return pd.DataFrame() + + logger.info(f"在目录 {directory} 中找到 {len(file_paths)} 个符合模式 {pattern} 的文件") + + # 调用process_files处理文件 + return self.process_files(file_paths, output_path, return_top_k, format) + + def process_dataframe(self, df: pd.DataFrame, text_column: str, + id_column: Optional[str] = None, + output_path: Optional[str] = None, + return_top_k: int = 1, format: str = 'csv') -> pd.DataFrame: + """ + 批量处理DataFrame中的文本 + + Args: + df: 输入DataFrame + text_column: 文本列名 + id_column: ID列名,如果为None则使用索引 + output_path: 输出文件路径,如果为None则不保存 + return_top_k: 返回概率最高的前k个类别 + format: 输出格式,'csv'或'json' + + Returns: + 预测结果DataFrame + """ + # 获取文本和ID + texts = df[text_column].tolist() + + if id_column: + ids = df[id_column].tolist() + else: + ids = df.index.tolist() + + # 批量预测 + result_df = self.predictor.predict_to_dataframe(texts, ids, return_top_k) + + # 保存结果 + if output_path: + if format.lower() == 'csv': + result_df.to_csv(output_path, index=False, encoding='utf-8') + elif format.lower() == 'json': + # 转换为嵌套的JSON格式 + self.predictor.save_predictions(texts, output_path, ids, return_top_k, 'json') + else: + raise ValueError(f"不支持的输出格式: {format}") + + logger.info(f"预测结果已保存到: {output_path}") + + return result_df + + def process_large_file(self, file_path: str, output_path: Optional[str] = None, + return_top_k: int = 1, format: str = 'csv', + chunk_size: int = MAX_TEXT_PER_BATCH, + delimiter: str = '\n\n') -> None: + """ + 处理大型文本文件,文件会被分块读取和处理 + + Args: + file_path: 文件路径 + output_path: 输出文件路径,如果为None则不保存 + return_top_k: 返回概率最高的前k个类别 + format: 输出格式,'csv'或'json' + chunk_size: 每个块的大小(文本数量) + delimiter: 文本分隔符 + """ + logger.info(f"开始处理大型文件: {file_path}") + start_time = time.time() + + # 读取文件内容 + with open(file_path, 'r', encoding=ENCODING) as f: + content = f.read() + + # 分割文本 + texts = content.split(delimiter) + texts = [text.strip() for text in texts if text.strip()] + + logger.info(f"文件共包含 {len(texts)} 条文本") + + # 创建输出文件 + if output_path: + if format.lower() == 'csv': + # 创建CSV文件头 + if return_top_k > 1: + header = "text_id,text,rank,predicted_class,probability\n" + else: + header = "text_id,text,predicted_class,probability\n" + + with open(output_path, 'w', encoding=ENCODING) as f: + f.write(header) + elif format.lower() == 'json': + # 创建JSON文件 + with open(output_path, 'w', encoding=ENCODING) as f: + f.write('{\n') + + # 分块处理 + total_chunks = (len(texts) + chunk_size - 1) // chunk_size + + for i in range(0, len(texts), chunk_size): + chunk = texts[i:i + chunk_size] + chunk_ids = list(range(i, i + len(chunk))) + + logger.info(f"处理第 {i // chunk_size + 1}/{total_chunks} 块,包含 {len(chunk)} 条文本") + + # 批量预测 + result_df = self.predictor.predict_to_dataframe(chunk, chunk_ids, return_top_k) + + # 追加到输出文件 + if output_path: + if format.lower() == 'csv': + result_df.to_csv(output_path, index=False, encoding=ENCODING, mode='a', header=False) + elif format.lower() == 'json': + # 转换为JSON并追加 + if return_top_k > 1: + # 分组后转换为嵌套格式 + for id_val in result_df['id'].unique(): + sub_df = result_df[result_df['id'] == id_val] + predictions = [] + for _, row in sub_df.iterrows(): + predictions.append({ + 'class': row['predicted_class'], + 'probability': float(row['probability']) + }) + + json_str = f' "{id_val}": {{\n' + json_str += f' "text": {json.dumps(sub_df.iloc[0]["text"], ensure_ascii=False)},\n' + json_str += f' "predictions": {json.dumps(predictions, ensure_ascii=False)}\n' + json_str += ' },' + + with open(output_path, 'a', encoding=ENCODING) as f: + f.write(json_str + '\n') + else: + # 直接构建JSON + for _, row in result_df.iterrows(): + json_str = f' "{row["id"]}": {{\n' + json_str += f' "text": {json.dumps(row["text"], ensure_ascii=False)},\n' + json_str += f' "predicted_class": "{row["predicted_class"]}",\n' + json_str += f' "probability": {float(row["probability"])}\n' + json_str += ' },' + + with open(output_path, 'a', encoding=ENCODING) as f: + f.write(json_str + '\n') + + # 完成JSON文件 + if output_path and format.lower() == 'json': + with open(output_path, 'a', encoding=ENCODING) as f: + f.write('}\n') + + # 修复JSON文件中的最后一个逗号 + with open(output_path, 'r', encoding=ENCODING) as f: + content = f.read() + + content = content.rstrip('\n}') + content = content.rstrip(',') + content += '\n}\n' + + with open(output_path, 'w', encoding=ENCODING) as f: + f.write(content) + + processing_time = time.time() - start_time + logger.info(f"处理大型文件完成,共处理 {len(texts)} 条文本,用时: {processing_time:.2f} 秒") + + +================================================================================ +文件: inference/__init__.py +================================================================================ + + + +================================================================================ +文件: evaluation/metrics.py +================================================================================ + +""" +评估指标模块:实现各种评估指标 +""" +import numpy as np +import tensorflow as tf +from sklearn.metrics import ( + accuracy_score, precision_score, recall_score, f1_score, + confusion_matrix, classification_report, roc_auc_score, + precision_recall_curve, average_precision_score +) +import matplotlib.pyplot as plt +from typing import List, Dict, Tuple, Optional, Any, Union, Callable +import pandas as pd + +from utils.logger import get_logger + +logger = get_logger("Metrics") + + +class ClassificationMetrics: + """分类评估指标类,计算各种分类评估指标""" + + def __init__(self, class_names: Optional[List[str]] = None): + """ + 初始化分类评估指标类 + + Args: + class_names: 类别名称列表 + """ + self.class_names = class_names + + def accuracy(self, y_true: np.ndarray, y_pred: np.ndarray) -> float: + """ + 计算准确率 + + Args: + y_true: 真实标签 + y_pred: 预测标签 + + Returns: + 准确率 + """ + return accuracy_score(y_true, y_pred) + + def precision(self, y_true: np.ndarray, y_pred: np.ndarray, + average: str = 'macro') -> Union[float, np.ndarray]: + """ + 计算精确率 + + Args: + y_true: 真实标签 + y_pred: 预测标签 + average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None + + Returns: + 精确率 + """ + return precision_score(y_true, y_pred, average=average, zero_division=0) + + def recall(self, y_true: np.ndarray, y_pred: np.ndarray, + average: str = 'macro') -> Union[float, np.ndarray]: + """ + 计算召回率 + + Args: + y_true: 真实标签 + y_pred: 预测标签 + average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None + + Returns: + 召回率 + """ + return recall_score(y_true, y_pred, average=average, zero_division=0) + + def f1(self, y_true: np.ndarray, y_pred: np.ndarray, + average: str = 'macro') -> Union[float, np.ndarray]: + """ + 计算F1分数 + + Args: + y_true: 真实标签 + y_pred: 预测标签 + average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None + + Returns: + F1分数 + """ + return f1_score(y_true, y_pred, average=average, zero_division=0) + + def confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray, + normalize: Optional[str] = None) -> np.ndarray: + """ + 计算混淆矩阵 + + Args: + y_true: 真实标签 + y_pred: 预测标签 + normalize: 归一化方式,可选值: 'true', 'pred', 'all', None + + Returns: + 混淆矩阵 + """ + cm = confusion_matrix(y_true, y_pred) + + if normalize is not None: + if normalize == 'true': + cm = cm.astype('float') / cm.sum(axis=1, keepdims=True) + elif normalize == 'pred': + cm = cm.astype('float') / cm.sum(axis=0, keepdims=True) + elif normalize == 'all': + cm = cm.astype('float') / cm.sum() + + return cm + + def plot_confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray, + normalize: Optional[str] = None, + figsize: Tuple[int, int] = (10, 8), + save_path: Optional[str] = None) -> None: + """ + 绘制混淆矩阵 + + Args: + y_true: 真实标签 + y_pred: 预测标签 + normalize: 归一化方式,可选值: 'true', 'pred', 'all', None + figsize: 图像大小 + save_path: 保存路径,如果为None则显示图像 + """ + # 计算混淆矩阵 + cm = self.confusion_matrix(y_true, y_pred, normalize) + + # 确定类别名称 + if self.class_names is None: + class_names = [str(i) for i in range(cm.shape[0])] + else: + class_names = self.class_names + + # 创建图像 + plt.figure(figsize=figsize) + + # 使用热图显示混淆矩阵 + im = plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + plt.colorbar(im) + + # 设置坐标轴标签 + plt.xticks(np.arange(cm.shape[1]), class_names, rotation=45, ha='right') + plt.yticks(np.arange(cm.shape[0]), class_names) + + # 设置标题 + if normalize is not None: + plt.title(f"Normalized ({normalize}) Confusion Matrix") + else: + plt.title("Confusion Matrix") + + plt.ylabel('True label') + plt.xlabel('Predicted label') + + # 在每个单元格中显示数值 + thresh = cm.max() / 2.0 + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + if normalize is not None: + plt.text(j, i, f"{cm[i, j]:.2f}", + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + else: + plt.text(j, i, f"{cm[i, j]}", + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + plt.tight_layout() + + # 保存或显示图像 + if save_path: + plt.savefig(save_path) + logger.info(f"混淆矩阵图已保存到: {save_path}") + else: + plt.show() + + def classification_report(self, y_true: np.ndarray, y_pred: np.ndarray, + output_dict: bool = False) -> Union[str, Dict]: + """ + 生成分类报告 + + Args: + y_true: 真实标签 + y_pred: 预测标签 + output_dict: 是否以字典形式返回 + + Returns: + 分类报告 + """ + target_names = self.class_names if self.class_names else None + return classification_report(y_true, y_pred, + target_names=target_names, + output_dict=output_dict, + zero_division=0) + + def auc_roc(self, y_true: np.ndarray, y_prob: np.ndarray, + multi_class: str = 'ovr') -> Union[float, np.ndarray]: + """ + 计算AUC-ROC + + Args: + y_true: 真实标签 + y_prob: 预测概率 + multi_class: 多分类处理方式,可选值: 'ovr', 'ovo' + + Returns: + AUC-ROC + """ + try: + # 如果y_true是one-hot编码,转换为类别索引 + if len(y_true.shape) > 1 and y_true.shape[1] > 1: + y_true = np.argmax(y_true, axis=1) + + # 多分类 + if y_prob.shape[1] > 2: + return roc_auc_score(y_true, y_prob, multi_class=multi_class, average='macro') + # 二分类 + else: + return roc_auc_score(y_true, y_prob[:, 1]) + except Exception as e: + logger.error(f"计算AUC-ROC时出错: {e}") + return 0.0 + + def average_precision(self, y_true: np.ndarray, y_prob: np.ndarray, + average: str = 'macro') -> Union[float, np.ndarray]: + """ + 计算平均精确率 + + Args: + y_true: 真实标签 + y_prob: 预测概率 + average: 平均方式,可选值: 'micro', 'macro', 'weighted', 'samples', None + + Returns: + 平均精确率 + """ + try: + # 如果y_true是one-hot编码,转换为类别索引 + if len(y_true.shape) > 1 and y_true.shape[1] > 1: + y_true = np.argmax(y_true, axis=1) + + # 多分类:使用sklearn的方法 + return average_precision_score( + tf.keras.utils.to_categorical(y_true, num_classes=y_prob.shape[1]), + y_prob, + average=average + ) + except Exception as e: + logger.error(f"计算平均精确率时出错: {e}") + return 0.0 + + def plot_precision_recall_curve(self, y_true: np.ndarray, y_prob: np.ndarray, + class_id: Optional[int] = None, + figsize: Tuple[int, int] = (10, 8), + save_path: Optional[str] = None) -> None: + """ + 绘制精确率-召回率曲线 + + Args: + y_true: 真实标签 + y_prob: 预测概率 + class_id: 要绘制的类别ID,如果为None则绘制所有类别 + figsize: 图像大小 + save_path: 保存路径,如果为None则显示图像 + """ + # 如果y_true是one-hot编码,转换为类别索引 + if len(y_true.shape) > 1 and y_true.shape[1] > 1: + y_true = np.argmax(y_true, axis=1) + + # 创建图像 + plt.figure(figsize=figsize) + + # 确定要绘制的类别 + if class_id is not None and class_id < y_prob.shape[1]: + # 绘制指定类别的PR曲线 + y_true_bin = (y_true == class_id).astype(int) + precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, class_id]) + avg_prec = average_precision_score(y_true_bin, y_prob[:, class_id]) + + plt.step(recall, precision, where='post', + label=f'Class {class_id} (AP = {avg_prec:.3f})') + else: + # 绘制所有类别的PR曲线 + for i in range(y_prob.shape[1]): + y_true_bin = (y_true == i).astype(int) + precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, i]) + avg_prec = average_precision_score(y_true_bin, y_prob[:, i]) + + class_name = self.class_names[i] if self.class_names else f"Class {i}" + plt.step(recall, precision, where='post', + label=f'{class_name} (AP = {avg_prec:.3f})') + + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.title('Precision-Recall Curve') + plt.legend(loc='lower left') + plt.grid(True) + + # 保存或显示图像 + if save_path: + plt.savefig(save_path) + logger.info(f"精确率-召回率曲线图已保存到: {save_path}") + else: + plt.show() + + def calculate_all_metrics(self, y_true: np.ndarray, y_pred: np.ndarray, + y_prob: Optional[np.ndarray] = None) -> Dict[str, float]: + """ + 计算所有评估指标 + + Args: + y_true: 真实标签 + y_pred: 预测标签 + y_prob: 预测概率 + + Returns: + 包含所有评估指标的字典 + """ + metrics = {} + + # 基础指标 + metrics['accuracy'] = self.accuracy(y_true, y_pred) + metrics['precision_macro'] = self.precision(y_true, y_pred, average='macro') + metrics['recall_macro'] = self.recall(y_true, y_pred, average='macro') + metrics['f1_macro'] = self.f1(y_true, y_pred, average='macro') + + # 如果提供了预测概率,计算AUC-ROC和平均精确率 + if y_prob is not None: + try: + metrics['auc_roc'] = self.auc_roc(y_true, y_prob) + metrics['average_precision'] = self.average_precision(y_true, y_prob) + except Exception as e: + logger.error(f"计算概率指标时出错: {e}") + + # 类别级别的指标 + for avg in ['micro', 'weighted']: + metrics[f'precision_{avg}'] = self.precision(y_true, y_pred, average=avg) + metrics[f'recall_{avg}'] = self.recall(y_true, y_pred, average=avg) + metrics[f'f1_{avg}'] = self.f1(y_true, y_pred, average=avg) + + return metrics + + def metrics_to_dataframe(self, metrics: Dict[str, float]) -> pd.DataFrame: + """ + 将评估指标转换为DataFrame + + Args: + metrics: 评估指标字典 + + Returns: + 评估指标DataFrame + """ + return pd.DataFrame(metrics.items(), columns=['Metric', 'Value']).set_index('Metric') + + +================================================================================ +文件: evaluation/__init__.py +================================================================================ + + + +================================================================================ +文件: evaluation/visualization.py +================================================================================ + +""" +可视化模块:实现评估结果的可视化 +""" +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from typing import List, Dict, Tuple, Optional, Any, Union +import os +import itertools +from sklearn.metrics import roc_curve, precision_recall_curve, auc +from sklearn.manifold import TSNE +from sklearn.decomposition import PCA + +from utils.logger import get_logger +from utils.file_utils import ensure_dir + +logger = get_logger("Visualization") + + +class EvaluationVisualizer: + """评估结果可视化类""" + + def __init__(self, output_dir: Optional[str] = None, + class_names: Optional[List[str]] = None, + figsize: Tuple[int, int] = (10, 8)): + """ + 初始化评估结果可视化类 + + Args: + output_dir: 输出目录,用于保存可视化结果 + class_names: 类别名称列表 + figsize: 图像默认大小 + """ + self.output_dir = output_dir + if output_dir: + ensure_dir(output_dir) + + self.class_names = class_names + self.figsize = figsize + + def plot_confusion_matrix(self, cm: np.ndarray, + normalize: Optional[str] = None, + title: str = 'Confusion Matrix', + cmap: str = 'Blues', + save_path: Optional[str] = None) -> None: + """ + 绘制混淆矩阵 + + Args: + cm: 混淆矩阵 + normalize: 归一化方式,可选值: 'true', 'pred', 'all', None + title: 图像标题 + cmap: 颜色映射 + save_path: 保存路径,如果为None则使用output_dir/confusion_matrix.png + """ + if normalize is not None: + if normalize == 'true': + cm = cm.astype('float') / cm.sum(axis=1, keepdims=True) + title = 'Normalized (by true) ' + title + elif normalize == 'pred': + cm = cm.astype('float') / cm.sum(axis=0, keepdims=True) + title = 'Normalized (by pred) ' + title + elif normalize == 'all': + cm = cm.astype('float') / cm.sum() + title = 'Normalized (by all) ' + title + + plt.figure(figsize=self.figsize) + + # 使用seaborn绘制热图 + sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd', + cmap=cmap, square=True, cbar=True) + + # 设置坐标轴标签 + if self.class_names: + tick_marks = np.arange(len(self.class_names)) + plt.xticks(tick_marks + 0.5, self.class_names, rotation=45, ha='right') + plt.yticks(tick_marks + 0.5, self.class_names, rotation=0) + + plt.title(title) + plt.ylabel('True label') + plt.xlabel('Predicted label') + plt.tight_layout() + + # 保存图像 + if save_path is None and self.output_dir: + save_path = os.path.join(self.output_dir, 'confusion_matrix.png') + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"混淆矩阵图已保存到: {save_path}") + + plt.close() + + def plot_metrics_comparison(self, metrics_dict: Dict[str, Dict[str, float]], + selected_metrics: Optional[List[str]] = None, + title: str = 'Metrics Comparison', + save_path: Optional[str] = None) -> None: + """ + 绘制多个模型的评估指标比较 + + Args: + metrics_dict: 模型评估指标字典,格式为{model_name: {metric_name: value}} + selected_metrics: 要比较的指标列表,如果为None则使用所有指标 + title: 图像标题 + save_path: 保存路径,如果为None则使用output_dir/metrics_comparison.png + """ + # 创建DataFrame + df = pd.DataFrame(metrics_dict).T + + # 筛选指标 + if selected_metrics: + df = df[selected_metrics] + + # 绘制条形图 + plt.figure(figsize=self.figsize) + df.plot(kind='bar', figsize=self.figsize) + + plt.title(title) + plt.ylabel('Score') + plt.ylim(0, 1) + plt.legend(loc='best') + plt.grid(axis='y', linestyle='--', alpha=0.7) + plt.tight_layout() + + # 保存图像 + if save_path is None and self.output_dir: + save_path = os.path.join(self.output_dir, 'metrics_comparison.png') + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"评估指标比较图已保存到: {save_path}") + + plt.close() + + def plot_roc_curves(self, y_true: np.ndarray, y_prob: np.ndarray, + title: str = 'ROC Curves', + save_path: Optional[str] = None) -> None: + """ + 绘制ROC曲线 + + Args: + y_true: 真实标签 + y_prob: 预测概率 + title: 图像标题 + save_path: 保存路径,如果为None则使用output_dir/roc_curves.png + """ + plt.figure(figsize=self.figsize) + + # 确保y_true是一维数组 + if len(y_true.shape) > 1 and y_true.shape[1] > 1: + y_true = np.argmax(y_true, axis=1) + + # 获取类别数 + num_classes = y_prob.shape[1] + + # 绘制每个类别的ROC曲线 + for i in range(num_classes): + # 二分类转换:当前类别为正类,其他为负类 + y_true_bin = (y_true == i).astype(int) + + # 计算ROC曲线 + fpr, tpr, _ = roc_curve(y_true_bin, y_prob[:, i]) + roc_auc = auc(fpr, tpr) + + # 确定类别名称 + if self.class_names and i < len(self.class_names): + class_name = self.class_names[i] + else: + class_name = f'Class {i}' + + # 绘制ROC曲线 + plt.plot(fpr, tpr, lw=2, label=f'{class_name} (AUC = {roc_auc:.3f})') + + # 绘制随机猜测的基准线 + plt.plot([0, 1], [0, 1], 'k--', lw=2) + + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title(title) + plt.legend(loc='lower right') + plt.grid(True, alpha=0.3) + plt.tight_layout() + + # 保存图像 + if save_path is None and self.output_dir: + save_path = os.path.join(self.output_dir, 'roc_curves.png') + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"ROC曲线图已保存到: {save_path}") + + plt.close() + + def plot_precision_recall_curves(self, y_true: np.ndarray, y_prob: np.ndarray, + title: str = 'Precision-Recall Curves', + save_path: Optional[str] = None) -> None: + """ + 绘制精确率-召回率曲线 + + Args: + y_true: 真实标签 + y_prob: 预测概率 + title: 图像标题 + save_path: 保存路径,如果为None则使用output_dir/precision_recall_curves.png + """ + plt.figure(figsize=self.figsize) + + # 确保y_true是一维数组 + if len(y_true.shape) > 1 and y_true.shape[1] > 1: + y_true = np.argmax(y_true, axis=1) + + # 获取类别数 + num_classes = y_prob.shape[1] + + # 绘制每个类别的PR曲线 + for i in range(num_classes): + # 二分类转换:当前类别为正类,其他为负类 + y_true_bin = (y_true == i).astype(int) + + # 计算PR曲线 + precision, recall, _ = precision_recall_curve(y_true_bin, y_prob[:, i]) + pr_auc = auc(recall, precision) + + # 确定类别名称 + if self.class_names and i < len(self.class_names): + class_name = self.class_names[i] + else: + class_name = f'Class {i}' + + # 绘制PR曲线 + plt.plot(recall, precision, lw=2, label=f'{class_name} (AUC = {pr_auc:.3f})') + + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.title(title) + plt.legend(loc='best') + plt.grid(True, alpha=0.3) + plt.tight_layout() + + # 保存图像 + if save_path is None and self.output_dir: + save_path = os.path.join(self.output_dir, 'precision_recall_curves.png') + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"精确率-召回率曲线图已保存到: {save_path}") + + plt.close() + + def plot_feature_importance(self, feature_names: List[str], + importance: np.ndarray, + title: str = 'Feature Importance', + top_n: int = 20, + save_path: Optional[str] = None) -> None: + """ + 绘制特征重要性 + + Args: + feature_names: 特征名称列表 + importance: 特征重要性数组 + title: 图像标题 + top_n: 显示前N个重要的特征 + save_path: 保存路径,如果为None则使用output_dir/feature_importance.png + """ + # 创建DataFrame + df = pd.DataFrame({'Feature': feature_names, 'Importance': importance}) + + # 按重要性排序 + df = df.sort_values('Importance', ascending=False).head(top_n) + + # 绘制条形图 + plt.figure(figsize=self.figsize) + sns.barplot(x='Importance', y='Feature', data=df) + + plt.title(title) + plt.xlabel('Importance') + plt.ylabel('Feature') + plt.tight_layout() + + # 保存图像 + if save_path is None and self.output_dir: + save_path = os.path.join(self.output_dir, 'feature_importance.png') + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"特征重要性图已保存到: {save_path}") + + plt.close() + + def plot_embedding_visualization(self, embeddings: np.ndarray, + labels: np.ndarray, + method: str = 'tsne', + title: str = 'Embedding Visualization', + save_path: Optional[str] = None) -> None: + """ + 绘制嵌入向量可视化 + + Args: + embeddings: 嵌入向量,形状为(样本数, 嵌入维度) + labels: 类别标签,形状为(样本数,) + method: 降维方法,'tsne'或'pca' + title: 图像标题 + save_path: 保存路径,如果为None则使用output_dir/embedding_visualization.png + """ + # 降维 + if method.lower() == 'tsne': + reducer = TSNE(n_components=2, random_state=42) + elif method.lower() == 'pca': + reducer = PCA(n_components=2, random_state=42) + else: + raise ValueError(f"不支持的降维方法: {method}") + + # 如果嵌入向量太多,采样一部分 + max_samples = 5000 + if len(embeddings) > max_samples: + indices = np.random.choice(len(embeddings), max_samples, replace=False) + embeddings_sample = embeddings[indices] + labels_sample = labels[indices] + else: + embeddings_sample = embeddings + labels_sample = labels + + # 执行降维 + embeddings_2d = reducer.fit_transform(embeddings_sample) + + # 绘制散点图 + plt.figure(figsize=self.figsize) + + # 确保标签是一维数组 + if len(labels_sample.shape) > 1 and labels_sample.shape[1] > 1: + labels_sample = np.argmax(labels_sample, axis=1) + + # 获取唯一类别 + unique_labels = np.unique(labels_sample) + + # 为每个类别分配不同的颜色 + colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels))) + + for i, label in enumerate(unique_labels): + mask = labels_sample == label + + # 确定类别名称 + if self.class_names and label < len(self.class_names): + class_name = self.class_names[int(label)] + else: + class_name = f'Class {int(label)}' + + plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1], + c=[colors[i]], label=class_name, alpha=0.7) + + plt.title(title) + plt.legend(loc='best') + plt.grid(True, alpha=0.3) + plt.tight_layout() + + # 保存图像 + if save_path is None and self.output_dir: + save_path = os.path.join(self.output_dir, f'embedding_visualization_{method}.png') + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches='tight') + logger.info(f"嵌入向量可视化图已保存到: {save_path}") + + plt.close() + + + +================================================================================ +文件: evaluation/evaluator.py +================================================================================ + +""" +评估器模块:实现模型评估流程 +""" +import numpy as np +import tensorflow as tf +import time +import os +from typing import List, Dict, Tuple, Optional, Any, Union, Callable +import pandas as pd +import matplotlib.pyplot as plt +import json + +from config.system_config import SAVED_MODELS_DIR +from models.base_model import TextClassificationModel +from evaluation.metrics import ClassificationMetrics +from utils.logger import get_logger +from utils.file_utils import ensure_dir, save_json + +logger = get_logger("Evaluator") + + +class ModelEvaluator: + """模型评估器,负责评估模型性能""" + + def __init__(self, model: TextClassificationModel, + class_names: Optional[List[str]] = None, + output_dir: Optional[str] = None, + batch_size: Optional[int] = None): + """ + 初始化模型评估器 + + Args: + model: 要评估的模型 + class_names: 类别名称列表 + output_dir: 输出目录,用于保存评估结果 + batch_size: 批大小,如果为None则使用模型默认值 + """ + self.model = model + self.class_names = class_names + self.batch_size = batch_size or model.batch_size + + # 设置输出目录 + if output_dir is None: + self.output_dir = os.path.join(SAVED_MODELS_DIR, 'evaluation', model.model_name) + else: + self.output_dir = output_dir + + ensure_dir(self.output_dir) + + # 创建评估指标计算器 + self.metrics = ClassificationMetrics(class_names) + + # 评估结果 + self.evaluation_results = None + + logger.info(f"初始化模型评估器,模型: {model.model_name}") + + def evaluate(self, x_test: Union[np.ndarray, tf.data.Dataset], + y_test: Optional[np.ndarray] = None, + batch_size: Optional[int] = None, + verbose: int = 1) -> Dict[str, float]: + """ + 评估模型 + + Args: + x_test: 测试数据特征 + y_test: 测试数据标签 + batch_size: 批大小 + verbose: 详细程度 + + Returns: + 评估结果 + """ + batch_size = batch_size or self.batch_size + + logger.info(f"开始评估模型: {self.model.model_name}") + start_time = time.time() + + # 使用模型评估 + model_metrics = self.model.evaluate(x_test, y_test, verbose=verbose) + + # 获取预测结果 + y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=0) + y_pred = np.argmax(y_prob, axis=1) + + # 处理y_test,确保y_test是一维数组 + if isinstance(x_test, tf.data.Dataset): + # 如果是TensorFlow Dataset,需要从中提取y_test + y_test_extracted = np.concatenate([y for _, y in x_test], axis=0) + y_test = y_test_extracted + + if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1: + y_test = np.argmax(y_test, axis=1) + + # 计算所有指标 + all_metrics = self.metrics.calculate_all_metrics(y_test, y_pred, y_prob) + + # 合并模型内置指标和自定义指标 + metrics_names = self.model.model.metrics_names + model_metrics_dict = {name: float(value) for name, value in zip(metrics_names, model_metrics)} + all_metrics.update(model_metrics_dict) + + # 记录评估时间 + evaluation_time = time.time() - start_time + all_metrics['evaluation_time'] = evaluation_time + + # 保存评估结果 + self.evaluation_results = { + 'metrics': all_metrics, + 'confusion_matrix': self.metrics.confusion_matrix(y_test, y_pred).tolist(), + 'classification_report': self.metrics.classification_report(y_test, y_pred, output_dict=True) + } + + logger.info(f"模型评估完成,用时: {evaluation_time:.2f} 秒") + logger.info(f"主要评估指标: accuracy={all_metrics.get('accuracy', 'N/A'):.4f}, " + f"f1_macro={all_metrics.get('f1_macro', 'N/A'):.4f}") + + return all_metrics + + def save_evaluation_results(self, save_plots: bool = True) -> str: + """ + 保存评估结果 + + Args: + save_plots: 是否保存可视化图表 + + Returns: + 结果保存路径 + """ + if self.evaluation_results is None: + raise ValueError("请先调用evaluate方法进行评估") + + # 保存评估结果为JSON + results_path = os.path.join(self.output_dir, 'evaluation_results.json') + with open(results_path, 'w', encoding='utf-8') as f: + json.dump(self.evaluation_results, f, ensure_ascii=False, indent=4) + + # 保存评估指标为CSV + metrics_df = pd.DataFrame( + self.evaluation_results['metrics'].items(), + columns=['Metric', 'Value'] + ).set_index('Metric') + + metrics_path = os.path.join(self.output_dir, 'metrics.csv') + metrics_df.to_csv(metrics_path) + + # 保存可视化图表 + if save_plots: + self._save_plots() + + logger.info(f"评估结果已保存到: {self.output_dir}") + + return self.output_dir + + def _save_plots(self) -> None: + """保存评估结果可视化图表""" + if self.evaluation_results is None: + raise ValueError("请先调用evaluate方法进行评估") + + # 创建可视化目录 + plots_dir = os.path.join(self.output_dir, 'plots') + ensure_dir(plots_dir) + + # 混淆矩阵图 + cm_path = os.path.join(plots_dir, 'confusion_matrix.png') + cm = np.array(self.evaluation_results['confusion_matrix']) + + # 将混淆矩阵转换为NumPy数组 + if isinstance(cm, list): + cm = np.array(cm) + + # 绘制混淆矩阵 + self.metrics.plot_confusion_matrix( + np.arange(cm.shape[0]), # 假设标签 + np.arange(cm.shape[1]), # 假设预测 + normalize='true', + save_path=cm_path + ) + + # 保存评估指标条形图 + metrics_path = os.path.join(plots_dir, 'metrics_bar.png') + metrics = self.evaluation_results['metrics'] + + # 选择要展示的主要指标 + main_metrics = { + 'accuracy': metrics.get('accuracy', 0), + 'precision_macro': metrics.get('precision_macro', 0), + 'recall_macro': metrics.get('recall_macro', 0), + 'f1_macro': metrics.get('f1_macro', 0) + } + + # 绘制条形图 + plt.figure(figsize=(10, 6)) + plt.bar(main_metrics.keys(), main_metrics.values()) + plt.title('Main Evaluation Metrics') + plt.ylabel('Score') + plt.ylim(0, 1) + plt.xticks(rotation=45, ha='right') + plt.tight_layout() + plt.savefig(metrics_path) + plt.close() + + # 如果有类别级别的指标,绘制每个类别的指标 + if 'classification_report' in self.evaluation_results: + report = self.evaluation_results['classification_report'] + + # 提取每个类别的精确率、召回率和F1值 + class_metrics = {} + for key, value in report.items(): + if key not in ['accuracy', 'macro avg', 'weighted avg', 'samples avg']: + if isinstance(value, dict): + class_metrics[key] = value + + if class_metrics: + # 绘制每个类别的F1分数 + class_f1_path = os.path.join(plots_dir, 'class_f1_scores.png') + + classes = list(class_metrics.keys()) + f1_scores = [metrics['f1-score'] for metrics in class_metrics.values()] + + plt.figure(figsize=(12, 6)) + bars = plt.bar(classes, f1_scores) + + # 在柱状图上方显示数值 + for bar in bars: + height = bar.get_height() + plt.text(bar.get_x() + bar.get_width() / 2., height + 0.01, + f'{height:.2f}', + ha='center', va='bottom', rotation=0) + + plt.title('F1 Score by Class') + plt.ylabel('F1 Score') + plt.ylim(0, 1.1) + plt.xticks(rotation=45, ha='right') + plt.tight_layout() + plt.savefig(class_f1_path) + plt.close() + + # 绘制每个类别的精确率和召回率 + class_prec_rec_path = os.path.join(plots_dir, 'class_precision_recall.png') + + precisions = [metrics['precision'] for metrics in class_metrics.values()] + recalls = [metrics['recall'] for metrics in class_metrics.values()] + + plt.figure(figsize=(12, 6)) + x = np.arange(len(classes)) + width = 0.35 + + plt.bar(x - width / 2, precisions, width, label='Precision') + plt.bar(x + width / 2, recalls, width, label='Recall') + + plt.ylabel('Score') + plt.title('Precision and Recall by Class') + plt.xticks(x, classes, rotation=45, ha='right') + plt.legend() + plt.ylim(0, 1.1) + plt.tight_layout() + plt.savefig(class_prec_rec_path) + plt.close() + + logger.info(f"评估可视化图表已保存到: {plots_dir}") + + def compare_models(self, other_evaluators: List['ModelEvaluator'], + metrics: Optional[List[str]] = None, + save_path: Optional[str] = None) -> pd.DataFrame: + """ + 比较多个模型的评估结果 + + Args: + other_evaluators: 其他模型评估器列表 + metrics: 要比较的指标列表,默认为['accuracy', 'precision_macro', 'recall_macro', 'f1_macro'] + save_path: 比较结果的保存路径 + + Returns: + 比较结果DataFrame + """ + if self.evaluation_results is None: + raise ValueError("请先调用evaluate方法进行评估") + + # 默认比较指标 + if metrics is None: + metrics = ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro'] + + # 收集所有模型的评估指标 + models_metrics = {} + + # 当前模型 + models_metrics[self.model.model_name] = { + metric: self.evaluation_results['metrics'].get(metric, 'N/A') + for metric in metrics + } + + # 其他模型 + for evaluator in other_evaluators: + if evaluator.evaluation_results is None: + logger.warning(f"模型 {evaluator.model.model_name} 尚未评估,跳过") + continue + + models_metrics[evaluator.model.model_name] = { + metric: evaluator.evaluation_results['metrics'].get(metric, 'N/A') + for metric in metrics + } + + # 创建比较DataFrame + comparison_df = pd.DataFrame(models_metrics).T + + # 保存比较结果 + if save_path: + comparison_df.to_csv(save_path) + logger.info(f"模型比较结果已保存到: {save_path}") + + # 绘制比较条形图 + plt.figure(figsize=(12, 6)) + comparison_df.plot(kind='bar', figsize=(12, 6)) + plt.title('Model Comparison') + plt.ylabel('Score') + plt.ylim(0, 1) + plt.legend(loc='lower right') + plt.tight_layout() + + # 如果save_path是CSV文件,将其替换为PNG文件 + if save_path.endswith('.csv'): + plot_path = save_path.replace('.csv', '.png') + else: + plot_path = save_path + '.png' + + plt.savefig(plot_path) + plt.close() + + logger.info(f"模型比较图表已保存到: {plot_path}") + + return comparison_df + + def evaluate_class_performance(self, x_test: Union[np.ndarray, tf.data.Dataset], + y_test: Optional[np.ndarray] = None, + batch_size: Optional[int] = None, + verbose: int = 0) -> pd.DataFrame: + """ + 评估模型在各个类别上的性能 + + Args: + x_test: 测试数据特征 + y_test: 测试数据标签 + batch_size: 批大小 + verbose: 详细程度 + + Returns: + 各类别性能指标DataFrame + """ + batch_size = batch_size or self.batch_size + + # 获取预测结果 + y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=verbose) + y_pred = np.argmax(y_prob, axis=1) + + # 处理y_test,确保y_test是一维数组 + if isinstance(x_test, tf.data.Dataset): + # 如果是TensorFlow Dataset,需要从中提取y_test + y_test_extracted = np.concatenate([y for _, y in x_test], axis=0) + y_test = y_test_extracted + + if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1: + y_test = np.argmax(y_test, axis=1) + + # 获取分类报告 + report = self.metrics.classification_report(y_test, y_pred, output_dict=True) + + # 提取各类别指标 + class_metrics = {} + for key, value in report.items(): + if key not in ['accuracy', 'macro avg', 'weighted avg', 'samples avg']: + if isinstance(value, dict): + class_metrics[key] = value + + # 转换为DataFrame + class_performance_df = pd.DataFrame(class_metrics).T + + # 添加支持度(样本数量) + class_counts = np.bincount(y_test) + for idx, count in enumerate(class_counts): + if str(idx) in class_performance_df.index: + class_performance_df.loc[str(idx), 'support'] = count + + # 添加类别名称 + if self.class_names: + class_performance_df['class_name'] = [ + self.class_names[int(idx)] if int(idx) < len(self.class_names) else idx + for idx in class_performance_df.index + ] + + # 保存类别性能指标 + performance_path = os.path.join(self.output_dir, 'class_performance.csv') + class_performance_df.to_csv(performance_path) + logger.info(f"各类别性能指标已保存到: {performance_path}") + + return class_performance_df + + def plot_error_analysis(self, x_test: Union[np.ndarray, tf.data.Dataset], + y_test: Optional[np.ndarray] = None, + batch_size: Optional[int] = None, + num_samples: int = 10, + save_path: Optional[str] = None) -> None: + """ + 绘制误分类样本分析 + + Args: + x_test: 测试数据特征 + y_test: 测试数据标签 + batch_size: 批大小 + num_samples: 要展示的误分类样本数量 + save_path: 保存路径 + """ + # 仅适用于文本数据的分析,需要原始文本 + logger.info("误分类样本分析需要原始文本数据,此方法可能需要根据实际数据类型进行修改") + + # 在实际应用中,这里应该根据实际数据类型进行修改 + # 例如,对于序列化的文本,可能需要反序列化,或者使用词汇表将索引转换回文本 + + # 此处仅展示一个基本框架 + batch_size = batch_size or self.batch_size + + # 获取预测结果 + y_prob = self.model.predict(x_test, batch_size=batch_size, verbose=0) + y_pred = np.argmax(y_prob, axis=1) + + # 处理y_test,确保y_test是一维数组 + if isinstance(x_test, tf.data.Dataset): + # 如果是TensorFlow Dataset,需要从中提取y_test和x_test + dataset_iterator = iter(x_test) + x_test_extracted = [] + y_test_extracted = [] + + try: + while True: + x, y = next(dataset_iterator) + x_test_extracted.append(x) + y_test_extracted.append(y) + except StopIteration: + pass + + x_test = np.concatenate(x_test_extracted, axis=0) + y_test = np.concatenate(y_test_extracted, axis=0) + + if y_test is not None and len(y_test.shape) > 1 and y_test.shape[1] > 1: + y_test = np.argmax(y_test, axis=1) + + # 找出误分类样本 + misclassified_indices = np.where(y_pred != y_test)[0] + + # 如果没有误分类样本,返回 + if len(misclassified_indices) == 0: + logger.info("没有误分类样本") + return + + # 随机选择一些误分类样本 + if len(misclassified_indices) > num_samples: + misclassified_indices = np.random.choice(misclassified_indices, num_samples, replace=False) + + # 保存误分类样本分析结果 + misclassified_data = [] + + for idx in misclassified_indices: + true_label = y_test[idx] + pred_label = y_pred[idx] + + true_class = self.class_names[true_label] if self.class_names else str(true_label) + pred_class = self.class_names[pred_label] if self.class_names else str(pred_label) + + # 对于序列化的文本,此处需要进行反序列化 + # 这里仅作示例,实际应用中需要根据具体数据类型修改 + sample_text = f"Sample {idx}" + + misclassified_data.append({ + 'sample_id': idx, + 'true_label': true_label, + 'predicted_label': pred_label, + 'true_class': true_class, + 'predicted_class': pred_class, + 'confidence': float(y_prob[idx, pred_label]), + 'sample_text': sample_text + }) + + # 创建DataFrame + misclassified_df = pd.DataFrame(misclassified_data) + + # 保存结果 + if save_path: + misclassified_df.to_csv(save_path) + logger.info(f"误分类样本分析已保存到: {save_path}") + + return misclassified_df + + +================================================================================ +文件: preprocessing/__init__.py +================================================================================ + + + +================================================================================ +文件: preprocessing/tokenization.py +================================================================================ + +""" +中文分词模块:负责中文文本分词处理 +""" +import os +import jieba +import re +from typing import List, Dict, Tuple, Optional, Any, Set, Union +import pandas as pd +from collections import Counter + +from config.system_config import STOPWORDS_DIR, ENCODING +from utils.logger import get_logger +from utils.file_utils import read_text_file, write_text_file, ensure_dir + +logger = get_logger("Tokenization") + + +class ChineseTokenizer: + """中文分词器,基于jieba实现""" + + def __init__(self, user_dict_path: Optional[str] = None, + use_hmm: bool = True, + remove_stopwords: bool = True, + stopwords_path: Optional[str] = None, + add_custom_words: Optional[List[str]] = None): + """ + 初始化中文分词器 + + Args: + user_dict_path: 用户自定义词典路径 + use_hmm: 是否使用HMM模型进行分词 + remove_stopwords: 是否移除停用词 + stopwords_path: 停用词表路径,如果为None,则使用默认停用词表 + add_custom_words: 要添加的自定义词语列表 + """ + self.use_hmm = use_hmm + self.remove_stopwords = remove_stopwords + + # 加载用户自定义词典 + if user_dict_path and os.path.exists(user_dict_path): + jieba.load_userdict(user_dict_path) + logger.info(f"已加载用户自定义词典:{user_dict_path}") + + # 加载停用词 + self.stopwords = set() + if remove_stopwords: + self._load_stopwords(stopwords_path) + + # 添加自定义词语 + if add_custom_words: + for word in add_custom_words: + jieba.add_word(word) + logger.info(f"已添加 {len(add_custom_words)} 个自定义词语") + + def _load_stopwords(self, stopwords_path: Optional[str] = None) -> None: + """ + 加载停用词 + + Args: + stopwords_path: 停用词表路径,如果为None,则使用默认停用词表 + """ + # 如果没有指定停用词表路径,则使用默认停用词表 + if not stopwords_path: + stopwords_path = os.path.join(STOPWORDS_DIR, "chinese_stopwords.txt") + + # 如果没有找到默认停用词表,则创建一个空的停用词表 + if not os.path.exists(stopwords_path): + ensure_dir(os.path.dirname(stopwords_path)) + # 常见中文停用词 + default_stopwords = [ + "的", "了", "和", "是", "就", "都", "而", "及", "与", "这", "那", "你", + "我", "他", "她", "它", "们", "或", "上", "下", "之", "地", "得", "着", + "说", "对", "在", "于", "由", "因", "为", "所", "以", "能", "可", "会" + ] + write_text_file("\n".join(default_stopwords), stopwords_path) + logger.info(f"未找到停用词表,已创建默认停用词表:{stopwords_path}") + + # 加载停用词表 + try: + with open(stopwords_path, "r", encoding=ENCODING) as f: + for line in f: + word = line.strip() + if word: + self.stopwords.add(word) + logger.info(f"已加载 {len(self.stopwords)} 个停用词") + except Exception as e: + logger.error(f"加载停用词表失败:{e}") + + def add_stopwords(self, words: Union[str, List[str]]) -> None: + """ + 添加停用词 + + Args: + words: 要添加的停用词(字符串或列表) + """ + if isinstance(words, str): + self.stopwords.add(words.strip()) + else: + for word in words: + self.stopwords.add(word.strip()) + + def remove_stopwords_from_list(self, words: List[str]) -> List[str]: + """ + 从词语列表中移除停用词 + + Args: + words: 词语列表 + + Returns: + 移除停用词后的词语列表 + """ + if not self.remove_stopwords: + return words + + return [word for word in words if word not in self.stopwords] + + def tokenize(self, text: str, return_string: bool = False, + cut_all: bool = False) -> Union[List[str], str]: + """ + 对文本进行分词 + + Args: + text: 要分词的文本 + return_string: 是否返回字符串(以空格分隔的词语) + cut_all: 是否使用全模式(默认使用精确模式) + + Returns: + 分词结果(词语列表或字符串) + """ + if not text: + return "" if return_string else [] + + # 使用jieba进行分词 + if cut_all: + words = jieba.lcut(text, cut_all=True) + else: + words = jieba.lcut(text, HMM=self.use_hmm) + + # 移除停用词 + if self.remove_stopwords: + words = self.remove_stopwords_from_list(words) + + # 返回结果 + if return_string: + return " ".join(words) + else: + return words + + def batch_tokenize(self, texts: List[str], return_string: bool = False, + cut_all: bool = False) -> List[Union[List[str], str]]: + """ + 批量分词 + + Args: + texts: 要分词的文本列表 + return_string: 是否返回字符串(以空格分隔的词语) + cut_all: 是否使用全模式(默认使用精确模式) + + Returns: + 分词结果列表 + """ + return [self.tokenize(text, return_string, cut_all) for text in texts] + + def analyze_tokens(self, texts: List[str], top_n: int = 20) -> Dict[str, Any]: + """ + 分析文本中的词频分布 + + Args: + texts: 要分析的文本列表 + top_n: 返回前多少个高频词 + + Returns: + 包含词频分析结果的字典 + """ + all_tokens = [] + for text in texts: + tokens = self.tokenize(text, return_string=False) + all_tokens.extend(tokens) + + # 统计词频 + token_counter = Counter(all_tokens) + + # 获取最常见的词 + most_common = token_counter.most_common(top_n) + + # 计算唯一词数量 + unique_tokens = len(token_counter) + + return { + "total_tokens": len(all_tokens), + "unique_tokens": unique_tokens, + "most_common": most_common, + "token_counter": token_counter + } + + def get_top_keywords(self, texts: List[str], top_n: int = 20, + min_freq: int = 3, min_length: int = 2) -> List[Tuple[str, int]]: + """ + 获取文本中的关键词 + + Args: + texts: 要分析的文本列表 + top_n: 返回前多少个关键词 + min_freq: 最小词频 + min_length: 最小词长度(字符数) + + Returns: + 包含(关键词, 词频)的元组列表 + """ + tokens_analysis = self.analyze_tokens(texts) + token_counter = tokens_analysis["token_counter"] + + # 过滤满足条件的词 + filtered_keywords = [(word, count) for word, count in token_counter.items() + if count >= min_freq and len(word) >= min_length] + + # 按词频排序 + sorted_keywords = sorted(filtered_keywords, key=lambda x: x[1], reverse=True) + + return sorted_keywords[:top_n] + + def get_vocabulary(self, texts: List[str], min_freq: int = 1) -> List[str]: + """ + 获取词汇表 + + Args: + texts: 文本列表 + min_freq: 最小词频 + + Returns: + 词汇表(词语列表) + """ + tokens_analysis = self.analyze_tokens(texts) + token_counter = tokens_analysis["token_counter"] + + # 过滤满足最小词频的词 + vocabulary = [word for word, count in token_counter.items() if count >= min_freq] + + return vocabulary + + def get_stopwords(self) -> Set[str]: + """ + 获取停用词集合 + + Returns: + 停用词集合 + """ + return self.stopwords.copy() + +================================================================================ +文件: preprocessing/vectorizer.py +================================================================================ + +""" +文本向量化模块:实现文本向量化,包括词袋模型、TF-IDF和词嵌入等多种文本表示方法 +""" +import numpy as np +import tensorflow as tf +from tensorflow.keras.preprocessing.text import Tokenizer +from tensorflow.keras.preprocessing.sequence import pad_sequences +from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer +import pickle +import os +from typing import List, Dict, Tuple, Optional, Any, Union, Callable +import gensim +from gensim.models import Word2Vec, KeyedVectors + +from config.system_config import PROCESSED_DATA_DIR, EMBEDDINGS_DIR +from config.model_config import ( + MAX_SEQUENCE_LENGTH, MAX_NUM_WORDS, MIN_WORD_FREQUENCY +) +from utils.logger import get_logger +from utils.file_utils import save_pickle, load_pickle, ensure_dir +from preprocessing.tokenization import ChineseTokenizer + +logger = get_logger("Vectorizer") + + +class TextVectorizer: + """文本向量化基类,定义通用接口""" + + def __init__(self, max_features: int = MAX_NUM_WORDS): + """ + 初始化文本向量化器 + + Args: + max_features: 最大特征数(词汇表大小) + """ + self.max_features = max_features + self.vectorizer = None + self.is_fitted = False + + def fit(self, texts: List[str]) -> None: + """ + 在文本上训练向量化器 + + Args: + texts: 文本列表 + """ + raise NotImplementedError("子类必须实现此方法") + + def transform(self, texts: List[str]) -> np.ndarray: + """ + 将文本转换为向量表示 + + Args: + texts: 文本列表 + + Returns: + 向量表示 + """ + raise NotImplementedError("子类必须实现此方法") + + def fit_transform(self, texts: List[str]) -> np.ndarray: + """ + 在文本上训练向量化器,并将文本转换为向量表示 + + Args: + texts: 文本列表 + + Returns: + 向量表示 + """ + self.fit(texts) + return self.transform(texts) + + def save(self, path: str) -> None: + """ + 保存向量化器 + + Args: + path: 保存路径 + """ + ensure_dir(os.path.dirname(path)) + save_pickle(self.vectorizer, path) + logger.info(f"向量化器已保存到:{path}") + + def load(self, path: str) -> None: + """ + 加载向量化器 + + Args: + path: 加载路径 + """ + self.vectorizer = load_pickle(path) + self.is_fitted = True + logger.info(f"向量化器已从 {path} 加载") + + def get_vocabulary(self) -> List[str]: + """ + 获取词汇表 + + Returns: + 词汇表 + """ + raise NotImplementedError("子类必须实现此方法") + + def get_vocabulary_size(self) -> int: + """ + 获取词汇表大小 + + Returns: + 词汇表大小 + """ + raise NotImplementedError("子类必须实现此方法") + + +class BagOfWordsVectorizer(TextVectorizer): + """词袋模型向量化器""" + + def __init__(self, max_features: int = MAX_NUM_WORDS, + min_df: int = MIN_WORD_FREQUENCY, + tokenizer: Optional[Callable[[str], List[str]]] = None, + binary: bool = False): + """ + 初始化词袋模型向量化器 + + Args: + max_features: 最大特征数(词汇表大小) + min_df: 最小文档频率 + tokenizer: 分词器函数,接收文本,返回词语列表 + binary: 是否使用二进制计数(只关注词语是否出现,不关注频率) + """ + super().__init__(max_features) + self.min_df = min_df + self.binary = binary + + # 创建sklearn的CountVectorizer + self.vectorizer = CountVectorizer( + max_features=max_features, + min_df=min_df, + tokenizer=tokenizer, + binary=binary + ) + + def fit(self, texts: List[str]) -> None: + """ + 在文本上训练词袋模型 + + Args: + texts: 文本列表 + """ + self.vectorizer.fit(texts) + self.is_fitted = True + logger.info(f"词袋模型已训练,词汇表大小:{len(self.vectorizer.vocabulary_)}") + + def transform(self, texts: List[str]) -> np.ndarray: + """ + 将文本转换为词袋向量表示 + + Args: + texts: 文本列表 + + Returns: + 词袋向量表示(稀疏矩阵) + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + return self.vectorizer.transform(texts) + + def get_vocabulary(self) -> List[str]: + """ + 获取词汇表 + + Returns: + 词汇表(按索引排序) + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + # CountVectorizer的词汇表是一个字典,键为词,值为索引 + vocab_dict = self.vectorizer.vocabulary_ + vocab_list = [""] * len(vocab_dict) + for word, idx in vocab_dict.items(): + vocab_list[idx] = word + + return vocab_list + + def get_vocabulary_size(self) -> int: + """ + 获取词汇表大小 + + Returns: + 词汇表大小 + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + return len(self.vectorizer.vocabulary_) + + +class TfidfVectorizer(TextVectorizer): + """TF-IDF向量化器""" + + def __init__(self, max_features: int = MAX_NUM_WORDS, + min_df: int = MIN_WORD_FREQUENCY, + tokenizer: Optional[Callable[[str], List[str]]] = None, + norm: str = 'l2', + use_idf: bool = True, + smooth_idf: bool = True, + sublinear_tf: bool = False): + """ + 初始化TF-IDF向量化器 + + Args: + max_features: 最大特征数(词汇表大小) + min_df: 最小文档频率 + tokenizer: 分词器函数,接收文本,返回词语列表 + norm: 规范化方法,默认为L2范数 + use_idf: 是否使用IDF(逆文档频率) + smooth_idf: 是否平滑IDF权重 + sublinear_tf: 是否应用sublinear scaling(对TF取对数) + """ + super().__init__(max_features) + self.min_df = min_df + + # 创建sklearn的TfidfVectorizer + self.vectorizer = sklearn.feature_extraction.text.TfidfVectorizer( + max_features=max_features, + min_df=min_df, + tokenizer=tokenizer, + norm=norm, + use_idf=use_idf, + smooth_idf=smooth_idf, + sublinear_tf=sublinear_tf + ) + + def fit(self, texts: List[str]) -> None: + """ + 在文本上训练TF-IDF模型 + + Args: + texts: 文本列表 + """ + self.vectorizer.fit(texts) + self.is_fitted = True + logger.info(f"TF-IDF模型已训练,词汇表大小:{len(self.vectorizer.vocabulary_)}") + + def transform(self, texts: List[str]) -> np.ndarray: + """ + 将文本转换为TF-IDF向量表示 + + Args: + texts: 文本列表 + + Returns: + TF-IDF向量表示(稀疏矩阵) + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + return self.vectorizer.transform(texts) + + def get_vocabulary(self) -> List[str]: + """ + 获取词汇表 + + Returns: + 词汇表(按索引排序) + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + # TfidfVectorizer的词汇表是一个字典,键为词,值为索引 + vocab_dict = self.vectorizer.vocabulary_ + vocab_list = [""] * len(vocab_dict) + for word, idx in vocab_dict.items(): + vocab_list[idx] = word + + return vocab_list + + def get_vocabulary_size(self) -> int: + """ + 获取词汇表大小 + + Returns: + 词汇表大小 + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + return len(self.vectorizer.vocabulary_) + + def get_feature_names(self) -> List[str]: + """ + 获取特征名称(词汇表) + + Returns: + 特征名称列表 + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + return self.vectorizer.get_feature_names_out() + + def get_idf(self) -> np.ndarray: + """ + 获取IDF权重 + + Returns: + IDF权重数组 + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + return self.vectorizer.idf_ + + +class SequenceVectorizer(TextVectorizer): + """序列向量化器,使用Keras的Tokenizer""" + + def __init__(self, max_features: int = MAX_NUM_WORDS, + max_sequence_length: int = MAX_SEQUENCE_LENGTH, + oov_token: str = "", + padding: str = "post", + truncating: str = "post"): + """ + 初始化序列向量化器 + + Args: + max_features: 最大特征数(词汇表大小) + max_sequence_length: 序列最大长度 + oov_token: 未登录词标记 + padding: 填充方式,'pre'或'post' + truncating: 截断方式,'pre'或'post' + """ + super().__init__(max_features) + self.max_sequence_length = max_sequence_length + self.oov_token = oov_token + self.padding = padding + self.truncating = truncating + + # 创建Keras的Tokenizer + self.vectorizer = Tokenizer(num_words=max_features, oov_token=oov_token) + + def fit(self, texts: List[str]) -> None: + """ + 在文本上训练序列向量化器 + + Args: + texts: 文本列表 + """ + self.vectorizer.fit_on_texts(texts) + self.is_fitted = True + logger.info(f"序列向量化器已训练,词汇表大小:{len(self.vectorizer.word_index)}") + + def transform(self, texts: List[str]) -> np.ndarray: + """ + 将文本转换为整数序列,并进行填充 + + Args: + texts: 文本列表 + + Returns: + 整数序列表示 + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + sequences = self.vectorizer.texts_to_sequences(texts) + padded_sequences = pad_sequences( + sequences, + maxlen=self.max_sequence_length, + padding=self.padding, + truncating=self.truncating + ) + + return padded_sequences + + def get_vocabulary(self) -> List[str]: + """ + 获取词汇表 + + Returns: + 词汇表(按索引排序) + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + # Tokenizer的词汇表是一个字典,键为词,值为索引(从1开始) + word_index = self.vectorizer.word_index + index_word = {index: word for word, index in word_index.items()} + + # 注意索引0保留给padding,索引1保留给OOV(如果有设置) + vocab = [""] + if self.oov_token: + vocab.append(self.oov_token) + + max_index = min(self.max_features, len(word_index) + 1) if self.max_features else len(word_index) + 1 + for i in range(1, max_index): + if i in index_word: + vocab.append(index_word[i]) + + return vocab + + def get_vocabulary_size(self) -> int: + """ + 获取词汇表大小 + + Returns: + 词汇表大小 + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + # +1是因为索引0保留给padding + return min(self.max_features, len(self.vectorizer.word_index) + 1) if self.max_features else len( + self.vectorizer.word_index) + 1 + + def texts_to_sequences(self, texts: List[str]) -> List[List[int]]: + """ + 将文本转换为整数序列(不填充) + + Args: + texts: 文本列表 + + Returns: + 整数序列列表 + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + return self.vectorizer.texts_to_sequences(texts) + + def sequences_to_padded(self, sequences: List[List[int]]) -> np.ndarray: + """ + 将整数序列填充到指定长度 + + Args: + sequences: 整数序列列表 + + Returns: + 填充后的整数序列 + """ + return pad_sequences( + sequences, + maxlen=self.max_sequence_length, + padding=self.padding, + truncating=self.truncating + ) + + def save(self, path: str) -> None: + """ + 保存序列向量化器 + + Args: + path: 保存路径 + """ + ensure_dir(os.path.dirname(path)) + + # 保存配置和状态 + tokenizer_state = { + 'tokenizer': self.vectorizer, + 'max_features': self.max_features, + 'max_sequence_length': self.max_sequence_length, + 'oov_token': self.oov_token, + 'padding': self.padding, + 'truncating': self.truncating, + 'is_fitted': self.is_fitted + } + + save_pickle(tokenizer_state, path) + logger.info(f"序列向量化器已保存到:{path}") + + def load(self, path: str) -> None: + """ + 加载序列向量化器 + + Args: + path: 加载路径 + """ + tokenizer_state = load_pickle(path) + + self.vectorizer = tokenizer_state['tokenizer'] + self.max_features = tokenizer_state['max_features'] + self.max_sequence_length = tokenizer_state['max_sequence_length'] + self.oov_token = tokenizer_state['oov_token'] + self.padding = tokenizer_state['padding'] + self.truncating = tokenizer_state['truncating'] + self.is_fitted = tokenizer_state['is_fitted'] + + logger.info(f"序列向量化器已从 {path} 加载,词汇表大小:{len(self.vectorizer.word_index)}") + + class Word2VecVectorizer(TextVectorizer): + """Word2Vec词嵌入向量化器""" + + def __init__(self, vector_size: int = 100, + window: int = 5, + min_count: int = MIN_WORD_FREQUENCY, + workers: int = 4, + sg: int = 1, # 1表示Skip-gram模型,0表示CBOW模型 + max_sequence_length: int = MAX_SEQUENCE_LENGTH, + padding: str = "post", + truncating: str = "post", + pretrained_path: Optional[str] = None): + """ + 初始化Word2Vec词嵌入向量化器 + + Args: + vector_size: 词向量维度 + window: 上下文窗口大小 + min_count: 最小词频 + workers: 并行训练的线程数 + sg: 训练算法,1表示Skip-gram,0表示CBOW + max_sequence_length: 序列最大长度 + padding: 填充方式,'pre'或'post' + truncating: 截断方式,'pre'或'post' + pretrained_path: 预训练词向量路径,如果不为None,则加载预训练词向量 + """ + super().__init__(max_features=None) # Word2Vec没有max_features限制 + self.vector_size = vector_size + self.window = window + self.min_count = min_count + self.workers = workers + self.sg = sg + self.max_sequence_length = max_sequence_length + self.padding = padding + self.truncating = truncating + self.pretrained_path = pretrained_path + + # Word2Vec模型 + self.model = None + + # 词汇表 + self.word_index = {} + self.index_word = {} + + # 如果有预训练词向量,加载它 + if pretrained_path and os.path.exists(pretrained_path): + self._load_pretrained(pretrained_path) + + def _load_pretrained(self, path: str) -> None: + """ + 加载预训练词向量 + + Args: + path: 预训练词向量路径 + """ + try: + # 尝试加载Word2Vec模型 + self.model = Word2Vec.load(path) + logger.info(f"已加载预训练Word2Vec模型:{path}") + except: + try: + # 尝试加载词向量(Word2Vec、GloVe或FastText格式) + self.model = KeyedVectors.load_word2vec_format(path, binary=path.endswith('.bin')) + logger.info(f"已加载预训练词向量:{path}") + except Exception as e: + logger.error(f"加载预训练词向量失败:{e}") + return + + # 如果加载成功,构建词汇表 + self._build_vocab_from_model() + self.is_fitted = True + + def _build_vocab_from_model(self) -> None: + """从模型构建词汇表""" + # 获取词汇表 + vocabulary = list(self.model.wv.index_to_key) + + # 构建词汇表索引 + self.word_index = {word: idx + 1 for idx, word in enumerate(vocabulary)} # 索引0保留给padding + self.index_word = {idx + 1: word for idx, word in enumerate(vocabulary)} + self.index_word[0] = "" + + def fit(self, tokenized_texts: List[List[str]]) -> None: + """ + 在分词后的文本上训练Word2Vec模型 + + Args: + tokenized_texts: 分词后的文本列表(每个文本是一个词语列表) + """ + # 如果已经有预训练模型,跳过训练 + if self.is_fitted and self.model is not None: + logger.info("已有预训练模型,跳过训练") + return + + # 训练Word2Vec模型 + self.model = Word2Vec( + sentences=tokenized_texts, + vector_size=self.vector_size, + window=self.window, + min_count=self.min_count, + workers=self.workers, + sg=self.sg + ) + + # 构建词汇表 + self._build_vocab_from_model() + self.is_fitted = True + + logger.info(f"Word2Vec模型已训练,词汇表大小:{len(self.word_index)}") + + def transform(self, tokenized_texts: List[List[str]]) -> np.ndarray: + """ + 将分词后的文本转换为词向量序列 + + Args: + tokenized_texts: 分词后的文本列表(每个文本是一个词语列表) + + Returns: + 词向量序列,形状为(样本数, 最大序列长度, 词向量维度) + """ + if not self.is_fitted or self.model is None: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + # 初始化结果数组 + result = np.zeros((len(tokenized_texts), self.max_sequence_length, self.vector_size)) + + # 处理每个文本 + for i, text in enumerate(tokenized_texts): + seq_len = min(len(text), self.max_sequence_length) + + # 根据截断方式处理 + if self.truncating == 'pre' and len(text) > self.max_sequence_length: + text = text[-self.max_sequence_length:] + elif self.truncating == 'post' and len(text) > self.max_sequence_length: + text = text[:self.max_sequence_length] + + # 获取每个词的词向量 + for j, word in enumerate(text[:seq_len]): + if word in self.model.wv: + # 根据填充方式确定位置 + pos = j if self.padding == 'post' else self.max_sequence_length - seq_len + j + result[i, pos] = self.model.wv[word] + + return result + + def transform_to_indices(self, tokenized_texts: List[List[str]]) -> np.ndarray: + """ + 将分词后的文本转换为词索引序列,并填充 + + Args: + tokenized_texts: 分词后的文本列表(每个文本是一个词语列表) + + Returns: + 词索引序列 + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + # 将词转换为索引 + sequences = [] + for text in tokenized_texts: + seq = [self.word_index.get(word, 0) for word in text] # 未登录词用0(padding) + sequences.append(seq) + + # 填充序列 + padded_sequences = pad_sequences( + sequences, + maxlen=self.max_sequence_length, + padding=self.padding, + truncating=self.truncating + ) + + return padded_sequences + + def get_embedding_matrix(self) -> np.ndarray: + """ + 获取嵌入矩阵,用于Embedding层的权重初始化 + + Returns: + 嵌入矩阵,形状为(词汇表大小, 词向量维度) + """ + if not self.is_fitted or self.model is None: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + vocab_size = len(self.word_index) + 1 # +1是因为索引0保留给padding + embedding_matrix = np.zeros((vocab_size, self.vector_size)) + + # 填充嵌入矩阵 + for word, idx in self.word_index.items(): + if word in self.model.wv: + embedding_matrix[idx] = self.model.wv[word] + + return embedding_matrix + + def get_vocabulary(self) -> List[str]: + """ + 获取词汇表 + + Returns: + 词汇表(按索引排序) + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + vocab = [""] # 索引0保留给padding + for idx in range(1, len(self.index_word) + 1): + if idx in self.index_word: + vocab.append(self.index_word[idx]) + + return vocab + + def get_vocabulary_size(self) -> int: + """ + 获取词汇表大小 + + Returns: + 词汇表大小 + """ + if not self.is_fitted: + raise ValueError("向量化器尚未训练,请先调用fit方法") + + return len(self.word_index) + 1 # +1是因为索引0保留给padding + + def save(self, path: str) -> None: + """ + 保存Word2Vec向量化器 + + Args: + path: 保存路径 + """ + ensure_dir(os.path.dirname(path)) + + # 保存模型和配置 + model_path = os.path.join(os.path.dirname(path), "word2vec_model") + if self.model: + self.model.save(model_path) + + # 保存配置和状态 + state = { + 'word_index': self.word_index, + 'index_word': self.index_word, + 'vector_size': self.vector_size, + 'window': self.window, + 'min_count': self.min_count, + 'workers': self.workers, + 'sg': self.sg, + 'max_sequence_length': self.max_sequence_length, + 'padding': self.padding, + 'truncating': self.truncating, + 'is_fitted': self.is_fitted, + 'model_path': model_path if self.model else None + } + + save_pickle(state, path) + logger.info(f"Word2Vec向量化器已保存到:{path}") + + def load(self, path: str) -> None: + """ + 加载Word2Vec向量化器 + + Args: + path: 加载路径 + """ + state = load_pickle(path) + + self.word_index = state['word_index'] + self.index_word = state['index_word'] + self.vector_size = state['vector_size'] + self.window = state['window'] + self.min_count = state['min_count'] + self.workers = state['workers'] + self.sg = state['sg'] + self.max_sequence_length = state['max_sequence_length'] + self.padding = state['padding'] + self.truncating = state['truncating'] + self.is_fitted = state['is_fitted'] + + # 加载模型 + model_path = state.get('model_path') + if model_path and os.path.exists(model_path): + self.model = Word2Vec.load(model_path) + + logger.info(f"Word2Vec向量化器已从 {path} 加载,词汇表大小:{len(self.word_index)}") + +================================================================================ +文件: preprocessing/text_cleaner.py +================================================================================ + +""" +文本清洗模块:实现文本清洗,去除无用字符、HTML标签等 +""" +import re +import unicodedata +import html +from typing import List, Dict, Tuple, Optional, Any, Callable, Set, Union +import string + +from utils.logger import get_logger + +logger = get_logger("TextCleaner") + + +class TextCleaner: + """文本清洗类,提供各种文本清洗方法""" + + def __init__(self, remove_html: bool = True, + remove_urls: bool = True, + remove_emails: bool = True, + remove_numbers: bool = False, + remove_punctuation: bool = False, + lowercase: bool = False, + normalize_unicode: bool = True, + remove_excessive_spaces: bool = True, + remove_short_texts: bool = False, + min_text_length: int = 10, + custom_patterns: Optional[List[str]] = None): + """ + 初始化文本清洗器 + + Args: + remove_html: 是否移除HTML标签 + remove_urls: 是否移除URL + remove_emails: 是否移除电子邮件地址 + remove_numbers: 是否移除数字 + remove_punctuation: 是否移除标点符号 + lowercase: 是否转为小写(对中文无效) + normalize_unicode: 是否规范化Unicode字符 + remove_excessive_spaces: 是否移除多余空格 + remove_short_texts: 是否过滤掉短文本 + min_text_length: 最小文本长度(当remove_short_texts=True时有效) + custom_patterns: 自定义的正则表达式模式列表,用于额外的文本清洗 + """ + self.remove_html = remove_html + self.remove_urls = remove_urls + self.remove_emails = remove_emails + self.remove_numbers = remove_numbers + self.remove_punctuation = remove_punctuation + self.lowercase = lowercase + self.normalize_unicode = normalize_unicode + self.remove_excessive_spaces = remove_excessive_spaces + self.remove_short_texts = remove_short_texts + self.min_text_length = min_text_length + self.custom_patterns = custom_patterns or [] + + # 编译正则表达式 + self.html_pattern = re.compile(r'<.*?>') + self.url_pattern = re.compile(r'https?://\S+|www\.\S+') + self.email_pattern = re.compile(r'\S+@\S+\.\S+') + self.number_pattern = re.compile(r'\d+') + self.space_pattern = re.compile(r'\s+') + + # 编译自定义模式 + self.compiled_custom_patterns = [re.compile(pattern) for pattern in self.custom_patterns] + + # 中文标点符号 + self.chinese_punctuation = ",。!?;:""''【】《》()、…—~·" + + logger.info("文本清洗器初始化完成") + + def clean_text(self, text: str) -> str: + """ + 清洗文本,应用所有已配置的清洗方法 + + Args: + text: 原始文本 + + Returns: + 清洗后的文本 + """ + if not text: + return "" + + # HTML解码 + if self.remove_html: + text = html.unescape(text) + text = self.html_pattern.sub(' ', text) + + # 移除URL + if self.remove_urls: + text = self.url_pattern.sub(' ', text) + + # 移除电子邮件 + if self.remove_emails: + text = self.email_pattern.sub(' ', text) + + # Unicode规范化 + if self.normalize_unicode: + text = unicodedata.normalize('NFKC', text) + + # 移除数字 + if self.remove_numbers: + text = self.number_pattern.sub(' ', text) + + # 移除标点符号 + if self.remove_punctuation: + # 处理英文标点 + for punct in string.punctuation: + text = text.replace(punct, ' ') + # 处理中文标点 + for punct in self.chinese_punctuation: + text = text.replace(punct, ' ') + + # 应用自定义清洗模式 + for pattern in self.compiled_custom_patterns: + text = pattern.sub(' ', text) + + # 转为小写 + if self.lowercase: + text = text.lower() + + # 移除多余空格 + if self.remove_excessive_spaces: + text = self.space_pattern.sub(' ', text) + text = text.strip() + + # 过滤掉短文本 + if self.remove_short_texts and len(text) < self.min_text_length: + return "" + + return text + + def clean_texts(self, texts: List[str]) -> List[str]: + """ + 批量清洗文本 + + Args: + texts: 原始文本列表 + + Returns: + 清洗后的文本列表 + """ + return [self.clean_text(text) for text in texts] + + def remove_redundant_texts(self, texts: List[str]) -> List[str]: + """ + 移除冗余文本(空文本和长度小于阈值的文本) + + Args: + texts: 原始文本列表 + + Returns: + 移除冗余后的文本列表 + """ + return [text for text in texts if text and len(text) >= self.min_text_length] + + @staticmethod + def remove_specific_characters(text: str, chars_to_remove: Union[str, Set[str]]) -> str: + """ + 移除特定字符 + + Args: + text: 原始文本 + chars_to_remove: 要移除的字符(字符串或字符集合) + + Returns: + 移除特定字符后的文本 + """ + if isinstance(chars_to_remove, str): + for char in chars_to_remove: + text = text.replace(char, '') + else: + for char in chars_to_remove: + text = text.replace(char, '') + return text + + @staticmethod + def replace_characters(text: str, char_map: Dict[str, str]) -> str: + """ + 替换特定字符 + + Args: + text: 原始文本 + char_map: 字符映射字典,键为要替换的字符,值为替换后的字符 + + Returns: + 替换特定字符后的文本 + """ + for old_char, new_char in char_map.items(): + text = text.replace(old_char, new_char) + return text + + @staticmethod + def remove_empty_lines(text: str) -> str: + """ + 移除空行 + + Args: + text: 原始文本 + + Returns: + 移除空行后的文本 + """ + lines = text.splitlines() + non_empty_lines = [line for line in lines if line.strip()] + return '\n'.join(non_empty_lines) + + @staticmethod + def truncate_text(text: str, max_length: int, truncate_from_end: bool = True) -> str: + """ + 截断文本 + + Args: + text: 原始文本 + max_length: 最大长度 + truncate_from_end: 是否从末尾截断,如果为False则从开头截断 + + Returns: + 截断后的文本 + """ + if len(text) <= max_length: + return text + + if truncate_from_end: + return text[:max_length] + else: + return text[len(text) - max_length:] + + + +================================================================================ +文件: preprocessing/feature_extraction.py +================================================================================ + +""" +特征提取模块:实现文本特征提取,包括语法特征、语义特征等 +""" +import re +import numpy as np +from typing import List, Dict, Tuple, Optional, Any, Union, Set +from collections import Counter +import jieba.posseg as pseg + +from config.system_config import CATEGORIES +from utils.logger import get_logger +from preprocessing.tokenization import ChineseTokenizer + +logger = get_logger("FeatureExtraction") + + +class FeatureExtractor: + """特征提取基类,定义通用接口""" + + def __init__(self): + """初始化特征提取器""" + pass + + def extract(self, text: str) -> Dict[str, Any]: + """ + 从文本中提取特征 + + Args: + text: 文本 + + Returns: + 特征字典 + """ + raise NotImplementedError("子类必须实现此方法") + + def batch_extract(self, texts: List[str]) -> List[Dict[str, Any]]: + """ + 批量提取特征 + + Args: + texts: 文本列表 + + Returns: + 特征字典列表 + """ + return [self.extract(text) for text in texts] + + def extract_as_vector(self, text: str) -> np.ndarray: + """ + 从文本中提取特征,并转换为向量表示 + + Args: + text: 文本 + + Returns: + 特征向量 + """ + raise NotImplementedError("子类必须实现此方法") + + def batch_extract_as_vector(self, texts: List[str]) -> np.ndarray: + """ + 批量提取特征,并转换为向量表示 + + Args: + texts: 文本列表 + + Returns: + 特征向量数组 + """ + return np.array([self.extract_as_vector(text) for text in texts]) + + +class StatisticalFeatureExtractor(FeatureExtractor): + """统计特征提取器,提取文本的统计特征""" + + def __init__(self, tokenizer: Optional[ChineseTokenizer] = None): + """ + 初始化统计特征提取器 + + Args: + tokenizer: 分词器,如果为None则创建一个新的分词器 + """ + super().__init__() + self.tokenizer = tokenizer or ChineseTokenizer() + + def extract(self, text: str) -> Dict[str, Any]: + """ + 从文本中提取统计特征 + + Args: + text: 文本 + + Returns: + 特征字典,包含各种统计特征 + """ + if not text: + return { + "char_count": 0, + "word_count": 0, + "sentence_count": 0, + "avg_word_length": 0, + "avg_sentence_length": 0, + "contains_number": False, + "contains_english": False, + "punctuation_ratio": 0, + "top_words": [] + } + + # 字符数 + char_count = len(text) + + # 分词 + words = self.tokenizer.tokenize(text, return_string=False) + word_count = len(words) + + # 句子数(按标点符号分割) + sentences = re.split(r'[。!?!?]+', text) + sentences = [s for s in sentences if s.strip()] + sentence_count = len(sentences) + + # 平均词长 + avg_word_length = sum(len(word) for word in words) / word_count if word_count > 0 else 0 + + # 平均句长(以字符为单位) + avg_sentence_length = char_count / sentence_count if sentence_count > 0 else 0 + + # 是否包含数字 + contains_number = bool(re.search(r'\d', text)) + + # 是否包含英文 + contains_english = bool(re.search(r'[a-zA-Z]', text)) + + # 标点符号比例 + punctuation_pattern = re.compile(r'[^\w\s]') + punctuations = punctuation_pattern.findall(text) + punctuation_ratio = len(punctuations) / char_count if char_count > 0 else 0 + + # 高频词 + word_counter = Counter(words) + top_words = word_counter.most_common(5) + + return { + "char_count": char_count, + "word_count": word_count, + "sentence_count": sentence_count, + "avg_word_length": avg_word_length, + "avg_sentence_length": avg_sentence_length, + "contains_number": contains_number, + "contains_english": contains_english, + "punctuation_ratio": punctuation_ratio, + "top_words": top_words + } + + def extract_as_vector(self, text: str) -> np.ndarray: + """ + 从文本中提取统计特征,并转换为向量表示 + + Args: + text: 文本 + + Returns: + 特征向量,包含各种统计特征 + """ + features = self.extract(text) + + # 提取数值特征 + vector = [ + features['char_count'], + features['word_count'], + features['sentence_count'], + features['avg_word_length'], + features['avg_sentence_length'], + int(features['contains_number']), + int(features['contains_english']), + features['punctuation_ratio'] + ] + + return np.array(vector, dtype=np.float32) + + +class POSFeatureExtractor(FeatureExtractor): + """词性特征提取器,提取文本的词性特征""" + + def __init__(self): + """初始化词性特征提取器""" + super().__init__() + + # 常见中文词性及其解释 + self.pos_tags = { + 'n': '名词', 'f': '方位名词', 's': '处所名词', 't': '时间名词', + 'nr': '人名', 'ns': '地名', 'nt': '机构团体', 'nw': '作品名', + 'nz': '其他专名', 'v': '动词', 'vd': '副动词', 'vn': '名动词', + 'a': '形容词', 'ad': '副形词', 'an': '名形词', 'd': '副词', + 'm': '数词', 'q': '量词', 'r': '代词', 'p': '介词', + 'c': '连词', 'u': '助词', 'xc': '其他虚词', 'w': '标点符号' + } + + def extract(self, text: str) -> Dict[str, Any]: + """ + 从文本中提取词性特征 + + Args: + text: 文本 + + Returns: + 特征字典,包含各种词性特征 + """ + if not text: + return { + "pos_counts": {}, + "pos_ratios": {} + } + + # 使用jieba进行词性标注 + pos_list = pseg.cut(text) + + # 统计各词性的数量 + pos_counts = {} + total_count = 0 + + for word, pos in pos_list: + if pos in pos_counts: + pos_counts[pos] += 1 + else: + pos_counts[pos] = 1 + total_count += 1 + + # 计算各词性的比例 + pos_ratios = {pos: count / total_count for pos, count in pos_counts.items()} if total_count > 0 else {} + + return { + "pos_counts": pos_counts, + "pos_ratios": pos_ratios + } + + def extract_as_vector(self, text: str) -> np.ndarray: + """ + 从文本中提取词性特征,并转换为向量表示 + + Args: + text: 文本 + + Returns: + 特征向量,包含各词性的比例 + """ + features = self.extract(text) + pos_ratios = features['pos_ratios'] + + # 按照 self.pos_tags 的顺序构建向量 + vector = [] + for pos in self.pos_tags.keys(): + vector.append(pos_ratios.get(pos, 0.0)) + + return np.array(vector, dtype=np.float32) + + +class KeywordFeatureExtractor(FeatureExtractor): + """关键词特征提取器,基于预定义关键词提取特征""" + + def __init__(self, category_keywords: Optional[Dict[str, List[str]]] = None): + """ + 初始化关键词特征提取器 + + Args: + category_keywords: 类别关键词字典,键为类别名称,值为关键词列表 + """ + super().__init__() + self.category_keywords = category_keywords or self._get_default_keywords() + self.tokenizer = ChineseTokenizer() + + def _get_default_keywords(self) -> Dict[str, List[str]]: + """ + 获取默认的类别关键词 + + Returns: + 类别关键词字典 + """ + # 为每个类别定义一些示例关键词 + default_keywords = { + "体育": ["比赛", "运动", "球员", "冠军", "球队", "足球", "篮球"], + "财经": ["股票", "基金", "投资", "市场", "经济", "金融", "股市"], + "房产": ["房价", "楼市", "地产", "购房", "房贷", "物业", "小区"], + "家居": ["装修", "家具", "设计", "卧室", "客厅", "厨房", "风格"], + "教育": ["学校", "学生", "考试", "教育", "大学", "课程", "老师"], + "科技": ["互联网", "科技", "创新", "数字", "智能", "研发", "技术"], + "时尚": ["时尚", "潮流", "服装", "搭配", "品牌", "美容", "穿着"], + "时政": ["政府", "政策", "国家", "发展", "会议", "主席", "总理"], + "游戏": ["游戏", "玩家", "电竞", "网游", "手游", "角色", "任务"], + "娱乐": ["明星", "电影", "节目", "综艺", "电视", "演员", "导演"], + "其他": ["其他", "一般", "常见", "普通", "正常", "通常", "传统"] + } + + # 确保 CATEGORIES 中的每个类别都有关键词 + for category in CATEGORIES: + if category not in default_keywords: + default_keywords[category] = [category] + + return default_keywords + + def extract(self, text: str) -> Dict[str, Any]: + """ + 从文本中提取关键词特征 + + Args: + text: 文本 + + Returns: + 特征字典,包含各类别的关键词匹配情况 + """ + if not text: + return { + "keyword_matches": {cat: 0 for cat in self.category_keywords}, + "keyword_match_ratios": {cat: 0.0 for cat in self.category_keywords} + } + + # 对文本分词 + words = set(self.tokenizer.tokenize(text, return_string=False)) + + # 统计各类别的关键词匹配数量 + keyword_matches = {} + for category, keywords in self.category_keywords.items(): + # 计算文本中包含的该类别关键词数量 + matches = sum(1 for kw in keywords if kw in words) + keyword_matches[category] = matches + + # 计算匹配比例(归一化) + total_matches = sum(keyword_matches.values()) + keyword_match_ratios = { + cat: matches / total_matches if total_matches > 0 else 0.0 + for cat, matches in keyword_matches.items() + } + + return { + "keyword_matches": keyword_matches, + "keyword_match_ratios": keyword_match_ratios + } + + def extract_as_vector(self, text: str) -> np.ndarray: + """ + 从文本中提取关键词特征,并转换为向量表示 + + Args: + text: 文本 + + Returns: + 特征向量,包含各类别的关键词匹配比例 + """ + features = self.extract(text) + match_ratios = features['keyword_match_ratios'] + + # 按照 CATEGORIES 的顺序构建向量 + vector = [match_ratios.get(cat, 0.0) for cat in CATEGORIES] + + return np.array(vector, dtype=np.float32) + + def update_keywords(self, category: str, keywords: List[str]) -> None: + """ + 更新指定类别的关键词 + + Args: + category: 类别名称 + keywords: 关键词列表 + """ + self.category_keywords[category] = keywords + logger.info(f"已更新类别 {category} 的关键词,共 {len(keywords)} 个") + + def add_keywords(self, category: str, keywords: List[str]) -> None: + """ + 向指定类别添加关键词 + + Args: + category: 类别名称 + keywords: 要添加的关键词列表 + """ + if category in self.category_keywords: + existing_keywords = set(self.category_keywords[category]) + for keyword in keywords: + existing_keywords.add(keyword) + self.category_keywords[category] = list(existing_keywords) + else: + self.category_keywords[category] = keywords + + logger.info(f"已向类别 {category} 添加关键词,当前共 {len(self.category_keywords[category])} 个") + + class CombinedFeatureExtractor(FeatureExtractor): + """组合特征提取器,组合多个特征提取器的结果""" + + def __init__(self, extractors: List[FeatureExtractor]): + """ + 初始化组合特征提取器 + + Args: + extractors: 特征提取器列表 + """ + super().__init__() + self.extractors = extractors + + def extract(self, text: str) -> Dict[str, Any]: + """ + 从文本中提取组合特征 + + Args: + text: 文本 + + Returns: + 特征字典,包含所有特征提取器的结果 + """ + combined_features = {} + for i, extractor in enumerate(self.extractors): + extractor_name = type(extractor).__name__ + features = extractor.extract(text) + combined_features[extractor_name] = features + + return combined_features + + def extract_as_vector(self, text: str) -> np.ndarray: + """ + 从文本中提取组合特征,并转换为向量表示 + + Args: + text: 文本 + + Returns: + 特征向量,包含所有特征提取器的向量拼接 + """ + # 获取所有特征提取器的向量 + feature_vectors = [extractor.extract_as_vector(text) for extractor in self.extractors] + + # 拼接向量 + return np.concatenate(feature_vectors) + +================================================================================ +文件: preprocessing/data_augmentation.py +================================================================================ + +""" +数据增强模块:实现文本数据增强技术 +""" +import random +import re +import jieba +import synonyms +import numpy as np +from typing import List, Dict, Tuple, Optional, Any, Union, Callable +import copy + +from config.model_config import RANDOM_SEED +from utils.logger import get_logger +from preprocessing.tokenization import ChineseTokenizer + +# 设置随机种子以保证可重复性 +random.seed(RANDOM_SEED) +np.random.seed(RANDOM_SEED) + +logger = get_logger("DataAugmentation") + + +class TextAugmenter: + """文本增强基类,定义通用接口""" + + def __init__(self): + """初始化文本增强器""" + pass + + def augment(self, text: str) -> str: + """ + 对文本进行增强 + + Args: + text: 原始文本 + + Returns: + 增强后的文本 + """ + raise NotImplementedError("子类必须实现此方法") + + def batch_augment(self, texts: List[str]) -> List[str]: + """ + 批量对文本进行增强 + + Args: + texts: 原始文本列表 + + Returns: + 增强后的文本列表 + """ + return [self.augment(text) for text in texts] + + def augment_with_label(self, text: str, label: Any) -> Tuple[str, Any]: + """ + 对文本进行增强,同时保留标签 + + Args: + text: 原始文本 + label: 标签 + + Returns: + (增强后的文本, 标签)的元组 + """ + return self.augment(text), label + + def batch_augment_with_label(self, texts: List[str], labels: List[Any]) -> List[Tuple[str, Any]]: + """ + 批量对文本进行增强,同时保留标签 + + Args: + texts: 原始文本列表 + labels: 标签列表 + + Returns: + (增强后的文本, 标签)的元组列表 + """ + return [self.augment_with_label(text, label) for text, label in zip(texts, labels)] + + +class SynonymReplacement(TextAugmenter): + """同义词替换增强器""" + + def __init__(self, tokenizer: Optional[ChineseTokenizer] = None, + replace_ratio: float = 0.1, + min_similarity: float = 0.7): + """ + 初始化同义词替换增强器 + + Args: + tokenizer: 分词器,如果为None则创建一个新的分词器 + replace_ratio: 替换比例,表示要替换的词占总词数的比例 + min_similarity: 最小相似度,只有相似度大于该值的同义词才会被用于替换 + """ + super().__init__() + self.tokenizer = tokenizer or ChineseTokenizer() + self.replace_ratio = replace_ratio + self.min_similarity = min_similarity + + def _get_synonym(self, word: str) -> Optional[str]: + """ + 获取词的同义词 + + Args: + word: 原始词 + + Returns: + 同义词,如果没有合适的同义词则返回None + """ + # 使用synonyms包获取同义词 + try: + synonyms_list = synonyms.nearby(word) + + # synonyms.nearby返回一个元组,第一个元素是相似词列表,第二个元素是相似度列表 + words = synonyms_list[0] + similarities = synonyms_list[1] + + # 过滤掉相似度低于阈值的词和原词本身 + valid_synonyms = [(w, s) for w, s in zip(words, similarities) + if s >= self.min_similarity and w != word] + + if valid_synonyms: + # 按相似度排序,选择最相似的词 + valid_synonyms.sort(key=lambda x: x[1], reverse=True) + return valid_synonyms[0][0] + + return None + except: + return None + + def augment(self, text: str) -> str: + """ + 对文本进行同义词替换增强 + + Args: + text: 原始文本 + + Returns: + 增强后的文本 + """ + if not text: + return text + + # 分词 + words = self.tokenizer.tokenize(text, return_string=False, cut_all=False) + + if not words: + return text + + # 计算要替换的词数量 + n_replace = max(1, int(len(words) * self.replace_ratio)) + + # 随机选择要替换的词索引 + replace_indices = random.sample(range(len(words)), min(n_replace, len(words))) + + # 替换为同义词 + for idx in replace_indices: + synonym = self._get_synonym(words[idx]) + if synonym: + words[idx] = synonym + + # 拼接为文本 + augmented_text = ''.join(words) + + return augmented_text + + +class RandomDeletion(TextAugmenter): + """随机删除增强器""" + + def __init__(self, tokenizer: Optional[ChineseTokenizer] = None, + delete_ratio: float = 0.1): + """ + 初始化随机删除增强器 + + Args: + tokenizer: 分词器,如果为None则创建一个新的分词器 + delete_ratio: 删除比例,表示要删除的词占总词数的比例 + """ + super().__init__() + self.tokenizer = tokenizer or ChineseTokenizer() + self.delete_ratio = delete_ratio + + def augment(self, text: str) -> str: + """ + 对文本进行随机删除增强 + + Args: + text: 原始文本 + + Returns: + 增强后的文本 + """ + if not text: + return text + + # 分词 + words = self.tokenizer.tokenize(text, return_string=False, cut_all=False) + + if len(words) <= 1: + return text + + # 计算要删除的词数量 + n_delete = max(1, int(len(words) * self.delete_ratio)) + + # 随机选择要删除的词索引 + delete_indices = random.sample(range(len(words)), min(n_delete, len(words) - 1)) + + # 删除选中的词 + augmented_words = [words[i] for i in range(len(words)) if i not in delete_indices] + + # 拼接为文本 + augmented_text = ''.join(augmented_words) + + return augmented_text + + +class RandomSwap(TextAugmenter): + """随机交换增强器""" + + def __init__(self, tokenizer: Optional[ChineseTokenizer] = None, + n_swaps: int = 1): + """ + 初始化随机交换增强器 + + Args: + tokenizer: 分词器,如果为None则创建一个新的分词器 + n_swaps: 交换次数 + """ + super().__init__() + self.tokenizer = tokenizer or ChineseTokenizer() + self.n_swaps = n_swaps + + def augment(self, text: str) -> str: + """ + 对文本进行随机交换增强 + + Args: + text: 原始文本 + + Returns: + 增强后的文本 + """ + if not text: + return text + + # 分词 + words = self.tokenizer.tokenize(text, return_string=False, cut_all=False) + + if len(words) <= 1: + return text + + # 进行n_swaps次随机交换 + augmented_words = words.copy() + for _ in range(min(self.n_swaps, len(words) // 2)): + # 随机选择两个不同的索引 + idx1, idx2 = random.sample(range(len(augmented_words)), 2) + + # 交换两个词 + augmented_words[idx1], augmented_words[idx2] = augmented_words[idx2], augmented_words[idx1] + + # 拼接为文本 + augmented_text = ''.join(augmented_words) + + return augmented_text + + +class CompositeAugmenter(TextAugmenter): + """组合增强器,组合多个增强器""" + + def __init__(self, augmenters: List[TextAugmenter], + probs: Optional[List[float]] = None): + """ + 初始化组合增强器 + + Args: + augmenters: 增强器列表 + probs: 各增强器被选择的概率列表,如果为None则均匀选择 + """ + super().__init__() + self.augmenters = augmenters + + # 如果没有提供概率,则均匀分配 + if probs is None: + self.probs = [1.0 / len(augmenters)] * len(augmenters) + else: + # 确保概率和为1 + total = sum(probs) + self.probs = [p / total for p in probs] + + assert len(self.augmenters) == len(self.probs), "增强器数量与概率数量不匹配" + + def augment(self, text: str) -> str: + """ + 对文本进行组合增强 + + Args: + text: 原始文本 + + Returns: + 增强后的文本 + """ + if not text: + return text + + # 根据概率随机选择一个增强器 + augmenter = random.choices(self.augmenters, weights=self.probs, k=1)[0] + + # 使用选中的增强器进行增强 + return augmenter.augment(text) + + +class BackTranslation(TextAugmenter): + """回译增强器""" + + def __init__(self, translator=None, source_lang: str = 'zh', + target_langs: List[str] = None): + """ + 初始化回译增强器 + + Args: + translator: 翻译器,需要实现translate方法 + source_lang: 源语言代码 + target_langs: 目标语言代码列表,如果为None则使用默认语言 + """ + super().__init__() + + # 如果没有提供翻译器,尝试使用第三方翻译库 + if translator is None: + try: + # 尝试导入多种翻译库 + # 首先尝试使用googletrans (需要单独安装: pip install googletrans==4.0.0-rc1) + try: + from googletrans import Translator + self.translator = Translator() + self.translate_func = self._google_translate + except ImportError: + # 如果googletrans不可用,尝试使用py-translate + try: + import translate + self.translator = translate + self.translate_func = self._py_translate + except ImportError: + logger.warning("未安装翻译库,回译功能将不可用。请安装googletrans或py-translate") + self.translator = None + self.translate_func = self._dummy_translate + except Exception as e: + logger.error(f"初始化翻译器失败: {e}") + self.translator = None + self.translate_func = self._dummy_translate + else: + self.translator = translator + self.translate_func = self._custom_translate + + self.source_lang = source_lang + self.target_langs = target_langs or ['en', 'fr', 'de', 'es', 'ja'] + + def _google_translate(self, text: str, source_lang: str, target_lang: str) -> str: + """使用googletrans进行翻译""" + try: + result = self.translator.translate(text, src=source_lang, dest=target_lang) + return result.text + except Exception as e: + logger.error(f"翻译失败: {e}") + return text + + def _py_translate(self, text: str, source_lang: str, target_lang: str) -> str: + """使用py-translate进行翻译""" + try: + return self.translator.translate(text, source_lang, target_lang) + except Exception as e: + logger.error(f"翻译失败: {e}") + return text + + def _custom_translate(self, text: str, source_lang: str, target_lang: str) -> str: + """使用自定义翻译器进行翻译""" + try: + return self.translator.translate(text, source_lang, target_lang) + except Exception as e: + logger.error(f"翻译失败: {e}") + return text + + def _dummy_translate(self, text: str, source_lang: str, target_lang: str) -> str: + """虚拟翻译功能,仅返回原文本""" + logger.warning("翻译功能不可用,使用原文本") + return text + + def augment(self, text: str) -> str: + """ + 对文本进行回译增强 + + Args: + text: 原始文本 + + Returns: + 增强后的文本 + """ + if not text or self.translator is None: + return text + + # 随机选择一个目标语言 + target_lang = random.choice(self.target_langs) + + try: + # 将源语言翻译为目标语言 + translated = self.translate_func(text, self.source_lang, target_lang) + + # 将目标语言翻译回源语言 + back_translated = self.translate_func(translated, target_lang, self.source_lang) + + return back_translated + except Exception as e: + logger.error(f"回译失败: {e}") + return text + +================================================================================ +文件: data/__init__.py +================================================================================ + + + +================================================================================ +文件: data/dataset.py +================================================================================ + + + +================================================================================ +文件: data/dataloader.py +================================================================================ + +""" +数据加载模块:负责从文件系统加载原始文本数据 +""" +import os +import glob +import time +from pathlib import Path +from typing import List, Dict, Tuple, Optional, Any +from concurrent.futures import ThreadPoolExecutor, as_completed +import random +import numpy as np + +from config.system_config import ( + RAW_DATA_DIR, DATA_LOADING_WORKERS, CATEGORIES, + CATEGORY_TO_ID, ENCODING, MAX_MEMORY_GB, MAX_TEXT_PER_BATCH +) +from config.model_config import RANDOM_SEED +from utils.logger import get_logger +from utils.file_utils import read_text_file, read_files_parallel, list_files + +# 设置随机种子以保证可重复性 +random.seed(RANDOM_SEED) +np.random.seed(RANDOM_SEED) + +logger = get_logger("DataLoader") + + +class DataLoader: + """负责加载THUCNews数据集的类""" + + def __init__(self, data_dir: Optional[str] = None, + categories: Optional[List[str]] = None, + encoding: str = ENCODING, + max_workers: int = DATA_LOADING_WORKERS, + max_text_per_batch: int = MAX_TEXT_PER_BATCH): + """ + 初始化数据加载器 + + Args: + data_dir: 数据目录,默认使用配置文件中的路径 + categories: 要加载的类别列表,默认加载所有类别 + encoding: 文件编码 + max_workers: 最大工作线程数 + max_text_per_batch: 每批处理的最大文本数量 + """ + self.data_dir = Path(data_dir) if data_dir else RAW_DATA_DIR + self.categories = categories if categories else CATEGORIES + self.encoding = encoding + self.max_workers = max_workers + self.max_text_per_batch = max_text_per_batch + + # 验证数据目录是否存在 + if not self.data_dir.exists(): + raise FileNotFoundError(f"数据目录不存在: {self.data_dir}") + + # 验证类别是否存在 + for category in self.categories: + category_dir = self.data_dir / category + if not category_dir.exists(): + logger.warning(f"类别目录不存在: {category_dir}") + + # 存储类别目录的映射 + self.category_dirs = { + category: self.data_dir / category + for category in self.categories + if (self.data_dir / category).exists() + } + + # 记录类别文件数量 + self.category_file_counts = {} + + # 统计并记录每个类别的文件数量 + self._count_files() + logger.info(f"初始化完成,共找到 {sum(self.category_file_counts.values())} 个文本文件") + + def _count_files(self) -> None: + """统计每个类别的文件数量""" + for category, category_dir in self.category_dirs.items(): + files = list(category_dir.glob("*.txt")) + self.category_file_counts[category] = len(files) + logger.info(f"类别 [{category}] 包含 {len(files)} 个文本文件") + + def get_file_paths(self, category: Optional[str] = None, + sample_ratio: float = 1.0, + shuffle: bool = True) -> List[Tuple[str, str]]: + """ + 获取指定类别的文件路径列表 + + Args: + category: 类别名称,如果为None则获取所有类别 + sample_ratio: 采样比例,默认为1.0(全部) + shuffle: 是否打乱文件顺序 + + Returns: + 包含(文件路径, 类别)元组的列表 + """ + file_paths = [] + + # 确定要处理的类别 + categories_to_process = [category] if category else self.categories + + # 获取每个类别的文件路径 + for cat in categories_to_process: + if cat in self.category_dirs: + category_dir = self.category_dirs[cat] + cat_files = list(category_dir.glob("*.txt")) + + # 采样 + if sample_ratio < 1.0: + sample_size = int(len(cat_files) * sample_ratio) + if shuffle: + cat_files = random.sample(cat_files, sample_size) + else: + cat_files = cat_files[:sample_size] + + # 添加文件路径和对应的类别 + file_paths.extend([(str(file), cat) for file in cat_files]) + + # 打乱全局顺序(如果需要) + if shuffle: + random.shuffle(file_paths) + + return file_paths + + def load_texts(self, file_paths: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + """ + 加载指定路径的文本内容 + + Args: + file_paths: 包含(文件路径, 类别)元组的列表 + + Returns: + 包含(文本内容, 类别)元组的列表 + """ + start_time = time.time() + texts_with_labels = [] + + # 提取文件路径列表 + paths = [path for path, _ in file_paths] + labels = [label for _, label in file_paths] + + # 并行加载文本内容 + contents = read_files_parallel(paths, max_workers=self.max_workers, encoding=self.encoding) + + # 将内容与标签配对 + for content, label in zip(contents, labels): + if content: # 确保内容不为空 + texts_with_labels.append((content, label)) + + elapsed = time.time() - start_time + logger.info(f"加载了 {len(texts_with_labels)} 个文本,用时 {elapsed:.2f} 秒") + + return texts_with_labels + + def load_data(self, categories: Optional[List[str]] = None, + sample_ratio: float = 1.0, + shuffle: bool = True, + return_generator: bool = False) -> Any: + """ + 加载指定类别的所有数据 + + Args: + categories: 要加载的类别列表,默认为所有类别 + sample_ratio: 采样比例,默认为1.0(全部) + shuffle: 是否打乱数据顺序 + return_generator: 是否返回生成器(批量加载) + + Returns: + 如果return_generator为False,返回包含(文本内容, 类别)元组的列表 + 如果return_generator为True,返回一个生成器,每次产生一批数据 + """ + # 确定要处理的类别 + cats_to_process = categories if categories else self.categories + + # 验证类别是否存在 + for cat in cats_to_process: + if cat not in self.category_dirs: + logger.warning(f"类别 {cat} 不存在,将被忽略") + + # 筛选存在的类别 + cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs] + + # 获取所有文件路径 + all_file_paths = [] + for cat in cats_to_process: + cat_files = self.get_file_paths(cat, sample_ratio=sample_ratio, shuffle=shuffle) + all_file_paths.extend(cat_files) + + # 打乱全局顺序(如果需要) + if shuffle: + random.shuffle(all_file_paths) + + # 如果需要返回生成器,分批次加载数据 + if return_generator: + def data_generator(): + for i in range(0, len(all_file_paths), self.max_text_per_batch): + batch_paths = all_file_paths[i:i + self.max_text_per_batch] + batch_data = self.load_texts(batch_paths) + yield batch_data + + return data_generator() + + # 否则,一次性加载所有数据 + return self.load_texts(all_file_paths) + + def load_balanced_data(self, n_per_category: int = 1000, + categories: Optional[List[str]] = None, + shuffle: bool = True) -> List[Tuple[str, str]]: + """ + 加载平衡的数据集(每个类别的样本数量相同) + + Args: + n_per_category: 每个类别加载的样本数量 + categories: 要加载的类别列表,默认为所有类别 + shuffle: 是否打乱数据顺序 + + Returns: + 包含(文本内容, 类别)元组的列表 + """ + # 确定要处理的类别 + cats_to_process = categories if categories else self.categories + cats_to_process = [cat for cat in cats_to_process if cat in self.category_dirs] + + balanced_data = [] + + for cat in cats_to_process: + # 获取该类别的文件路径 + cat_files = self.get_file_paths(cat, shuffle=shuffle) + + # 限制数量 + cat_files = cat_files[:n_per_category] + + # 加载文本 + cat_data = self.load_texts(cat_files) + balanced_data.extend(cat_data) + + # 打乱全局顺序(如果需要) + if shuffle: + random.shuffle(balanced_data) + + return balanced_data + + def get_category_distribution(self) -> Dict[str, int]: + """ + 获取数据集的类别分布 + + Returns: + 包含各类别样本数量的字典 + """ + return self.category_file_counts + + def get_data_stats(self) -> Dict[str, Any]: + """ + 获取数据集的统计信息 + + Returns: + 包含统计信息的字典 + """ + # 计算总样本数 + total_samples = sum(self.category_file_counts.values()) + + # 计算各类别占比 + category_percentages = { + cat: count / total_samples * 100 + for cat, count in self.category_file_counts.items() + } + + # 采样几个文件计算平均文本长度 + sample_files = [] + for cat in self.categories: + if cat in self.category_dirs: + cat_files = list((self.data_dir / cat).glob("*.txt")) + if cat_files: + # 每个类别最多采样10个文件 + sample_files.extend(random.sample(cat_files, min(10, len(cat_files)))) + + # 加载采样的文件内容 + sample_contents = [] + for file_path in sample_files: + content = read_text_file(str(file_path), encoding=self.encoding) + if content: + sample_contents.append(content) + + # 计算平均文本长度(字符数) + avg_char_length = sum(len(content) for content in sample_contents) / len( + sample_contents) if sample_contents else 0 + + # 返回统计信息 + return { + "total_samples": total_samples, + "category_counts": self.category_file_counts, + "category_percentages": category_percentages, + "average_text_length": avg_char_length, + "categories": self.categories, + "num_categories": len(self.categories), + } + + +================================================================================ +文件: data/data_manager.py +================================================================================ + +""" +数据管理模块:负责数据的存储、读取和转换 +""" +import os +import pickle +import json +import time +from typing import List, Dict, Tuple, Optional, Any, Union +import numpy as np +import pandas as pd +from collections import Counter +import matplotlib.pyplot as plt +from sklearn.model_selection import train_test_split + +from config.system_config import ( + PROCESSED_DATA_DIR, ENCODING, CATEGORY_TO_ID, ID_TO_CATEGORY +) +from config.model_config import ( + VALIDATION_SPLIT, TEST_SPLIT, RANDOM_SEED +) +from utils.logger import get_logger +from utils.file_utils import ( + save_pickle, load_pickle, save_json, load_json, ensure_dir +) +from data.dataloader import DataLoader + +logger = get_logger("DataManager") + + +class DataManager: + """数据管理类,负责数据的存储、读取和转换""" + + def __init__(self, processed_dir: Optional[str] = None): + """ + 初始化数据管理器 + + Args: + processed_dir: 处理后数据的存储目录,默认使用配置文件中的路径 + """ + self.processed_dir = processed_dir or PROCESSED_DATA_DIR + ensure_dir(self.processed_dir) + + # 数据分割后的存储 + self.train_texts = [] + self.train_labels = [] + self.val_texts = [] + self.val_labels = [] + self.test_texts = [] + self.test_labels = [] + + # 数据统计信息 + self.stats = {} + + # 标签编码映射 + self.label_to_id = CATEGORY_TO_ID + self.id_to_label = ID_TO_CATEGORY + + logger.info(f"数据管理器初始化完成,处理后数据将存储在 {self.processed_dir}") + + def load_and_split_data(self, data_loader: DataLoader, + categories: Optional[List[str]] = None, + val_split: float = VALIDATION_SPLIT, + test_split: float = TEST_SPLIT, + sample_ratio: float = 1.0, + balanced: bool = False, + n_per_category: int = 1000, + save: bool = True) -> Dict[str, Any]: + """ + 加载并分割数据集 + + Args: + data_loader: 数据加载器实例 + categories: 要包含的类别列表,默认为所有类别 + val_split: 验证集比例 + test_split: 测试集比例 + sample_ratio: 采样比例,默认为1.0(全部) + balanced: 是否平衡各类别的样本数量 + n_per_category: 平衡模式下每个类别的样本数量 + save: 是否保存处理后的数据 + + Returns: + 包含分割后数据集的字典 + """ + start_time = time.time() + + # 加载数据 + if balanced: + logger.info(f"加载平衡数据集,每个类别 {n_per_category} 个样本") + data = data_loader.load_balanced_data( + n_per_category=n_per_category, + categories=categories, + shuffle=True + ) + else: + logger.info(f"加载数据集,采样比例 {sample_ratio}") + data = data_loader.load_data( + categories=categories, + sample_ratio=sample_ratio, + shuffle=True, + return_generator=False + ) + + logger.info(f"加载了 {len(data)} 个样本") + + # 分离文本和标签 + texts = [text for text, _ in data] + labels = [label for _, label in data] + + # 进行标签编码 + encoded_labels = np.array([self.label_to_id[label] for label in labels]) + + # 计算数据统计信息 + self._compute_stats(texts, labels) + + # 划分训练集、验证集和测试集 + # 先分出测试集 + if test_split > 0: + train_val_texts, self.test_texts, train_val_labels, self.test_labels = train_test_split( + texts, encoded_labels, + test_size=test_split, + random_state=RANDOM_SEED, + stratify=encoded_labels if len(set(encoded_labels)) > 1 else None + ) + else: + train_val_texts, train_val_labels = texts, encoded_labels + self.test_texts, self.test_labels = [], [] + + # 再划分训练集和验证集 + if val_split > 0: + self.train_texts, self.val_texts, self.train_labels, self.val_labels = train_test_split( + train_val_texts, train_val_labels, + test_size=val_split / (1 - test_split), + random_state=RANDOM_SEED, + stratify=train_val_labels if len(set(train_val_labels)) > 1 else None + ) + else: + self.train_texts, self.train_labels = train_val_texts, train_val_labels + self.val_texts, self.val_labels = [], [] + + # 打印数据集划分结果 + logger.info(f"数据集划分结果:") + logger.info(f" 训练集:{len(self.train_texts)} 个样本") + logger.info(f" 验证集:{len(self.val_texts)} 个样本") + logger.info(f" 测试集:{len(self.test_texts)} 个样本") + + # 保存处理后的数据 + if save: + self.save_data() + + elapsed = time.time() - start_time + logger.info(f"数据加载和分割完成,用时 {elapsed:.2f} 秒") + + return { + "train_texts": self.train_texts, + "train_labels": self.train_labels, + "val_texts": self.val_texts, + "val_labels": self.val_labels, + "test_texts": self.test_texts, + "test_labels": self.test_labels, + "stats": self.stats + } + + def _compute_stats(self, texts: List[str], labels: List[str]) -> None: + """ + 计算数据统计信息 + + Args: + texts: 文本列表 + labels: 标签列表 + """ + # 文本数量 + num_samples = len(texts) + + # 类别分布 + label_counter = Counter(labels) + label_distribution = {label: count / num_samples * 100 for label, count in label_counter.items()} + + # 文本长度统计 + text_lengths = [len(text) for text in texts] + avg_length = sum(text_lengths) / len(text_lengths) + max_length = max(text_lengths) + min_length = min(text_lengths) + + # 前5个最长和最短的文本的长度 + sorted_lengths = sorted(text_lengths) + shortest_lengths = sorted_lengths[:5] + longest_lengths = sorted_lengths[-5:] + + # 95%的文本长度分位数 + percentile_95 = np.percentile(text_lengths, 95) + + # 存储统计信息 + self.stats = { + "num_samples": num_samples, + "num_categories": len(label_counter), + "label_counter": label_counter, + "label_distribution": label_distribution, + "text_length": { + "average": avg_length, + "max": max_length, + "min": min_length, + "percentile_95": percentile_95, + "shortest_5": shortest_lengths, + "longest_5": longest_lengths + } + } + + def save_data(self, save_dir: Optional[str] = None) -> None: + """ + 保存处理后的数据 + + Args: + save_dir: 保存目录,默认使用初始化时设置的目录 + """ + save_dir = save_dir or self.processed_dir + ensure_dir(save_dir) + + # 保存训练集 + save_pickle( + {"texts": self.train_texts, "labels": self.train_labels}, + os.path.join(save_dir, "train_data.pkl") + ) + + # 保存验证集 + if len(self.val_texts) > 0: + save_pickle( + {"texts": self.val_texts, "labels": self.val_labels}, + os.path.join(save_dir, "val_data.pkl") + ) + + # 保存测试集 + if len(self.test_texts) > 0: + save_pickle( + {"texts": self.test_texts, "labels": self.test_labels}, + os.path.join(save_dir, "test_data.pkl") + ) + + # 保存标签编码映射 + save_json( + {"label_to_id": self.label_to_id, "id_to_label": self.id_to_label}, + os.path.join(save_dir, "label_mapping.json") + ) + + # 保存数据统计信息 + # 将Counter对象转换为普通字典以便JSON序列化 + stats_for_json = self.stats.copy() + if "label_counter" in stats_for_json: + stats_for_json["label_counter"] = dict(stats_for_json["label_counter"]) + + save_json( + stats_for_json, + os.path.join(save_dir, "data_stats.json") + ) + + logger.info(f"已将处理后的数据保存到 {save_dir}") + + def load_data(self, load_dir: Optional[str] = None) -> Dict[str, Any]: + """ + 加载处理后的数据 + + Args: + load_dir: 加载目录,默认使用初始化时设置的目录 + + Returns: + 包含加载的数据集的字典 + """ + load_dir = load_dir or self.processed_dir + + # 加载训练集 + train_data_path = os.path.join(load_dir, "train_data.pkl") + if os.path.exists(train_data_path): + train_data = load_pickle(train_data_path) + self.train_texts = train_data["texts"] + self.train_labels = train_data["labels"] + logger.info(f"已加载训练集,包含 {len(self.train_texts)} 个样本") + else: + logger.warning(f"训练集文件不存在: {train_data_path}") + self.train_texts, self.train_labels = [], [] + + # 加载验证集 + val_data_path = os.path.join(load_dir, "val_data.pkl") + if os.path.exists(val_data_path): + val_data = load_pickle(val_data_path) + self.val_texts = val_data["texts"] + self.val_labels = val_data["labels"] + logger.info(f"已加载验证集,包含 {len(self.val_texts)} 个样本") + else: + logger.warning(f"验证集文件不存在: {val_data_path}") + self.val_texts, self.val_labels = [], [] + + # 加载测试集 + test_data_path = os.path.join(load_dir, "test_data.pkl") + if os.path.exists(test_data_path): + test_data = load_pickle(test_data_path) + self.test_texts = test_data["texts"] + self.test_labels = test_data["labels"] + logger.info(f"已加载测试集,包含 {len(self.test_texts)} 个样本") + else: + logger.warning(f"测试集文件不存在: {test_data_path}") + self.test_texts, self.test_labels = [], [] + + # 加载标签编码映射 + mapping_path = os.path.join(load_dir, "label_mapping.json") + if os.path.exists(mapping_path): + mapping = load_json(mapping_path) + self.label_to_id = mapping["label_to_id"] + self.id_to_label = mapping["id_to_label"] + # 将字符串键转换为整数(JSON序列化会将所有键转为字符串) + self.id_to_label = {int(k): v for k, v in self.id_to_label.items()} + logger.info(f"已加载标签编码映射,共 {len(self.label_to_id)} 个类别") + + # 加载数据统计信息 + stats_path = os.path.join(load_dir, "data_stats.json") + if os.path.exists(stats_path): + self.stats = load_json(stats_path) + logger.info("已加载数据统计信息") + + return { + "train_texts": self.train_texts, + "train_labels": self.train_labels, + "val_texts": self.val_texts, + "val_labels": self.val_labels, + "test_texts": self.test_texts, + "test_labels": self.test_labels, + "stats": self.stats + } + + def get_label_distribution(self, dataset: str = "train") -> Dict[str, float]: + """ + 获取指定数据集的标签分布 + + Args: + dataset: 数据集名称,可选值:'train', 'val', 'test' + + Returns: + 标签分布字典,键为类别名称,值为比例 + """ + if dataset == "train": + labels = self.train_labels + elif dataset == "val": + labels = self.val_labels + elif dataset == "test": + labels = self.test_labels + else: + raise ValueError(f"不支持的数据集名称: {dataset}") + + # 计算标签分布 + label_counter = Counter(labels) + num_samples = len(labels) + + # 将数字标签转换为类别名称 + distribution = {} + for label_id, count in label_counter.items(): + label_name = self.id_to_label.get(label_id, str(label_id)) + distribution[label_name] = count / num_samples * 100 + + return distribution + + def visualize_label_distribution(self, dataset: str = "train", + save_path: Optional[str] = None) -> None: + """ + 可视化标签分布 + + Args: + dataset: 数据集名称,可选值:'train', 'val', 'test', 'all' + save_path: 图表保存路径,默认为None(显示而不保存) + """ + plt.figure(figsize=(12, 8)) + + if dataset == "all": + # 显示所有数据集的标签分布 + train_dist = self.get_label_distribution("train") + val_dist = self.get_label_distribution("val") if len(self.val_labels) > 0 else {} + test_dist = self.get_label_distribution("test") if len(self.test_labels) > 0 else {} + + # 准备数据 + categories = list(train_dist.keys()) + train_values = [train_dist.get(cat, 0) for cat in categories] + val_values = [val_dist.get(cat, 0) for cat in categories] + test_values = [test_dist.get(cat, 0) for cat in categories] + + # 绘制条形图 + x = np.arange(len(categories)) + width = 0.25 + + plt.bar(x - width, train_values, width, label="Training") + if val_values: + plt.bar(x, val_values, width, label="Validation") + if test_values: + plt.bar(x + width, test_values, width, label="Testing") + + plt.xlabel("Categories") + plt.ylabel("Percentage (%)") + plt.title("Label Distribution Across Datasets") + plt.xticks(x, categories, rotation=45, ha="right") + plt.legend() + plt.tight_layout() + else: + # 显示单个数据集的标签分布 + distribution = self.get_label_distribution(dataset) + + # 按值排序 + sorted_items = sorted(distribution.items(), key=lambda x: x[1], reverse=True) + categories = [item[0] for item in sorted_items] + values = [item[1] for item in sorted_items] + + # 绘制条形图 + plt.bar(categories, values, color='skyblue') + plt.xlabel("Categories") + plt.ylabel("Percentage (%)") + plt.title(f"Label Distribution in {dataset.capitalize()} Dataset") + plt.xticks(rotation=45, ha="right") + plt.tight_layout() + + # 保存或显示图表 + if save_path: + plt.savefig(save_path) + logger.info(f"标签分布图已保存到 {save_path}") + else: + plt.show() + + def visualize_text_length_distribution(self, dataset: str = "train", + bins: int = 50, + save_path: Optional[str] = None) -> None: + """ + 可视化文本长度分布 + + Args: + dataset: 数据集名称,可选值:'train', 'val', 'test' + bins: 直方图的箱数 + save_path: 图表保存路径,默认为None(显示而不保存) + """ + if dataset == "train": + texts = self.train_texts + elif dataset == "val": + texts = self.val_texts + elif dataset == "test": + texts = self.test_texts + else: + raise ValueError(f"不支持的数据集名称: {dataset}") + + # 计算文本长度 + text_lengths = [len(text) for text in texts] + + # 绘制直方图 + plt.figure(figsize=(10, 6)) + plt.hist(text_lengths, bins=bins, color='skyblue', alpha=0.7) + + # 计算并绘制一些统计量 + avg_length = sum(text_lengths) / len(text_lengths) + median_length = np.median(text_lengths) + percentile_95 = np.percentile(text_lengths, 95) + + plt.axvline(avg_length, color='red', linestyle='dashed', linewidth=1, label=f'Mean: {avg_length:.1f}') + plt.axvline(median_length, color='green', linestyle='dashed', linewidth=1, label=f'Median: {median_length:.1f}') + plt.axvline(percentile_95, color='purple', linestyle='dashed', linewidth=1, + label=f'95th Percentile: {percentile_95:.1f}') + + plt.xlabel('Text Length (characters)') + plt.ylabel('Frequency') + plt.title(f'Text Length Distribution in {dataset.capitalize()} Dataset') + plt.legend() + plt.tight_layout() + + # 保存或显示图表 + if save_path: + plt.savefig(save_path) + logger.info(f"文本长度分布图已保存到 {save_path}") + else: + plt.show() + + def get_data_summary(self) -> Dict[str, Any]: + """ + 获取数据集的摘要信息 + + Returns: + 包含数据摘要的字典 + """ + # 获取数据集的基本信息 + summary = { + "train_size": len(self.train_texts), + "val_size": len(self.val_texts), + "test_size": len(self.test_texts), + "num_categories": len(self.label_to_id), + "categories": list(self.label_to_id.keys()), + } + + # 添加训练集的标签分布 + if len(self.train_texts) > 0: + summary["train_label_distribution"] = self.get_label_distribution("train") + + # 添加验证集的标签分布 + if len(self.val_texts) > 0: + summary["val_label_distribution"] = self.get_label_distribution("val") + + # 添加测试集的标签分布 + if len(self.test_texts) > 0: + summary["test_label_distribution"] = self.get_label_distribution("test") + + # 添加更多统计信息(如果有) + if self.stats: + # 只添加一些关键的统计信息 + if "text_length" in self.stats: + summary["text_length_stats"] = self.stats["text_length"] + + return summary + + def export_to_pandas(self, dataset: str = "train") -> pd.DataFrame: + """ + 将数据导出为Pandas DataFrame + + Args: + dataset: 数据集名称,可选值:'train', 'val', 'test' + + Returns: + Pandas DataFrame + """ + if dataset == "train": + texts = self.train_texts + labels_ids = self.train_labels + elif dataset == "val": + texts = self.val_texts + labels_ids = self.val_labels + elif dataset == "test": + texts = self.test_texts + labels_ids = self.test_labels + else: + raise ValueError(f"不支持的数据集名称: {dataset}") + + # 将数字标签转换为类别名称 + labels = [self.id_to_label.get(label_id, str(label_id)) for label_id in labels_ids] + + # 创建DataFrame + df = pd.DataFrame({ + "text": texts, + "label_id": labels_ids, + "label": labels + }) + + return df + + def get_label_name(self, label_id: int) -> str: + """ + 获取标签ID对应的类别名称 + + Args: + label_id: 标签ID + + Returns: + 类别名称 + """ + return self.id_to_label.get(label_id, str(label_id)) + + def get_label_id(self, label_name: str) -> int: + """ + 获取类别名称对应的标签ID + + Args: + label_name: 类别名称 + + Returns: + 标签ID + """ + return self.label_to_id.get(label_name, -1) + + def get_data(self, dataset: str = "train") -> Tuple[List[str], np.ndarray]: + """ + 获取指定数据集的文本和标签 + + Args: + dataset: 数据集名称,可选值:'train', 'val', 'test' + + Returns: + (文本列表, 标签数组)的元组 + """ + if dataset == "train": + return self.train_texts, self.train_labels + elif dataset == "val": + return self.val_texts, self.val_labels + elif dataset == "test": + return self.test_texts, self.test_labels + else: + raise ValueError(f"不支持的数据集名称: {dataset}") diff --git a/python_files.txt b/python_files.txt new file mode 100644 index 0000000..1917c79 --- /dev/null +++ b/python_files.txt @@ -0,0 +1,54 @@ +export_all_pyfile.py +setup.py +main.py +interface/__init__.py +interface/api.py +interface/cli.py +interface/web/__init__.py +interface/web/app.py +interface/web/routes.py +config/model_config.py +config/__init__.py +config/system_config.py +training/__init__.py +training/callbacks.py +training/optimizer.py +training/scheduler.py +training/trainer.py +tests/__init__.py +tests/test_evaluation.py +tests/test_preprocessing.py +tests/test_models.py +utils/text_utils.py +utils/__init__.py +utils/time_utils.py +utils/logger.py +utils/file_utils.py +models/ensemble_model.py +models/transformer_model.py +models/__init__.py +models/base_model.py +models/rnn_model.py +models/model_factory.py +models/cnn_model.py +models/layers/__init__.py +scripts/predict.py +scripts/train.py +scripts/evaluate.py +inference/predictor.py +inference/batch_processor.py +inference/__init__.py +evaluation/metrics.py +evaluation/__init__.py +evaluation/visualization.py +evaluation/evaluator.py +preprocessing/__init__.py +preprocessing/tokenization.py +preprocessing/vectorizer.py +preprocessing/text_cleaner.py +preprocessing/feature_extraction.py +preprocessing/data_augmentation.py +data/__init__.py +data/dataset.py +data/dataloader.py +data/data_manager.py diff --git a/scripts/train.py b/scripts/train.py index c473a80..86b354a 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -11,6 +11,19 @@ import numpy as np import tensorflow as tf import matplotlib.pyplot as plt +# 检测 GPU +physical_devices = tf.config.list_physical_devices('GPU') +print("可用的物理 GPU 设备:", physical_devices) + +if physical_devices: + try: + # 设置 GPU 内存增长模式 + for gpu in physical_devices: + tf.config.experimental.set_memory_growth(gpu, True) + print("已设置 GPU 内存增长模式") + except RuntimeError as e: + print(f"设置 GPU 内存增长时出错: {e}") + # 将项目根目录添加到系统路径 project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(project_root) diff --git a/training/trainer.py b/training/trainer.py index b955e33..cdedc66 100644 --- a/training/trainer.py +++ b/training/trainer.py @@ -190,20 +190,58 @@ class Trainer: } self.training_logger.log_training_start({**model_config, **train_config}) + # 检查和配置 GPU 使用 + physical_devices = tf.config.list_physical_devices('GPU') + logger.info(f"可用的物理 GPU 设备: {physical_devices}") + + # 记录当前使用的设备情况 + logger.info(f"TensorFlow 版本: {tf.__version__}") + if physical_devices: + logger.info(f"模型将使用 GPU 进行训练") + try: + # 设置 GPU 内存增长模式 + for gpu in physical_devices: + tf.config.experimental.set_memory_growth(gpu, True) + logger.info(f"已设置 GPU 内存增长模式") + except RuntimeError as e: + logger.warning(f"设置 GPU 内存增长时出错: {e}") + else: + logger.warning(f"未检测到 GPU,将使用 CPU 进行训练") + + # 尝试强制使用 GPU + if physical_devices: + try: + # 将运算放到 GPU 上 + with tf.device('/GPU:0'): + logger.info("已强制指定使用 GPU:0 进行训练") + except RuntimeError as e: + logger.warning(f"指定 GPU 设备时出错: {e}") + # 准备验证数据 validation_data = None if x_val is not None and y_val is not None: validation_data = (x_val, y_val) # 训练模型 - history = self.model.fit( - x_train, y_train, - validation_data=validation_data, - epochs=self.epochs, - callbacks=callbacks, - class_weights=class_weights, - verbose=1 - ) + if physical_devices: + with tf.device('/GPU:0'): + history = self.model.fit( + x_train, y_train, + validation_data=validation_data, + epochs=self.epochs, + callbacks=callbacks, + class_weights=class_weights, + verbose=1 + ) + else: + history = self.model.fit( + x_train, y_train, + validation_data=validation_data, + epochs=self.epochs, + callbacks=callbacks, + class_weights=class_weights, + verbose=1 + ) # 计算训练时间 train_time = time.time() - start_time