aurora

Форк
0
/
collator.py 
51 строка · 2.1 Кб
1
import torch
2
from dataclasses import dataclass
3
from typing import Any, Dict, List, Sequence, Tuple
4
from transformers import DataCollatorForSeq2Seq
5

6

7
@dataclass
8
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
9
    r"""
10
    Data collator for pairwise data.
11
    """
12

13
    def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
14
        padded_labels = []
15
        for feature, (prompt_len, answer_len) in zip(batch, positions):
16
            if self.tokenizer.padding_side == "left":
17
                start, end = feature.size(0) - answer_len, feature.size(0)
18
            else:
19
                start, end = prompt_len, prompt_len + answer_len
20
            padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
21
            padded_tensor[start:end] = feature[start:end]
22
            padded_labels.append(padded_tensor)
23
        return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
24

25
    def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
26
        r"""
27
        Pads batched data to the longest sequence in the batch.
28

29
        We generate 2 * n examples where the first n examples represent chosen examples and
30
        the last n examples represent rejected examples.
31
        """
32
        concatenated_features = []
33
        label_positions = []
34
        for key in ("chosen_ids", "rejected_ids"):
35
            for feature in features:
36
                prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
37
                concatenated_features.append({
38
                    "input_ids": feature["prompt_ids"] + feature[key],
39
                    "attention_mask": [1] * (prompt_len + answer_len)
40
                })
41
                label_positions.append((prompt_len, answer_len))
42

43
        batch = self.tokenizer.pad(
44
            concatenated_features,
45
            padding=self.padding,
46
            max_length=self.max_length,
47
            pad_to_multiple_of=self.pad_to_multiple_of,
48
            return_tensors=self.return_tensors,
49
        )
50
        batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
51
        return batch
52

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.