pytorch
569 строк · 20.0 Кб
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import torch
6import torch.nn as nn
7from torch import distributed as dist
8from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
9from torch.distributed.fsdp._flat_param import (
10FlatParamHandle,
11FlatParamShardMetadata,
12HandleShardingStrategy,
13)
14from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
15from torch.testing._internal.common_fsdp import FSDPTest
16from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
17
18if not dist.is_available():
19print("Distributed not available, skipping tests", file=sys.stderr)
20sys.exit(0)
21
22if TEST_WITH_DEV_DBG_ASAN:
23print(
24"Skip dev-asan as torch + multiprocessing spawn have known issues",
25file=sys.stderr,
26)
27sys.exit(0)
28
29
30class TestFlattenParams(FSDPTest):
31"""Tests parameter flattening and shard metadata logic."""
32
33@property
34def world_size(self) -> int:
35# Clamp the world size to 1 since these unit tests either exercise only
36# the flattening logic or check sharding subroutines directly without
37# requiring multiple ranks
38return 1
39
40def _get_default_config(self):
41return {
42"device": torch.device("cuda"),
43"sharding_strategy": HandleShardingStrategy.FULL_SHARD,
44"offload_params": False,
45"mp_param_dtype": None,
46"mp_reduce_dtype": None,
47"keep_low_precision_grads": False,
48"process_group": self.process_group,
49"use_orig_params": False,
50"fsdp_extension": None,
51}
52
53def _get_transformer(self, seed=0):
54torch.manual_seed(seed) # keep everything deterministic
55module = torch.nn.Transformer(
56d_model=32,
57num_encoder_layers=2,
58num_decoder_layers=2,
59dim_feedforward=128,
60dropout=0.1,
61)
62module.register_buffer("dummy_buffer", torch.tensor(1.0))
63
64def get_input(device, dtype):
65torch.manual_seed(1) # keep everything deterministic
66src = torch.rand(20, 8, 32).to(device=device, dtype=dtype) # T x B x C
67tgt = torch.rand(10, 8, 32).to(device=device, dtype=dtype) # T x B x C
68return (src, tgt)
69
70module.get_input = get_input
71return module
72
73def _get_shared_params_transformer(self, seed=0):
74module = self._get_transformer(seed=seed)
75# share the FFNs
76for enc_layer, dec_layer in zip(module.encoder.layers, module.decoder.layers):
77dec_layer.linear1.weight = enc_layer.linear1.weight
78dec_layer.linear2.weight = enc_layer.linear2.weight
79return module
80
81@skip_if_lt_x_gpu(1)
82def test_partial_flattening(self):
83"""Tests flattening some submodules but not others."""
84self.run_subtests(
85{"half": [False, True]},
86self._test_partial_flattening,
87)
88
89def _test_partial_flattening(self, half: bool):
90module = self._get_transformer()
91if half:
92module = module.half()
93numel = sum(p.numel() for p in module.parameters())
94
95encoder_1_params = list(module.encoder.layers[1].parameters())
96decoder_0_params = list(module.decoder.layers[0].parameters())
97params_to_flatten = encoder_1_params + decoder_0_params
98num_params = [len(encoder_1_params), len(decoder_0_params)]
99numel_to_flatten = sum(p.numel() for p in params_to_flatten)
100module.encoder.layers[1] = FSDP(module.encoder.layers[1])
101module.decoder.layers[0] = FSDP(module.decoder.layers[0])
102flat_params = [
103module.encoder.layers[1]._flat_param,
104module.decoder.layers[0]._flat_param,
105]
106
107self.assertEqual(sum(fp.numel() for fp in flat_params), numel_to_flatten)
108self.assertEqual(sum(p.numel() for p in module.parameters()), numel)
109
110# Check that flattened parameters have been replaced with a single
111# `FlatParameter`
112self.assertEqual(len(list(module.encoder.layers[1].parameters())), 1)
113self.assertEqual(len(list(module.decoder.layers[0].parameters())), 1)
114
115# Check that non-flattened parameters remain
116self.assertEqual(
117len(list(module.encoder.layers[0].parameters())), num_params[0]
118)
119self.assertEqual(
120len(list(module.decoder.layers[1].parameters())), num_params[1]
121)
122
123# Check that calling `module.to()` affects the `FlatParameter`s
124orig_dtype = params_to_flatten[0].dtype
125new_dtype = torch.float32 if orig_dtype == torch.float16 else torch.float16
126for flat_param in flat_params:
127self.assertEqual(flat_param.dtype, orig_dtype)
128self.assertTrue(
129all(p.dtype == orig_dtype for p in module.encoder.layers[0].parameters())
130)
131module = module.to(dtype=new_dtype)
132for flat_param in flat_params:
133self.assertEqual(flat_param.dtype, new_dtype)
134self.assertTrue(
135all(p.dtype == new_dtype for p in module.encoder.layers[0].parameters())
136)
137
138def test_flatten_nothing(self):
139"""
140Tests that constructing a ``FlatParamHandle`` with no parameters
141raises an error.
142"""
143self.run_subtests(
144{"half": [False, True]},
145self._test_flatten_nothing,
146)
147
148def _test_flatten_nothing(self, half: bool):
149module = self._get_transformer()
150if half:
151module = module.half()
152with self.assertRaisesRegex(
153ValueError,
154"Cannot construct a FlatParamHandle with an empty parameter list",
155):
156FlatParamHandle(
157[],
158module,
159**self._get_default_config(),
160)
161
162@skip_if_lt_x_gpu(1)
163def test_empty_module(self):
164"""
165Tests flattening an empty module (i.e. one without any parameters).
166"""
167module = self._get_empty_module()
168in_data = torch.rand(1)
169ref_out = module(in_data)
170fsdp_module = FSDP(module)
171self.assertEqual(len(list(fsdp_module.parameters())), 0)
172self.assertIsNone(fsdp_module._flat_param)
173fsdp_out = fsdp_module(in_data)
174self.assertEqual(ref_out, fsdp_out)
175
176def _get_empty_module(self):
177"""Returns a module with no parameters."""
178torch.manual_seed(0) # keep everything deterministic
179
180class EmptyModule(torch.nn.Module):
181def forward(self, x):
182return x + 1
183
184def get_input(self, device, dtype):
185torch.manual_seed(1) # keep everything deterministic
186return torch.rand(1).to(device=device, dtype=dtype)
187
188return EmptyModule()
189
190def test_numel_without_shared_params(self):
191"""
192Tests that numel is preserved after flattening when there are no shared
193parameters in the module.
194"""
195self.run_subtests(
196{"half": [False, True]},
197self._test_numel_without_shared_params,
198)
199
200def _test_numel_without_shared_params(self, half: bool):
201module = self._get_transformer()
202if half:
203module = module.half()
204self._test_numel(module)
205
206def test_numel_with_shared_params(self):
207"""
208Tests that numel is preserved after flattening when there are shared
209parameters in the module.
210"""
211self.run_subtests(
212{"half": [False, True]},
213self._test_numel_with_shared_params,
214)
215
216def _test_numel_with_shared_params(self, half: bool):
217module = self._get_shared_params_transformer()
218if half:
219module = module.half()
220self._test_numel(module)
221
222def _test_numel(self, module):
223ref_numel = sum(p.numel() for p in module.parameters())
224params_to_flatten = list(module.parameters())
225flat_param_handle = FlatParamHandle(
226params_to_flatten,
227module,
228**self._get_default_config(),
229)
230self.assertEqual(ref_numel, flat_param_handle.flat_param.numel())
231
232@skip_if_lt_x_gpu(1)
233def test_output_without_shared_params(self):
234"""
235Tests a forward pass after flattening when there are no shared
236parameters in the module.
237"""
238self.run_subtests(
239{"half": [False, True]},
240self._test_output_without_shared_params,
241)
242
243def _test_output_without_shared_params(self, half: bool):
244module = self._get_transformer()
245if half:
246module = module.half()
247self._test_output(module)
248
249@skip_if_lt_x_gpu(1)
250def test_output_with_shared_params(self):
251"""
252Tests a forward pass after flattening when there are shared parameters
253in the module.
254"""
255self.run_subtests(
256{"half": [False, True]},
257self._test_output_with_shared_params,
258)
259
260def _test_output_with_shared_params(self, half: bool):
261module = self._get_shared_params_transformer()
262if half:
263module = module.half()
264self._test_output(module)
265
266def _test_output(self, module: nn.Module):
267module = module.to(self.rank)
268ref_output = self._get_output(module)
269fsdp_module = FSDP(module)
270fsdp_output = self._get_output(fsdp_module)
271self.assertEqual(ref_output, fsdp_output)
272
273def _get_output(self, module):
274device = next(module.parameters()).device
275dtype = next(module.parameters()).dtype
276input = module.get_input(device, dtype)
277return module(*input)
278
279@skip_if_lt_x_gpu(1)
280def test_pnorm_after_step_with_shared_params(self):
281"""
282Tests for parameter Frobenius norm parity after an optimizer step when
283there are shared parameters in the module. If the parameter sharing is
284handled incorrectly, then an optimizer step should reveal that.
285"""
286self.run_subtests(
287{"half": [False, True]},
288self._test_pnorm_after_step_with_shared_params,
289)
290
291def _test_pnorm_after_step_with_shared_params(self, half: bool):
292module = self._get_shared_params_transformer().to(self.rank)
293if half:
294module = module.half()
295ref_pnorm_after_step = self._get_pnorm_after_step(module)
296module = self._get_shared_params_transformer().to(self.rank) # recreate
297if half:
298module = module.half()
299fsdp_module = FSDP(module)
300fsdp_pnorm_after_step = self._get_pnorm_after_step(fsdp_module)
301self.assertEqual(ref_pnorm_after_step, fsdp_pnorm_after_step)
302
303def _get_pnorm_after_step(self, module):
304optim = torch.optim.SGD(module.parameters(), lr=0.01)
305loss = self._get_output(module).sum()
306loss.backward()
307optim.step()
308return torch.norm(torch.stack([p.detach().norm() for p in module.parameters()]))
309
310def test_flat_param_shard_metadata_unaligned(self):
311"""
312Tests that ``FlatParameter`` shard metadata are computed as expected
313without any explicit alignment padding.
314"""
315module = torch.nn.Sequential(
316torch.nn.Linear(10, 10, bias=False),
317nn.ReLU(),
318torch.nn.Linear(10, 10, bias=False),
319nn.ReLU(),
320torch.nn.Linear(10, 10, bias=False),
321nn.ReLU(),
322)
323params_to_flatten = list(module.parameters())
324handle = FlatParamHandle(
325params_to_flatten,
326module,
327**self._get_default_config(),
328)
329
330self._test_flat_param_shard_metadata(
331handle,
332start=0,
333end=0,
334expected=FlatParamShardMetadata(
335param_names=["0.weight"],
336param_shapes=[(10, 10)],
337param_numels=[100],
338param_offsets=[(0, 0)],
339),
340)
341self._test_flat_param_shard_metadata(
342handle,
343start=0,
344end=50,
345expected=FlatParamShardMetadata(
346param_names=["0.weight"],
347param_shapes=[(10, 10)],
348param_numels=[100],
349param_offsets=[(0, 50)],
350),
351)
352self._test_flat_param_shard_metadata(
353handle,
354start=0,
355end=99,
356expected=FlatParamShardMetadata(
357param_names=["0.weight"],
358param_shapes=[(10, 10)],
359param_numels=[100],
360param_offsets=[(0, 99)],
361),
362)
363self._test_flat_param_shard_metadata(
364handle,
365start=50,
366end=149,
367expected=FlatParamShardMetadata(
368param_names=["0.weight", "2.weight"],
369param_shapes=[(10, 10), (10, 10)],
370param_numels=[100, 100],
371param_offsets=[(50, 99), (0, 49)],
372),
373)
374self._test_flat_param_shard_metadata(
375handle,
376start=50,
377end=199,
378expected=FlatParamShardMetadata(
379param_names=["0.weight", "2.weight"],
380param_shapes=[(10, 10), (10, 10)],
381param_numels=[100, 100],
382param_offsets=[(50, 99), (0, 99)],
383),
384)
385self._test_flat_param_shard_metadata(
386handle,
387start=99,
388end=199,
389expected=FlatParamShardMetadata(
390param_names=["0.weight", "2.weight"],
391param_shapes=[(10, 10), (10, 10)],
392param_numels=[100, 100],
393param_offsets=[(99, 99), (0, 99)],
394),
395)
396self._test_flat_param_shard_metadata(
397handle,
398start=100,
399end=199,
400expected=FlatParamShardMetadata(
401param_names=["2.weight"],
402param_shapes=[(10, 10)],
403param_numels=[100],
404param_offsets=[(0, 99)],
405),
406)
407self._test_flat_param_shard_metadata(
408handle,
409start=100,
410end=299,
411expected=FlatParamShardMetadata(
412param_names=["2.weight", "4.weight"],
413param_shapes=[(10, 10), (10, 10)],
414param_numels=[100, 100],
415param_offsets=[(0, 99), (0, 99)],
416),
417)
418self._test_flat_param_shard_metadata(
419handle,
420start=100,
421end=1000,
422expected=FlatParamShardMetadata(
423param_names=["2.weight", "4.weight"],
424param_shapes=[(10, 10), (10, 10)],
425param_numels=[100, 100],
426param_offsets=[(0, 99), (0, 99)],
427),
428)
429self._test_flat_param_shard_metadata(
430handle,
431start=299,
432end=299,
433expected=FlatParamShardMetadata(
434param_names=["4.weight"],
435param_shapes=[(10, 10)],
436param_numels=[100],
437param_offsets=[(99, 99)],
438),
439)
440
441def test_flat_param_shard_metadata_aligned_full_precision(self):
442"""
443Tests that ``FlatParameter`` shard metadata are computed as expected
444with alignment padding and parameter full precision.
445"""
446module = torch.nn.Sequential(
447torch.nn.Linear(3, 7, bias=False), # 0.weight
448torch.nn.Linear(7, 5, bias=False), # 1.weight
449torch.nn.Linear(5, 5, bias=False), # 2.weight
450)
451params_to_flatten = list(module.parameters())
452handle_kwargs = self._get_default_config()
453handle_kwargs["use_orig_params"] = True
454handle = FlatParamHandle(params_to_flatten, module, **handle_kwargs)
455# For 32-bit full precision, FSDP pads up to 3 numel after each
456# original parameter to achieve 0 mod 4 numel (i.e. 0 mod 16 bytes).
457# Thus, the unsharded `FlatParameter` layout looks like:
458# 21 + (3) + 35 + (1) + 25
459# where (x) means x numel of padding. This gives a total of 85 numel.
460
461# The `FlatParamShardMetadata` do not include alignment padding but do
462# account for them
463self._test_flat_param_shard_metadata(
464handle,
465# Emulate rank 0 of 2 ranks
466start=0,
467end=42,
468expected=FlatParamShardMetadata(
469param_names=["0.weight", "1.weight"],
470param_shapes=[(7, 3), (5, 7)],
471param_numels=[21, 35],
472# 21 + (3) + 19 = 43
473param_offsets=[(0, 20), (0, 18)],
474),
475)
476self._test_flat_param_shard_metadata(
477handle,
478# Emulate rank 1 of 2 ranks
479start=43,
480end=85,
481expected=FlatParamShardMetadata(
482param_names=["1.weight", "2.weight"],
483param_shapes=[(5, 7), (5, 5)],
484param_numels=[35, 25],
485# 16 + (1) + 25 = 42
486param_offsets=[(19, 34), (0, 24)],
487),
488)
489
490def test_flat_param_shard_metadata_aligned_mixed_precision(self):
491"""
492Tests that ``FlatParameter`` shard metadata are computed as expected
493with alignment padding and parameter mixed precision.
494"""
495module = torch.nn.Sequential(
496torch.nn.Linear(2, 5, bias=False), # 0.weight
497torch.nn.Linear(5, 5, bias=False), # 1.weight
498torch.nn.Linear(5, 3, bias=False), # 2.weight
499)
500params_to_flatten = list(module.parameters())
501handle_kwargs = self._get_default_config()
502handle_kwargs["use_orig_params"] = True
503handle_kwargs["mp_param_dtype"] = torch.float16
504handle = FlatParamHandle(params_to_flatten, module, **handle_kwargs)
505# For 16-bit mixed precision, FSDP pads up to 7 numel after each
506# original parameter to achieve 0 mod 8 numel (i.e. 0 mod 16 bytes).
507# Thus, the unsharded `FlatParameter` layout looks like:
508# 10 + (6) + 25 + (7) + 15
509# where (x) means x numel of padding. This gives a total of 63 numel.
510
511# The `FlatParamShardMetadata` do not include alignment padding but do
512# account for them
513self._test_flat_param_shard_metadata(
514handle,
515# Emulate rank 0 of 2 ranks
516start=0,
517end=31,
518expected=FlatParamShardMetadata(
519param_names=["0.weight", "1.weight"],
520param_shapes=[(5, 2), (5, 5)],
521param_numels=[10, 25],
522# 10 + (6) + 16 = 32
523param_offsets=[(0, 9), (0, 15)],
524),
525)
526self._test_flat_param_shard_metadata(
527handle,
528# Emulate rank 1 of 2 ranks
529start=32,
530end=63,
531expected=FlatParamShardMetadata(
532param_names=["1.weight", "2.weight"],
533param_shapes=[(5, 5), (3, 5)],
534param_numels=[25, 15],
535# 9 + (7) + 15 = 31
536param_offsets=[(16, 24), (0, 14)],
537),
538)
539
540def _test_flat_param_shard_metadata(
541self,
542handle: FlatParamHandle,
543start: int,
544end: int,
545expected: FlatParamShardMetadata,
546):
547"""
548Tests the subroutine ``_get_shard_metadata()`` that computes shard
549metadata based on start and end indices in the unsharded flat
550parameter, where both indices are inclusive.
551
552We manually set the relevant attributes on the flat parameter to be
553able to check the effect of ``_get_shard_metadata()`` via
554``shard_metadata()`` since normally the attributes are set in
555``_init_shard_metadata()`` with the start and end indices fixed based
556on rank and world size.
557"""
558flat_param = handle.flat_param
559flat_param._shard_param_infos = handle._get_shard_metadata(start, end)
560shard_metadata = handle.shard_metadata()
561self.assertEqual(
562shard_metadata,
563expected,
564msg=f"{handle.shard_metadata()}, {expected}",
565)
566
567
568if __name__ == "__main__":
569run_tests()
570