h2o-llmstudio

Форк
0
/
text_reward_model.py 
170 строк · 5.5 Кб
1
from dataclasses import dataclass
2
from typing import Literal, Optional
3

4
import torch
5
import torch.nn as nn
6
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
7
from transformers.models.gpt_neox.modeling_gpt_neox import (
8
    GPTNeoXConfig,
9
    GPTNeoXModel,
10
    GPTNeoXPreTrainedModel,
11
)
12
from transformers.utils import ModelOutput
13

14
from llm_studio.src.datasets.text_utils import TEXT_SEPARATOR
15

16

17
class GPTNeoXRewardModelConfig(GPTNeoXConfig):
18
    model_type = "gpt_neox_reward_model"
19

20
    pooling: Literal["mean", "last"]
21

22
    def __init__(
23
        self,
24
        pooling: Literal["mean", "last"] = "last",
25
        **kwargs,
26
    ):
27
        super().__init__(**kwargs)
28
        self.pooling = pooling or "last"
29

30

31
@dataclass
32
class GPTNeoXRewardModelOutput(ModelOutput):
33
    """
34
    Reward model output.
35

36
    Args:
37
        logits (`torch.FloatTensor` of shape `(batch_size, 1)`):
38
            Reward score
39
    """
40

41
    logits: torch.FloatTensor = None
42

43

44
class GPTNeoXRewardModel(GPTNeoXPreTrainedModel):
45
    config_class = GPTNeoXRewardModelConfig
46

47
    def __init__(self, config):
48
        if isinstance(config, GPTNeoXConfig):
49
            # When a normal GPTNeoX was loaded it will be converted into a reward model.
50
            # The direct `type(config) == GPTNeoXConfig` comparison is used (instead of
51
            # `isinstance()`) since the configuration class of the reward model is also
52
            # derived form `GPTNeoXConfig`.
53
            config = GPTNeoXRewardModelConfig.from_dict(config.to_dict())
54
        super().__init__(config)
55

56
        self.gpt_neox = GPTNeoXModel(config)
57
        self.out_proj = nn.Linear(config.hidden_size, 1)
58
        self.pooling = config.pooling
59

60
    def forward(
61
        self,
62
        input_ids,
63
        attention_mask: Optional[torch.FloatTensor] = None,
64
        inputs_embeds: Optional[torch.FloatTensor] = None,
65
        head_mask: Optional[torch.FloatTensor] = None,
66
        use_cache: Optional[bool] = None,
67
        return_dict: Optional[bool] = True,
68
    ) -> GPTNeoXRewardModelOutput:
69
        outputs = self.gpt_neox(
70
            input_ids,
71
            attention_mask=attention_mask,
72
            head_mask=head_mask,
73
            inputs_embeds=inputs_embeds,
74
            use_cache=use_cache,
75
            return_dict=return_dict,
76
        )
77

78
        hidden_states = outputs[0]
79
        if self.pooling == "mean":
80
            if attention_mask is None:
81
                pooled = hidden_states.mean(dim=1)
82
            else:
83
                pooled = (hidden_states * attention_mask).sum(
84
                    dim=1
85
                ) / attention_mask.sum(dim=1)
86
        elif self.pooling == "last":
87
            if attention_mask is None:
88
                pooled = hidden_states[:, -1]
89
            else:
90
                last_idx = attention_mask.cumsum(dim=1).argmax(dim=1)
91
                pooled = hidden_states.gather(
92
                    1, last_idx.view(-1, 1, 1).expand(-1, 1, hidden_states.size(-1))
93
                ).squeeze(1)
94
        else:
95
            raise ValueError(f"Unknown pooling method: {self.pooling}")
96

97
        logits = self.out_proj(pooled)
98

99
        if not return_dict:
100
            return (logits,) + outputs[1:]
101

102
        return GPTNeoXRewardModelOutput(logits=logits)
103

104

105
class RewardModel(nn.Module):
106
    def __init__(self, cfg):
107
        super(RewardModel, self).__init__()
108

109
        AutoConfig.register("gpt_neox_reward_model", GPTNeoXRewardModelConfig)
110
        AutoModelForSequenceClassification.register(
111
            GPTNeoXRewardModelConfig, GPTNeoXRewardModel
112
        )
113

114
        self.cfg = cfg
115
        self.model_name = cfg.reward_model
116
        self.device = cfg.environment._device
117
        self.model = AutoModelForSequenceClassification.from_pretrained(
118
            self.model_name,
119
            torch_dtype=(
120
                torch.float16
121
                if (torch.cuda.is_available() and len(cfg.environment.gpus) > 0)
122
                else torch.float32
123
            ),
124
        ).to(self.device)
125
        self.tokenizer = AutoTokenizer.from_pretrained(
126
            self.model_name, max_model_input_sizes=2048
127
        )
128

129
    def get_score(
130
        self,
131
        prompts=None,
132
        answers=None,
133
    ):
134
        scores = []
135
        for prompt, answer in zip(prompts, answers):
136
            if "deberta-v3" in self.model_name:
137
                inputs = self.tokenizer(
138
                    " ".join(prompt.split(TEXT_SEPARATOR)),
139
                    answer,
140
                    return_tensors="pt",
141
                    max_length=2048,
142
                ).to(self.device)
143
            elif self.model_name in [
144
                "OpenAssistant/oasst-rm-2.1-pythia-1.4b-epoch-2.5",
145
                "OpenAssistant/oasst-rm-2-pythia-6.9b-epoch-1",
146
            ]:
147
                prompt = prompt.split(TEXT_SEPARATOR)
148

149
                input_text = ""
150

151
                for i, prompt_part in enumerate(prompt[::-1]):
152
                    if i % 2 == 0:
153
                        prefix = "<|prompter|>"
154
                    else:
155
                        prefix = "<|assistant|>"
156
                    input_text = f"{prefix}{prompt_part}<|endoftext|>" + input_text
157

158
                input_text = input_text + f"<|assistant|>{answer}<|endoftext|>"
159

160
                inputs = self.tokenizer(
161
                    input_text, return_tensors="pt", max_length=2048
162
                ).to(self.device)
163
            else:
164
                raise ValueError(
165
                    f"Reward model {self.model_name} not supported for scoring."
166
                )
167

168
            scores.append(self.model(**inputs).logits[0].cpu().detach().item())
169
            del inputs
170
        return scores
171

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.