gemma_pytorch
578 строк · 19.3 Кб
1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Inference-only Gemma model implementation."""
16
17import re18import torch19from torch import nn20import torch.nn.functional as F21from typing import Any, List, Optional, Sequence, Tuple, Union22
23from gemma import config as gemma_config24from gemma.xla_model_parallel import (25ColumnParallelLinear,26ParallelEmbedding,27RowParallelLinear,28reduce_from_model_parallel_region,29scatter_to_model_parallel_region,30)
31
32
33class Sampler(nn.Module):34
35def __init__(self, vocab_size: int, world_size: int, rank: int) -> None:36super().__init__()37self.vocab_size = vocab_size38self.world_size = world_size39self.rank = rank40
41@torch.no_grad()42def forward(43self,44embedding: torch.Tensor,45hidden_states: torch.Tensor,46output_positions: torch.Tensor,47temperatures: torch.Tensor,48top_ps: torch.Tensor,49top_ks: torch.Tensor,50embedding_bias: Optional[torch.Tensor] = None,51) -> torch.Tensor:52# Select the last element for each sequence.53# (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)54hidden_states = hidden_states.index_select(551, output_positions).squeeze(dim=1)56
57hidden_states_parallel = scatter_to_model_parallel_region(58hidden_states,59groups=None,60world_size=self.world_size,61rank=self.rank)62hidden_states_parallel = torch.matmul(hidden_states_parallel,63embedding.t())64logits = reduce_from_model_parallel_region(65hidden_states_parallel,66groups=None,67world_size=self.world_size,68rank=self.rank,69)70if embedding_bias is not None:71logits += embedding_bias72
73if temperatures is None:74return torch.argmax(logits, dim=-1).squeeze(dim=-1)75
76# Apply temperature scaling.77logits.div_(temperatures.unsqueeze(dim=1))78
79# Calculate probabilities with softmax.80probs = torch.softmax(logits, dim=-1, dtype=torch.float)81probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)82
83# Apply top-p, top-k.84probs_sum = torch.cumsum(probs_sort, dim=-1)85top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)86probs_sort = torch.where(top_ps_mask, 0, probs_sort)87
88top_ks_mask = torch.arange(probs_idx.shape[-1],89device=probs_idx.device)90top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)91top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)92probs_sort = torch.where(top_ks_mask, 0, probs_sort)93
94# Re-normalization.95probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))96probs = torch.gather(probs_sort,97dim=-1,98index=torch.argsort(probs_idx, dim=-1))99
100next_token_ids = torch.multinomial(probs,101num_samples=1,102replacement=True).squeeze(dim=-1)103return next_token_ids104
105
106def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):107"""Precomputes the frequency cis."""108freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))109t = torch.arange(end, device=freqs.device) # type: ignore110freqs = torch.outer(t, freqs).float() # type: ignore111freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64112return freqs_cis113
114
115def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:116"""Applies the rotary embedding to the query and key tensors."""117x_ = torch.view_as_complex(118torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),119dim=-1))120x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)121x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)122x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],123-1).transpose(1, 2)124return x_out125
126
127class RMSNorm(torch.nn.Module):128
129def __init__(130self,131dim: int,132eps: float = 1e-6,133add_unit_offset: bool = True,134):135super().__init__()136self.eps = eps137self.add_unit_offset = add_unit_offset138self.weight = nn.Parameter(torch.ones(dim))139
140def _norm(self, x):141return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)142
143def forward(self, x):144x = self._norm(x.float()).type_as(x)145if self.add_unit_offset:146output = x * (1 + self.weight)147else:148output = x * self.weight149return output150
151
152class GemmaMLP(nn.Module):153
154def __init__(155self,156hidden_size: int,157intermediate_size: int,158world_size: int,159rank: int,160quant: bool,161):162super().__init__()163self.hidden_size = hidden_size164self.intermediate_size = intermediate_size165
166def init_method(x):167return x168
169self.gate_proj = ColumnParallelLinear(170hidden_size,171intermediate_size,172bias=False,173gather_output=False,174init_method=init_method,175world_size=world_size,176rank=rank,177quant=quant,178)179
180self.up_proj = ColumnParallelLinear(181hidden_size,182intermediate_size,183bias=False,184gather_output=False,185init_method=init_method,186world_size=world_size,187rank=rank,188quant=quant,189)190
191self.down_proj = RowParallelLinear(192intermediate_size,193hidden_size,194bias=False,195input_is_parallel=True,196init_method=init_method,197world_size=world_size,198rank=rank,199quant=quant,200)201
202def forward(self, x):203gate = self.gate_proj(x)204gate = F.gelu(gate, approximate="tanh")205up = self.up_proj(x)206fuse = gate * up207outputs = self.down_proj(fuse)208return outputs209
210
211class GemmaAttention(nn.Module):212
213def __init__(214self,215hidden_size: int,216num_heads: int,217num_kv_heads: int,218head_dim: int,219world_size: int,220rank: int,221quant: bool,222):223super().__init__()224self.rank = rank225
226def init_method(x):227return x228
229self.total_num_heads = num_heads230assert self.total_num_heads % world_size == 0231self.num_heads = self.total_num_heads // world_size # head per shard232
233if num_kv_heads < world_size:234assert world_size % num_kv_heads == 0235self.total_num_kv_heads = world_size236else:237assert num_kv_heads % world_size == 0238self.total_num_kv_heads = num_kv_heads239self.num_kv_heads = self.total_num_kv_heads // world_size # kv head per shard240
241assert self.num_heads % self.num_kv_heads == 0242self.num_queries_per_kv = self.num_heads // self.num_kv_heads243
244self.hidden_size = hidden_size245self.head_dim = head_dim246
247self.q_size = self.num_heads * self.head_dim248self.kv_size = self.num_kv_heads * self.head_dim249
250self.scaling = self.head_dim**-0.5251
252self.qkv_proj = ColumnParallelLinear(253self.hidden_size,254(self.total_num_heads + 2 * self.total_num_kv_heads) *255self.head_dim,256bias=False,257gather_output=False,258init_method=init_method,259world_size=world_size,260rank=rank,261quant=quant,262)263
264self.o_proj = RowParallelLinear(265self.total_num_heads * self.head_dim,266self.hidden_size,267bias=False,268input_is_parallel=True,269init_method=init_method,270world_size=world_size,271rank=rank,272quant=quant,273)274
275def forward(276self,277hidden_states: torch.Tensor,278freqs_cis: torch.Tensor,279kv_write_indices: torch.Tensor,280kv_cache: Tuple[torch.Tensor, torch.Tensor],281mask: torch.Tensor,282) -> torch.Tensor:283hidden_states_shape = hidden_states.shape284assert len(hidden_states_shape) == 3285
286batch_size, input_len, _ = hidden_states_shape287
288qkv = self.qkv_proj(hidden_states)289xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],290dim=-1)291
292xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)293xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)294xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)295
296# Positional embedding.297xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)298xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)299
300# Write new kv cache.301# [batch_size, input_len, n_local_kv_heads, head_dim]302k_cache, v_cache = kv_cache303k_cache.index_copy_(1, kv_write_indices, xk)304v_cache.index_copy_(1, kv_write_indices, xv)305
306key = k_cache307value = v_cache308if self.num_kv_heads != self.num_heads:309# [batch_size, max_seq_len, n_local_heads, head_dim]310key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)311value = torch.repeat_interleave(value,312self.num_queries_per_kv,313dim=2)314
315# [batch_size, n_local_heads, input_len, head_dim]316q = xq.transpose(1, 2)317# [batch_size, n_local_heads, max_seq_len, head_dim]318k = key.transpose(1, 2)319v = value.transpose(1, 2)320
321# [batch_size, n_local_heads, input_len, max_seq_len]322scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling323scores = scores + mask324scores = F.softmax(scores.float(), dim=-1).type_as(q)325
326# [batch_size, n_local_heads, input_len, head_dim]327output = torch.matmul(scores, v)328
329# [batch_size, input_len, hidden_dim]330output = (output.transpose(1, 2).contiguous().view(331batch_size, input_len, -1))332output = self.o_proj(output)333return output334
335
336class GemmaDecoderLayer(nn.Module):337
338def __init__(339self,340config: gemma_config.GemmaConfig,341world_size: int,342rank: int,343):344super().__init__()345self.rank = rank346self.self_attn = GemmaAttention(347hidden_size=config.hidden_size,348num_heads=config.num_attention_heads,349num_kv_heads=config.num_key_value_heads,350head_dim=config.head_dim,351world_size=world_size,352rank=rank,353quant=config.quant,354)355self.mlp = GemmaMLP(356hidden_size=config.hidden_size,357intermediate_size=config.intermediate_size,358world_size=world_size,359rank=rank,360quant=config.quant,361)362self.input_layernorm = RMSNorm(config.hidden_size,363eps=config.rms_norm_eps)364self.post_attention_layernorm = RMSNorm(config.hidden_size,365eps=config.rms_norm_eps)366
367def forward(368self,369hidden_states: torch.Tensor,370freqs_cis: torch.Tensor,371kv_write_indices: torch.Tensor,372kv_cache: Tuple[torch.Tensor, torch.Tensor],373mask: torch.Tensor,374) -> torch.Tensor:375# Self Attention376residual = hidden_states377hidden_states = self.input_layernorm(hidden_states)378hidden_states = self.self_attn(379hidden_states=hidden_states,380freqs_cis=freqs_cis,381kv_write_indices=kv_write_indices,382kv_cache=kv_cache,383mask=mask,384)385hidden_states = residual + hidden_states386
387# MLP388residual = hidden_states389hidden_states = self.post_attention_layernorm(hidden_states)390hidden_states = self.mlp(hidden_states)391hidden_states = residual + hidden_states392
393return hidden_states394
395
396class GemmaModel(nn.Module):397
398def __init__(399self,400config: gemma_config.GemmaConfig,401world_size: int,402rank: int403):404super().__init__()405self.config = config406self.rank = rank407self.vocab_size = config.vocab_size408
409self.layers = nn.ModuleList()410for _ in range(config.num_hidden_layers):411self.layers.append(GemmaDecoderLayer(config, world_size, rank))412self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)413
414def forward(415self,416hidden_states: torch.Tensor,417freqs_cis: torch.Tensor,418kv_write_indices: torch.Tensor,419kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],420mask: torch.Tensor,421) -> torch.Tensor:422for i in range(len(self.layers)):423layer = self.layers[i]424hidden_states = layer(425hidden_states=hidden_states,426freqs_cis=freqs_cis,427kv_write_indices=kv_write_indices,428kv_cache=kv_caches[i],429mask=mask,430)431hidden_states = self.norm(hidden_states)432return hidden_states433
434
435class GemmaForCausalLM(nn.Module):436
437def __init__(438self,439config: gemma_config.GemmaConfig,440world_size: int,441rank: int,442device: torch.device,443):444super().__init__()445self.config = config446self.world_size = world_size447self.rank = rank448self.device = device449
450assert config.num_attention_heads % world_size == 0451assert config.hidden_size % config.num_attention_heads == 0452
453max_seq_len = config.max_position_embeddings454head_dim = config.head_dim455vocab_size = config.vocab_size456
457def init_method(x):458return x459
460self.embedder = ParallelEmbedding(461vocab_size,462config.hidden_size,463init_method=init_method,464world_size=world_size,465rank=rank,466quant=config.quant,467)468self.model = GemmaModel(config, world_size, rank)469self.sampler = Sampler(vocab_size, world_size, rank)470
471rope_theta = getattr(config, 'rope_theta', 10000)472# [head_dim * 2, ] -> complex -> two dim (real, imaginary) implicitly473freqs_cis = precompute_freqs_cis(head_dim,474max_seq_len * 2,475theta=rope_theta)476self.register_buffer('freqs_cis', freqs_cis)477
478@torch.no_grad()479def forward(480self,481input_token_ids: torch.Tensor,482input_positions: torch.Tensor,483kv_write_indices: torch.Tensor,484kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],485mask: torch.Tensor,486output_positions: torch.Tensor,487temperatures: torch.Tensor,488top_ps: torch.Tensor,489top_ks: torch.Tensor,490**kwargs,491) -> torch.Tensor:492freqs_cis = self.freqs_cis.index_select(0, input_positions)493kv_write_indices = input_positions494
495hidden_states = self.embedder(input_token_ids)496# Gemma normalizes the embedding by sqrt(hidden_size).497hidden_states = hidden_states * (self.config.hidden_size**0.5)498# hidden_states should be [batch_size, input_len, hidden_size]499
500hidden_states = self.model(501hidden_states=hidden_states,502freqs_cis=freqs_cis,503kv_write_indices=kv_write_indices,504kv_caches=kv_caches,505mask=mask,506)507embedder_weight = self.embedder.weight508if self.config.quant:509embedder_weight = (510embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))511next_tokens = self.sampler(512embedding=embedder_weight,513hidden_states=hidden_states,514output_positions=output_positions,515temperatures=temperatures,516top_ps=top_ps,517top_ks=top_ks,518)519return next_tokens520
521def load_weights(self, model_path: str):522checkpoint = torch.load(model_path, weights_only=True)523model_state_dict = checkpoint['model_state_dict']524
525num_attn_heads = self.config.num_attention_heads526num_kv_heads = self.config.num_key_value_heads527head_dim = self.config.head_dim528hidden_size = self.config.hidden_size529
530def split(tensor: torch.Tensor, axis: int) -> torch.Tensor:531axis_len = tensor.shape[axis]532split_len = axis_len // self.world_size533split_start = split_len * self.rank534split_end = split_start + split_len535tensor = torch.moveaxis(tensor, axis, 0)536tensor = tensor[split_start:split_end, ...]537tensor = torch.moveaxis(tensor, 0, axis)538return tensor539
540for k, v in model_state_dict.items():541if k == 'freqs_cis':542continue543if (k == 'model.norm.weight' or re.fullmatch(544r'model.layers.\d+.input_layernorm.weight', k)545or re.fullmatch(546r'model.layers.\d+.post_attention_layernorm.weight',547k) or k.endswith('weight_scaler')):548pass549elif (k == 'embedder.weight' or re.fullmatch(550r'model.layers.\d+.mlp.down_proj.weight', k)):551v = split(v, 1)552elif (re.fullmatch(r'model.layers.\d+.mlp.gate_proj.weight', k)553or re.fullmatch(r'model.layers.\d+.mlp.up_proj.weight', k)):554v = split(v, 0)555elif re.fullmatch(r'model.layers.\d+.self_attn.qkv_proj.weight',556k):557if num_kv_heads <= self.world_size:558num_replicas = self.world_size // num_kv_heads559v = v.reshape(num_attn_heads + num_kv_heads * 2, head_dim,560hidden_size)561query = v[:num_attn_heads, ...]562key = v[num_attn_heads:num_attn_heads + num_kv_heads,563...].repeat(num_replicas, 1, 1)564value = v[-num_kv_heads:, ...].repeat(num_replicas, 1, 1)565v = torch.cat(566(split(query, 0), split(key, 0), split(value, 0)),567dim=0)568else:569v = v.reshape(3, num_attn_heads, head_dim, hidden_size)570v = split(v, 1)571v = v.reshape(-1, hidden_size)572elif re.fullmatch(r'model.layers.\d+.self_attn.o_proj.weight', k):573v = v.reshape(hidden_size, num_attn_heads, head_dim)574v = split(v, 1)575v = v.reshape(hidden_size, -1)576else:577raise ValueError(f'Unrecognized key: {k}')578self.state_dict()[k].copy_(v)579