pytorch
292 строки · 9.6 Кб
1# flake8: noqa: E266, C417, B950
2from dataclasses import dataclass
3from typing import Optional
4
5import torch
6import torch.nn as nn
7from torch import Tensor
8from torch.nn import functional as F
9
10
11def find_multiple(n: int, k: int) -> int:
12if n % k == 0:
13return n
14return n + k - (n % k)
15
16
17@dataclass
18class ModelArgs:
19block_size: int = 2048
20vocab_size: int = 32000
21n_layer: int = 32
22n_head: int = 32
23dim: int = 4096
24intermediate_size: int = None
25n_local_heads: int = -1
26head_dim: int = 64
27rope_base: float = 10000
28norm_eps: float = 1e-5
29
30def __post_init__(self):
31if self.n_local_heads == -1:
32self.n_local_heads = self.n_head
33if self.intermediate_size is None:
34hidden_dim = 4 * self.dim
35n_hidden = int(2 * hidden_dim / 3)
36self.intermediate_size = find_multiple(n_hidden, 256)
37self.head_dim = self.dim // self.n_head
38
39@classmethod
40def from_name(cls, name: str):
41if name in transformer_configs:
42return cls(**transformer_configs[name])
43# fuzzy search
44config = [
45config
46for config in transformer_configs
47if config in str(name).upper() or config in str(name)
48]
49
50# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
51# take longer name (as it have more symbols matched)
52if len(config) > 1:
53config.sort(key=len, reverse=True)
54assert len(config[0]) != len(
55config[1]
56), name # make sure only one 'best' match
57
58return cls(**transformer_configs[config[0]])
59
60
61transformer_configs = {
62"CodeLlama-7b-Python-hf": dict(
63block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000
64),
65"7B": dict(n_layer=32, n_head=32, dim=4096),
66"13B": dict(n_layer=40, n_head=40, dim=5120),
67"30B": dict(n_layer=60, n_head=52, dim=6656),
68"34B": dict(
69n_layer=48,
70n_head=64,
71dim=8192,
72vocab_size=32000,
73n_local_heads=8,
74intermediate_size=22016,
75rope_base=1000000,
76), # CodeLlama-34B-Python-hf
77"70B": dict(
78n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672
79),
80"Mistral-7B": dict(
81n_layer=32,
82n_head=32,
83n_local_heads=8,
84dim=4096,
85intermediate_size=14336,
86vocab_size=32000,
87),
88}
89
90
91class KVCache(nn.Module):
92def __init__(
93self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
94):
95super().__init__()
96cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
97self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
98self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
99
100def update(self, input_pos, k_val, v_val):
101# input_pos: [S], k_val: [B, H, S, D]
102assert input_pos.shape[0] == k_val.shape[2]
103
104k_out = self.k_cache
105v_out = self.v_cache
106k_out[:, :, input_pos] = k_val
107v_out[:, :, input_pos] = v_val
108
109return k_out, v_out
110
111
112class Transformer(nn.Module):
113def __init__(self, config: ModelArgs) -> None:
114super().__init__()
115self.config = config
116
117self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
118self.layers = nn.ModuleList(
119TransformerBlock(config) for _ in range(config.n_layer)
120)
121self.norm = RMSNorm(config.dim, eps=config.norm_eps)
122self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
123
124self.freqs_cis: Optional[Tensor] = None
125self.mask_cache: Optional[Tensor] = None
126self.max_batch_size = -1
127self.max_seq_length = -1
128
129def setup_caches(self, max_batch_size, max_seq_length):
130if (
131self.max_seq_length >= max_seq_length
132and self.max_batch_size >= max_batch_size
133):
134return
135head_dim = self.config.dim // self.config.n_head
136max_seq_length = find_multiple(max_seq_length, 8)
137self.max_seq_length = max_seq_length
138self.max_batch_size = max_batch_size
139for b in self.layers:
140b.attention.kv_cache = KVCache(
141max_batch_size, max_seq_length, self.config.n_local_heads, head_dim
142)
143
144self.freqs_cis = precompute_freqs_cis(
145self.config.block_size,
146self.config.dim // self.config.n_head,
147self.config.rope_base,
148)
149self.causal_mask = torch.tril(
150torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
151)
152
153def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
154assert self.freqs_cis is not None, "Caches must be initialized first"
155mask = self.causal_mask[None, None, input_pos]
156freqs_cis = self.freqs_cis[input_pos]
157x = self.tok_embeddings(idx)
158
159for i, layer in enumerate(self.layers):
160x = layer(x, input_pos, freqs_cis, mask)
161x = self.norm(x)
162logits = self.output(x)
163return logits
164
165@classmethod
166def from_name(cls, name: str):
167return cls(ModelArgs.from_name(name))
168
169
170class TransformerBlock(nn.Module):
171def __init__(self, config: ModelArgs) -> None:
172super().__init__()
173self.attention = Attention(config)
174self.feed_forward = FeedForward(config)
175self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
176self.attention_norm = RMSNorm(config.dim, config.norm_eps)
177
178def forward(
179self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor
180) -> Tensor:
181h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
182out = h + self.feed_forward(self.ffn_norm(h))
183return out
184
185
186class Attention(nn.Module):
187def __init__(self, config: ModelArgs):
188super().__init__()
189assert config.dim % config.n_head == 0
190
191total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
192# key, query, value projections for all heads, but in a batch
193self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
194self.wo = nn.Linear(config.dim, config.dim, bias=False)
195self.kv_cache = None
196
197self.n_head = config.n_head
198self.head_dim = config.head_dim
199self.n_local_heads = config.n_local_heads
200self.dim = config.dim
201self._register_load_state_dict_pre_hook(self.load_hook)
202
203def load_hook(self, state_dict, prefix, *args):
204if prefix + "wq.weight" in state_dict:
205wq = state_dict.pop(prefix + "wq.weight")
206wk = state_dict.pop(prefix + "wk.weight")
207wv = state_dict.pop(prefix + "wv.weight")
208state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
209
210def forward(
211self,
212x: Tensor,
213freqs_cis: Tensor,
214mask: Tensor,
215input_pos: Optional[Tensor] = None,
216) -> Tensor:
217bsz, seqlen, _ = x.shape
218
219kv_size = self.n_local_heads * self.head_dim
220q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
221
222q = q.view(bsz, seqlen, self.n_head, self.head_dim)
223k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
224v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
225
226q = apply_rotary_emb(q, freqs_cis)
227k = apply_rotary_emb(k, freqs_cis)
228
229q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
230
231if self.kv_cache is not None:
232k, v = self.kv_cache.update(input_pos, k, v)
233
234k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
235v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
236y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
237
238y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
239
240y = self.wo(y)
241return y
242
243
244class FeedForward(nn.Module):
245def __init__(self, config: ModelArgs) -> None:
246super().__init__()
247self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
248self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
249self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
250
251def forward(self, x: Tensor) -> Tensor:
252return self.w2(F.silu(self.w1(x)) * self.w3(x))
253
254
255class RMSNorm(nn.Module):
256def __init__(self, dim: int, eps: float = 1e-5):
257super().__init__()
258self.eps = eps
259self.weight = nn.Parameter(torch.ones(dim))
260
261def _norm(self, x):
262return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
263
264def forward(self, x: Tensor) -> Tensor:
265output = self._norm(x.float()).type_as(x)
266return output * self.weight
267
268
269def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
270freqs = 1.0 / (
271base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
272)
273t = torch.arange(seq_len, device=freqs.device)
274freqs = torch.outer(t, freqs)
275freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
276cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
277return cache.to(dtype=torch.bfloat16)
278
279
280def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
281xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
282freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
283x_out2 = torch.stack(
284[
285xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
286xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
287],
288-1,
289)
290
291x_out2 = x_out2.flatten(3)
292return x_out2.type_as(x)
293