lmops
84 строки · 3.3 Кб
1import torch2
3from dataclasses import dataclass4from typing import List, Dict, Any, Union, Optional5from transformers import BatchEncoding, PreTrainedTokenizerBase6from transformers.file_utils import PaddingStrategy7
8from logger_config import logger9
10
11@dataclass
12class ScoreCollator:13tokenizer: PreTrainedTokenizerBase14padding: Union[bool, str, PaddingStrategy] = True15max_length: Optional[int] = None16pad_to_multiple_of: Optional[int] = None17return_tensors: str = "pt"18delimiter: str = '\n'19
20def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:21self.tokenizer.padding_side = 'right'22input_texts = [f['input_texts'] for f in features]23output_texts = [f['output_texts'] for f in features]24assert all(not text.endswith(self.delimiter) for text in input_texts)25assert all(not text.startswith(self.delimiter) for text in output_texts)26concat_texts: List[str] = [self.delimiter.join([inp, out]) for inp, out in zip(input_texts, output_texts)]27
28batch_dict = self.tokenizer(29concat_texts,30max_length=self.max_length,31truncation=True,32padding=self.padding,33pad_to_multiple_of=self.pad_to_multiple_of,34return_tensors=self.return_tensors)35
36labels = batch_dict['input_ids'].clone()37if self.tokenizer.pad_token_id is not None:38labels[labels == self.tokenizer.pad_token_id] = -10039num_valid_tokens = torch.cumsum(batch_dict['attention_mask'], dim=1)40output_lengths: torch.LongTensor = torch.LongTensor(self._get_output_lengths(output_texts))41logger.debug('output lengths: {}'.format(output_lengths))42input_lengths: torch.LongTensor = torch.sum(batch_dict['attention_mask'], dim=1) - output_lengths43labels[num_valid_tokens <= input_lengths[:, None]] = -10044batch_dict['labels'] = labels45
46return batch_dict47
48def _get_output_lengths(self, output_texts: List[str]) -> List[int]:49output_ids: List[List[int]] = self.tokenizer(50output_texts, max_length=self.max_length, truncation=True, padding=False51)['input_ids']52
53for idx in range(len(output_ids)):54# llama tokenizer prepend a bos token55if output_ids[idx][0] == self.tokenizer.bos_token_id:56output_ids[idx] = output_ids[idx][1:]57
58lengths: List[int] = [len(output_id) for output_id in output_ids]59assert all(length > 0 for length in lengths), lengths60
61return lengths62
63
64@dataclass
65class DecodeCollator:66tokenizer: PreTrainedTokenizerBase67padding: Union[bool, str, PaddingStrategy] = True68max_length: Optional[int] = None69pad_to_multiple_of: Optional[int] = None70return_tensors: str = "pt"71
72def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:73# batch_score requires right padding, but generate requires left padding74self.tokenizer.padding_side = 'left'75input_texts = [f['input_texts'] for f in features]76
77batch_dict = self.tokenizer(78input_texts,79max_length=self.max_length,80truncation=True,81padding=self.padding,82pad_to_multiple_of=self.pad_to_multiple_of,83return_tensors=self.return_tensors)84return batch_dict85