pytorch
162 строки · 5.5 Кб
1# Owner(s): ["oncall: distributed"]
2
3# Copyright (c) Meta Platforms, Inc. and affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9
10import unittest
11from copy import deepcopy
12
13import torch
14import torch.nn as nn
15from torch.distributed.optim import (
16_apply_optimizer_in_backward,
17_get_in_backward_optimizers,
18)
19
20
21# TODO (rohan-varma): Add FSDP & DDP tests once supported
22
23
24def _validate_params(params_list, fn):
25ref_params = params_list[0]
26for param_list in params_list[1:]:
27for p1, p2 in zip(ref_params, param_list):
28fn(p1, p2)
29
30
31class ApplyOverlappedOptimizerTest(unittest.TestCase):
32def _run_training_loop_and_validate(self, inp, models, optimizers):
33for i in range(6):
34for model in models:
35model(inp).sum().backward()
36for opt in optimizers:
37opt.step()
38
39with self.subTest(i):
40_validate_params(
41[model.parameters() for model in models],
42torch.testing.assert_allclose,
43)
44
45for opt in optimizers:
46opt.zero_grad(set_to_none=True)
47
48def _test_apply_optimizer_in_backward(self, share_params) -> None:
49weight_optimizer_kwargs = {"lr": 1.0}
50bias_optimizer_kwargs = {"lr": 0.5}
51model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
52if share_params:
53model[0].weight = model[1].weight
54
55# Use different optimizers for weights & biases.
56weights = [m.weight for m in model]
57biases = [m.bias for m in model]
58optim_weight = torch.optim.SGD(weights, **weight_optimizer_kwargs)
59optim_bias = torch.optim.SGD(biases, **bias_optimizer_kwargs)
60model_with_opt_in_bwd = deepcopy(model)
61
62# Apply different optimizer in backwards for weights and biases.
63_apply_optimizer_in_backward(
64torch.optim.SGD,
65[m.weight for m in model_with_opt_in_bwd],
66optimizer_kwargs=weight_optimizer_kwargs,
67)
68
69_apply_optimizer_in_backward(
70torch.optim.SGD,
71[m.bias for m in model_with_opt_in_bwd],
72optimizer_kwargs=bias_optimizer_kwargs,
73)
74
75_validate_params(
76[
77model.parameters(),
78model_with_opt_in_bwd.parameters(),
79],
80torch.testing.assert_allclose,
81)
82
83self._run_training_loop_and_validate(
84torch.randn(4, 10),
85[model, model_with_opt_in_bwd],
86[optim_weight, optim_bias],
87)
88
89def test_apply_optimizer_in_backward(self) -> None:
90self._test_apply_optimizer_in_backward(share_params=False)
91
92def test_apply_optimizer_in_backward_shared_params(self) -> None:
93self._test_apply_optimizer_in_backward(share_params=True)
94
95def test_no_register_hook(self):
96model_with_hook = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
97initial_model = deepcopy(model_with_hook)
98model_no_hook = deepcopy(model_with_hook)
99_apply_optimizer_in_backward(
100torch.optim.SGD,
101model_with_hook.parameters(),
102optimizer_kwargs={"lr": 0.03},
103)
104_apply_optimizer_in_backward(
105torch.optim.SGD,
106model_no_hook.parameters(),
107optimizer_kwargs={"lr": 0.03},
108register_hook=False,
109)
110inp = torch.randn(4, 10)
111model_with_hook(inp).sum().backward()
112model_no_hook(inp).sum().backward()
113
114for p1, p2 in zip(model_with_hook.parameters(), initial_model.parameters()):
115with self.assertRaises(AssertionError):
116torch.testing.assert_allclose(p1, p2)
117
118for p1, p2 in zip(model_no_hook.parameters(), initial_model.parameters()):
119torch.testing.assert_allclose(p1, p2)
120
121def test_multiple_optim_for_params(self) -> None:
122model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
123opt_0_kwargs = {"lr": 0.03}
124opt_1_kwargs = {"lr": 0.01}
125opt_0 = torch.optim.SGD(model.parameters(), **opt_0_kwargs)
126opt_1 = torch.optim.SGD(model.parameters(), **opt_1_kwargs)
127model_with_opt_in_bwd = deepcopy(model)
128_apply_optimizer_in_backward(
129torch.optim.SGD,
130model_with_opt_in_bwd.parameters(),
131optimizer_kwargs=opt_0_kwargs,
132)
133_apply_optimizer_in_backward(
134torch.optim.SGD,
135model_with_opt_in_bwd.parameters(),
136optimizer_kwargs=opt_1_kwargs,
137)
138self._run_training_loop_and_validate(
139torch.randn(4, 10),
140[model, model_with_opt_in_bwd],
141[opt_0, opt_1],
142)
143
144def test_get_optimizers_in_backward(self):
145# Create a simple test model
146class TestModel(torch.nn.Module):
147def __init__(self) -> None:
148super().__init__()
149self.linear1 = torch.nn.Linear(10, 5)
150self.linear2 = torch.nn.Linear(5, 2)
151
152model = TestModel()
153
154# Apply optimizers in backward
155_apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01})
156in_backward_optims = _get_in_backward_optimizers(model)
157self.assertEqual(len(list(model.parameters())), len(in_backward_optims))
158result = set(in_backward_optims)
159expected = {
160optim for p in model.parameters() for optim in p._in_backward_optimizers
161}
162self.assertEqual(result, expected)
163