lmops
43 строки · 1.5 Кб
1import torch2
3from dataclasses import dataclass4from typing import List, Dict, Any, Union, Optional5from transformers import BatchEncoding, PreTrainedTokenizerBase6from transformers.file_utils import PaddingStrategy7
8from config import Arguments9
10
11@dataclass
12class BiencoderCollator:13
14args: Arguments15tokenizer: PreTrainedTokenizerBase16padding: Union[bool, str, PaddingStrategy] = True17max_length: Optional[int] = None18pad_to_multiple_of: Optional[int] = None19return_tensors: str = "pt"20
21def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:22queries: List[str] = [f['query'] for f in features]23passages: List[str] = sum([f['passages'] for f in features], [])24
25input_texts = queries + passages26
27merged_batch_dict = self.tokenizer(28input_texts,29max_length=self.args.max_len,30truncation=True,31padding=self.padding,32return_token_type_ids=False,33pad_to_multiple_of=self.pad_to_multiple_of,34return_tensors=self.return_tensors)35
36# dummy placeholder for field "labels", won't use it to compute loss37labels = torch.zeros(len(queries), dtype=torch.long)38merged_batch_dict['labels'] = labels39
40if 'kd_labels' in features[0]:41kd_labels = torch.stack([torch.tensor(f['kd_labels']) for f in features], dim=0).float()42merged_batch_dict['kd_labels'] = kd_labels43return merged_batch_dict44