google-research

Форк
0
/
distributed_shampoo_test.py 
402 строки · 13.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for distributed_shampoo."""
17

18
import functools
19
import itertools
20

21
from absl.testing import absltest
22
from absl.testing import parameterized
23
import chex
24
import jax
25
import jax.numpy as jnp
26
import numpy as np
27
import scipy
28

29
from scalable_shampoo.optax import distributed_shampoo
30

31

32
class PaddingTest(parameterized.TestCase):
33

34
  def assertAllClose(self, x, y, atol=1e-5, rtol=1e-5):
35
    np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
36

37
  @parameterized.named_parameters(
38
      {
39
          'testcase_name': 'NoPadding',
40
          'max_size': 3,
41
          'result': [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]],
42
      },
43
      {
44
          'testcase_name':
45
              'Padding',
46
          'max_size':
47
              5,
48
          'result': [[1., 1., 1., 0., 0.], [1., 1., 1., 0., 0.],
49
                     [1., 1., 1., 0., 0.], [0., 0., 0., 1., 0.],
50
                     [0., 0., 0., 0., 1.]],
51
      },
52
  )
53
  def test_pad_square_matrix(self, max_size, result):
54
    self.assertAllClose(
55
        distributed_shampoo.pad_square_matrix(
56
            mat=jnp.ones(shape=(3, 3), dtype=jnp.float32), max_size=max_size),
57
        jnp.asarray(result, dtype=jnp.float32))
58

59
  @parameterized.named_parameters(
60
      {
61
          'testcase_name': 'TooLarge',
62
          'shape': (3, 3),
63
          'max_size': 2
64
      },
65
      {
66
          'testcase_name': 'NotSquare',
67
          'shape': (3, 4),
68
          'max_size': 5
69
      },
70
  )
71
  def test_pad_square_matrix_error(self, shape, max_size):
72
    with self.assertRaises(ValueError):
73
      distributed_shampoo.pad_square_matrix(
74
          mat=jnp.ones(shape=shape), max_size=max_size)
75

76

77
def _pth_root_difference_cases():
78
  """Returns cases for _pth_root_difference() test."""
79
  cases = []
80
  # The test checks accuracy of
81
  # (w + a)^(-1/p) - (w + b)^(-1/p)
82
  # so generate corresponding parameters.
83
  p_vals = [2, 4, 6, 8]
84
  a_vals = b_vals = [1e-6, 1e-5, 0.0, 1.0]
85
  w_vals = [1e-6, 1e-5, 1.0, 1e3]
86
  for p, a, b, w in itertools.product(p_vals, a_vals, b_vals, w_vals):
87
    cases.append({'p': p, 'a': a, 'b': b, 'w': w})
88
  return cases
89

90

91
class DistributedShampooTest(chex.TestCase, parameterized.TestCase):
92

93
  def setUp(self):
94
    super().setUp()
95
    self.init_params = (jnp.array([[1., 3.],
96
                                   [2., 4.]]), jnp.array([[3., 4.], [3., 4.]]))
97
    self.per_step_updates = (jnp.array([[500., 5.], [500., 5.]]),
98
                             jnp.array([[300., 3.], [300., 3.]]))
99
    self.per_step_updates_custom_preconditioner = (self.per_step_updates,
100
                                                   (jnp.array([[200., 4.],
101
                                                               [200., 4.]]),
102
                                                    jnp.array([[600., 2.],
103
                                                               [600., 2.]])))
104
    self.rng = np.random.default_rng(1234)
105
    shape = ([2, 5], [6, 3])
106
    dt = self.init_params[0].dtype
107

108
    def make_shape(bigger_first_entry):
109
      x = tuple(self.rng.standard_normal(size=s) for s in shape)
110
      if bigger_first_entry:
111
        for xx in x:
112
          xx[Ellipsis, 0] *= 100
113
      return tuple(jnp.array(xx).astype(dt) for xx in x)
114

115
    self.init_params_larger = make_shape(False)
116
    self.per_step_updates_larger = make_shape(True)
117

118
  @chex.all_variants(with_pmap=False)
119
  @parameterized.named_parameters(
120
      {
121
          'testcase_name': 'default',
122
          'best_effort_memory_usage_reduction': True,
123
          'expected_value': -0.57,
124
      },
125
      {
126
          'testcase_name': 'default_nomerge',
127
          'best_effort_memory_usage_reduction': True,
128
          'merge_small_dims_block_size': 1,
129
          'expected_value': -0.57,
130
      },
131
      {
132
          'testcase_name': 'default_larger',
133
          'best_effort_memory_usage_reduction': True,
134
          'slightly_larger': True,
135
          'expected_value': -0.17019942,
136
      },
137
      {
138
          'testcase_name': 'default_larger_nomerge',
139
          'best_effort_memory_usage_reduction': True,
140
          'slightly_larger': True,
141
          'merge_small_dims_block_size': 1,
142
          'expected_value': -0.17019942,
143
      },
144
      {
145
          'testcase_name': 'materialize_statistics',
146
          'best_effort_memory_usage_reduction': True,
147
      },
148
      {
149
          'testcase_name': 'blocked_statistics',
150
          'best_effort_memory_usage_reduction': True,
151
      },
152
      {
153
          'testcase_name': 'default_quantized',
154
      },
155
      {
156
          'testcase_name': 'materialize_statistics_quantized',
157
      },
158
      {
159
          'testcase_name': 'blocked_statistics_quantized',
160
      },
161
      {
162
          'testcase_name': 'no_training_metrics',
163
          'generate_training_metrics': False,
164
      },
165
      {
166
          'testcase_name': 'larger_reuse',
167
          'best_effort_memory_usage_reduction': True,
168
          'reuse_preconditioner': True,
169
          'slightly_larger': True,
170
          'expected_value': -0.17019942,
171
      },
172
      {
173
          'testcase_name': 'larger_reuse_highmem',
174
          'best_effort_memory_usage_reduction': False,
175
          'reuse_preconditioner': True,
176
          'slightly_larger': True,
177
          'expected_value': -0.17019942,
178
      },
179
      {
180
          'testcase_name': 'larger_reuse_highmem_nomerge',
181
          'best_effort_memory_usage_reduction': False,
182
          'merge_small_dims_block_size': 1,
183
          'reuse_preconditioner': True,
184
          'slightly_larger': True,
185
          'expected_value': -0.17019942,
186
      },
187
  )
188
  def test_distributed_shampoo(
189
      self,
190
      best_effort_memory_usage_reduction=False,
191
      merge_small_dims_block_size=4096,
192
      generate_training_metrics=True,
193
      slightly_larger=False,
194
      expected_value=None,
195
      reuse_preconditioner=False,
196
  ):
197
    params = self.init_params_larger if slightly_larger else self.init_params
198

199
    optim = distributed_shampoo.distributed_shampoo(
200
        0.1,
201
        32,
202
        batch_axis_name='batch',
203
        preconditioning_compute_steps=2,
204
        best_effort_memory_usage_reduction=best_effort_memory_usage_reduction,
205
        merge_small_dims_block_size=merge_small_dims_block_size,
206
        generate_training_metrics=generate_training_metrics,
207
        reuse_preconditioner=reuse_preconditioner,
208
    )
209
    init_fn = self.variant(optim.init)
210
    transform_fn = self.variant(optim.update)
211

212
    if slightly_larger:
213
      updates = self.per_step_updates_larger
214
    else:
215
      updates = self.per_step_updates
216

217
    def _update(unused_batch):
218
      return transform_fn(updates, state, params)
219

220
    state = init_fn(params)
221
    chex.assert_tree_all_finite(state)
222
    pmap_fn = jax.pmap(_update, axis_name='batch')
223

224
    updates, state = pmap_fn(jnp.array([1.0]))
225
    chex.assert_tree_all_finite((params, updates, state))
226
    if expected_value is not None:
227
      last_entry = updates[1][-1, -1, -1]
228
      self.assertLess(
229
          abs(last_entry - expected_value),
230
          1e-4,
231
          msg=f'{last_entry=}, {expected_value=}')
232
    for _ in range(5):
233
      updates, state = pmap_fn(jnp.array([1.0]))
234
      chex.assert_tree_all_finite((params, updates, state))
235

236
  @chex.all_variants(with_pmap=False)
237
  @parameterized.named_parameters([
238
      {
239
          'testcase_name': 'default',
240
      },
241
      {
242
          'testcase_name': 'no_training_metrics',
243
          'generate_training_metrics': False,
244
      },
245
  ])
246
  def test_distributed_shampoo_no_pmap(self, generate_training_metrics=True):
247
    params = self.init_params
248

249
    optim = distributed_shampoo.distributed_shampoo(
250
        0.1,
251
        32,
252
        batch_axis_name=None,
253
        preconditioning_compute_steps=2,
254
        generate_training_metrics=generate_training_metrics)
255
    init_fn = self.variant(optim.init)
256
    transform_fn = self.variant(optim.update)
257
    state = init_fn(params)
258
    chex.assert_tree_all_finite(state)
259
    updates, state = transform_fn(self.per_step_updates, state, params)
260
    chex.assert_tree_all_finite((params, updates, state))
261

262
  def _gen_symmetrix_matrix(self, dim, condition_number):
263
    u = scipy.stats.ortho_group.rvs(
264
        dim=dim, random_state=self.rng).astype(np.float64)
265
    v = u.T
266
    diag = np.diag([condition_number**(-i / (dim - 1)) for i in range(dim)])
267
    return u @ diag @ v
268

269
  def test_matrix_inverse_root(self):
270
    """Test for matrix inverse pth root."""
271

272
    # Fails after it reaches a particular condition number.
273
    for e in range(2, 12):
274
      condition_number = 10**e
275
      ms = self._gen_symmetrix_matrix(16, condition_number)
276
      self.assertLess(
277
          np.abs(np.linalg.cond(ms) - condition_number),
278
          condition_number * 0.01)
279
      metrics = distributed_shampoo.matrix_inverse_pth_root(
280
          ms.astype(np.float32), 4, ridge_epsilon=1e-12)[1]
281
      error = metrics.inverse_pth_root_errors
282
      if e < 7:
283
        self.assertLess(error, 0.1)
284
      else:
285
        # No guarantee of success after e >= 7
286
        pass
287

288
  @parameterized.parameters([{'sz': sz} for sz in [4, 32]])
289
  def test_matrix_inverse_root_padding(self, sz):
290
    """Test padding does not affect result much."""
291

292
    # Note sz == 1 case will not pass tests here b/c the method
293
    # is exact for scalars (but padding triggers coupled iteration).
294

295
    condition_number = 1e3
296
    ms = self._gen_symmetrix_matrix(sz, condition_number).astype(np.float32)
297

298
    # Shift matrix norm down by some large factor, so that improper padding
299
    # handling results in an error by increasing the condition number.
300
    ms = jnp.array(ms) * 1e-3
301

302
    rt, metrics = distributed_shampoo.matrix_inverse_pth_root(
303
        ms, 4, ridge_epsilon=1e-3)
304
    err = metrics.inverse_pth_root_errors
305
    pad_ms = distributed_shampoo.pad_square_matrix(ms, sz * 2)
306
    pad_rt, metrics = distributed_shampoo.matrix_inverse_pth_root(
307
        pad_ms, 4, ridge_epsilon=1e-3, padding_start=sz)
308
    pad_err = metrics.inverse_pth_root_errors
309
    pad_rt_principal = pad_rt[:sz, :sz]
310
    np.testing.assert_allclose(
311
        rt,
312
        pad_rt_principal,
313
        # The fact that this is so large keeps vladf up at night,
314
        # but without padding_start argument it's even worse (>1).
315
        rtol=1e-2,
316
        err_msg=np.array2string(rt - pad_rt_principal))
317
    self.assertLessEqual(pad_err, 4 * err)
318
    self.assertEqual(np.abs(pad_rt[sz:]).sum(), 0)
319
    self.assertEqual(np.abs(pad_rt[:, sz:]).sum(), 0)
320

321
  def test_all_padding(self):
322
    """Test full padding matrix."""
323
    empty = jnp.zeros([0, 0])
324
    padded = distributed_shampoo.pad_square_matrix(empty, 10)
325
    rt, metrics = distributed_shampoo.matrix_inverse_pth_root(
326
        padded, 4, ridge_epsilon=1e-3, padding_start=0)
327
    err = metrics.inverse_pth_root_errors
328
    self.assertEqual(np.abs(rt).sum(), 0.0)
329
    self.assertEqual(np.abs(err).sum(), 0.0)
330

331
  def _make_pth_diff_message(self, w, alpha, beta, p):
332
    left = f'({w} + {alpha})^(-1.0 / {p}) - '
333
    right = f'({w} + {beta})^(-1.0 / {p})'
334
    return left + right
335

336
  @parameterized.parameters(_pth_root_difference_cases())
337
  def test_pth_root_difference(self, p, a, b, w):
338
    """Test stable difference computation."""
339
    pth_rt_diff = jax.jit(
340
        functools.partial(distributed_shampoo._pth_root_difference, p=p))
341
    actual = pth_rt_diff(w, a, b)
342
    # in float64
343
    exp = (-1.0 / p)
344
    expected = (w + a)**exp - (w + b)**exp
345

346
    self.assertAlmostEqual(
347
        actual,
348
        expected,
349
        msg=self._make_pth_diff_message(w, a, b, p),
350
        delta=1e-2)
351

352
  @parameterized.parameters([{'p': p} for p in [2, 4, 8]])
353
  def test_lobpcg_preconditioning(self, p):
354
    """Checks that root calculation is valid with top-k preconditioning."""
355
    rng = np.random.RandomState(seed=42)
356
    n = 11
357
    epsilon = jnp.float32(1e-4)
358
    a_asymm = jnp.array(rng.random((n, n)), jnp.float32)
359
    a = jnp.matmul(a_asymm.T, a_asymm, precision=jax.lax.Precision.HIGHEST)
360
    log2 = (p - 1).bit_length()
361
    assert 2**log2 == p, (p, log2)
362

363
    root = functools.partial(
364
        distributed_shampoo.matrix_inverse_pth_root, ridge_epsilon=epsilon, p=p)
365
    root_lobpcg = functools.partial(
366
        root, lobpcg_topk_precondition=2, lobpcg_max_iter=10)
367

368
    methods = {'default': root, 'precond': root_lobpcg}
369
    spectrum_err, entry_err = {}, {}
370
    for k, method in methods.items():
371
      rt = jax.jit(method)(a)[0]
372

373
      # Recover the inverse by repeated squaring of inverse p-th root.
374
      inv = np.asarray(rt).astype(np.float64)
375
      for _ in range(log2):
376
        inv = inv.dot(inv)
377

378
      approx_id = inv.dot(a)
379
      spectrum = np.linalg.eigvalsh(approx_id)
380
      spectrum_err[k] = np.abs(1 - spectrum)
381
      entry_err[k] = np.mean(np.abs(approx_id - np.eye(n)))
382

383
    with np.printoptions(precision=2):
384

385
      def print_dict(d):
386
        return '\n'.join(f'{k} {v}' for k, v in d.items())
387

388
      err_msg = (f'p={p} log2(p)={log2}\n'
389
                 f'spectrum error\n{print_dict(spectrum_err)}\n'
390
                 f'entry_err\n{print_dict(entry_err)}')
391

392
      self.assertLessEqual(
393
          np.median(spectrum_err['precond']),
394
          2 * np.median(spectrum_err['default']),
395
          msg=err_msg)
396

397
      self.assertLessEqual(
398
          entry_err['precond'], entry_err['default'] * 2, msg=err_msg)
399

400

401
if __name__ == '__main__':
402
  absltest.main()
403

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.