pytorch
162 строки · 5.5 Кб
1# Owner(s): ["oncall: distributed"]
2
3# Copyright (c) Meta Platforms, Inc. and affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9
10import unittest11from copy import deepcopy12
13import torch14import torch.nn as nn15
16from torch.distributed.optim import (17_apply_optimizer_in_backward,18_get_in_backward_optimizers,19)
20
21# TODO (rohan-varma): Add FSDP & DDP tests once supported
22
23
24def _validate_params(params_list, fn):25ref_params = params_list[0]26for param_list in params_list[1:]:27for p1, p2 in zip(ref_params, param_list):28fn(p1, p2)29
30
31class ApplyOverlappedOptimizerTest(unittest.TestCase):32def _run_training_loop_and_validate(self, inp, models, optimizers):33for i in range(6):34for model in models:35model(inp).sum().backward()36for opt in optimizers:37opt.step()38
39with self.subTest(i):40_validate_params(41[model.parameters() for model in models],42torch.testing.assert_allclose,43)44
45for opt in optimizers:46opt.zero_grad(set_to_none=True)47
48def _test_apply_optimizer_in_backward(self, share_params) -> None:49weight_optimizer_kwargs = {"lr": 1.0}50bias_optimizer_kwargs = {"lr": 0.5}51model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))52if share_params:53model[0].weight = model[1].weight54
55# Use different optimizers for weights & biases.56weights = [m.weight for m in model]57biases = [m.bias for m in model]58optim_weight = torch.optim.SGD(weights, **weight_optimizer_kwargs)59optim_bias = torch.optim.SGD(biases, **bias_optimizer_kwargs)60model_with_opt_in_bwd = deepcopy(model)61
62# Apply different optimizer in backwards for weights and biases.63_apply_optimizer_in_backward(64torch.optim.SGD,65[m.weight for m in model_with_opt_in_bwd],66optimizer_kwargs=weight_optimizer_kwargs,67)68
69_apply_optimizer_in_backward(70torch.optim.SGD,71[m.bias for m in model_with_opt_in_bwd],72optimizer_kwargs=bias_optimizer_kwargs,73)74
75_validate_params(76[77model.parameters(),78model_with_opt_in_bwd.parameters(),79],80torch.testing.assert_allclose,81)82
83self._run_training_loop_and_validate(84torch.randn(4, 10),85[model, model_with_opt_in_bwd],86[optim_weight, optim_bias],87)88
89def test_apply_optimizer_in_backward(self) -> None:90self._test_apply_optimizer_in_backward(share_params=False)91
92def test_apply_optimizer_in_backward_shared_params(self) -> None:93self._test_apply_optimizer_in_backward(share_params=True)94
95def test_no_register_hook(self):96model_with_hook = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))97initial_model = deepcopy(model_with_hook)98model_no_hook = deepcopy(model_with_hook)99_apply_optimizer_in_backward(100torch.optim.SGD,101model_with_hook.parameters(),102optimizer_kwargs={"lr": 0.03},103)104_apply_optimizer_in_backward(105torch.optim.SGD,106model_no_hook.parameters(),107optimizer_kwargs={"lr": 0.03},108register_hook=False,109)110inp = torch.randn(4, 10)111model_with_hook(inp).sum().backward()112model_no_hook(inp).sum().backward()113
114for p1, p2 in zip(model_with_hook.parameters(), initial_model.parameters()):115with self.assertRaises(AssertionError):116torch.testing.assert_allclose(p1, p2)117
118for p1, p2 in zip(model_no_hook.parameters(), initial_model.parameters()):119torch.testing.assert_allclose(p1, p2)120
121def test_multiple_optim_for_params(self) -> None:122model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))123opt_0_kwargs = {"lr": 0.03}124opt_1_kwargs = {"lr": 0.01}125opt_0 = torch.optim.SGD(model.parameters(), **opt_0_kwargs)126opt_1 = torch.optim.SGD(model.parameters(), **opt_1_kwargs)127model_with_opt_in_bwd = deepcopy(model)128_apply_optimizer_in_backward(129torch.optim.SGD,130model_with_opt_in_bwd.parameters(),131optimizer_kwargs=opt_0_kwargs,132)133_apply_optimizer_in_backward(134torch.optim.SGD,135model_with_opt_in_bwd.parameters(),136optimizer_kwargs=opt_1_kwargs,137)138self._run_training_loop_and_validate(139torch.randn(4, 10),140[model, model_with_opt_in_bwd],141[opt_0, opt_1],142)143
144def test_get_optimizers_in_backward(self):145# Create a simple test model146class TestModel(torch.nn.Module):147def __init__(self):148super().__init__()149self.linear1 = torch.nn.Linear(10, 5)150self.linear2 = torch.nn.Linear(5, 2)151
152model = TestModel()153
154# Apply optimizers in backward155_apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01})156in_backward_optims = _get_in_backward_optimizers(model)157self.assertEqual(len(list(model.parameters())), len(in_backward_optims))158result = set(in_backward_optims)159expected = {160optim for p in model.parameters() for optim in p._in_backward_optimizers161}162self.assertEqual(result, expected)163