google-research

Форк
0
138 строк · 3.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utility functions for logging and recording training metrics."""
17
import logging
18
from absl import logging as absl_logging
19
import jax
20
from jax import lax
21
import jax.numpy as jnp
22
import numpy as onp
23

24

25
def add_log_file(logfile):
26
  """Replicate logs to an additional logfile.
27

28
  The caller is responsible for closing the logfile.
29
  Args:
30
    logfile: Open file to write log to
31
  """
32
  handler = logging.StreamHandler(logfile)
33
  handler.setFormatter(absl_logging.PythonFormatter())
34

35
  absl_logger = logging.getLogger('absl')
36
  absl_logger.addHandler(handler)
37

38

39
def to_state_list(obj):
40
  """Return the state of the model as a flattened list.
41

42
  Restore with `load_state_list`.
43

44
  Args:
45
    obj: the object to extract state from
46

47
  Returns:
48
    State as a list of jax.numpy arrays
49
  """
50
  return jax.device_get(
51
      [x[0] for x in jax.tree_leaves(obj)])
52

53

54
def restore_state_list(obj, state_list):
55
  """Restore model state from a state list.
56

57
  Args:
58
    obj: the object that is to be duplicated with the
59
      restored state
60
    state_list: state as a list of jax.numpy arrays
61

62
  Returns:
63
    a copy of `self` with the parameters from state_list loaded
64

65
  >>> restored = restore_state_list(model, state_list)
66
  """
67
  state_list = replicate(state_list)
68
  structure = jax.tree_util.tree_structure(obj)
69
  return jax.tree_unflatten(structure, state_list)
70

71

72
def replicate(xs, n_devices=None):
73
  if n_devices is None:
74
    n_devices = jax.local_device_count()
75
  return jax.pmap(
76
      lambda _: xs, axis_name='batch')(jnp.arange(n_devices))
77

78

79
def shard(xs):
80
  local_device_count = jax.local_device_count()
81
  return jax.tree_map(
82
      lambda x: x.reshape((local_device_count, -1) + x.shape[1:]), xs)
83

84

85
def onehot(labels, num_classes):
86
  x = (labels[Ellipsis, None] == jnp.arange(num_classes)[None])
87
  return x.astype(jnp.float32)
88

89

90
def pmean(tree, axis_name='batch'):
91
  num_devices = lax.psum(1., axis_name)
92
  return jax.tree_map(lambda x: lax.psum(x, axis_name) / num_devices, tree)
93

94

95
def psum(tree, axis_name='batch'):
96
  return jax.tree_map(lambda x: lax.psum(x, axis_name), tree)
97

98

99
def pad_classification_batch(batch, batch_size):
100
  """Pad a classification batch so that it has `batch_size` samples.
101

102
  The batch should be a dictionary of the form:
103

104
  {
105
    'image': <image>,
106
    'label': <GT label>
107
  }
108

109
  Args:
110
    batch: the batch to pad
111
    batch_size: the desired number of elements
112

113
  Returns:
114
    Padded batch as a dictionary
115
  """
116
  actual_size = len(batch['image'])
117
  if actual_size < batch_size:
118
    padding = batch_size - actual_size
119
    padded = {
120
        'label': onp.pad(batch['label'], [[0, padding]],
121
                         mode='constant', constant_values=-1),
122
        'image': onp.pad(batch['image'], [[0, padding], [0, 0], [0, 0], [0, 0]],
123
                         mode='constant', constant_values=0),
124
    }
125
    return padded
126
  else:
127
    return batch
128

129

130
def stack_forest(forest):
131
  stack_args = lambda *args: onp.stack(args)
132
  return jax.tree_map(stack_args, *forest)
133

134

135
def get_metrics(device_metrics):
136
  device_metrics = jax.tree_map(lambda x: x[0], device_metrics)
137
  metrics_np = jax.device_get(device_metrics)
138
  return stack_forest(metrics_np)
139

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.