google-research
402 строки · 13.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for distributed_shampoo."""
17
18import functools
19import itertools
20
21from absl.testing import absltest
22from absl.testing import parameterized
23import chex
24import jax
25import jax.numpy as jnp
26import numpy as np
27import scipy
28
29from scalable_shampoo.optax import distributed_shampoo
30
31
32class PaddingTest(parameterized.TestCase):
33
34def assertAllClose(self, x, y, atol=1e-5, rtol=1e-5):
35np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
36
37@parameterized.named_parameters(
38{
39'testcase_name': 'NoPadding',
40'max_size': 3,
41'result': [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]],
42},
43{
44'testcase_name':
45'Padding',
46'max_size':
475,
48'result': [[1., 1., 1., 0., 0.], [1., 1., 1., 0., 0.],
49[1., 1., 1., 0., 0.], [0., 0., 0., 1., 0.],
50[0., 0., 0., 0., 1.]],
51},
52)
53def test_pad_square_matrix(self, max_size, result):
54self.assertAllClose(
55distributed_shampoo.pad_square_matrix(
56mat=jnp.ones(shape=(3, 3), dtype=jnp.float32), max_size=max_size),
57jnp.asarray(result, dtype=jnp.float32))
58
59@parameterized.named_parameters(
60{
61'testcase_name': 'TooLarge',
62'shape': (3, 3),
63'max_size': 2
64},
65{
66'testcase_name': 'NotSquare',
67'shape': (3, 4),
68'max_size': 5
69},
70)
71def test_pad_square_matrix_error(self, shape, max_size):
72with self.assertRaises(ValueError):
73distributed_shampoo.pad_square_matrix(
74mat=jnp.ones(shape=shape), max_size=max_size)
75
76
77def _pth_root_difference_cases():
78"""Returns cases for _pth_root_difference() test."""
79cases = []
80# The test checks accuracy of
81# (w + a)^(-1/p) - (w + b)^(-1/p)
82# so generate corresponding parameters.
83p_vals = [2, 4, 6, 8]
84a_vals = b_vals = [1e-6, 1e-5, 0.0, 1.0]
85w_vals = [1e-6, 1e-5, 1.0, 1e3]
86for p, a, b, w in itertools.product(p_vals, a_vals, b_vals, w_vals):
87cases.append({'p': p, 'a': a, 'b': b, 'w': w})
88return cases
89
90
91class DistributedShampooTest(chex.TestCase, parameterized.TestCase):
92
93def setUp(self):
94super().setUp()
95self.init_params = (jnp.array([[1., 3.],
96[2., 4.]]), jnp.array([[3., 4.], [3., 4.]]))
97self.per_step_updates = (jnp.array([[500., 5.], [500., 5.]]),
98jnp.array([[300., 3.], [300., 3.]]))
99self.per_step_updates_custom_preconditioner = (self.per_step_updates,
100(jnp.array([[200., 4.],
101[200., 4.]]),
102jnp.array([[600., 2.],
103[600., 2.]])))
104self.rng = np.random.default_rng(1234)
105shape = ([2, 5], [6, 3])
106dt = self.init_params[0].dtype
107
108def make_shape(bigger_first_entry):
109x = tuple(self.rng.standard_normal(size=s) for s in shape)
110if bigger_first_entry:
111for xx in x:
112xx[Ellipsis, 0] *= 100
113return tuple(jnp.array(xx).astype(dt) for xx in x)
114
115self.init_params_larger = make_shape(False)
116self.per_step_updates_larger = make_shape(True)
117
118@chex.all_variants(with_pmap=False)
119@parameterized.named_parameters(
120{
121'testcase_name': 'default',
122'best_effort_memory_usage_reduction': True,
123'expected_value': -0.57,
124},
125{
126'testcase_name': 'default_nomerge',
127'best_effort_memory_usage_reduction': True,
128'merge_small_dims_block_size': 1,
129'expected_value': -0.57,
130},
131{
132'testcase_name': 'default_larger',
133'best_effort_memory_usage_reduction': True,
134'slightly_larger': True,
135'expected_value': -0.17019942,
136},
137{
138'testcase_name': 'default_larger_nomerge',
139'best_effort_memory_usage_reduction': True,
140'slightly_larger': True,
141'merge_small_dims_block_size': 1,
142'expected_value': -0.17019942,
143},
144{
145'testcase_name': 'materialize_statistics',
146'best_effort_memory_usage_reduction': True,
147},
148{
149'testcase_name': 'blocked_statistics',
150'best_effort_memory_usage_reduction': True,
151},
152{
153'testcase_name': 'default_quantized',
154},
155{
156'testcase_name': 'materialize_statistics_quantized',
157},
158{
159'testcase_name': 'blocked_statistics_quantized',
160},
161{
162'testcase_name': 'no_training_metrics',
163'generate_training_metrics': False,
164},
165{
166'testcase_name': 'larger_reuse',
167'best_effort_memory_usage_reduction': True,
168'reuse_preconditioner': True,
169'slightly_larger': True,
170'expected_value': -0.17019942,
171},
172{
173'testcase_name': 'larger_reuse_highmem',
174'best_effort_memory_usage_reduction': False,
175'reuse_preconditioner': True,
176'slightly_larger': True,
177'expected_value': -0.17019942,
178},
179{
180'testcase_name': 'larger_reuse_highmem_nomerge',
181'best_effort_memory_usage_reduction': False,
182'merge_small_dims_block_size': 1,
183'reuse_preconditioner': True,
184'slightly_larger': True,
185'expected_value': -0.17019942,
186},
187)
188def test_distributed_shampoo(
189self,
190best_effort_memory_usage_reduction=False,
191merge_small_dims_block_size=4096,
192generate_training_metrics=True,
193slightly_larger=False,
194expected_value=None,
195reuse_preconditioner=False,
196):
197params = self.init_params_larger if slightly_larger else self.init_params
198
199optim = distributed_shampoo.distributed_shampoo(
2000.1,
20132,
202batch_axis_name='batch',
203preconditioning_compute_steps=2,
204best_effort_memory_usage_reduction=best_effort_memory_usage_reduction,
205merge_small_dims_block_size=merge_small_dims_block_size,
206generate_training_metrics=generate_training_metrics,
207reuse_preconditioner=reuse_preconditioner,
208)
209init_fn = self.variant(optim.init)
210transform_fn = self.variant(optim.update)
211
212if slightly_larger:
213updates = self.per_step_updates_larger
214else:
215updates = self.per_step_updates
216
217def _update(unused_batch):
218return transform_fn(updates, state, params)
219
220state = init_fn(params)
221chex.assert_tree_all_finite(state)
222pmap_fn = jax.pmap(_update, axis_name='batch')
223
224updates, state = pmap_fn(jnp.array([1.0]))
225chex.assert_tree_all_finite((params, updates, state))
226if expected_value is not None:
227last_entry = updates[1][-1, -1, -1]
228self.assertLess(
229abs(last_entry - expected_value),
2301e-4,
231msg=f'{last_entry=}, {expected_value=}')
232for _ in range(5):
233updates, state = pmap_fn(jnp.array([1.0]))
234chex.assert_tree_all_finite((params, updates, state))
235
236@chex.all_variants(with_pmap=False)
237@parameterized.named_parameters([
238{
239'testcase_name': 'default',
240},
241{
242'testcase_name': 'no_training_metrics',
243'generate_training_metrics': False,
244},
245])
246def test_distributed_shampoo_no_pmap(self, generate_training_metrics=True):
247params = self.init_params
248
249optim = distributed_shampoo.distributed_shampoo(
2500.1,
25132,
252batch_axis_name=None,
253preconditioning_compute_steps=2,
254generate_training_metrics=generate_training_metrics)
255init_fn = self.variant(optim.init)
256transform_fn = self.variant(optim.update)
257state = init_fn(params)
258chex.assert_tree_all_finite(state)
259updates, state = transform_fn(self.per_step_updates, state, params)
260chex.assert_tree_all_finite((params, updates, state))
261
262def _gen_symmetrix_matrix(self, dim, condition_number):
263u = scipy.stats.ortho_group.rvs(
264dim=dim, random_state=self.rng).astype(np.float64)
265v = u.T
266diag = np.diag([condition_number**(-i / (dim - 1)) for i in range(dim)])
267return u @ diag @ v
268
269def test_matrix_inverse_root(self):
270"""Test for matrix inverse pth root."""
271
272# Fails after it reaches a particular condition number.
273for e in range(2, 12):
274condition_number = 10**e
275ms = self._gen_symmetrix_matrix(16, condition_number)
276self.assertLess(
277np.abs(np.linalg.cond(ms) - condition_number),
278condition_number * 0.01)
279metrics = distributed_shampoo.matrix_inverse_pth_root(
280ms.astype(np.float32), 4, ridge_epsilon=1e-12)[1]
281error = metrics.inverse_pth_root_errors
282if e < 7:
283self.assertLess(error, 0.1)
284else:
285# No guarantee of success after e >= 7
286pass
287
288@parameterized.parameters([{'sz': sz} for sz in [4, 32]])
289def test_matrix_inverse_root_padding(self, sz):
290"""Test padding does not affect result much."""
291
292# Note sz == 1 case will not pass tests here b/c the method
293# is exact for scalars (but padding triggers coupled iteration).
294
295condition_number = 1e3
296ms = self._gen_symmetrix_matrix(sz, condition_number).astype(np.float32)
297
298# Shift matrix norm down by some large factor, so that improper padding
299# handling results in an error by increasing the condition number.
300ms = jnp.array(ms) * 1e-3
301
302rt, metrics = distributed_shampoo.matrix_inverse_pth_root(
303ms, 4, ridge_epsilon=1e-3)
304err = metrics.inverse_pth_root_errors
305pad_ms = distributed_shampoo.pad_square_matrix(ms, sz * 2)
306pad_rt, metrics = distributed_shampoo.matrix_inverse_pth_root(
307pad_ms, 4, ridge_epsilon=1e-3, padding_start=sz)
308pad_err = metrics.inverse_pth_root_errors
309pad_rt_principal = pad_rt[:sz, :sz]
310np.testing.assert_allclose(
311rt,
312pad_rt_principal,
313# The fact that this is so large keeps vladf up at night,
314# but without padding_start argument it's even worse (>1).
315rtol=1e-2,
316err_msg=np.array2string(rt - pad_rt_principal))
317self.assertLessEqual(pad_err, 4 * err)
318self.assertEqual(np.abs(pad_rt[sz:]).sum(), 0)
319self.assertEqual(np.abs(pad_rt[:, sz:]).sum(), 0)
320
321def test_all_padding(self):
322"""Test full padding matrix."""
323empty = jnp.zeros([0, 0])
324padded = distributed_shampoo.pad_square_matrix(empty, 10)
325rt, metrics = distributed_shampoo.matrix_inverse_pth_root(
326padded, 4, ridge_epsilon=1e-3, padding_start=0)
327err = metrics.inverse_pth_root_errors
328self.assertEqual(np.abs(rt).sum(), 0.0)
329self.assertEqual(np.abs(err).sum(), 0.0)
330
331def _make_pth_diff_message(self, w, alpha, beta, p):
332left = f'({w} + {alpha})^(-1.0 / {p}) - '
333right = f'({w} + {beta})^(-1.0 / {p})'
334return left + right
335
336@parameterized.parameters(_pth_root_difference_cases())
337def test_pth_root_difference(self, p, a, b, w):
338"""Test stable difference computation."""
339pth_rt_diff = jax.jit(
340functools.partial(distributed_shampoo._pth_root_difference, p=p))
341actual = pth_rt_diff(w, a, b)
342# in float64
343exp = (-1.0 / p)
344expected = (w + a)**exp - (w + b)**exp
345
346self.assertAlmostEqual(
347actual,
348expected,
349msg=self._make_pth_diff_message(w, a, b, p),
350delta=1e-2)
351
352@parameterized.parameters([{'p': p} for p in [2, 4, 8]])
353def test_lobpcg_preconditioning(self, p):
354"""Checks that root calculation is valid with top-k preconditioning."""
355rng = np.random.RandomState(seed=42)
356n = 11
357epsilon = jnp.float32(1e-4)
358a_asymm = jnp.array(rng.random((n, n)), jnp.float32)
359a = jnp.matmul(a_asymm.T, a_asymm, precision=jax.lax.Precision.HIGHEST)
360log2 = (p - 1).bit_length()
361assert 2**log2 == p, (p, log2)
362
363root = functools.partial(
364distributed_shampoo.matrix_inverse_pth_root, ridge_epsilon=epsilon, p=p)
365root_lobpcg = functools.partial(
366root, lobpcg_topk_precondition=2, lobpcg_max_iter=10)
367
368methods = {'default': root, 'precond': root_lobpcg}
369spectrum_err, entry_err = {}, {}
370for k, method in methods.items():
371rt = jax.jit(method)(a)[0]
372
373# Recover the inverse by repeated squaring of inverse p-th root.
374inv = np.asarray(rt).astype(np.float64)
375for _ in range(log2):
376inv = inv.dot(inv)
377
378approx_id = inv.dot(a)
379spectrum = np.linalg.eigvalsh(approx_id)
380spectrum_err[k] = np.abs(1 - spectrum)
381entry_err[k] = np.mean(np.abs(approx_id - np.eye(n)))
382
383with np.printoptions(precision=2):
384
385def print_dict(d):
386return '\n'.join(f'{k} {v}' for k, v in d.items())
387
388err_msg = (f'p={p} log2(p)={log2}\n'
389f'spectrum error\n{print_dict(spectrum_err)}\n'
390f'entry_err\n{print_dict(entry_err)}')
391
392self.assertLessEqual(
393np.median(spectrum_err['precond']),
3942 * np.median(spectrum_err['default']),
395msg=err_msg)
396
397self.assertLessEqual(
398entry_err['precond'], entry_err['default'] * 2, msg=err_msg)
399
400
401if __name__ == '__main__':
402absltest.main()
403