google-research
480 строк · 16.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Flax Modules."""
17from flax.deprecated import nn
18from jax import lax
19import jax.numpy as jnp
20import jax.random as jrandom
21import numpy as np
22
23
24def shift_right(x, bos_token):
25"""Shift the input to the right by padding on axis 1 at train time."""
26pad_widths = [(0, 0)] * len(x.shape)
27pad_widths[1] = (1, 0) # Padding on axis=1
28padded = jnp.pad(
29x,
30pad_widths,
31mode='constant',
32constant_values=jnp.asarray(bos_token, dtype=x.dtype))
33return padded[:, :-1]
34
35
36def mask_uniform(inputs, rate, rng, mask_value):
37"""Applies a random dropout mask to the input.
38
39Args:
40inputs: the inputs that should be randomly masked.
41rate: the probablity of masking out a value.
42rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will be
43used.
44mask_value: Value to mask with.
45
46Returns:
47The masked inputs.
48"""
49if rate == 0.:
50return inputs
51keep_prob = 1. - rate
52mask = jrandom.bernoulli(rng, p=keep_prob, shape=inputs.shape)
53return lax.select(mask, inputs, jnp.full_like(inputs, mask_value))
54
55
56class Tag(nn.Module):
57"""Save a value to global state when running in stateful mode."""
58
59def apply(self, x):
60if self.is_stateful():
61tagged = self.state('tag')
62tagged.value = x
63return x
64
65
66class Embed(nn.Module):
67"""Embedding Module.
68
69A parameterized function from integers [0, n) to d-dimensional vectors.
70"""
71
72def apply(self,
73inputs,
74num_embeddings,
75num_features,
76mode='input',
77emb_init=nn.initializers.normal(stddev=1.0)):
78"""Applies the Embed module.
79
80Args:
81inputs: An array of shape (batch_size, length) or (batch_size, length,
82vocab_size) with the input sequences. When 2-dimensional, the array
83contains sequences of int tokens. Otherwise, the array contains
84next-token distributions over tokens (e.g. one-hot representations).
85num_embeddings: An int with the number of embeddings.
86num_features: An int with the size of the embedding dimension.
87mode: A string, 'input' or 'output' -> to share input/output embeddings.
88emb_init: A callable, the embedding initializer function.
89
90Returns:
91An array of shape (batch_size, length, num_features) with embedded data.
92"""
93if inputs.ndim != 2 and inputs.ndim != 3:
94raise ValueError('Expected 2 or 3 dimensions, found %d.' % inputs.ndim)
95
96embedding = self.param('embedding', (num_embeddings, num_features),
97emb_init)
98if mode == 'input':
99if inputs.ndim == 2: # Inputs are lists of integers.
100if inputs.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]:
101raise ValueError('Input type must be an integer or unsigned integer.')
102return jnp.take(embedding, inputs, axis=0)
103
104# Inputs contain per-token probabilities.
105if inputs.shape[2] != num_embeddings:
106raise ValueError('Expected shape (..., %d), found (..., %d)' %
107(num_embeddings, inputs.shape[2]))
108batch_size, length, _ = tuple(inputs.shape)
109
110# Tile embeddings to (batch_size, length, num_features, num_embeddings).
111emb = jnp.transpose(embedding)
112tiled_emb = jnp.tile(emb[None, None, Ellipsis], [batch_size, length, 1, 1])
113
114# Accumulate embeddings proportional to token probabilities.
115accum_emb = jnp.matmul(tiled_emb, inputs[Ellipsis, None])
116return accum_emb[Ellipsis, 0]
117if mode == 'output':
118return jnp.einsum('bld,vd->blv', inputs, embedding)
119
120
121def get_positional_encodings(max_len, emb_size, concatenate=False):
122"""Compute positional encodings as described in the Transformer paper.
123
124Positional encoddings use sine and cosine functions of different frequencies:
125
126PE(pos, 2i) = sin(pos / (10000^(2i / emb_size)))
127PE(pos, 2i + 1) = cos(pos / (10000^(2i / emb_size))
128
129where pos is the position and i is the dimension
130
131Reference: Section 3.5 in
132[Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf)
133
134Args:
135max_len: An int with the maximum possible length for the input.
136emb_size: An int with the embedding size.
137concatenate: A bool indicating whether to concatenate or interleave the
138sines and cosines. The default is False to match the Transformer paper.
139
140Returns:
141An array of shape (1, max_len, emb_size) with positional embeddings.
142"""
143def _get_angles_per_position(position, dim, emb_size):
144denominator = np.power(10000, (2 * (dim // 2)) / np.float32(emb_size))
145return position / denominator
146
147# Create the arguments for the sines and cosines.
148angles = _get_angles_per_position(np.arange(max_len)[:, np.newaxis],
149np.arange(emb_size)[np.newaxis, :],
150emb_size)
151
152# Apply sine to the odd positions.
153sines = np.sin(angles[:, 0::2])
154
155# Apply cosine to the even positions.
156cosines = np.cos(angles[:, 1::2])
157
158if concatenate:
159# See e.g. http://jalammar.github.io/illustrated-transformer/.
160output = np.concatenate([sines, cosines], axis=-1)
161else:
162# See e.g.
163# https://kazemnejad.com/blog/transformer_architecture_positional_encoding/.
164output = np.zeros_like(angles)
165output[:, 0::2] = sines
166output[:, 1::2] = cosines
167
168output = output[np.newaxis, :, :]
169return output
170
171
172def sinusoidal_init(max_len=2048):
173"""Weight initializer based on sinusoial positional embeddings.
174
175Args:
176max_len: An int with the maximum possible length for the input.
177
178Returns:
179Callable taking as input a key and a shape (..., emb_size) and returning
180positional embeddings of shape (1, max_len, emb_size).
181"""
182
183def init(key, shape, dtype=np.float32):
184"""Sinusoidal init."""
185del key, dtype
186return jnp.array(get_positional_encodings(max_len, shape[-1]))
187return init
188
189
190class AddLearnedPositionalEncodings(nn.Module):
191"""Adds learned positional embeddings to the inputs."""
192
193def apply(self,
194inputs,
195inputs_positions=None,
196max_len=2048,
197posemb_init=nn.initializers.normal(stddev=1.0),
198cache=None):
199"""Applies the AddLearnedPositionalEncodings module.
200
201Args:
202inputs: input data
203inputs_positions: input position indices for packed sequences.
204max_len: maximum possible length for the input
205posemb_init: positional embedding initializer
206cache: flax attention cache for fast decoding.
207
208Returns:
209output: `(bs, timesteps, in_dim)`
210"""
211if inputs.ndim != 3:
212raise ValueError('Wrong number of dimensions: found %d expected 3' %
213inputs.ndim)
214length = inputs.shape[1]
215pos_emb_shape = (1, max_len, inputs.shape[-1])
216pos_embedding = self.param('pos_embedding', pos_emb_shape, posemb_init)
217pe = pos_embedding[:, :length, :]
218# We abuse the same attention Cache mechanism to run positional embeddings
219# in fast predict mode. We could use state variables instead, but this
220# simplifies invocation with a single top-level cache context manager.
221# We only use the cache's position index for tracking decoding position.
222if cache:
223if self.is_initializing():
224cache.store(np.array((4, 1, 1), dtype=np.int32))
225else:
226cache_entry = cache.retrieve(None)
227i = cache_entry.i
228one = jnp.array(1, jnp.uint32)
229cache_entry = cache_entry.replace(i=cache_entry.i + one)
230cache.store(cache_entry)
231_, _, df = pos_embedding.shape
232pe = lax.dynamic_slice(pos_embedding, jnp.array((0, i, 0)),
233(1, 1, df))
234if inputs_positions is None:
235# normal unpacked case:
236return inputs + pe
237else:
238# for packed data we need to use known position indices:
239return inputs + jnp.take(pe[0], inputs_positions, axis=0)
240
241
242class AddSinusoidalPositionalEncodings(nn.Module):
243"""Adds the standard sinusoidal positional encodings to the inputs."""
244
245def apply(self, inputs, max_len=2048):
246"""Applies the AddSinusoidalPositionalEncodings module.
247
248Args:
249inputs: An array of shape (batch_size, length, emb_size) with the token
250embeddings.
251max_len: An int with the maximum possible length for the input.
252
253Returns:
254An array of shape (batch_size, length, emb_size).
255"""
256if inputs.ndim != 3:
257raise ValueError('Wrong number of dimensions: found %d expected 3' %
258inputs.ndim)
259
260seq_len = inputs.shape[1]
261emb_size = inputs.shape[2]
262positional_encodings = get_positional_encodings(max_len, emb_size)
263positional_encodings = positional_encodings[:, :seq_len, :]
264return inputs + positional_encodings
265
266
267class MlpBlock(nn.Module):
268"""Transformer MLP block."""
269
270def apply(self,
271inputs,
272mlp_dim,
273out_dim=None,
274dropout_rate=0.1,
275deterministic=False,
276kernel_init=nn.initializers.xavier_uniform(),
277bias_init=nn.initializers.normal(stddev=1e-6)):
278"""Applies Transformer MlpBlock module."""
279actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
280x = nn.Dense(inputs, mlp_dim, kernel_init=kernel_init, bias_init=bias_init)
281x = nn.gelu(x)
282x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
283output = nn.Dense(
284x, actual_out_dim, kernel_init=kernel_init, bias_init=bias_init)
285output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic)
286return output
287
288
289class Transformer1DBlock(nn.Module):
290"""Transformer layer (https://openreview.net/forum?id=H1e5GJBtDr)."""
291
292def apply(self,
293inputs,
294qkv_dim,
295mlp_dim,
296num_heads,
297causal_mask=False,
298padding_mask=None,
299dropout_rate=0.1,
300attention_dropout_rate=0.1,
301deterministic=False,
302self_attention_module=nn.SelfAttention,
303attention_fn=None,
304cache=None):
305"""Applies Transformer1DBlock module.
306
307Args:
308inputs: input data
309qkv_dim: dimension of the query/key/value
310mlp_dim: dimension of the mlp on top of attention block
311num_heads: number of heads
312causal_mask: bool, mask future or not
313padding_mask: bool, mask padding tokens
314dropout_rate: dropout rate
315attention_dropout_rate: dropout rate for attention weights
316deterministic: bool, deterministic or not (to apply dropout)
317self_attention_module: Self attention module.
318attention_fn: dot product function to use inside attention.
319cache: Cache for decoding.
320
321Returns:
322output after transformer block.
323
324"""
325
326# Attention block.
327assert inputs.ndim == 3
328x = nn.LayerNorm(inputs)
329if attention_fn is not None:
330self_attention_module = self_attention_module.partial(
331attention_fn=attention_fn)
332x = self_attention_module(
333x,
334num_heads=num_heads,
335qkv_features=qkv_dim,
336attention_axis=(1,),
337causal_mask=causal_mask,
338padding_mask=padding_mask,
339kernel_init=nn.initializers.xavier_uniform(),
340bias_init=nn.initializers.normal(stddev=1e-6),
341bias=False,
342broadcast_dropout=False,
343dropout_rate=attention_dropout_rate,
344deterministic=deterministic,
345cache=cache)
346x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
347x = x + inputs
348
349# MLP block.
350y = nn.LayerNorm(x)
351y = MlpBlock(
352y,
353mlp_dim=mlp_dim,
354dropout_rate=dropout_rate,
355deterministic=deterministic)
356
357return x + y
358
359
360# TODO(levskaya): modify for 3 modes: train, eval and fast predict.
361class Transformer(nn.Module):
362"""Transformer Model for language modeling."""
363
364def apply(self,
365inputs,
366vocab_size,
367emb_dim=512,
368num_heads=8,
369num_layers=6,
370qkv_dim=512,
371mlp_dim=2048,
372max_len=2048,
373train=False,
374dropout_rate=0.1,
375attention_dropout_rate=0.1,
376causal=True,
377cache=None,
378positional_encoding_module=AddLearnedPositionalEncodings,
379self_attention_module=nn.SelfAttention,
380attention_fn=None,
381pad_token=None,
382output_head='logits'):
383"""Applies Transformer model on the inputs.
384
385Args:
386inputs: An array of shape (batch_size, length) or (batch_size, length,
387vocab_size) with the input sequences. When 2-dimensional, the array
388contains sequences of int tokens. Otherwise, the array contains
389next-token distributions over tokens (e.g. one-hot representations).
390vocab_size: An int with the size of the vocabulary.
391emb_dim: An int with the token embedding dimension.
392num_heads: An int with the number of attention heads.
393num_layers: An int with the number of transformer encoder layers.
394qkv_dim: An int with the dimension of the query/key/value vectors.
395mlp_dim: An int with the inner dimension of the feed-forward network which
396follows the attention block.
397max_len: An int with the maximum training sequence length.
398train: A bool denoting whether we are currently training.
399dropout_rate: A float with the dropout rate.
400attention_dropout_rate: A float with a dropout rate for attention weights.
401causal: Whether to apply causal masking.
402cache: Cache for decoding.
403positional_encoding_module: A module used for adding positional encodings.
404self_attention_module: Self attention module.
405attention_fn: Method to use in place of dot product attention.
406pad_token: Token to ignore in attention.
407output_head: String or iterable over strings containing the model's output
408head(s) to return.
409
410Returns:
411Output of a transformer decoder. If output_head is a string, we return a
412single output head output; if output_head is an iterable, we return a
413dict with (output head name, output head output) key-value pairs.
414"""
415if inputs.ndim != 2 and inputs.ndim != 3:
416raise ValueError('Expected 2 or 3 dimensions, found %d.' % inputs.ndim)
417
418if inputs.ndim == 3:
419padding_mask = jnp.ones_like(inputs[Ellipsis, 0])
420elif pad_token is None:
421padding_mask = jnp.ones_like(inputs)
422else:
423# Mask out padding tokens.
424padding_mask = jnp.where(inputs != pad_token, 1, 0).astype(jnp.float32)
425padding_mask = padding_mask[Ellipsis, None] # Add embedding dimension.
426
427heads = dict()
428x = inputs
429if inputs.ndim == 2:
430x = x.astype('int32')
431x = Embed(x, num_embeddings=vocab_size, num_features=emb_dim, name='embed')
432
433if positional_encoding_module == AddLearnedPositionalEncodings:
434x = positional_encoding_module(
435x,
436max_len=max_len,
437cache=cache,
438posemb_init=sinusoidal_init(max_len=max_len))
439else:
440x = positional_encoding_module(x, max_len=max_len)
441x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
442heads['input_emb'] = x
443for i in range(num_layers):
444x = Transformer1DBlock(
445x,
446qkv_dim=qkv_dim,
447mlp_dim=mlp_dim,
448num_heads=num_heads,
449causal_mask=causal,
450padding_mask=padding_mask,
451dropout_rate=dropout_rate,
452attention_dropout_rate=attention_dropout_rate,
453self_attention_module=self_attention_module,
454deterministic=not train,
455attention_fn=attention_fn,
456cache=cache,
457)
458heads['layer_%s' % i] = x
459x = nn.LayerNorm(x)
460heads['output_emb'] = x * padding_mask # Zero out PAD positions.
461if 'logits' in output_head:
462logits = nn.Dense(
463x,
464vocab_size,
465kernel_init=nn.initializers.xavier_uniform(),
466bias_init=nn.initializers.normal(stddev=1e-6))
467heads['logits'] = logits
468
469if 'regression' in output_head:
470regression = nn.Dense(
471x,
4721,
473kernel_init=nn.initializers.xavier_uniform(),
474bias_init=nn.initializers.normal(stddev=1e-6))
475regression = jnp.squeeze(regression, axis=-1)
476heads['regression'] = regression
477
478if isinstance(output_head, (tuple, list)):
479return {head: heads[head] for head in output_head}
480return heads[output_head]
481