lmops

Форк
0
/
cross_encoder_model.py 
57 строк · 1.9 Кб
1
import torch
2
import torch.nn as nn
3

4
from typing import Optional, Dict
5
from transformers import (
6
    PreTrainedModel,
7
    AutoModelForSequenceClassification
8
)
9
from transformers.modeling_outputs import SequenceClassifierOutput
10

11
from config import Arguments
12

13

14
class Reranker(nn.Module):
15
    def __init__(self, hf_model: PreTrainedModel, args: Arguments):
16
        super().__init__()
17
        self.hf_model = hf_model
18
        self.args = args
19

20
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
21

22
    def forward(self, batch: Dict[str, torch.Tensor]) -> SequenceClassifierOutput:
23
        input_batch_dict = {k: v for k, v in batch.items() if k != 'labels'}
24

25
        outputs: SequenceClassifierOutput = self.hf_model(**input_batch_dict, return_dict=True)
26
        outputs.logits = outputs.logits.view(-1, self.args.train_n_passages)
27
        loss = self.cross_entropy(outputs.logits, batch['labels'])
28
        outputs.loss = loss
29

30
        return outputs
31

32
    def gradient_checkpointing_enable(self):
33
        self.hf_model.gradient_checkpointing_enable()
34

35
    @classmethod
36
    def from_pretrained(cls, all_args: Arguments, *args, **kwargs):
37
        hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs)
38
        return cls(hf_model, all_args)
39

40
    def save_pretrained(self, output_dir: str):
41
        self.hf_model.save_pretrained(output_dir)
42

43

44
class RerankerForInference(nn.Module):
45
    def __init__(self, hf_model: Optional[PreTrainedModel] = None):
46
        super().__init__()
47
        self.hf_model = hf_model
48
        self.hf_model.eval()
49

50
    @torch.no_grad()
51
    def forward(self, batch) -> SequenceClassifierOutput:
52
        return self.hf_model(**batch)
53

54
    @classmethod
55
    def from_pretrained(cls, pretrained_model_name_or_path: str):
56
        hf_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)
57
        return cls(hf_model)
58

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.