google-research
1046 строк · 34.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Standalone Flax models."""
17
18import abc19import enum20import functools21import math22import operator as op23import os24import pprint25import time26
27from absl import logging28from flax import jax_utils29from flax.deprecated import nn30from flax.training import checkpoints31from flax.training import common_utils32import gin33from gin import config34import jax35from jax import random as jrandom36import jax.example_libraries.optimizers37import jax.nn38import jax.numpy as jnp39import numpy as onp40import tensorflow.compat.v1 as tf41import tree42
43from protein_lm import data44from protein_lm import evaluation45from protein_lm import modules46from protein_lm import sampling47from protein_lm import utils48
49
50class Mode(enum.Enum):51train = 'train'52eval = 'eval'53predict = 'predict'54sample = 'sample'55
56
57def parse_config(ckpt_dir):58"""Parses a FlaxLM config as a dict from checkpoint dir."""59cfg = dict()60with tf.gfile.GFile(os.path.join(ckpt_dir, 'config.gin')) as f:61for line in f:62if 'FlaxLM' in line and not line.startswith('#'):63key, value = line.split(' = ')64_, kwarg = key.split('.')65value = config.parse_value(value)66cfg[kwarg] = value67return cfg68
69
70def save_model_kwargs(ckpt_dir, model):71"""Saves a dict FlaxLM config into the checkpoint dir."""72model_kwargs = model.model_kwargs73model_name = type(model).__name__74with tf.gfile.GFile(os.path.join(ckpt_dir, 'config.gin'), 'w') as f:75for key, value in model_kwargs.items():76f.write('%s.%s = %s\n' % (model_name, key, str(value)))77
78
79@functools.lru_cache()80def load_model(ckpt_dir, model_cls, domain=None):81"""Loads a model from directory."""82if domain is None:83domain = data.protein_domain84cfg = parse_config(ckpt_dir)85print('Loading model with config:')86pprint.pprint(cfg)87model = model_cls(domain=domain, **cfg)88model.load_checkpoint(ckpt_dir)89return model90
91
92def train_step(optimizer,93inputs,94learning_rate_fn,95dropout_rng,96preprocess_fn,97example_weights=None,98grad_clip=None,99epsilon=1e-9):100"""Performs a single training step. Masks out BOS/PAD positions.101
102Args:
103optimizer: Flax optimizer.
104inputs: Inputs to model.preprocess which returns (inputs, targets, weights).
105learning_rate_fn: function from step idx --> learning rate.
106dropout_rng: RNG for dropout.
107preprocess_fn: function mapping
108(inputs, rng, mode) -> (inputs, targets, weights).
109example_weights: Optional [batch] weights for the loss on each example.
110See utils.compute_weighted_cross_entropy for details.
111grad_clip: If not None, clip gradients to [-x, +x].
112epsilon: Epsilon for denominator of loss averaging.
113
114Returns:
115new_optimizer, metrics, new_dropout_rng
116"""
117
118# We handle PRNG splitting inside the top pmap, rather119# than handling it outside in the training loop - doing the120# latter can add some stalls to the devices.121dropout_rng, new_dropout_rng = jrandom.split(dropout_rng)122dropout_rng, preprocess_rng = jrandom.split(dropout_rng)123
124inputs, targets, weights = preprocess_fn(125inputs, rng=preprocess_rng, mode=Mode.train)126
127if isinstance(targets, dict):128classification_targets = targets['classification']129classification_weights = weights['classification']130
131regression_targets = targets['regression']132regression_weights = weights['regression']133else:134# Default to classification loss.135classification_targets = targets136classification_weights = weights137regression_targets = None138
139if classification_targets is None and regression_targets is None:140raise ValueError('No targets specified for train step.')141
142if classification_weights is None and regression_weights is None:143raise ValueError('No weights specified for train step.')144
145def loss_fn(model):146"""Loss function used for training."""147# Stateful collection for tracking internal state like activations.148with nn.stateful() as batch_stats:149with nn.stochastic(dropout_rng):150outputs = model(inputs, train=True, cache=None)151
152if isinstance(outputs, dict):153logits = outputs.get('logits', None)154regression_predictions = outputs.get('regression', None)155else:156logits = outputs157regression_predictions = None158
159mean_loss = 0.0160
161# Classification loss162if classification_targets is not None:163classification_loss, classification_weight_sum = utils.compute_weighted_cross_entropy(164logits,165classification_targets,166token_weights=classification_weights,167example_weights=example_weights)168classification_weight_sum = jnp.maximum(classification_weight_sum,169epsilon)170# Handle case where nothing is masked out in BERT171# (Only occurs with very short sequences).172mean_classification_loss = classification_loss / classification_weight_sum173mean_loss += mean_classification_loss174
175if regression_targets is not None:176regression_loss, regression_weight_sum = utils.compute_weighted_mse(177regression_predictions,178regression_targets,179weights=regression_weights)180regression_weight_sum = jnp.maximum(regression_weight_sum, epsilon)181mean_regression_loss = regression_loss / regression_weight_sum182outputs['regression_loss'] = mean_regression_loss183
184# TODO(ddohan): Allow weighting each loss separately.185mean_loss += mean_regression_loss186
187return mean_loss, (outputs, batch_stats)188
189step = optimizer.state.step190lr = learning_rate_fn(step)191
192grad_fn = jax.value_and_grad(loss_fn, has_aux=True)193(_, (outputs, batch_stats)), grad = grad_fn(optimizer.target)194
195try:196grad = jax.lax.pmean(grad, 'batch')197except NameError:198pass199
200if grad_clip is not None:201# Clip gradients after pmean aggregation202unclipped_grad = grad203grad = jax.example_libraries.optimizers.clip_grads(grad, grad_clip)204
205new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)206
207# TODO(ddohan): Avoid computing metrics except when needed.208if isinstance(outputs, dict):209logits = outputs.get('logits', None)210else:211logits = outputs212
213metrics = dict()214if logits is not None:215classification_metrics = utils.compute_metrics(logits,216classification_targets,217classification_weights)218metrics.update(classification_metrics)219if regression_targets is not None:220# TODO(ddohan): Implement regression metrics.221logging.info('No regression targets yet')222# regression = outputs.get('regression', None)223# regression_metrics = utils.compute_metrics(logits, regression_targets,224# classification_weights)225metrics['learning_rate'] = lr226
227# Training metrics228metrics['l2_param_sum'] = utils.l2_regularization(optimizer.target.params)229
230# Gradient norms231grad_l2_tree = utils.l2_norm(grad)232grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree)233grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree)234metrics['l2_grad_sum'] = grad_l2_sum235metrics['l2_grad_max'] = grad_l2_max236
237# Store any tagged metrics238batch_stats = batch_stats.as_dict()239if batch_stats:240
241def clean_name(k):242return 'nn/' + k.replace('MultiHeadDotProductAttention_', '').replace(243'/Transformer1DBlock_', '')244
245stats = {clean_name(k): v['tag'] for k, v in batch_stats.items()}246metrics.update(stats)247
248if grad_clip is not None:249# Unclipped gradient norms (if applicable).250grad_l2_tree = utils.l2_norm(unclipped_grad)251grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree)252grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree)253metrics['l2_noclip_grad_sum'] = grad_l2_sum254metrics['l2_noclip_grad_max'] = grad_l2_max255
256return new_optimizer, metrics, new_dropout_rng257
258
259def eval_step(model, inputs, preprocess_fn):260inputs, targets, weights = preprocess_fn(inputs, rng=None, mode=Mode.eval)261logits = model(inputs, train=False, cache=None)262return utils.compute_metrics(logits, targets, weights)263
264
265def predict_step(model, inputs, preprocess_fn, output_head='logits'):266inputs, _, _ = preprocess_fn(inputs, rng=None, mode=Mode.predict)267logits = model(inputs, train=False, cache=None, output_head=output_head)268return logits269
270
271def _tokens_to_logits(last_token, cache, model, internal_state=None):272"""Computes the next token logits.273
274Args:
275last_token: An array of shape (batch_size, 1) containing last token ids.
276cache: A flax.deprecated.nn.attention.Cache object.
277model: A Jax decoder model to be used for computing the next token logits.
278internal_state: A dict with internal state received from the previous time
279step. If None, no information is shared across time steps.
280
281Returns:
282logits: An array of shape (batch_size, vocab_size) with the logits.
283new_cache: A flax.deprecated.nn.attention.Cache object with the updated
284cache.
285new_internal_state: A dict with internal state passed to the next time step.
286"""
287del internal_state # Not used.288# The returned logits have shape (batch_size, 1, vocab_size).289with cache.mutate() as new_cache:290logits = model(last_token, train=False, cache=new_cache)291
292# Remove the singleton dimension to return shape (batch_size, vocab_size).293logits = logits.squeeze(axis=1)294return logits, new_cache, None295
296
297def sample_step(prompt,298model,299cache,300rng,301masked_tokens,302eos_token,303pad_token,304max_decode_len,305tokens_to_logits=_tokens_to_logits,306**sampling_kwargs):307"""Samples autoregressively from the model.308
309Args:
310prompt: An array of shape (batch_size, prompt_length) containing the input
311prompt (the model consumes these tokens and starts generation after). For
312generic sampling, the prompt must be a single BOS token.
313model: A Jax decoder model to be used for computing the next token logits.
314cache: A flax.deprecated.nn.attention.Cache object.
315rng: A jax.random.PRNGKey object.
316masked_tokens: A list of ints indicating tokens to mask out during sampling.
317eos_token: An int indicating the EOS token id. If None, we decode until
318reaching the maximum sequence length.
319pad_token: An int token used to pad sequences after the eos token. If none,
320we set pad_token to eos_token.
321max_decode_len: An int indicating the maximum sequence length.
322tokens_to_logits: A callable that computes the next token logits given the
323current cache and previous token.
324**sampling_kwargs: Named arguments passed to sampling.temperature_sample.
325
326Returns:
327An array of shape (batch_size, max_decode_len) containing sampled sequences.
328If variable-length, the sequences are right-padded with the EOS token.
329"""
330tokens_to_logits = functools.partial(tokens_to_logits, model=model)331return sampling.temperature_sample(332prompt,333init_cache=cache,334tokens_to_logits=tokens_to_logits,335max_decode_len=max_decode_len,336rng=rng,337eos_token=eos_token,338pad_token=pad_token,339masked_tokens=masked_tokens,340**sampling_kwargs,341)342
343
344def compute_logprob(inputs, model, mask_token=None):345"""Returns an array of log probabilities for the input sequences."""346
347assert inputs.ndim == 2348
349targets = inputs350weights = jnp.where(targets != model.pad_token, 1, 0)351if mask_token is not None:352weights *= jnp.where(targets != mask_token, 1, 0)353logits = model.score(inputs)354assert logits.ndim == 3355
356onehot_targets = common_utils.onehot(targets, logits.shape[-1])357log_lik = jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)358log_lik *= weights359log_prob = jnp.sum(log_lik, axis=-1)360
361return log_prob362
363
364def preprocess_causal(batch, bos_token, pad_token, mode):365"""Preprocessing for causal language modeling.366
367Right shifts and shards.
368
369Args:
370batch: [batch x length] tokens.
371bos_token: Int ID to use as beginning of sentence token.
372pad_token: Padding token which should be masked out in loss.
373mode: Mode value.
374
375Returns:
376Tuple of [batch x length] inputs, targets, per position weights. Targets
377will have random positions masked out with either a MASK token, or a
378randomly chosen token from the vocabulary.
379"""
380if mode == Mode.sample:381inputs = batch382else:383inputs = modules.shift_right(batch, bos_token=bos_token)384
385targets = batch386# Mask out PAD in loss.387if pad_token is None:388weights = jnp.ones_like(targets)389else:390weights = jnp.where(targets != pad_token, 1, 0)391return inputs, targets, weights392
393
394@gin.configurable395class FlaxModel(abc.ABC):396"""Model built on Flax."""397
398def __init__(self,399domain=data.protein_domain,400model_cls=modules.Transformer,401random_seed=0,402batch_size=None,403grad_clip=None,404learning_rate=0.001,405weight_decay=0.1,406cache=True,407pmap=True,408attention_fn=None,409with_bos=False,410with_mask=False,411store_metrics=False,412sampling_kwargs=None,413**model_kwargs):414"""Creates a Flax model for sequence prediction.415
416Args:
417domain: discrete domain.
418model_cls: Flax.nn.Module to train.
419random_seed: Random seed.
420batch_size: Default batch size.
421grad_clip: Gradient clipping in optimizer.
422learning_rate: learning rate in optimizer, or callable mapping a step to
423current learning rate.
424weight_decay: L2 decay for AdamW.
425cache: Whether to create a cache.
426pmap: Whether to pmap inference (and JIT as a side effect).
427attention_fn: Function to use in place of nn.dot_product_attention.
428with_bos: Whether to ensure vocab contains BOS.
429with_mask: Whether to ensure vocab contains MASK.
430store_metrics: Whether to store train and evaluation metrics.
431sampling_kwargs: Additional config options for sample step.
432**model_kwargs: Additional config options for `model_cls.partial`.
433"""
434self._batch_size = batch_size # Default batch size435
436# TODO(b/157255958): Reenable tracking metrics inside class.437self._store_metrics = store_metrics438if store_metrics:439self._metrics_train = []440self._metrics_test = []441self._epoch_train = []442self._epoch_test = []443
444self._pmap = pmap445self._sampling_kwargs = sampling_kwargs446self._model_kwargs = model_kwargs447self._opt_hparams = dict(448learning_rate=learning_rate,449weight_decay=weight_decay,450grad_clip=grad_clip)451
452# TODO(ddohan): Reimplement __getstate__ and __setstate__ to support pickle,453# and use these functions to init model.454self._set_domain(domain=domain, with_bos=with_bos, with_mask=with_mask)455self._init_model(456model_cls=model_cls,457random_seed=random_seed,458pmap=pmap,459cache=cache,460attention_fn=attention_fn,461sampling_kwargs=sampling_kwargs,462model_kwargs=model_kwargs,463**self._opt_hparams)464
465def _set_domain(self, domain, with_bos, with_mask):466"""Set vocabulary based on domain."""467self.domain = domain468self._length = domain.length469self._bos_token = domain.vocab.bos470self._eos_token = domain.vocab.eos471self._pad_token = domain.vocab.pad472self._mask_token = domain.vocab.mask473
474vocab_size = domain.vocab_size475if with_bos and self._bos_token is None: # Add bos token.476self._bos_token = vocab_size477vocab_size += 1478if with_mask and self._mask_token is None: # Add mask token.479self._mask_token = vocab_size480vocab_size += 1481self._vocab_size = vocab_size482
483def _get_masked_tokens(self):484"""Get list of token IDs to mask for a given domain."""485tokens = []486for token in [self._bos_token, self._pad_token, self._mask_token]:487if token is not None:488tokens.append(token)489return tokens490
491def _init_model(self,492model_cls,493pmap,494learning_rate,495weight_decay,496grad_clip,497attention_fn,498random_seed,499cache=True,500sampling_kwargs=None,501model_kwargs=None):502"""Initialize model."""503model_kwargs = model_kwargs or dict()504model_def = model_cls.partial(505vocab_size=self._vocab_size,506max_len=self.domain.length,507# Don't attend to PAD tokens508pad_token=self._pad_token,509attention_fn=attention_fn,510**model_kwargs)511
512if callable(learning_rate):513learning_rate_fn = learning_rate514else:515learning_rate_fn = lambda step: learning_rate516
517train_fn = functools.partial(518train_step,519learning_rate_fn=learning_rate_fn,520grad_clip=grad_clip,521preprocess_fn=self.preprocess)522eval_fn = functools.partial(eval_step, preprocess_fn=self.preprocess)523predict_fn = functools.partial(predict_step, preprocess_fn=self.preprocess)524
525sampling_kwargs = sampling_kwargs or dict()526masked_tokens = self._get_masked_tokens()527sample_fn = functools.partial(528sample_step,529masked_tokens=masked_tokens,530eos_token=self._eos_token,531pad_token=self._pad_token,532max_decode_len=self._length + 1,533**sampling_kwargs)534
535# Default to pmapped versions.536if pmap:537train_fn = jax.pmap(train_fn, axis_name='batch')538eval_fn = jax.pmap(eval_fn, axis_name='batch')539sample_fn = jax.pmap(sample_fn, axis_name='batch')540predict_fn = jax.pmap(predict_fn, axis_name='batch')541
542self._train_fn = train_fn543self._predict_fn = predict_fn544self._sample_fn = sample_fn545self._eval_fn = eval_fn546
547rng = jrandom.PRNGKey(random_seed)548rng, init_rng = jrandom.split(rng)549rng, self._sample_rng = jrandom.split(rng)550
551# We init the first set of dropout PRNG keys, but update it afterwards552# inside the main pmap'd training update for performance.553if self._pmap:554self._dropout_rngs = jrandom.split(rng, jax.local_device_count())555else:556self._dropout_rngs = rng557
558# Note: any batch size can be used later. This is arbitrary for init.559input_shape = (self._batch_size or 2, self.domain.length)560if cache:561init_model, self._cache_def = utils.create_model_and_cache(562init_rng, input_shape, model_def)563else:564init_model = utils.create_model(init_rng, input_shape, model_def)565self._cache_def = None566self._optimizer = utils.create_adam_optimizer(567init_model,568learning_rate=learning_rate,569weight_decay=weight_decay,570pmap=self._pmap)571del init_model # Delete initial model.572
573def preprocess(self, batch, rng, mode):574"""Unpack batch of data to (inputs, targets, weights).575
576batch may be one of:
577- a [batch x length] batch of input data.
578Results in (batch, None, None)
579- a tuple of (inputs, targets)
580Results in (inputs, targets, ones_like(targets))
581- a tuple of (inputs, targets, weights)
582Passed through unchanged.
583- a dict containing 'inputs', 'targets', and
584optionally 'weights'.
585Results in (inputs, targets, weights or ones_like(targets))
586
587Args:
588batch: Batch of data.
589rng: Ignored. Jax random seed.
590mode: member of Mode enum.
591
592Returns:
593Tuple of (inputs, targets, weights).
594`targets` and `weights` are None if `targets` is not provided.
595"""
596del rng597if isinstance(batch, tuple):598if len(batch) == 2:599inputs, targets = batch600weights = jnp.ones_like(targets)601elif len(batch) == 3:602inputs, targets, weights = batch603else:604raise ValueError(605'Must provide (inputs, targets) or (inputs, targets, weights)')606elif isinstance(batch, dict):607inputs = batch['inputs']608targets = batch['targets']609weights = batch.get('targets', None)610if weights is None:611weights = jnp.ones_like(targets)612else:613inputs = batch614targets = None615weights = None616
617if targets is None and mode not in (Mode.predict, Mode.sample):618raise ValueError('Must provide targets for train and eval.')619
620return inputs, targets, weights621
622@property623def train_step(self):624"""Returns the current train step."""625step = self.optimizer.state.step626if self._pmap:627step = step[0]628return step629
630@property631def bos_token(self):632"""Returns the BOS token id."""633return self._bos_token634
635@property636def eos_token(self):637"""Returns the EOS token id."""638return self._eos_token639
640@property641def pad_token(self):642"""Returns the BOS token id."""643return self._pad_token644
645@property646def mask_token(self):647"""Returns the MASK token id."""648return self._mask_token649
650@property651def length(self):652"""Returns the maximum sequence length."""653return self._length654
655@property656def vocab_size(self):657"""Returns the vocabulary size used for training."""658return self._vocab_size659
660@property661def optimizer(self):662"""Returns Flax optimizer containing optimizer and model parameters."""663return self._optimizer664
665@property666def model_kwargs(self):667"""Returns the model kwargs as a dictionary."""668return self._model_kwargs669
670@property671def pmap(self):672"""Returns whether or not the optimizer was trained with pmap."""673return self._pmap674
675def set_weights(self, optimizer):676"""Sets weights from unreplicated optimizer."""677if self._pmap:678optimizer = jax_utils.replicate(optimizer)679self._optimizer = optimizer680
681def get_weights(self):682"""Returns unreplicated optimizer."""683optimizer = self.optimizer684if self._pmap:685optimizer = jax_utils.unreplicate(optimizer)686return optimizer687
688def save_checkpoint(self, ckpt_dir):689"""Saves unreplicated optimizer to ckpt_dir."""690optimizer = self.get_weights()691checkpoints.save_checkpoint(692ckpt_dir,693target=optimizer,694step=self.train_step,695)696
697def load_checkpoint(self, ckpt_dir):698"""Loads optimizer from ckpt_dir."""699target = self.get_weights()700optimizer = checkpoints.restore_checkpoint(ckpt_dir, target=target)701if optimizer is target:702raise ValueError('Unable to load checkpoint from %s' % ckpt_dir)703self.set_weights(optimizer)704
705def fit(self, xs, epochs=1, batch_size=None, max_steps=10**6):706"""Fits to sequences given as [N x length] token array."""707if batch_size is None:708batch_size = self._batch_size709if hasattr(xs, 'as_numpy_iterator'):710# TF Dataset711ds = xs.repeat(epochs)712num_train_steps = max_steps713elif hasattr(xs, 'element_spec'):714# Dataset iterator.715if epochs != 1:716raise ValueError('Epochs must == 1 when using iterator input.')717ds = xs718num_train_steps = max_steps719else:720# Raw sequences which we turn into a dataset.721ds = data.dataset_from_tensors(xs)722ds = ds.shuffle(buffer_size=1024).repeat().batch(batch_size)723num_train_steps = math.ceil((len(xs) * epochs) / float(batch_size))724
725if max_steps:726num_train_steps = min(num_train_steps, max_steps)727
728if not num_train_steps:729raise ValueError('Must set max_steps to nonzero value.')730
731metrics = []732start = time.time()733max_steps = max_steps or 10**6734for _, batch in zip(range(num_train_steps), ds):735metrics.append(self.fit_batch(batch))736finish = time.time()737average = evaluation.combine_metrics(metrics)738average['runtime'] = finish - start739average['rate'] = len(metrics) / (finish - start)740
741if self._store_metrics:742average = tree.map_structure(onp.array, average)743self._epoch_train.append(average)744return dict(last=evaluation.combine_metrics([metrics[-1]]), average=average)745
746def evaluate(self, ds, steps=None):747"""Test model on data generator."""748return evaluation.evaluate(model=self, eval_ds=ds, num_eval_steps=steps)749
750def fit_batch(self, batch):751"""Update model on batch of sequences of shape [batch x length]."""752batch = tree.map_structure(jnp.asarray, batch)753if self._pmap:754batch = common_utils.shard(batch)755self._optimizer, metrics, self._dropout_rngs = self._train_fn(756optimizer=self.optimizer, inputs=batch, dropout_rng=self._dropout_rngs)757if self._store_metrics:758metrics = tree.map_structure(onp.array, metrics)759self._metrics_train.append(metrics)760return metrics761
762def score(self, batch):763"""Predicts logits for given [batch x length] sequences."""764batch = tree.map_structure(jnp.asarray, batch)765if self._pmap:766batch = common_utils.shard(batch)767logits = self._predict_fn(self.optimizer.target, batch)768# Undo pmap batching769if self._pmap:770logits = jnp.reshape(logits, [-1, logits.shape[-2], logits.shape[-1]])771return logits772
773def evaluate_batch(self, batch):774"""Computes metrics for given [batch x length] sequences."""775batch = tree.map_structure(jnp.asarray, batch)776if self._pmap:777batch = common_utils.shard(batch)778metrics = self._eval_fn(self.optimizer.target, batch)779if self._store_metrics:780metrics = tree.map_structure(onp.array, metrics)781self._metrics_test.append(metrics)782return metrics783
784
785@gin.configurable786class FlaxLM(FlaxModel):787"""Transformer with causal attention, right shift, and generative sampling."""788
789def __init__(self,790domain=data.protein_domain,791model_cls=modules.Transformer,792**kwargs):793
794model_cls = model_cls.partial(causal=True)795super().__init__(796domain=domain, model_cls=model_cls, cache=True, with_bos=True, **kwargs)797
798def preprocess(self, batch, rng, mode):799del rng800return preprocess_causal(801batch=batch,802bos_token=self._bos_token,803pad_token=self._pad_token,804mode=mode)805
806@property807def cache_def(self):808"""Returns the associated autoregressive cache_def."""809return self._cache_def810
811def sample_with_prompt(self, prompt, rng=None):812"""Draws prompt-guided samples from the model.813
814# TODO(gandreea): We could handle variable length prompts by assuming the
815# input prompt to be a list and padding with the out_of_prompt_token.
816
817Args:
818prompt: Iterable over equal-length sequences to use as input for sampling.
819The prompt is assumed to start with the BOS token.
820rng: A jax.random.PRNGKey object.
821
822Returns:
823An array of shape (len(prompt), self._length) containing sequences. If
824variable-length, the sequences are right-padded with the EOS token.
825"""
826if rng is None:827self._sample_rng, rng = jax.random.split(self._sample_rng)828length = self._length + 1829
830if self._pmap:831prompt = common_utils.shard(prompt)832cache = self.cache_def.initialize_cache((prompt.shape[1], length))833cache = jax_utils.replicate(cache)834rng = jax.random.split(rng, num=len(jax.local_devices()))835else:836cache = self.cache_def.initialize_cache((prompt.shape[0], length))837
838samples = self._sample_fn(839prompt=prompt,840model=self.optimizer.target,841cache=cache,842rng=rng,843)844
845# Remove the BOS token from the sampled sequences.846samples = samples[Ellipsis, 1:]847
848# Undo pmap batching849if self._pmap:850samples = jnp.reshape(samples, [-1, self._length])851return samples852
853def sample(self, batch_size, rng=None):854"""Draws samples from the model.855
856Args:
857batch_size: An int indicating the number of samples to return.
858rng: A jax.random.PRNGKey object.
859
860Returns:
861An array of shape (batch_size, self._length) containing sequences. If
862variable-length, the sequences are right-padded with the EOS token.
863"""
864# To draw generic samples, we initialize the prompt with the BOS token.865prompt = jnp.ones((batch_size, 1)).astype(jnp.int32) * self._bos_token866return self.sample_with_prompt(prompt, rng=rng)867
868
869@gin.configurable870class FlaxBERT(FlaxModel):871"""Transformer with all-to-all attention and token dropout."""872
873def __init__(self,874domain=data.protein_domain,875model_cls=modules.Transformer,876mask_rate=0.15,877random_token_proportion=0.8,878mask_token_proportion=0.1,879**kwargs):880"""Create BERT model.881
882
883For each token in input, masks with probability `mask_rate`. A masked token
884is replaced with:
885- MASK with probability `mask_token_proportion`,
886- a random token with `random_token_proportion`,
887- left unchanged but included in loss with the remaining probability.
888
889Args:
890domain: Domain to operate over.
891model_cls: Flax Module operating on sequences.
892mask_rate: Probability of replacing a token and including in the loss
893random_token_proportion: Portion of masked tokens to replace with a
894randomly chosen token.
895mask_token_proportion: Portion of masked tokens to replace with MASK.
896**kwargs: Arguments passed to FlaxModel.
897"""
898model_cls = model_cls.partial(causal=False)899self._mask_rate = mask_rate900total = random_token_proportion + mask_token_proportion901if total < 0 or total > 1:902raise ValueError('Sum of random proportion and mask proportion must be'903' in [0, 1] range.')904self._masker = BertMasker(905domain,906mask_rate=mask_rate,907mask_token_proportion=mask_token_proportion,908random_token_proportion=random_token_proportion)909
910super().__init__(911domain=domain,912model_cls=model_cls,913cache=False,914with_mask=True,915**kwargs)916
917def preprocess(self, batch, rng, mode):918return self._masker(inputs=batch, mode=mode, rng=rng)919
920def sample(self, masked_inputs, rng):921"""Fill in MASK positions in inputs."""922mask_positions = masked_inputs == self.domain.vocab.mask923logits = self.score(masked_inputs)924
925# Mask out MASK token.926mask = common_utils.onehot(927jnp.array([self.domain.vocab.mask]),928num_classes=logits.shape[-1],929on_value=sampling.LARGE_NEGATIVE)930logits = logits + mask931samples = jax.random.categorical(rng, logits=logits)932infilled = onp.where(mask_positions, samples, masked_inputs)933return infilled934
935
936def preprocess_masked(inputs, random_tokens, mask_token, pad_token, mask_rate,937mask_token_proportion, random_token_proportion, mode,938rng):939"""Preprocess inputs for masked language modeling.940
941Args:
942inputs: [batch x length] input tokens.
943random_tokens: Set of tokens usable for replacing
944mask_token: Int ID to mask blanks with.
945pad_token: Int ID for PAD token. Positions left unchanged.
946mask_rate: Proportion of tokens to mask out.
947mask_token_proportion: Replace this proportion of chosen positions with
948MASK.
949random_token_proportion: Replace this proportion of chosen positions with
950randomly sampled tokens
951mode: Mode key.
952rng: Jax RNG.
953
954Returns:
955Tuple of [batch x length] inputs, targets, per position weights. targets
956will have random positions masked out with either a MASK token, or a
957randomly chosen token from the vocabulary.
958"""
959total = random_token_proportion + mask_token_proportion960if total < 0 or total > 1:961raise ValueError('Sum of random proportion and mask proportion must be'962' in [0, 1] range.')963targets = inputs964
965if mode == Mode.predict:966weights = jnp.full_like(targets, 1)967masked_inputs = inputs # Pass through968else:969if rng is None:970if mode is not Mode.eval:971raise ValueError('Must provide RNG unless in eval mode.')972# TODO(b/157055145): How to keep same eval set across runs?973# Make each sequences mask invariant to other members974# of the batch. Right now there is batch size dependence.975rng = jrandom.PRNGKey(jnp.sum(inputs))976
977# Get positions to leave untouched978is_pad = inputs == pad_token979
980# Positions to mask981rng, subrng = jax.random.split(rng)982should_mask = jrandom.bernoulli(subrng, p=mask_rate, shape=inputs.shape)983should_mask = jnp.where(is_pad, 0, should_mask) # Don't mask out padding.984
985# Generate full array of random tokens.986rng, subrng = jax.random.split(rng)987random_ids = jax.random.randint(988subrng, inputs.shape, minval=0, maxval=len(random_tokens))989
990fullrandom = random_tokens[random_ids]991# Full array of MASK tokens992fullmask = jnp.full_like(inputs, mask_token)993
994# Build up masked array by selecting from inputs/fullmask/fullrandom.995rand = jax.random.uniform(rng, shape=inputs.shape)996masked_inputs = inputs997# Remaining probability mass stays original values after MASK and RANDOM.998# MASK tokens.999masked_inputs = jnp.where(rand < mask_token_proportion, fullmask,1000masked_inputs)1001# Random tokens.1002masked_inputs = jnp.where(1003jnp.logical_and(rand >= mask_token_proportion,1004rand < mask_token_proportion + random_token_proportion),1005fullrandom, masked_inputs)1006
1007# Only replace positions where `should_mask`1008masked_inputs = jnp.where(should_mask, masked_inputs, inputs)1009weights = should_mask1010
1011return masked_inputs, targets, weights1012
1013
1014class BertMasker():1015"""Construct BERT masker given a domain."""1016
1017def __init__(self,1018domain,1019mask_rate=0.15,1020mask_token_proportion=0.1,1021random_token_proportion=0.8):1022vocab = domain.vocab1023if vocab.mask is None:1024raise ValueError('Vocabulary must specify a MASK token.')1025special_tokens = [vocab.bos, vocab.eos, vocab.mask, vocab.pad]1026special_tokens = [x for x in special_tokens if x is not None]1027normal_tokens = [x for x in vocab.token_ids if x not in special_tokens]1028self._domain = domain1029self._special_tokens = jnp.array(special_tokens)1030self._normal_tokens = jnp.array(normal_tokens)1031self._mask_rate = mask_rate1032self._mask_token_proportion = mask_token_proportion1033self._random_token_proportion = random_token_proportion1034
1035def __call__(self, inputs, mode, rng):1036inputs, targets, weights = preprocess_masked(1037inputs=inputs,1038mode=mode,1039rng=rng,1040random_tokens=self._normal_tokens,1041mask_token=self._domain.vocab.mask,1042pad_token=self._domain.vocab.pad,1043mask_rate=self._mask_rate,1044mask_token_proportion=self._mask_token_proportion,1045random_token_proportion=self._random_token_proportion)1046return inputs, targets, weights1047