google-research
860 строк · 31.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Distributed Shampoo Implementation."""
17# An implementation of distributed Shampoo optimizer from:
18#
19# Scalable Second Order Optimization for Deep Learning
20# Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
21# Preprint Paper: https://arxiv.org/abs/2002.09018
22#
23# This implementation moves computation of inverse pth root back to the
24# accelerator (if higher precision is available). We will present the details
25# in an ArXiv note soon.
26#
27# This implementation has been verified to work on ResNet-50 training to 75.9%
28# accuracy which is the MLPerf benchmark at 32K batch size. At the time of
29# writing this comment it achieves this in 1729 steps whereas the best known
30# first order method trains in 2512 steps.
31#
32# Authors: Rohan Anil (rohananil at google dot com)
33# & Vineet Gupta (vineet at google dot com)
34#
35import enum
36import itertools
37
38from flax import struct
39from flax.optim.base import OptimizerDef
40from flax.optim.base import OptimizerState
41import jax
42from jax import lax
43import jax.numpy as jnp
44import numpy as onp
45
46# Precision to use for matrix inverse pth root. Switch to f64 if you have
47# hardware that supports it.
48_INVERSE_PTH_ROOT_DATA_TYPE = jnp.float32
49
50# Numerics are hard. Inverses fail sometimes. We determine that using this
51# threshold.
52_INVERSE_PTH_ROOT_FAILURE_THRESHOLD = 0.1
53
54# Inverse pth root precision (XLA related) flag.
55#
56# Options are:
57# a. lax.Precision.DEFAULT (Better step time, but not precise)
58# b. lax.Precision.HIGH (Increased precision, slower)
59# c. lax.Precision.HIGHEST (Best possible precision, slowest)
60#
61_INVERSE_PTH_ROOT_PRECISION = lax.Precision.HIGHEST
62
63
64# Grafting is a technique to fix the layerwise scale of Shampoo optimizer.
65# https://arxiv.org/pdf/2002.11803.pdf studies this in detail. Moreover this
66# allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad
67# is already well tuned.
68class LayerwiseGrafting(enum.IntEnum):
69SGD = 1
70ADAGRAD = 2
71
72
73@struct.dataclass
74class _ShampooHyperParams:
75"""Shampoo hyperparameters."""
76
77learning_rate: float
78# Momentum (in Heavy-Ball or Nesterov, if nesterov is True).
79beta1: onp.ndarray
80# Parameter for exponential moving average of Shampoo second moment statistics
81# if set == 1.0, then sums statistics instead of moving average.
82beta2: onp.ndarray
83# Only set if using Layerwise grafting mode to adagrad. This is the epsilon
84# for adagrad update.
85diagonal_eps: float
86
87# Epsilon to add to statistics before computing inverse pth root. If you are
88# running in f32 precision for inverse pth root (recommended today)
89# this can go upto 1e-6. If you have latest hardware with native f64 precision
90# set this upto 1e-12.
91matrix_eps: float
92
93# Weight decay parameter for regularization.
94weight_decay: float
95
96# When to start Shampoo update before which diagonal update is used. This is
97# because we do not have enough information to compute a stable inverse.
98start_preconditioning_step: int
99
100# Performance tuning params for controlling memory and compute requirements.
101# How often to compute preconditioner. Ideally set both params to 1.
102preconditioning_compute_steps: int
103# How often to compute statistics.
104statistics_compute_steps: int
105
106# Block size for large layers (if > 0).
107block_size: int
108
109# if there are some small dimensions, collapse them:
110# e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if block = 1024
111# [1, 2, 768, 1, 2048] --> [2, 768, 2048]
112best_effort_shape_interpretation: bool
113
114# Type of grafting (SGD or AdaGrad).
115# https://arxiv.org/pdf/2002.11803.pdf
116graft_type: int
117
118# Avoids preconditioning large layers to reduce overall memory usage if any
119# of the dimensions are greater than this value.
120no_preconditioning_for_layers_with_dim_gt: int
121
122# Nesterov momentum
123nesterov: bool
124# Exponent override (if > 0):
125exponent_override: int
126# Batch axis name (for data parallel code).
127batch_axis_name: str
128
129
130class BlockPartitioner:
131"""Partitions a tensor into smaller tensors."""
132
133def __init__(self, param, hps):
134self._shape = param.shape
135self._splits = []
136split_sizes = []
137# We split params into smaller blocks. Here we store the metadata to make
138# that split.
139for i, d in enumerate(param.shape):
140if hps.block_size > 0 and d > hps.block_size:
141# d-1, otherwise split appends a 0-size array.
142nsplit = (d-1) // hps.block_size
143indices = (onp.arange(nsplit, dtype=onp.int32) + 1) * hps.block_size
144sizes = onp.ones(nsplit + 1, dtype=onp.int32) * hps.block_size
145sizes[-1] = d - indices[-1]
146self._splits.append((i, indices))
147split_sizes.append(sizes)
148else:
149split_sizes.append(onp.array([d], dtype=onp.int32))
150self._num_splits = len(split_sizes)
151self._preconditioner_shapes = []
152for t in itertools.product(*split_sizes):
153self._preconditioner_shapes.extend([[d, d] for d in t])
154
155def shapes_for_preconditioners(self):
156return self._preconditioner_shapes
157
158def num_splits(self):
159return self._num_splits
160
161def partition(self, tensor):
162"""Partition tensor into blocks."""
163
164assert tensor.shape == self._shape
165tensors = [tensor]
166for (i, indices) in self._splits:
167tensors_local = []
168for t in tensors:
169tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
170tensors = tensors_local
171return tensors
172
173def merge_partitions(self, partitions):
174"""Merge partitions back to original shape."""
175
176for (i, indices) in reversed(self._splits):
177n = len(indices) + 1
178partial_merged_tensors = []
179ind = 0
180while ind < len(partitions):
181partial_merged_tensors.append(
182jnp.concatenate(partitions[ind:ind + n], axis=i))
183ind += n
184partitions = partial_merged_tensors
185assert len(partitions) == 1
186return partitions[0]
187
188
189def _merge_small_dims(shape_to_merge, max_dim):
190"""Merge small dimensions.
191
192If there are some small dimensions, we collapse them:
193e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
194[1, 2, 768, 1, 2048] --> [2, 768, 2048]
195
196Args:
197shape_to_merge: Shape to merge small dimensions.
198max_dim: Maximal dimension of output shape used in merging.
199
200Returns:
201Merged shape.
202"""
203resulting_shape = []
204product = 1
205for d in shape_to_merge:
206if product * d <= max_dim:
207product *= d
208else:
209if product > 1:
210resulting_shape.append(product)
211product = d
212if product > 1:
213resulting_shape.append(product)
214return resulting_shape
215
216
217class Preconditioner:
218"""Compute statistics/shape from gradients for preconditioning."""
219
220def __init__(self, param, hps):
221self._hps = hps
222self._original_shape = param.shape
223self._transformed_shape = param.shape
224if hps.best_effort_shape_interpretation:
225self._transformed_shape = _merge_small_dims(
226self._original_shape, hps.block_size)
227
228reshaped_param = jnp.reshape(param, self._transformed_shape)
229self._partitioner = BlockPartitioner(reshaped_param, hps)
230
231def statistics_from_grad(self, grad):
232"""Compute statistics from gradients.
233
234Args:
235grad: Gradient to compute statistics from.
236
237Returns:
238A list of gradient statistics for each partition.
239"""
240reshaped_grad = jnp.reshape(grad, self._transformed_shape)
241partitioned_grads = self._partitioner.partition(reshaped_grad)
242stats = []
243for grad in partitioned_grads:
244grad_stats = []
245rank = len(grad.shape)
246for i in range(rank):
247axes = list(range(i)) + list(range(i + 1, rank))
248stat = jnp.tensordot(grad, grad, axes=(axes, axes))
249grad_stats.append(stat)
250stats.extend(grad_stats)
251return stats
252
253def shapes_for_preconditioners(self):
254"""Returns shape from statistics."""
255return self._partitioner.shapes_for_preconditioners()
256
257def exponent_for_preconditioner(self):
258"""Returns exponent to use for inverse-pth root M^{-1/p}."""
259return 2 * len(self._transformed_shape)
260
261def preconditioned_grad(self, grad, preconditioners):
262"""Precondition the gradient.
263
264Args:
265grad: A gradient tensor to precondition.
266preconditioners: A list of preconditioners to apply.
267
268Returns:
269A preconditioned gradient.
270"""
271
272reshaped_grad = jnp.reshape(grad, self._transformed_shape)
273partitioned_grads = self._partitioner.partition(reshaped_grad)
274preconditioned_partitioned_grads = []
275num_splits = self._partitioner.num_splits()
276for i, grad in enumerate(partitioned_grads):
277preconditioners_for_grad = preconditioners[i * num_splits:(i + 1) *
278num_splits]
279rank = len(grad.shape)
280precond_grad = grad
281for j in range(rank):
282precond_grad = jnp.tensordot(
283precond_grad, preconditioners_for_grad[j], axes=[[0], [0]])
284preconditioned_partitioned_grads.append(precond_grad)
285merged_grad = self._partitioner.merge_partitions(
286preconditioned_partitioned_grads)
287return jnp.reshape(merged_grad, self._original_shape)
288
289
290@struct.dataclass
291class _ShampooDefaultParamState:
292"""Shampoo default parameter state."""
293
294# Accumulator for diagonal preconditioner
295diagonal_statistics: onp.ndarray
296# Statistics
297statistics: onp.ndarray
298# Preconditioners
299preconditioners: onp.ndarray
300# Momentum for the diagonal preconditioner
301diagonal_momentum: onp.ndarray
302# Momentum for the shampoo preconditioner
303momentum: onp.ndarray
304
305
306def power_iter(mat_g, error_tolerance=1e-6, num_iters=100):
307"""Power iteration.
308
309Args:
310mat_g: the symmetric PSD matrix.
311error_tolerance: Iterative exit condition.
312num_iters: Number of iterations.
313
314Returns:
315eigen vector, eigen value, num_iters
316"""
317mat_g_size = mat_g.shape[-1]
318def _iter_condition(state):
319i, unused_v, unused_s, unused_s_v, run_step = state
320return jnp.logical_and(i < num_iters, run_step)
321
322def _iter_body(state):
323"""One step of power iteration."""
324i, new_v, s, s_v, unused_run_step = state
325new_v = new_v / jnp.linalg.norm(new_v)
326
327s_v = jnp.einsum(
328'ij,j->i', mat_g, new_v, precision=_INVERSE_PTH_ROOT_PRECISION)
329s_new = jnp.einsum(
330'i,i->', new_v, s_v, precision=_INVERSE_PTH_ROOT_PRECISION)
331return (i + 1, s_v, s_new, s_v,
332jnp.greater(jnp.abs(s_new - s), error_tolerance))
333
334# Figure out how to use step as seed for random.
335v_0 = onp.random.uniform(-1.0, 1.0, mat_g_size).astype(mat_g.dtype)
336
337init_state = tuple([0, v_0, jnp.zeros([], dtype=mat_g.dtype), v_0, True])
338num_iters, v_out, s_out, _, _ = lax.while_loop(
339_iter_condition, _iter_body, init_state)
340v_out = v_out / jnp.linalg.norm(v_out)
341return v_out, s_out, num_iters
342
343
344def matrix_inverse_pth_root(mat_g,
345p,
346iter_count=100,
347error_tolerance=1e-6,
348ridge_epsilon=1e-6):
349"""Computes mat_g^(-1/p), where p is a positive integer.
350
351Coupled newton iterations for matrix inverse pth root.
352
353Args:
354mat_g: the symmetric PSD matrix whose power it to be computed
355p: exponent, for p a positive integer.
356iter_count: Maximum number of iterations.
357error_tolerance: Error indicator, useful for early termination.
358ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
359
360Returns:
361mat_g^(-1/p)
362"""
363mat_g_size = mat_g.shape[0]
364alpha = jnp.asarray(-1.0 / p, _INVERSE_PTH_ROOT_DATA_TYPE)
365identity = jnp.eye(mat_g_size, dtype=_INVERSE_PTH_ROOT_DATA_TYPE)
366_, max_ev, _ = power_iter(mat_g)
367ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16)
368
369def _unrolled_mat_pow_1(mat_m):
370"""Computes mat_m^1."""
371return mat_m
372
373def _unrolled_mat_pow_2(mat_m):
374"""Computes mat_m^2."""
375return jnp.matmul(mat_m, mat_m, precision=_INVERSE_PTH_ROOT_PRECISION)
376
377def _unrolled_mat_pow_4(mat_m):
378"""Computes mat_m^4."""
379mat_pow_2 = _unrolled_mat_pow_2(mat_m)
380return jnp.matmul(
381mat_pow_2, mat_pow_2, precision=_INVERSE_PTH_ROOT_PRECISION)
382
383def _unrolled_mat_pow_8(mat_m):
384"""Computes mat_m^4."""
385mat_pow_4 = _unrolled_mat_pow_4(mat_m)
386return jnp.matmul(
387mat_pow_4, mat_pow_4, precision=_INVERSE_PTH_ROOT_PRECISION)
388
389def mat_power(mat_m, p):
390"""Computes mat_m^p, for p == 1, 2, 4 or 8.
391
392Args:
393mat_m: a square matrix
394p: a positive integer
395
396Returns:
397mat_m^p
398"""
399# We unrolled the loop for performance reasons.
400exponent = jnp.round(jnp.log2(p))
401return lax.switch(
402jnp.asarray(exponent, jnp.int32), [
403_unrolled_mat_pow_1,
404_unrolled_mat_pow_2,
405_unrolled_mat_pow_4,
406_unrolled_mat_pow_8,
407], (mat_m))
408
409def _iter_condition(state):
410(i, unused_mat_m, unused_mat_h, unused_old_mat_h, error,
411run_step) = state
412error_above_threshold = jnp.logical_and(
413error > error_tolerance, run_step)
414return jnp.logical_and(i < iter_count, error_above_threshold)
415
416def _iter_body(state):
417(i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
418mat_m_i = (1 - alpha) * identity + alpha * mat_m
419new_mat_m = jnp.matmul(
420mat_power(mat_m_i, p), mat_m, precision=_INVERSE_PTH_ROOT_PRECISION)
421new_mat_h = jnp.matmul(
422mat_h, mat_m_i, precision=_INVERSE_PTH_ROOT_PRECISION)
423new_error = jnp.max(jnp.abs(new_mat_m - identity))
424# sometimes error increases after an iteration before decreasing and
425# converging. 1.2 factor is used to bound the maximal allowed increase.
426return (i + 1, new_mat_m, new_mat_h, mat_h, new_error,
427new_error < error * 1.2)
428
429if mat_g_size == 1:
430resultant_mat_h = (mat_g + ridge_epsilon)**alpha
431error = 0
432else:
433damped_mat_g = mat_g + ridge_epsilon * identity
434z = (1 + p) / (2 * jnp.linalg.norm(damped_mat_g))
435new_mat_m_0 = damped_mat_g * z
436new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
437new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
438init_state = tuple(
439[0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
440_, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
441_iter_condition, _iter_body, init_state)
442error = jnp.max(jnp.abs(mat_m - identity))
443is_converged = jnp.asarray(convergence, old_mat_h.dtype)
444resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
445resultant_mat_h = jnp.asarray(resultant_mat_h, mat_g.dtype)
446return resultant_mat_h, error
447
448
449class Shampoo(OptimizerDef):
450"""Shampoo optimizer.
451
452Scalable Second Order Optimization for Deep Learning,
453Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
454
455Preprint: https://arxiv.org/abs/2002.09018
456"""
457
458def __init__(self,
459learning_rate = None,
460beta1=0.9,
461beta2=0.999,
462diagonal_epsilon=1e-10,
463matrix_epsilon=1e-6,
464weight_decay=0.0,
465start_preconditioning_step=1,
466preconditioning_compute_steps=1,
467statistics_compute_steps=1,
468block_size=128,
469best_effort_shape_interpretation=True,
470graft_type=LayerwiseGrafting.SGD,
471no_preconditioning_for_layers_with_dim_gt=8192,
472nesterov=True,
473exponent_override=0,
474batch_axis_name=None):
475"""Constructor for the Shampoo optimizer.
476
477Args:
478learning_rate: the step size used to update the parameters.
479beta1: momentum parameter.
480beta2: second moment averaging parameter.
481diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
482to AdaGrad is enabled).
483matrix_epsilon: epsilon to add to statistics before computing inverse pth
484root. If you are running in f32 precision for inverse pth root
485(recommended today) this can go upto 1e-6. If you have latest hardware
486with native f64 precision, set this upto 1e-12.
487weight_decay: Weight decay for regularization.
488start_preconditioning_step: When to start Shampoo update before which
489diagonal update is used. This is because we dont have enough information
490to do stable inverse.
491preconditioning_compute_steps: How often to compute preconditioner.
492Performance tuning params for controlling memory and compute
493requirements. Ideally set both params to 1.
494statistics_compute_steps: How often to compute statistics.
495block_size: Block size for large layers (if > 0). Preconditioning compute
496operation is cubic in the dimension of the tensor. Block size allows us
497to chunk the layers into sub-layers of maximal dimension dictated by
498this value. Use 128 as default (increase if you have compute budget).
499best_effort_shape_interpretation:
500graft_type: Options are: LayerwiseGrafting.SGD, LayerwiseGrafting.ADAGRAD
501no_preconditioning_for_layers_with_dim_gt: Avoids preconditioning large
502layers to reduce overall memory usage.
503nesterov: Nesterov momentum.
504exponent_override: Override the exponent used in matrix inverse.
505batch_axis_name: labeled axis over pmap for dataparallel training the
506optimizer used for.
507"""
508hps = _ShampooHyperParams(
509learning_rate,
510beta1,
511beta2,
512diagonal_epsilon,
513matrix_epsilon,
514weight_decay,
515start_preconditioning_step,
516preconditioning_compute_steps,
517statistics_compute_steps,
518block_size,
519best_effort_shape_interpretation,
520graft_type=graft_type,
521no_preconditioning_for_layers_with_dim_gt=no_preconditioning_for_layers_with_dim_gt,
522nesterov=nesterov,
523exponent_override=exponent_override,
524batch_axis_name=batch_axis_name)
525print(hps)
526super().__init__(hps)
527
528def init_param_state(self, param):
529"""Initialize parameter state."""
530hps = self.hyper_params
531statistics = []
532preconditioners = []
533if not self._skip_preconditioning(param, hps):
534preconditioner = Preconditioner(param, hps)
535shapes = preconditioner.shapes_for_preconditioners()
536statistics = [
537self.hyper_params.matrix_eps * jnp.eye(s[0]) for s in shapes
538]
539preconditioners = [jnp.eye(s[0]) for s in shapes]
540
541adagrad_statistics = []
542if hps.graft_type == LayerwiseGrafting.ADAGRAD:
543adagrad_statistics = jnp.zeros_like(param)
544
545return _ShampooDefaultParamState(adagrad_statistics, statistics,
546preconditioners, jnp.zeros_like(param),
547jnp.zeros_like(param))
548
549def _skip_preconditioning(self, param, hps):
550return (len(param.shape) < 1 or any([
551s > hps.no_preconditioning_for_layers_with_dim_gt for s in param.shape
552]))
553
554def fast_cond(self, predicate, compute_fn, init_state, *args, **kwargs):
555"""Avoids wasteful buffer allocation with XLA."""
556
557def _iter_body(unused_state):
558results = compute_fn(*args, **kwargs)
559return tuple([False] + list(results))
560
561def _iter_condition(state):
562return state[0]
563
564results = lax.while_loop(_iter_condition, _iter_body,
565tuple([predicate] + init_state))
566return tuple(results[1:])
567
568def compute_shampoo_statistics(self, step, hps, param, state, grad):
569"""Compute statistics."""
570preconditioner = Preconditioner(param, hps)
571assert hps.learning_rate is not None, 'no learning rate provided.'
572new_statistics = [[]] * len(state.statistics)
573w1 = hps.beta2
574w2 = hps.beta2 if hps.beta2 == 1.0 else (1.0 - hps.beta2)
575if not self._skip_preconditioning(param, hps):
576def compute_updated_statistics():
577new_stats = preconditioner.statistics_from_grad(grad)
578new_stats_accumulators = []
579for stat, stat_accumulator in zip(new_stats, state.statistics):
580new_stats_accumulators.append(w1 * stat_accumulator + w2 * stat)
581return new_stats_accumulators
582
583if hps.statistics_compute_steps > 1:
584perform_step = step % hps.statistics_compute_steps == 0
585init_state = state.statistics
586new_statistics = list(
587self.fast_cond(perform_step, compute_updated_statistics,
588init_state))
589else:
590new_statistics = compute_updated_statistics()
591new_state = _ShampooDefaultParamState(state.diagonal_statistics,
592new_statistics, state.preconditioners,
593state.diagonal_momentum,
594state.momentum)
595return new_state
596
597def compute_preconditioners_from_statistics(self, states, params, hps, step):
598"""Compute preconditioners for statistics."""
599statistics = []
600num_statistics_per_state = []
601original_shapes = []
602exponents = []
603max_size = 0
604prev_preconditioners = []
605for state, param in zip(states, params):
606preconditioner = Preconditioner(param, hps)
607num_statistics = len(state.statistics)
608num_statistics_per_state.append(num_statistics)
609original_shapes_for_state = []
610if num_statistics > 0:
611for statistic in state.statistics:
612exponents.append(preconditioner.exponent_for_preconditioner() if hps
613.exponent_override == 0 else hps.exponent_override)
614original_shapes_for_state.append(statistic.shape)
615max_size = max(max_size, statistic.shape[0])
616statistics.extend(state.statistics)
617prev_preconditioners.extend(state.preconditioners)
618original_shapes.extend(original_shapes_for_state)
619num_statistics = len(statistics)
620
621def pack(mat, max_size):
622"""Pack a matrix to a max_size for inverse on TPUs with static shapes.
623
624Args:
625mat: Matrix for computing inverse pth root.
626max_size: Matrix size to pack to.
627
628Returns:
629Given M returns [[M, 0], [0, I]]
630"""
631size = mat.shape[0]
632assert size <= max_size
633if size == max_size:
634return mat
635pad_size = max_size - size
636zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
637zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
638eye = jnp.eye(pad_size, dtype=mat.dtype)
639mat = jnp.concatenate([mat, zs1], 1)
640mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
641return mat
642
643if not hps.batch_axis_name:
644num_devices = jax.local_device_count()
645else:
646num_devices = lax.psum(1, hps.batch_axis_name)
647
648# Pad statistics and exponents to next multiple of num_devices.
649packed_statistics = [pack(stat, max_size) for stat in statistics]
650to_pad = -num_statistics % num_devices
651packed_statistics.extend([
652jnp.eye(max_size, dtype=packed_statistics[0].dtype)
653for _ in range(to_pad)
654])
655exponents.extend([1 for _ in range(to_pad)])
656
657# Batch statistics and exponents so that so that leading axis is
658# num_devices.
659def _batch(statistics, exponents, num_devices):
660assert len(statistics) == len(exponents)
661n = len(statistics)
662b = int(n / num_devices)
663batched_statistics = [
664jnp.stack(statistics[idx:idx + b]) for idx in range(0, n, b)
665]
666batched_exponents = [
667jnp.stack(exponents[idx:idx + b]) for idx in range(0, n, b)
668]
669return jnp.stack(batched_statistics), jnp.stack(batched_exponents)
670
671# Unbatch values across leading axis and return a list of elements.
672def _unbatch(batched_values):
673b1, b2 = batched_values.shape[0], batched_values.shape[1]
674results = []
675for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
676v_array = jnp.squeeze(v_array)
677# b2 = batches (number of preconditioner computation) per core.
678if b2 > 1:
679for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
680results.append(jnp.squeeze(v))
681else:
682results.append(v_array)
683
684return results
685
686all_statistics, all_exponents = _batch(packed_statistics, exponents,
687num_devices)
688
689def _matrix_inverse_pth_root(xs, ps):
690mi_pth_root = lambda x, y: matrix_inverse_pth_root( # pylint: disable=g-long-lambda
691x, y, ridge_epsilon=hps.matrix_eps)
692preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
693return preconditioners, errors
694
695if not hps.batch_axis_name:
696preconditioners, errors = jax.pmap(_matrix_inverse_pth_root)(
697all_statistics, all_exponents)
698preconditioners_flat = _unbatch(preconditioners)
699errors_flat = _unbatch(errors)
700else:
701
702def _internal_inverse_pth_root_all():
703preconditioners = jnp.array(all_statistics)
704current_replica = lax.axis_index(hps.batch_axis_name)
705preconditioners, errors = _matrix_inverse_pth_root(
706all_statistics[current_replica], all_exponents[current_replica])
707preconditioners = jax.lax.all_gather(preconditioners,
708hps.batch_axis_name)
709errors = jax.lax.all_gather(errors, hps.batch_axis_name)
710preconditioners_flat = _unbatch(preconditioners)
711errors_flat = _unbatch(errors)
712return preconditioners_flat, errors_flat
713
714if hps.preconditioning_compute_steps == 1:
715preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
716else:
717# Passing statistics instead of preconditioners as they are similarly
718# shaped tensors, as error we are passing is the threshold these will
719# be ignored.
720preconditioners_init = packed_statistics
721errors_init = ([_INVERSE_PTH_ROOT_FAILURE_THRESHOLD] *
722len(packed_statistics))
723init_state = [preconditioners_init, errors_init]
724perform_step = step % hps.preconditioning_compute_steps == 0
725preconditioners_flat, errors_flat = self.fast_cond(
726perform_step, _internal_inverse_pth_root_all, init_state)
727
728def _skip(error):
729return jnp.logical_or(
730jnp.isnan(error),
731error >= _INVERSE_PTH_ROOT_FAILURE_THRESHOLD).astype(error.dtype)
732
733def _select_preconditioner(error, new_p, old_p):
734return lax.cond(
735_skip(error), lambda _: old_p, lambda _: new_p, operand=None)
736
737new_preconditioners_flat = []
738for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
739prev_preconditioners, errors_flat):
740new_preconditioners_flat.append(
741_select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
742
743assert len(states) == len(num_statistics_per_state)
744assert len(new_preconditioners_flat) == num_statistics
745
746# Add back empty preconditioners so we that we can set the optimizer state.
747preconditioners_for_states = []
748idx = 0
749for num_statistics, state in zip(num_statistics_per_state, states):
750if num_statistics == 0:
751preconditioners_for_states.append([])
752else:
753preconditioners_for_state = new_preconditioners_flat[idx:idx +
754num_statistics]
755assert len(state.statistics) == len(preconditioners_for_state)
756preconditioners_for_states.append(preconditioners_for_state)
757idx += num_statistics
758new_states = []
759for state, new_preconditioners in zip(states, preconditioners_for_states):
760new_states.append(
761_ShampooDefaultParamState(state.diagonal_statistics, state.statistics,
762new_preconditioners,
763state.diagonal_momentum, state.momentum))
764
765return new_states
766
767def apply_per_param_gradient(self, step, hps, param, state, grad):
768"""Apply per-parameter gradients."""
769preconditioner = Preconditioner(param, hps)
770assert hps.learning_rate is not None, 'no learning rate provided.'
771sgd_update = grad
772new_diagonal_statistics = state.diagonal_statistics
773if hps.graft_type == LayerwiseGrafting.ADAGRAD:
774new_diagonal_statistics = state.diagonal_statistics + jnp.square(grad)
775adagrad_update = grad / (
776jnp.sqrt(new_diagonal_statistics) + hps.diagonal_eps)
777grafting_update = adagrad_update
778else:
779grafting_update = sgd_update
780
781precond_grad = grad
782if not self._skip_preconditioning(param, hps):
783precond_grad = preconditioner.preconditioned_grad(precond_grad,
784state.preconditioners)
785else:
786precond_grad = grafting_update
787
788grafting_update_norm = jnp.linalg.norm(grafting_update)
789precond_grad_norm = jnp.linalg.norm(precond_grad)
790shampoo_update = precond_grad * (
791grafting_update_norm / (precond_grad_norm + 1e-16))
792
793shampoo_update_with_wd = shampoo_update
794grafting_update_with_wd = grafting_update
795if hps.weight_decay != 0:
796shampoo_update_with_wd = shampoo_update + hps.weight_decay * param
797grafting_update_with_wd = grafting_update + hps.weight_decay * param
798
799shampoo_update_with_wd_momentum = (
800state.momentum * hps.beta1 + shampoo_update_with_wd)
801grafting_update_with_wd_momentum = (
802state.diagonal_momentum * hps.beta1 + grafting_update_with_wd)
803
804run_shampoo = (step >= hps.start_preconditioning_step).astype(
805grafting_update_with_wd_momentum.dtype)
806
807momentum_update = (
808run_shampoo * shampoo_update_with_wd_momentum +
809(1.0 - run_shampoo) * grafting_update_with_wd_momentum)
810
811wd_update = (
812run_shampoo * shampoo_update_with_wd +
813(1.0 - run_shampoo) * grafting_update_with_wd)
814
815if hps.nesterov:
816momentum_update = wd_update + hps.beta1 * momentum_update
817
818new_param = param - hps.learning_rate * momentum_update
819new_state = _ShampooDefaultParamState(new_diagonal_statistics,
820state.statistics,
821state.preconditioners,
822grafting_update_with_wd_momentum,
823shampoo_update_with_wd_momentum)
824return new_param, new_state
825
826def apply_gradient(self, hyper_params, params, state, grads):
827"""Applies a gradient for a set of parameters.
828
829Args:
830hyper_params: a named tuple of hyper parameters.
831params: the parameters that should be updated.
832state: a named tuple containing the state of the optimizer
833grads: the gradient tensors for the parameters.
834
835Returns:
836A tuple containing the new parameters and the new optimizer state.
837"""
838step = state.step
839params_flat, treedef = jax.tree_flatten(params)
840states_flat = treedef.flatten_up_to(state.param_states)
841grads_flat = treedef.flatten_up_to(grads)
842
843new_states_flat = [
844self.compute_shampoo_statistics(step, hyper_params, param, state, grad)
845for param, state, grad in zip(params_flat, states_flat, grads_flat)
846]
847
848new_states_flat = self.compute_preconditioners_from_statistics(
849new_states_flat, params_flat, hyper_params, step)
850
851out = [
852self.apply_per_param_gradient(step, hyper_params, param, state, grad)
853for param, state, grad in zip(params_flat, new_states_flat, grads_flat)
854]
855
856new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ())
857new_params = jax.tree_unflatten(treedef, new_params_flat)
858new_param_states = jax.tree_unflatten(treedef, new_states_flat)
859new_state = OptimizerState(step + 1, new_param_states)
860return new_params, new_state
861