pytorch

Форк
0
/
test_fsdp_flatten_params.py 
569 строк · 20.0 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4

5
import torch
6
import torch.nn as nn
7
from torch import distributed as dist
8
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
9
from torch.distributed.fsdp._flat_param import (
10
    FlatParamHandle,
11
    FlatParamShardMetadata,
12
    HandleShardingStrategy,
13
)
14
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
15
from torch.testing._internal.common_fsdp import FSDPTest
16
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
17

18
if not dist.is_available():
19
    print("Distributed not available, skipping tests", file=sys.stderr)
20
    sys.exit(0)
21

22
if TEST_WITH_DEV_DBG_ASAN:
23
    print(
24
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
25
        file=sys.stderr,
26
    )
27
    sys.exit(0)
28

29

30
class TestFlattenParams(FSDPTest):
31
    """Tests parameter flattening and shard metadata logic."""
32

33
    @property
34
    def world_size(self) -> int:
35
        # Clamp the world size to 1 since these unit tests either exercise only
36
        # the flattening logic or check sharding subroutines directly without
37
        # requiring multiple ranks
38
        return 1
39

40
    def _get_default_config(self):
41
        return {
42
            "device": torch.device("cuda"),
43
            "sharding_strategy": HandleShardingStrategy.FULL_SHARD,
44
            "offload_params": False,
45
            "mp_param_dtype": None,
46
            "mp_reduce_dtype": None,
47
            "keep_low_precision_grads": False,
48
            "process_group": self.process_group,
49
            "use_orig_params": False,
50
            "fsdp_extension": None,
51
        }
52

53
    def _get_transformer(self, seed=0):
54
        torch.manual_seed(seed)  # keep everything deterministic
55
        module = torch.nn.Transformer(
56
            d_model=32,
57
            num_encoder_layers=2,
58
            num_decoder_layers=2,
59
            dim_feedforward=128,
60
            dropout=0.1,
61
        )
62
        module.register_buffer("dummy_buffer", torch.tensor(1.0))
63

64
        def get_input(device, dtype):
65
            torch.manual_seed(1)  # keep everything deterministic
66
            src = torch.rand(20, 8, 32).to(device=device, dtype=dtype)  # T x B x C
67
            tgt = torch.rand(10, 8, 32).to(device=device, dtype=dtype)  # T x B x C
68
            return (src, tgt)
69

70
        module.get_input = get_input
71
        return module
72

73
    def _get_shared_params_transformer(self, seed=0):
74
        module = self._get_transformer(seed=seed)
75
        # share the FFNs
76
        for enc_layer, dec_layer in zip(module.encoder.layers, module.decoder.layers):
77
            dec_layer.linear1.weight = enc_layer.linear1.weight
78
            dec_layer.linear2.weight = enc_layer.linear2.weight
79
        return module
80

81
    @skip_if_lt_x_gpu(1)
82
    def test_partial_flattening(self):
83
        """Tests flattening some submodules but not others."""
84
        self.run_subtests(
85
            {"half": [False, True]},
86
            self._test_partial_flattening,
87
        )
88

89
    def _test_partial_flattening(self, half: bool):
90
        module = self._get_transformer()
91
        if half:
92
            module = module.half()
93
        numel = sum(p.numel() for p in module.parameters())
94

95
        encoder_1_params = list(module.encoder.layers[1].parameters())
96
        decoder_0_params = list(module.decoder.layers[0].parameters())
97
        params_to_flatten = encoder_1_params + decoder_0_params
98
        num_params = [len(encoder_1_params), len(decoder_0_params)]
99
        numel_to_flatten = sum(p.numel() for p in params_to_flatten)
100
        module.encoder.layers[1] = FSDP(module.encoder.layers[1])
101
        module.decoder.layers[0] = FSDP(module.decoder.layers[0])
102
        flat_params = [
103
            module.encoder.layers[1]._flat_param,
104
            module.decoder.layers[0]._flat_param,
105
        ]
106

107
        self.assertEqual(sum(fp.numel() for fp in flat_params), numel_to_flatten)
108
        self.assertEqual(sum(p.numel() for p in module.parameters()), numel)
109

110
        # Check that flattened parameters have been replaced with a single
111
        # `FlatParameter`
112
        self.assertEqual(len(list(module.encoder.layers[1].parameters())), 1)
113
        self.assertEqual(len(list(module.decoder.layers[0].parameters())), 1)
114

115
        # Check that non-flattened parameters remain
116
        self.assertEqual(
117
            len(list(module.encoder.layers[0].parameters())), num_params[0]
118
        )
119
        self.assertEqual(
120
            len(list(module.decoder.layers[1].parameters())), num_params[1]
121
        )
122

123
        # Check that calling `module.to()` affects the `FlatParameter`s
124
        orig_dtype = params_to_flatten[0].dtype
125
        new_dtype = torch.float32 if orig_dtype == torch.float16 else torch.float16
126
        for flat_param in flat_params:
127
            self.assertEqual(flat_param.dtype, orig_dtype)
128
        self.assertTrue(
129
            all(p.dtype == orig_dtype for p in module.encoder.layers[0].parameters())
130
        )
131
        module = module.to(dtype=new_dtype)
132
        for flat_param in flat_params:
133
            self.assertEqual(flat_param.dtype, new_dtype)
134
        self.assertTrue(
135
            all(p.dtype == new_dtype for p in module.encoder.layers[0].parameters())
136
        )
137

138
    def test_flatten_nothing(self):
139
        """
140
        Tests that constructing a ``FlatParamHandle`` with no parameters
141
        raises an error.
142
        """
143
        self.run_subtests(
144
            {"half": [False, True]},
145
            self._test_flatten_nothing,
146
        )
147

148
    def _test_flatten_nothing(self, half: bool):
149
        module = self._get_transformer()
150
        if half:
151
            module = module.half()
152
        with self.assertRaisesRegex(
153
            ValueError,
154
            "Cannot construct a FlatParamHandle with an empty parameter list",
155
        ):
156
            FlatParamHandle(
157
                [],
158
                module,
159
                **self._get_default_config(),
160
            )
161

162
    @skip_if_lt_x_gpu(1)
163
    def test_empty_module(self):
164
        """
165
        Tests flattening an empty module (i.e. one without any parameters).
166
        """
167
        module = self._get_empty_module()
168
        in_data = torch.rand(1)
169
        ref_out = module(in_data)
170
        fsdp_module = FSDP(module)
171
        self.assertEqual(len(list(fsdp_module.parameters())), 0)
172
        self.assertIsNone(fsdp_module._flat_param)
173
        fsdp_out = fsdp_module(in_data)
174
        self.assertEqual(ref_out, fsdp_out)
175

176
    def _get_empty_module(self):
177
        """Returns a module with no parameters."""
178
        torch.manual_seed(0)  # keep everything deterministic
179

180
        class EmptyModule(torch.nn.Module):
181
            def forward(self, x):
182
                return x + 1
183

184
            def get_input(self, device, dtype):
185
                torch.manual_seed(1)  # keep everything deterministic
186
                return torch.rand(1).to(device=device, dtype=dtype)
187

188
        return EmptyModule()
189

190
    def test_numel_without_shared_params(self):
191
        """
192
        Tests that numel is preserved after flattening when there are no shared
193
        parameters in the module.
194
        """
195
        self.run_subtests(
196
            {"half": [False, True]},
197
            self._test_numel_without_shared_params,
198
        )
199

200
    def _test_numel_without_shared_params(self, half: bool):
201
        module = self._get_transformer()
202
        if half:
203
            module = module.half()
204
        self._test_numel(module)
205

206
    def test_numel_with_shared_params(self):
207
        """
208
        Tests that numel is preserved after flattening when there are shared
209
        parameters in the module.
210
        """
211
        self.run_subtests(
212
            {"half": [False, True]},
213
            self._test_numel_with_shared_params,
214
        )
215

216
    def _test_numel_with_shared_params(self, half: bool):
217
        module = self._get_shared_params_transformer()
218
        if half:
219
            module = module.half()
220
        self._test_numel(module)
221

222
    def _test_numel(self, module):
223
        ref_numel = sum(p.numel() for p in module.parameters())
224
        params_to_flatten = list(module.parameters())
225
        flat_param_handle = FlatParamHandle(
226
            params_to_flatten,
227
            module,
228
            **self._get_default_config(),
229
        )
230
        self.assertEqual(ref_numel, flat_param_handle.flat_param.numel())
231

232
    @skip_if_lt_x_gpu(1)
233
    def test_output_without_shared_params(self):
234
        """
235
        Tests a forward pass after flattening when there are no shared
236
        parameters in the module.
237
        """
238
        self.run_subtests(
239
            {"half": [False, True]},
240
            self._test_output_without_shared_params,
241
        )
242

243
    def _test_output_without_shared_params(self, half: bool):
244
        module = self._get_transformer()
245
        if half:
246
            module = module.half()
247
        self._test_output(module)
248

249
    @skip_if_lt_x_gpu(1)
250
    def test_output_with_shared_params(self):
251
        """
252
        Tests a forward pass after flattening when there are shared parameters
253
        in the module.
254
        """
255
        self.run_subtests(
256
            {"half": [False, True]},
257
            self._test_output_with_shared_params,
258
        )
259

260
    def _test_output_with_shared_params(self, half: bool):
261
        module = self._get_shared_params_transformer()
262
        if half:
263
            module = module.half()
264
        self._test_output(module)
265

266
    def _test_output(self, module: nn.Module):
267
        module = module.to(self.rank)
268
        ref_output = self._get_output(module)
269
        fsdp_module = FSDP(module)
270
        fsdp_output = self._get_output(fsdp_module)
271
        self.assertEqual(ref_output, fsdp_output)
272

273
    def _get_output(self, module):
274
        device = next(module.parameters()).device
275
        dtype = next(module.parameters()).dtype
276
        input = module.get_input(device, dtype)
277
        return module(*input)
278

279
    @skip_if_lt_x_gpu(1)
280
    def test_pnorm_after_step_with_shared_params(self):
281
        """
282
        Tests for parameter Frobenius norm parity after an optimizer step when
283
        there are shared parameters in the module. If the parameter sharing is
284
        handled incorrectly, then an optimizer step should reveal that.
285
        """
286
        self.run_subtests(
287
            {"half": [False, True]},
288
            self._test_pnorm_after_step_with_shared_params,
289
        )
290

291
    def _test_pnorm_after_step_with_shared_params(self, half: bool):
292
        module = self._get_shared_params_transformer().to(self.rank)
293
        if half:
294
            module = module.half()
295
        ref_pnorm_after_step = self._get_pnorm_after_step(module)
296
        module = self._get_shared_params_transformer().to(self.rank)  # recreate
297
        if half:
298
            module = module.half()
299
        fsdp_module = FSDP(module)
300
        fsdp_pnorm_after_step = self._get_pnorm_after_step(fsdp_module)
301
        self.assertEqual(ref_pnorm_after_step, fsdp_pnorm_after_step)
302

303
    def _get_pnorm_after_step(self, module):
304
        optim = torch.optim.SGD(module.parameters(), lr=0.01)
305
        loss = self._get_output(module).sum()
306
        loss.backward()
307
        optim.step()
308
        return torch.norm(torch.stack([p.detach().norm() for p in module.parameters()]))
309

310
    def test_flat_param_shard_metadata_unaligned(self):
311
        """
312
        Tests that ``FlatParameter`` shard metadata are computed as expected
313
        without any explicit alignment padding.
314
        """
315
        module = torch.nn.Sequential(
316
            torch.nn.Linear(10, 10, bias=False),
317
            nn.ReLU(),
318
            torch.nn.Linear(10, 10, bias=False),
319
            nn.ReLU(),
320
            torch.nn.Linear(10, 10, bias=False),
321
            nn.ReLU(),
322
        )
323
        params_to_flatten = list(module.parameters())
324
        handle = FlatParamHandle(
325
            params_to_flatten,
326
            module,
327
            **self._get_default_config(),
328
        )
329

330
        self._test_flat_param_shard_metadata(
331
            handle,
332
            start=0,
333
            end=0,
334
            expected=FlatParamShardMetadata(
335
                param_names=["0.weight"],
336
                param_shapes=[(10, 10)],
337
                param_numels=[100],
338
                param_offsets=[(0, 0)],
339
            ),
340
        )
341
        self._test_flat_param_shard_metadata(
342
            handle,
343
            start=0,
344
            end=50,
345
            expected=FlatParamShardMetadata(
346
                param_names=["0.weight"],
347
                param_shapes=[(10, 10)],
348
                param_numels=[100],
349
                param_offsets=[(0, 50)],
350
            ),
351
        )
352
        self._test_flat_param_shard_metadata(
353
            handle,
354
            start=0,
355
            end=99,
356
            expected=FlatParamShardMetadata(
357
                param_names=["0.weight"],
358
                param_shapes=[(10, 10)],
359
                param_numels=[100],
360
                param_offsets=[(0, 99)],
361
            ),
362
        )
363
        self._test_flat_param_shard_metadata(
364
            handle,
365
            start=50,
366
            end=149,
367
            expected=FlatParamShardMetadata(
368
                param_names=["0.weight", "2.weight"],
369
                param_shapes=[(10, 10), (10, 10)],
370
                param_numels=[100, 100],
371
                param_offsets=[(50, 99), (0, 49)],
372
            ),
373
        )
374
        self._test_flat_param_shard_metadata(
375
            handle,
376
            start=50,
377
            end=199,
378
            expected=FlatParamShardMetadata(
379
                param_names=["0.weight", "2.weight"],
380
                param_shapes=[(10, 10), (10, 10)],
381
                param_numels=[100, 100],
382
                param_offsets=[(50, 99), (0, 99)],
383
            ),
384
        )
385
        self._test_flat_param_shard_metadata(
386
            handle,
387
            start=99,
388
            end=199,
389
            expected=FlatParamShardMetadata(
390
                param_names=["0.weight", "2.weight"],
391
                param_shapes=[(10, 10), (10, 10)],
392
                param_numels=[100, 100],
393
                param_offsets=[(99, 99), (0, 99)],
394
            ),
395
        )
396
        self._test_flat_param_shard_metadata(
397
            handle,
398
            start=100,
399
            end=199,
400
            expected=FlatParamShardMetadata(
401
                param_names=["2.weight"],
402
                param_shapes=[(10, 10)],
403
                param_numels=[100],
404
                param_offsets=[(0, 99)],
405
            ),
406
        )
407
        self._test_flat_param_shard_metadata(
408
            handle,
409
            start=100,
410
            end=299,
411
            expected=FlatParamShardMetadata(
412
                param_names=["2.weight", "4.weight"],
413
                param_shapes=[(10, 10), (10, 10)],
414
                param_numels=[100, 100],
415
                param_offsets=[(0, 99), (0, 99)],
416
            ),
417
        )
418
        self._test_flat_param_shard_metadata(
419
            handle,
420
            start=100,
421
            end=1000,
422
            expected=FlatParamShardMetadata(
423
                param_names=["2.weight", "4.weight"],
424
                param_shapes=[(10, 10), (10, 10)],
425
                param_numels=[100, 100],
426
                param_offsets=[(0, 99), (0, 99)],
427
            ),
428
        )
429
        self._test_flat_param_shard_metadata(
430
            handle,
431
            start=299,
432
            end=299,
433
            expected=FlatParamShardMetadata(
434
                param_names=["4.weight"],
435
                param_shapes=[(10, 10)],
436
                param_numels=[100],
437
                param_offsets=[(99, 99)],
438
            ),
439
        )
440

441
    def test_flat_param_shard_metadata_aligned_full_precision(self):
442
        """
443
        Tests that ``FlatParameter`` shard metadata are computed as expected
444
        with alignment padding and parameter full precision.
445
        """
446
        module = torch.nn.Sequential(
447
            torch.nn.Linear(3, 7, bias=False),  # 0.weight
448
            torch.nn.Linear(7, 5, bias=False),  # 1.weight
449
            torch.nn.Linear(5, 5, bias=False),  # 2.weight
450
        )
451
        params_to_flatten = list(module.parameters())
452
        handle_kwargs = self._get_default_config()
453
        handle_kwargs["use_orig_params"] = True
454
        handle = FlatParamHandle(params_to_flatten, module, **handle_kwargs)
455
        # For 32-bit full precision, FSDP pads up to 3 numel after each
456
        # original parameter to achieve 0 mod 4 numel (i.e. 0 mod 16 bytes).
457
        # Thus, the unsharded `FlatParameter` layout looks like:
458
        #   21 + (3) + 35 + (1) + 25
459
        # where (x) means x numel of padding. This gives a total of 85 numel.
460

461
        # The `FlatParamShardMetadata` do not include alignment padding but do
462
        # account for them
463
        self._test_flat_param_shard_metadata(
464
            handle,
465
            # Emulate rank 0 of 2 ranks
466
            start=0,
467
            end=42,
468
            expected=FlatParamShardMetadata(
469
                param_names=["0.weight", "1.weight"],
470
                param_shapes=[(7, 3), (5, 7)],
471
                param_numels=[21, 35],
472
                # 21 + (3) + 19 = 43
473
                param_offsets=[(0, 20), (0, 18)],
474
            ),
475
        )
476
        self._test_flat_param_shard_metadata(
477
            handle,
478
            # Emulate rank 1 of 2 ranks
479
            start=43,
480
            end=85,
481
            expected=FlatParamShardMetadata(
482
                param_names=["1.weight", "2.weight"],
483
                param_shapes=[(5, 7), (5, 5)],
484
                param_numels=[35, 25],
485
                # 16 + (1) + 25 = 42
486
                param_offsets=[(19, 34), (0, 24)],
487
            ),
488
        )
489

490
    def test_flat_param_shard_metadata_aligned_mixed_precision(self):
491
        """
492
        Tests that ``FlatParameter`` shard metadata are computed as expected
493
        with alignment padding and parameter mixed precision.
494
        """
495
        module = torch.nn.Sequential(
496
            torch.nn.Linear(2, 5, bias=False),  # 0.weight
497
            torch.nn.Linear(5, 5, bias=False),  # 1.weight
498
            torch.nn.Linear(5, 3, bias=False),  # 2.weight
499
        )
500
        params_to_flatten = list(module.parameters())
501
        handle_kwargs = self._get_default_config()
502
        handle_kwargs["use_orig_params"] = True
503
        handle_kwargs["mp_param_dtype"] = torch.float16
504
        handle = FlatParamHandle(params_to_flatten, module, **handle_kwargs)
505
        # For 16-bit mixed precision, FSDP pads up to 7 numel after each
506
        # original parameter to achieve 0 mod 8 numel (i.e. 0 mod 16 bytes).
507
        # Thus, the unsharded `FlatParameter` layout looks like:
508
        #   10 + (6) + 25 + (7) + 15
509
        # where (x) means x numel of padding. This gives a total of 63 numel.
510

511
        # The `FlatParamShardMetadata` do not include alignment padding but do
512
        # account for them
513
        self._test_flat_param_shard_metadata(
514
            handle,
515
            # Emulate rank 0 of 2 ranks
516
            start=0,
517
            end=31,
518
            expected=FlatParamShardMetadata(
519
                param_names=["0.weight", "1.weight"],
520
                param_shapes=[(5, 2), (5, 5)],
521
                param_numels=[10, 25],
522
                # 10 + (6) + 16 = 32
523
                param_offsets=[(0, 9), (0, 15)],
524
            ),
525
        )
526
        self._test_flat_param_shard_metadata(
527
            handle,
528
            # Emulate rank 1 of 2 ranks
529
            start=32,
530
            end=63,
531
            expected=FlatParamShardMetadata(
532
                param_names=["1.weight", "2.weight"],
533
                param_shapes=[(5, 5), (3, 5)],
534
                param_numels=[25, 15],
535
                # 9 + (7) + 15 = 31
536
                param_offsets=[(16, 24), (0, 14)],
537
            ),
538
        )
539

540
    def _test_flat_param_shard_metadata(
541
        self,
542
        handle: FlatParamHandle,
543
        start: int,
544
        end: int,
545
        expected: FlatParamShardMetadata,
546
    ):
547
        """
548
        Tests the subroutine ``_get_shard_metadata()`` that computes shard
549
        metadata based on start and end indices in the unsharded flat
550
        parameter, where both indices are inclusive.
551

552
        We manually set the relevant attributes on the flat parameter to be
553
        able to check the effect of ``_get_shard_metadata()`` via
554
        ``shard_metadata()`` since normally the attributes are set in
555
        ``_init_shard_metadata()`` with the start and end indices fixed based
556
        on rank and world size.
557
        """
558
        flat_param = handle.flat_param
559
        flat_param._shard_param_infos = handle._get_shard_metadata(start, end)
560
        shard_metadata = handle.shard_metadata()
561
        self.assertEqual(
562
            shard_metadata,
563
            expected,
564
            msg=f"{handle.shard_metadata()}, {expected}",
565
        )
566

567

568
if __name__ == "__main__":
569
    run_tests()
570

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.