lmops

Форк
0
/
biencoder_trainer.py 
54 строки · 2.1 Кб
1
import os
2

3
from typing import Optional
4
from transformers.trainer import Trainer
5

6
from logger_config import logger
7
from evaluation.metrics import accuracy, batch_mrr
8
from models import BiencoderOutput, BiencoderModel
9
from utils import AverageMeter
10

11

12
class BiencoderTrainer(Trainer):
13
    def __init__(self, *pargs, **kwargs):
14
        super(BiencoderTrainer, self).__init__(*pargs, **kwargs)
15
        self.model: BiencoderModel
16

17
        self.acc1_meter = AverageMeter('Acc@1', round_digits=2)
18
        self.acc3_meter = AverageMeter('Acc@3', round_digits=2)
19
        self.mrr_meter = AverageMeter('mrr', round_digits=2)
20
        self.last_epoch = 0
21

22
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
23
        output_dir = output_dir if output_dir is not None else self.args.output_dir
24
        os.makedirs(output_dir, exist_ok=True)
25
        logger.info("Saving model checkpoint to {}".format(output_dir))
26
        self.model.save(output_dir)
27
        if self.tokenizer is not None:
28
            self.tokenizer.save_pretrained(output_dir)
29

30
    def compute_loss(self, model, inputs, return_outputs=False):
31
        outputs: BiencoderOutput = model(inputs)
32
        loss = outputs.loss
33

34
        if self.model.training:
35
            step_acc1, step_acc3 = accuracy(output=outputs.scores.detach(), target=outputs.labels, topk=(1, 3))
36
            step_mrr = batch_mrr(output=outputs.scores.detach(), target=outputs.labels)
37

38
            self.acc1_meter.update(step_acc1)
39
            self.acc3_meter.update(step_acc3)
40
            self.mrr_meter.update(step_mrr)
41

42
            if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0:
43
                logger.info('step: {}, {}, {}, {}'.format(self.state.global_step, self.mrr_meter, self.acc1_meter, self.acc3_meter))
44

45
            self._reset_meters_if_needed()
46

47
        return (loss, outputs) if return_outputs else loss
48

49
    def _reset_meters_if_needed(self):
50
        if int(self.state.epoch) != self.last_epoch:
51
            self.last_epoch = int(self.state.epoch)
52
            self.acc1_meter.reset()
53
            self.acc3_meter.reset()
54
            self.mrr_meter.reset()
55

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.