""" 数据增强模块:实现文本数据增强技术 """ import random import re import jieba import synonyms import numpy as np from typing import List, Dict, Tuple, Optional, Any, Union, Callable import copy from config.model_config import RANDOM_SEED from utils.logger import get_logger from preprocessing.tokenization import ChineseTokenizer # 设置随机种子以保证可重复性 random.seed(RANDOM_SEED) np.random.seed(RANDOM_SEED) logger = get_logger("DataAugmentation") class TextAugmenter: """文本增强基类,定义通用接口""" def __init__(self): """初始化文本增强器""" pass def augment(self, text: str) -> str: """ 对文本进行增强 Args: text: 原始文本 Returns: 增强后的文本 """ raise NotImplementedError("子类必须实现此方法") def batch_augment(self, texts: List[str]) -> List[str]: """ 批量对文本进行增强 Args: texts: 原始文本列表 Returns: 增强后的文本列表 """ return [self.augment(text) for text in texts] def augment_with_label(self, text: str, label: Any) -> Tuple[str, Any]: """ 对文本进行增强,同时保留标签 Args: text: 原始文本 label: 标签 Returns: (增强后的文本, 标签)的元组 """ return self.augment(text), label def batch_augment_with_label(self, texts: List[str], labels: List[Any]) -> List[Tuple[str, Any]]: """ 批量对文本进行增强,同时保留标签 Args: texts: 原始文本列表 labels: 标签列表 Returns: (增强后的文本, 标签)的元组列表 """ return [self.augment_with_label(text, label) for text, label in zip(texts, labels)] class SynonymReplacement(TextAugmenter): """同义词替换增强器""" def __init__(self, tokenizer: Optional[ChineseTokenizer] = None, replace_ratio: float = 0.1, min_similarity: float = 0.7): """ 初始化同义词替换增强器 Args: tokenizer: 分词器,如果为None则创建一个新的分词器 replace_ratio: 替换比例,表示要替换的词占总词数的比例 min_similarity: 最小相似度,只有相似度大于该值的同义词才会被用于替换 """ super().__init__() self.tokenizer = tokenizer or ChineseTokenizer() self.replace_ratio = replace_ratio self.min_similarity = min_similarity def _get_synonym(self, word: str) -> Optional[str]: """ 获取词的同义词 Args: word: 原始词 Returns: 同义词,如果没有合适的同义词则返回None """ # 使用synonyms包获取同义词 try: synonyms_list = synonyms.nearby(word) # synonyms.nearby返回一个元组,第一个元素是相似词列表,第二个元素是相似度列表 words = synonyms_list[0] similarities = synonyms_list[1] # 过滤掉相似度低于阈值的词和原词本身 valid_synonyms = [(w, s) for w, s in zip(words, similarities) if s >= self.min_similarity and w != word] if valid_synonyms: # 按相似度排序,选择最相似的词 valid_synonyms.sort(key=lambda x: x[1], reverse=True) return valid_synonyms[0][0] return None except: return None def augment(self, text: str) -> str: """ 对文本进行同义词替换增强 Args: text: 原始文本 Returns: 增强后的文本 """ if not text: return text # 分词 words = self.tokenizer.tokenize(text, return_string=False, cut_all=False) if not words: return text # 计算要替换的词数量 n_replace = max(1, int(len(words) * self.replace_ratio)) # 随机选择要替换的词索引 replace_indices = random.sample(range(len(words)), min(n_replace, len(words))) # 替换为同义词 for idx in replace_indices: synonym = self._get_synonym(words[idx]) if synonym: words[idx] = synonym # 拼接为文本 augmented_text = ''.join(words) return augmented_text class RandomDeletion(TextAugmenter): """随机删除增强器""" def __init__(self, tokenizer: Optional[ChineseTokenizer] = None, delete_ratio: float = 0.1): """ 初始化随机删除增强器 Args: tokenizer: 分词器,如果为None则创建一个新的分词器 delete_ratio: 删除比例,表示要删除的词占总词数的比例 """ super().__init__() self.tokenizer = tokenizer or ChineseTokenizer() self.delete_ratio = delete_ratio def augment(self, text: str) -> str: """ 对文本进行随机删除增强 Args: text: 原始文本 Returns: 增强后的文本 """ if not text: return text # 分词 words = self.tokenizer.tokenize(text, return_string=False, cut_all=False) if len(words) <= 1: return text # 计算要删除的词数量 n_delete = max(1, int(len(words) * self.delete_ratio)) # 随机选择要删除的词索引 delete_indices = random.sample(range(len(words)), min(n_delete, len(words) - 1)) # 删除选中的词 augmented_words = [words[i] for i in range(len(words)) if i not in delete_indices] # 拼接为文本 augmented_text = ''.join(augmented_words) return augmented_text class RandomSwap(TextAugmenter): """随机交换增强器""" def __init__(self, tokenizer: Optional[ChineseTokenizer] = None, n_swaps: int = 1): """ 初始化随机交换增强器 Args: tokenizer: 分词器,如果为None则创建一个新的分词器 n_swaps: 交换次数 """ super().__init__() self.tokenizer = tokenizer or ChineseTokenizer() self.n_swaps = n_swaps def augment(self, text: str) -> str: """ 对文本进行随机交换增强 Args: text: 原始文本 Returns: 增强后的文本 """ if not text: return text # 分词 words = self.tokenizer.tokenize(text, return_string=False, cut_all=False) if len(words) <= 1: return text # 进行n_swaps次随机交换 augmented_words = words.copy() for _ in range(min(self.n_swaps, len(words) // 2)): # 随机选择两个不同的索引 idx1, idx2 = random.sample(range(len(augmented_words)), 2) # 交换两个词 augmented_words[idx1], augmented_words[idx2] = augmented_words[idx2], augmented_words[idx1] # 拼接为文本 augmented_text = ''.join(augmented_words) return augmented_text class CompositeAugmenter(TextAugmenter): """组合增强器,组合多个增强器""" def __init__(self, augmenters: List[TextAugmenter], probs: Optional[List[float]] = None): """ 初始化组合增强器 Args: augmenters: 增强器列表 probs: 各增强器被选择的概率列表,如果为None则均匀选择 """ super().__init__() self.augmenters = augmenters # 如果没有提供概率,则均匀分配 if probs is None: self.probs = [1.0 / len(augmenters)] * len(augmenters) else: # 确保概率和为1 total = sum(probs) self.probs = [p / total for p in probs] assert len(self.augmenters) == len(self.probs), "增强器数量与概率数量不匹配" def augment(self, text: str) -> str: """ 对文本进行组合增强 Args: text: 原始文本 Returns: 增强后的文本 """ if not text: return text # 根据概率随机选择一个增强器 augmenter = random.choices(self.augmenters, weights=self.probs, k=1)[0] # 使用选中的增强器进行增强 return augmenter.augment(text) class BackTranslation(TextAugmenter): """回译增强器""" def __init__(self, translator=None, source_lang: str = 'zh', target_langs: List[str] = None): """ 初始化回译增强器 Args: translator: 翻译器,需要实现translate方法 source_lang: 源语言代码 target_langs: 目标语言代码列表,如果为None则使用默认语言 """ super().__init__() # 如果没有提供翻译器,尝试使用第三方翻译库 if translator is None: try: # 尝试导入多种翻译库 # 首先尝试使用googletrans (需要单独安装: pip install googletrans==4.0.0-rc1) try: from googletrans import Translator self.translator = Translator() self.translate_func = self._google_translate except ImportError: # 如果googletrans不可用,尝试使用py-translate try: import translate self.translator = translate self.translate_func = self._py_translate except ImportError: logger.warning("未安装翻译库,回译功能将不可用。请安装googletrans或py-translate") self.translator = None self.translate_func = self._dummy_translate except Exception as e: logger.error(f"初始化翻译器失败: {e}") self.translator = None self.translate_func = self._dummy_translate else: self.translator = translator self.translate_func = self._custom_translate self.source_lang = source_lang self.target_langs = target_langs or ['en', 'fr', 'de', 'es', 'ja'] def _google_translate(self, text: str, source_lang: str, target_lang: str) -> str: """使用googletrans进行翻译""" try: result = self.translator.translate(text, src=source_lang, dest=target_lang) return result.text except Exception as e: logger.error(f"翻译失败: {e}") return text def _py_translate(self, text: str, source_lang: str, target_lang: str) -> str: """使用py-translate进行翻译""" try: return self.translator.translate(text, source_lang, target_lang) except Exception as e: logger.error(f"翻译失败: {e}") return text def _custom_translate(self, text: str, source_lang: str, target_lang: str) -> str: """使用自定义翻译器进行翻译""" try: return self.translator.translate(text, source_lang, target_lang) except Exception as e: logger.error(f"翻译失败: {e}") return text def _dummy_translate(self, text: str, source_lang: str, target_lang: str) -> str: """虚拟翻译功能,仅返回原文本""" logger.warning("翻译功能不可用,使用原文本") return text def augment(self, text: str) -> str: """ 对文本进行回译增强 Args: text: 原始文本 Returns: 增强后的文本 """ if not text or self.translator is None: return text # 随机选择一个目标语言 target_lang = random.choice(self.target_langs) try: # 将源语言翻译为目标语言 translated = self.translate_func(text, self.source_lang, target_lang) # 将目标语言翻译回源语言 back_translated = self.translate_func(translated, target_lang, self.source_lang) return back_translated except Exception as e: logger.error(f"回译失败: {e}") return text