71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
"""
|
||
系统全局配置文件
|
||
"""
|
||
import os
|
||
import platform
|
||
from pathlib import Path
|
||
|
||
# 项目根目录
|
||
ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
"""
|
||
Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 是当前文件的上一级目录
|
||
这种写法主要是为了方便移植项目到不同的平台运行
|
||
"""
|
||
|
||
# 数据相关路径
|
||
DATA_DIR = ROOT_DIR / "data"
|
||
RAW_DATA_DIR = DATA_DIR / "raw" / "THUCNews"
|
||
PROCESSED_DATA_DIR = DATA_DIR / "processed"
|
||
RESOURCES_DIR = DATA_DIR / "resources"
|
||
STOPWORDS_DIR = RESOURCES_DIR / "stopwords"
|
||
EMBEDDINGS_DIR = RESOURCES_DIR / "embeddings"
|
||
|
||
# 确保必要的目录存在
|
||
for directory in [PROCESSED_DATA_DIR, RESOURCES_DIR, STOPWORDS_DIR, EMBEDDINGS_DIR]:
|
||
directory.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 保存模型的路径
|
||
SAVED_MODELS_DIR = ROOT_DIR / "saved_models"
|
||
TOKENIZERS_DIR = SAVED_MODELS_DIR / "tokenizers"
|
||
CLASSIFIERS_DIR = SAVED_MODELS_DIR / "classifiers"
|
||
|
||
# 确保模型保存目录存在
|
||
for directory in [SAVED_MODELS_DIR, TOKENIZERS_DIR, CLASSIFIERS_DIR]:
|
||
directory.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 系统资源配置
|
||
CPU_COUNT = os.cpu_count()
|
||
USE_GPU = True
|
||
MULTI_GPU = False # 目前只使用单个GPU
|
||
|
||
# 基于13900K性能设置并行处理参数
|
||
DATA_LOADING_WORKERS = min(16, CPU_COUNT) # 数据加载线程数
|
||
PREPROCESSING_WORKERS = min(24, CPU_COUNT) # 预处理线程数,13900K有强大的多线程能力
|
||
|
||
# 基于64GB内存设置内存相关参数
|
||
MAX_MEMORY_GB = 48 # 保留部分内存给系统和其他应用
|
||
MAX_TEXT_PER_BATCH = 10000 # 每批处理的最大文本数量
|
||
|
||
# 日志配置
|
||
LOG_DIR = ROOT_DIR / "logs"
|
||
LOG_DIR.mkdir(exist_ok=True)
|
||
LOG_LEVEL = "INFO"
|
||
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||
|
||
# 类别标签映射(与THUCNews数据集一致)
|
||
CATEGORIES = [
|
||
"体育", "娱乐", "家居", "彩票", "房产", "教育",
|
||
"时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"
|
||
]
|
||
CATEGORY_TO_ID = {category: idx for idx, category in enumerate(CATEGORIES)}
|
||
ID_TO_CATEGORY = {idx: category for idx, category in enumerate(CATEGORIES)}
|
||
|
||
# 文件编码
|
||
ENCODING = "utf-8"
|
||
|
||
# 系统信息
|
||
SYSTEM_INFO = {
|
||
"platform": platform.platform(),
|
||
"python_version": platform.python_version(),
|
||
"processor": platform.processor(),
|
||
} |