414 lines
12 KiB
Python
414 lines
12 KiB
Python
"""
|
||
数据增强模块:实现文本数据增强技术
|
||
"""
|
||
import random
|
||
import re
|
||
import jieba
|
||
import synonyms
|
||
import numpy as np
|
||
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
|
||
import copy
|
||
|
||
from config.model_config import RANDOM_SEED
|
||
from utils.logger import get_logger
|
||
from preprocessing.tokenization import ChineseTokenizer
|
||
|
||
# 设置随机种子以保证可重复性
|
||
random.seed(RANDOM_SEED)
|
||
np.random.seed(RANDOM_SEED)
|
||
|
||
logger = get_logger("DataAugmentation")
|
||
|
||
|
||
class TextAugmenter:
|
||
"""文本增强基类,定义通用接口"""
|
||
|
||
def __init__(self):
|
||
"""初始化文本增强器"""
|
||
pass
|
||
|
||
def augment(self, text: str) -> str:
|
||
"""
|
||
对文本进行增强
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
增强后的文本
|
||
"""
|
||
raise NotImplementedError("子类必须实现此方法")
|
||
|
||
def batch_augment(self, texts: List[str]) -> List[str]:
|
||
"""
|
||
批量对文本进行增强
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
|
||
Returns:
|
||
增强后的文本列表
|
||
"""
|
||
return [self.augment(text) for text in texts]
|
||
|
||
def augment_with_label(self, text: str, label: Any) -> Tuple[str, Any]:
|
||
"""
|
||
对文本进行增强,同时保留标签
|
||
|
||
Args:
|
||
text: 原始文本
|
||
label: 标签
|
||
|
||
Returns:
|
||
(增强后的文本, 标签)的元组
|
||
"""
|
||
return self.augment(text), label
|
||
|
||
def batch_augment_with_label(self, texts: List[str], labels: List[Any]) -> List[Tuple[str, Any]]:
|
||
"""
|
||
批量对文本进行增强,同时保留标签
|
||
|
||
Args:
|
||
texts: 原始文本列表
|
||
labels: 标签列表
|
||
|
||
Returns:
|
||
(增强后的文本, 标签)的元组列表
|
||
"""
|
||
return [self.augment_with_label(text, label) for text, label in zip(texts, labels)]
|
||
|
||
|
||
class SynonymReplacement(TextAugmenter):
|
||
"""同义词替换增强器"""
|
||
|
||
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
|
||
replace_ratio: float = 0.1,
|
||
min_similarity: float = 0.7):
|
||
"""
|
||
初始化同义词替换增强器
|
||
|
||
Args:
|
||
tokenizer: 分词器,如果为None则创建一个新的分词器
|
||
replace_ratio: 替换比例,表示要替换的词占总词数的比例
|
||
min_similarity: 最小相似度,只有相似度大于该值的同义词才会被用于替换
|
||
"""
|
||
super().__init__()
|
||
self.tokenizer = tokenizer or ChineseTokenizer()
|
||
self.replace_ratio = replace_ratio
|
||
self.min_similarity = min_similarity
|
||
|
||
def _get_synonym(self, word: str) -> Optional[str]:
|
||
"""
|
||
获取词的同义词
|
||
|
||
Args:
|
||
word: 原始词
|
||
|
||
Returns:
|
||
同义词,如果没有合适的同义词则返回None
|
||
"""
|
||
# 使用synonyms包获取同义词
|
||
try:
|
||
synonyms_list = synonyms.nearby(word)
|
||
|
||
# synonyms.nearby返回一个元组,第一个元素是相似词列表,第二个元素是相似度列表
|
||
words = synonyms_list[0]
|
||
similarities = synonyms_list[1]
|
||
|
||
# 过滤掉相似度低于阈值的词和原词本身
|
||
valid_synonyms = [(w, s) for w, s in zip(words, similarities)
|
||
if s >= self.min_similarity and w != word]
|
||
|
||
if valid_synonyms:
|
||
# 按相似度排序,选择最相似的词
|
||
valid_synonyms.sort(key=lambda x: x[1], reverse=True)
|
||
return valid_synonyms[0][0]
|
||
|
||
return None
|
||
except:
|
||
return None
|
||
|
||
def augment(self, text: str) -> str:
|
||
"""
|
||
对文本进行同义词替换增强
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
增强后的文本
|
||
"""
|
||
if not text:
|
||
return text
|
||
|
||
# 分词
|
||
words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
|
||
|
||
if not words:
|
||
return text
|
||
|
||
# 计算要替换的词数量
|
||
n_replace = max(1, int(len(words) * self.replace_ratio))
|
||
|
||
# 随机选择要替换的词索引
|
||
replace_indices = random.sample(range(len(words)), min(n_replace, len(words)))
|
||
|
||
# 替换为同义词
|
||
for idx in replace_indices:
|
||
synonym = self._get_synonym(words[idx])
|
||
if synonym:
|
||
words[idx] = synonym
|
||
|
||
# 拼接为文本
|
||
augmented_text = ''.join(words)
|
||
|
||
return augmented_text
|
||
|
||
|
||
class RandomDeletion(TextAugmenter):
|
||
"""随机删除增强器"""
|
||
|
||
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
|
||
delete_ratio: float = 0.1):
|
||
"""
|
||
初始化随机删除增强器
|
||
|
||
Args:
|
||
tokenizer: 分词器,如果为None则创建一个新的分词器
|
||
delete_ratio: 删除比例,表示要删除的词占总词数的比例
|
||
"""
|
||
super().__init__()
|
||
self.tokenizer = tokenizer or ChineseTokenizer()
|
||
self.delete_ratio = delete_ratio
|
||
|
||
def augment(self, text: str) -> str:
|
||
"""
|
||
对文本进行随机删除增强
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
增强后的文本
|
||
"""
|
||
if not text:
|
||
return text
|
||
|
||
# 分词
|
||
words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
|
||
|
||
if len(words) <= 1:
|
||
return text
|
||
|
||
# 计算要删除的词数量
|
||
n_delete = max(1, int(len(words) * self.delete_ratio))
|
||
|
||
# 随机选择要删除的词索引
|
||
delete_indices = random.sample(range(len(words)), min(n_delete, len(words) - 1))
|
||
|
||
# 删除选中的词
|
||
augmented_words = [words[i] for i in range(len(words)) if i not in delete_indices]
|
||
|
||
# 拼接为文本
|
||
augmented_text = ''.join(augmented_words)
|
||
|
||
return augmented_text
|
||
|
||
|
||
class RandomSwap(TextAugmenter):
|
||
"""随机交换增强器"""
|
||
|
||
def __init__(self, tokenizer: Optional[ChineseTokenizer] = None,
|
||
n_swaps: int = 1):
|
||
"""
|
||
初始化随机交换增强器
|
||
|
||
Args:
|
||
tokenizer: 分词器,如果为None则创建一个新的分词器
|
||
n_swaps: 交换次数
|
||
"""
|
||
super().__init__()
|
||
self.tokenizer = tokenizer or ChineseTokenizer()
|
||
self.n_swaps = n_swaps
|
||
|
||
def augment(self, text: str) -> str:
|
||
"""
|
||
对文本进行随机交换增强
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
增强后的文本
|
||
"""
|
||
if not text:
|
||
return text
|
||
|
||
# 分词
|
||
words = self.tokenizer.tokenize(text, return_string=False, cut_all=False)
|
||
|
||
if len(words) <= 1:
|
||
return text
|
||
|
||
# 进行n_swaps次随机交换
|
||
augmented_words = words.copy()
|
||
for _ in range(min(self.n_swaps, len(words) // 2)):
|
||
# 随机选择两个不同的索引
|
||
idx1, idx2 = random.sample(range(len(augmented_words)), 2)
|
||
|
||
# 交换两个词
|
||
augmented_words[idx1], augmented_words[idx2] = augmented_words[idx2], augmented_words[idx1]
|
||
|
||
# 拼接为文本
|
||
augmented_text = ''.join(augmented_words)
|
||
|
||
return augmented_text
|
||
|
||
|
||
class CompositeAugmenter(TextAugmenter):
|
||
"""组合增强器,组合多个增强器"""
|
||
|
||
def __init__(self, augmenters: List[TextAugmenter],
|
||
probs: Optional[List[float]] = None):
|
||
"""
|
||
初始化组合增强器
|
||
|
||
Args:
|
||
augmenters: 增强器列表
|
||
probs: 各增强器被选择的概率列表,如果为None则均匀选择
|
||
"""
|
||
super().__init__()
|
||
self.augmenters = augmenters
|
||
|
||
# 如果没有提供概率,则均匀分配
|
||
if probs is None:
|
||
self.probs = [1.0 / len(augmenters)] * len(augmenters)
|
||
else:
|
||
# 确保概率和为1
|
||
total = sum(probs)
|
||
self.probs = [p / total for p in probs]
|
||
|
||
assert len(self.augmenters) == len(self.probs), "增强器数量与概率数量不匹配"
|
||
|
||
def augment(self, text: str) -> str:
|
||
"""
|
||
对文本进行组合增强
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
增强后的文本
|
||
"""
|
||
if not text:
|
||
return text
|
||
|
||
# 根据概率随机选择一个增强器
|
||
augmenter = random.choices(self.augmenters, weights=self.probs, k=1)[0]
|
||
|
||
# 使用选中的增强器进行增强
|
||
return augmenter.augment(text)
|
||
|
||
|
||
class BackTranslation(TextAugmenter):
|
||
"""回译增强器"""
|
||
|
||
def __init__(self, translator=None, source_lang: str = 'zh',
|
||
target_langs: List[str] = None):
|
||
"""
|
||
初始化回译增强器
|
||
|
||
Args:
|
||
translator: 翻译器,需要实现translate方法
|
||
source_lang: 源语言代码
|
||
target_langs: 目标语言代码列表,如果为None则使用默认语言
|
||
"""
|
||
super().__init__()
|
||
|
||
# 如果没有提供翻译器,尝试使用第三方翻译库
|
||
if translator is None:
|
||
try:
|
||
# 尝试导入多种翻译库
|
||
# 首先尝试使用googletrans (需要单独安装: pip install googletrans==4.0.0-rc1)
|
||
try:
|
||
from googletrans import Translator
|
||
self.translator = Translator()
|
||
self.translate_func = self._google_translate
|
||
except ImportError:
|
||
# 如果googletrans不可用,尝试使用py-translate
|
||
try:
|
||
import translate
|
||
self.translator = translate
|
||
self.translate_func = self._py_translate
|
||
except ImportError:
|
||
logger.warning("未安装翻译库,回译功能将不可用。请安装googletrans或py-translate")
|
||
self.translator = None
|
||
self.translate_func = self._dummy_translate
|
||
except Exception as e:
|
||
logger.error(f"初始化翻译器失败: {e}")
|
||
self.translator = None
|
||
self.translate_func = self._dummy_translate
|
||
else:
|
||
self.translator = translator
|
||
self.translate_func = self._custom_translate
|
||
|
||
self.source_lang = source_lang
|
||
self.target_langs = target_langs or ['en', 'fr', 'de', 'es', 'ja']
|
||
|
||
def _google_translate(self, text: str, source_lang: str, target_lang: str) -> str:
|
||
"""使用googletrans进行翻译"""
|
||
try:
|
||
result = self.translator.translate(text, src=source_lang, dest=target_lang)
|
||
return result.text
|
||
except Exception as e:
|
||
logger.error(f"翻译失败: {e}")
|
||
return text
|
||
|
||
def _py_translate(self, text: str, source_lang: str, target_lang: str) -> str:
|
||
"""使用py-translate进行翻译"""
|
||
try:
|
||
return self.translator.translate(text, source_lang, target_lang)
|
||
except Exception as e:
|
||
logger.error(f"翻译失败: {e}")
|
||
return text
|
||
|
||
def _custom_translate(self, text: str, source_lang: str, target_lang: str) -> str:
|
||
"""使用自定义翻译器进行翻译"""
|
||
try:
|
||
return self.translator.translate(text, source_lang, target_lang)
|
||
except Exception as e:
|
||
logger.error(f"翻译失败: {e}")
|
||
return text
|
||
|
||
def _dummy_translate(self, text: str, source_lang: str, target_lang: str) -> str:
|
||
"""虚拟翻译功能,仅返回原文本"""
|
||
logger.warning("翻译功能不可用,使用原文本")
|
||
return text
|
||
|
||
def augment(self, text: str) -> str:
|
||
"""
|
||
对文本进行回译增强
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
增强后的文本
|
||
"""
|
||
if not text or self.translator is None:
|
||
return text
|
||
|
||
# 随机选择一个目标语言
|
||
target_lang = random.choice(self.target_langs)
|
||
|
||
try:
|
||
# 将源语言翻译为目标语言
|
||
translated = self.translate_func(text, self.source_lang, target_lang)
|
||
|
||
# 将目标语言翻译回源语言
|
||
back_translated = self.translate_func(translated, target_lang, self.source_lang)
|
||
|
||
return back_translated
|
||
except Exception as e:
|
||
logger.error(f"回译失败: {e}")
|
||
return text |