aurora

Форк
0
/
collator.py 
27 строк · 928.0 Байт
1
import torch
2
from dataclasses import dataclass
3
from typing import Any, Dict, Sequence
4
from transformers import DataCollatorWithPadding
5

6

7
@dataclass
8
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
9
    r"""
10
    Data collator for pairwise data.
11
    """
12

13
    def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
14
        r"""
15
        Pads batched data to the longest sequence in the batch.
16

17
        We generate 2 * n examples where the first n examples represent chosen examples and
18
        the last n examples represent rejected examples.
19
        """
20
        features = [
21
            {
22
                "input_ids": feature["prompt_ids"] + feature[key],
23
                "attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
24
            }
25
            for key in ("chosen_ids", "rejected_ids") for feature in features
26
        ]
27
        return super().__call__(features)
28

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.