lmops

Форк
0
/
reward_trainer.py 
50 строк · 1.9 Кб
1
import os
2

3
from typing import Optional
4
from transformers.trainer import Trainer, IntervalStrategy
5
from transformers.modeling_outputs import SequenceClassifierOutput
6

7
from logger_config import logger
8
from evaluation.metrics import accuracy
9
from utils import AverageMeter
10

11

12
class RewardTrainer(Trainer):
13

14
    def __init__(self, *pargs, **kwargs):
15
        super(RewardTrainer, self).__init__(*pargs, **kwargs)
16

17
        self.acc_meter = AverageMeter('acc', round_digits=2)
18
        self.last_epoch = 0
19

20
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
21
        output_dir = output_dir if output_dir is not None else self.args.output_dir
22
        os.makedirs(output_dir, exist_ok=True)
23
        logger.info("Saving model checkpoint to {}".format(output_dir))
24

25
        self.model.save_pretrained(output_dir)
26

27
        if self.tokenizer is not None and self.is_world_process_zero():
28
            self.tokenizer.save_pretrained(output_dir)
29

30
    def compute_loss(self, model, inputs, return_outputs=False):
31
        outputs: SequenceClassifierOutput = model(inputs)
32
        loss = outputs.loss
33

34
        if self.model.training:
35
            labels = inputs['labels']
36
            step_acc = accuracy(output=outputs.logits.detach(), target=labels)[0]
37
            self.acc_meter.update(step_acc)
38
            if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0:
39
                logger.info('step: {}, {}'.format(self.state.global_step, self.acc_meter))
40

41
            self._reset_meters_if_needed()
42

43
        return (loss, outputs) if return_outputs else loss
44

45
    def _reset_meters_if_needed(self):
46
        if int(self.state.epoch) != self.last_epoch:
47
            self.last_epoch = int(self.state.epoch)
48
            self.acc_meter.reset()
49
        if self.args.save_strategy == IntervalStrategy.STEPS and self.state.global_step % self.args.save_steps == 0:
50
            self.acc_meter.reset()
51

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.