lmops

Форк
0
/
biencoder_collator.py 
43 строки · 1.5 Кб
1
import torch
2

3
from dataclasses import dataclass
4
from typing import List, Dict, Any, Union, Optional
5
from transformers import BatchEncoding, PreTrainedTokenizerBase
6
from transformers.file_utils import PaddingStrategy
7

8
from config import Arguments
9

10

11
@dataclass
12
class BiencoderCollator:
13

14
    args: Arguments
15
    tokenizer: PreTrainedTokenizerBase
16
    padding: Union[bool, str, PaddingStrategy] = True
17
    max_length: Optional[int] = None
18
    pad_to_multiple_of: Optional[int] = None
19
    return_tensors: str = "pt"
20

21
    def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
22
        queries: List[str] = [f['query'] for f in features]
23
        passages: List[str] = sum([f['passages'] for f in features], [])
24

25
        input_texts = queries + passages
26

27
        merged_batch_dict = self.tokenizer(
28
            input_texts,
29
            max_length=self.args.max_len,
30
            truncation=True,
31
            padding=self.padding,
32
            return_token_type_ids=False,
33
            pad_to_multiple_of=self.pad_to_multiple_of,
34
            return_tensors=self.return_tensors)
35

36
        # dummy placeholder for field "labels", won't use it to compute loss
37
        labels = torch.zeros(len(queries), dtype=torch.long)
38
        merged_batch_dict['labels'] = labels
39

40
        if 'kd_labels' in features[0]:
41
            kd_labels = torch.stack([torch.tensor(f['kd_labels']) for f in features], dim=0).float()
42
            merged_batch_dict['kd_labels'] = kd_labels
43
        return merged_batch_dict
44

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.