lmops
57 строк · 1.9 Кб
1import torch
2import torch.nn as nn
3
4from typing import Optional, Dict
5from transformers import (
6PreTrainedModel,
7AutoModelForSequenceClassification
8)
9from transformers.modeling_outputs import SequenceClassifierOutput
10
11from config import Arguments
12
13
14class Reranker(nn.Module):
15def __init__(self, hf_model: PreTrainedModel, args: Arguments):
16super().__init__()
17self.hf_model = hf_model
18self.args = args
19
20self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
21
22def forward(self, batch: Dict[str, torch.Tensor]) -> SequenceClassifierOutput:
23input_batch_dict = {k: v for k, v in batch.items() if k != 'labels'}
24
25outputs: SequenceClassifierOutput = self.hf_model(**input_batch_dict, return_dict=True)
26outputs.logits = outputs.logits.view(-1, self.args.train_n_passages)
27loss = self.cross_entropy(outputs.logits, batch['labels'])
28outputs.loss = loss
29
30return outputs
31
32def gradient_checkpointing_enable(self):
33self.hf_model.gradient_checkpointing_enable()
34
35@classmethod
36def from_pretrained(cls, all_args: Arguments, *args, **kwargs):
37hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs)
38return cls(hf_model, all_args)
39
40def save_pretrained(self, output_dir: str):
41self.hf_model.save_pretrained(output_dir)
42
43
44class RerankerForInference(nn.Module):
45def __init__(self, hf_model: Optional[PreTrainedModel] = None):
46super().__init__()
47self.hf_model = hf_model
48self.hf_model.eval()
49
50@torch.no_grad()
51def forward(self, batch) -> SequenceClassifierOutput:
52return self.hf_model(**batch)
53
54@classmethod
55def from_pretrained(cls, pretrained_model_name_or_path: str):
56hf_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)
57return cls(hf_model)
58