google-research
485 строк · 15.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utils related to Flax models."""
17
18import collections19import functools20import operator as op21import pprint22import time23
24from absl import logging25from flax import jax_utils26from flax import optim27from flax.deprecated import nn28from flax.training import common_utils29import gin30import jax31from jax import lax32import jax.nn33import jax.numpy as jnp34import numpy as np35import tree36
37
38def l2_norm(params):39return jax.tree_util.tree_map(lambda x: jnp.sum(x**2), params)40
41
42def l2_regularization(params):43"""Computes l2 regularization term for parameters."""44return jax.tree_util.tree_reduce(op.add, l2_norm(params))45
46
47@functools.partial(jax.jit, static_argnums=(1, 2))48def create_model_and_cache(rng, input_shape, model_def):49"""Create a model and cache definition.50
51Args:
52rng: Init RNG.
53input_shape: Input shape.
54model_def: Model definition.
55
56Returns:
57Tuple of (model, cache_def)
58"""
59# Create a cache object for efficient decoding.60with nn.attention.Cache().mutate() as cache_def:61_, model = model_def.create_by_shape(62rng, [(input_shape, jnp.float32)], cache=cache_def)63return model, cache_def64
65
66@functools.partial(jax.jit, static_argnums=(1, 2))67def create_model(rng, input_shape, model_def):68"""Create a model and cache definition.69
70Args:
71rng: Init RNG.
72input_shape: Input shape.
73model_def: Model definition.
74
75Returns:
76Tuple of (model, cache_def)
77"""
78_, model = model_def.create_by_shape(79rng, [(input_shape, jnp.float32)], cache=None)80return model81
82
83def create_adam_optimizer(model,84learning_rate,85weight_decay=0.0,86beta1=0.9,87beta2=0.98,88eps=1e-9,89pmap=True):90"""Create (optionally replicated) Adam optimizer for `model`."""91optimizer_def = optim.Adam(92learning_rate,93beta1=beta1,94beta2=beta2,95eps=eps,96weight_decay=weight_decay)97optimizer = optimizer_def.create(model)98if pmap:99optimizer = jax_utils.replicate(optimizer)100return optimizer101
102
103def compute_weighted_cross_entropy(logits, targets,104token_weights=None,105example_weights=None):106"""Compute weighted cross entropy and entropy for log probs and targets.107
108The loss is assumed to be sum_i example_weights[i] * logprob[i], where
109i indexes elements in the batch.
110
111logprob[i] is the log probability of sequence i, which is a weighted
112average of per-token log probabilities with weights according
113to token_weights. Typically token_weights is a mask for whether tokens are
114padding or not.
115
116Maximum likelihood training sets example_weights[i] = 1.
117Training with a REINFORCE-style objective may set example_weights[i]
118to any positive or negative number.
119
120Args:
121logits: [batch, length, num_classes] float array.
122targets: categorical targets [batch, length] int array.
123token_weights: None or array of shape [batch x length]
124example_weights: None or array of shape [batch_size]
125Returns:
126Tuple of scalar loss and batch normalizing factor.
127"""
128if logits.ndim != targets.ndim + 1:129raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %130(str(logits.shape), str(targets.shape)))131onehot_targets = common_utils.onehot(targets, logits.shape[-1])132loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)133normalizing_factor = onehot_targets.sum()134if token_weights is not None:135loss = loss * token_weights136normalizing_factor = token_weights.sum()137
138if example_weights is not None:139loss = loss.sum(axis=1)140loss *= example_weights141
142return loss.sum(), normalizing_factor143
144
145def compute_weighted_mse(predictions, targets, weights):146"""Compute mean squared error between predictions and targets.147
148Args:
149predictions: [batch, length, ...] float array.
150targets: float targets of same size as predictions.
151weights: weights of same shape as predictions.
152
153Returns:
154Scalar loss.
155"""
156if predictions.shape != targets.shape:157raise ValueError(158f'Incorrect shapes. Got shape {predictions.shape} predictions'159f' and {targets.shape} targets')160per_position_loss = jnp.square(targets - predictions) * weights161summed_loss = jnp.sum(per_position_loss)162denominator = jnp.sum(weights)163return summed_loss, denominator164
165
166def compute_weighted_accuracy(logits, targets, weights=None):167"""Compute weighted accuracy for log probs and targets.168
169Args:
170logits: [batch, length, num_classes] float array.
171targets: categorical targets [batch, length] int array.
172weights: None or array of shape [batch x length]
173
174Returns:
175Tuple of scalar accuracy and batch normalizing factor.
176"""
177if logits.ndim != targets.ndim + 1:178raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %179(str(logits.shape), str(targets.shape)))180loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)181normalizing_factor = np.prod(logits.shape[:-1])182if weights is not None:183loss = loss * weights184normalizing_factor = weights.sum()185
186return loss.sum(), normalizing_factor187
188
189def get_normalized_matrix(domain, freq_dict):190"""Compute the normalized matrix for soft-accuracy computation.191
192Args:
193domain: A Sequin domain which provides the ordered list of tokens.
194freq_dict: A dict of dicts containing pairs of frequencies. E.g. for
195computing the normalized matrix based on the Blosum matrix use
196freq_dict=pfam_utils.BLOSUM62_TABLE.to_dict().
197
198Returns:
199An array of shape (vocab_size, vocab_size) containing the matrix to be
200used for soft-accuracy computation.
201"""
202matrix = np.zeros((domain.vocab_size, domain.vocab_size))203for idx_1, token_1 in enumerate(domain.vocab.tokens):204for idx_2, token_2 in enumerate(domain.vocab.tokens):205if token_1 in freq_dict:206if token_2 in freq_dict[token_1]:207matrix[idx_1][idx_2] = freq_dict[token_1][token_2]208matrix -= np.min(matrix)209matrix /= np.max(matrix)210return matrix211
212
213def compute_weighted_soft_accuracy(logits, targets, weights=None, matrix=None):214"""Compute weighted soft-accuracy for log probs and targets.215
216Based on Section 3.4 in
217[ProGen](https://www.biorxiv.org/content/10.1101/2020.03.07.982272v2).
218
219Args:
220logits: [batch, length, num_classes] float array.
221targets: categorical targets [batch, length] int array.
222weights: None or array of shape [batch x length]
223matrix: [num_classes, num_classes] normalized matrix to use for soft-accuracy
224computation.
225
226Returns:
227Tuple of scalar soft-accuracy and batch normalizing factor.
228
229Raises:
230ValueError when the logits and targets have incorrect number of dimensions.
231"""
232if logits.ndim != targets.ndim + 1:233raise ValueError(f'Incorrect shapes. Got shape {logits.shape} for logits '234f'and {targets.shape} for targets.')235
236# Compute hard accuracy.237pred = np.argmax(logits, axis=-1)238loss = np.equal(pred, targets).astype(np.float32)239
240# Add matrix-based accuracy for incorrect predictions.241if matrix is not None:242matrix = matrix * (np.ones(len(matrix)) - np.eye(len(matrix)))243loss_matrix = matrix[np.reshape(pred, [-1])]244loss_matrix = np.transpose(loss_matrix)245loss_matrix = loss_matrix[np.reshape(targets, [-1])]246loss_matrix = np.reshape(np.diag(loss_matrix), pred.shape)247loss += loss_matrix248
249normalizing_factor = np.prod(logits.shape[:-1])250if weights is not None:251loss = loss * weights252normalizing_factor = weights.sum()253
254return loss.sum(), normalizing_factor255
256
257def _psum(target_tree, axis_name='batch'):258return jax.tree_map(lambda x: lax.psum(x, axis_name), target_tree)259
260
261def compute_metrics(logits, labels, token_weights, example_weights=None):262"""Compute summary metrics with loss, accuracy, and normalizing factor."""263loss, weight_sum = compute_weighted_cross_entropy(logits, labels,264token_weights,265example_weights)266acc, _ = compute_weighted_accuracy(logits, labels, token_weights)267metrics = {268'loss': loss,269'accuracy': acc,270'denominator': weight_sum,271}272try:273metrics = _psum(metrics)274except NameError:275pass # We are not inside pmap. No need to psum.276return metrics277
278
279def get_params(model):280"""Get model parameters."""281return model.optimizer.target.params282
283
284def param_count(model):285"""Get total parameter count."""286params = get_params(model)287num_params = sum(x.size for x in jax.tree_leaves(params))288return num_params289
290
291def param_pprint(model):292"""Pretty print parameter tree to stdout."""293params = get_params(model)294x = tree.map_structure(lambda x: x.size / 1024, params)295as_str = pprint.pformat(x)296return as_str297
298
299def param_reduce(model, log=False):300"""Return a dict containing param counts per submodule."""301params = get_params(model)302sizes = collections.defaultdict(int)303for path, x in tree.flatten_with_path(params):304size = x.size305for i in range(len(path)):306k = path[:i]307sizes[k] += size308for k in sorted(sizes):309if log:310logging.info('%s: %s', k, sizes[k])311return sizes312
313
314def batchify(inputs, batch_size):315"""Reshapes and pads inputs to include an additional batch dimension.316
317The inputs can be of arbitrary length. They length does not need to be a
318multiple of batch_size, in which case padding will be added.
319
320Args:
321inputs: An np.ndarray or iterable of np.ndarray of shape [input_size, ...].
322batch_size:
323The size of the batches to group the data into.
324Returns:
325batch_inputs: np.ndarray of size [num_batches, batch_size, ...],
326where num_batches is ceil(input_size / batch_size).
327pad_size: Number of examples in the final batch that are padding. We use
328copies of inputs[0] as padding.
329"""
330
331inputs = np.asarray(inputs)332
333pad_size = -len(inputs) % batch_size334if pad_size:335padding = np.tile(inputs[:1], [pad_size, 1])336padded_inputs = np.concatenate([inputs, padding], axis=0)337else:338padded_inputs = inputs339batched_shape = (-1, batch_size) + padded_inputs.shape[1:]340batched_inputs = np.reshape(padded_inputs, batched_shape)341return batched_inputs, pad_size342
343
344def batch_apply(fn, inputs, batch_size):345"""Applies fn() to inputs in batches of size batch_size.346
347The inputs can be of arbitrary length. They length does not need to be a
348multiple of batch_size. Padding will be added (and then removed) such that
349fn() is always called on inputs of size exactly batch_size. fn() is assumed
350to operate independently across the batch dimension of its inputs (e.g.
351computing predictions of a model on inputs) instead of performing an operation
352where the batch elements interact (e.g. performing a gradient step of a model
353on a batch of inputs).
354
355Args:
356fn: The function to map across the inputs.
357inputs: An np.ndarray or iterable of np.ndarray. fn() is mapped
358along the first dimension.
359batch_size:
360The size of the batches to evaluate fn() on.
361Returns:
362np.ndarray where outputs[i] = fn(inputs[i])
363"""
364
365batched_inputs, pad_size = batchify(inputs, batch_size)366results = np.concatenate([fn(batch) for batch in batched_inputs])367if pad_size:368results = results[:-pad_size]369return results370
371
372@gin.configurable373def create_learning_rate_scheduler(374factors='constant * linear_warmup * rsqrt_decay',375base_learning_rate=0.5,376warmup_steps=8000,377decay_factor=0.5,378steps_per_decay=20000,379steps_per_cycle=100000):380"""Creates learning rate schedule.381
382Interprets factors in the factors string which can consist of:
383* constant: interpreted as the constant value,
384* linear_warmup: interpreted as linear warmup until warmup_steps,
385* rsqrt_decay: divide by square root of max(step, warmup_steps)
386* decay_every: Every k steps decay the learning rate by decay_factor.
387* cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
388
389Args:
390factors: A string with factors separated by '*' that defines the schedule.
391base_learning_rate: Float, the starting constant for the lr schedule.
392warmup_steps: How many steps to warm up for in the warmup schedule.
393decay_factor: The amount to decay the learning rate by.
394steps_per_decay: How often to decay the learning rate.
395steps_per_cycle: Steps per cycle when using cosine decay.
396
397Returns:
398a function learning_rate(step): float -> {'learning_rate': float}, the
399step-dependent lr.
400"""
401factors = [n.strip() for n in factors.split('*')]402
403def step_fn(step):404"""Step to learning rate function."""405ret = 1.0406for name in factors:407if name == 'constant':408ret *= base_learning_rate409elif name == 'linear_warmup':410ret *= jnp.minimum(1.0, step / warmup_steps)411elif name == 'rsqrt_decay':412ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))413elif name == 'rsqrt_normalized_decay':414ret *= jnp.sqrt(warmup_steps)415ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))416elif name == 'decay_every':417ret *= (decay_factor**(step // steps_per_decay))418elif name == 'cosine_decay':419progress = jnp.maximum(0.0,420(step - warmup_steps) / float(steps_per_cycle))421ret *= jnp.maximum(0.0,4220.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))423else:424raise ValueError('Unknown factor %s.' % name)425return jnp.asarray(ret, dtype=jnp.float32)426
427return step_fn428
429
430class Timer(object):431"""Context manager for logging timing.432
433Example usage:
434with Timer('my function'):
435my_function(inputs)
436
437Attributes:
438elapsed: The time in seconds that it took to execute the context.
439"""
440
441def __init__(self, message=None, verbose=True):442"""Creates and instance of this class.443
444Args:
445message: The message to be used for logging. If `None`, does not log.
446verbose: Whether to log messages to the console.
447"""
448self._message = message449self._elapsed = None450self._verbose = verbose451
452def _log(self, msg, *args, **kwargs):453if self._message and self._verbose:454logging.info(msg, *args, **kwargs)455logging.flush()456
457def __enter__(self):458self._log('Starting: %s', self._message)459self._elapsed = None460self._start = time.time()461return self462
463def __exit__(self, *args):464self._elapsed = time.time() - self._start465self._log('Finished: %s. Elapsed seconds: %f', self._message, self._elapsed)466
467@property468def elapsed(self):469if self._elapsed is None:470raise ValueError('Timer not executed!')471return self._elapsed472
473
474def get_random_state(seed_or_state):475"""Returns a np.random.RandomState given an integer seed or RandomState."""476if isinstance(seed_or_state, int):477return np.random.RandomState(seed_or_state)478elif seed_or_state is None:479# This returns the current global np random state.480return np.random.random.__self__481elif not isinstance(seed_or_state, np.random.RandomState):482raise ValueError('Numpy RandomState or integer seed expected! Got: %s' %483seed_or_state)484else:485return seed_or_state486