lmops
50 строк · 1.9 Кб
1import os
2
3from typing import Optional
4from transformers.trainer import Trainer, IntervalStrategy
5from transformers.modeling_outputs import SequenceClassifierOutput
6
7from logger_config import logger
8from evaluation.metrics import accuracy
9from utils import AverageMeter
10
11
12class RewardTrainer(Trainer):
13
14def __init__(self, *pargs, **kwargs):
15super(RewardTrainer, self).__init__(*pargs, **kwargs)
16
17self.acc_meter = AverageMeter('acc', round_digits=2)
18self.last_epoch = 0
19
20def _save(self, output_dir: Optional[str] = None, state_dict=None):
21output_dir = output_dir if output_dir is not None else self.args.output_dir
22os.makedirs(output_dir, exist_ok=True)
23logger.info("Saving model checkpoint to {}".format(output_dir))
24
25self.model.save_pretrained(output_dir)
26
27if self.tokenizer is not None and self.is_world_process_zero():
28self.tokenizer.save_pretrained(output_dir)
29
30def compute_loss(self, model, inputs, return_outputs=False):
31outputs: SequenceClassifierOutput = model(inputs)
32loss = outputs.loss
33
34if self.model.training:
35labels = inputs['labels']
36step_acc = accuracy(output=outputs.logits.detach(), target=labels)[0]
37self.acc_meter.update(step_acc)
38if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0:
39logger.info('step: {}, {}'.format(self.state.global_step, self.acc_meter))
40
41self._reset_meters_if_needed()
42
43return (loss, outputs) if return_outputs else loss
44
45def _reset_meters_if_needed(self):
46if int(self.state.epoch) != self.last_epoch:
47self.last_epoch = int(self.state.epoch)
48self.acc_meter.reset()
49if self.args.save_strategy == IntervalStrategy.STEPS and self.state.global_step % self.args.save_steps == 0:
50self.acc_meter.reset()
51