h2o-llmstudio

Форк
0
/
text_rlhf_language_modeling_model.py 
141 строка · 4.2 Кб
1
import logging
2
from typing import Any, Dict
3

4
import torch
5
from torch import nn
6
from transformers import AutoModelForCausalLM
7

8
from llm_studio.src.metrics.text_causal_language_modeling_metrics import Perplexity
9
from llm_studio.src.utils.data_utils import batch_padding
10
from llm_studio.src.utils.modeling_utils import (
11
    create_nlp_backbone,
12
    generate,
13
    prepare_lora,
14
)
15

16
logger = logging.getLogger(__name__)
17

18

19
class ValueHead(nn.Module):
20
    """
21
    The ValueHead class implements a head for GPT2 that returns a scalar for each
22
    output token.
23

24
    Based on the implementation of trl library:
25
    https://github.com/lvwerra/trl/blob/main/trl/models/modeling_value_head.py
26
    """
27

28
    def __init__(self, config):
29
        super().__init__()
30
        if not hasattr(config, "summary_dropout_prob"):
31
            summary_dropout_prob = 0.1
32
        else:
33
            summary_dropout_prob = config.summary_dropout_prob
34

35
        self.dropout = (
36
            nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity()
37
        )
38

39
        # some models such as OPT have a projection layer before the word embeddings
40
        # e.g. OPT-350m
41
        if hasattr(config, "word_embed_proj_dim"):
42
            hidden_size = config.word_embed_proj_dim
43
        else:
44
            hidden_size = config.hidden_size
45

46
        self.summary = nn.Linear(hidden_size, 1)
47

48
    def forward(self, hidden_states):
49
        output = self.dropout(hidden_states)
50

51
        # For now force upcast in fp32 if needed. Let's keep the
52
        # output in fp32 for numerical stability.
53
        if output.dtype != self.summary.weight.dtype:
54
            output = output.to(self.summary.weight.dtype)
55

56
        output = self.summary(output)
57
        return output
58

59

60
class Model(nn.Module):
61
    """
62
    Model for causal language modeling problem type.
63
    """
64

65
    def __init__(self, cfg: Any):
66
        """
67
        Args:
68
            cfg: config with all the hyperparameters
69
        """
70

71
        super(Model, self).__init__()
72

73
        self.cfg = cfg
74
        assert cfg.training.lora, "LoRA must be True for RLHF"
75

76
        self.backbone, self.backbone_config = create_nlp_backbone(
77
            cfg, model_class=AutoModelForCausalLM
78
        )
79

80
        self.backbone = prepare_lora(cfg=self.cfg, backbone=self.backbone)
81

82
        if self.cfg.prediction.metric == "Perplexity":
83
            self.perplexity = Perplexity(self.cfg, reduce=False)
84

85
        self.value_head = ValueHead(self.backbone_config)
86
        self.value_head.summary.bias.data.zero_()
87

88
    def forward(
89
        self,
90
        batch: Dict,
91
        padding: bool = True,
92
    ) -> Dict:
93
        # disable cache if gradient checkpointing is enabled
94
        if self.cfg.architecture.gradient_checkpointing:
95
            self.backbone.config.use_cache = False
96

97
        outputs: Dict = {}
98
        mask_key = "attention_mask"
99
        pad_keys = [
100
            "input_ids",
101
            "attention_mask",
102
            "special_tokens_mask",
103
            "labels",
104
        ]
105

106
        if padding:
107
            batch = batch_padding(
108
                self.cfg,
109
                batch,
110
                self.training,
111
                mask_key=mask_key,
112
                pad_keys=pad_keys,
113
            )
114

115
        output = self.backbone(
116
            input_ids=batch["input_ids"],
117
            attention_mask=batch["attention_mask"],
118
            output_hidden_states=True,
119
        )
120

121
        if self.cfg.prediction.metric == "Perplexity" and not self.training:
122
            outputs["perplexity"] = self.perplexity(output.logits, batch["labels"])
123

124
        if self.training:
125
            last_hidden_state = output.hidden_states[-1]
126

127
            # force upcast in fp32 if logits are in half-precision
128
            if output.logits.dtype != torch.float32:
129
                output.logits = output.logits.float()
130

131
            outputs["logits"] = output.logits
132
            outputs["value"] = self.value_head(last_hidden_state).squeeze(-1)
133

134
        # enable cache again if gradient checkpointing is enabled
135
        if self.cfg.architecture.gradient_checkpointing:
136
            self.backbone.config.use_cache = True
137

138
        return outputs
139

140
    def generate(self, batch: Dict, cfg: Any, streamer=None):
141
        return generate(self.backbone, batch, cfg, streamer)
142

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.