h2o-llmstudio
170 строк · 5.5 Кб
1from dataclasses import dataclass2from typing import Literal, Optional3
4import torch5import torch.nn as nn6from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer7from transformers.models.gpt_neox.modeling_gpt_neox import (8GPTNeoXConfig,9GPTNeoXModel,10GPTNeoXPreTrainedModel,11)
12from transformers.utils import ModelOutput13
14from llm_studio.src.datasets.text_utils import TEXT_SEPARATOR15
16
17class GPTNeoXRewardModelConfig(GPTNeoXConfig):18model_type = "gpt_neox_reward_model"19
20pooling: Literal["mean", "last"]21
22def __init__(23self,24pooling: Literal["mean", "last"] = "last",25**kwargs,26):27super().__init__(**kwargs)28self.pooling = pooling or "last"29
30
31@dataclass
32class GPTNeoXRewardModelOutput(ModelOutput):33"""34Reward model output.
35
36Args:
37logits (`torch.FloatTensor` of shape `(batch_size, 1)`):
38Reward score
39"""
40
41logits: torch.FloatTensor = None42
43
44class GPTNeoXRewardModel(GPTNeoXPreTrainedModel):45config_class = GPTNeoXRewardModelConfig46
47def __init__(self, config):48if isinstance(config, GPTNeoXConfig):49# When a normal GPTNeoX was loaded it will be converted into a reward model.50# The direct `type(config) == GPTNeoXConfig` comparison is used (instead of51# `isinstance()`) since the configuration class of the reward model is also52# derived form `GPTNeoXConfig`.53config = GPTNeoXRewardModelConfig.from_dict(config.to_dict())54super().__init__(config)55
56self.gpt_neox = GPTNeoXModel(config)57self.out_proj = nn.Linear(config.hidden_size, 1)58self.pooling = config.pooling59
60def forward(61self,62input_ids,63attention_mask: Optional[torch.FloatTensor] = None,64inputs_embeds: Optional[torch.FloatTensor] = None,65head_mask: Optional[torch.FloatTensor] = None,66use_cache: Optional[bool] = None,67return_dict: Optional[bool] = True,68) -> GPTNeoXRewardModelOutput:69outputs = self.gpt_neox(70input_ids,71attention_mask=attention_mask,72head_mask=head_mask,73inputs_embeds=inputs_embeds,74use_cache=use_cache,75return_dict=return_dict,76)77
78hidden_states = outputs[0]79if self.pooling == "mean":80if attention_mask is None:81pooled = hidden_states.mean(dim=1)82else:83pooled = (hidden_states * attention_mask).sum(84dim=185) / attention_mask.sum(dim=1)86elif self.pooling == "last":87if attention_mask is None:88pooled = hidden_states[:, -1]89else:90last_idx = attention_mask.cumsum(dim=1).argmax(dim=1)91pooled = hidden_states.gather(921, last_idx.view(-1, 1, 1).expand(-1, 1, hidden_states.size(-1))93).squeeze(1)94else:95raise ValueError(f"Unknown pooling method: {self.pooling}")96
97logits = self.out_proj(pooled)98
99if not return_dict:100return (logits,) + outputs[1:]101
102return GPTNeoXRewardModelOutput(logits=logits)103
104
105class RewardModel(nn.Module):106def __init__(self, cfg):107super(RewardModel, self).__init__()108
109AutoConfig.register("gpt_neox_reward_model", GPTNeoXRewardModelConfig)110AutoModelForSequenceClassification.register(111GPTNeoXRewardModelConfig, GPTNeoXRewardModel112)113
114self.cfg = cfg115self.model_name = cfg.reward_model116self.device = cfg.environment._device117self.model = AutoModelForSequenceClassification.from_pretrained(118self.model_name,119torch_dtype=(120torch.float16121if (torch.cuda.is_available() and len(cfg.environment.gpus) > 0)122else torch.float32123),124).to(self.device)125self.tokenizer = AutoTokenizer.from_pretrained(126self.model_name, max_model_input_sizes=2048127)128
129def get_score(130self,131prompts=None,132answers=None,133):134scores = []135for prompt, answer in zip(prompts, answers):136if "deberta-v3" in self.model_name:137inputs = self.tokenizer(138" ".join(prompt.split(TEXT_SEPARATOR)),139answer,140return_tensors="pt",141max_length=2048,142).to(self.device)143elif self.model_name in [144"OpenAssistant/oasst-rm-2.1-pythia-1.4b-epoch-2.5",145"OpenAssistant/oasst-rm-2-pythia-6.9b-epoch-1",146]:147prompt = prompt.split(TEXT_SEPARATOR)148
149input_text = ""150
151for i, prompt_part in enumerate(prompt[::-1]):152if i % 2 == 0:153prefix = "<|prompter|>"154else:155prefix = "<|assistant|>"156input_text = f"{prefix}{prompt_part}<|endoftext|>" + input_text157
158input_text = input_text + f"<|assistant|>{answer}<|endoftext|>"159
160inputs = self.tokenizer(161input_text, return_tensors="pt", max_length=2048162).to(self.device)163else:164raise ValueError(165f"Reward model {self.model_name} not supported for scoring."166)167
168scores.append(self.model(**inputs).logits[0].cpu().detach().item())169del inputs170return scores171