pytorch
427 строк · 14.6 Кб
1# Owner(s): ["oncall: distributed"]
2
3# Copyright (c) Meta Platforms, Inc. and affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import unittest10
11import torch12import torch.nn as nn13
14from torch.distributed.optim import _NamedOptimizer15
16
17def _run_model_training(model_optim_lists):18for _ in range(2):19x = torch.rand(5, 8)20for model_optim_list in model_optim_lists:21model = model_optim_list[0]22optim_list = model_optim_list[1]23y = model(x)24y.sum().backward()25for optim in optim_list:26optim.step()27
28
29class TestDummyModel(torch.nn.Module):30def __init__(self):31super().__init__()32torch.manual_seed(0)33self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())34self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())35self.net3 = nn.Linear(32, 64)36self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))37
38def forward(self, x):39return self.net4(self.net3(self.net2(self.net1(x))))40
41
42class NamedOptimizerTest(unittest.TestCase):43def _compare_state_dict_group(self, group, named_group, assert_equal=True):44for key, val in group.items():45if key != "params":46self.assertTrue(47key in named_group, f"{key} not in named optimizer state dict"48)49err_msg = (50f"{key} state not equal" if assert_equal else f"{key} state equal"51)52if isinstance(val, torch.Tensor):53fn = self.assertTrue if assert_equal else self.assertFalse54fn(torch.allclose(val, named_group[key]), err_msg)55else:56fn = self.assertEqual if assert_equal else self.assertNotEqual57fn(val, named_group[key], err_msg)58
59def _compare_param_groups(self, param_groups_1, param_groups_2):60self.assertTrue(isinstance(param_groups_1, list))61self.assertTrue(isinstance(param_groups_2, list))62for groups in zip(param_groups_1, param_groups_2):63self._compare_param_group(groups[0], groups[1])64
65def _compare_param_group(self, group_1, group_2):66self.assertTrue(isinstance(group_1, dict))67self.assertTrue(isinstance(group_2, dict))68for key, val in group_1.items():69self.assertTrue(key in group_2)70if key != "params":71self.assertEqual(val, group_2[key])72else:73for tensors in zip(val, group_2[key]):74self.assertTrue(torch.allclose(tensors[0], tensors[1]))75
76def test_state_dict(self):77"""Check that NamedOptimizer exposes the expected state dict78interface."""
79m = TestDummyModel()80m_dup = TestDummyModel()81optim = torch.optim.SGD(82m.parameters(),83lr=1e-2,84momentum=0.9,85)86
87named_optim = _NamedOptimizer(88m_dup.named_parameters(),89torch.optim.SGD,90lr=1e-2,91momentum=0.9,92)93self._compare_param_groups(optim.param_groups, named_optim.param_groups)94
95_run_model_training([(m, [optim]), (m_dup, [named_optim])])96self._compare_param_groups(optim.param_groups, named_optim.param_groups)97
98sd = optim.state_dict()99named_sd = named_optim.state_dict()100
101# Compare "state" in optim state dict102self._compare_state_dict_group(103sd["state"][0],104named_sd["state"]["net1.0.weight"],105assert_equal=True,106)107self._compare_state_dict_group(108sd["state"][3],109named_sd["state"]["net2.0.bias"],110assert_equal=True,111)112self._compare_state_dict_group(113sd["state"][4],114named_sd["state"]["net3.weight"],115assert_equal=True,116)117self._compare_state_dict_group(118sd["state"][7],119named_sd["state"]["net4.1.bias"],120assert_equal=True,121)122
123def test_state_dict_multi_param_group(self):124"""Check that NamedOptimizer exposes the expected state dict125interface when multiple param groups are specified."""
126m = TestDummyModel()127m_dup = TestDummyModel()128optim_1 = torch.optim.SGD(129[130{"params": m.net1.parameters()},131{"params": m.net3.parameters(), "lr": 1e-3},132],133lr=1e-2,134momentum=0.9,135)136
137optim_2 = torch.optim.Adam(138[139{"params": m.net2.parameters()},140{"params": m.net4.parameters(), "lr": 1e-5},141]142)143
144named_optim_1 = _NamedOptimizer(145m_dup.named_parameters(),146torch.optim.SGD,147[148{"params": m_dup.net1.parameters()},149{"params": m_dup.net3.parameters(), "lr": 1e-3},150],151lr=1e-2,152momentum=0.9,153)154
155named_optim_2 = _NamedOptimizer(156m_dup.named_parameters(),157torch.optim.Adam,158[159{"params": m_dup.net2.parameters()},160{"params": m_dup.net4.parameters(), "lr": 1e-5},161],162)163self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)164self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)165
166_run_model_training(167[(m, [optim_1, optim_2]), (m_dup, [named_optim_1, named_optim_2])]168)169self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)170self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)171sd_1 = optim_1.state_dict()172sd_2 = optim_2.state_dict()173named_sd_1 = named_optim_1.state_dict()174named_sd_2 = named_optim_2.state_dict()175
176# Compare "state" in optim state dict177self._compare_state_dict_group(178sd_1["state"][0],179named_sd_1["state"]["net1.0.weight"],180assert_equal=True,181)182self._compare_state_dict_group(183sd_2["state"][1],184named_sd_2["state"]["net2.0.bias"],185assert_equal=True,186)187self._compare_state_dict_group(188sd_1["state"][2],189named_sd_1["state"]["net3.weight"],190assert_equal=True,191)192self._compare_state_dict_group(193sd_2["state"][3],194named_sd_2["state"]["net4.1.bias"],195assert_equal=True,196)197
198# Compare "param_groups" in optim state dict199self._compare_state_dict_group(200sd_1["param_groups"][0],201named_sd_1["param_groups"][0],202assert_equal=True,203)204self._compare_state_dict_group(205sd_2["param_groups"][1], named_sd_2["param_groups"][1], assert_equal=True206)207
208def test_load_state_dict(self):209"""Check that NamedOptimizer's load_state_dict works as expected."""210m = TestDummyModel()211named_optim_1 = _NamedOptimizer(212m.named_parameters(),213torch.optim.SGD,214lr=1e-2,215momentum=0.9,216)217
218_run_model_training([(m, [named_optim_1])])219state_dict_to_load = named_optim_1.state_dict()220
221named_optim_2 = _NamedOptimizer(222m.named_parameters(),223torch.optim.SGD,224lr=1e-2,225momentum=0.6,226)227
228_run_model_training([(m, [named_optim_2])])229state_dict_before_load = named_optim_2.state_dict()230
231# Compare "state" in optim state dict232self._compare_state_dict_group(233state_dict_to_load["state"]["net1.0.weight"],234state_dict_before_load["state"]["net1.0.weight"],235assert_equal=False,236)237self._compare_state_dict_group(238state_dict_to_load["state"]["net2.0.bias"],239state_dict_before_load["state"]["net2.0.bias"],240assert_equal=False,241)242self._compare_state_dict_group(243state_dict_to_load["state"]["net3.weight"],244state_dict_before_load["state"]["net3.weight"],245assert_equal=False,246)247self._compare_state_dict_group(248state_dict_to_load["state"]["net4.1.bias"],249state_dict_before_load["state"]["net4.1.bias"],250assert_equal=False,251)252
253named_optim_2.load_state_dict(state_dict_to_load)254state_dict_after_load = named_optim_2.state_dict()255
256# Compare "state" in optim state dict257self._compare_state_dict_group(258state_dict_to_load["state"]["net1.0.weight"],259state_dict_after_load["state"]["net1.0.weight"],260assert_equal=True,261)262self._compare_state_dict_group(263state_dict_to_load["state"]["net2.0.bias"],264state_dict_after_load["state"]["net2.0.bias"],265assert_equal=True,266)267self._compare_state_dict_group(268state_dict_to_load["state"]["net3.weight"],269state_dict_after_load["state"]["net3.weight"],270assert_equal=True,271)272self._compare_state_dict_group(273state_dict_to_load["state"]["net4.1.bias"],274state_dict_after_load["state"]["net4.1.bias"],275assert_equal=True,276)277
278def test_load_state_dict_conditional_training(self):279"""Check that NamedOptimizer load_state_dict works under conditional training case."""280m = TestDummyModel()281named_optim_1 = _NamedOptimizer(282m.named_parameters(),283torch.optim.SGD,284[285{"params": m.net1.parameters()},286{"params": m.net3.parameters(), "lr": 1e-3},287],288lr=1e-2,289momentum=0.9,290)291
292_run_model_training([(m, [named_optim_1])])293state_dict_to_load = named_optim_1.state_dict()294
295named_optim_2 = _NamedOptimizer(296m.named_parameters(),297torch.optim.SGD,298lr=1e-2,299momentum=0.6,300)301
302_run_model_training([(m, [named_optim_2])])303named_optim_2.load_state_dict(state_dict_to_load)304state_dict_after_load = named_optim_2.state_dict()305
306# Compare "state" in optim state dict307self._compare_state_dict_group(308state_dict_to_load["state"]["net1.0.weight"],309state_dict_after_load["state"]["net1.0.weight"],310assert_equal=True,311)312self._compare_state_dict_group(313state_dict_to_load["state"]["net3.weight"],314state_dict_after_load["state"]["net3.weight"],315assert_equal=True,316)317
318def test_load_state_dict_error(self):319m = TestDummyModel()320named_optim_1 = _NamedOptimizer(321m.named_parameters(),322torch.optim.SGD,323lr=1e-2,324momentum=0.9,325)326
327_run_model_training([(m, [named_optim_1])])328state_dict_to_load = named_optim_1.state_dict()329
330named_optim_2 = _NamedOptimizer(331m.named_parameters(),332torch.optim.SGD,333lr=1e-2,334momentum=0.6,335)336
337err_msg = (338"Expects the optim to be initialized before load but found not initialized"339)340with self.assertRaisesRegex(ValueError, err_msg):341named_optim_2.load_state_dict(state_dict_to_load)342
343def test_add_param_group(self):344m = TestDummyModel()345m_dup = TestDummyModel()346optim = torch.optim.SGD(347[348{"params": m.net1.parameters()},349{"params": m.net3.parameters(), "lr": 1e-3},350],351lr=1e-2,352momentum=0.9,353)354named_optim = _NamedOptimizer(355m_dup.named_parameters(),356torch.optim.SGD,357[358{"params": m_dup.net1.parameters()},359{"params": m_dup.net3.parameters(), "lr": 1e-3},360],361lr=1e-2,362momentum=0.9,363)364
365_run_model_training([(m, [optim]), (m_dup, [named_optim])])366self._compare_param_groups(optim.param_groups, named_optim.param_groups)367
368optim.add_param_group({"params": m.net2.parameters(), "lr": 1e-5})369named_optim.add_param_group({"params": m_dup.net2.parameters(), "lr": 1e-5})370_run_model_training([(m, [optim]), (m_dup, [named_optim])])371self._compare_param_groups(optim.param_groups, named_optim.param_groups)372
373optim.add_param_group({"params": m.net4[1].weight, "lr": 1e-3})374named_optim.add_param_group({"params": m_dup.net4[1].weight, "lr": 1e-3})375_run_model_training([(m, [optim]), (m_dup, [named_optim])])376self._compare_param_groups(optim.param_groups, named_optim.param_groups)377
378def test_add_param_group_error(self):379m = TestDummyModel()380named_optim = _NamedOptimizer(381m.named_parameters(),382torch.optim.SGD,383[384{"params": m.net1.parameters()},385{"params": m.net3.parameters(), "lr": 1e-3},386],387lr=1e-2,388momentum=0.9,389)390
391err_msg = "some parameters are not in the module"392with self.assertRaisesRegex(ValueError, err_msg):393named_optim.add_param_group({"params": [torch.ones(8, 1)], "lr": 1e-5})394
395def test_init_state(self):396m = TestDummyModel()397named_optim = _NamedOptimizer(398m.named_parameters(),399torch.optim.SGD,400[401{"params": m.net1.parameters()},402{"params": m.net3.parameters(), "lr": 1e-3},403],404lr=1e-2,405momentum=0.9,406)407named_sd = named_optim.state_dict()408self.assertTrue(m.net1[0].weight.grad is None)409self.assertTrue(len(named_sd["state"]) == 0)410named_optim.init_state()411named_sd = named_optim.state_dict()412self.assertTrue(m.net1[0].weight.grad is not None)413self.assertTrue("momentum_buffer" in named_sd["state"]["net1.0.weight"])414self.assertFalse(415torch.all(named_sd["state"]["net1.0.weight"]["momentum_buffer"]).item()416)417self.assertFalse(418torch.all(named_sd["state"]["net1.0.bias"]["momentum_buffer"]).item()419)420self.assertTrue(m.net3.bias.grad is not None)421self.assertTrue("momentum_buffer" in named_sd["state"]["net3.bias"])422self.assertFalse(423torch.all(named_sd["state"]["net3.bias"]["momentum_buffer"]).item()424)425self.assertFalse(426torch.all(named_sd["state"]["net3.weight"]["momentum_buffer"]).item()427)428