gemma_pytorch
566 строк · 20.1 Кб
1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Inference-only Gemma model implementation."""
15
16import re17import torch18from torch import nn19import torch.nn.functional as F20from typing import Any, List, Optional, Sequence, Tuple, Union21
22from gemma import config as gemma_config23from gemma import tokenizer24
25
26class Sampler(nn.Module):27
28def __init__(self, vocab_size: int):29super().__init__()30self.vocab_size = vocab_size31
32@torch.no_grad()33def forward(34self,35embedding: torch.Tensor,36hidden_states: torch.Tensor,37output_positions: torch.Tensor,38temperatures: torch.Tensor,39top_ps: torch.Tensor,40top_ks: torch.Tensor,41embedding_bias: Optional[torch.Tensor] = None,42) -> torch.Tensor:43# Select the last element for each sequence.44# (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)45hidden_states = hidden_states.index_select(461, output_positions).squeeze(dim=1)47logits = torch.matmul(hidden_states, embedding.t())48if embedding_bias is not None:49logits += embedding_bias50
51if temperatures is None:52return torch.argmax(logits, dim=-1).squeeze(dim=-1)53
54# Apply temperature scaling.55logits.div_(temperatures.unsqueeze(dim=1))56
57# Calculate probabilities with softmax.58probs = torch.softmax(logits, dim=-1, dtype=torch.float)59probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)60
61# Apply top-p, top-k.62probs_sum = torch.cumsum(probs_sort, dim=-1)63top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)64probs_sort = torch.where(top_ps_mask, 0, probs_sort)65
66top_ks_mask = torch.arange(probs_idx.shape[-1],67device=probs_idx.device)68top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)69top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)70probs_sort = torch.where(top_ks_mask, 0, probs_sort)71
72# Re-normalization.73probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))74probs = torch.gather(probs_sort,75dim=-1,76index=torch.argsort(probs_idx, dim=-1))77
78next_token_ids = torch.multinomial(probs,79num_samples=1,80replacement=True).squeeze(dim=-1)81return next_token_ids82
83
84def precompute_freqs_cis(dim: int,85end: int,86theta: float = 10000.0) -> torch.Tensor:87"""Precomputes the frequency cis."""88freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))89t = torch.arange(end, device=freqs.device)90freqs = torch.outer(t, freqs).float()91freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex6492return freqs_cis93
94
95def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:96"""Applies the rotary embedding to the query and key tensors."""97x_ = torch.view_as_complex(98torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),99dim=-1))100x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)101x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)102x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],103-1).transpose(1, 2)104return x_out105
106
107class Linear(nn.Module):108
109def __init__(self, in_features: int, out_features: int, quant: bool):110super().__init__()111if quant:112self.weight = nn.Parameter(113torch.empty((out_features, in_features), dtype=torch.int8),114requires_grad=False,115)116self.weight_scaler = nn.Parameter(torch.Tensor(out_features))117else:118self.weight = nn.Parameter(119torch.empty((out_features, in_features)),120requires_grad=False,121)122self.quant = quant123
124def forward(self, x):125weight = self.weight126if self.quant:127weight = weight * self.weight_scaler.unsqueeze(-1)128output = F.linear(x, weight)129return output130
131
132class Embedding(nn.Module):133
134def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool):135super().__init__()136if quant:137self.weight = nn.Parameter(138torch.empty((num_embeddings, embedding_dim), dtype=torch.int8),139requires_grad=False,140)141self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings))142else:143self.weight = nn.Parameter(144torch.empty((num_embeddings, embedding_dim)),145requires_grad=False,146)147self.quant = quant148
149def forward(self, x):150weight = self.weight151if self.quant:152weight = weight * self.weight_scaler.unsqueeze(-1)153output = F.embedding(x, weight)154return output155
156
157class RMSNorm(torch.nn.Module):158
159def __init__(160self,161dim: int,162eps: float = 1e-6,163add_unit_offset: bool = True,164):165super().__init__()166self.eps = eps167self.add_unit_offset = add_unit_offset168self.weight = nn.Parameter(torch.zeros(dim))169
170def _norm(self, x):171return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)172
173def forward(self, x):174x = self._norm(x.float()).type_as(x)175if self.add_unit_offset:176output = x * (1 + self.weight)177else:178output = x * self.weight179return output180
181
182class GemmaMLP(nn.Module):183
184def __init__(185self,186hidden_size: int,187intermediate_size: int,188quant: bool,189):190super().__init__()191self.gate_proj = Linear(hidden_size, intermediate_size, quant)192self.up_proj = Linear(hidden_size, intermediate_size, quant)193self.down_proj = Linear(intermediate_size, hidden_size, quant)194
195def forward(self, x):196gate = self.gate_proj(x)197gate = F.gelu(gate, approximate="tanh")198up = self.up_proj(x)199fuse = gate * up200outputs = self.down_proj(fuse)201return outputs202
203
204class GemmaAttention(nn.Module):205
206def __init__(207self,208hidden_size: int,209num_heads: int,210num_kv_heads: int,211head_dim: int,212quant: bool,213):214super().__init__()215
216self.num_heads = num_heads217self.num_kv_heads = num_kv_heads218
219assert self.num_heads % self.num_kv_heads == 0220self.num_queries_per_kv = self.num_heads // self.num_kv_heads221
222self.hidden_size = hidden_size223self.head_dim = head_dim224
225self.q_size = self.num_heads * self.head_dim226self.kv_size = self.num_kv_heads * self.head_dim227
228self.scaling = self.head_dim**-0.5229
230self.qkv_proj = Linear(231self.hidden_size,232(self.num_heads + 2 * self.num_kv_heads) * self.head_dim,233quant=quant)234self.o_proj = Linear(235self.num_heads * self.head_dim,236self.hidden_size,237quant=quant)238
239def forward(240self,241hidden_states: torch.Tensor,242freqs_cis: torch.Tensor,243kv_write_indices: torch.Tensor,244kv_cache: Tuple[torch.Tensor, torch.Tensor],245mask: torch.Tensor,246) -> torch.Tensor:247hidden_states_shape = hidden_states.shape248assert len(hidden_states_shape) == 3249
250batch_size, input_len, _ = hidden_states_shape251
252qkv = self.qkv_proj(hidden_states)253xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],254dim=-1)255
256xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)257xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)258xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)259
260# Positional embedding.261xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)262xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)263
264# Write new kv cache.265# [batch_size, input_len, n_local_kv_heads, head_dim]266k_cache, v_cache = kv_cache267k_cache.index_copy_(1, kv_write_indices, xk)268v_cache.index_copy_(1, kv_write_indices, xv)269
270key = k_cache271value = v_cache272if self.num_kv_heads != self.num_heads:273# [batch_size, max_seq_len, n_local_heads, head_dim]274key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)275value = torch.repeat_interleave(value,276self.num_queries_per_kv,277dim=2)278
279# [batch_size, n_local_heads, input_len, head_dim]280q = xq.transpose(1, 2)281# [batch_size, n_local_heads, max_seq_len, head_dim]282k = key.transpose(1, 2)283v = value.transpose(1, 2)284
285# [batch_size, n_local_heads, input_len, max_seq_len]286scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling287scores = scores + mask288scores = F.softmax(scores.float(), dim=-1).type_as(q)289
290# [batch_size, n_local_heads, input_len, head_dim]291output = torch.matmul(scores, v)292
293# [batch_size, input_len, hidden_dim]294output = (output.transpose(1, 2).contiguous().view(295batch_size, input_len, -1))296output = self.o_proj(output)297return output298
299
300class GemmaDecoderLayer(nn.Module):301
302def __init__(303self,304config: gemma_config.GemmaConfig,305):306super().__init__()307self.self_attn = GemmaAttention(308hidden_size=config.hidden_size,309num_heads=config.num_attention_heads,310num_kv_heads=config.num_key_value_heads,311head_dim=config.head_dim,312quant=config.quant,313)314self.mlp = GemmaMLP(315hidden_size=config.hidden_size,316intermediate_size=config.intermediate_size,317quant=config.quant,318)319self.input_layernorm = RMSNorm(config.hidden_size,320eps=config.rms_norm_eps)321self.post_attention_layernorm = RMSNorm(config.hidden_size,322eps=config.rms_norm_eps)323
324def forward(325self,326hidden_states: torch.Tensor,327freqs_cis: torch.Tensor,328kv_write_indices: torch.Tensor,329kv_cache: Tuple[torch.Tensor, torch.Tensor],330mask: torch.Tensor,331) -> torch.Tensor:332# Self Attention333residual = hidden_states334hidden_states = self.input_layernorm(hidden_states)335hidden_states = self.self_attn(336hidden_states=hidden_states,337freqs_cis=freqs_cis,338kv_write_indices=kv_write_indices,339kv_cache=kv_cache,340mask=mask,341)342hidden_states = residual + hidden_states343
344# MLP345residual = hidden_states346hidden_states = self.post_attention_layernorm(hidden_states)347hidden_states = self.mlp(hidden_states)348hidden_states = residual + hidden_states349
350return hidden_states351
352
353class GemmaModel(nn.Module):354
355def __init__(self, config: gemma_config.GemmaConfig):356super().__init__()357self.config = config358self.vocab_size = config.vocab_size359
360self.layers = nn.ModuleList()361for _ in range(config.num_hidden_layers):362self.layers.append(GemmaDecoderLayer(config))363self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)364
365def forward(366self,367hidden_states: torch.Tensor,368freqs_cis: torch.Tensor,369kv_write_indices: torch.Tensor,370kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],371mask: torch.Tensor,372) -> torch.Tensor:373for i in range(len(self.layers)):374layer = self.layers[i]375hidden_states = layer(376hidden_states=hidden_states,377freqs_cis=freqs_cis,378kv_write_indices=kv_write_indices,379kv_cache=kv_caches[i],380mask=mask,381)382hidden_states = self.norm(hidden_states)383return hidden_states384
385
386class GemmaForCausalLM(nn.Module):387
388def __init__(389self,390config: gemma_config.GemmaConfig,391):392super().__init__()393self.config = config394assert config.hidden_size % config.num_attention_heads == 0395
396max_seq_len = config.max_position_embeddings397head_dim = config.head_dim398vocab_size = config.vocab_size399
400self.tokenizer = tokenizer.Tokenizer(config.tokenizer)401self.embedder = Embedding(vocab_size, config.hidden_size, config.quant)402self.model = GemmaModel(config)403self.sampler = Sampler(vocab_size)404
405# Pre-compute rotary embedding table.406rope_theta = getattr(config, 'rope_theta', 10000)407freqs_cis = precompute_freqs_cis(head_dim,408max_seq_len * 2,409theta=rope_theta)410self.register_buffer('freqs_cis', freqs_cis)411
412@torch.no_grad()413def forward(414self,415input_token_ids: torch.Tensor,416input_positions: torch.Tensor,417kv_write_indices: torch.Tensor,418kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],419mask: torch.Tensor,420output_positions: torch.Tensor,421temperatures: torch.Tensor,422top_ps: torch.Tensor,423top_ks: torch.Tensor,424**kwargs,425) -> torch.Tensor:426freqs_cis = self.freqs_cis.index_select(0, input_positions)427kv_write_indices = input_positions428
429# [batch_size, input_len, hidden_size]430hidden_states = self.embedder(input_token_ids)431# Gemma normalizes the embedding by sqrt(hidden_size).432hidden_states = hidden_states * (self.config.hidden_size**0.5)433
434hidden_states = self.model(435hidden_states=hidden_states,436freqs_cis=freqs_cis,437kv_write_indices=kv_write_indices,438kv_caches=kv_caches,439mask=mask,440)441embedder_weight = self.embedder.weight442if self.config.quant:443embedder_weight = (444embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))445next_tokens = self.sampler(446embedding=embedder_weight,447hidden_states=hidden_states,448output_positions=output_positions,449temperatures=temperatures,450top_ps=top_ps,451top_ks=top_ks,452)453return next_tokens454
455def generate(456self,457prompts: Union[str, Sequence[str]],458device: Any,459output_len: int = 100,460temperature: float = 0.95,461top_p: float = 1.0,462top_k: int = 100,463) -> Union[str, Sequence[str]]:464"""Generates responses for given prompts using Gemma model."""465# If a single prompt is provided, treat it as a batch of 1.466is_str_prompt = isinstance(prompts, str)467if is_str_prompt:468prompts = [prompts]469
470batch_size = len(prompts)471prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts]472min_prompt_len = min(len(p) for p in prompt_tokens)473max_prompt_len = max(len(p) for p in prompt_tokens)474max_seq_len = max_prompt_len + output_len475assert max_seq_len <= self.config.max_position_embeddings476
477# build KV caches478kv_caches = []479for _ in range(self.config.num_hidden_layers):480size = (batch_size, max_seq_len, self.config.num_key_value_heads,481self.config.head_dim)482dtype = self.config.get_dtype()483k_cache = torch.zeros(size=size, dtype=dtype, device=device)484v_cache = torch.zeros(size=size, dtype=dtype, device=device)485kv_caches.append((k_cache, v_cache))486
487# prepare inputs488token_ids_tensor = torch.full((batch_size, max_seq_len),489self.tokenizer.pad_id, dtype=torch.int64)490input_token_ids_tensor = torch.full((batch_size, min_prompt_len),491self.tokenizer.pad_id,492dtype=torch.int64)493for i, p in enumerate(prompt_tokens):494token_ids_tensor[i, :len(p)] = torch.tensor(p)495input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(496p[:min_prompt_len])497token_ids_tensor = token_ids_tensor.to(device)498input_token_ids_tensor = input_token_ids_tensor.to(device)499prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id500input_positions_tensor = torch.arange(0, min_prompt_len,501dtype=torch.int64).to(device)502mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),503-2.3819763e38).to(torch.float)504mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)505curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)506output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(507device)508temperatures_tensor = torch.FloatTensor([temperature] * batch_size).to(509device)510top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)511top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)512output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(513device)514
515# Prefill up to min_prompt_len tokens, then treat other prefill as516# decode and ignore output.517for i in range(max_seq_len - min_prompt_len):518next_token_ids = self(519input_token_ids=input_token_ids_tensor,520input_positions=input_positions_tensor,521kv_write_indices=None,522kv_caches=kv_caches,523mask=curr_mask_tensor,524output_positions=output_positions_tensor,525temperatures=temperatures_tensor,526top_ps=top_ps_tensor,527top_ks=top_ks_tensor,528)529
530curr_prompt_mask = prompt_mask_tensor.index_select(5311, output_index).squeeze(dim=1)532curr_token_ids = token_ids_tensor.index_select(5331, output_index).squeeze(dim=1)534output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,535next_token_ids).unsqueeze(dim=1)536token_ids_tensor.index_copy_(1, output_index, output_token_ids)537
538input_token_ids_tensor = output_token_ids539input_positions_tensor = output_index.unsqueeze(dim=-1)540curr_mask_tensor = mask_tensor.index_select(2,541input_positions_tensor)542output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(543device)544output_index = output_index + 1545
546# Detokenization.547token_ids = token_ids_tensor.tolist()548results = []549for i, tokens in enumerate(token_ids):550trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i])551+ output_len]552if self.tokenizer.eos_id in trimmed_output:553eos_index = trimmed_output.index(self.tokenizer.eos_id)554trimmed_output = trimmed_output[:eos_index]555results.append(self.tokenizer.decode(trimmed_output))556
557# If a string was provided as input, return a string as output.558return results[0] if is_str_prompt else results559
560def load_weights(self, model_path: str):561self.load_state_dict(562torch.load(563model_path, mmap=True, weights_only=True,564)['model_state_dict'],565strict=False,566)567