google-research

Форк
0
485 строк · 15.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utils related to Flax models."""
17

18
import collections
19
import functools
20
import operator as op
21
import pprint
22
import time
23

24
from absl import logging
25
from flax import jax_utils
26
from flax import optim
27
from flax.deprecated import nn
28
from flax.training import common_utils
29
import gin
30
import jax
31
from jax import lax
32
import jax.nn
33
import jax.numpy as jnp
34
import numpy as np
35
import tree
36

37

38
def l2_norm(params):
39
  return jax.tree_util.tree_map(lambda x: jnp.sum(x**2), params)
40

41

42
def l2_regularization(params):
43
  """Computes l2 regularization term for parameters."""
44
  return jax.tree_util.tree_reduce(op.add, l2_norm(params))
45

46

47
@functools.partial(jax.jit, static_argnums=(1, 2))
48
def create_model_and_cache(rng, input_shape, model_def):
49
  """Create a model and cache definition.
50

51
  Args:
52
    rng: Init RNG.
53
    input_shape: Input shape.
54
    model_def: Model definition.
55

56
  Returns:
57
    Tuple of (model, cache_def)
58
  """
59
  # Create a cache object for efficient decoding.
60
  with nn.attention.Cache().mutate() as cache_def:
61
    _, model = model_def.create_by_shape(
62
        rng, [(input_shape, jnp.float32)], cache=cache_def)
63
  return model, cache_def
64

65

66
@functools.partial(jax.jit, static_argnums=(1, 2))
67
def create_model(rng, input_shape, model_def):
68
  """Create a model and cache definition.
69

70
  Args:
71
    rng: Init RNG.
72
    input_shape: Input shape.
73
    model_def: Model definition.
74

75
  Returns:
76
    Tuple of (model, cache_def)
77
  """
78
  _, model = model_def.create_by_shape(
79
      rng, [(input_shape, jnp.float32)], cache=None)
80
  return model
81

82

83
def create_adam_optimizer(model,
84
                          learning_rate,
85
                          weight_decay=0.0,
86
                          beta1=0.9,
87
                          beta2=0.98,
88
                          eps=1e-9,
89
                          pmap=True):
90
  """Create (optionally replicated) Adam optimizer for `model`."""
91
  optimizer_def = optim.Adam(
92
      learning_rate,
93
      beta1=beta1,
94
      beta2=beta2,
95
      eps=eps,
96
      weight_decay=weight_decay)
97
  optimizer = optimizer_def.create(model)
98
  if pmap:
99
    optimizer = jax_utils.replicate(optimizer)
100
  return optimizer
101

102

103
def compute_weighted_cross_entropy(logits, targets,
104
                                   token_weights=None,
105
                                   example_weights=None):
106
  """Compute weighted cross entropy and entropy for log probs and targets.
107

108
  The loss is assumed to be sum_i example_weights[i] * logprob[i], where
109
  i indexes elements in the batch.
110

111
  logprob[i] is the log probability of sequence i, which is a weighted
112
  average of per-token log probabilities with weights according
113
  to token_weights. Typically token_weights is a mask for whether tokens are
114
  padding or not.
115

116
  Maximum likelihood training sets example_weights[i] = 1.
117
  Training with a REINFORCE-style objective may set example_weights[i]
118
  to any positive or negative number.
119

120
  Args:
121
   logits: [batch, length, num_classes] float array.
122
   targets: categorical targets [batch, length] int array.
123
   token_weights: None or array of shape [batch x length]
124
   example_weights: None or array of shape [batch_size]
125
  Returns:
126
    Tuple of scalar loss and batch normalizing factor.
127
  """
128
  if logits.ndim != targets.ndim + 1:
129
    raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
130
                     (str(logits.shape), str(targets.shape)))
131
  onehot_targets = common_utils.onehot(targets, logits.shape[-1])
132
  loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
133
  normalizing_factor = onehot_targets.sum()
134
  if token_weights is not None:
135
    loss = loss * token_weights
136
    normalizing_factor = token_weights.sum()
137

138
  if example_weights is not None:
139
    loss = loss.sum(axis=1)
140
    loss *= example_weights
141

142
  return loss.sum(), normalizing_factor
143

144

145
def compute_weighted_mse(predictions, targets, weights):
146
  """Compute mean squared error between predictions and targets.
147

148
  Args:
149
   predictions: [batch, length, ...] float array.
150
   targets: float targets of same size as predictions.
151
   weights: weights of same shape as predictions.
152

153
  Returns:
154
    Scalar loss.
155
  """
156
  if predictions.shape != targets.shape:
157
    raise ValueError(
158
        f'Incorrect shapes. Got shape {predictions.shape} predictions'
159
        f' and {targets.shape} targets')
160
  per_position_loss = jnp.square(targets - predictions) * weights
161
  summed_loss = jnp.sum(per_position_loss)
162
  denominator = jnp.sum(weights)
163
  return summed_loss, denominator
164

165

166
def compute_weighted_accuracy(logits, targets, weights=None):
167
  """Compute weighted accuracy for log probs and targets.
168

169
  Args:
170
   logits: [batch, length, num_classes] float array.
171
   targets: categorical targets [batch, length] int array.
172
   weights: None or array of shape [batch x length]
173

174
  Returns:
175
    Tuple of scalar accuracy and batch normalizing factor.
176
  """
177
  if logits.ndim != targets.ndim + 1:
178
    raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
179
                     (str(logits.shape), str(targets.shape)))
180
  loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
181
  normalizing_factor = np.prod(logits.shape[:-1])
182
  if weights is not None:
183
    loss = loss * weights
184
    normalizing_factor = weights.sum()
185

186
  return loss.sum(), normalizing_factor
187

188

189
def get_normalized_matrix(domain, freq_dict):
190
  """Compute the normalized matrix for soft-accuracy computation.
191

192
  Args:
193
    domain: A Sequin domain which provides the ordered list of tokens.
194
    freq_dict: A dict of dicts containing pairs of frequencies. E.g. for
195
      computing the normalized matrix based on the Blosum matrix use
196
      freq_dict=pfam_utils.BLOSUM62_TABLE.to_dict().
197

198
  Returns:
199
    An array of shape (vocab_size, vocab_size) containing the matrix to be
200
      used for soft-accuracy computation.
201
  """
202
  matrix = np.zeros((domain.vocab_size, domain.vocab_size))
203
  for idx_1, token_1 in enumerate(domain.vocab.tokens):
204
    for idx_2, token_2 in enumerate(domain.vocab.tokens):
205
      if token_1 in freq_dict:
206
        if token_2 in freq_dict[token_1]:
207
          matrix[idx_1][idx_2] = freq_dict[token_1][token_2]
208
  matrix -= np.min(matrix)
209
  matrix /= np.max(matrix)
210
  return matrix
211

212

213
def compute_weighted_soft_accuracy(logits, targets, weights=None, matrix=None):
214
  """Compute weighted soft-accuracy for log probs and targets.
215

216
  Based on Section 3.4 in
217
    [ProGen](https://www.biorxiv.org/content/10.1101/2020.03.07.982272v2).
218

219
  Args:
220
   logits: [batch, length, num_classes] float array.
221
   targets: categorical targets [batch, length] int array.
222
   weights: None or array of shape [batch x length]
223
   matrix: [num_classes, num_classes] normalized matrix to use for soft-accuracy
224
    computation.
225

226
  Returns:
227
    Tuple of scalar soft-accuracy and batch normalizing factor.
228

229
  Raises:
230
    ValueError when the logits and targets have incorrect number of dimensions.
231
  """
232
  if logits.ndim != targets.ndim + 1:
233
    raise ValueError(f'Incorrect shapes. Got shape {logits.shape} for logits '
234
                     f'and {targets.shape} for targets.')
235

236
  # Compute hard accuracy.
237
  pred = np.argmax(logits, axis=-1)
238
  loss = np.equal(pred, targets).astype(np.float32)
239

240
  # Add matrix-based accuracy for incorrect predictions.
241
  if matrix is not None:
242
    matrix = matrix * (np.ones(len(matrix)) - np.eye(len(matrix)))
243
    loss_matrix = matrix[np.reshape(pred, [-1])]
244
    loss_matrix = np.transpose(loss_matrix)
245
    loss_matrix = loss_matrix[np.reshape(targets, [-1])]
246
    loss_matrix = np.reshape(np.diag(loss_matrix), pred.shape)
247
    loss += loss_matrix
248

249
  normalizing_factor = np.prod(logits.shape[:-1])
250
  if weights is not None:
251
    loss = loss * weights
252
    normalizing_factor = weights.sum()
253

254
  return loss.sum(), normalizing_factor
255

256

257
def _psum(target_tree, axis_name='batch'):
258
  return jax.tree_map(lambda x: lax.psum(x, axis_name), target_tree)
259

260

261
def compute_metrics(logits, labels, token_weights, example_weights=None):
262
  """Compute summary metrics with loss, accuracy, and normalizing factor."""
263
  loss, weight_sum = compute_weighted_cross_entropy(logits, labels,
264
                                                    token_weights,
265
                                                    example_weights)
266
  acc, _ = compute_weighted_accuracy(logits, labels, token_weights)
267
  metrics = {
268
      'loss': loss,
269
      'accuracy': acc,
270
      'denominator': weight_sum,
271
  }
272
  try:
273
    metrics = _psum(metrics)
274
  except NameError:
275
    pass  # We are not inside pmap. No need to psum.
276
  return metrics
277

278

279
def get_params(model):
280
  """Get model parameters."""
281
  return model.optimizer.target.params
282

283

284
def param_count(model):
285
  """Get total parameter count."""
286
  params = get_params(model)
287
  num_params = sum(x.size for x in jax.tree_leaves(params))
288
  return num_params
289

290

291
def param_pprint(model):
292
  """Pretty print parameter tree to stdout."""
293
  params = get_params(model)
294
  x = tree.map_structure(lambda x: x.size / 1024, params)
295
  as_str = pprint.pformat(x)
296
  return as_str
297

298

299
def param_reduce(model, log=False):
300
  """Return a dict containing param counts per submodule."""
301
  params = get_params(model)
302
  sizes = collections.defaultdict(int)
303
  for path, x in tree.flatten_with_path(params):
304
    size = x.size
305
    for i in range(len(path)):
306
      k = path[:i]
307
      sizes[k] += size
308
  for k in sorted(sizes):
309
    if log:
310
      logging.info('%s: %s', k, sizes[k])
311
  return sizes
312

313

314
def batchify(inputs, batch_size):
315
  """Reshapes and pads inputs to include an additional batch dimension.
316

317
  The inputs can be of arbitrary length. They length does not need to be a
318
  multiple of batch_size, in which case padding will be added.
319

320
  Args:
321
    inputs: An np.ndarray or iterable of np.ndarray of shape [input_size, ...].
322
    batch_size:
323
      The size of the batches to group the data into.
324
  Returns:
325
    batch_inputs: np.ndarray of size [num_batches, batch_size, ...],
326
    where num_batches is ceil(input_size / batch_size).
327
    pad_size: Number of examples in the final batch that are padding. We use
328
      copies of inputs[0] as padding.
329
  """
330

331
  inputs = np.asarray(inputs)
332

333
  pad_size = -len(inputs) % batch_size
334
  if pad_size:
335
    padding = np.tile(inputs[:1], [pad_size, 1])
336
    padded_inputs = np.concatenate([inputs, padding], axis=0)
337
  else:
338
    padded_inputs = inputs
339
  batched_shape = (-1, batch_size) + padded_inputs.shape[1:]
340
  batched_inputs = np.reshape(padded_inputs, batched_shape)
341
  return batched_inputs, pad_size
342

343

344
def batch_apply(fn, inputs, batch_size):
345
  """Applies fn() to inputs in batches of size batch_size.
346

347
  The inputs can be of arbitrary length. They length does not need to be a
348
  multiple of batch_size. Padding will be added (and then removed) such that
349
  fn() is always called on inputs of size exactly batch_size. fn() is assumed
350
  to operate independently across the batch dimension of its inputs (e.g.
351
  computing predictions of a model on inputs) instead of performing an operation
352
  where the batch elements interact (e.g. performing a gradient step of a model
353
  on a batch of inputs).
354

355
  Args:
356
    fn: The function to map across the inputs.
357
    inputs: An np.ndarray or iterable of np.ndarray. fn() is mapped
358
      along the first dimension.
359
    batch_size:
360
      The size of the batches to evaluate fn() on.
361
  Returns:
362
    np.ndarray where outputs[i] = fn(inputs[i])
363
  """
364

365
  batched_inputs, pad_size = batchify(inputs, batch_size)
366
  results = np.concatenate([fn(batch) for batch in batched_inputs])
367
  if pad_size:
368
    results = results[:-pad_size]
369
  return results
370

371

372
@gin.configurable
373
def create_learning_rate_scheduler(
374
    factors='constant * linear_warmup * rsqrt_decay',
375
    base_learning_rate=0.5,
376
    warmup_steps=8000,
377
    decay_factor=0.5,
378
    steps_per_decay=20000,
379
    steps_per_cycle=100000):
380
  """Creates learning rate schedule.
381

382
  Interprets factors in the factors string which can consist of:
383
  * constant: interpreted as the constant value,
384
  * linear_warmup: interpreted as linear warmup until warmup_steps,
385
  * rsqrt_decay: divide by square root of max(step, warmup_steps)
386
  * decay_every: Every k steps decay the learning rate by decay_factor.
387
  * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
388

389
  Args:
390
    factors: A string with factors separated by '*' that defines the schedule.
391
    base_learning_rate: Float, the starting constant for the lr schedule.
392
    warmup_steps: How many steps to warm up for in the warmup schedule.
393
    decay_factor: The amount to decay the learning rate by.
394
    steps_per_decay: How often to decay the learning rate.
395
    steps_per_cycle: Steps per cycle when using cosine decay.
396

397
  Returns:
398
    a function learning_rate(step): float -> {'learning_rate': float}, the
399
    step-dependent lr.
400
  """
401
  factors = [n.strip() for n in factors.split('*')]
402

403
  def step_fn(step):
404
    """Step to learning rate function."""
405
    ret = 1.0
406
    for name in factors:
407
      if name == 'constant':
408
        ret *= base_learning_rate
409
      elif name == 'linear_warmup':
410
        ret *= jnp.minimum(1.0, step / warmup_steps)
411
      elif name == 'rsqrt_decay':
412
        ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
413
      elif name == 'rsqrt_normalized_decay':
414
        ret *= jnp.sqrt(warmup_steps)
415
        ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
416
      elif name == 'decay_every':
417
        ret *= (decay_factor**(step // steps_per_decay))
418
      elif name == 'cosine_decay':
419
        progress = jnp.maximum(0.0,
420
                               (step - warmup_steps) / float(steps_per_cycle))
421
        ret *= jnp.maximum(0.0,
422
                           0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
423
      else:
424
        raise ValueError('Unknown factor %s.' % name)
425
    return jnp.asarray(ret, dtype=jnp.float32)
426

427
  return step_fn
428

429

430
class Timer(object):
431
  """Context manager for logging timing.
432

433
  Example usage:
434
    with Timer('my function'):
435
      my_function(inputs)
436

437
  Attributes:
438
    elapsed: The time in seconds that it took to execute the context.
439
  """
440

441
  def __init__(self, message=None, verbose=True):
442
    """Creates and instance of this class.
443

444
    Args:
445
      message: The message to be used for logging. If `None`, does not log.
446
      verbose: Whether to log messages to the console.
447
    """
448
    self._message = message
449
    self._elapsed = None
450
    self._verbose = verbose
451

452
  def _log(self, msg, *args, **kwargs):
453
    if self._message and self._verbose:
454
      logging.info(msg, *args, **kwargs)
455
      logging.flush()
456

457
  def __enter__(self):
458
    self._log('Starting: %s', self._message)
459
    self._elapsed = None
460
    self._start = time.time()
461
    return self
462

463
  def __exit__(self, *args):
464
    self._elapsed = time.time() - self._start
465
    self._log('Finished: %s. Elapsed seconds: %f', self._message, self._elapsed)
466

467
  @property
468
  def elapsed(self):
469
    if self._elapsed is None:
470
      raise ValueError('Timer not executed!')
471
    return self._elapsed
472

473

474
def get_random_state(seed_or_state):
475
  """Returns a np.random.RandomState given an integer seed or RandomState."""
476
  if isinstance(seed_or_state, int):
477
    return np.random.RandomState(seed_or_state)
478
  elif seed_or_state is None:
479
    # This returns the current global np random state.
480
    return np.random.random.__self__
481
  elif not isinstance(seed_or_state, np.random.RandomState):
482
    raise ValueError('Numpy RandomState or integer seed expected! Got: %s' %
483
                     seed_or_state)
484
  else:
485
    return seed_or_state
486

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.