google-research
138 строк · 3.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utility functions for logging and recording training metrics."""
17import logging18from absl import logging as absl_logging19import jax20from jax import lax21import jax.numpy as jnp22import numpy as onp23
24
25def add_log_file(logfile):26"""Replicate logs to an additional logfile.27
28The caller is responsible for closing the logfile.
29Args:
30logfile: Open file to write log to
31"""
32handler = logging.StreamHandler(logfile)33handler.setFormatter(absl_logging.PythonFormatter())34
35absl_logger = logging.getLogger('absl')36absl_logger.addHandler(handler)37
38
39def to_state_list(obj):40"""Return the state of the model as a flattened list.41
42Restore with `load_state_list`.
43
44Args:
45obj: the object to extract state from
46
47Returns:
48State as a list of jax.numpy arrays
49"""
50return jax.device_get(51[x[0] for x in jax.tree_leaves(obj)])52
53
54def restore_state_list(obj, state_list):55"""Restore model state from a state list.56
57Args:
58obj: the object that is to be duplicated with the
59restored state
60state_list: state as a list of jax.numpy arrays
61
62Returns:
63a copy of `self` with the parameters from state_list loaded
64
65>>> restored = restore_state_list(model, state_list)
66"""
67state_list = replicate(state_list)68structure = jax.tree_util.tree_structure(obj)69return jax.tree_unflatten(structure, state_list)70
71
72def replicate(xs, n_devices=None):73if n_devices is None:74n_devices = jax.local_device_count()75return jax.pmap(76lambda _: xs, axis_name='batch')(jnp.arange(n_devices))77
78
79def shard(xs):80local_device_count = jax.local_device_count()81return jax.tree_map(82lambda x: x.reshape((local_device_count, -1) + x.shape[1:]), xs)83
84
85def onehot(labels, num_classes):86x = (labels[Ellipsis, None] == jnp.arange(num_classes)[None])87return x.astype(jnp.float32)88
89
90def pmean(tree, axis_name='batch'):91num_devices = lax.psum(1., axis_name)92return jax.tree_map(lambda x: lax.psum(x, axis_name) / num_devices, tree)93
94
95def psum(tree, axis_name='batch'):96return jax.tree_map(lambda x: lax.psum(x, axis_name), tree)97
98
99def pad_classification_batch(batch, batch_size):100"""Pad a classification batch so that it has `batch_size` samples.101
102The batch should be a dictionary of the form:
103
104{
105'image': <image>,
106'label': <GT label>
107}
108
109Args:
110batch: the batch to pad
111batch_size: the desired number of elements
112
113Returns:
114Padded batch as a dictionary
115"""
116actual_size = len(batch['image'])117if actual_size < batch_size:118padding = batch_size - actual_size119padded = {120'label': onp.pad(batch['label'], [[0, padding]],121mode='constant', constant_values=-1),122'image': onp.pad(batch['image'], [[0, padding], [0, 0], [0, 0], [0, 0]],123mode='constant', constant_values=0),124}125return padded126else:127return batch128
129
130def stack_forest(forest):131stack_args = lambda *args: onp.stack(args)132return jax.tree_map(stack_args, *forest)133
134
135def get_metrics(device_metrics):136device_metrics = jax.tree_map(lambda x: x[0], device_metrics)137metrics_np = jax.device_get(device_metrics)138return stack_forest(metrics_np)139