lmops

Форк
0
/
simple_encoder.py 
79 строк · 2.8 Кб
1
import torch
2
import torch.nn.functional as F
3
import tqdm
4

5
from functools import partial
6
from torch.utils.data import DataLoader
7
from datasets import Dataset
8
from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding, PreTrainedTokenizerFast, BatchEncoding
9
from transformers.modeling_outputs import BaseModelOutput
10
from typing import List, Dict
11

12
from utils import pool, move_to_cuda
13

14

15
def _transform_func(tokenizer: PreTrainedTokenizerFast,
16
                    examples: Dict[str, List],
17
                    prompt: str = None) -> BatchEncoding:
18
    if prompt:
19
        examples['input_texts'] = [prompt + t for t in examples['input_texts']]
20
    batch_dict = tokenizer(
21
        examples['input_texts'],
22
        max_length=256,
23
        return_token_type_ids=False,
24
        padding=True,
25
        truncation=True,
26
    )
27

28
    return batch_dict
29

30

31
class SimpleEncoder(torch.nn.Module):
32
    def __init__(self, model_name_or_path: str,
33
                 l2_normalize: bool = True,
34
                 pool_type: str = 'avg',
35
                 prompt: str = 'query: '):
36
        super().__init__()
37
        self.model_name_or_path = model_name_or_path
38
        self.encoder = AutoModel.from_pretrained(model_name_or_path)
39
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
40
        self.gpu_count = torch.cuda.device_count()
41

42
        self.l2_normalize = l2_normalize
43
        self.pool_type = pool_type
44
        self.prompt = prompt
45
        assert self.prompt in ['', 'query: ', 'passage: ']
46

47
        self.encoder.eval()
48
        self.encoder.cuda()
49

50
        if self.gpu_count > 1:
51
            self.encoder = torch.nn.DataParallel(self.encoder)
52

53
    @torch.no_grad()
54
    def encode(self, sentences: List[str], **kwargs) -> torch.Tensor:
55
        dataset: Dataset = Dataset.from_dict({'input_texts': sentences})
56
        dataset.set_transform(partial(_transform_func, self.tokenizer, prompt=self.prompt))
57

58
        data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8)
59
        data_loader = DataLoader(
60
            dataset,
61
            batch_size=128 * self.gpu_count,
62
            shuffle=False,
63
            drop_last=False,
64
            num_workers=2,
65
            collate_fn=data_collator,
66
            pin_memory=True)
67

68
        encoded_embeds = []
69
        for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10, disable=len(sentences) < 128):
70
            batch_dict = move_to_cuda(batch_dict)
71

72
            with torch.cuda.amp.autocast():
73
                outputs: BaseModelOutput = self.encoder(**batch_dict)
74
                embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], self.pool_type)
75
                if self.l2_normalize:
76
                    embeds = F.normalize(embeds, p=2, dim=-1)
77
                encoded_embeds.append(embeds.cpu())
78

79
        return torch.cat(encoded_embeds, dim=0)
80

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.