pytorch
250 строк · 7.2 Кб
1# Owner(s): ["oncall: distributed"]
2
3import contextlib
4import sys
5from enum import Enum
6
7import torch
8import torch.nn as nn
9import torch.optim as optim
10from torch import distributed as dist
11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
12from torch.nn.parallel import DistributedDataParallel
13from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
14from torch.testing._internal.common_fsdp import FSDPTest, get_full_params
15from torch.testing._internal.common_utils import (
16instantiate_parametrized_tests,
17parametrize,
18run_tests,
19TEST_WITH_DEV_DBG_ASAN,
20)
21
22if not dist.is_available():
23print("Distributed not available, skipping tests", file=sys.stderr)
24sys.exit(0)
25
26if TEST_WITH_DEV_DBG_ASAN:
27print(
28"Skip dev-asan as torch + multiprocessing spawn have known issues",
29file=sys.stderr,
30)
31sys.exit(0)
32
33
34class Model(nn.Module):
35def __init__(
36self,
37with_fsdp,
38freeze_after_wrap_fsdp,
39disable_autograd,
40fsdp_kwargs,
41):
42super().__init__()
43self.trunk = nn.Sequential(
44nn.Conv2d(3, 64, kernel_size=3),
45nn.ReLU(inplace=True),
46nn.AdaptiveAvgPool2d(output_size=(1, 1)),
47nn.Flatten(),
48)
49self.device = torch.cuda.current_device()
50self.head = nn.Linear(64, 10)
51if with_fsdp and freeze_after_wrap_fsdp:
52self.fsdp_wrap(fsdp_kwargs)
53self.autograd_ctx = (
54torch.no_grad if disable_autograd else contextlib.nullcontext
55)
56
57def fsdp_wrap(self, fsdp_kwargs):
58self.trunk = FSDP(self.trunk, **fsdp_kwargs)
59self.head = FSDP(self.head, **fsdp_kwargs)
60
61def forward(self, x):
62with self.autograd_ctx():
63x = self.trunk(x)
64return self.head(x)
65
66
67class NestedTrunkModel(nn.Module):
68def __init__(
69self,
70with_fsdp,
71freeze_after_wrap_fsdp,
72disable_autograd,
73fsdp_kwargs,
74):
75super().__init__()
76self.trunk = nn.Sequential(
77self._create_block(3, 64, with_fsdp, freeze_after_wrap_fsdp),
78self._create_block(64, 64, with_fsdp, freeze_after_wrap_fsdp),
79)
80self.head = nn.Sequential(
81nn.AdaptiveAvgPool2d(output_size=(1, 1)),
82nn.Flatten(),
83nn.Linear(64, 10),
84)
85if with_fsdp and freeze_after_wrap_fsdp:
86self.fsdp_wrap(fsdp_kwargs)
87self.autograd_ctx = (
88torch.no_grad if disable_autograd else contextlib.nullcontext
89)
90
91def fsdp_wrap(self, fsdp_kwargs):
92for name, child in self.trunk.named_children():
93wrapped_child = FSDP(child, **fsdp_kwargs)
94setattr(self.trunk, name, wrapped_child)
95self.trunk = FSDP(self.trunk, **fsdp_kwargs)
96self.head = FSDP(self.head, **fsdp_kwargs)
97
98def forward(self, x):
99with self.autograd_ctx():
100x = self.trunk(x)
101return self.head(x)
102
103def _create_block(
104self, in_channels, out_channels, with_fsdp, freeze_after_wrap_fsdp
105):
106block = nn.Sequential(
107nn.Conv2d(in_channels, out_channels, kernel_size=3),
108nn.ReLU(inplace=True),
109)
110return block
111
112
113class FreezingMethod(str, Enum):
114GradToNone = "grad_to_none"
115RequiresGrad = "requires_grad"
116
117
118class TestFreezingWeights(FSDPTest):
119def _create_model(
120self,
121with_fsdp,
122with_nested_trunk,
123freeze_after_wrap_fsdp,
124disable_autograd,
125fsdp_kwargs,
126):
127if with_nested_trunk:
128model = NestedTrunkModel(
129with_fsdp, freeze_after_wrap_fsdp, disable_autograd, fsdp_kwargs
130)
131else:
132model = Model(
133with_fsdp, freeze_after_wrap_fsdp, disable_autograd, fsdp_kwargs
134)
135return model
136
137def _dist_train(
138self,
139with_nested_trunk,
140freezing_method,
141freeze_after_wrap_fsdp,
142with_fsdp,
143disable_autograd,
144forward_prefetch,
145):
146torch.manual_seed(0)
147batch = torch.randn(size=(2, 3, 224, 224)).cuda()
148
149fsdp_kwargs = {
150"device_id": self.rank,
151"forward_prefetch": forward_prefetch,
152}
153
154ddp_kwargs = {
155"device_ids": [self.rank],
156"find_unused_parameters": True if disable_autograd else False,
157}
158
159model = self._create_model(
160with_fsdp,
161with_nested_trunk,
162freeze_after_wrap_fsdp,
163disable_autograd,
164fsdp_kwargs,
165)
166model = model.cuda()
167
168# freezing the trunk using requires_grad.
169if freezing_method == FreezingMethod.RequiresGrad:
170for param in model.trunk.parameters():
171param.requires_grad = False
172
173if with_fsdp:
174if not freeze_after_wrap_fsdp:
175model.fsdp_wrap(fsdp_kwargs)
176model = FSDP(model, **fsdp_kwargs)
177else:
178model = DistributedDataParallel(model, **ddp_kwargs)
179
180target = torch.tensor([0, 1], dtype=torch.long).cuda()
181criterion = nn.CrossEntropyLoss()
182optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
183
184for iteration in range(3):
185out = model(batch)
186fake_loss = criterion(out, target)
187optimizer.zero_grad()
188fake_loss.backward()
189if freezing_method == FreezingMethod.GradToNone:
190for param in model.module.trunk.parameters():
191param.grad = None
192optimizer.step()
193
194if with_fsdp:
195return get_full_params(model)
196
197return list(model.parameters())
198
199@skip_if_lt_x_gpu(2)
200@parametrize("with_nested_trunk", [True, False])
201@parametrize(
202"freezing_method", [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone]
203)
204@parametrize("freeze_after_wrap_fsdp", [True, False])
205@parametrize("disable_autograd", [True, False])
206@parametrize("forward_prefetch", [True, False])
207def test_freezing_weights(
208self,
209with_nested_trunk,
210freezing_method,
211freeze_after_wrap_fsdp,
212disable_autograd,
213forward_prefetch,
214):
215# DDP
216ddp_state = self._dist_train(
217with_nested_trunk,
218freezing_method,
219freeze_after_wrap_fsdp,
220with_fsdp=False,
221disable_autograd=disable_autograd,
222forward_prefetch=False, # does not apply to DDP
223)
224
225# FSDP
226fsdp_state = self._dist_train(
227with_nested_trunk,
228freezing_method,
229freeze_after_wrap_fsdp,
230with_fsdp=True,
231disable_autograd=disable_autograd,
232forward_prefetch=forward_prefetch,
233)
234
235self.assertEqual(
236ddp_state,
237fsdp_state,
238exact_device=True,
239msg="FullyShardedDataParallel states didn't match PyTorch DDP states",
240)
241
242if freezing_method == FreezingMethod.RequiresGrad:
243for ddp_param, fsdp_param in zip(ddp_state, fsdp_state):
244self.assertEqual(ddp_param.requires_grad, fsdp_param.requires_grad)
245
246
247instantiate_parametrized_tests(TestFreezingWeights)
248
249if __name__ == "__main__":
250run_tests()
251