google-research

Форк
0
480 строк · 16.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Flax Modules."""
17
from flax.deprecated import nn
18
from jax import lax
19
import jax.numpy as jnp
20
import jax.random as jrandom
21
import numpy as np
22

23

24
def shift_right(x, bos_token):
25
  """Shift the input to the right by padding on axis 1 at train time."""
26
  pad_widths = [(0, 0)] * len(x.shape)
27
  pad_widths[1] = (1, 0)  # Padding on axis=1
28
  padded = jnp.pad(
29
      x,
30
      pad_widths,
31
      mode='constant',
32
      constant_values=jnp.asarray(bos_token, dtype=x.dtype))
33
  return padded[:, :-1]
34

35

36
def mask_uniform(inputs, rate, rng, mask_value):
37
  """Applies a random dropout mask to the input.
38

39
  Args:
40
    inputs: the inputs that should be randomly masked.
41
    rate: the probablity of masking out a value.
42
    rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will be
43
      used.
44
    mask_value: Value to mask with.
45

46
  Returns:
47
    The masked inputs.
48
  """
49
  if rate == 0.:
50
    return inputs
51
  keep_prob = 1. - rate
52
  mask = jrandom.bernoulli(rng, p=keep_prob, shape=inputs.shape)
53
  return lax.select(mask, inputs, jnp.full_like(inputs, mask_value))
54

55

56
class Tag(nn.Module):
57
  """Save a value to global state when running in stateful mode."""
58

59
  def apply(self, x):
60
    if self.is_stateful():
61
      tagged = self.state('tag')
62
      tagged.value = x
63
    return x
64

65

66
class Embed(nn.Module):
67
  """Embedding Module.
68

69
  A parameterized function from integers [0, n) to d-dimensional vectors.
70
  """
71

72
  def apply(self,
73
            inputs,
74
            num_embeddings,
75
            num_features,
76
            mode='input',
77
            emb_init=nn.initializers.normal(stddev=1.0)):
78
    """Applies the Embed module.
79

80
    Args:
81
      inputs: An array of shape (batch_size, length) or (batch_size, length,
82
        vocab_size) with the input sequences. When 2-dimensional, the array
83
        contains sequences of int tokens. Otherwise, the array contains
84
        next-token distributions over tokens (e.g. one-hot representations).
85
      num_embeddings: An int with the number of embeddings.
86
      num_features: An int with the size of the embedding dimension.
87
      mode: A string, 'input' or 'output' -> to share input/output embeddings.
88
      emb_init: A callable, the embedding initializer function.
89

90
    Returns:
91
      An array of shape (batch_size, length, num_features) with embedded data.
92
    """
93
    if inputs.ndim != 2 and inputs.ndim != 3:
94
      raise ValueError('Expected 2 or 3 dimensions, found %d.' % inputs.ndim)
95

96
    embedding = self.param('embedding', (num_embeddings, num_features),
97
                           emb_init)
98
    if mode == 'input':
99
      if inputs.ndim == 2:  # Inputs are lists of integers.
100
        if inputs.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]:
101
          raise ValueError('Input type must be an integer or unsigned integer.')
102
        return jnp.take(embedding, inputs, axis=0)
103

104
      # Inputs contain per-token probabilities.
105
      if inputs.shape[2] != num_embeddings:
106
        raise ValueError('Expected shape (..., %d), found (..., %d)' %
107
                         (num_embeddings, inputs.shape[2]))
108
      batch_size, length, _ = tuple(inputs.shape)
109

110
      # Tile embeddings to (batch_size, length, num_features, num_embeddings).
111
      emb = jnp.transpose(embedding)
112
      tiled_emb = jnp.tile(emb[None, None, Ellipsis], [batch_size, length, 1, 1])
113

114
      # Accumulate embeddings proportional to token probabilities.
115
      accum_emb = jnp.matmul(tiled_emb, inputs[Ellipsis, None])
116
      return accum_emb[Ellipsis, 0]
117
    if mode == 'output':
118
      return jnp.einsum('bld,vd->blv', inputs, embedding)
119

120

121
def get_positional_encodings(max_len, emb_size, concatenate=False):
122
  """Compute positional encodings as described in the Transformer paper.
123

124
  Positional encoddings use sine and cosine functions of different frequencies:
125

126
    PE(pos, 2i) = sin(pos / (10000^(2i / emb_size)))
127
    PE(pos, 2i + 1) = cos(pos / (10000^(2i / emb_size))
128

129
  where pos is the position and i is the dimension
130

131
  Reference: Section 3.5 in
132
    [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf)
133

134
  Args:
135
    max_len: An int with the maximum possible length for the input.
136
    emb_size: An int with the embedding size.
137
    concatenate: A bool indicating whether to concatenate or interleave the
138
      sines and cosines. The default is False to match the Transformer paper.
139

140
  Returns:
141
    An array of shape (1, max_len, emb_size) with positional embeddings.
142
  """
143
  def _get_angles_per_position(position, dim, emb_size):
144
    denominator = np.power(10000, (2 * (dim // 2)) / np.float32(emb_size))
145
    return position / denominator
146

147
  # Create the arguments for the sines and cosines.
148
  angles = _get_angles_per_position(np.arange(max_len)[:, np.newaxis],
149
                                    np.arange(emb_size)[np.newaxis, :],
150
                                    emb_size)
151

152
  # Apply sine to the odd positions.
153
  sines = np.sin(angles[:, 0::2])
154

155
  # Apply cosine to the even positions.
156
  cosines = np.cos(angles[:, 1::2])
157

158
  if concatenate:
159
    # See e.g. http://jalammar.github.io/illustrated-transformer/.
160
    output = np.concatenate([sines, cosines], axis=-1)
161
  else:
162
    # See e.g.
163
    # https://kazemnejad.com/blog/transformer_architecture_positional_encoding/.
164
    output = np.zeros_like(angles)
165
    output[:, 0::2] = sines
166
    output[:, 1::2] = cosines
167

168
  output = output[np.newaxis, :, :]
169
  return output
170

171

172
def sinusoidal_init(max_len=2048):
173
  """Weight initializer based on sinusoial positional embeddings.
174

175
  Args:
176
    max_len: An int with the maximum possible length for the input.
177

178
  Returns:
179
    Callable taking as input a key and a shape (..., emb_size) and returning
180
      positional embeddings of shape (1, max_len, emb_size).
181
  """
182

183
  def init(key, shape, dtype=np.float32):
184
    """Sinusoidal init."""
185
    del key, dtype
186
    return jnp.array(get_positional_encodings(max_len, shape[-1]))
187
  return init
188

189

190
class AddLearnedPositionalEncodings(nn.Module):
191
  """Adds learned positional embeddings to the inputs."""
192

193
  def apply(self,
194
            inputs,
195
            inputs_positions=None,
196
            max_len=2048,
197
            posemb_init=nn.initializers.normal(stddev=1.0),
198
            cache=None):
199
    """Applies the AddLearnedPositionalEncodings module.
200

201
    Args:
202
      inputs: input data
203
      inputs_positions: input position indices for packed sequences.
204
      max_len: maximum possible length for the input
205
      posemb_init: positional embedding initializer
206
      cache: flax attention cache for fast decoding.
207

208
    Returns:
209
      output: `(bs, timesteps, in_dim)`
210
    """
211
    if inputs.ndim != 3:
212
      raise ValueError('Wrong number of dimensions: found %d expected 3' %
213
                       inputs.ndim)
214
    length = inputs.shape[1]
215
    pos_emb_shape = (1, max_len, inputs.shape[-1])
216
    pos_embedding = self.param('pos_embedding', pos_emb_shape, posemb_init)
217
    pe = pos_embedding[:, :length, :]
218
    # We abuse the same attention Cache mechanism to run positional embeddings
219
    # in fast predict mode. We could use state variables instead, but this
220
    # simplifies invocation with a single top-level cache context manager.
221
    # We only use the cache's position index for tracking decoding position.
222
    if cache:
223
      if self.is_initializing():
224
        cache.store(np.array((4, 1, 1), dtype=np.int32))
225
      else:
226
        cache_entry = cache.retrieve(None)
227
        i = cache_entry.i
228
        one = jnp.array(1, jnp.uint32)
229
        cache_entry = cache_entry.replace(i=cache_entry.i + one)
230
        cache.store(cache_entry)
231
        _, _, df = pos_embedding.shape
232
        pe = lax.dynamic_slice(pos_embedding, jnp.array((0, i, 0)),
233
                               (1, 1, df))
234
    if inputs_positions is None:
235
      # normal unpacked case:
236
      return inputs + pe
237
    else:
238
      # for packed data we need to use known position indices:
239
      return inputs + jnp.take(pe[0], inputs_positions, axis=0)
240

241

242
class AddSinusoidalPositionalEncodings(nn.Module):
243
  """Adds the standard sinusoidal positional encodings to the inputs."""
244

245
  def apply(self, inputs, max_len=2048):
246
    """Applies the AddSinusoidalPositionalEncodings module.
247

248
    Args:
249
      inputs: An array of shape (batch_size, length, emb_size) with the token
250
        embeddings.
251
      max_len: An int with the maximum possible length for the input.
252

253
    Returns:
254
      An array of shape (batch_size, length, emb_size).
255
    """
256
    if inputs.ndim != 3:
257
      raise ValueError('Wrong number of dimensions: found %d expected 3' %
258
                       inputs.ndim)
259

260
    seq_len = inputs.shape[1]
261
    emb_size = inputs.shape[2]
262
    positional_encodings = get_positional_encodings(max_len, emb_size)
263
    positional_encodings = positional_encodings[:, :seq_len, :]
264
    return inputs + positional_encodings
265

266

267
class MlpBlock(nn.Module):
268
  """Transformer MLP block."""
269

270
  def apply(self,
271
            inputs,
272
            mlp_dim,
273
            out_dim=None,
274
            dropout_rate=0.1,
275
            deterministic=False,
276
            kernel_init=nn.initializers.xavier_uniform(),
277
            bias_init=nn.initializers.normal(stddev=1e-6)):
278
    """Applies Transformer MlpBlock module."""
279
    actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
280
    x = nn.Dense(inputs, mlp_dim, kernel_init=kernel_init, bias_init=bias_init)
281
    x = nn.gelu(x)
282
    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
283
    output = nn.Dense(
284
        x, actual_out_dim, kernel_init=kernel_init, bias_init=bias_init)
285
    output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic)
286
    return output
287

288

289
class Transformer1DBlock(nn.Module):
290
  """Transformer layer (https://openreview.net/forum?id=H1e5GJBtDr)."""
291

292
  def apply(self,
293
            inputs,
294
            qkv_dim,
295
            mlp_dim,
296
            num_heads,
297
            causal_mask=False,
298
            padding_mask=None,
299
            dropout_rate=0.1,
300
            attention_dropout_rate=0.1,
301
            deterministic=False,
302
            self_attention_module=nn.SelfAttention,
303
            attention_fn=None,
304
            cache=None):
305
    """Applies Transformer1DBlock module.
306

307
    Args:
308
      inputs: input data
309
      qkv_dim: dimension of the query/key/value
310
      mlp_dim: dimension of the mlp on top of attention block
311
      num_heads: number of heads
312
      causal_mask: bool, mask future or not
313
      padding_mask: bool, mask padding tokens
314
      dropout_rate: dropout rate
315
      attention_dropout_rate: dropout rate for attention weights
316
      deterministic: bool, deterministic or not (to apply dropout)
317
      self_attention_module: Self attention module.
318
      attention_fn: dot product function to use inside attention.
319
      cache: Cache for decoding.
320

321
    Returns:
322
      output after transformer block.
323

324
    """
325

326
    # Attention block.
327
    assert inputs.ndim == 3
328
    x = nn.LayerNorm(inputs)
329
    if attention_fn is not None:
330
      self_attention_module = self_attention_module.partial(
331
          attention_fn=attention_fn)
332
    x = self_attention_module(
333
        x,
334
        num_heads=num_heads,
335
        qkv_features=qkv_dim,
336
        attention_axis=(1,),
337
        causal_mask=causal_mask,
338
        padding_mask=padding_mask,
339
        kernel_init=nn.initializers.xavier_uniform(),
340
        bias_init=nn.initializers.normal(stddev=1e-6),
341
        bias=False,
342
        broadcast_dropout=False,
343
        dropout_rate=attention_dropout_rate,
344
        deterministic=deterministic,
345
        cache=cache)
346
    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
347
    x = x + inputs
348

349
    # MLP block.
350
    y = nn.LayerNorm(x)
351
    y = MlpBlock(
352
        y,
353
        mlp_dim=mlp_dim,
354
        dropout_rate=dropout_rate,
355
        deterministic=deterministic)
356

357
    return x + y
358

359

360
# TODO(levskaya): modify for 3 modes: train, eval and fast predict.
361
class Transformer(nn.Module):
362
  """Transformer Model for language modeling."""
363

364
  def apply(self,
365
            inputs,
366
            vocab_size,
367
            emb_dim=512,
368
            num_heads=8,
369
            num_layers=6,
370
            qkv_dim=512,
371
            mlp_dim=2048,
372
            max_len=2048,
373
            train=False,
374
            dropout_rate=0.1,
375
            attention_dropout_rate=0.1,
376
            causal=True,
377
            cache=None,
378
            positional_encoding_module=AddLearnedPositionalEncodings,
379
            self_attention_module=nn.SelfAttention,
380
            attention_fn=None,
381
            pad_token=None,
382
            output_head='logits'):
383
    """Applies Transformer model on the inputs.
384

385
    Args:
386
      inputs: An array of shape (batch_size, length) or (batch_size, length,
387
        vocab_size) with the input sequences. When 2-dimensional, the array
388
        contains sequences of int tokens. Otherwise, the array contains
389
        next-token distributions over tokens (e.g. one-hot representations).
390
      vocab_size: An int with the size of the vocabulary.
391
      emb_dim: An int with the token embedding dimension.
392
      num_heads: An int with the number of attention heads.
393
      num_layers: An int with the number of transformer encoder layers.
394
      qkv_dim: An int with the dimension of the query/key/value vectors.
395
      mlp_dim: An int with the inner dimension of the feed-forward network which
396
        follows the attention block.
397
      max_len: An int with the maximum training sequence length.
398
      train: A bool denoting whether we are currently training.
399
      dropout_rate: A float with the dropout rate.
400
      attention_dropout_rate: A float with a dropout rate for attention weights.
401
      causal: Whether to apply causal masking.
402
      cache: Cache for decoding.
403
      positional_encoding_module: A module used for adding positional encodings.
404
      self_attention_module: Self attention module.
405
      attention_fn: Method to use in place of dot product attention.
406
      pad_token: Token to ignore in attention.
407
      output_head: String or iterable over strings containing the model's output
408
        head(s) to return.
409

410
    Returns:
411
      Output of a transformer decoder. If output_head is a string, we return a
412
        single output head output; if output_head is an iterable, we return a
413
        dict with (output head name, output head output) key-value pairs.
414
    """
415
    if inputs.ndim != 2 and inputs.ndim != 3:
416
      raise ValueError('Expected 2 or 3 dimensions, found %d.' % inputs.ndim)
417

418
    if inputs.ndim == 3:
419
      padding_mask = jnp.ones_like(inputs[Ellipsis, 0])
420
    elif pad_token is None:
421
      padding_mask = jnp.ones_like(inputs)
422
    else:
423
      # Mask out padding tokens.
424
      padding_mask = jnp.where(inputs != pad_token, 1, 0).astype(jnp.float32)
425
    padding_mask = padding_mask[Ellipsis, None]  # Add embedding dimension.
426

427
    heads = dict()
428
    x = inputs
429
    if inputs.ndim == 2:
430
      x = x.astype('int32')
431
    x = Embed(x, num_embeddings=vocab_size, num_features=emb_dim, name='embed')
432

433
    if positional_encoding_module == AddLearnedPositionalEncodings:
434
      x = positional_encoding_module(
435
          x,
436
          max_len=max_len,
437
          cache=cache,
438
          posemb_init=sinusoidal_init(max_len=max_len))
439
    else:
440
      x = positional_encoding_module(x, max_len=max_len)
441
    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
442
    heads['input_emb'] = x
443
    for i in range(num_layers):
444
      x = Transformer1DBlock(
445
          x,
446
          qkv_dim=qkv_dim,
447
          mlp_dim=mlp_dim,
448
          num_heads=num_heads,
449
          causal_mask=causal,
450
          padding_mask=padding_mask,
451
          dropout_rate=dropout_rate,
452
          attention_dropout_rate=attention_dropout_rate,
453
          self_attention_module=self_attention_module,
454
          deterministic=not train,
455
          attention_fn=attention_fn,
456
          cache=cache,
457
      )
458
      heads['layer_%s' % i] = x
459
    x = nn.LayerNorm(x)
460
    heads['output_emb'] = x * padding_mask  # Zero out PAD positions.
461
    if 'logits' in output_head:
462
      logits = nn.Dense(
463
          x,
464
          vocab_size,
465
          kernel_init=nn.initializers.xavier_uniform(),
466
          bias_init=nn.initializers.normal(stddev=1e-6))
467
      heads['logits'] = logits
468

469
    if 'regression' in output_head:
470
      regression = nn.Dense(
471
          x,
472
          1,
473
          kernel_init=nn.initializers.xavier_uniform(),
474
          bias_init=nn.initializers.normal(stddev=1e-6))
475
      regression = jnp.squeeze(regression, axis=-1)
476
      heads['regression'] = regression
477

478
    if isinstance(output_head, (tuple, list)):
479
      return {head: heads[head] for head in output_head}
480
    return heads[output_head]
481

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.