google-research

Форк
0
1046 строк · 34.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Standalone Flax models."""
17

18
import abc
19
import enum
20
import functools
21
import math
22
import operator as op
23
import os
24
import pprint
25
import time
26

27
from absl import logging
28
from flax import jax_utils
29
from flax.deprecated import nn
30
from flax.training import checkpoints
31
from flax.training import common_utils
32
import gin
33
from gin import config
34
import jax
35
from jax import random as jrandom
36
import jax.example_libraries.optimizers
37
import jax.nn
38
import jax.numpy as jnp
39
import numpy as onp
40
import tensorflow.compat.v1 as tf
41
import tree
42

43
from protein_lm import data
44
from protein_lm import evaluation
45
from protein_lm import modules
46
from protein_lm import sampling
47
from protein_lm import utils
48

49

50
class Mode(enum.Enum):
51
  train = 'train'
52
  eval = 'eval'
53
  predict = 'predict'
54
  sample = 'sample'
55

56

57
def parse_config(ckpt_dir):
58
  """Parses a FlaxLM config as a dict from checkpoint dir."""
59
  cfg = dict()
60
  with tf.gfile.GFile(os.path.join(ckpt_dir, 'config.gin')) as f:
61
    for line in f:
62
      if 'FlaxLM' in line and not line.startswith('#'):
63
        key, value = line.split(' = ')
64
        _, kwarg = key.split('.')
65
        value = config.parse_value(value)
66
        cfg[kwarg] = value
67
  return cfg
68

69

70
def save_model_kwargs(ckpt_dir, model):
71
  """Saves a dict FlaxLM config into the checkpoint dir."""
72
  model_kwargs = model.model_kwargs
73
  model_name = type(model).__name__
74
  with tf.gfile.GFile(os.path.join(ckpt_dir, 'config.gin'), 'w') as f:
75
    for key, value in model_kwargs.items():
76
      f.write('%s.%s = %s\n' % (model_name, key, str(value)))
77

78

79
@functools.lru_cache()
80
def load_model(ckpt_dir, model_cls, domain=None):
81
  """Loads a model from directory."""
82
  if domain is None:
83
    domain = data.protein_domain
84
  cfg = parse_config(ckpt_dir)
85
  print('Loading model with config:')
86
  pprint.pprint(cfg)
87
  model = model_cls(domain=domain, **cfg)
88
  model.load_checkpoint(ckpt_dir)
89
  return model
90

91

92
def train_step(optimizer,
93
               inputs,
94
               learning_rate_fn,
95
               dropout_rng,
96
               preprocess_fn,
97
               example_weights=None,
98
               grad_clip=None,
99
               epsilon=1e-9):
100
  """Performs a single training step. Masks out BOS/PAD positions.
101

102
  Args:
103
    optimizer: Flax optimizer.
104
    inputs: Inputs to model.preprocess which returns (inputs, targets, weights).
105
    learning_rate_fn: function from step idx --> learning rate.
106
    dropout_rng: RNG for dropout.
107
    preprocess_fn: function mapping
108
      (inputs, rng, mode) -> (inputs, targets, weights).
109
    example_weights: Optional [batch] weights for the loss on each example.
110
      See utils.compute_weighted_cross_entropy for details.
111
    grad_clip: If not None, clip gradients to [-x, +x].
112
    epsilon: Epsilon for denominator of loss averaging.
113

114
  Returns:
115
    new_optimizer, metrics, new_dropout_rng
116
  """
117

118
  # We handle PRNG splitting inside the top pmap, rather
119
  # than handling it outside in the training loop - doing the
120
  # latter can add some stalls to the devices.
121
  dropout_rng, new_dropout_rng = jrandom.split(dropout_rng)
122
  dropout_rng, preprocess_rng = jrandom.split(dropout_rng)
123

124
  inputs, targets, weights = preprocess_fn(
125
      inputs, rng=preprocess_rng, mode=Mode.train)
126

127
  if isinstance(targets, dict):
128
    classification_targets = targets['classification']
129
    classification_weights = weights['classification']
130

131
    regression_targets = targets['regression']
132
    regression_weights = weights['regression']
133
  else:
134
    # Default to classification loss.
135
    classification_targets = targets
136
    classification_weights = weights
137
    regression_targets = None
138

139
  if classification_targets is None and regression_targets is None:
140
    raise ValueError('No targets specified for train step.')
141

142
  if classification_weights is None and regression_weights is None:
143
    raise ValueError('No weights specified for train step.')
144

145
  def loss_fn(model):
146
    """Loss function used for training."""
147
    # Stateful collection for tracking internal state like activations.
148
    with nn.stateful() as batch_stats:
149
      with nn.stochastic(dropout_rng):
150
        outputs = model(inputs, train=True, cache=None)
151

152
      if isinstance(outputs, dict):
153
        logits = outputs.get('logits', None)
154
        regression_predictions = outputs.get('regression', None)
155
      else:
156
        logits = outputs
157
        regression_predictions = None
158

159
    mean_loss = 0.0
160

161
    # Classification loss
162
    if classification_targets is not None:
163
      classification_loss, classification_weight_sum = utils.compute_weighted_cross_entropy(
164
          logits,
165
          classification_targets,
166
          token_weights=classification_weights,
167
          example_weights=example_weights)
168
      classification_weight_sum = jnp.maximum(classification_weight_sum,
169
                                              epsilon)
170
      # Handle case where nothing is masked out in BERT
171
      # (Only occurs with very short sequences).
172
      mean_classification_loss = classification_loss / classification_weight_sum
173
      mean_loss += mean_classification_loss
174

175
    if regression_targets is not None:
176
      regression_loss, regression_weight_sum = utils.compute_weighted_mse(
177
          regression_predictions,
178
          regression_targets,
179
          weights=regression_weights)
180
      regression_weight_sum = jnp.maximum(regression_weight_sum, epsilon)
181
      mean_regression_loss = regression_loss / regression_weight_sum
182
      outputs['regression_loss'] = mean_regression_loss
183

184
      # TODO(ddohan): Allow weighting each loss separately.
185
      mean_loss += mean_regression_loss
186

187
    return mean_loss, (outputs, batch_stats)
188

189
  step = optimizer.state.step
190
  lr = learning_rate_fn(step)
191

192
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
193
  (_, (outputs, batch_stats)), grad = grad_fn(optimizer.target)
194

195
  try:
196
    grad = jax.lax.pmean(grad, 'batch')
197
  except NameError:
198
    pass
199

200
  if grad_clip is not None:
201
    # Clip gradients after pmean aggregation
202
    unclipped_grad = grad
203
    grad = jax.example_libraries.optimizers.clip_grads(grad, grad_clip)
204

205
  new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
206

207
  # TODO(ddohan): Avoid computing metrics except when needed.
208
  if isinstance(outputs, dict):
209
    logits = outputs.get('logits', None)
210
  else:
211
    logits = outputs
212

213
  metrics = dict()
214
  if logits is not None:
215
    classification_metrics = utils.compute_metrics(logits,
216
                                                   classification_targets,
217
                                                   classification_weights)
218
    metrics.update(classification_metrics)
219
  if regression_targets is not None:
220
    # TODO(ddohan): Implement regression metrics.
221
    logging.info('No regression targets yet')
222
    # regression = outputs.get('regression', None)
223
    # regression_metrics = utils.compute_metrics(logits, regression_targets,
224
    #                                                classification_weights)
225
  metrics['learning_rate'] = lr
226

227
  # Training metrics
228
  metrics['l2_param_sum'] = utils.l2_regularization(optimizer.target.params)
229

230
  # Gradient norms
231
  grad_l2_tree = utils.l2_norm(grad)
232
  grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree)
233
  grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree)
234
  metrics['l2_grad_sum'] = grad_l2_sum
235
  metrics['l2_grad_max'] = grad_l2_max
236

237
  # Store any tagged metrics
238
  batch_stats = batch_stats.as_dict()
239
  if batch_stats:
240

241
    def clean_name(k):
242
      return 'nn/' + k.replace('MultiHeadDotProductAttention_', '').replace(
243
          '/Transformer1DBlock_', '')
244

245
    stats = {clean_name(k): v['tag'] for k, v in batch_stats.items()}
246
    metrics.update(stats)
247

248
  if grad_clip is not None:
249
    # Unclipped gradient norms (if applicable).
250
    grad_l2_tree = utils.l2_norm(unclipped_grad)
251
    grad_l2_sum = jax.tree_util.tree_reduce(op.add, grad_l2_tree)
252
    grad_l2_max = jax.tree_util.tree_reduce(jnp.maximum, grad_l2_tree)
253
    metrics['l2_noclip_grad_sum'] = grad_l2_sum
254
    metrics['l2_noclip_grad_max'] = grad_l2_max
255

256
  return new_optimizer, metrics, new_dropout_rng
257

258

259
def eval_step(model, inputs, preprocess_fn):
260
  inputs, targets, weights = preprocess_fn(inputs, rng=None, mode=Mode.eval)
261
  logits = model(inputs, train=False, cache=None)
262
  return utils.compute_metrics(logits, targets, weights)
263

264

265
def predict_step(model, inputs, preprocess_fn, output_head='logits'):
266
  inputs, _, _ = preprocess_fn(inputs, rng=None, mode=Mode.predict)
267
  logits = model(inputs, train=False, cache=None, output_head=output_head)
268
  return logits
269

270

271
def _tokens_to_logits(last_token, cache, model, internal_state=None):
272
  """Computes the next token logits.
273

274
  Args:
275
    last_token: An array of shape (batch_size, 1) containing last token ids.
276
    cache: A flax.deprecated.nn.attention.Cache object.
277
    model: A Jax decoder model to be used for computing the next token logits.
278
    internal_state: A dict with internal state received from the previous time
279
      step. If None, no information is shared across time steps.
280

281
  Returns:
282
    logits: An array of shape (batch_size, vocab_size) with the logits.
283
    new_cache: A flax.deprecated.nn.attention.Cache object with the updated
284
      cache.
285
    new_internal_state: A dict with internal state passed to the next time step.
286
  """
287
  del internal_state  # Not used.
288
  # The returned logits have shape (batch_size, 1, vocab_size).
289
  with cache.mutate() as new_cache:
290
    logits = model(last_token, train=False, cache=new_cache)
291

292
  # Remove the singleton dimension to return shape (batch_size, vocab_size).
293
  logits = logits.squeeze(axis=1)
294
  return logits, new_cache, None
295

296

297
def sample_step(prompt,
298
                model,
299
                cache,
300
                rng,
301
                masked_tokens,
302
                eos_token,
303
                pad_token,
304
                max_decode_len,
305
                tokens_to_logits=_tokens_to_logits,
306
                **sampling_kwargs):
307
  """Samples autoregressively from the model.
308

309
  Args:
310
    prompt: An array of shape (batch_size, prompt_length) containing the input
311
      prompt (the model consumes these tokens and starts generation after). For
312
      generic sampling, the prompt must be a single BOS token.
313
    model: A Jax decoder model to be used for computing the next token logits.
314
    cache: A flax.deprecated.nn.attention.Cache object.
315
    rng: A jax.random.PRNGKey object.
316
    masked_tokens: A list of ints indicating tokens to mask out during sampling.
317
    eos_token: An int indicating the EOS token id. If None, we decode until
318
      reaching the maximum sequence length.
319
    pad_token: An int token used to pad sequences after the eos token. If none,
320
      we set pad_token to eos_token.
321
    max_decode_len: An int indicating the maximum sequence length.
322
    tokens_to_logits: A callable that computes the next token logits given the
323
      current cache and previous token.
324
    **sampling_kwargs: Named arguments passed to sampling.temperature_sample.
325

326
  Returns:
327
    An array of shape (batch_size, max_decode_len) containing sampled sequences.
328
      If variable-length, the sequences are right-padded with the EOS token.
329
  """
330
  tokens_to_logits = functools.partial(tokens_to_logits, model=model)
331
  return sampling.temperature_sample(
332
      prompt,
333
      init_cache=cache,
334
      tokens_to_logits=tokens_to_logits,
335
      max_decode_len=max_decode_len,
336
      rng=rng,
337
      eos_token=eos_token,
338
      pad_token=pad_token,
339
      masked_tokens=masked_tokens,
340
      **sampling_kwargs,
341
  )
342

343

344
def compute_logprob(inputs, model, mask_token=None):
345
  """Returns an array of log probabilities for the input sequences."""
346

347
  assert inputs.ndim == 2
348

349
  targets = inputs
350
  weights = jnp.where(targets != model.pad_token, 1, 0)
351
  if mask_token is not None:
352
    weights *= jnp.where(targets != mask_token, 1, 0)
353
  logits = model.score(inputs)
354
  assert logits.ndim == 3
355

356
  onehot_targets = common_utils.onehot(targets, logits.shape[-1])
357
  log_lik = jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
358
  log_lik *= weights
359
  log_prob = jnp.sum(log_lik, axis=-1)
360

361
  return log_prob
362

363

364
def preprocess_causal(batch, bos_token, pad_token, mode):
365
  """Preprocessing for causal language modeling.
366

367
  Right shifts and shards.
368

369
  Args:
370
    batch: [batch x length] tokens.
371
    bos_token: Int ID to use as beginning of sentence token.
372
    pad_token: Padding token which should be masked out in loss.
373
    mode: Mode value.
374

375
  Returns:
376
    Tuple of [batch x length] inputs, targets, per position weights. Targets
377
      will have random positions masked out with either a MASK token, or a
378
      randomly chosen token from the vocabulary.
379
  """
380
  if mode == Mode.sample:
381
    inputs = batch
382
  else:
383
    inputs = modules.shift_right(batch, bos_token=bos_token)
384

385
  targets = batch
386
  # Mask out PAD in loss.
387
  if pad_token is None:
388
    weights = jnp.ones_like(targets)
389
  else:
390
    weights = jnp.where(targets != pad_token, 1, 0)
391
  return inputs, targets, weights
392

393

394
@gin.configurable
395
class FlaxModel(abc.ABC):
396
  """Model built on Flax."""
397

398
  def __init__(self,
399
               domain=data.protein_domain,
400
               model_cls=modules.Transformer,
401
               random_seed=0,
402
               batch_size=None,
403
               grad_clip=None,
404
               learning_rate=0.001,
405
               weight_decay=0.1,
406
               cache=True,
407
               pmap=True,
408
               attention_fn=None,
409
               with_bos=False,
410
               with_mask=False,
411
               store_metrics=False,
412
               sampling_kwargs=None,
413
               **model_kwargs):
414
    """Creates a Flax model for sequence prediction.
415

416
    Args:
417
      domain: discrete domain.
418
      model_cls: Flax.nn.Module to train.
419
      random_seed: Random seed.
420
      batch_size: Default batch size.
421
      grad_clip: Gradient clipping in optimizer.
422
      learning_rate: learning rate in optimizer, or callable mapping a step to
423
        current learning rate.
424
      weight_decay: L2 decay for AdamW.
425
      cache: Whether to create a cache.
426
      pmap: Whether to pmap inference (and JIT as a side effect).
427
      attention_fn: Function to use in place of nn.dot_product_attention.
428
      with_bos: Whether to ensure vocab contains BOS.
429
      with_mask: Whether to ensure vocab contains MASK.
430
      store_metrics: Whether to store train and evaluation metrics.
431
      sampling_kwargs: Additional config options for sample step.
432
      **model_kwargs: Additional config options for `model_cls.partial`.
433
    """
434
    self._batch_size = batch_size  # Default batch size
435

436
    # TODO(b/157255958): Reenable tracking metrics inside class.
437
    self._store_metrics = store_metrics
438
    if store_metrics:
439
      self._metrics_train = []
440
      self._metrics_test = []
441
      self._epoch_train = []
442
      self._epoch_test = []
443

444
    self._pmap = pmap
445
    self._sampling_kwargs = sampling_kwargs
446
    self._model_kwargs = model_kwargs
447
    self._opt_hparams = dict(
448
        learning_rate=learning_rate,
449
        weight_decay=weight_decay,
450
        grad_clip=grad_clip)
451

452
    # TODO(ddohan): Reimplement __getstate__ and __setstate__ to support pickle,
453
    # and use these functions to init model.
454
    self._set_domain(domain=domain, with_bos=with_bos, with_mask=with_mask)
455
    self._init_model(
456
        model_cls=model_cls,
457
        random_seed=random_seed,
458
        pmap=pmap,
459
        cache=cache,
460
        attention_fn=attention_fn,
461
        sampling_kwargs=sampling_kwargs,
462
        model_kwargs=model_kwargs,
463
        **self._opt_hparams)
464

465
  def _set_domain(self, domain, with_bos, with_mask):
466
    """Set vocabulary based on domain."""
467
    self.domain = domain
468
    self._length = domain.length
469
    self._bos_token = domain.vocab.bos
470
    self._eos_token = domain.vocab.eos
471
    self._pad_token = domain.vocab.pad
472
    self._mask_token = domain.vocab.mask
473

474
    vocab_size = domain.vocab_size
475
    if with_bos and self._bos_token is None:  # Add bos token.
476
      self._bos_token = vocab_size
477
      vocab_size += 1
478
    if with_mask and self._mask_token is None:  # Add mask token.
479
      self._mask_token = vocab_size
480
      vocab_size += 1
481
    self._vocab_size = vocab_size
482

483
  def _get_masked_tokens(self):
484
    """Get list of token IDs to mask for a given domain."""
485
    tokens = []
486
    for token in [self._bos_token, self._pad_token, self._mask_token]:
487
      if token is not None:
488
        tokens.append(token)
489
    return tokens
490

491
  def _init_model(self,
492
                  model_cls,
493
                  pmap,
494
                  learning_rate,
495
                  weight_decay,
496
                  grad_clip,
497
                  attention_fn,
498
                  random_seed,
499
                  cache=True,
500
                  sampling_kwargs=None,
501
                  model_kwargs=None):
502
    """Initialize model."""
503
    model_kwargs = model_kwargs or dict()
504
    model_def = model_cls.partial(
505
        vocab_size=self._vocab_size,
506
        max_len=self.domain.length,
507
        # Don't attend to PAD tokens
508
        pad_token=self._pad_token,
509
        attention_fn=attention_fn,
510
        **model_kwargs)
511

512
    if callable(learning_rate):
513
      learning_rate_fn = learning_rate
514
    else:
515
      learning_rate_fn = lambda step: learning_rate
516

517
    train_fn = functools.partial(
518
        train_step,
519
        learning_rate_fn=learning_rate_fn,
520
        grad_clip=grad_clip,
521
        preprocess_fn=self.preprocess)
522
    eval_fn = functools.partial(eval_step, preprocess_fn=self.preprocess)
523
    predict_fn = functools.partial(predict_step, preprocess_fn=self.preprocess)
524

525
    sampling_kwargs = sampling_kwargs or dict()
526
    masked_tokens = self._get_masked_tokens()
527
    sample_fn = functools.partial(
528
        sample_step,
529
        masked_tokens=masked_tokens,
530
        eos_token=self._eos_token,
531
        pad_token=self._pad_token,
532
        max_decode_len=self._length + 1,
533
        **sampling_kwargs)
534

535
    # Default to pmapped versions.
536
    if pmap:
537
      train_fn = jax.pmap(train_fn, axis_name='batch')
538
      eval_fn = jax.pmap(eval_fn, axis_name='batch')
539
      sample_fn = jax.pmap(sample_fn, axis_name='batch')
540
      predict_fn = jax.pmap(predict_fn, axis_name='batch')
541

542
    self._train_fn = train_fn
543
    self._predict_fn = predict_fn
544
    self._sample_fn = sample_fn
545
    self._eval_fn = eval_fn
546

547
    rng = jrandom.PRNGKey(random_seed)
548
    rng, init_rng = jrandom.split(rng)
549
    rng, self._sample_rng = jrandom.split(rng)
550

551
    # We init the first set of dropout PRNG keys, but update it afterwards
552
    # inside the main pmap'd training update for performance.
553
    if self._pmap:
554
      self._dropout_rngs = jrandom.split(rng, jax.local_device_count())
555
    else:
556
      self._dropout_rngs = rng
557

558
    # Note: any batch size can be used later. This is arbitrary for init.
559
    input_shape = (self._batch_size or 2, self.domain.length)
560
    if cache:
561
      init_model, self._cache_def = utils.create_model_and_cache(
562
          init_rng, input_shape, model_def)
563
    else:
564
      init_model = utils.create_model(init_rng, input_shape, model_def)
565
      self._cache_def = None
566
    self._optimizer = utils.create_adam_optimizer(
567
        init_model,
568
        learning_rate=learning_rate,
569
        weight_decay=weight_decay,
570
        pmap=self._pmap)
571
    del init_model  # Delete initial model.
572

573
  def preprocess(self, batch, rng, mode):
574
    """Unpack batch of data to (inputs, targets, weights).
575

576
    batch may be one of:
577
      - a [batch x length] batch of input data.
578
        Results in (batch, None, None)
579
      - a tuple of (inputs, targets)
580
        Results in (inputs, targets, ones_like(targets))
581
      - a tuple of (inputs, targets, weights)
582
        Passed through unchanged.
583
      - a dict containing 'inputs', 'targets', and
584
        optionally 'weights'.
585
        Results in (inputs, targets, weights or ones_like(targets))
586

587
    Args:
588
      batch: Batch of data.
589
      rng: Ignored. Jax random seed.
590
      mode: member of Mode enum.
591

592
    Returns:
593
      Tuple of (inputs, targets, weights).
594
        `targets` and `weights` are None if `targets` is not provided.
595
    """
596
    del rng
597
    if isinstance(batch, tuple):
598
      if len(batch) == 2:
599
        inputs, targets = batch
600
        weights = jnp.ones_like(targets)
601
      elif len(batch) == 3:
602
        inputs, targets, weights = batch
603
      else:
604
        raise ValueError(
605
            'Must provide (inputs, targets) or (inputs, targets, weights)')
606
    elif isinstance(batch, dict):
607
      inputs = batch['inputs']
608
      targets = batch['targets']
609
      weights = batch.get('targets', None)
610
      if weights is None:
611
        weights = jnp.ones_like(targets)
612
    else:
613
      inputs = batch
614
      targets = None
615
      weights = None
616

617
    if targets is None and mode not in (Mode.predict, Mode.sample):
618
      raise ValueError('Must provide targets for train and eval.')
619

620
    return inputs, targets, weights
621

622
  @property
623
  def train_step(self):
624
    """Returns the current train step."""
625
    step = self.optimizer.state.step
626
    if self._pmap:
627
      step = step[0]
628
    return step
629

630
  @property
631
  def bos_token(self):
632
    """Returns the BOS token id."""
633
    return self._bos_token
634

635
  @property
636
  def eos_token(self):
637
    """Returns the EOS token id."""
638
    return self._eos_token
639

640
  @property
641
  def pad_token(self):
642
    """Returns the BOS token id."""
643
    return self._pad_token
644

645
  @property
646
  def mask_token(self):
647
    """Returns the MASK token id."""
648
    return self._mask_token
649

650
  @property
651
  def length(self):
652
    """Returns the maximum sequence length."""
653
    return self._length
654

655
  @property
656
  def vocab_size(self):
657
    """Returns the vocabulary size used for training."""
658
    return self._vocab_size
659

660
  @property
661
  def optimizer(self):
662
    """Returns Flax optimizer containing optimizer and model parameters."""
663
    return self._optimizer
664

665
  @property
666
  def model_kwargs(self):
667
    """Returns the model kwargs as a dictionary."""
668
    return self._model_kwargs
669

670
  @property
671
  def pmap(self):
672
    """Returns whether or not the optimizer was trained with pmap."""
673
    return self._pmap
674

675
  def set_weights(self, optimizer):
676
    """Sets weights from unreplicated optimizer."""
677
    if self._pmap:
678
      optimizer = jax_utils.replicate(optimizer)
679
    self._optimizer = optimizer
680

681
  def get_weights(self):
682
    """Returns unreplicated optimizer."""
683
    optimizer = self.optimizer
684
    if self._pmap:
685
      optimizer = jax_utils.unreplicate(optimizer)
686
    return optimizer
687

688
  def save_checkpoint(self, ckpt_dir):
689
    """Saves unreplicated optimizer to ckpt_dir."""
690
    optimizer = self.get_weights()
691
    checkpoints.save_checkpoint(
692
        ckpt_dir,
693
        target=optimizer,
694
        step=self.train_step,
695
    )
696

697
  def load_checkpoint(self, ckpt_dir):
698
    """Loads optimizer from ckpt_dir."""
699
    target = self.get_weights()
700
    optimizer = checkpoints.restore_checkpoint(ckpt_dir, target=target)
701
    if optimizer is target:
702
      raise ValueError('Unable to load checkpoint from %s' % ckpt_dir)
703
    self.set_weights(optimizer)
704

705
  def fit(self, xs, epochs=1, batch_size=None, max_steps=10**6):
706
    """Fits to sequences given as [N x length] token array."""
707
    if batch_size is None:
708
      batch_size = self._batch_size
709
    if hasattr(xs, 'as_numpy_iterator'):
710
      # TF Dataset
711
      ds = xs.repeat(epochs)
712
      num_train_steps = max_steps
713
    elif hasattr(xs, 'element_spec'):
714
      # Dataset iterator.
715
      if epochs != 1:
716
        raise ValueError('Epochs must == 1 when using iterator input.')
717
      ds = xs
718
      num_train_steps = max_steps
719
    else:
720
      # Raw sequences which we turn into a dataset.
721
      ds = data.dataset_from_tensors(xs)
722
      ds = ds.shuffle(buffer_size=1024).repeat().batch(batch_size)
723
      num_train_steps = math.ceil((len(xs) * epochs) / float(batch_size))
724

725
      if max_steps:
726
        num_train_steps = min(num_train_steps, max_steps)
727

728
    if not num_train_steps:
729
      raise ValueError('Must set max_steps to nonzero value.')
730

731
    metrics = []
732
    start = time.time()
733
    max_steps = max_steps or 10**6
734
    for _, batch in zip(range(num_train_steps), ds):
735
      metrics.append(self.fit_batch(batch))
736
    finish = time.time()
737
    average = evaluation.combine_metrics(metrics)
738
    average['runtime'] = finish - start
739
    average['rate'] = len(metrics) / (finish - start)
740

741
    if self._store_metrics:
742
      average = tree.map_structure(onp.array, average)
743
      self._epoch_train.append(average)
744
    return dict(last=evaluation.combine_metrics([metrics[-1]]), average=average)
745

746
  def evaluate(self, ds, steps=None):
747
    """Test model on data generator."""
748
    return evaluation.evaluate(model=self, eval_ds=ds, num_eval_steps=steps)
749

750
  def fit_batch(self, batch):
751
    """Update model on batch of sequences of shape [batch x length]."""
752
    batch = tree.map_structure(jnp.asarray, batch)
753
    if self._pmap:
754
      batch = common_utils.shard(batch)
755
    self._optimizer, metrics, self._dropout_rngs = self._train_fn(
756
        optimizer=self.optimizer, inputs=batch, dropout_rng=self._dropout_rngs)
757
    if self._store_metrics:
758
      metrics = tree.map_structure(onp.array, metrics)
759
      self._metrics_train.append(metrics)
760
    return metrics
761

762
  def score(self, batch):
763
    """Predicts logits for given [batch x length] sequences."""
764
    batch = tree.map_structure(jnp.asarray, batch)
765
    if self._pmap:
766
      batch = common_utils.shard(batch)
767
    logits = self._predict_fn(self.optimizer.target, batch)
768
    # Undo pmap batching
769
    if self._pmap:
770
      logits = jnp.reshape(logits, [-1, logits.shape[-2], logits.shape[-1]])
771
    return logits
772

773
  def evaluate_batch(self, batch):
774
    """Computes metrics for given [batch x length] sequences."""
775
    batch = tree.map_structure(jnp.asarray, batch)
776
    if self._pmap:
777
      batch = common_utils.shard(batch)
778
    metrics = self._eval_fn(self.optimizer.target, batch)
779
    if self._store_metrics:
780
      metrics = tree.map_structure(onp.array, metrics)
781
      self._metrics_test.append(metrics)
782
    return metrics
783

784

785
@gin.configurable
786
class FlaxLM(FlaxModel):
787
  """Transformer with causal attention, right shift, and generative sampling."""
788

789
  def __init__(self,
790
               domain=data.protein_domain,
791
               model_cls=modules.Transformer,
792
               **kwargs):
793

794
    model_cls = model_cls.partial(causal=True)
795
    super().__init__(
796
        domain=domain, model_cls=model_cls, cache=True, with_bos=True, **kwargs)
797

798
  def preprocess(self, batch, rng, mode):
799
    del rng
800
    return preprocess_causal(
801
        batch=batch,
802
        bos_token=self._bos_token,
803
        pad_token=self._pad_token,
804
        mode=mode)
805

806
  @property
807
  def cache_def(self):
808
    """Returns the associated autoregressive cache_def."""
809
    return self._cache_def
810

811
  def sample_with_prompt(self, prompt, rng=None):
812
    """Draws prompt-guided samples from the model.
813

814
    # TODO(gandreea): We could handle variable length prompts by assuming the
815
    #   input prompt to be a list and padding with the out_of_prompt_token.
816

817
    Args:
818
      prompt: Iterable over equal-length sequences to use as input for sampling.
819
        The prompt is assumed to start with the BOS token.
820
      rng: A jax.random.PRNGKey object.
821

822
    Returns:
823
      An array of shape (len(prompt), self._length) containing sequences. If
824
        variable-length, the sequences are right-padded with the EOS token.
825
    """
826
    if rng is None:
827
      self._sample_rng, rng = jax.random.split(self._sample_rng)
828
    length = self._length + 1
829

830
    if self._pmap:
831
      prompt = common_utils.shard(prompt)
832
      cache = self.cache_def.initialize_cache((prompt.shape[1], length))
833
      cache = jax_utils.replicate(cache)
834
      rng = jax.random.split(rng, num=len(jax.local_devices()))
835
    else:
836
      cache = self.cache_def.initialize_cache((prompt.shape[0], length))
837

838
    samples = self._sample_fn(
839
        prompt=prompt,
840
        model=self.optimizer.target,
841
        cache=cache,
842
        rng=rng,
843
    )
844

845
    # Remove the BOS token from the sampled sequences.
846
    samples = samples[Ellipsis, 1:]
847

848
    # Undo pmap batching
849
    if self._pmap:
850
      samples = jnp.reshape(samples, [-1, self._length])
851
    return samples
852

853
  def sample(self, batch_size, rng=None):
854
    """Draws samples from the model.
855

856
    Args:
857
      batch_size: An int indicating the number of samples to return.
858
      rng: A jax.random.PRNGKey object.
859

860
    Returns:
861
      An array of shape (batch_size, self._length) containing sequences. If
862
        variable-length, the sequences are right-padded with the EOS token.
863
    """
864
    # To draw generic samples, we initialize the prompt with the BOS token.
865
    prompt = jnp.ones((batch_size, 1)).astype(jnp.int32) * self._bos_token
866
    return self.sample_with_prompt(prompt, rng=rng)
867

868

869
@gin.configurable
870
class FlaxBERT(FlaxModel):
871
  """Transformer with all-to-all attention and token dropout."""
872

873
  def __init__(self,
874
               domain=data.protein_domain,
875
               model_cls=modules.Transformer,
876
               mask_rate=0.15,
877
               random_token_proportion=0.8,
878
               mask_token_proportion=0.1,
879
               **kwargs):
880
    """Create BERT model.
881

882

883
    For each token in input, masks with probability `mask_rate`. A masked token
884
    is replaced with:
885
    - MASK with probability `mask_token_proportion`,
886
    -  a random token with `random_token_proportion`,
887
    - left unchanged but included in loss with the remaining probability.
888

889
    Args:
890
      domain: Domain to operate over.
891
      model_cls: Flax Module operating on sequences.
892
      mask_rate: Probability of replacing a token and including in the loss
893
      random_token_proportion: Portion of masked tokens to replace with a
894
        randomly chosen token.
895
      mask_token_proportion: Portion of masked tokens to replace with MASK.
896
      **kwargs: Arguments passed to FlaxModel.
897
    """
898
    model_cls = model_cls.partial(causal=False)
899
    self._mask_rate = mask_rate
900
    total = random_token_proportion + mask_token_proportion
901
    if total < 0 or total > 1:
902
      raise ValueError('Sum of random proportion and mask proportion must be'
903
                       ' in [0, 1] range.')
904
    self._masker = BertMasker(
905
        domain,
906
        mask_rate=mask_rate,
907
        mask_token_proportion=mask_token_proportion,
908
        random_token_proportion=random_token_proportion)
909

910
    super().__init__(
911
        domain=domain,
912
        model_cls=model_cls,
913
        cache=False,
914
        with_mask=True,
915
        **kwargs)
916

917
  def preprocess(self, batch, rng, mode):
918
    return self._masker(inputs=batch, mode=mode, rng=rng)
919

920
  def sample(self, masked_inputs, rng):
921
    """Fill in MASK positions in inputs."""
922
    mask_positions = masked_inputs == self.domain.vocab.mask
923
    logits = self.score(masked_inputs)
924

925
    # Mask out MASK token.
926
    mask = common_utils.onehot(
927
        jnp.array([self.domain.vocab.mask]),
928
        num_classes=logits.shape[-1],
929
        on_value=sampling.LARGE_NEGATIVE)
930
    logits = logits + mask
931
    samples = jax.random.categorical(rng, logits=logits)
932
    infilled = onp.where(mask_positions, samples, masked_inputs)
933
    return infilled
934

935

936
def preprocess_masked(inputs, random_tokens, mask_token, pad_token, mask_rate,
937
                      mask_token_proportion, random_token_proportion, mode,
938
                      rng):
939
  """Preprocess inputs for masked language modeling.
940

941
  Args:
942
    inputs: [batch x length] input tokens.
943
    random_tokens: Set of tokens usable for replacing
944
    mask_token: Int ID to mask blanks with.
945
    pad_token: Int ID for PAD token. Positions left unchanged.
946
    mask_rate: Proportion of tokens to mask out.
947
    mask_token_proportion: Replace this proportion of chosen positions with
948
      MASK.
949
    random_token_proportion: Replace this proportion of chosen positions with
950
      randomly sampled tokens
951
    mode: Mode key.
952
    rng: Jax RNG.
953

954
  Returns:
955
    Tuple of [batch x length] inputs, targets, per position weights. targets
956
      will have random positions masked out with either a MASK token, or a
957
      randomly chosen token from the vocabulary.
958
  """
959
  total = random_token_proportion + mask_token_proportion
960
  if total < 0 or total > 1:
961
    raise ValueError('Sum of random proportion and mask proportion must be'
962
                     ' in [0, 1] range.')
963
  targets = inputs
964

965
  if mode == Mode.predict:
966
    weights = jnp.full_like(targets, 1)
967
    masked_inputs = inputs  # Pass through
968
  else:
969
    if rng is None:
970
      if mode is not Mode.eval:
971
        raise ValueError('Must provide RNG unless in eval mode.')
972
      # TODO(b/157055145): How to keep same eval set across runs?
973
      # Make each sequences mask invariant to other members
974
      # of the batch. Right now there is batch size dependence.
975
      rng = jrandom.PRNGKey(jnp.sum(inputs))
976

977
    # Get positions to leave untouched
978
    is_pad = inputs == pad_token
979

980
    # Positions to mask
981
    rng, subrng = jax.random.split(rng)
982
    should_mask = jrandom.bernoulli(subrng, p=mask_rate, shape=inputs.shape)
983
    should_mask = jnp.where(is_pad, 0, should_mask)  # Don't mask out padding.
984

985
    # Generate full array of random tokens.
986
    rng, subrng = jax.random.split(rng)
987
    random_ids = jax.random.randint(
988
        subrng, inputs.shape, minval=0, maxval=len(random_tokens))
989

990
    fullrandom = random_tokens[random_ids]
991
    # Full array of MASK tokens
992
    fullmask = jnp.full_like(inputs, mask_token)
993

994
    # Build up masked array by selecting from inputs/fullmask/fullrandom.
995
    rand = jax.random.uniform(rng, shape=inputs.shape)
996
    masked_inputs = inputs
997
    # Remaining probability mass stays original values after MASK and RANDOM.
998
    # MASK tokens.
999
    masked_inputs = jnp.where(rand < mask_token_proportion, fullmask,
1000
                              masked_inputs)
1001
    # Random tokens.
1002
    masked_inputs = jnp.where(
1003
        jnp.logical_and(rand >= mask_token_proportion,
1004
                        rand < mask_token_proportion + random_token_proportion),
1005
        fullrandom, masked_inputs)
1006

1007
    # Only replace positions where `should_mask`
1008
    masked_inputs = jnp.where(should_mask, masked_inputs, inputs)
1009
    weights = should_mask
1010

1011
  return masked_inputs, targets, weights
1012

1013

1014
class BertMasker():
1015
  """Construct BERT masker given a domain."""
1016

1017
  def __init__(self,
1018
               domain,
1019
               mask_rate=0.15,
1020
               mask_token_proportion=0.1,
1021
               random_token_proportion=0.8):
1022
    vocab = domain.vocab
1023
    if vocab.mask is None:
1024
      raise ValueError('Vocabulary must specify a MASK token.')
1025
    special_tokens = [vocab.bos, vocab.eos, vocab.mask, vocab.pad]
1026
    special_tokens = [x for x in special_tokens if x is not None]
1027
    normal_tokens = [x for x in vocab.token_ids if x not in special_tokens]
1028
    self._domain = domain
1029
    self._special_tokens = jnp.array(special_tokens)
1030
    self._normal_tokens = jnp.array(normal_tokens)
1031
    self._mask_rate = mask_rate
1032
    self._mask_token_proportion = mask_token_proportion
1033
    self._random_token_proportion = random_token_proportion
1034

1035
  def __call__(self, inputs, mode, rng):
1036
    inputs, targets, weights = preprocess_masked(
1037
        inputs=inputs,
1038
        mode=mode,
1039
        rng=rng,
1040
        random_tokens=self._normal_tokens,
1041
        mask_token=self._domain.vocab.mask,
1042
        pad_token=self._domain.vocab.pad,
1043
        mask_rate=self._mask_rate,
1044
        mask_token_proportion=self._mask_token_proportion,
1045
        random_token_proportion=self._random_token_proportion)
1046
    return inputs, targets, weights
1047

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.