aurora

Форк
0
/
trainer.py 
75 строк · 2.9 Кб
1
import torch
2
from collections import defaultdict
3
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
4
from transformers import BatchEncoding, Trainer
5
from trl import DPOTrainer
6
from trl.trainer.utils import disable_dropout_in_model
7

8
from llmtuner.extras.constants import IGNORE_INDEX
9

10
if TYPE_CHECKING:
11
    from transformers import PreTrainedModel
12

13

14
class CustomDPOTrainer(DPOTrainer):
15

16
    def __init__(
17
        self,
18
        beta: float,
19
        model: Union["PreTrainedModel", torch.nn.Module],
20
        ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
21
        disable_dropout: Optional[bool] = True,
22
        loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
23
        **kwargs
24
    ):
25
        if disable_dropout:
26
            disable_dropout_in_model(model)
27
            if ref_model is not None:
28
                disable_dropout_in_model(ref_model)
29

30
        self.is_encoder_decoder = model.config.is_encoder_decoder
31
        self.ref_model = ref_model
32
        self.use_dpo_data_collator = True # hack to avoid warning
33
        self.generate_during_eval = False # disable at evaluation
34
        self.label_pad_token_id = IGNORE_INDEX
35
        self.padding_value = 0
36
        self.beta = beta
37
        self.loss_type = loss_type
38
        self._stored_metrics = defaultdict(lambda: defaultdict(list))
39

40
        Trainer.__init__(self, model=model, **kwargs)
41
        if not hasattr(self, "accelerator"):
42
            raise AttributeError("Please update `transformers`.")
43

44
        if ref_model is not None:
45
            if self.is_deepspeed_enabled:
46
                if not (
47
                    getattr(ref_model, "is_loaded_in_8bit", False)
48
                    or getattr(ref_model, "is_loaded_in_4bit", False)
49
                ): # quantized models are already set on the correct device
50
                    self.ref_model = self._prepare_deepspeed(self.ref_model)
51
            else:
52
                self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
53

54
    def concatenated_forward(
55
        self,
56
        model: Optional[torch.nn.Module] = None,
57
        batch: Optional[Dict[str, torch.Tensor]] = None
58
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
59
        batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
60

61
        all_logits = model(
62
            input_ids=batch_copied["input_ids"],
63
            attention_mask=batch_copied["attention_mask"],
64
            return_dict=True
65
        ).logits.to(torch.float32)
66

67
        all_logps = self._get_batch_logps(
68
            all_logits,
69
            batch["labels"],
70
            average_log_prob=False
71
        )
72
        batch_size = batch["input_ids"].size(0) // 2
73
        chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
74
        chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
75
        return chosen_logps, rejected_logps, chosen_logits, rejected_logits
76

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.