h2o-llmstudio

Форк
0
54 строки · 1.4 Кб
1
import logging
2
from abc import abstractmethod
3
from typing import Any, Dict
4

5
import torch
6
from torch import nn
7

8
logger = logging.getLogger(__name__)
9

10

11
class BaseNLPAug(nn.Module):
12
    """Base class for NLP augmentation"""
13

14
    def __init__(self, cfg: Any):
15
        """
16
        Args:
17
            cfg: config with all the hyperparameters
18
        """
19

20
        super().__init__()
21
        self.cfg = cfg
22

23
    @abstractmethod
24
    def forward(self, batch: Dict) -> Dict:
25
        """Augmenting
26

27
        Args:
28
            batch: current batch
29

30
        Returns:
31
            augmented batch
32
        """
33

34
        if self.cfg.augmentation.token_mask_probability > 0:
35
            input_ids = batch["input_ids"].clone()
36
            # special_mask = ~batch["special_tokens_mask"].clone().bool()
37
            mask = (
38
                torch.bernoulli(
39
                    torch.full(
40
                        input_ids.shape,
41
                        float(self.cfg.augmentation.token_mask_probability),
42
                    )
43
                )
44
                .to(input_ids.device)
45
                .bool()
46
                # & special_mask
47
            ).bool()
48
            input_ids[mask] = self.cfg._tokenizer_mask_token_id
49
            batch["input_ids"] = input_ids.clone()
50
            batch["attention_mask"][mask] = 0
51
            if batch["labels"].shape[1] == batch["input_ids"].shape[1]:
52
                batch["labels"][mask] = -100
53

54
        return batch
55

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.