h2o-llmstudio
141 строка · 4.2 Кб
1import logging
2from typing import Any, Dict
3
4import torch
5from torch import nn
6from transformers import AutoModelForCausalLM
7
8from llm_studio.src.metrics.text_causal_language_modeling_metrics import Perplexity
9from llm_studio.src.utils.data_utils import batch_padding
10from llm_studio.src.utils.modeling_utils import (
11create_nlp_backbone,
12generate,
13prepare_lora,
14)
15
16logger = logging.getLogger(__name__)
17
18
19class ValueHead(nn.Module):
20"""
21The ValueHead class implements a head for GPT2 that returns a scalar for each
22output token.
23
24Based on the implementation of trl library:
25https://github.com/lvwerra/trl/blob/main/trl/models/modeling_value_head.py
26"""
27
28def __init__(self, config):
29super().__init__()
30if not hasattr(config, "summary_dropout_prob"):
31summary_dropout_prob = 0.1
32else:
33summary_dropout_prob = config.summary_dropout_prob
34
35self.dropout = (
36nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity()
37)
38
39# some models such as OPT have a projection layer before the word embeddings
40# e.g. OPT-350m
41if hasattr(config, "word_embed_proj_dim"):
42hidden_size = config.word_embed_proj_dim
43else:
44hidden_size = config.hidden_size
45
46self.summary = nn.Linear(hidden_size, 1)
47
48def forward(self, hidden_states):
49output = self.dropout(hidden_states)
50
51# For now force upcast in fp32 if needed. Let's keep the
52# output in fp32 for numerical stability.
53if output.dtype != self.summary.weight.dtype:
54output = output.to(self.summary.weight.dtype)
55
56output = self.summary(output)
57return output
58
59
60class Model(nn.Module):
61"""
62Model for causal language modeling problem type.
63"""
64
65def __init__(self, cfg: Any):
66"""
67Args:
68cfg: config with all the hyperparameters
69"""
70
71super(Model, self).__init__()
72
73self.cfg = cfg
74assert cfg.training.lora, "LoRA must be True for RLHF"
75
76self.backbone, self.backbone_config = create_nlp_backbone(
77cfg, model_class=AutoModelForCausalLM
78)
79
80self.backbone = prepare_lora(cfg=self.cfg, backbone=self.backbone)
81
82if self.cfg.prediction.metric == "Perplexity":
83self.perplexity = Perplexity(self.cfg, reduce=False)
84
85self.value_head = ValueHead(self.backbone_config)
86self.value_head.summary.bias.data.zero_()
87
88def forward(
89self,
90batch: Dict,
91padding: bool = True,
92) -> Dict:
93# disable cache if gradient checkpointing is enabled
94if self.cfg.architecture.gradient_checkpointing:
95self.backbone.config.use_cache = False
96
97outputs: Dict = {}
98mask_key = "attention_mask"
99pad_keys = [
100"input_ids",
101"attention_mask",
102"special_tokens_mask",
103"labels",
104]
105
106if padding:
107batch = batch_padding(
108self.cfg,
109batch,
110self.training,
111mask_key=mask_key,
112pad_keys=pad_keys,
113)
114
115output = self.backbone(
116input_ids=batch["input_ids"],
117attention_mask=batch["attention_mask"],
118output_hidden_states=True,
119)
120
121if self.cfg.prediction.metric == "Perplexity" and not self.training:
122outputs["perplexity"] = self.perplexity(output.logits, batch["labels"])
123
124if self.training:
125last_hidden_state = output.hidden_states[-1]
126
127# force upcast in fp32 if logits are in half-precision
128if output.logits.dtype != torch.float32:
129output.logits = output.logits.float()
130
131outputs["logits"] = output.logits
132outputs["value"] = self.value_head(last_hidden_state).squeeze(-1)
133
134# enable cache again if gradient checkpointing is enabled
135if self.cfg.architecture.gradient_checkpointing:
136self.backbone.config.use_cache = True
137
138return outputs
139
140def generate(self, batch: Dict, cfg: Any, streamer=None):
141return generate(self.backbone, batch, cfg, streamer)
142