h2o-llmstudio
54 строки · 1.4 Кб
1import logging
2from abc import abstractmethod
3from typing import Any, Dict
4
5import torch
6from torch import nn
7
8logger = logging.getLogger(__name__)
9
10
11class BaseNLPAug(nn.Module):
12"""Base class for NLP augmentation"""
13
14def __init__(self, cfg: Any):
15"""
16Args:
17cfg: config with all the hyperparameters
18"""
19
20super().__init__()
21self.cfg = cfg
22
23@abstractmethod
24def forward(self, batch: Dict) -> Dict:
25"""Augmenting
26
27Args:
28batch: current batch
29
30Returns:
31augmented batch
32"""
33
34if self.cfg.augmentation.token_mask_probability > 0:
35input_ids = batch["input_ids"].clone()
36# special_mask = ~batch["special_tokens_mask"].clone().bool()
37mask = (
38torch.bernoulli(
39torch.full(
40input_ids.shape,
41float(self.cfg.augmentation.token_mask_probability),
42)
43)
44.to(input_ids.device)
45.bool()
46# & special_mask
47).bool()
48input_ids[mask] = self.cfg._tokenizer_mask_token_id
49batch["input_ids"] = input_ids.clone()
50batch["attention_mask"][mask] = 0
51if batch["labels"].shape[1] == batch["input_ids"].shape[1]:
52batch["labels"][mask] = -100
53
54return batch
55