pytorch

Форк
0
/
test_named_optimizer.py 
427 строк · 14.6 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
# Copyright (c) Meta Platforms, Inc. and affiliates.
4
# All rights reserved.
5
#
6
# This source code is licensed under the BSD-style license found in the
7
# LICENSE file in the root directory of this source tree.
8

9
import unittest
10

11
import torch
12
import torch.nn as nn
13

14
from torch.distributed.optim import _NamedOptimizer
15

16

17
def _run_model_training(model_optim_lists):
18
    for _ in range(2):
19
        x = torch.rand(5, 8)
20
        for model_optim_list in model_optim_lists:
21
            model = model_optim_list[0]
22
            optim_list = model_optim_list[1]
23
            y = model(x)
24
            y.sum().backward()
25
            for optim in optim_list:
26
                optim.step()
27

28

29
class TestDummyModel(torch.nn.Module):
30
    def __init__(self):
31
        super().__init__()
32
        torch.manual_seed(0)
33
        self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
34
        self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
35
        self.net3 = nn.Linear(32, 64)
36
        self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
37

38
    def forward(self, x):
39
        return self.net4(self.net3(self.net2(self.net1(x))))
40

41

42
class NamedOptimizerTest(unittest.TestCase):
43
    def _compare_state_dict_group(self, group, named_group, assert_equal=True):
44
        for key, val in group.items():
45
            if key != "params":
46
                self.assertTrue(
47
                    key in named_group, f"{key} not in named optimizer state dict"
48
                )
49
                err_msg = (
50
                    f"{key} state not equal" if assert_equal else f"{key} state equal"
51
                )
52
                if isinstance(val, torch.Tensor):
53
                    fn = self.assertTrue if assert_equal else self.assertFalse
54
                    fn(torch.allclose(val, named_group[key]), err_msg)
55
                else:
56
                    fn = self.assertEqual if assert_equal else self.assertNotEqual
57
                    fn(val, named_group[key], err_msg)
58

59
    def _compare_param_groups(self, param_groups_1, param_groups_2):
60
        self.assertTrue(isinstance(param_groups_1, list))
61
        self.assertTrue(isinstance(param_groups_2, list))
62
        for groups in zip(param_groups_1, param_groups_2):
63
            self._compare_param_group(groups[0], groups[1])
64

65
    def _compare_param_group(self, group_1, group_2):
66
        self.assertTrue(isinstance(group_1, dict))
67
        self.assertTrue(isinstance(group_2, dict))
68
        for key, val in group_1.items():
69
            self.assertTrue(key in group_2)
70
            if key != "params":
71
                self.assertEqual(val, group_2[key])
72
            else:
73
                for tensors in zip(val, group_2[key]):
74
                    self.assertTrue(torch.allclose(tensors[0], tensors[1]))
75

76
    def test_state_dict(self):
77
        """Check that NamedOptimizer exposes the expected state dict
78
        interface."""
79
        m = TestDummyModel()
80
        m_dup = TestDummyModel()
81
        optim = torch.optim.SGD(
82
            m.parameters(),
83
            lr=1e-2,
84
            momentum=0.9,
85
        )
86

87
        named_optim = _NamedOptimizer(
88
            m_dup.named_parameters(),
89
            torch.optim.SGD,
90
            lr=1e-2,
91
            momentum=0.9,
92
        )
93
        self._compare_param_groups(optim.param_groups, named_optim.param_groups)
94

95
        _run_model_training([(m, [optim]), (m_dup, [named_optim])])
96
        self._compare_param_groups(optim.param_groups, named_optim.param_groups)
97

98
        sd = optim.state_dict()
99
        named_sd = named_optim.state_dict()
100

101
        # Compare "state" in optim state dict
102
        self._compare_state_dict_group(
103
            sd["state"][0],
104
            named_sd["state"]["net1.0.weight"],
105
            assert_equal=True,
106
        )
107
        self._compare_state_dict_group(
108
            sd["state"][3],
109
            named_sd["state"]["net2.0.bias"],
110
            assert_equal=True,
111
        )
112
        self._compare_state_dict_group(
113
            sd["state"][4],
114
            named_sd["state"]["net3.weight"],
115
            assert_equal=True,
116
        )
117
        self._compare_state_dict_group(
118
            sd["state"][7],
119
            named_sd["state"]["net4.1.bias"],
120
            assert_equal=True,
121
        )
122

123
    def test_state_dict_multi_param_group(self):
124
        """Check that NamedOptimizer exposes the expected state dict
125
        interface when multiple param groups are specified."""
126
        m = TestDummyModel()
127
        m_dup = TestDummyModel()
128
        optim_1 = torch.optim.SGD(
129
            [
130
                {"params": m.net1.parameters()},
131
                {"params": m.net3.parameters(), "lr": 1e-3},
132
            ],
133
            lr=1e-2,
134
            momentum=0.9,
135
        )
136

137
        optim_2 = torch.optim.Adam(
138
            [
139
                {"params": m.net2.parameters()},
140
                {"params": m.net4.parameters(), "lr": 1e-5},
141
            ]
142
        )
143

144
        named_optim_1 = _NamedOptimizer(
145
            m_dup.named_parameters(),
146
            torch.optim.SGD,
147
            [
148
                {"params": m_dup.net1.parameters()},
149
                {"params": m_dup.net3.parameters(), "lr": 1e-3},
150
            ],
151
            lr=1e-2,
152
            momentum=0.9,
153
        )
154

155
        named_optim_2 = _NamedOptimizer(
156
            m_dup.named_parameters(),
157
            torch.optim.Adam,
158
            [
159
                {"params": m_dup.net2.parameters()},
160
                {"params": m_dup.net4.parameters(), "lr": 1e-5},
161
            ],
162
        )
163
        self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)
164
        self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)
165

166
        _run_model_training(
167
            [(m, [optim_1, optim_2]), (m_dup, [named_optim_1, named_optim_2])]
168
        )
169
        self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)
170
        self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)
171
        sd_1 = optim_1.state_dict()
172
        sd_2 = optim_2.state_dict()
173
        named_sd_1 = named_optim_1.state_dict()
174
        named_sd_2 = named_optim_2.state_dict()
175

176
        # Compare "state" in optim state dict
177
        self._compare_state_dict_group(
178
            sd_1["state"][0],
179
            named_sd_1["state"]["net1.0.weight"],
180
            assert_equal=True,
181
        )
182
        self._compare_state_dict_group(
183
            sd_2["state"][1],
184
            named_sd_2["state"]["net2.0.bias"],
185
            assert_equal=True,
186
        )
187
        self._compare_state_dict_group(
188
            sd_1["state"][2],
189
            named_sd_1["state"]["net3.weight"],
190
            assert_equal=True,
191
        )
192
        self._compare_state_dict_group(
193
            sd_2["state"][3],
194
            named_sd_2["state"]["net4.1.bias"],
195
            assert_equal=True,
196
        )
197

198
        # Compare "param_groups" in optim state dict
199
        self._compare_state_dict_group(
200
            sd_1["param_groups"][0],
201
            named_sd_1["param_groups"][0],
202
            assert_equal=True,
203
        )
204
        self._compare_state_dict_group(
205
            sd_2["param_groups"][1], named_sd_2["param_groups"][1], assert_equal=True
206
        )
207

208
    def test_load_state_dict(self):
209
        """Check that NamedOptimizer's load_state_dict works as expected."""
210
        m = TestDummyModel()
211
        named_optim_1 = _NamedOptimizer(
212
            m.named_parameters(),
213
            torch.optim.SGD,
214
            lr=1e-2,
215
            momentum=0.9,
216
        )
217

218
        _run_model_training([(m, [named_optim_1])])
219
        state_dict_to_load = named_optim_1.state_dict()
220

221
        named_optim_2 = _NamedOptimizer(
222
            m.named_parameters(),
223
            torch.optim.SGD,
224
            lr=1e-2,
225
            momentum=0.6,
226
        )
227

228
        _run_model_training([(m, [named_optim_2])])
229
        state_dict_before_load = named_optim_2.state_dict()
230

231
        # Compare "state" in optim state dict
232
        self._compare_state_dict_group(
233
            state_dict_to_load["state"]["net1.0.weight"],
234
            state_dict_before_load["state"]["net1.0.weight"],
235
            assert_equal=False,
236
        )
237
        self._compare_state_dict_group(
238
            state_dict_to_load["state"]["net2.0.bias"],
239
            state_dict_before_load["state"]["net2.0.bias"],
240
            assert_equal=False,
241
        )
242
        self._compare_state_dict_group(
243
            state_dict_to_load["state"]["net3.weight"],
244
            state_dict_before_load["state"]["net3.weight"],
245
            assert_equal=False,
246
        )
247
        self._compare_state_dict_group(
248
            state_dict_to_load["state"]["net4.1.bias"],
249
            state_dict_before_load["state"]["net4.1.bias"],
250
            assert_equal=False,
251
        )
252

253
        named_optim_2.load_state_dict(state_dict_to_load)
254
        state_dict_after_load = named_optim_2.state_dict()
255

256
        # Compare "state" in optim state dict
257
        self._compare_state_dict_group(
258
            state_dict_to_load["state"]["net1.0.weight"],
259
            state_dict_after_load["state"]["net1.0.weight"],
260
            assert_equal=True,
261
        )
262
        self._compare_state_dict_group(
263
            state_dict_to_load["state"]["net2.0.bias"],
264
            state_dict_after_load["state"]["net2.0.bias"],
265
            assert_equal=True,
266
        )
267
        self._compare_state_dict_group(
268
            state_dict_to_load["state"]["net3.weight"],
269
            state_dict_after_load["state"]["net3.weight"],
270
            assert_equal=True,
271
        )
272
        self._compare_state_dict_group(
273
            state_dict_to_load["state"]["net4.1.bias"],
274
            state_dict_after_load["state"]["net4.1.bias"],
275
            assert_equal=True,
276
        )
277

278
    def test_load_state_dict_conditional_training(self):
279
        """Check that NamedOptimizer load_state_dict works under conditional training case."""
280
        m = TestDummyModel()
281
        named_optim_1 = _NamedOptimizer(
282
            m.named_parameters(),
283
            torch.optim.SGD,
284
            [
285
                {"params": m.net1.parameters()},
286
                {"params": m.net3.parameters(), "lr": 1e-3},
287
            ],
288
            lr=1e-2,
289
            momentum=0.9,
290
        )
291

292
        _run_model_training([(m, [named_optim_1])])
293
        state_dict_to_load = named_optim_1.state_dict()
294

295
        named_optim_2 = _NamedOptimizer(
296
            m.named_parameters(),
297
            torch.optim.SGD,
298
            lr=1e-2,
299
            momentum=0.6,
300
        )
301

302
        _run_model_training([(m, [named_optim_2])])
303
        named_optim_2.load_state_dict(state_dict_to_load)
304
        state_dict_after_load = named_optim_2.state_dict()
305

306
        # Compare "state" in optim state dict
307
        self._compare_state_dict_group(
308
            state_dict_to_load["state"]["net1.0.weight"],
309
            state_dict_after_load["state"]["net1.0.weight"],
310
            assert_equal=True,
311
        )
312
        self._compare_state_dict_group(
313
            state_dict_to_load["state"]["net3.weight"],
314
            state_dict_after_load["state"]["net3.weight"],
315
            assert_equal=True,
316
        )
317

318
    def test_load_state_dict_error(self):
319
        m = TestDummyModel()
320
        named_optim_1 = _NamedOptimizer(
321
            m.named_parameters(),
322
            torch.optim.SGD,
323
            lr=1e-2,
324
            momentum=0.9,
325
        )
326

327
        _run_model_training([(m, [named_optim_1])])
328
        state_dict_to_load = named_optim_1.state_dict()
329

330
        named_optim_2 = _NamedOptimizer(
331
            m.named_parameters(),
332
            torch.optim.SGD,
333
            lr=1e-2,
334
            momentum=0.6,
335
        )
336

337
        err_msg = (
338
            "Expects the optim to be initialized before load but found not initialized"
339
        )
340
        with self.assertRaisesRegex(ValueError, err_msg):
341
            named_optim_2.load_state_dict(state_dict_to_load)
342

343
    def test_add_param_group(self):
344
        m = TestDummyModel()
345
        m_dup = TestDummyModel()
346
        optim = torch.optim.SGD(
347
            [
348
                {"params": m.net1.parameters()},
349
                {"params": m.net3.parameters(), "lr": 1e-3},
350
            ],
351
            lr=1e-2,
352
            momentum=0.9,
353
        )
354
        named_optim = _NamedOptimizer(
355
            m_dup.named_parameters(),
356
            torch.optim.SGD,
357
            [
358
                {"params": m_dup.net1.parameters()},
359
                {"params": m_dup.net3.parameters(), "lr": 1e-3},
360
            ],
361
            lr=1e-2,
362
            momentum=0.9,
363
        )
364

365
        _run_model_training([(m, [optim]), (m_dup, [named_optim])])
366
        self._compare_param_groups(optim.param_groups, named_optim.param_groups)
367

368
        optim.add_param_group({"params": m.net2.parameters(), "lr": 1e-5})
369
        named_optim.add_param_group({"params": m_dup.net2.parameters(), "lr": 1e-5})
370
        _run_model_training([(m, [optim]), (m_dup, [named_optim])])
371
        self._compare_param_groups(optim.param_groups, named_optim.param_groups)
372

373
        optim.add_param_group({"params": m.net4[1].weight, "lr": 1e-3})
374
        named_optim.add_param_group({"params": m_dup.net4[1].weight, "lr": 1e-3})
375
        _run_model_training([(m, [optim]), (m_dup, [named_optim])])
376
        self._compare_param_groups(optim.param_groups, named_optim.param_groups)
377

378
    def test_add_param_group_error(self):
379
        m = TestDummyModel()
380
        named_optim = _NamedOptimizer(
381
            m.named_parameters(),
382
            torch.optim.SGD,
383
            [
384
                {"params": m.net1.parameters()},
385
                {"params": m.net3.parameters(), "lr": 1e-3},
386
            ],
387
            lr=1e-2,
388
            momentum=0.9,
389
        )
390

391
        err_msg = "some parameters are not in the module"
392
        with self.assertRaisesRegex(ValueError, err_msg):
393
            named_optim.add_param_group({"params": [torch.ones(8, 1)], "lr": 1e-5})
394

395
    def test_init_state(self):
396
        m = TestDummyModel()
397
        named_optim = _NamedOptimizer(
398
            m.named_parameters(),
399
            torch.optim.SGD,
400
            [
401
                {"params": m.net1.parameters()},
402
                {"params": m.net3.parameters(), "lr": 1e-3},
403
            ],
404
            lr=1e-2,
405
            momentum=0.9,
406
        )
407
        named_sd = named_optim.state_dict()
408
        self.assertTrue(m.net1[0].weight.grad is None)
409
        self.assertTrue(len(named_sd["state"]) == 0)
410
        named_optim.init_state()
411
        named_sd = named_optim.state_dict()
412
        self.assertTrue(m.net1[0].weight.grad is not None)
413
        self.assertTrue("momentum_buffer" in named_sd["state"]["net1.0.weight"])
414
        self.assertFalse(
415
            torch.all(named_sd["state"]["net1.0.weight"]["momentum_buffer"]).item()
416
        )
417
        self.assertFalse(
418
            torch.all(named_sd["state"]["net1.0.bias"]["momentum_buffer"]).item()
419
        )
420
        self.assertTrue(m.net3.bias.grad is not None)
421
        self.assertTrue("momentum_buffer" in named_sd["state"]["net3.bias"])
422
        self.assertFalse(
423
            torch.all(named_sd["state"]["net3.bias"]["momentum_buffer"]).item()
424
        )
425
        self.assertFalse(
426
            torch.all(named_sd["state"]["net3.weight"]["momentum_buffer"]).item()
427
        )
428

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.