google-research
165 строк · 6.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for Distributed Shampoo."""
17
18from absl.testing import absltest
19from flax import optim
20import numpy as np
21import scipy.stats
22
23from scalable_shampoo.jax import shampoo
24
25
26class ShampooTest(absltest.TestCase):
27"""Test cases for Distributed Shampoo."""
28
29def test_init_state(self):
30# Create an optimizer def and check the params are wired through.
31optimizer_def = shampoo.Shampoo(
32learning_rate=0.1,
33beta1=0.9,
34beta2=0.9,
35diagonal_epsilon=0.0,
36matrix_epsilon=1e-1,
37exponent_override=2,
38weight_decay=1e-4,
39start_preconditioning_step=1,
40preconditioning_compute_steps=1,
41statistics_compute_steps=1,
42best_effort_shape_interpretation=True,
43block_size=8,
44no_preconditioning_for_layers_with_dim_gt=1024,
45graft_type=shampoo.LayerwiseGrafting.SGD,
46nesterov=False,
47batch_axis_name=None)
48expected_hyper_params = shampoo._ShampooHyperParams(
49learning_rate=0.1,
50beta1=0.9,
51beta2=0.9,
52diagonal_eps=0.0,
53matrix_eps=1e-1,
54exponent_override=2,
55weight_decay=1e-4,
56start_preconditioning_step=1,
57preconditioning_compute_steps=1,
58statistics_compute_steps=1,
59best_effort_shape_interpretation=True,
60block_size=8,
61no_preconditioning_for_layers_with_dim_gt=1024,
62graft_type=shampoo.LayerwiseGrafting.SGD,
63nesterov=False,
64batch_axis_name=None)
65self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
66
67params = np.zeros((1,))
68state = optimizer_def.init_state(params)
69zeros_like_param = np.zeros((1,))
70expected_state = optim.OptimizerState(
710,
72shampoo._ShampooDefaultParamState([], [], [], zeros_like_param,
73zeros_like_param))
74self.assertEqual(state, expected_state)
75
76params = np.zeros((8,))
77state = optimizer_def.init_state(params)
78identity = np.eye(8)
79statistic = identity * 1e-1 # I * matrix_epsilon
80preconditioner = identity
81self.assertLen(state.param_states.statistics, 1)
82self.assertLen(state.param_states.statistics, 1)
83np.testing.assert_allclose(state.param_states.statistics[0], statistic)
84np.testing.assert_allclose(state.param_states.preconditioners[0],
85preconditioner)
86
87params = np.zeros((8, 8))
88state = optimizer_def.init_state(params)
89identity = np.eye(8)
90statistic = identity * 1e-1 # I * matrix_epsilon
91preconditioner = identity
92self.assertLen(state.param_states.statistics, 2)
93self.assertLen(state.param_states.statistics, 2)
94np.testing.assert_allclose(state.param_states.statistics[0], statistic)
95np.testing.assert_allclose(state.param_states.statistics[1], statistic)
96np.testing.assert_allclose(state.param_states.preconditioners[0],
97preconditioner)
98np.testing.assert_allclose(state.param_states.preconditioners[1],
99preconditioner)
100
101params = np.zeros((16, 16))
102state = optimizer_def.init_state(params)
103zeros_like_param = np.zeros((8,))
104identity = np.eye(8)
105statistic = identity * 1e-1 # I * matrix_epsilon
106preconditioner = identity
107self.assertLen(state.param_states.statistics, 8)
108self.assertLen(state.param_states.statistics, 8)
109for i in range(8):
110np.testing.assert_allclose(state.param_states.statistics[i], statistic)
111np.testing.assert_allclose(state.param_states.preconditioners[i],
112preconditioner)
113
114# Test best_effort_shape_interpretation
115# (3, 2, 16) wil be reshaped to (6, 16)
116# Last dim will be split into two (6, 8) and (6, 8)
117params = np.zeros((3, 2, 16))
118state = optimizer_def.init_state(params)
119zeros_like_param = np.zeros((8,))
120identity_left = np.eye(6)
121statistic_left = identity_left * 1e-1 # I * matrix_epsilon
122preconditioner_left = identity_left
123identity_right = np.eye(8)
124statistic_right = identity_right * 1e-1 # I * matrix_epsilon
125preconditioner_right = identity_right
126self.assertLen(state.param_states.statistics, 4)
127self.assertLen(state.param_states.statistics, 4)
128for i in range(4):
129if i % 2 == 0:
130np.testing.assert_allclose(state.param_states.statistics[i],
131statistic_left)
132np.testing.assert_allclose(state.param_states.preconditioners[i],
133preconditioner_left)
134else:
135np.testing.assert_allclose(state.param_states.statistics[i],
136statistic_right)
137np.testing.assert_allclose(state.param_states.preconditioners[i],
138preconditioner_right)
139
140def test_matrix_inverse_root(self):
141"""Test for matrix inverse pth root."""
142
143def _gen_symmetrix_matrix(dim, condition_number):
144u = scipy.stats.ortho_group.rvs(dim=dim).astype(np.float64)
145v = u.T
146diag = np.diag([condition_number ** (-i/(dim-1)) for i in range(dim)])
147return u @ diag @ v
148
149# Fails after it reaches a particular condition number.
150for e in range(2, 12):
151condition_number = 10 ** e
152ms = _gen_symmetrix_matrix(16, condition_number)
153self.assertLess(
154np.abs(np.linalg.cond(ms) - condition_number),
155condition_number * 0.01)
156error = shampoo.matrix_inverse_pth_root(
157ms.astype(np.float32), 4, ridge_epsilon=1e-12)[1]
158if e < 7:
159self.assertLess(error, 0.1)
160else:
161# No guarantee of success after e >= 7
162pass
163
164if __name__ == '__main__':
165absltest.main()
166