gemma_pytorch

Форк
0
/
model_xla.py 
578 строк · 19.3 Кб
1
# Copyright 2024 Google LLC
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
"""Inference-only Gemma model implementation."""
16

17
import re
18
import torch
19
from torch import nn
20
import torch.nn.functional as F
21
from typing import Any, List, Optional, Sequence, Tuple, Union
22

23
from gemma import config as gemma_config
24
from gemma.xla_model_parallel import (
25
    ColumnParallelLinear,
26
    ParallelEmbedding,
27
    RowParallelLinear,
28
    reduce_from_model_parallel_region,
29
    scatter_to_model_parallel_region,
30
)
31

32

33
class Sampler(nn.Module):
34

35
    def __init__(self, vocab_size: int, world_size: int, rank: int) -> None:
36
        super().__init__()
37
        self.vocab_size = vocab_size
38
        self.world_size = world_size
39
        self.rank = rank
40

41
    @torch.no_grad()
42
    def forward(
43
        self,
44
        embedding: torch.Tensor,
45
        hidden_states: torch.Tensor,
46
        output_positions: torch.Tensor,
47
        temperatures: torch.Tensor,
48
        top_ps: torch.Tensor,
49
        top_ks: torch.Tensor,
50
        embedding_bias: Optional[torch.Tensor] = None,
51
    ) -> torch.Tensor:
52
        # Select the last element for each sequence.
53
        # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)
54
        hidden_states = hidden_states.index_select(
55
            1, output_positions).squeeze(dim=1)
56

57
        hidden_states_parallel = scatter_to_model_parallel_region(
58
            hidden_states,
59
            groups=None,
60
            world_size=self.world_size,
61
            rank=self.rank)
62
        hidden_states_parallel = torch.matmul(hidden_states_parallel,
63
                                              embedding.t())
64
        logits = reduce_from_model_parallel_region(
65
            hidden_states_parallel,
66
            groups=None,
67
            world_size=self.world_size,
68
            rank=self.rank,
69
        )
70
        if embedding_bias is not None:
71
            logits += embedding_bias
72

73
        if temperatures is None:
74
            return torch.argmax(logits, dim=-1).squeeze(dim=-1)
75

76
        # Apply temperature scaling.
77
        logits.div_(temperatures.unsqueeze(dim=1))
78

79
        # Calculate probabilities with softmax.
80
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
81
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
82

83
        # Apply top-p, top-k.
84
        probs_sum = torch.cumsum(probs_sort, dim=-1)
85
        top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
86
        probs_sort = torch.where(top_ps_mask, 0, probs_sort)
87

88
        top_ks_mask = torch.arange(probs_idx.shape[-1],
89
                                   device=probs_idx.device)
90
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
91
        top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
92
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)
93

94
        # Re-normalization.
95
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
96
        probs = torch.gather(probs_sort,
97
                             dim=-1,
98
                             index=torch.argsort(probs_idx, dim=-1))
99

100
        next_token_ids = torch.multinomial(probs,
101
                                           num_samples=1,
102
                                           replacement=True).squeeze(dim=-1)
103
        return next_token_ids
104

105

106
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
107
    """Precomputes the frequency cis."""
108
    freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
109
    t = torch.arange(end, device=freqs.device)  # type: ignore
110
    freqs = torch.outer(t, freqs).float()  # type: ignore
111
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
112
    return freqs_cis
113

114

115
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
116
    """Applies the rotary embedding to the query and key tensors."""
117
    x_ = torch.view_as_complex(
118
        torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
119
                    dim=-1))
120
    x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
121
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
122
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
123
                          -1).transpose(1, 2)
124
    return x_out
125

126

127
class RMSNorm(torch.nn.Module):
128

129
    def __init__(
130
        self,
131
        dim: int,
132
        eps: float = 1e-6,
133
        add_unit_offset: bool = True,
134
    ):
135
        super().__init__()
136
        self.eps = eps
137
        self.add_unit_offset = add_unit_offset
138
        self.weight = nn.Parameter(torch.ones(dim))
139

140
    def _norm(self, x):
141
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
142

143
    def forward(self, x):
144
        x = self._norm(x.float()).type_as(x)
145
        if self.add_unit_offset:
146
            output = x * (1 + self.weight)
147
        else:
148
            output = x * self.weight
149
        return output
150

151

152
class GemmaMLP(nn.Module):
153

154
    def __init__(
155
        self,
156
        hidden_size: int,
157
        intermediate_size: int,
158
        world_size: int,
159
        rank: int,
160
        quant: bool,
161
    ):
162
        super().__init__()
163
        self.hidden_size = hidden_size
164
        self.intermediate_size = intermediate_size
165

166
        def init_method(x):
167
            return x
168

169
        self.gate_proj = ColumnParallelLinear(
170
            hidden_size,
171
            intermediate_size,
172
            bias=False,
173
            gather_output=False,
174
            init_method=init_method,
175
            world_size=world_size,
176
            rank=rank,
177
            quant=quant,
178
        )
179

180
        self.up_proj = ColumnParallelLinear(
181
            hidden_size,
182
            intermediate_size,
183
            bias=False,
184
            gather_output=False,
185
            init_method=init_method,
186
            world_size=world_size,
187
            rank=rank,
188
            quant=quant,
189
        )
190

191
        self.down_proj = RowParallelLinear(
192
            intermediate_size,
193
            hidden_size,
194
            bias=False,
195
            input_is_parallel=True,
196
            init_method=init_method,
197
            world_size=world_size,
198
            rank=rank,
199
            quant=quant,
200
        )
201

202
    def forward(self, x):
203
        gate = self.gate_proj(x)
204
        gate = F.gelu(gate, approximate="tanh")
205
        up = self.up_proj(x)
206
        fuse = gate * up
207
        outputs = self.down_proj(fuse)
208
        return outputs
209

210

211
class GemmaAttention(nn.Module):
212

213
    def __init__(
214
        self,
215
        hidden_size: int,
216
        num_heads: int,
217
        num_kv_heads: int,
218
        head_dim: int,
219
        world_size: int,
220
        rank: int,
221
        quant: bool,
222
    ):
223
        super().__init__()
224
        self.rank = rank
225

226
        def init_method(x):
227
            return x
228

229
        self.total_num_heads = num_heads
230
        assert self.total_num_heads % world_size == 0
231
        self.num_heads = self.total_num_heads // world_size  # head per shard
232

233
        if num_kv_heads < world_size:
234
            assert world_size % num_kv_heads == 0
235
            self.total_num_kv_heads = world_size
236
        else:
237
            assert num_kv_heads % world_size == 0
238
            self.total_num_kv_heads = num_kv_heads
239
        self.num_kv_heads = self.total_num_kv_heads // world_size  # kv head per shard
240

241
        assert self.num_heads % self.num_kv_heads == 0
242
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
243

244
        self.hidden_size = hidden_size
245
        self.head_dim = head_dim
246

247
        self.q_size = self.num_heads * self.head_dim
248
        self.kv_size = self.num_kv_heads * self.head_dim
249

250
        self.scaling = self.head_dim**-0.5
251

252
        self.qkv_proj = ColumnParallelLinear(
253
            self.hidden_size,
254
            (self.total_num_heads + 2 * self.total_num_kv_heads) *
255
            self.head_dim,
256
            bias=False,
257
            gather_output=False,
258
            init_method=init_method,
259
            world_size=world_size,
260
            rank=rank,
261
            quant=quant,
262
        )
263

264
        self.o_proj = RowParallelLinear(
265
            self.total_num_heads * self.head_dim,
266
            self.hidden_size,
267
            bias=False,
268
            input_is_parallel=True,
269
            init_method=init_method,
270
            world_size=world_size,
271
            rank=rank,
272
            quant=quant,
273
        )
274

275
    def forward(
276
        self,
277
        hidden_states: torch.Tensor,
278
        freqs_cis: torch.Tensor,
279
        kv_write_indices: torch.Tensor,
280
        kv_cache: Tuple[torch.Tensor, torch.Tensor],
281
        mask: torch.Tensor,
282
    ) -> torch.Tensor:
283
        hidden_states_shape = hidden_states.shape
284
        assert len(hidden_states_shape) == 3
285

286
        batch_size, input_len, _ = hidden_states_shape
287

288
        qkv = self.qkv_proj(hidden_states)
289
        xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
290
                               dim=-1)
291

292
        xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
293
        xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
294
        xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
295

296
        # Positional embedding.
297
        xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
298
        xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
299

300
        # Write new kv cache.
301
        # [batch_size, input_len, n_local_kv_heads, head_dim]
302
        k_cache, v_cache = kv_cache
303
        k_cache.index_copy_(1, kv_write_indices, xk)
304
        v_cache.index_copy_(1, kv_write_indices, xv)
305

306
        key = k_cache
307
        value = v_cache
308
        if self.num_kv_heads != self.num_heads:
309
            # [batch_size, max_seq_len, n_local_heads, head_dim]
310
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
311
            value = torch.repeat_interleave(value,
312
                                            self.num_queries_per_kv,
313
                                            dim=2)
314

315
        # [batch_size, n_local_heads, input_len, head_dim]
316
        q = xq.transpose(1, 2)
317
        # [batch_size, n_local_heads, max_seq_len, head_dim]
318
        k = key.transpose(1, 2)
319
        v = value.transpose(1, 2)
320

321
        # [batch_size, n_local_heads, input_len, max_seq_len]
322
        scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling
323
        scores = scores + mask
324
        scores = F.softmax(scores.float(), dim=-1).type_as(q)
325

326
        # [batch_size, n_local_heads, input_len, head_dim]
327
        output = torch.matmul(scores, v)
328

329
        # [batch_size, input_len, hidden_dim]
330
        output = (output.transpose(1, 2).contiguous().view(
331
            batch_size, input_len, -1))
332
        output = self.o_proj(output)
333
        return output
334

335

336
class GemmaDecoderLayer(nn.Module):
337

338
    def __init__(
339
        self,
340
        config: gemma_config.GemmaConfig,
341
        world_size: int,
342
        rank: int,
343
    ):
344
        super().__init__()
345
        self.rank = rank
346
        self.self_attn = GemmaAttention(
347
            hidden_size=config.hidden_size,
348
            num_heads=config.num_attention_heads,
349
            num_kv_heads=config.num_key_value_heads,
350
            head_dim=config.head_dim,
351
            world_size=world_size,
352
            rank=rank,
353
            quant=config.quant,
354
        )
355
        self.mlp = GemmaMLP(
356
            hidden_size=config.hidden_size,
357
            intermediate_size=config.intermediate_size,
358
            world_size=world_size,
359
            rank=rank,
360
            quant=config.quant,
361
        )
362
        self.input_layernorm = RMSNorm(config.hidden_size,
363
                                       eps=config.rms_norm_eps)
364
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
365
                                                eps=config.rms_norm_eps)
366

367
    def forward(
368
        self,
369
        hidden_states: torch.Tensor,
370
        freqs_cis: torch.Tensor,
371
        kv_write_indices: torch.Tensor,
372
        kv_cache: Tuple[torch.Tensor, torch.Tensor],
373
        mask: torch.Tensor,
374
    ) -> torch.Tensor:
375
        # Self Attention
376
        residual = hidden_states
377
        hidden_states = self.input_layernorm(hidden_states)
378
        hidden_states = self.self_attn(
379
            hidden_states=hidden_states,
380
            freqs_cis=freqs_cis,
381
            kv_write_indices=kv_write_indices,
382
            kv_cache=kv_cache,
383
            mask=mask,
384
        )
385
        hidden_states = residual + hidden_states
386

387
        # MLP
388
        residual = hidden_states
389
        hidden_states = self.post_attention_layernorm(hidden_states)
390
        hidden_states = self.mlp(hidden_states)
391
        hidden_states = residual + hidden_states
392

393
        return hidden_states
394

395

396
class GemmaModel(nn.Module):
397

398
    def __init__(
399
        self,
400
        config: gemma_config.GemmaConfig,
401
        world_size: int,
402
        rank: int
403
    ):
404
        super().__init__()
405
        self.config = config
406
        self.rank = rank
407
        self.vocab_size = config.vocab_size
408

409
        self.layers = nn.ModuleList()
410
        for _ in range(config.num_hidden_layers):
411
            self.layers.append(GemmaDecoderLayer(config, world_size, rank))
412
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
413

414
    def forward(
415
        self,
416
        hidden_states: torch.Tensor,
417
        freqs_cis: torch.Tensor,
418
        kv_write_indices: torch.Tensor,
419
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
420
        mask: torch.Tensor,
421
    ) -> torch.Tensor:
422
        for i in range(len(self.layers)):
423
            layer = self.layers[i]
424
            hidden_states = layer(
425
                hidden_states=hidden_states,
426
                freqs_cis=freqs_cis,
427
                kv_write_indices=kv_write_indices,
428
                kv_cache=kv_caches[i],
429
                mask=mask,
430
            )
431
        hidden_states = self.norm(hidden_states)
432
        return hidden_states
433

434

435
class GemmaForCausalLM(nn.Module):
436

437
    def __init__(
438
        self,
439
        config: gemma_config.GemmaConfig,
440
        world_size: int,
441
        rank: int,
442
        device: torch.device,
443
    ):
444
        super().__init__()
445
        self.config = config
446
        self.world_size = world_size
447
        self.rank = rank
448
        self.device = device
449

450
        assert config.num_attention_heads % world_size == 0
451
        assert config.hidden_size % config.num_attention_heads == 0
452

453
        max_seq_len = config.max_position_embeddings
454
        head_dim = config.head_dim
455
        vocab_size = config.vocab_size
456

457
        def init_method(x):
458
            return x
459

460
        self.embedder = ParallelEmbedding(
461
            vocab_size,
462
            config.hidden_size,
463
            init_method=init_method,
464
            world_size=world_size,
465
            rank=rank,
466
            quant=config.quant,
467
        )
468
        self.model = GemmaModel(config, world_size, rank)
469
        self.sampler = Sampler(vocab_size, world_size, rank)
470

471
        rope_theta = getattr(config, 'rope_theta', 10000)
472
        # [head_dim * 2, ] -> complex -> two dim (real, imaginary) implicitly
473
        freqs_cis = precompute_freqs_cis(head_dim,
474
                                         max_seq_len * 2,
475
                                         theta=rope_theta)
476
        self.register_buffer('freqs_cis', freqs_cis)
477

478
    @torch.no_grad()
479
    def forward(
480
        self,
481
        input_token_ids: torch.Tensor,
482
        input_positions: torch.Tensor,
483
        kv_write_indices: torch.Tensor,
484
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
485
        mask: torch.Tensor,
486
        output_positions: torch.Tensor,
487
        temperatures: torch.Tensor,
488
        top_ps: torch.Tensor,
489
        top_ks: torch.Tensor,
490
        **kwargs,
491
    ) -> torch.Tensor:
492
        freqs_cis = self.freqs_cis.index_select(0, input_positions)
493
        kv_write_indices = input_positions
494

495
        hidden_states = self.embedder(input_token_ids)
496
        # Gemma normalizes the embedding by sqrt(hidden_size).
497
        hidden_states = hidden_states * (self.config.hidden_size**0.5)
498
        # hidden_states should be [batch_size, input_len, hidden_size]
499

500
        hidden_states = self.model(
501
            hidden_states=hidden_states,
502
            freqs_cis=freqs_cis,
503
            kv_write_indices=kv_write_indices,
504
            kv_caches=kv_caches,
505
            mask=mask,
506
        )
507
        embedder_weight = self.embedder.weight
508
        if self.config.quant:
509
            embedder_weight = (
510
                embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))
511
        next_tokens = self.sampler(
512
            embedding=embedder_weight,
513
            hidden_states=hidden_states,
514
            output_positions=output_positions,
515
            temperatures=temperatures,
516
            top_ps=top_ps,
517
            top_ks=top_ks,
518
        )
519
        return next_tokens
520

521
    def load_weights(self, model_path: str):
522
        checkpoint = torch.load(model_path, weights_only=True)
523
        model_state_dict = checkpoint['model_state_dict']
524

525
        num_attn_heads = self.config.num_attention_heads
526
        num_kv_heads = self.config.num_key_value_heads
527
        head_dim = self.config.head_dim
528
        hidden_size = self.config.hidden_size
529

530
        def split(tensor: torch.Tensor, axis: int) -> torch.Tensor:
531
            axis_len = tensor.shape[axis]
532
            split_len = axis_len // self.world_size
533
            split_start = split_len * self.rank
534
            split_end = split_start + split_len
535
            tensor = torch.moveaxis(tensor, axis, 0)
536
            tensor = tensor[split_start:split_end, ...]
537
            tensor = torch.moveaxis(tensor, 0, axis)
538
            return tensor
539

540
        for k, v in model_state_dict.items():
541
            if k == 'freqs_cis':
542
                continue
543
            if (k == 'model.norm.weight' or re.fullmatch(
544
                    r'model.layers.\d+.input_layernorm.weight', k)
545
                    or re.fullmatch(
546
                        r'model.layers.\d+.post_attention_layernorm.weight',
547
                        k) or k.endswith('weight_scaler')):
548
                pass
549
            elif (k == 'embedder.weight' or re.fullmatch(
550
                    r'model.layers.\d+.mlp.down_proj.weight', k)):
551
                v = split(v, 1)
552
            elif (re.fullmatch(r'model.layers.\d+.mlp.gate_proj.weight', k)
553
                  or re.fullmatch(r'model.layers.\d+.mlp.up_proj.weight', k)):
554
                v = split(v, 0)
555
            elif re.fullmatch(r'model.layers.\d+.self_attn.qkv_proj.weight',
556
                              k):
557
                if num_kv_heads <= self.world_size:
558
                    num_replicas = self.world_size // num_kv_heads
559
                    v = v.reshape(num_attn_heads + num_kv_heads * 2, head_dim,
560
                                  hidden_size)
561
                    query = v[:num_attn_heads, ...]
562
                    key = v[num_attn_heads:num_attn_heads + num_kv_heads,
563
                            ...].repeat(num_replicas, 1, 1)
564
                    value = v[-num_kv_heads:, ...].repeat(num_replicas, 1, 1)
565
                    v = torch.cat(
566
                        (split(query, 0), split(key, 0), split(value, 0)),
567
                        dim=0)
568
                else:
569
                    v = v.reshape(3, num_attn_heads, head_dim, hidden_size)
570
                    v = split(v, 1)
571
                v = v.reshape(-1, hidden_size)
572
            elif re.fullmatch(r'model.layers.\d+.self_attn.o_proj.weight', k):
573
                v = v.reshape(hidden_size, num_attn_heads, head_dim)
574
                v = split(v, 1)
575
                v = v.reshape(hidden_size, -1)
576
            else:
577
                raise ValueError(f'Unrecognized key: {k}')
578
            self.state_dict()[k].copy_(v)
579

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.