google-research

Форк
0
191 строка · 5.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Generic helper utilities."""
17

18
import functools
19
from typing import Any, Optional
20

21
import chex
22
from flax import core
23
from flax.training import checkpoints
24
import flax.training.train_state as ts
25
import haiku as hk
26
import jax
27
import jax.numpy as jnp
28
import numpy as np
29
import optax
30

31
from hct.common import typing
32

33

34
def param_count(params):
35
  return sum(x.size for x in jax.tree_util.tree_leaves(params))
36

37

38
def check_params_finite(params):
39
  return jnp.array(
40
      [jnp.isfinite(x).all() for x in jax.tree_util.tree_leaves(params)]).all()
41

42

43
class TrainStateBN(ts.TrainState):
44
  """Train-state with batchnorm batch-stats."""
45
  batch_stats: core.FrozenDict[str, Any]
46

47

48
def make_optax_adam(learning_rate,
49
                    weight_decay):
50
  if weight_decay > 0:
51
    return optax.adamw(learning_rate=learning_rate, weight_decay=weight_decay)
52
  else:
53
    return optax.adam(learning_rate=learning_rate)
54

55

56
@functools.partial(jax.jit, static_argnums=(0,))
57
def split_across_devices(num, x):
58
  """Split batch across devices."""
59
  return jnp.reshape(x, (num, x.shape[0] // num) + x.shape[1:])
60

61

62
def insert_batch_axis(x):
63
  return jax.tree_util.tree_map(lambda leaf: leaf[None, Ellipsis], x)
64

65

66
def remove_batch_axis(x):
67
  return jax.tree_util.tree_map(lambda leaf: jnp.squeeze(leaf, axis=0), x)
68

69

70
def unbatch_flax_fn(fn, has_params = True):
71
  """Unbatch flax fn."""
72
  # assumes all args are passed in as batchified
73
  # and all kwargs are to be broadcasted
74
  if has_params:
75
    def unbatched_fn(params, *args, **kwargs):
76
      batched_args = map(insert_batch_axis, args)
77
      return remove_batch_axis(fn(params, *batched_args, **kwargs))
78
    return unbatched_fn
79
  else:
80
    def unbatched_fn(*args, **kwargs):
81
      batched_args = map(insert_batch_axis, args)
82
      return remove_batch_axis(fn(*batched_args, **kwargs))
83
    return unbatched_fn
84

85

86
class BatchManager:
87
  """A simple batch manager."""
88

89
  def __init__(self,
90
               key,
91
               dataset,
92
               batch_size):
93
    self._prng = hk.PRNGSequence(key)
94
    self._num = len(dataset['images'])
95
    assert len(dataset['hf_obs']) == self._num
96
    assert len(dataset['actions']) == self._num
97

98
    # Ensure saved copy off accelerator device
99
    self._dataset = {k: jax.device_get(arr) for k, arr in dataset.items()}
100
    self._batch_size = min(batch_size, self._num)
101
    self._num_batches = self._num // self._batch_size
102
    self._epochs_completed = 0
103

104
    self._permutation = None
105
    self._current_batch_idx = None
106
    self._resample()
107

108
  @property
109
  def batch_size(self):
110
    return self._batch_size
111

112
  @property
113
  def num_batches(self):
114
    return self._num_batches
115

116
  @property
117
  def epochs_completed(self):
118
    return self._epochs_completed
119

120
  def _resample(self):
121
    self._permutation = jax.random.permutation(next(self._prng), self._num)
122
    total_points = self._batch_size * self._num_batches
123
    self._permutation = self._permutation[:total_points].reshape(
124
        (self._num_batches, self._batch_size))
125
    self._current_batch_idx = 0
126

127
  def _select(self, inds):
128
    return {k: arr[inds] for k, arr in self._dataset.items()}
129

130
  def next_batch(self):
131
    """Get the next batch."""
132
    assert self._permutation is not None
133
    inds = self._permutation[self._current_batch_idx]
134
    ret = self._select(inds)
135
    self._current_batch_idx += 1
136
    if self._current_batch_idx >= len(self._permutation):
137
      self._epochs_completed += 1
138
      self._resample()
139
    return ret
140

141
  def next_pmapped_batch(self, num_devices):
142
    # assert that num_devices is compatible with batch size
143
    assert self._batch_size % num_devices == 0
144
    ret = self.next_batch()
145
    return {k: split_across_devices(num_devices, arr) for k, arr in ret.items()}
146

147

148
def compute_norm_stats(
149
    dataset):
150
  """Compute mean and std pytrees."""
151
  means = {k: np.mean(arr, axis=0) for k, arr in dataset.items()}
152
  stds = {k: np.std(arr, axis=0) for k, arr in dataset.items()}
153

154
  # Force normalization for images to be zero mean, std = 255.
155
  means['images'] = np.zeros_like(means['images'])
156
  stds['images'] = 255. * np.ones_like(stds['images'])
157
  return means, stds
158

159

160
def normalize(dataset,
161
              means,
162
              stds):
163
  """Normalize dataset."""
164
  return jax.tree_util.tree_map(
165
      lambda leaf, leaf_mean, leaf_std: (leaf - leaf_mean) / leaf_std,
166
      dataset, means, stds)
167

168

169
def unnormalize(arr,
170
                arr_mean,
171
                arr_std):
172
  """Unormalize array."""
173
  return arr * arr_std + arr_mean
174

175

176
def save_model(checkpoint_dir, step,
177
               keep_every, state):
178
  """Checkpoints and saves models."""
179
  checkpoints.save_checkpoint(
180
      ckpt_dir=checkpoint_dir,
181
      target=state,
182
      step=step,
183
      overwrite=True,
184
      keep_every_n_steps=keep_every)
185

186

187
def restore_model(checkpoint_dir, state,
188
                  step = None):
189
  """Restore all models."""
190
  # NOTE: Assumes states have pre-defined structures.
191
  return checkpoints.restore_checkpoint(checkpoint_dir, state, step=step)
192

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.