gemma_pytorch

Форк
0
/
model.py 
566 строк · 20.1 Кб
1
# Copyright 2024 Google LLC
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""Inference-only Gemma model implementation."""
15

16
import re
17
import torch
18
from torch import nn
19
import torch.nn.functional as F
20
from typing import Any, List, Optional, Sequence, Tuple, Union
21

22
from gemma import config as gemma_config
23
from gemma import tokenizer
24

25

26
class Sampler(nn.Module):
27

28
    def __init__(self, vocab_size: int):
29
        super().__init__()
30
        self.vocab_size = vocab_size
31

32
    @torch.no_grad()
33
    def forward(
34
        self,
35
        embedding: torch.Tensor,
36
        hidden_states: torch.Tensor,
37
        output_positions: torch.Tensor,
38
        temperatures: torch.Tensor,
39
        top_ps: torch.Tensor,
40
        top_ks: torch.Tensor,
41
        embedding_bias: Optional[torch.Tensor] = None,
42
    ) -> torch.Tensor:
43
        # Select the last element for each sequence.
44
        # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)
45
        hidden_states = hidden_states.index_select(
46
            1, output_positions).squeeze(dim=1)
47
        logits = torch.matmul(hidden_states, embedding.t())
48
        if embedding_bias is not None:
49
            logits += embedding_bias
50

51
        if temperatures is None:
52
            return torch.argmax(logits, dim=-1).squeeze(dim=-1)
53

54
        # Apply temperature scaling.
55
        logits.div_(temperatures.unsqueeze(dim=1))
56

57
        # Calculate probabilities with softmax.
58
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
59
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
60

61
        # Apply top-p, top-k.
62
        probs_sum = torch.cumsum(probs_sort, dim=-1)
63
        top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
64
        probs_sort = torch.where(top_ps_mask, 0, probs_sort)
65

66
        top_ks_mask = torch.arange(probs_idx.shape[-1],
67
                                   device=probs_idx.device)
68
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
69
        top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
70
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)
71

72
        # Re-normalization.
73
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
74
        probs = torch.gather(probs_sort,
75
                             dim=-1,
76
                             index=torch.argsort(probs_idx, dim=-1))
77

78
        next_token_ids = torch.multinomial(probs,
79
                                           num_samples=1,
80
                                           replacement=True).squeeze(dim=-1)
81
        return next_token_ids
82

83

84
def precompute_freqs_cis(dim: int,
85
                         end: int,
86
                         theta: float = 10000.0) -> torch.Tensor:
87
    """Precomputes the frequency cis."""
88
    freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
89
    t = torch.arange(end, device=freqs.device)
90
    freqs = torch.outer(t, freqs).float()
91
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
92
    return freqs_cis
93

94

95
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
96
    """Applies the rotary embedding to the query and key tensors."""
97
    x_ = torch.view_as_complex(
98
        torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
99
                    dim=-1))
100
    x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
101
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
102
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
103
                          -1).transpose(1, 2)
104
    return x_out
105

106

107
class Linear(nn.Module):
108

109
    def __init__(self, in_features: int, out_features: int, quant: bool):
110
        super().__init__()
111
        if quant:
112
            self.weight = nn.Parameter(
113
                torch.empty((out_features, in_features), dtype=torch.int8),
114
                requires_grad=False,
115
            )
116
            self.weight_scaler = nn.Parameter(torch.Tensor(out_features))
117
        else:
118
            self.weight = nn.Parameter(
119
                torch.empty((out_features, in_features)),
120
                requires_grad=False,
121
            )
122
        self.quant = quant
123

124
    def forward(self, x):
125
        weight = self.weight
126
        if self.quant:
127
            weight = weight * self.weight_scaler.unsqueeze(-1)
128
        output = F.linear(x, weight)
129
        return output
130

131

132
class Embedding(nn.Module):
133

134
    def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool):
135
        super().__init__()
136
        if quant:
137
            self.weight = nn.Parameter(
138
                torch.empty((num_embeddings, embedding_dim), dtype=torch.int8),
139
                requires_grad=False,
140
            )
141
            self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings))
142
        else:
143
            self.weight = nn.Parameter(
144
                torch.empty((num_embeddings, embedding_dim)),
145
                requires_grad=False,
146
            )
147
        self.quant = quant
148

149
    def forward(self, x):
150
        weight = self.weight
151
        if self.quant:
152
            weight = weight * self.weight_scaler.unsqueeze(-1)
153
        output = F.embedding(x, weight)
154
        return output
155

156

157
class RMSNorm(torch.nn.Module):
158

159
    def __init__(
160
        self,
161
        dim: int,
162
        eps: float = 1e-6,
163
        add_unit_offset: bool = True,
164
    ):
165
        super().__init__()
166
        self.eps = eps
167
        self.add_unit_offset = add_unit_offset
168
        self.weight = nn.Parameter(torch.zeros(dim))
169

170
    def _norm(self, x):
171
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
172

173
    def forward(self, x):
174
        x = self._norm(x.float()).type_as(x)
175
        if self.add_unit_offset:
176
            output = x * (1 + self.weight)
177
        else:
178
            output = x * self.weight
179
        return output
180

181

182
class GemmaMLP(nn.Module):
183

184
    def __init__(
185
        self,
186
        hidden_size: int,
187
        intermediate_size: int,
188
        quant: bool,
189
    ):
190
        super().__init__()
191
        self.gate_proj = Linear(hidden_size, intermediate_size, quant)
192
        self.up_proj = Linear(hidden_size, intermediate_size, quant)
193
        self.down_proj = Linear(intermediate_size, hidden_size, quant)
194

195
    def forward(self, x):
196
        gate = self.gate_proj(x)
197
        gate = F.gelu(gate, approximate="tanh")
198
        up = self.up_proj(x)
199
        fuse = gate * up
200
        outputs = self.down_proj(fuse)
201
        return outputs
202

203

204
class GemmaAttention(nn.Module):
205

206
    def __init__(
207
        self,
208
        hidden_size: int,
209
        num_heads: int,
210
        num_kv_heads: int,
211
        head_dim: int,
212
        quant: bool,
213
    ):
214
        super().__init__()
215

216
        self.num_heads = num_heads
217
        self.num_kv_heads = num_kv_heads
218

219
        assert self.num_heads % self.num_kv_heads == 0
220
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
221

222
        self.hidden_size = hidden_size
223
        self.head_dim = head_dim
224

225
        self.q_size = self.num_heads * self.head_dim
226
        self.kv_size = self.num_kv_heads * self.head_dim
227

228
        self.scaling = self.head_dim**-0.5
229

230
        self.qkv_proj = Linear(
231
            self.hidden_size,
232
            (self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
233
            quant=quant)
234
        self.o_proj = Linear(
235
            self.num_heads * self.head_dim,
236
            self.hidden_size,
237
            quant=quant)
238

239
    def forward(
240
        self,
241
        hidden_states: torch.Tensor,
242
        freqs_cis: torch.Tensor,
243
        kv_write_indices: torch.Tensor,
244
        kv_cache: Tuple[torch.Tensor, torch.Tensor],
245
        mask: torch.Tensor,
246
    ) -> torch.Tensor:
247
        hidden_states_shape = hidden_states.shape
248
        assert len(hidden_states_shape) == 3
249

250
        batch_size, input_len, _ = hidden_states_shape
251

252
        qkv = self.qkv_proj(hidden_states)
253
        xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
254
                               dim=-1)
255

256
        xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
257
        xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
258
        xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
259

260
        # Positional embedding.
261
        xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
262
        xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
263

264
        # Write new kv cache.
265
        # [batch_size, input_len, n_local_kv_heads, head_dim]
266
        k_cache, v_cache = kv_cache
267
        k_cache.index_copy_(1, kv_write_indices, xk)
268
        v_cache.index_copy_(1, kv_write_indices, xv)
269

270
        key = k_cache
271
        value = v_cache
272
        if self.num_kv_heads != self.num_heads:
273
            # [batch_size, max_seq_len, n_local_heads, head_dim]
274
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
275
            value = torch.repeat_interleave(value,
276
                                            self.num_queries_per_kv,
277
                                            dim=2)
278

279
        # [batch_size, n_local_heads, input_len, head_dim]
280
        q = xq.transpose(1, 2)
281
        # [batch_size, n_local_heads, max_seq_len, head_dim]
282
        k = key.transpose(1, 2)
283
        v = value.transpose(1, 2)
284

285
        # [batch_size, n_local_heads, input_len, max_seq_len]
286
        scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling
287
        scores = scores + mask
288
        scores = F.softmax(scores.float(), dim=-1).type_as(q)
289

290
        # [batch_size, n_local_heads, input_len, head_dim]
291
        output = torch.matmul(scores, v)
292

293
        # [batch_size, input_len, hidden_dim]
294
        output = (output.transpose(1, 2).contiguous().view(
295
            batch_size, input_len, -1))
296
        output = self.o_proj(output)
297
        return output
298

299

300
class GemmaDecoderLayer(nn.Module):
301

302
    def __init__(
303
        self,
304
        config: gemma_config.GemmaConfig,
305
    ):
306
        super().__init__()
307
        self.self_attn = GemmaAttention(
308
            hidden_size=config.hidden_size,
309
            num_heads=config.num_attention_heads,
310
            num_kv_heads=config.num_key_value_heads,
311
            head_dim=config.head_dim,
312
            quant=config.quant,
313
        )
314
        self.mlp = GemmaMLP(
315
            hidden_size=config.hidden_size,
316
            intermediate_size=config.intermediate_size,
317
            quant=config.quant,
318
        )
319
        self.input_layernorm = RMSNorm(config.hidden_size,
320
                                       eps=config.rms_norm_eps)
321
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
322
                                                eps=config.rms_norm_eps)
323

324
    def forward(
325
        self,
326
        hidden_states: torch.Tensor,
327
        freqs_cis: torch.Tensor,
328
        kv_write_indices: torch.Tensor,
329
        kv_cache: Tuple[torch.Tensor, torch.Tensor],
330
        mask: torch.Tensor,
331
    ) -> torch.Tensor:
332
        # Self Attention
333
        residual = hidden_states
334
        hidden_states = self.input_layernorm(hidden_states)
335
        hidden_states = self.self_attn(
336
            hidden_states=hidden_states,
337
            freqs_cis=freqs_cis,
338
            kv_write_indices=kv_write_indices,
339
            kv_cache=kv_cache,
340
            mask=mask,
341
        )
342
        hidden_states = residual + hidden_states
343

344
        # MLP
345
        residual = hidden_states
346
        hidden_states = self.post_attention_layernorm(hidden_states)
347
        hidden_states = self.mlp(hidden_states)
348
        hidden_states = residual + hidden_states
349

350
        return hidden_states
351

352

353
class GemmaModel(nn.Module):
354

355
    def __init__(self, config: gemma_config.GemmaConfig):
356
        super().__init__()
357
        self.config = config
358
        self.vocab_size = config.vocab_size
359

360
        self.layers = nn.ModuleList()
361
        for _ in range(config.num_hidden_layers):
362
            self.layers.append(GemmaDecoderLayer(config))
363
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
364

365
    def forward(
366
        self,
367
        hidden_states: torch.Tensor,
368
        freqs_cis: torch.Tensor,
369
        kv_write_indices: torch.Tensor,
370
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
371
        mask: torch.Tensor,
372
    ) -> torch.Tensor:
373
        for i in range(len(self.layers)):
374
            layer = self.layers[i]
375
            hidden_states = layer(
376
                hidden_states=hidden_states,
377
                freqs_cis=freqs_cis,
378
                kv_write_indices=kv_write_indices,
379
                kv_cache=kv_caches[i],
380
                mask=mask,
381
            )
382
        hidden_states = self.norm(hidden_states)
383
        return hidden_states
384

385

386
class GemmaForCausalLM(nn.Module):
387

388
    def __init__(
389
        self,
390
        config: gemma_config.GemmaConfig,
391
    ):
392
        super().__init__()
393
        self.config = config
394
        assert config.hidden_size % config.num_attention_heads == 0
395

396
        max_seq_len = config.max_position_embeddings
397
        head_dim = config.head_dim
398
        vocab_size = config.vocab_size
399

400
        self.tokenizer = tokenizer.Tokenizer(config.tokenizer)
401
        self.embedder = Embedding(vocab_size, config.hidden_size, config.quant)
402
        self.model = GemmaModel(config)
403
        self.sampler = Sampler(vocab_size)
404

405
        # Pre-compute rotary embedding table.
406
        rope_theta = getattr(config, 'rope_theta', 10000)
407
        freqs_cis = precompute_freqs_cis(head_dim,
408
                                         max_seq_len * 2,
409
                                         theta=rope_theta)
410
        self.register_buffer('freqs_cis', freqs_cis)
411

412
    @torch.no_grad()
413
    def forward(
414
        self,
415
        input_token_ids: torch.Tensor,
416
        input_positions: torch.Tensor,
417
        kv_write_indices: torch.Tensor,
418
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
419
        mask: torch.Tensor,
420
        output_positions: torch.Tensor,
421
        temperatures: torch.Tensor,
422
        top_ps: torch.Tensor,
423
        top_ks: torch.Tensor,
424
        **kwargs,
425
    ) -> torch.Tensor:
426
        freqs_cis = self.freqs_cis.index_select(0, input_positions)
427
        kv_write_indices = input_positions
428

429
        # [batch_size, input_len, hidden_size]
430
        hidden_states = self.embedder(input_token_ids)
431
        # Gemma normalizes the embedding by sqrt(hidden_size).
432
        hidden_states = hidden_states * (self.config.hidden_size**0.5)
433

434
        hidden_states = self.model(
435
            hidden_states=hidden_states,
436
            freqs_cis=freqs_cis,
437
            kv_write_indices=kv_write_indices,
438
            kv_caches=kv_caches,
439
            mask=mask,
440
        )
441
        embedder_weight = self.embedder.weight
442
        if self.config.quant:
443
            embedder_weight = (
444
                embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))
445
        next_tokens = self.sampler(
446
            embedding=embedder_weight,
447
            hidden_states=hidden_states,
448
            output_positions=output_positions,
449
            temperatures=temperatures,
450
            top_ps=top_ps,
451
            top_ks=top_ks,
452
        )
453
        return next_tokens
454

455
    def generate(
456
        self,
457
        prompts: Union[str, Sequence[str]],
458
        device: Any,
459
        output_len: int = 100,
460
        temperature: float = 0.95,
461
        top_p: float = 1.0,
462
        top_k: int = 100,
463
    ) -> Union[str, Sequence[str]]:
464
        """Generates responses for given prompts using Gemma model."""
465
        # If a single prompt is provided, treat it as a batch of 1.
466
        is_str_prompt = isinstance(prompts, str)
467
        if is_str_prompt:
468
            prompts = [prompts]
469

470
        batch_size = len(prompts)
471
        prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts]
472
        min_prompt_len = min(len(p) for p in prompt_tokens)
473
        max_prompt_len = max(len(p) for p in prompt_tokens)
474
        max_seq_len = max_prompt_len + output_len
475
        assert max_seq_len <= self.config.max_position_embeddings
476

477
        # build KV caches
478
        kv_caches = []
479
        for _ in range(self.config.num_hidden_layers):
480
            size = (batch_size, max_seq_len, self.config.num_key_value_heads,
481
                    self.config.head_dim)
482
            dtype = self.config.get_dtype()
483
            k_cache = torch.zeros(size=size, dtype=dtype, device=device)
484
            v_cache = torch.zeros(size=size, dtype=dtype, device=device)
485
            kv_caches.append((k_cache, v_cache))
486

487
        # prepare inputs
488
        token_ids_tensor = torch.full((batch_size, max_seq_len),
489
                                      self.tokenizer.pad_id, dtype=torch.int64)
490
        input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
491
                                            self.tokenizer.pad_id,
492
                                            dtype=torch.int64)
493
        for i, p in enumerate(prompt_tokens):
494
            token_ids_tensor[i, :len(p)] = torch.tensor(p)
495
            input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
496
                p[:min_prompt_len])
497
        token_ids_tensor = token_ids_tensor.to(device)
498
        input_token_ids_tensor = input_token_ids_tensor.to(device)
499
        prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id
500
        input_positions_tensor = torch.arange(0, min_prompt_len,
501
                                              dtype=torch.int64).to(device)
502
        mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
503
                                 -2.3819763e38).to(torch.float)
504
        mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
505
        curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
506
        output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(
507
            device)
508
        temperatures_tensor = torch.FloatTensor([temperature] * batch_size).to(
509
            device)
510
        top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
511
        top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
512
        output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(
513
            device)
514

515
        # Prefill up to min_prompt_len tokens, then treat other prefill as
516
        # decode and ignore output.
517
        for i in range(max_seq_len - min_prompt_len):
518
            next_token_ids = self(
519
                input_token_ids=input_token_ids_tensor,
520
                input_positions=input_positions_tensor,
521
                kv_write_indices=None,
522
                kv_caches=kv_caches,
523
                mask=curr_mask_tensor,
524
                output_positions=output_positions_tensor,
525
                temperatures=temperatures_tensor,
526
                top_ps=top_ps_tensor,
527
                top_ks=top_ks_tensor,
528
            )
529

530
            curr_prompt_mask = prompt_mask_tensor.index_select(
531
                1, output_index).squeeze(dim=1)
532
            curr_token_ids = token_ids_tensor.index_select(
533
                1, output_index).squeeze(dim=1)
534
            output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
535
                                           next_token_ids).unsqueeze(dim=1)
536
            token_ids_tensor.index_copy_(1, output_index, output_token_ids)
537

538
            input_token_ids_tensor = output_token_ids
539
            input_positions_tensor = output_index.unsqueeze(dim=-1)
540
            curr_mask_tensor = mask_tensor.index_select(2,
541
                                                        input_positions_tensor)
542
            output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
543
                device)
544
            output_index = output_index + 1
545

546
        # Detokenization.
547
        token_ids = token_ids_tensor.tolist()
548
        results = []
549
        for i, tokens in enumerate(token_ids):
550
            trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i])
551
                                    + output_len]
552
            if self.tokenizer.eos_id in trimmed_output:
553
                eos_index = trimmed_output.index(self.tokenizer.eos_id)
554
                trimmed_output = trimmed_output[:eos_index]
555
            results.append(self.tokenizer.decode(trimmed_output))
556

557
        # If a string was provided as input, return a string as output.
558
        return results[0] if is_str_prompt else results
559

560
    def load_weights(self, model_path: str):
561
        self.load_state_dict(
562
            torch.load(
563
                model_path, mmap=True, weights_only=True,
564
            )['model_state_dict'],
565
            strict=False,
566
        )
567

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.