google-research

Форк
0
165 строк · 6.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for Distributed Shampoo."""
17

18
from absl.testing import absltest
19
from flax import optim
20
import numpy as np
21
import scipy.stats
22

23
from scalable_shampoo.jax import shampoo
24

25

26
class ShampooTest(absltest.TestCase):
27
  """Test cases for Distributed Shampoo."""
28

29
  def test_init_state(self):
30
    # Create an optimizer def and check the params are wired through.
31
    optimizer_def = shampoo.Shampoo(
32
        learning_rate=0.1,
33
        beta1=0.9,
34
        beta2=0.9,
35
        diagonal_epsilon=0.0,
36
        matrix_epsilon=1e-1,
37
        exponent_override=2,
38
        weight_decay=1e-4,
39
        start_preconditioning_step=1,
40
        preconditioning_compute_steps=1,
41
        statistics_compute_steps=1,
42
        best_effort_shape_interpretation=True,
43
        block_size=8,
44
        no_preconditioning_for_layers_with_dim_gt=1024,
45
        graft_type=shampoo.LayerwiseGrafting.SGD,
46
        nesterov=False,
47
        batch_axis_name=None)
48
    expected_hyper_params = shampoo._ShampooHyperParams(
49
        learning_rate=0.1,
50
        beta1=0.9,
51
        beta2=0.9,
52
        diagonal_eps=0.0,
53
        matrix_eps=1e-1,
54
        exponent_override=2,
55
        weight_decay=1e-4,
56
        start_preconditioning_step=1,
57
        preconditioning_compute_steps=1,
58
        statistics_compute_steps=1,
59
        best_effort_shape_interpretation=True,
60
        block_size=8,
61
        no_preconditioning_for_layers_with_dim_gt=1024,
62
        graft_type=shampoo.LayerwiseGrafting.SGD,
63
        nesterov=False,
64
        batch_axis_name=None)
65
    self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
66

67
    params = np.zeros((1,))
68
    state = optimizer_def.init_state(params)
69
    zeros_like_param = np.zeros((1,))
70
    expected_state = optim.OptimizerState(
71
        0,
72
        shampoo._ShampooDefaultParamState([], [], [], zeros_like_param,
73
                                          zeros_like_param))
74
    self.assertEqual(state, expected_state)
75

76
    params = np.zeros((8,))
77
    state = optimizer_def.init_state(params)
78
    identity = np.eye(8)
79
    statistic = identity * 1e-1  #  I * matrix_epsilon
80
    preconditioner = identity
81
    self.assertLen(state.param_states.statistics, 1)
82
    self.assertLen(state.param_states.statistics, 1)
83
    np.testing.assert_allclose(state.param_states.statistics[0], statistic)
84
    np.testing.assert_allclose(state.param_states.preconditioners[0],
85
                               preconditioner)
86

87
    params = np.zeros((8, 8))
88
    state = optimizer_def.init_state(params)
89
    identity = np.eye(8)
90
    statistic = identity * 1e-1  #  I * matrix_epsilon
91
    preconditioner = identity
92
    self.assertLen(state.param_states.statistics, 2)
93
    self.assertLen(state.param_states.statistics, 2)
94
    np.testing.assert_allclose(state.param_states.statistics[0], statistic)
95
    np.testing.assert_allclose(state.param_states.statistics[1], statistic)
96
    np.testing.assert_allclose(state.param_states.preconditioners[0],
97
                               preconditioner)
98
    np.testing.assert_allclose(state.param_states.preconditioners[1],
99
                               preconditioner)
100

101
    params = np.zeros((16, 16))
102
    state = optimizer_def.init_state(params)
103
    zeros_like_param = np.zeros((8,))
104
    identity = np.eye(8)
105
    statistic = identity * 1e-1  #  I * matrix_epsilon
106
    preconditioner = identity
107
    self.assertLen(state.param_states.statistics, 8)
108
    self.assertLen(state.param_states.statistics, 8)
109
    for i in range(8):
110
      np.testing.assert_allclose(state.param_states.statistics[i], statistic)
111
      np.testing.assert_allclose(state.param_states.preconditioners[i],
112
                                 preconditioner)
113

114
    # Test best_effort_shape_interpretation
115
    # (3, 2, 16) wil be reshaped to (6, 16)
116
    # Last dim will be split into two (6, 8) and (6, 8)
117
    params = np.zeros((3, 2, 16))
118
    state = optimizer_def.init_state(params)
119
    zeros_like_param = np.zeros((8,))
120
    identity_left = np.eye(6)
121
    statistic_left = identity_left * 1e-1  #  I * matrix_epsilon
122
    preconditioner_left = identity_left
123
    identity_right = np.eye(8)
124
    statistic_right = identity_right * 1e-1  #  I * matrix_epsilon
125
    preconditioner_right = identity_right
126
    self.assertLen(state.param_states.statistics, 4)
127
    self.assertLen(state.param_states.statistics, 4)
128
    for i in range(4):
129
      if i % 2 == 0:
130
        np.testing.assert_allclose(state.param_states.statistics[i],
131
                                   statistic_left)
132
        np.testing.assert_allclose(state.param_states.preconditioners[i],
133
                                   preconditioner_left)
134
      else:
135
        np.testing.assert_allclose(state.param_states.statistics[i],
136
                                   statistic_right)
137
        np.testing.assert_allclose(state.param_states.preconditioners[i],
138
                                   preconditioner_right)
139

140
  def test_matrix_inverse_root(self):
141
    """Test for matrix inverse pth root."""
142

143
    def _gen_symmetrix_matrix(dim, condition_number):
144
      u = scipy.stats.ortho_group.rvs(dim=dim).astype(np.float64)
145
      v = u.T
146
      diag = np.diag([condition_number ** (-i/(dim-1)) for i in range(dim)])
147
      return u @ diag @ v
148

149
    # Fails after it reaches a particular condition number.
150
    for e in range(2, 12):
151
      condition_number = 10 ** e
152
      ms = _gen_symmetrix_matrix(16, condition_number)
153
      self.assertLess(
154
          np.abs(np.linalg.cond(ms) - condition_number),
155
          condition_number * 0.01)
156
      error = shampoo.matrix_inverse_pth_root(
157
          ms.astype(np.float32), 4, ridge_epsilon=1e-12)[1]
158
      if e < 7:
159
        self.assertLess(error, 0.1)
160
      else:
161
        # No guarantee of success after e >= 7
162
        pass
163

164
if __name__ == '__main__':
165
  absltest.main()
166

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.