google-research

Форк
0
860 строк · 31.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Distributed Shampoo Implementation."""
17
# An implementation of distributed Shampoo optimizer from:
18
#
19
#  Scalable Second Order Optimization for Deep Learning
20
#  Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
21
#  Preprint Paper: https://arxiv.org/abs/2002.09018
22
#
23
# This implementation moves computation of inverse pth root back to the
24
# accelerator (if higher precision is available). We will present the details
25
# in an ArXiv note soon.
26
#
27
# This implementation has been verified to work on ResNet-50 training to 75.9%
28
# accuracy which is the MLPerf benchmark at 32K batch size. At the time of
29
# writing this comment it achieves this in 1729 steps whereas the best known
30
# first order method trains in 2512 steps.
31
#
32
# Authors: Rohan Anil (rohananil at google dot com)
33
#    &     Vineet Gupta (vineet at google dot com)
34
#
35
import enum
36
import itertools
37

38
from flax import struct
39
from flax.optim.base import OptimizerDef
40
from flax.optim.base import OptimizerState
41
import jax
42
from jax import lax
43
import jax.numpy as jnp
44
import numpy as onp
45

46
# Precision to use for matrix inverse pth root. Switch to f64 if you have
47
# hardware that supports it.
48
_INVERSE_PTH_ROOT_DATA_TYPE = jnp.float32
49

50
# Numerics are hard. Inverses fail sometimes. We determine that using this
51
# threshold.
52
_INVERSE_PTH_ROOT_FAILURE_THRESHOLD = 0.1
53

54
# Inverse pth root precision (XLA related) flag.
55
#
56
# Options are:
57
# a. lax.Precision.DEFAULT (Better step time, but not precise)
58
# b. lax.Precision.HIGH (Increased precision, slower)
59
# c. lax.Precision.HIGHEST (Best possible precision, slowest)
60
#
61
_INVERSE_PTH_ROOT_PRECISION = lax.Precision.HIGHEST
62

63

64
# Grafting is a technique to fix the layerwise scale of Shampoo optimizer.
65
# https://arxiv.org/pdf/2002.11803.pdf studies this in detail. Moreover this
66
# allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad
67
# is already well tuned.
68
class LayerwiseGrafting(enum.IntEnum):
69
  SGD = 1
70
  ADAGRAD = 2
71

72

73
@struct.dataclass
74
class _ShampooHyperParams:
75
  """Shampoo hyperparameters."""
76

77
  learning_rate: float
78
  # Momentum (in Heavy-Ball or Nesterov, if nesterov is True).
79
  beta1: onp.ndarray
80
  # Parameter for exponential moving average of Shampoo second moment statistics
81
  # if set == 1.0, then sums statistics instead of moving average.
82
  beta2: onp.ndarray
83
  # Only set if using Layerwise grafting mode to adagrad. This is the epsilon
84
  # for adagrad update.
85
  diagonal_eps: float
86

87
  # Epsilon to add to statistics before computing inverse pth root. If you are
88
  # running in f32 precision for inverse pth root (recommended today)
89
  # this can go upto 1e-6. If you have latest hardware with native f64 precision
90
  # set this upto 1e-12.
91
  matrix_eps: float
92

93
  # Weight decay parameter for regularization.
94
  weight_decay: float
95

96
  # When to start Shampoo update before which diagonal update is used. This is
97
  # because we do not have enough information to compute a stable inverse.
98
  start_preconditioning_step: int
99

100
  # Performance tuning params for controlling memory and compute requirements.
101
  # How often to compute preconditioner. Ideally set both params to 1.
102
  preconditioning_compute_steps: int
103
  # How often to compute statistics.
104
  statistics_compute_steps: int
105

106
  # Block size for large layers (if > 0).
107
  block_size: int
108

109
  # if there are some small dimensions, collapse them:
110
  # e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if block = 1024
111
  # [1, 2, 768, 1, 2048] --> [2, 768, 2048]
112
  best_effort_shape_interpretation: bool
113

114
  # Type of grafting (SGD or AdaGrad).
115
  # https://arxiv.org/pdf/2002.11803.pdf
116
  graft_type: int
117

118
  # Avoids preconditioning large layers to reduce overall memory usage if any
119
  # of the dimensions are greater than this value.
120
  no_preconditioning_for_layers_with_dim_gt: int
121

122
  # Nesterov momentum
123
  nesterov: bool
124
  # Exponent override (if > 0):
125
  exponent_override: int
126
  # Batch axis name (for data parallel code).
127
  batch_axis_name: str
128

129

130
class BlockPartitioner:
131
  """Partitions a tensor into smaller tensors."""
132

133
  def __init__(self, param, hps):
134
    self._shape = param.shape
135
    self._splits = []
136
    split_sizes = []
137
    # We split params into smaller blocks. Here we store the metadata to make
138
    # that split.
139
    for i, d in enumerate(param.shape):
140
      if hps.block_size > 0 and d > hps.block_size:
141
        # d-1, otherwise split appends a 0-size array.
142
        nsplit = (d-1) // hps.block_size
143
        indices = (onp.arange(nsplit, dtype=onp.int32) + 1) * hps.block_size
144
        sizes = onp.ones(nsplit + 1, dtype=onp.int32) * hps.block_size
145
        sizes[-1] = d - indices[-1]
146
        self._splits.append((i, indices))
147
        split_sizes.append(sizes)
148
      else:
149
        split_sizes.append(onp.array([d], dtype=onp.int32))
150
    self._num_splits = len(split_sizes)
151
    self._preconditioner_shapes = []
152
    for t in itertools.product(*split_sizes):
153
      self._preconditioner_shapes.extend([[d, d] for d in t])
154

155
  def shapes_for_preconditioners(self):
156
    return self._preconditioner_shapes
157

158
  def num_splits(self):
159
    return self._num_splits
160

161
  def partition(self, tensor):
162
    """Partition tensor into blocks."""
163

164
    assert tensor.shape == self._shape
165
    tensors = [tensor]
166
    for (i, indices) in self._splits:
167
      tensors_local = []
168
      for t in tensors:
169
        tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
170
      tensors = tensors_local
171
    return tensors
172

173
  def merge_partitions(self, partitions):
174
    """Merge partitions back to original shape."""
175

176
    for (i, indices) in reversed(self._splits):
177
      n = len(indices) + 1
178
      partial_merged_tensors = []
179
      ind = 0
180
      while ind < len(partitions):
181
        partial_merged_tensors.append(
182
            jnp.concatenate(partitions[ind:ind + n], axis=i))
183
        ind += n
184
      partitions = partial_merged_tensors
185
    assert len(partitions) == 1
186
    return partitions[0]
187

188

189
def _merge_small_dims(shape_to_merge, max_dim):
190
  """Merge small dimensions.
191

192
  If there are some small dimensions, we collapse them:
193
  e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
194
       [1, 2, 768, 1, 2048] --> [2, 768, 2048]
195

196
  Args:
197
    shape_to_merge: Shape to merge small dimensions.
198
    max_dim: Maximal dimension of output shape used in merging.
199

200
  Returns:
201
    Merged shape.
202
  """
203
  resulting_shape = []
204
  product = 1
205
  for d in shape_to_merge:
206
    if product * d <= max_dim:
207
      product *= d
208
    else:
209
      if product > 1:
210
        resulting_shape.append(product)
211
      product = d
212
  if product > 1:
213
    resulting_shape.append(product)
214
  return resulting_shape
215

216

217
class Preconditioner:
218
  """Compute statistics/shape from gradients for preconditioning."""
219

220
  def __init__(self, param, hps):
221
    self._hps = hps
222
    self._original_shape = param.shape
223
    self._transformed_shape = param.shape
224
    if hps.best_effort_shape_interpretation:
225
      self._transformed_shape = _merge_small_dims(
226
          self._original_shape, hps.block_size)
227

228
    reshaped_param = jnp.reshape(param, self._transformed_shape)
229
    self._partitioner = BlockPartitioner(reshaped_param, hps)
230

231
  def statistics_from_grad(self, grad):
232
    """Compute statistics from gradients.
233

234
    Args:
235
      grad: Gradient to compute statistics from.
236

237
    Returns:
238
      A list of gradient statistics for each partition.
239
    """
240
    reshaped_grad = jnp.reshape(grad, self._transformed_shape)
241
    partitioned_grads = self._partitioner.partition(reshaped_grad)
242
    stats = []
243
    for grad in partitioned_grads:
244
      grad_stats = []
245
      rank = len(grad.shape)
246
      for i in range(rank):
247
        axes = list(range(i)) + list(range(i + 1, rank))
248
        stat = jnp.tensordot(grad, grad, axes=(axes, axes))
249
        grad_stats.append(stat)
250
      stats.extend(grad_stats)
251
    return stats
252

253
  def shapes_for_preconditioners(self):
254
    """Returns shape from statistics."""
255
    return self._partitioner.shapes_for_preconditioners()
256

257
  def exponent_for_preconditioner(self):
258
    """Returns exponent to use for inverse-pth root M^{-1/p}."""
259
    return 2 * len(self._transformed_shape)
260

261
  def preconditioned_grad(self, grad, preconditioners):
262
    """Precondition the gradient.
263

264
    Args:
265
      grad: A gradient tensor to precondition.
266
      preconditioners: A list of preconditioners to apply.
267

268
    Returns:
269
      A preconditioned gradient.
270
    """
271

272
    reshaped_grad = jnp.reshape(grad, self._transformed_shape)
273
    partitioned_grads = self._partitioner.partition(reshaped_grad)
274
    preconditioned_partitioned_grads = []
275
    num_splits = self._partitioner.num_splits()
276
    for i, grad in enumerate(partitioned_grads):
277
      preconditioners_for_grad = preconditioners[i * num_splits:(i + 1) *
278
                                                 num_splits]
279
      rank = len(grad.shape)
280
      precond_grad = grad
281
      for j in range(rank):
282
        precond_grad = jnp.tensordot(
283
            precond_grad, preconditioners_for_grad[j], axes=[[0], [0]])
284
      preconditioned_partitioned_grads.append(precond_grad)
285
    merged_grad = self._partitioner.merge_partitions(
286
        preconditioned_partitioned_grads)
287
    return jnp.reshape(merged_grad, self._original_shape)
288

289

290
@struct.dataclass
291
class _ShampooDefaultParamState:
292
  """Shampoo default parameter state."""
293

294
  # Accumulator for diagonal preconditioner
295
  diagonal_statistics: onp.ndarray
296
  # Statistics
297
  statistics: onp.ndarray
298
  # Preconditioners
299
  preconditioners: onp.ndarray
300
  # Momentum for the diagonal preconditioner
301
  diagonal_momentum: onp.ndarray
302
  # Momentum for the shampoo preconditioner
303
  momentum: onp.ndarray
304

305

306
def power_iter(mat_g, error_tolerance=1e-6, num_iters=100):
307
  """Power iteration.
308

309
  Args:
310
    mat_g: the symmetric PSD matrix.
311
    error_tolerance: Iterative exit condition.
312
    num_iters: Number of iterations.
313

314
  Returns:
315
    eigen vector, eigen value, num_iters
316
  """
317
  mat_g_size = mat_g.shape[-1]
318
  def _iter_condition(state):
319
    i, unused_v, unused_s, unused_s_v, run_step = state
320
    return jnp.logical_and(i < num_iters, run_step)
321

322
  def _iter_body(state):
323
    """One step of power iteration."""
324
    i, new_v, s, s_v, unused_run_step = state
325
    new_v = new_v / jnp.linalg.norm(new_v)
326

327
    s_v = jnp.einsum(
328
        'ij,j->i', mat_g, new_v, precision=_INVERSE_PTH_ROOT_PRECISION)
329
    s_new = jnp.einsum(
330
        'i,i->', new_v, s_v, precision=_INVERSE_PTH_ROOT_PRECISION)
331
    return (i + 1, s_v, s_new, s_v,
332
            jnp.greater(jnp.abs(s_new - s), error_tolerance))
333

334
  # Figure out how to use step as seed for random.
335
  v_0 = onp.random.uniform(-1.0, 1.0, mat_g_size).astype(mat_g.dtype)
336

337
  init_state = tuple([0, v_0, jnp.zeros([], dtype=mat_g.dtype), v_0, True])
338
  num_iters, v_out, s_out, _, _ = lax.while_loop(
339
      _iter_condition, _iter_body, init_state)
340
  v_out = v_out / jnp.linalg.norm(v_out)
341
  return v_out, s_out, num_iters
342

343

344
def matrix_inverse_pth_root(mat_g,
345
                            p,
346
                            iter_count=100,
347
                            error_tolerance=1e-6,
348
                            ridge_epsilon=1e-6):
349
  """Computes mat_g^(-1/p), where p is a positive integer.
350

351
  Coupled newton iterations for matrix inverse pth root.
352

353
  Args:
354
    mat_g: the symmetric PSD matrix whose power it to be computed
355
    p: exponent, for p a positive integer.
356
    iter_count: Maximum number of iterations.
357
    error_tolerance: Error indicator, useful for early termination.
358
    ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
359

360
  Returns:
361
    mat_g^(-1/p)
362
  """
363
  mat_g_size = mat_g.shape[0]
364
  alpha = jnp.asarray(-1.0 / p, _INVERSE_PTH_ROOT_DATA_TYPE)
365
  identity = jnp.eye(mat_g_size, dtype=_INVERSE_PTH_ROOT_DATA_TYPE)
366
  _, max_ev, _ = power_iter(mat_g)
367
  ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16)
368

369
  def _unrolled_mat_pow_1(mat_m):
370
    """Computes mat_m^1."""
371
    return mat_m
372

373
  def _unrolled_mat_pow_2(mat_m):
374
    """Computes mat_m^2."""
375
    return jnp.matmul(mat_m, mat_m, precision=_INVERSE_PTH_ROOT_PRECISION)
376

377
  def _unrolled_mat_pow_4(mat_m):
378
    """Computes mat_m^4."""
379
    mat_pow_2 = _unrolled_mat_pow_2(mat_m)
380
    return jnp.matmul(
381
        mat_pow_2, mat_pow_2, precision=_INVERSE_PTH_ROOT_PRECISION)
382

383
  def _unrolled_mat_pow_8(mat_m):
384
    """Computes mat_m^4."""
385
    mat_pow_4 = _unrolled_mat_pow_4(mat_m)
386
    return jnp.matmul(
387
        mat_pow_4, mat_pow_4, precision=_INVERSE_PTH_ROOT_PRECISION)
388

389
  def mat_power(mat_m, p):
390
    """Computes mat_m^p, for p == 1, 2, 4 or 8.
391

392
    Args:
393
      mat_m: a square matrix
394
      p: a positive integer
395

396
    Returns:
397
      mat_m^p
398
    """
399
    # We unrolled the loop for performance reasons.
400
    exponent = jnp.round(jnp.log2(p))
401
    return lax.switch(
402
        jnp.asarray(exponent, jnp.int32), [
403
            _unrolled_mat_pow_1,
404
            _unrolled_mat_pow_2,
405
            _unrolled_mat_pow_4,
406
            _unrolled_mat_pow_8,
407
        ], (mat_m))
408

409
  def _iter_condition(state):
410
    (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error,
411
     run_step) = state
412
    error_above_threshold = jnp.logical_and(
413
        error > error_tolerance, run_step)
414
    return jnp.logical_and(i < iter_count, error_above_threshold)
415

416
  def _iter_body(state):
417
    (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
418
    mat_m_i = (1 - alpha) * identity + alpha * mat_m
419
    new_mat_m = jnp.matmul(
420
        mat_power(mat_m_i, p), mat_m, precision=_INVERSE_PTH_ROOT_PRECISION)
421
    new_mat_h = jnp.matmul(
422
        mat_h, mat_m_i, precision=_INVERSE_PTH_ROOT_PRECISION)
423
    new_error = jnp.max(jnp.abs(new_mat_m - identity))
424
    # sometimes error increases after an iteration before decreasing and
425
    # converging. 1.2 factor is used to bound the maximal allowed increase.
426
    return (i + 1, new_mat_m, new_mat_h, mat_h, new_error,
427
            new_error < error * 1.2)
428

429
  if mat_g_size == 1:
430
    resultant_mat_h = (mat_g + ridge_epsilon)**alpha
431
    error = 0
432
  else:
433
    damped_mat_g = mat_g + ridge_epsilon * identity
434
    z = (1 + p) / (2 * jnp.linalg.norm(damped_mat_g))
435
    new_mat_m_0 = damped_mat_g * z
436
    new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
437
    new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
438
    init_state = tuple(
439
        [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
440
    _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
441
        _iter_condition, _iter_body, init_state)
442
    error = jnp.max(jnp.abs(mat_m - identity))
443
    is_converged = jnp.asarray(convergence, old_mat_h.dtype)
444
    resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
445
    resultant_mat_h = jnp.asarray(resultant_mat_h, mat_g.dtype)
446
  return resultant_mat_h, error
447

448

449
class Shampoo(OptimizerDef):
450
  """Shampoo optimizer.
451

452
    Scalable Second Order Optimization for Deep Learning,
453
    Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
454

455
    Preprint: https://arxiv.org/abs/2002.09018
456
  """
457

458
  def __init__(self,
459
               learning_rate = None,
460
               beta1=0.9,
461
               beta2=0.999,
462
               diagonal_epsilon=1e-10,
463
               matrix_epsilon=1e-6,
464
               weight_decay=0.0,
465
               start_preconditioning_step=1,
466
               preconditioning_compute_steps=1,
467
               statistics_compute_steps=1,
468
               block_size=128,
469
               best_effort_shape_interpretation=True,
470
               graft_type=LayerwiseGrafting.SGD,
471
               no_preconditioning_for_layers_with_dim_gt=8192,
472
               nesterov=True,
473
               exponent_override=0,
474
               batch_axis_name=None):
475
    """Constructor for the Shampoo optimizer.
476

477
    Args:
478
      learning_rate: the step size used to update the parameters.
479
      beta1: momentum parameter.
480
      beta2: second moment averaging parameter.
481
      diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
482
        to AdaGrad is enabled).
483
      matrix_epsilon: epsilon to add to statistics before computing inverse pth
484
        root. If you are running in f32 precision for inverse pth root
485
        (recommended today) this can go upto 1e-6. If you have latest hardware
486
        with native f64 precision, set this upto 1e-12.
487
      weight_decay: Weight decay for regularization.
488
      start_preconditioning_step: When to start Shampoo update before which
489
        diagonal update is used. This is because we dont have enough information
490
        to do stable inverse.
491
      preconditioning_compute_steps: How often to compute preconditioner.
492
        Performance tuning params for controlling memory and compute
493
        requirements. Ideally set both params to 1.
494
      statistics_compute_steps: How often to compute statistics.
495
      block_size: Block size for large layers (if > 0). Preconditioning compute
496
        operation is cubic in the dimension of the tensor. Block size allows us
497
        to chunk the layers into sub-layers of maximal dimension dictated by
498
        this value. Use 128 as default (increase if you have compute budget).
499
      best_effort_shape_interpretation:
500
      graft_type: Options are: LayerwiseGrafting.SGD, LayerwiseGrafting.ADAGRAD
501
      no_preconditioning_for_layers_with_dim_gt: Avoids preconditioning large
502
        layers to reduce overall memory usage.
503
      nesterov: Nesterov momentum.
504
      exponent_override: Override the exponent used in matrix inverse.
505
      batch_axis_name: labeled axis over pmap for dataparallel training the
506
        optimizer used for.
507
    """
508
    hps = _ShampooHyperParams(
509
        learning_rate,
510
        beta1,
511
        beta2,
512
        diagonal_epsilon,
513
        matrix_epsilon,
514
        weight_decay,
515
        start_preconditioning_step,
516
        preconditioning_compute_steps,
517
        statistics_compute_steps,
518
        block_size,
519
        best_effort_shape_interpretation,
520
        graft_type=graft_type,
521
        no_preconditioning_for_layers_with_dim_gt=no_preconditioning_for_layers_with_dim_gt,
522
        nesterov=nesterov,
523
        exponent_override=exponent_override,
524
        batch_axis_name=batch_axis_name)
525
    print(hps)
526
    super().__init__(hps)
527

528
  def init_param_state(self, param):
529
    """Initialize parameter state."""
530
    hps = self.hyper_params
531
    statistics = []
532
    preconditioners = []
533
    if not self._skip_preconditioning(param, hps):
534
      preconditioner = Preconditioner(param, hps)
535
      shapes = preconditioner.shapes_for_preconditioners()
536
      statistics = [
537
          self.hyper_params.matrix_eps * jnp.eye(s[0]) for s in shapes
538
      ]
539
      preconditioners = [jnp.eye(s[0]) for s in shapes]
540

541
    adagrad_statistics = []
542
    if hps.graft_type == LayerwiseGrafting.ADAGRAD:
543
      adagrad_statistics = jnp.zeros_like(param)
544

545
    return _ShampooDefaultParamState(adagrad_statistics, statistics,
546
                                     preconditioners, jnp.zeros_like(param),
547
                                     jnp.zeros_like(param))
548

549
  def _skip_preconditioning(self, param, hps):
550
    return (len(param.shape) < 1 or any([
551
        s > hps.no_preconditioning_for_layers_with_dim_gt for s in param.shape
552
    ]))
553

554
  def fast_cond(self, predicate, compute_fn, init_state, *args, **kwargs):
555
    """Avoids wasteful buffer allocation with XLA."""
556

557
    def _iter_body(unused_state):
558
      results = compute_fn(*args, **kwargs)
559
      return tuple([False] + list(results))
560

561
    def _iter_condition(state):
562
      return state[0]
563

564
    results = lax.while_loop(_iter_condition, _iter_body,
565
                             tuple([predicate] + init_state))
566
    return tuple(results[1:])
567

568
  def compute_shampoo_statistics(self, step, hps, param, state, grad):
569
    """Compute statistics."""
570
    preconditioner = Preconditioner(param, hps)
571
    assert hps.learning_rate is not None, 'no learning rate provided.'
572
    new_statistics = [[]] * len(state.statistics)
573
    w1 = hps.beta2
574
    w2 = hps.beta2 if hps.beta2 == 1.0 else (1.0 - hps.beta2)
575
    if not self._skip_preconditioning(param, hps):
576
      def compute_updated_statistics():
577
        new_stats = preconditioner.statistics_from_grad(grad)
578
        new_stats_accumulators = []
579
        for stat, stat_accumulator in zip(new_stats, state.statistics):
580
          new_stats_accumulators.append(w1 * stat_accumulator + w2 * stat)
581
        return new_stats_accumulators
582

583
      if hps.statistics_compute_steps > 1:
584
        perform_step = step % hps.statistics_compute_steps == 0
585
        init_state = state.statistics
586
        new_statistics = list(
587
            self.fast_cond(perform_step, compute_updated_statistics,
588
                           init_state))
589
      else:
590
        new_statistics = compute_updated_statistics()
591
    new_state = _ShampooDefaultParamState(state.diagonal_statistics,
592
                                          new_statistics, state.preconditioners,
593
                                          state.diagonal_momentum,
594
                                          state.momentum)
595
    return new_state
596

597
  def compute_preconditioners_from_statistics(self, states, params, hps, step):
598
    """Compute preconditioners for statistics."""
599
    statistics = []
600
    num_statistics_per_state = []
601
    original_shapes = []
602
    exponents = []
603
    max_size = 0
604
    prev_preconditioners = []
605
    for state, param in zip(states, params):
606
      preconditioner = Preconditioner(param, hps)
607
      num_statistics = len(state.statistics)
608
      num_statistics_per_state.append(num_statistics)
609
      original_shapes_for_state = []
610
      if num_statistics > 0:
611
        for statistic in state.statistics:
612
          exponents.append(preconditioner.exponent_for_preconditioner() if hps
613
                           .exponent_override == 0 else hps.exponent_override)
614
          original_shapes_for_state.append(statistic.shape)
615
          max_size = max(max_size, statistic.shape[0])
616
        statistics.extend(state.statistics)
617
        prev_preconditioners.extend(state.preconditioners)
618
        original_shapes.extend(original_shapes_for_state)
619
    num_statistics = len(statistics)
620

621
    def pack(mat, max_size):
622
      """Pack a matrix to a max_size for inverse on TPUs with static shapes.
623

624
      Args:
625
        mat: Matrix for computing inverse pth root.
626
        max_size: Matrix size to pack to.
627

628
      Returns:
629
        Given M returns [[M, 0], [0, I]]
630
      """
631
      size = mat.shape[0]
632
      assert size <= max_size
633
      if size == max_size:
634
        return mat
635
      pad_size = max_size - size
636
      zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
637
      zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
638
      eye = jnp.eye(pad_size, dtype=mat.dtype)
639
      mat = jnp.concatenate([mat, zs1], 1)
640
      mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
641
      return mat
642

643
    if not hps.batch_axis_name:
644
      num_devices = jax.local_device_count()
645
    else:
646
      num_devices = lax.psum(1, hps.batch_axis_name)
647

648
    # Pad statistics and exponents to next multiple of num_devices.
649
    packed_statistics = [pack(stat, max_size) for stat in statistics]
650
    to_pad = -num_statistics % num_devices
651
    packed_statistics.extend([
652
        jnp.eye(max_size, dtype=packed_statistics[0].dtype)
653
        for _ in range(to_pad)
654
    ])
655
    exponents.extend([1 for _ in range(to_pad)])
656

657
    # Batch statistics and exponents so that so that leading axis is
658
    # num_devices.
659
    def _batch(statistics, exponents, num_devices):
660
      assert len(statistics) == len(exponents)
661
      n = len(statistics)
662
      b = int(n / num_devices)
663
      batched_statistics = [
664
          jnp.stack(statistics[idx:idx + b]) for idx in range(0, n, b)
665
      ]
666
      batched_exponents = [
667
          jnp.stack(exponents[idx:idx + b]) for idx in range(0, n, b)
668
      ]
669
      return jnp.stack(batched_statistics), jnp.stack(batched_exponents)
670

671
    # Unbatch values across leading axis and return a list of elements.
672
    def _unbatch(batched_values):
673
      b1, b2 = batched_values.shape[0], batched_values.shape[1]
674
      results = []
675
      for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
676
        v_array = jnp.squeeze(v_array)
677
        # b2 = batches (number of preconditioner computation) per core.
678
        if b2 > 1:
679
          for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
680
            results.append(jnp.squeeze(v))
681
        else:
682
          results.append(v_array)
683

684
      return results
685

686
    all_statistics, all_exponents = _batch(packed_statistics, exponents,
687
                                           num_devices)
688

689
    def _matrix_inverse_pth_root(xs, ps):
690
      mi_pth_root = lambda x, y: matrix_inverse_pth_root(  # pylint: disable=g-long-lambda
691
          x, y, ridge_epsilon=hps.matrix_eps)
692
      preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
693
      return preconditioners, errors
694

695
    if not hps.batch_axis_name:
696
      preconditioners, errors = jax.pmap(_matrix_inverse_pth_root)(
697
          all_statistics, all_exponents)
698
      preconditioners_flat = _unbatch(preconditioners)
699
      errors_flat = _unbatch(errors)
700
    else:
701

702
      def _internal_inverse_pth_root_all():
703
        preconditioners = jnp.array(all_statistics)
704
        current_replica = lax.axis_index(hps.batch_axis_name)
705
        preconditioners, errors = _matrix_inverse_pth_root(
706
            all_statistics[current_replica], all_exponents[current_replica])
707
        preconditioners = jax.lax.all_gather(preconditioners,
708
                                             hps.batch_axis_name)
709
        errors = jax.lax.all_gather(errors, hps.batch_axis_name)
710
        preconditioners_flat = _unbatch(preconditioners)
711
        errors_flat = _unbatch(errors)
712
        return preconditioners_flat, errors_flat
713

714
      if hps.preconditioning_compute_steps == 1:
715
        preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
716
      else:
717
        # Passing statistics instead of preconditioners as they are similarly
718
        # shaped tensors, as error we are passing is the threshold these will
719
        # be ignored.
720
        preconditioners_init = packed_statistics
721
        errors_init = ([_INVERSE_PTH_ROOT_FAILURE_THRESHOLD] *
722
                       len(packed_statistics))
723
        init_state = [preconditioners_init, errors_init]
724
        perform_step = step % hps.preconditioning_compute_steps == 0
725
        preconditioners_flat, errors_flat = self.fast_cond(
726
            perform_step, _internal_inverse_pth_root_all, init_state)
727

728
    def _skip(error):
729
      return jnp.logical_or(
730
          jnp.isnan(error),
731
          error >= _INVERSE_PTH_ROOT_FAILURE_THRESHOLD).astype(error.dtype)
732

733
    def _select_preconditioner(error, new_p, old_p):
734
      return lax.cond(
735
          _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
736

737
    new_preconditioners_flat = []
738
    for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
739
                                       prev_preconditioners, errors_flat):
740
      new_preconditioners_flat.append(
741
          _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
742

743
    assert len(states) == len(num_statistics_per_state)
744
    assert len(new_preconditioners_flat) == num_statistics
745

746
    # Add back empty preconditioners so we that we can set the optimizer state.
747
    preconditioners_for_states = []
748
    idx = 0
749
    for num_statistics, state in zip(num_statistics_per_state, states):
750
      if num_statistics == 0:
751
        preconditioners_for_states.append([])
752
      else:
753
        preconditioners_for_state = new_preconditioners_flat[idx:idx +
754
                                                             num_statistics]
755
        assert len(state.statistics) == len(preconditioners_for_state)
756
        preconditioners_for_states.append(preconditioners_for_state)
757
        idx += num_statistics
758
    new_states = []
759
    for state, new_preconditioners in zip(states, preconditioners_for_states):
760
      new_states.append(
761
          _ShampooDefaultParamState(state.diagonal_statistics, state.statistics,
762
                                    new_preconditioners,
763
                                    state.diagonal_momentum, state.momentum))
764

765
    return new_states
766

767
  def apply_per_param_gradient(self, step, hps, param, state, grad):
768
    """Apply per-parameter gradients."""
769
    preconditioner = Preconditioner(param, hps)
770
    assert hps.learning_rate is not None, 'no learning rate provided.'
771
    sgd_update = grad
772
    new_diagonal_statistics = state.diagonal_statistics
773
    if hps.graft_type == LayerwiseGrafting.ADAGRAD:
774
      new_diagonal_statistics = state.diagonal_statistics + jnp.square(grad)
775
      adagrad_update = grad / (
776
          jnp.sqrt(new_diagonal_statistics) + hps.diagonal_eps)
777
      grafting_update = adagrad_update
778
    else:
779
      grafting_update = sgd_update
780

781
    precond_grad = grad
782
    if not self._skip_preconditioning(param, hps):
783
      precond_grad = preconditioner.preconditioned_grad(precond_grad,
784
                                                        state.preconditioners)
785
    else:
786
      precond_grad = grafting_update
787

788
    grafting_update_norm = jnp.linalg.norm(grafting_update)
789
    precond_grad_norm = jnp.linalg.norm(precond_grad)
790
    shampoo_update = precond_grad * (
791
        grafting_update_norm / (precond_grad_norm + 1e-16))
792

793
    shampoo_update_with_wd = shampoo_update
794
    grafting_update_with_wd = grafting_update
795
    if hps.weight_decay != 0:
796
      shampoo_update_with_wd = shampoo_update + hps.weight_decay * param
797
      grafting_update_with_wd = grafting_update + hps.weight_decay * param
798

799
    shampoo_update_with_wd_momentum = (
800
        state.momentum * hps.beta1 + shampoo_update_with_wd)
801
    grafting_update_with_wd_momentum = (
802
        state.diagonal_momentum * hps.beta1 + grafting_update_with_wd)
803

804
    run_shampoo = (step >= hps.start_preconditioning_step).astype(
805
        grafting_update_with_wd_momentum.dtype)
806

807
    momentum_update = (
808
        run_shampoo * shampoo_update_with_wd_momentum +
809
        (1.0 - run_shampoo) * grafting_update_with_wd_momentum)
810

811
    wd_update = (
812
        run_shampoo * shampoo_update_with_wd +
813
        (1.0 - run_shampoo) * grafting_update_with_wd)
814

815
    if hps.nesterov:
816
      momentum_update = wd_update + hps.beta1 * momentum_update
817

818
    new_param = param - hps.learning_rate * momentum_update
819
    new_state = _ShampooDefaultParamState(new_diagonal_statistics,
820
                                          state.statistics,
821
                                          state.preconditioners,
822
                                          grafting_update_with_wd_momentum,
823
                                          shampoo_update_with_wd_momentum)
824
    return new_param, new_state
825

826
  def apply_gradient(self, hyper_params, params, state, grads):
827
    """Applies a gradient for a set of parameters.
828

829
    Args:
830
      hyper_params: a named tuple of hyper parameters.
831
      params: the parameters that should be updated.
832
      state: a named tuple containing the state of the optimizer
833
      grads: the gradient tensors for the parameters.
834

835
    Returns:
836
      A tuple containing the new parameters and the new optimizer state.
837
    """
838
    step = state.step
839
    params_flat, treedef = jax.tree_flatten(params)
840
    states_flat = treedef.flatten_up_to(state.param_states)
841
    grads_flat = treedef.flatten_up_to(grads)
842

843
    new_states_flat = [
844
        self.compute_shampoo_statistics(step, hyper_params, param, state, grad)
845
        for param, state, grad in zip(params_flat, states_flat, grads_flat)
846
    ]
847

848
    new_states_flat = self.compute_preconditioners_from_statistics(
849
        new_states_flat, params_flat, hyper_params, step)
850

851
    out = [
852
        self.apply_per_param_gradient(step, hyper_params, param, state, grad)
853
        for param, state, grad in zip(params_flat, new_states_flat, grads_flat)
854
    ]
855

856
    new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ())
857
    new_params = jax.tree_unflatten(treedef, new_params_flat)
858
    new_param_states = jax.tree_unflatten(treedef, new_states_flat)
859
    new_state = OptimizerState(step + 1, new_param_states)
860
    return new_params, new_state
861

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.