pytorch
424 строки · 14.3 Кб
1# Owner(s): ["oncall: distributed"]
2
3import itertools
4import sys
5
6from typing import Union
7
8import torch
9import torch.distributed as dist
10import torch.nn as nn
11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
12from torch.distributed.fsdp.wrap import (
13always_wrap_policy as always_wrap,
14enable_wrap,
15ModuleWrapPolicy,
16wrap,
17)
18from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
19from torch.testing._internal.common_fsdp import FSDPTest
20from torch.testing._internal.common_utils import (
21instantiate_parametrized_tests,
22parametrize,
23run_tests,
24skip_but_pass_in_sandcastle_if,
25TEST_WITH_DEV_DBG_ASAN,
26)
27
28_TORCHDISTX_AVAIL = True
29try:
30from torchdistx import deferred_init
31except ImportError:
32_TORCHDISTX_AVAIL = False
33
34
35if not dist.is_available():
36print("Distributed not available, skipping tests", file=sys.stderr)
37sys.exit(0)
38
39if TEST_WITH_DEV_DBG_ASAN:
40print(
41"Skip dev-asan as torch + multiprocessing spawn have known issues",
42file=sys.stderr,
43)
44sys.exit(0)
45
46
47def _reset_params_if_meta(is_meta: bool, model: nn.Module):
48# For torchdistX init, we don't need to call reset_params, as
49# deferred_init(model).materialize() is equivalent to model().
50if is_meta:
51for module in model.modules():
52# Assume that a module has `reset_parameters()` iff it has directly
53# managed parameters or buffers
54if hasattr(module, "reset_parameters"):
55module.reset_parameters()
56
57
58class MyLinear(nn.Linear):
59"""
60Linear layer with deterministic reset_parameters for testing.
61"""
62
63def __init__(self, *args, **kwargs):
64super().__init__(*args, **kwargs)
65
66def reset_parameters(self, *args, **kwargs):
67torch.manual_seed(42)
68with torch.no_grad():
69# Use an initialization method that depends on shape
70torch.nn.init.xavier_uniform_(self.weight, 1.0)
71
72
73class MyBuffer(nn.Module):
74def __init__(self, device: torch.device):
75super().__init__()
76self.register_buffer("buf", torch.empty((3, 3), device=device))
77
78def reset_parameters(self, *args, **kwargs):
79torch.manual_seed(42)
80# Use an initialization method that depends on shape
81torch.nn.init.xavier_uniform_(self.buf, 0.5)
82
83
84class MyModel(nn.Module):
85def __init__(self, device: torch.device):
86super().__init__()
87self.lin1 = MyLinear(2, 2, bias=False, device=device)
88self.lin2 = MyLinear(2, 2, bias=False, device=device)
89self.buf_mod = MyBuffer(device)
90
91def forward(self, x):
92return self.lin2(self.lin1(x))
93
94
95class NestedModel(nn.Module):
96def __init__(self, device):
97super().__init__()
98self.lin1 = MyLinear(2, 2, bias=False, device=device)
99self.lin1 = wrap(self.lin1)
100self.lin2 = MyLinear(2, 2, bias=False, device=device)
101self.l3 = MyModel(device=device)
102self.l3 = wrap(self.l3)
103
104def forward(self, x):
105return self.l3(self.lin2(self.lin1(x)))
106
107
108def _init_with_reset_params(module: nn.Module):
109"""
110to_empty + reset_parameters() init function example for modules
111initialized with device="meta"
112"""
113has_meta_states = any(
114t.is_meta
115for t in itertools.chain(
116module.parameters(recurse=False), module.buffers(recurse=False)
117)
118)
119if has_meta_states:
120device = torch.device("cuda", torch.cuda.current_device())
121module.to_empty(device=device, recurse=False)
122module.reset_parameters()
123
124
125def _init_with_torchdistX(module: nn.Module):
126"""
127torchdistX-based deferred module initialization function example
128using ``materialize_module``.
129"""
130assert _TORCHDISTX_AVAIL
131
132def check_fn(k):
133return not isinstance(k, FSDP)
134
135deferred_init.materialize_module(module, check_fn=check_fn)
136
137
138class TestFSDPWithMetaDevice(FSDPTest):
139@property
140def world_size(self):
141return 2
142
143@property
144def process_group(self):
145return dist.distributed_c10d._get_default_group()
146
147def _compare_fsdp(self, fsdp1, fsdp2):
148with FSDP.summon_full_params(fsdp1):
149with FSDP.summon_full_params(fsdp2):
150for p1, p2 in zip(fsdp1.parameters(), fsdp2.parameters()):
151self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
152
153def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None):
154# Create model on meta device and wrap with FSDP.
155model = meta_module_fn()
156is_meta = next(model.parameters()).is_meta
157fsdp_meta = FSDP(
158model,
159auto_wrap_policy=always_wrap,
160param_init_fn=init_fn,
161)
162
163meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
164
165# Test to make sure it is the same model parameters as regular FSDP
166# approach.
167regular = MyModel(device="cuda")
168_reset_params_if_meta(is_meta, regular)
169fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap)
170regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
171
172self._compare_fsdp(fsdp_meta, fsdp_regular)
173inp = torch.randn(10, 2, device="cuda")
174fsdp_meta(inp).sum().backward()
175fsdp_regular(inp).sum().backward()
176meta_opt.step()
177regular_opt.step()
178self._compare_fsdp(fsdp_meta, fsdp_regular)
179
180# Test that meta init works if all submodules are contained in only a
181# single FSDP unit.
182model = meta_module_fn()
183fsdp_meta = FSDP(model, param_init_fn=init_fn)
184meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
185regular = MyModel(device="cuda")
186_reset_params_if_meta(is_meta, regular)
187fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap)
188regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
189
190# Run a forward + backward pass + optimizer step
191fsdp_meta(inp).sum().backward()
192fsdp_regular(inp).sum().backward()
193meta_opt.step()
194regular_opt.step()
195self._compare_fsdp(fsdp_meta, fsdp_regular)
196
197@skip_if_lt_x_gpu(2)
198def test_simple_model_with_meta_device_reset_params(self):
199def meta_module_fn():
200return MyModel(device="meta")
201
202self._test_simple_model_with_meta_device(
203meta_module_fn, _init_with_reset_params
204)
205
206@skip_if_lt_x_gpu(2)
207def test_simple_model_with_meta_device_default_init(self):
208def meta_module_fn():
209return MyModel(device="meta")
210
211self._test_simple_model_with_meta_device(meta_module_fn)
212
213@skip_if_lt_x_gpu(2)
214@skip_but_pass_in_sandcastle_if(
215not _TORCHDISTX_AVAIL,
216"Test requires torchdistX: https://github.com/pytorch/torchdistX",
217)
218def test_simple_model_with_torchdistX_default_init(self):
219def meta_module_fn():
220return deferred_init.deferred_init(MyModel, device="cuda")
221
222self._test_simple_model_with_meta_device(meta_module_fn)
223
224@skip_if_lt_x_gpu(2)
225@skip_but_pass_in_sandcastle_if(
226not _TORCHDISTX_AVAIL,
227"Test requires torchdistX: https://github.com/pytorch/torchdistX",
228)
229def test_simple_model_with_torchdistX_init_fn(self):
230def meta_module_fn():
231return deferred_init.deferred_init(MyModel, device="cuda")
232
233self._test_simple_model_with_meta_device(
234meta_module_fn, init_fn=_init_with_torchdistX
235)
236
237def _test_nested_model_with_meta_device(
238self, auto_wrap, meta_module_fn, init_fn=None
239):
240if auto_wrap:
241module = meta_module_fn()
242is_meta = (
243next(module.parameters()).is_meta or next(module.buffers()).is_meta
244)
245fsdp_meta = FSDP(
246module,
247auto_wrap_policy=always_wrap,
248param_init_fn=init_fn,
249)
250meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
251module_regular = NestedModel(device="cuda")
252_reset_params_if_meta(is_meta, module_regular)
253fsdp_regular = FSDP(
254module_regular,
255auto_wrap_policy=always_wrap,
256)
257regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
258else:
259with enable_wrap(
260wrapper_cls=FSDP,
261param_init_fn=init_fn,
262):
263module = meta_module_fn()
264is_meta = next(module.parameters()).is_meta
265# Non FSDP modules will still be initialized because they bubble up
266# to be part of a larger FSDP unit.
267fsdp_meta = wrap(module)
268meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
269
270# Init and reset parameters before wrapping so that reset_params
271# matches up with meta device's initialization.
272module_regular = NestedModel(device="cuda")
273_reset_params_if_meta(is_meta, module_regular)
274with enable_wrap(wrapper_cls=FSDP):
275module_regular.lin1 = wrap(module_regular.lin1)
276module_regular.l3 = wrap(module_regular.l3)
277fsdp_regular = wrap(module_regular)
278regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
279
280# Compare it before training
281self._compare_fsdp(fsdp_meta, fsdp_regular)
282inp = torch.randn(10, 2, device="cuda")
283fsdp_meta(inp).sum().backward()
284fsdp_regular(inp).sum().backward()
285meta_opt.step()
286regular_opt.step()
287self._compare_fsdp(fsdp_meta, fsdp_regular)
288
289@skip_if_lt_x_gpu(2)
290@parametrize("auto_wrap", [True, False])
291def test_nested_model_with_meta_device_reset_params(self, auto_wrap):
292def meta_module_fn():
293return NestedModel(device="meta")
294
295self._test_nested_model_with_meta_device(
296auto_wrap=auto_wrap,
297meta_module_fn=meta_module_fn,
298init_fn=_init_with_reset_params,
299)
300
301@skip_if_lt_x_gpu(2)
302@parametrize("auto_wrap", [True, False])
303def test_nested_model_with_meta_device_default_init(self, auto_wrap):
304def meta_module_fn():
305return NestedModel(device="meta")
306
307self._test_nested_model_with_meta_device(
308auto_wrap=auto_wrap,
309meta_module_fn=meta_module_fn,
310)
311
312@skip_if_lt_x_gpu(2)
313@skip_but_pass_in_sandcastle_if(
314not _TORCHDISTX_AVAIL,
315"Test requires torchdistX: https://github.com/pytorch/torchdistX",
316)
317@parametrize("auto_wrap", [True, False])
318def test_nested_model_with_torchdistX_default_init(self, auto_wrap):
319def meta_module_fn():
320return deferred_init.deferred_init(NestedModel, device="cuda")
321
322self._test_nested_model_with_meta_device(
323auto_wrap=auto_wrap, meta_module_fn=meta_module_fn
324)
325
326@skip_if_lt_x_gpu(2)
327@skip_but_pass_in_sandcastle_if(
328not _TORCHDISTX_AVAIL,
329"Test requires torchdistX: https://github.com/pytorch/torchdistX",
330)
331@parametrize("auto_wrap", [True, False])
332def test_nested_model_with_torchdistX_init_fn(self, auto_wrap):
333def meta_module_fn():
334return deferred_init.deferred_init(NestedModel, device="cuda")
335
336self._test_nested_model_with_meta_device(
337auto_wrap=auto_wrap,
338meta_module_fn=meta_module_fn,
339init_fn=_init_with_torchdistX,
340)
341
342def _test_bad_arg(self, meta_module_fn):
343mod = meta_module_fn()
344with self.assertRaisesRegex(ValueError, "to be callable"):
345FSDP(mod, param_init_fn=42)
346
347@skip_if_lt_x_gpu(2)
348@skip_but_pass_in_sandcastle_if(
349not _TORCHDISTX_AVAIL,
350"Test requires torchdistX: https://github.com/pytorch/torchdistX",
351)
352def test_bad_arg_torchdistx(self):
353def meta_module_fn():
354return deferred_init.deferred_init(NestedModel, "cuda")
355
356self._test_bad_arg(meta_module_fn)
357
358@skip_if_lt_x_gpu(2)
359def test_bad_arg_meta(self):
360def meta_module_fn():
361return NestedModel(device="meta")
362
363self._test_bad_arg(meta_module_fn)
364
365@skip_if_lt_x_gpu(2)
366def test_meta_device_with_mixed_precision(self):
367"""
368Tests meta device initialization with a ``param_init_fn`` when
369specifying mixed precision with ``param_dtype=torch.float32``.
370"""
371
372class FakeLinear(nn.Module):
373def __init__(
374self, in_dim: int, out_dim: int, device: Union[torch.device, str]
375) -> None:
376super().__init__()
377self.weight = nn.Parameter(
378torch.randn((in_dim, out_dim), device=device)
379)
380
381def forward(self, x: torch.Tensor) -> torch.Tensor:
382return x @ self.weight
383
384class Model(nn.Module):
385def __init__(self) -> None:
386super().__init__()
387self.lin1 = nn.Linear(5, 5, device="meta")
388self.lin2 = FakeLinear(5, 5, device="meta")
389self.relu = nn.ReLU()
390
391def forward(self, x: torch.Tensor) -> torch.Tensor:
392return self.lin2(self.relu(self.lin1(x)))
393
394def _module_init_fn(self, module: nn.Module):
395if isinstance(module, nn.Linear):
396torch.nn.init.normal_(module.weight, mean=0.0, std=0.1)
397if module.bias is not None:
398torch.nn.init.zeros_(module.bias)
399
400def _param_init_fn(module: nn.Module) -> None:
401# TODO: `module.to_empty()` is not generally correct for meta
402# device initialization.
403# https://github.com/pytorch/pytorch/issues/90465
404module.to_empty(device=torch.device("cuda"))
405module.apply(model._module_init_fn)
406
407model = Model()
408# Wrap `lin1` and the top level `model` to create nested FSDP instances
409# where each instance has parameters
410FSDP(
411model,
412auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
413mixed_precision=MixedPrecision(
414param_dtype=torch.float32, reduce_dtype=torch.float16
415),
416param_init_fn=_param_init_fn,
417device_id=torch.cuda.current_device(),
418)
419
420
421instantiate_parametrized_tests(TestFSDPWithMetaDevice)
422
423if __name__ == "__main__":
424run_tests()
425