pytorch

Форк
0
292 строки · 9.6 Кб
1
# flake8: noqa: E266, C417, B950
2
from dataclasses import dataclass
3
from typing import Optional
4

5
import torch
6
import torch.nn as nn
7
from torch import Tensor
8
from torch.nn import functional as F
9

10

11
def find_multiple(n: int, k: int) -> int:
12
    if n % k == 0:
13
        return n
14
    return n + k - (n % k)
15

16

17
@dataclass
18
class ModelArgs:
19
    block_size: int = 2048
20
    vocab_size: int = 32000
21
    n_layer: int = 32
22
    n_head: int = 32
23
    dim: int = 4096
24
    intermediate_size: int = None
25
    n_local_heads: int = -1
26
    head_dim: int = 64
27
    rope_base: float = 10000
28
    norm_eps: float = 1e-5
29

30
    def __post_init__(self):
31
        if self.n_local_heads == -1:
32
            self.n_local_heads = self.n_head
33
        if self.intermediate_size is None:
34
            hidden_dim = 4 * self.dim
35
            n_hidden = int(2 * hidden_dim / 3)
36
            self.intermediate_size = find_multiple(n_hidden, 256)
37
        self.head_dim = self.dim // self.n_head
38

39
    @classmethod
40
    def from_name(cls, name: str):
41
        if name in transformer_configs:
42
            return cls(**transformer_configs[name])
43
        # fuzzy search
44
        config = [
45
            config
46
            for config in transformer_configs
47
            if config in str(name).upper() or config in str(name)
48
        ]
49

50
        # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
51
        # take longer name (as it have more symbols matched)
52
        if len(config) > 1:
53
            config.sort(key=len, reverse=True)
54
            assert len(config[0]) != len(
55
                config[1]
56
            ), name  # make sure only one 'best' match
57

58
        return cls(**transformer_configs[config[0]])
59

60

61
transformer_configs = {
62
    "CodeLlama-7b-Python-hf": dict(
63
        block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000
64
    ),
65
    "7B": dict(n_layer=32, n_head=32, dim=4096),
66
    "13B": dict(n_layer=40, n_head=40, dim=5120),
67
    "30B": dict(n_layer=60, n_head=52, dim=6656),
68
    "34B": dict(
69
        n_layer=48,
70
        n_head=64,
71
        dim=8192,
72
        vocab_size=32000,
73
        n_local_heads=8,
74
        intermediate_size=22016,
75
        rope_base=1000000,
76
    ),  # CodeLlama-34B-Python-hf
77
    "70B": dict(
78
        n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672
79
    ),
80
    "Mistral-7B": dict(
81
        n_layer=32,
82
        n_head=32,
83
        n_local_heads=8,
84
        dim=4096,
85
        intermediate_size=14336,
86
        vocab_size=32000,
87
    ),
88
}
89

90

91
class KVCache(nn.Module):
92
    def __init__(
93
        self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
94
    ):
95
        super().__init__()
96
        cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
97
        self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
98
        self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
99

100
    def update(self, input_pos, k_val, v_val):
101
        # input_pos: [S], k_val: [B, H, S, D]
102
        assert input_pos.shape[0] == k_val.shape[2]
103

104
        k_out = self.k_cache
105
        v_out = self.v_cache
106
        k_out[:, :, input_pos] = k_val
107
        v_out[:, :, input_pos] = v_val
108

109
        return k_out, v_out
110

111

112
class Transformer(nn.Module):
113
    def __init__(self, config: ModelArgs) -> None:
114
        super().__init__()
115
        self.config = config
116

117
        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
118
        self.layers = nn.ModuleList(
119
            TransformerBlock(config) for _ in range(config.n_layer)
120
        )
121
        self.norm = RMSNorm(config.dim, eps=config.norm_eps)
122
        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
123

124
        self.freqs_cis: Optional[Tensor] = None
125
        self.mask_cache: Optional[Tensor] = None
126
        self.max_batch_size = -1
127
        self.max_seq_length = -1
128

129
    def setup_caches(self, max_batch_size, max_seq_length):
130
        if (
131
            self.max_seq_length >= max_seq_length
132
            and self.max_batch_size >= max_batch_size
133
        ):
134
            return
135
        head_dim = self.config.dim // self.config.n_head
136
        max_seq_length = find_multiple(max_seq_length, 8)
137
        self.max_seq_length = max_seq_length
138
        self.max_batch_size = max_batch_size
139
        for b in self.layers:
140
            b.attention.kv_cache = KVCache(
141
                max_batch_size, max_seq_length, self.config.n_local_heads, head_dim
142
            )
143

144
        self.freqs_cis = precompute_freqs_cis(
145
            self.config.block_size,
146
            self.config.dim // self.config.n_head,
147
            self.config.rope_base,
148
        )
149
        self.causal_mask = torch.tril(
150
            torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
151
        )
152

153
    def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
154
        assert self.freqs_cis is not None, "Caches must be initialized first"
155
        mask = self.causal_mask[None, None, input_pos]
156
        freqs_cis = self.freqs_cis[input_pos]
157
        x = self.tok_embeddings(idx)
158

159
        for i, layer in enumerate(self.layers):
160
            x = layer(x, input_pos, freqs_cis, mask)
161
        x = self.norm(x)
162
        logits = self.output(x)
163
        return logits
164

165
    @classmethod
166
    def from_name(cls, name: str):
167
        return cls(ModelArgs.from_name(name))
168

169

170
class TransformerBlock(nn.Module):
171
    def __init__(self, config: ModelArgs) -> None:
172
        super().__init__()
173
        self.attention = Attention(config)
174
        self.feed_forward = FeedForward(config)
175
        self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
176
        self.attention_norm = RMSNorm(config.dim, config.norm_eps)
177

178
    def forward(
179
        self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor
180
    ) -> Tensor:
181
        h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
182
        out = h + self.feed_forward(self.ffn_norm(h))
183
        return out
184

185

186
class Attention(nn.Module):
187
    def __init__(self, config: ModelArgs):
188
        super().__init__()
189
        assert config.dim % config.n_head == 0
190

191
        total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
192
        # key, query, value projections for all heads, but in a batch
193
        self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
194
        self.wo = nn.Linear(config.dim, config.dim, bias=False)
195
        self.kv_cache = None
196

197
        self.n_head = config.n_head
198
        self.head_dim = config.head_dim
199
        self.n_local_heads = config.n_local_heads
200
        self.dim = config.dim
201
        self._register_load_state_dict_pre_hook(self.load_hook)
202

203
    def load_hook(self, state_dict, prefix, *args):
204
        if prefix + "wq.weight" in state_dict:
205
            wq = state_dict.pop(prefix + "wq.weight")
206
            wk = state_dict.pop(prefix + "wk.weight")
207
            wv = state_dict.pop(prefix + "wv.weight")
208
            state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
209

210
    def forward(
211
        self,
212
        x: Tensor,
213
        freqs_cis: Tensor,
214
        mask: Tensor,
215
        input_pos: Optional[Tensor] = None,
216
    ) -> Tensor:
217
        bsz, seqlen, _ = x.shape
218

219
        kv_size = self.n_local_heads * self.head_dim
220
        q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
221

222
        q = q.view(bsz, seqlen, self.n_head, self.head_dim)
223
        k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
224
        v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
225

226
        q = apply_rotary_emb(q, freqs_cis)
227
        k = apply_rotary_emb(k, freqs_cis)
228

229
        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
230

231
        if self.kv_cache is not None:
232
            k, v = self.kv_cache.update(input_pos, k, v)
233

234
        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
235
        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
236
        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
237

238
        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
239

240
        y = self.wo(y)
241
        return y
242

243

244
class FeedForward(nn.Module):
245
    def __init__(self, config: ModelArgs) -> None:
246
        super().__init__()
247
        self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
248
        self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
249
        self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
250

251
    def forward(self, x: Tensor) -> Tensor:
252
        return self.w2(F.silu(self.w1(x)) * self.w3(x))
253

254

255
class RMSNorm(nn.Module):
256
    def __init__(self, dim: int, eps: float = 1e-5):
257
        super().__init__()
258
        self.eps = eps
259
        self.weight = nn.Parameter(torch.ones(dim))
260

261
    def _norm(self, x):
262
        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
263

264
    def forward(self, x: Tensor) -> Tensor:
265
        output = self._norm(x.float()).type_as(x)
266
        return output * self.weight
267

268

269
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
270
    freqs = 1.0 / (
271
        base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
272
    )
273
    t = torch.arange(seq_len, device=freqs.device)
274
    freqs = torch.outer(t, freqs)
275
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
276
    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
277
    return cache.to(dtype=torch.bfloat16)
278

279

280
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
281
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
282
    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
283
    x_out2 = torch.stack(
284
        [
285
            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
286
            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
287
        ],
288
        -1,
289
    )
290

291
    x_out2 = x_out2.flatten(3)
292
    return x_out2.type_as(x)
293

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.