lmops

Форк
0
/
gpt2_collator.py 
84 строки · 3.3 Кб
1
import torch
2

3
from dataclasses import dataclass
4
from typing import List, Dict, Any, Union, Optional
5
from transformers import BatchEncoding, PreTrainedTokenizerBase
6
from transformers.file_utils import PaddingStrategy
7

8
from logger_config import logger
9

10

11
@dataclass
12
class ScoreCollator:
13
    tokenizer: PreTrainedTokenizerBase
14
    padding: Union[bool, str, PaddingStrategy] = True
15
    max_length: Optional[int] = None
16
    pad_to_multiple_of: Optional[int] = None
17
    return_tensors: str = "pt"
18
    delimiter: str = '\n'
19

20
    def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
21
        self.tokenizer.padding_side = 'right'
22
        input_texts = [f['input_texts'] for f in features]
23
        output_texts = [f['output_texts'] for f in features]
24
        assert all(not text.endswith(self.delimiter) for text in input_texts)
25
        assert all(not text.startswith(self.delimiter) for text in output_texts)
26
        concat_texts: List[str] = [self.delimiter.join([inp, out]) for inp, out in zip(input_texts, output_texts)]
27

28
        batch_dict = self.tokenizer(
29
            concat_texts,
30
            max_length=self.max_length,
31
            truncation=True,
32
            padding=self.padding,
33
            pad_to_multiple_of=self.pad_to_multiple_of,
34
            return_tensors=self.return_tensors)
35

36
        labels = batch_dict['input_ids'].clone()
37
        if self.tokenizer.pad_token_id is not None:
38
            labels[labels == self.tokenizer.pad_token_id] = -100
39
        num_valid_tokens = torch.cumsum(batch_dict['attention_mask'], dim=1)
40
        output_lengths: torch.LongTensor = torch.LongTensor(self._get_output_lengths(output_texts))
41
        logger.debug('output lengths: {}'.format(output_lengths))
42
        input_lengths: torch.LongTensor = torch.sum(batch_dict['attention_mask'], dim=1) - output_lengths
43
        labels[num_valid_tokens <= input_lengths[:, None]] = -100
44
        batch_dict['labels'] = labels
45

46
        return batch_dict
47

48
    def _get_output_lengths(self, output_texts: List[str]) -> List[int]:
49
        output_ids: List[List[int]] = self.tokenizer(
50
            output_texts, max_length=self.max_length, truncation=True, padding=False
51
        )['input_ids']
52

53
        for idx in range(len(output_ids)):
54
            # llama tokenizer prepend a bos token
55
            if output_ids[idx][0] == self.tokenizer.bos_token_id:
56
                output_ids[idx] = output_ids[idx][1:]
57

58
        lengths: List[int] = [len(output_id) for output_id in output_ids]
59
        assert all(length > 0 for length in lengths), lengths
60

61
        return lengths
62

63

64
@dataclass
65
class DecodeCollator:
66
    tokenizer: PreTrainedTokenizerBase
67
    padding: Union[bool, str, PaddingStrategy] = True
68
    max_length: Optional[int] = None
69
    pad_to_multiple_of: Optional[int] = None
70
    return_tensors: str = "pt"
71

72
    def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
73
        # batch_score requires right padding, but generate requires left padding
74
        self.tokenizer.padding_side = 'left'
75
        input_texts = [f['input_texts'] for f in features]
76

77
        batch_dict = self.tokenizer(
78
            input_texts,
79
            max_length=self.max_length,
80
            truncation=True,
81
            padding=self.padding,
82
            pad_to_multiple_of=self.pad_to_multiple_of,
83
            return_tensors=self.return_tensors)
84
        return batch_dict
85

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.