2025-03-08 01:34:36 +08:00

414 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据增强模块:实现文本数据增强技术
"""
import random
import re
import jieba
import synonyms
import numpy as np
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
import copy
from config.model_config import RANDOM_SEED
from utils.logger import get_logger
from preprocessing.tokenization import ChineseTokenizer
# 设置随机种子以保证可重复性
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
logger = get_logger("DataAugmentation")
class TextAugmenter:
"""文本增强基类,定义通用接口"""
def __init__(self):
"""初始化文本增强器"""
pass
def augment(self, text: str) -> str:
"""
对文本进行增强
Args:
text: 原始文本
Returns:
增强后的文本
"""
raise NotImplementedError("子类必须实现此方法")
def batch_augment(self, texts: List[str]) -> List[str]:
"""
批量对文本进行增强
Args:
texts: 原始文本列表
Returns:
增强后的文本列表
"""
return [self.augment(text) for text in texts]
def augment_with_label(self, text: str, label: Any) -> Tuple[str, Any]:
"""
对文本进行增强,同时保留标签
Args:
text: 原始文本
label: 标签
Returns:
(增强后的文本, 标签)的元组
"""
return self.augment(text), label
def batch_augment_with_label(self, texts: List[str], labels: List[Any]) -> List[Tuple[str, Any]]:
"""
批量对文本进行增强,同时保留标签
Args:
texts: 原始文本列表
labels: 标签列表
Returns:
(增强后的文本, 标签)的元组列表
"""
return [self.augment_with_label(text, label) for text, label in zip(texts, labels)]
class SynonymReplacement(TextAugmenter):
"""同义词替换增强器"""
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
replace_ratio: float = 0.1,
min_similarity: float = 0.7):
"""
初始化同义词替换增强器
Args:
tokenizer: 分词器如果为None则创建一个新的分词器
replace_ratio: 替换比例,表示要替换的词占总词数的比例
min_similarity: 最小相似度,只有相似度大于该值的同义词才会被用于替换
"""
super().__init__()
self.tokenizer = tokenizer or ChineseTokenizer()
self.replace_ratio = replace_ratio
self.min_similarity = min_similarity
def _get_synonym(self, word: str) -> Optional[str]:
"""
获取词的同义词
Args:
word: 原始词
Returns:
同义词如果没有合适的同义词则返回None
"""
# 使用synonyms包获取同义词
try:
synonyms_list = synonyms.nearby(word)
# synonyms.nearby返回一个元组第一个元素是相似词列表第二个元素是相似度列表
words = synonyms_list[0]
similarities = synonyms_list[1]
# 过滤掉相似度低于阈值的词和原词本身
valid_synonyms = [(w, s) for w, s in zip(words, similarities)
if s >= self.min_similarity and w != word]
if valid_synonyms:
# 按相似度排序,选择最相似的词
valid_synonyms.sort(key=lambda x: x[1], reverse=True)
return valid_synonyms[0][0]
return None
except:
return None
def augment(self, text: str) -> str:
"""
对文本进行同义词替换增强
Args:
text: 原始文本
Returns:
增强后的文本
"""
if not text:
return text
# 分词
words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
if not words:
return text
# 计算要替换的词数量
n_replace = max(1, int(len(words) * self.replace_ratio))
# 随机选择要替换的词索引
replace_indices = random.sample(range(len(words)), min(n_replace, len(words)))
# 替换为同义词
for idx in replace_indices:
synonym = self._get_synonym(words[idx])
if synonym:
words[idx] = synonym
# 拼接为文本
augmented_text = ''.join(words)
return augmented_text
class RandomDeletion(TextAugmenter):
"""随机删除增强器"""
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
delete_ratio: float = 0.1):
"""
初始化随机删除增强器
Args:
tokenizer: 分词器如果为None则创建一个新的分词器
delete_ratio: 删除比例,表示要删除的词占总词数的比例
"""
super().__init__()
self.tokenizer = tokenizer or ChineseTokenizer()
self.delete_ratio = delete_ratio
def augment(self, text: str) -> str:
"""
对文本进行随机删除增强
Args:
text: 原始文本
Returns:
增强后的文本
"""
if not text:
return text
# 分词
words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
if len(words) <= 1:
return text
# 计算要删除的词数量
n_delete = max(1, int(len(words) * self.delete_ratio))
# 随机选择要删除的词索引
delete_indices = random.sample(range(len(words)), min(n_delete, len(words) - 1))
# 删除选中的词
augmented_words = [words[i] for i in range(len(words)) if i not in delete_indices]
# 拼接为文本
augmented_text = ''.join(augmented_words)
return augmented_text
class RandomSwap(TextAugmenter):
"""随机交换增强器"""
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
n_swaps: int = 1):
"""
初始化随机交换增强器
Args:
tokenizer: 分词器如果为None则创建一个新的分词器
n_swaps: 交换次数
"""
super().__init__()
self.tokenizer = tokenizer or ChineseTokenizer()
self.n_swaps = n_swaps
def augment(self, text: str) -> str:
"""
对文本进行随机交换增强
Args:
text: 原始文本
Returns:
增强后的文本
"""
if not text:
return text
# 分词
words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
if len(words) <= 1:
return text
# 进行n_swaps次随机交换
augmented_words = words.copy()
for _ in range(min(self.n_swaps, len(words) // 2)):
# 随机选择两个不同的索引
idx1, idx2 = random.sample(range(len(augmented_words)), 2)
# 交换两个词
augmented_words[idx1], augmented_words[idx2] = augmented_words[idx2], augmented_words[idx1]
# 拼接为文本
augmented_text = ''.join(augmented_words)
return augmented_text
class CompositeAugmenter(TextAugmenter):
"""组合增强器,组合多个增强器"""
def __init__(self, augmenters: List[TextAugmenter],
probs: Optional[List[float]] = None):
"""
初始化组合增强器
Args:
augmenters: 增强器列表
probs: 各增强器被选择的概率列表如果为None则均匀选择
"""
super().__init__()
self.augmenters = augmenters
# 如果没有提供概率,则均匀分配
if probs is None:
self.probs = [1.0 / len(augmenters)] * len(augmenters)
else:
# 确保概率和为1
total = sum(probs)
self.probs = [p / total for p in probs]
assert len(self.augmenters) == len(self.probs), "增强器数量与概率数量不匹配"
def augment(self, text: str) -> str:
"""
对文本进行组合增强
Args:
text: 原始文本
Returns:
增强后的文本
"""
if not text:
return text
# 根据概率随机选择一个增强器
augmenter = random.choices(self.augmenters, weights=self.probs, k=1)[0]
# 使用选中的增强器进行增强
return augmenter.augment(text)
class BackTranslation(TextAugmenter):
"""回译增强器"""
def __init__(self, translator=None, source_lang: str = 'zh',
target_langs: List[str] = None):
"""
初始化回译增强器
Args:
translator: 翻译器需要实现translate方法
source_lang: 源语言代码
target_langs: 目标语言代码列表如果为None则使用默认语言
"""
super().__init__()
# 如果没有提供翻译器,尝试使用第三方翻译库
if translator is None:
try:
# 尝试导入多种翻译库
# 首先尝试使用googletrans (需要单独安装: pip install googletrans==4.0.0-rc1)
try:
from googletrans import Translator
self.translator = Translator()
self.translate_func = self._google_translate
except ImportError:
# 如果googletrans不可用尝试使用py-translate
try:
import translate
self.translator = translate
self.translate_func = self._py_translate
except ImportError:
logger.warning("未安装翻译库回译功能将不可用。请安装googletrans或py-translate")
self.translator = None
self.translate_func = self._dummy_translate
except Exception as e:
logger.error(f"初始化翻译器失败: {e}")
self.translator = None
self.translate_func = self._dummy_translate
else:
self.translator = translator
self.translate_func = self._custom_translate
self.source_lang = source_lang
self.target_langs = target_langs or ['en', 'fr', 'de', 'es', 'ja']
def _google_translate(self, text: str, source_lang: str, target_lang: str) -> str:
"""使用googletrans进行翻译"""
try:
result = self.translator.translate(text, src=source_lang, dest=target_lang)
return result.text
except Exception as e:
logger.error(f"翻译失败: {e}")
return text
def _py_translate(self, text: str, source_lang: str, target_lang: str) -> str:
"""使用py-translate进行翻译"""
try:
return self.translator.translate(text, source_lang, target_lang)
except Exception as e:
logger.error(f"翻译失败: {e}")
return text
def _custom_translate(self, text: str, source_lang: str, target_lang: str) -> str:
"""使用自定义翻译器进行翻译"""
try:
return self.translator.translate(text, source_lang, target_lang)
except Exception as e:
logger.error(f"翻译失败: {e}")
return text
def _dummy_translate(self, text: str, source_lang: str, target_lang: str) -> str:
"""虚拟翻译功能,仅返回原文本"""
logger.warning("翻译功能不可用,使用原文本")
return text
def augment(self, text: str) -> str:
"""
对文本进行回译增强
Args:
text: 原始文本
Returns:
增强后的文本
"""
if not text or self.translator is None:
return text
# 随机选择一个目标语言
target_lang = random.choice(self.target_langs)
try:
# 将源语言翻译为目标语言
translated = self.translate_func(text, self.source_lang, target_lang)
# 将目标语言翻译回源语言
back_translated = self.translate_func(translated, target_lang, self.source_lang)
return back_translated
except Exception as e:
logger.error(f"回译失败: {e}")
return text