google-research
191 строка · 5.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Generic helper utilities."""
17
18import functools
19from typing import Any, Optional
20
21import chex
22from flax import core
23from flax.training import checkpoints
24import flax.training.train_state as ts
25import haiku as hk
26import jax
27import jax.numpy as jnp
28import numpy as np
29import optax
30
31from hct.common import typing
32
33
34def param_count(params):
35return sum(x.size for x in jax.tree_util.tree_leaves(params))
36
37
38def check_params_finite(params):
39return jnp.array(
40[jnp.isfinite(x).all() for x in jax.tree_util.tree_leaves(params)]).all()
41
42
43class TrainStateBN(ts.TrainState):
44"""Train-state with batchnorm batch-stats."""
45batch_stats: core.FrozenDict[str, Any]
46
47
48def make_optax_adam(learning_rate,
49weight_decay):
50if weight_decay > 0:
51return optax.adamw(learning_rate=learning_rate, weight_decay=weight_decay)
52else:
53return optax.adam(learning_rate=learning_rate)
54
55
56@functools.partial(jax.jit, static_argnums=(0,))
57def split_across_devices(num, x):
58"""Split batch across devices."""
59return jnp.reshape(x, (num, x.shape[0] // num) + x.shape[1:])
60
61
62def insert_batch_axis(x):
63return jax.tree_util.tree_map(lambda leaf: leaf[None, Ellipsis], x)
64
65
66def remove_batch_axis(x):
67return jax.tree_util.tree_map(lambda leaf: jnp.squeeze(leaf, axis=0), x)
68
69
70def unbatch_flax_fn(fn, has_params = True):
71"""Unbatch flax fn."""
72# assumes all args are passed in as batchified
73# and all kwargs are to be broadcasted
74if has_params:
75def unbatched_fn(params, *args, **kwargs):
76batched_args = map(insert_batch_axis, args)
77return remove_batch_axis(fn(params, *batched_args, **kwargs))
78return unbatched_fn
79else:
80def unbatched_fn(*args, **kwargs):
81batched_args = map(insert_batch_axis, args)
82return remove_batch_axis(fn(*batched_args, **kwargs))
83return unbatched_fn
84
85
86class BatchManager:
87"""A simple batch manager."""
88
89def __init__(self,
90key,
91dataset,
92batch_size):
93self._prng = hk.PRNGSequence(key)
94self._num = len(dataset['images'])
95assert len(dataset['hf_obs']) == self._num
96assert len(dataset['actions']) == self._num
97
98# Ensure saved copy off accelerator device
99self._dataset = {k: jax.device_get(arr) for k, arr in dataset.items()}
100self._batch_size = min(batch_size, self._num)
101self._num_batches = self._num // self._batch_size
102self._epochs_completed = 0
103
104self._permutation = None
105self._current_batch_idx = None
106self._resample()
107
108@property
109def batch_size(self):
110return self._batch_size
111
112@property
113def num_batches(self):
114return self._num_batches
115
116@property
117def epochs_completed(self):
118return self._epochs_completed
119
120def _resample(self):
121self._permutation = jax.random.permutation(next(self._prng), self._num)
122total_points = self._batch_size * self._num_batches
123self._permutation = self._permutation[:total_points].reshape(
124(self._num_batches, self._batch_size))
125self._current_batch_idx = 0
126
127def _select(self, inds):
128return {k: arr[inds] for k, arr in self._dataset.items()}
129
130def next_batch(self):
131"""Get the next batch."""
132assert self._permutation is not None
133inds = self._permutation[self._current_batch_idx]
134ret = self._select(inds)
135self._current_batch_idx += 1
136if self._current_batch_idx >= len(self._permutation):
137self._epochs_completed += 1
138self._resample()
139return ret
140
141def next_pmapped_batch(self, num_devices):
142# assert that num_devices is compatible with batch size
143assert self._batch_size % num_devices == 0
144ret = self.next_batch()
145return {k: split_across_devices(num_devices, arr) for k, arr in ret.items()}
146
147
148def compute_norm_stats(
149dataset):
150"""Compute mean and std pytrees."""
151means = {k: np.mean(arr, axis=0) for k, arr in dataset.items()}
152stds = {k: np.std(arr, axis=0) for k, arr in dataset.items()}
153
154# Force normalization for images to be zero mean, std = 255.
155means['images'] = np.zeros_like(means['images'])
156stds['images'] = 255. * np.ones_like(stds['images'])
157return means, stds
158
159
160def normalize(dataset,
161means,
162stds):
163"""Normalize dataset."""
164return jax.tree_util.tree_map(
165lambda leaf, leaf_mean, leaf_std: (leaf - leaf_mean) / leaf_std,
166dataset, means, stds)
167
168
169def unnormalize(arr,
170arr_mean,
171arr_std):
172"""Unormalize array."""
173return arr * arr_std + arr_mean
174
175
176def save_model(checkpoint_dir, step,
177keep_every, state):
178"""Checkpoints and saves models."""
179checkpoints.save_checkpoint(
180ckpt_dir=checkpoint_dir,
181target=state,
182step=step,
183overwrite=True,
184keep_every_n_steps=keep_every)
185
186
187def restore_model(checkpoint_dir, state,
188step = None):
189"""Restore all models."""
190# NOTE: Assumes states have pre-defined structures.
191return checkpoints.restore_checkpoint(checkpoint_dir, state, step=step)
192