lmops
54 строки · 2.1 Кб
1import os
2
3from typing import Optional
4from transformers.trainer import Trainer
5
6from logger_config import logger
7from evaluation.metrics import accuracy, batch_mrr
8from models import BiencoderOutput, BiencoderModel
9from utils import AverageMeter
10
11
12class BiencoderTrainer(Trainer):
13def __init__(self, *pargs, **kwargs):
14super(BiencoderTrainer, self).__init__(*pargs, **kwargs)
15self.model: BiencoderModel
16
17self.acc1_meter = AverageMeter('Acc@1', round_digits=2)
18self.acc3_meter = AverageMeter('Acc@3', round_digits=2)
19self.mrr_meter = AverageMeter('mrr', round_digits=2)
20self.last_epoch = 0
21
22def _save(self, output_dir: Optional[str] = None, state_dict=None):
23output_dir = output_dir if output_dir is not None else self.args.output_dir
24os.makedirs(output_dir, exist_ok=True)
25logger.info("Saving model checkpoint to {}".format(output_dir))
26self.model.save(output_dir)
27if self.tokenizer is not None:
28self.tokenizer.save_pretrained(output_dir)
29
30def compute_loss(self, model, inputs, return_outputs=False):
31outputs: BiencoderOutput = model(inputs)
32loss = outputs.loss
33
34if self.model.training:
35step_acc1, step_acc3 = accuracy(output=outputs.scores.detach(), target=outputs.labels, topk=(1, 3))
36step_mrr = batch_mrr(output=outputs.scores.detach(), target=outputs.labels)
37
38self.acc1_meter.update(step_acc1)
39self.acc3_meter.update(step_acc3)
40self.mrr_meter.update(step_mrr)
41
42if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0:
43logger.info('step: {}, {}, {}, {}'.format(self.state.global_step, self.mrr_meter, self.acc1_meter, self.acc3_meter))
44
45self._reset_meters_if_needed()
46
47return (loss, outputs) if return_outputs else loss
48
49def _reset_meters_if_needed(self):
50if int(self.state.epoch) != self.last_epoch:
51self.last_epoch = int(self.state.epoch)
52self.acc1_meter.reset()
53self.acc3_meter.reset()
54self.mrr_meter.reset()
55