lmops
79 строк · 2.8 Кб
1import torch
2import torch.nn.functional as F
3import tqdm
4
5from functools import partial
6from torch.utils.data import DataLoader
7from datasets import Dataset
8from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding, PreTrainedTokenizerFast, BatchEncoding
9from transformers.modeling_outputs import BaseModelOutput
10from typing import List, Dict
11
12from utils import pool, move_to_cuda
13
14
15def _transform_func(tokenizer: PreTrainedTokenizerFast,
16examples: Dict[str, List],
17prompt: str = None) -> BatchEncoding:
18if prompt:
19examples['input_texts'] = [prompt + t for t in examples['input_texts']]
20batch_dict = tokenizer(
21examples['input_texts'],
22max_length=256,
23return_token_type_ids=False,
24padding=True,
25truncation=True,
26)
27
28return batch_dict
29
30
31class SimpleEncoder(torch.nn.Module):
32def __init__(self, model_name_or_path: str,
33l2_normalize: bool = True,
34pool_type: str = 'avg',
35prompt: str = 'query: '):
36super().__init__()
37self.model_name_or_path = model_name_or_path
38self.encoder = AutoModel.from_pretrained(model_name_or_path)
39self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
40self.gpu_count = torch.cuda.device_count()
41
42self.l2_normalize = l2_normalize
43self.pool_type = pool_type
44self.prompt = prompt
45assert self.prompt in ['', 'query: ', 'passage: ']
46
47self.encoder.eval()
48self.encoder.cuda()
49
50if self.gpu_count > 1:
51self.encoder = torch.nn.DataParallel(self.encoder)
52
53@torch.no_grad()
54def encode(self, sentences: List[str], **kwargs) -> torch.Tensor:
55dataset: Dataset = Dataset.from_dict({'input_texts': sentences})
56dataset.set_transform(partial(_transform_func, self.tokenizer, prompt=self.prompt))
57
58data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8)
59data_loader = DataLoader(
60dataset,
61batch_size=128 * self.gpu_count,
62shuffle=False,
63drop_last=False,
64num_workers=2,
65collate_fn=data_collator,
66pin_memory=True)
67
68encoded_embeds = []
69for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10, disable=len(sentences) < 128):
70batch_dict = move_to_cuda(batch_dict)
71
72with torch.cuda.amp.autocast():
73outputs: BaseModelOutput = self.encoder(**batch_dict)
74embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], self.pool_type)
75if self.l2_normalize:
76embeds = F.normalize(embeds, p=2, dim=-1)
77encoded_embeds.append(embeds.cpu())
78
79return torch.cat(encoded_embeds, dim=0)
80