pytorch

Форк
0
/
test_fsdp_misc.py 
1105 строк · 40.9 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import functools
4
import os
5
import sys
6
import warnings
7
from collections import namedtuple
8
from contextlib import nullcontext
9
from copy import deepcopy
10
from itertools import chain
11
from typing import Any, Tuple
12

13
import torch
14
import torch.distributed as dist
15
import torch.distributed.fsdp._traversal_utils as traversal_utils
16
import torch.nn as nn
17
from torch.distributed.fsdp import (
18
    CPUOffload,
19
    FlatParameter,
20
    FullyShardedDataParallel as FSDP,
21
    ShardingStrategy,
22
)
23
from torch.distributed.fsdp._flat_param import _FSDP_USE_UNSAFE_SETATTR
24
from torch.distributed.fsdp._runtime_utils import HOMOGENEOUS_ATTR_NAMES
25
from torch.distributed.fsdp.wrap import (
26
    always_wrap_policy,
27
    ModuleWrapPolicy,
28
    transformer_auto_wrap_policy,
29
)
30
from torch.distributed.optim import _apply_optimizer_in_backward
31
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
32
from torch.nn.parallel import DistributedDataParallel as DDP
33
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
34
from torch.testing._internal.common_fsdp import (
35
    _assert_module_states,
36
    CUDAInitMode,
37
    FSDPInitMode,
38
    FSDPTest,
39
    FSDPTestMultiThread,
40
    NestedWrappedModule,
41
    TransformerWithSharedParams,
42
)
43
from torch.testing._internal.common_utils import (
44
    instantiate_parametrized_tests,
45
    parametrize,
46
    run_tests,
47
    TEST_WITH_DEV_DBG_ASAN,
48
)
49

50
if not dist.is_available():
51
    print("Distributed not available, skipping tests", file=sys.stderr)
52
    sys.exit(0)
53

54
if TEST_WITH_DEV_DBG_ASAN:
55
    print(
56
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
57
        file=sys.stderr,
58
    )
59
    sys.exit(0)
60

61

62
class MyModel(nn.Module):
63
    def __init__(self):
64
        super().__init__()
65
        self.a = nn.Linear(2, 2)
66
        self.b = nn.Linear(2, 2)
67

68
    def forward(self, x, y):
69
        return self.b(self.a(x + y))
70

71

72
class TestFSDPMiscMultiProcess(FSDPTest):
73
    @property
74
    def world_size(self):
75
        return 2
76

77
    @property
78
    def process_group(self):
79
        return dist.distributed_c10d._get_default_group()
80

81
    @skip_if_lt_x_gpu(2)
82
    @parametrize("use_index", [True, False])
83
    def test_fsdp_device_id(self, use_index):
84
        """
85
        Tests the FSDP ``device_id`` argument:
86
          - Wrapping a CPU module should move the module to the GPU matching
87
          ``device_id``
88
          - Wrapping a GPU module already on the GPU matching ``device_id``
89
          should not raise an error
90
          - Wrapping a GPU module already on GPU and passing a GPU device
91
          without specifying a device ID (i.e. ``torch.device("cuda")``) warns
92
        """
93
        dev_id = (
94
            torch.cuda.current_device()
95
            if use_index
96
            else torch.device("cuda", torch.cuda.current_device())
97
        )
98

99
        def _check_device_matches(module, device_id):
100
            """Checks that the ``FlatParameter``s in ``module`` have device
101
            matching ``device_id``."""
102
            devices = {
103
                p.device for p in module.parameters() if isinstance(p, FlatParameter)
104
            }
105
            assert len(devices) > 0
106
            self.assertEqual(1, len(devices))
107
            found_device = devices.pop()
108
            if use_index and not isinstance(device_id, torch.device):
109
                device = torch.device("cuda", device_id)
110
            else:
111
                device = device_id
112
            self.assertEqual(found_device, device)
113

114
        # Check that FSDP parameters are moved to `device_id` for a CPU module
115
        nested_wrapped_module = NestedWrappedModule.init(
116
            self.process_group,
117
            FSDPInitMode.RECURSIVE,
118
            CUDAInitMode.CUDA_NEVER,
119
            fsdp_kwargs={"device_id": dev_id},
120
        )
121
        _check_device_matches(nested_wrapped_module, dev_id)
122
        # Check that specifying `device_id` for a GPU module already on that
123
        # device does not raise an error
124
        nested_wrapped_module = NestedWrappedModule.init(
125
            self.process_group,
126
            FSDPInitMode.RECURSIVE,
127
            CUDAInitMode.CUDA_BEFORE,
128
            fsdp_kwargs={"device_id": dev_id},
129
        )
130
        _check_device_matches(nested_wrapped_module, dev_id)
131
        # Check that passing in `torch.device("cuda")` for a GPU module warns
132
        regex = "does not have an explicit index"
133
        context = self.assertWarnsRegex(
134
            expected_warning=UserWarning, expected_regex=regex
135
        )
136
        with context:
137
            nested_wrapped_module = NestedWrappedModule.init(
138
                self.process_group,
139
                FSDPInitMode.RECURSIVE,
140
                CUDAInitMode.CUDA_BEFORE,
141
                fsdp_kwargs={"device_id": torch.device("cuda")},
142
            )
143
        _check_device_matches(
144
            nested_wrapped_module, torch.device("cuda", torch.cuda.current_device())
145
        )
146

147
    @skip_if_lt_x_gpu(2)
148
    def test_fsdp_zero2_eval_with_prefetch(self):
149
        # Test FSDP validation with SHARD_GRAD_OP and forward_prefetch
150

151
        class Mnist(nn.Module):
152
            def __init__(self):
153
                super().__init__()
154
                self.conv1 = nn.Conv2d(1, 32, 3, 1)
155
                self.conv2 = nn.Conv2d(32, 64, 3, 1)
156
                self.dropout1 = nn.Dropout(0.25)
157
                self.dropout2 = nn.Dropout(0.5)
158
                self.fc1 = nn.Linear(9216, 128)
159
                self.fc2 = nn.Linear(128, 10)
160
                self.ln = nn.LayerNorm(9216)
161

162
            def forward(self, x, y):
163
                x = self.conv1(x)
164
                x = torch.nn.functional.relu(x)
165
                x = self.conv2(x)
166
                x = torch.nn.functional.relu(x)
167
                x = torch.nn.functional.max_pool2d(x, 2)
168
                x = self.dropout1(x)
169
                x = torch.flatten(x, 1)
170
                x = self.ln(x)
171
                x = self.fc1(x)
172
                x = torch.nn.functional.relu(x)
173
                x = self.dropout2(x)
174
                x = self.fc2(x)
175
                output = torch.nn.functional.log_softmax(x, dim=1)
176
                loss = torch.nn.functional.cross_entropy(output, y)
177
                return loss
178

179
        model = Mnist().cuda()
180
        model1 = Mnist().cuda()
181
        model1.load_state_dict(model.state_dict())
182
        fsdp_model = FSDP(
183
            model,
184
            sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
185
            forward_prefetch=True,
186
            use_orig_params=True,
187
            auto_wrap_policy=ModuleWrapPolicy([nn.Linear, nn.Conv2d]),
188
        )
189
        ddp_model = torch.nn.parallel.DistributedDataParallel(
190
            model1,
191
        )
192

193
        fsdp_opt = torch.optim.SGD(fsdp_model.parameters(), lr=1e-4)
194
        ddp_opt = torch.optim.SGD(ddp_model.parameters(), lr=1e-4)
195

196
        seed = self.rank + 20231010
197
        torch.manual_seed(seed)
198
        torch.cuda.manual_seed(seed)
199

200
        losses = []
201
        grads = []
202
        for i in range(5):
203
            x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
204
            y = torch.randint(low=0, high=9, size=(8,), device="cuda")
205
            for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
206
                seed = self.rank + i
207
                torch.manual_seed(seed)
208
                torch.cuda.manual_seed(seed)
209
                loss = model(x, y).sum()
210
                losses.append(loss)
211
                loss.backward()
212
                opt.step()
213
                grads.append(x.grad)
214
                opt.zero_grad()
215
            assert torch.allclose(losses[0], losses[1])
216
            assert torch.allclose(grads[0], grads[1])
217
            losses.clear()
218
            grads.clear()
219

220
        with torch.no_grad():
221
            fsdp_model.eval()
222
            ddp_model.eval()
223
            for _ in range(5):
224
                x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
225
                y = torch.randint(low=0, high=9, size=(8,), device="cuda")
226
                fsdp_loss = fsdp_model(x, y)
227
                ddp_loss = ddp_model(x, y)
228
                assert torch.allclose(fsdp_loss, ddp_loss)
229

230
        fsdp_model.train()
231
        ddp_model.train()
232
        for i in range(5):
233
            x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
234
            y = torch.randint(low=0, high=9, size=(8,), device="cuda")
235
            for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
236
                seed = self.rank + i
237
                torch.manual_seed(seed)
238
                torch.cuda.manual_seed(seed)
239
                loss = model(x, y).sum()
240
                losses.append(loss)
241
                loss.backward()
242
                opt.step()
243
                grads.append(x.grad)
244
                opt.zero_grad()
245
            assert torch.allclose(losses[0], losses[1])
246
            assert torch.allclose(grads[0], grads[1])
247
            losses.clear()
248
            grads.clear()
249

250
    @skip_if_lt_x_gpu(2)
251
    @parametrize("use_second_layer", [True, False])
252
    @parametrize("sharding_strategy", [ShardingStrategy.NO_SHARD, None])
253
    def test_fsdp_module_no_compute_grad(self, use_second_layer, sharding_strategy):
254
        # When use_second_layer=True, b is involved in forward computation but does
255
        # not receive grad in backward. Otherwise, b is not involved in forward
256
        # computation.
257

258
        class MyModel(nn.Module):
259
            def __init__(self):
260
                super().__init__()
261
                self.a = nn.Linear(10, 10)
262
                self.b = nn.Linear(10, 10)
263

264
            def forward(self, x, y):
265
                out1 = self.a(x)
266
                if use_second_layer:
267
                    out2 = self.b(y)
268
                    return out1, out2
269
                else:
270
                    return out1
271

272
        fsdp = FSDP(
273
            MyModel().cuda(),
274
            sharding_strategy=sharding_strategy,
275
            auto_wrap_policy=always_wrap_policy,
276
        )
277
        x = torch.randn(10, 10, device="cuda")
278
        y = torch.randn(10, 10, device="cuda")
279
        for i in range(4):
280
            if use_second_layer:
281
                a, b = fsdp(x, y)
282
            else:
283
                a = fsdp(x, y)
284
            loss = a.sum()
285
            loss.backward()
286

287
            # self.a receives grad, self.b does not
288
            a_grad = fsdp.module.a._handle.flat_param.grad
289
            b_grad = fsdp.module.b._handle.flat_param.grad
290
            self.assertIsNotNone(a_grad)
291
            self.assertIsNone(b_grad)
292

293
    @skip_if_lt_x_gpu(2)
294
    def test_fsdp_not_all_outputs_used_in_loss(self):
295
        self.run_subtests(
296
            {
297
                "sharding_strategy": [
298
                    ShardingStrategy.FULL_SHARD,
299
                    ShardingStrategy.SHARD_GRAD_OP,
300
                    ShardingStrategy.NO_SHARD,
301
                ]
302
            },
303
            self._test_fsdp_not_all_outputs_used_in_loss,
304
        )
305

306
    def _test_fsdp_not_all_outputs_used_in_loss(
307
        self, sharding_strategy: ShardingStrategy
308
    ):
309
        class MyModule(nn.Module):
310
            def __init__(self):
311
                super().__init__()
312
                self.lin1 = nn.Linear(4, 4)
313
                self.lin2 = nn.Linear(4, 4)
314

315
            def forward(self, x):
316
                a = self.lin1(x)
317
                b = self.lin2(x)
318
                return (a, b)
319

320
        def _check_resharded(fsdp_module):
321
            handle = fsdp_module._handle
322
            if not handle:
323
                return
324
            param = handle.flat_param
325
            if handle.uses_sharded_strategy:
326
                full_param = param._full_param_padded
327
                self.assertEqual(full_param.storage().size(), 0)
328

329
            self.assertEqual(param.data_ptr(), param._local_shard.data_ptr())
330

331
        def _check_equal(local, fsdp):
332
            with FSDP.summon_full_params(fsdp):
333
                for p1, p2 in zip(fsdp.parameters(), local.parameters()):
334
                    torch.testing.assert_close(p1, p2)
335

336
        fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
337
        m = MyModule().cuda()
338
        m_local = deepcopy(m)
339
        local_m = m_local
340
        prev_params = [p.clone() for p in m_local.parameters()]
341

342
        m.lin1 = fsdp_ctor(m.lin1)
343
        m = fsdp_ctor(m)
344
        _check_equal(m_local, m)
345

346
        opt = torch.optim.SGD(m.parameters(), lr=1e-3)
347
        opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)
348

349
        for i in range(6):
350
            t = torch.ones(4, device="cuda")
351
            a, b = m(t)
352
            local_a, local_b = local_m(t)
353
            if i < 2:
354
                # use both params in loss computation. Later,
355
                # b will go unused and we check grads are the
356
                # same as local training.
357
                loss = (a @ b).sum()
358
                loss_local = (local_a @ local_b).sum()
359
            else:
360
                loss = a.sum()
361
                loss_local = local_a.sum()
362

363
            loss.backward()
364
            loss_local.backward()
365
            _check_resharded(m)
366
            opt.step()
367
            opt_local.step()
368
            _check_equal(m_local, m)
369
            # Ensure at least some change from previous params, otherwise
370
            # above check would be vacuously true.
371
            self.assertTrue(
372
                any(
373
                    not torch.equal(p1, p2)
374
                    for p1, p2 in zip(prev_params, m_local.parameters())
375
                )
376
            )
377
            prev_params = [p.clone() for p in local_m.parameters()]
378
            opt.zero_grad()
379
            opt_local.zero_grad()
380

381
        dist.barrier()
382

383
    @skip_if_lt_x_gpu(2)
384
    def test_fsdp_optim_overlap_no_use_orig_params_error(self):
385
        fsdp_overlap = FSDP(
386
            MyModel().cuda(),
387
            auto_wrap_policy=always_wrap_policy,
388
            use_orig_params=False,
389
        )
390
        optim_cls = torch.optim.SGD
391
        optim_kwargs = {"lr": 0.03}
392
        _apply_optimizer_in_backward(
393
            optimizer_class=optim_cls,
394
            params=fsdp_overlap.parameters(),
395
            optimizer_kwargs=optim_kwargs,
396
            register_hook=False,
397
        )
398

399
        inp = torch.randn(10, 10, device="cuda")
400
        with self.assertRaisesRegex(
401
            RuntimeError, "only supported with use_orig_params=True"
402
        ):
403
            fsdp_overlap(inp, inp)
404

405
    @skip_if_lt_x_gpu(2)
406
    def test_fsdp_optimizer_overlap(self):
407
        torch.manual_seed(0)
408
        for cpu_offload in [True, False]:
409
            offload = CPUOffload(offload_params=cpu_offload)
410
            model = MyModel().cuda()
411
            model_overlap = deepcopy(model)
412
            fsdp = FSDP(
413
                model.cuda(),
414
                auto_wrap_policy=always_wrap_policy,
415
                use_orig_params=True,
416
                cpu_offload=offload,
417
            )
418
            fsdp_overlap = FSDP(
419
                model_overlap.cuda(),
420
                auto_wrap_policy=always_wrap_policy,
421
                use_orig_params=True,
422
                cpu_offload=offload,
423
            )
424
            optim_cls = torch.optim.SGD
425
            optim_kwargs = {"lr": 0.03}
426
            _apply_optimizer_in_backward(
427
                optimizer_class=optim_cls,
428
                params=fsdp_overlap.parameters(),
429
                optimizer_kwargs=optim_kwargs,
430
                register_hook=False,
431
            )
432
            for p in fsdp_overlap.parameters():
433
                assert hasattr(p, "_in_backward_optimizers")
434
            optim = optim_cls(fsdp.parameters(), **optim_kwargs)
435

436
            # Verify params initially equal
437
            for p1, p2 in zip(fsdp.parameters(), fsdp_overlap.parameters()):
438
                self.assertEqual(p1, p2)
439

440
            with FSDP.summon_full_params(fsdp_overlap):
441
                fsdp_overlap_prev_params = [
442
                    (n, p.clone()) for n, p in fsdp_overlap.named_parameters()
443
                ]
444

445
            for i in range(6):
446
                inp = torch.randn(2, 2, device="cuda")
447
                with torch.no_grad():
448
                    inp_clone = inp.clone()
449
                fsdp(inp, inp).sum().backward()
450
                fsdp_overlap(inp_clone, inp_clone).sum().backward()
451

452
                optim.step()
453
                optim.zero_grad()
454

455
                # Overlapped optimizer FSDP module should have sharded_grad as None.
456
                for fsdp_unit in FSDP.fsdp_modules(fsdp_overlap):
457
                    handle = fsdp_unit._handle
458
                    if handle:
459
                        handle_grad = handle.sharded_grad
460
                        self.assertEqual(
461
                            None,
462
                            handle_grad,
463
                            "Overlapped FSDP sharded_grad is not None!",
464
                        )
465

466
                # Note: FSDP without optimizer overlap won't set sharded_grad to None until the next
467
                # pre-forward since it needs to run FSDP specific logic that picks up that set_to_none=True
468
                # has been called (or that the gradients have been otherwise set to None)
469

470
                # Verify parameters are different than prev iteration
471
                with FSDP.summon_full_params(fsdp_overlap, with_grads=True):
472
                    for (n, p), (n_prev, p_prev) in zip(
473
                        fsdp_overlap.named_parameters(), fsdp_overlap_prev_params
474
                    ):
475
                        self.assertNotEqual(
476
                            p,
477
                            p_prev,
478
                            f"{n_prev} Params at iter {i} same as previous iter!",
479
                        )
480

481
                # Verify overlap and non overlapped are the same
482
                with FSDP.summon_full_params(fsdp_overlap):
483
                    with FSDP.summon_full_params(fsdp):
484
                        for (n_overlap, p_overlap), (n, p) in zip(
485
                            fsdp_overlap.named_parameters(), fsdp.named_parameters()
486
                        ):
487
                            self.assertEqual(n_overlap, n)
488
                            self.assertEqual(
489
                                p,
490
                                p_overlap,
491
                                f"Rank {self.rank}: Params not equal at iteration {i}: {n_overlap} - {p} vs {p_overlap}",
492
                            )
493
                            self.assertEqual(
494
                                None, p.grad, f"Expected param {n} grad to be None"
495
                            )
496
                            self.assertEqual(
497
                                None,
498
                                p_overlap.grad,
499
                                f"Expected param {n_overlap} grad to be None",
500
                            )
501

502
                    fsdp_overlap_prev_params = [
503
                        (n, p.clone()) for n, p in fsdp_overlap.named_parameters()
504
                    ]
505

506
    @skip_if_lt_x_gpu(2)
507
    def test_fsdp_cpu_training(self):
508
        """Tests FSDP training on CPU."""
509
        gloo_pg = dist.new_group(backend="gloo")
510
        for ss in [
511
            ShardingStrategy.NO_SHARD,
512
            ShardingStrategy.FULL_SHARD,
513
            ShardingStrategy.SHARD_GRAD_OP,
514
            ShardingStrategy.HYBRID_SHARD,
515
            ShardingStrategy._HYBRID_SHARD_ZERO2,
516
        ]:
517
            torch.manual_seed(42)
518
            model = MyModel()
519
            ref_model = DDP(deepcopy(model), process_group=gloo_pg)
520
            model = FSDP(
521
                model,
522
                auto_wrap_policy=always_wrap_policy,
523
                process_group=gloo_pg,
524
                device_id=torch.device("cpu"),
525
            )
526
            ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
527
            optim = torch.optim.Adam(model.parameters(), lr=1e-2)
528
            torch.manual_seed(42 + self.rank)
529
            inp = torch.randn(2, 2)
530
            for _ in range(10):
531
                losses = []
532
                for _model, _optim in ((ref_model, ref_optim), (model, optim)):
533
                    loss = _model(inp, inp).sum()
534
                    losses.append(loss)
535
                    loss.backward()
536
                    _optim.step()
537
                    _optim.zero_grad()
538
                self.assertEqual(losses[0], losses[1])
539

540
    @skip_if_lt_x_gpu(2)
541
    def test_fsdp_cpu_init_stays_on_cpu(self):
542
        # Move me to MT test once warning logging and backward collective issue
543
        # is resolved.
544
        """Tests that passing a CPU module to FSDP preserves that the wrapped
545
        module is on CPU after FSDP initialization, albeit after logging a
546
        warning, and that FSDP moves CPU input to GPU before the forward."""
547
        torch.cuda.set_device(self.rank)
548
        regex = "passed-in `module` is on CPU"
549
        context = self.assertWarnsRegex(
550
            expected_warning=UserWarning, expected_regex=regex
551
        )
552
        with context:
553
            nested_wrapped_module = NestedWrappedModule.init(
554
                self.process_group,
555
                FSDPInitMode.RECURSIVE,
556
                CUDAInitMode.CUDA_NEVER,
557
            )
558
            fsdp_model = FSDP(nested_wrapped_module, self.process_group)
559
        devices = {p.device for p in fsdp_model.parameters()}
560
        self.assertEqual(1, len(devices))
561
        self.assertEqual(torch.device("cpu"), devices.pop())
562
        fsdp_model = fsdp_model.cuda()
563
        # Ensure fwd + backward can be performed after moving to CUDA.
564
        # CPU input also tests that input is correctly moved to appropriate
565
        # CUDA device.
566
        inp = fsdp_model.module.get_input(device=torch.device("cpu"))
567
        fsdp_model(*inp).sum().backward()
568

569
    @skip_if_lt_x_gpu(2)
570
    def test_cpu_init_with_sync_module_states(self):
571
        """
572
        Tests that passing ``sync_module_states=True`` raises an error for
573
        a CPU module since the synchronization requires GPU communication,
574
        while additionally passing ``device_id`` does not raise an error, even
575
        when the model has CPU buffers.
576
        """
577

578
        def init_nested_wrapped_module():
579
            return NestedWrappedModule.init(
580
                self.process_group,
581
                FSDPInitMode.NO_FSDP,
582
                CUDAInitMode.CUDA_NEVER,
583
            )
584

585
        with self.assertRaisesRegex(
586
            ValueError,
587
            "The module has CPU parameters or buffers when `sync_module_states=True`",
588
        ):
589
            FSDP(
590
                init_nested_wrapped_module(),
591
                self.process_group,
592
                sync_module_states=True,
593
            )
594

595
        # Check that `device_id` with `sync_module_states=True` works
596
        nested_wrapped_module = init_nested_wrapped_module()
597
        nested_wrapped_module.register_buffer(
598
            "buf", torch.ones((2, 2), device="cpu") * self.rank
599
        )
600
        nested_wrapped_module.module[0].register_buffer(
601
            "buf", torch.ones((3, 2), device="cpu") * self.rank
602
        )
603
        nested_wrapped_module = FSDP(
604
            nested_wrapped_module,
605
            self.process_group,
606
            auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
607
            device_id=torch.cuda.current_device(),
608
            sync_module_states=True,
609
        )
610
        # Each rank's buffers should be 0s since rank 0 is the source, and they
611
        # should be on GPU since we specified `device_id`
612
        self.assertEqual(
613
            nested_wrapped_module.buf.device,
614
            torch.device("cuda", torch.cuda.current_device()),
615
        )
616
        self.assertEqual(nested_wrapped_module.buf, torch.zeros((2, 2)))
617
        self.assertEqual(
618
            nested_wrapped_module.module.module[0].buf.device,
619
            torch.device("cuda", torch.cuda.current_device()),
620
        )
621
        self.assertEqual(
622
            nested_wrapped_module.module.module[0].buf, torch.zeros((3, 2))
623
        )
624

625

626
class TestFSDPMiscMultiThread(FSDPTestMultiThread):
627
    @property
628
    def world_size(self):
629
        return 2
630

631
    @property
632
    def process_group(self):
633
        return dist.distributed_c10d._get_default_group()
634

635
    @skip_if_lt_x_gpu(2)
636
    def test_fsdp_namedtuple(self):
637
        class MyModule(nn.Module):
638
            def __init__(self):
639
                super().__init__()
640
                self.lin = nn.Linear(100, 100)
641

642
            def forward(self, x):
643
                return x
644

645
        m = MyModule().cuda()
646
        m = FSDP(m)
647
        t = torch.ones(1, device="cuda", requires_grad=True)
648

649
        MyOutputType = namedtuple(
650
            "MyOutputType", ["a", "b", "c", "d"], defaults=(t, t, t, t)
651
        )
652

653
        inp = MyOutputType()
654
        out = m(inp)
655
        # Ensure hooks are registered
656
        for x in out:
657
            self.assertNotEqual([], list(x._backward_hooks.values()))
658

659
        # TODO: we should check backward() and param is resharded
660
        # as well, but this is blocked by
661
        # https://github.com/pytorch/pytorch/issues/83107 and
662
        # https://github.com/pytorch/pytorch/issues/83129
663

664
    @skip_if_lt_x_gpu(2)
665
    def test_device_id_auto_wrap(self):
666
        """Tests that ``auto_wrap_policy`` propagates ``device_id`` to all
667
        nested FSDP instances."""
668
        self.run_subtests(
669
            {"use_callable": [False, True]},
670
            self._test_device_id_auto_wrap,
671
        )
672

673
    def _test_device_id_auto_wrap(self, use_callable: bool):
674
        module_classes = {TransformerEncoderLayer, TransformerDecoderLayer}
675
        if use_callable:
676
            auto_wrap_policy = functools.partial(
677
                transformer_auto_wrap_policy,
678
                transformer_layer_cls=module_classes,
679
            )
680
        else:
681
            auto_wrap_policy = ModuleWrapPolicy(module_classes)
682
        fsdp_kwargs = {
683
            "auto_wrap_policy": auto_wrap_policy,
684
            "device_id": torch.cuda.current_device(),
685
        }
686
        fsdp_model = TransformerWithSharedParams.init(
687
            self.process_group,
688
            FSDPInitMode.RECURSIVE,
689
            CUDAInitMode.CUDA_BEFORE,
690
            fsdp_kwargs,
691
        )
692
        for fsdp_module in FSDP.fsdp_modules(fsdp_model):
693
            self.assertEqual(
694
                fsdp_module.compute_device,
695
                torch.device("cuda", torch.cuda.current_device()),
696
            )
697

698
    @skip_if_lt_x_gpu(2)
699
    def test_fsdp_device_id_cpu_offload(self):
700
        """
701
        Tests FSDP when specifying both ``device_id`` and parameter CPU
702
        offloading.
703
        """
704
        self.run_subtests(
705
            {"use_orig_params": [False, True]},
706
            self._test_fsdp_device_id_cpu_offload,
707
        )
708

709
    def _test_fsdp_device_id_cpu_offload(self, use_orig_params: bool):
710
        class MyModel(nn.Module):
711
            def __init__(self):
712
                super().__init__()
713
                self.seq = nn.Sequential(
714
                    nn.Linear(10, 10),
715
                    nn.Linear(10, 10),
716
                )
717
                self.lin = nn.Linear(10, 10)
718

719
            def forward(self, x):
720
                return self.lin(self.seq(x))
721

722
        model = MyModel()
723
        # Choose a wrapping policy such that there are (1) nested FSDP
724
        # instances and (2) the parent FSDP instance has managed parameters
725
        auto_wrap_policy = ModuleWrapPolicy({nn.Sequential})
726
        fsdp_model = FSDP(
727
            model,
728
            auto_wrap_policy=auto_wrap_policy,
729
            cpu_offload=CPUOffload(offload_params=True),
730
            device_id=torch.cuda.current_device(),
731
            use_orig_params=use_orig_params,
732
        )
733
        cpu_device = torch.device("cpu")
734
        for handle in traversal_utils._get_fsdp_handles(fsdp_model):
735
            self.assertEqual(handle.flat_param.device, cpu_device)
736

737
    @skip_if_lt_x_gpu(2)
738
    def test_module_device_mismatches_device_id(self):
739
        """Tests that specifying a ``device_id`` argument to FSDP for a GPU
740
        module that does not match the GPU device ID raises an error."""
741
        # TODO: override FSDP MT Thread _run to set this instead of here for
742
        # every test.
743
        torch.cuda.set_device(self.rank)
744
        context = (
745
            self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0")
746
            if self.rank != 0
747
            else nullcontext()
748
        )
749
        with context:
750
            NestedWrappedModule.init(
751
                self.process_group,
752
                FSDPInitMode.RECURSIVE,
753
                # Move wrapped modules to CUDA before wrapping with FSDP
754
                cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
755
                # Should raise error since rank 1 is given `device_id=0` when
756
                # the model is on cuda:1
757
                fsdp_kwargs={"device_id": 0},
758
            )
759

760
    @skip_if_lt_x_gpu(2)
761
    def test_cpu_gpu_module(self):
762
        """Tests a CPU + GPU module supported if device_id is passed
763
        in, errors if device_id is not.
764
        """
765
        torch.cuda.set_device(self.rank)
766

767
        class CPUGPUModule(nn.Module):
768
            def __init__(self):
769
                super().__init__()
770
                self.a = nn.Linear(1, 1).cuda()
771
                self.b = nn.Linear(1, 1)
772

773
        cpu_gpu = CPUGPUModule()
774
        fsdp = FSDP(cpu_gpu, device_id=torch.cuda.current_device())
775
        for param in fsdp.parameters():
776
            self.assertEqual(param.device, torch.device(torch.cuda.current_device()))
777

778
        # without device_id, we hit an error
779
        with self.assertRaisesRegex(RuntimeError, "please pass in device_id"):
780
            FSDP(CPUGPUModule())
781

782
    @skip_if_lt_x_gpu(2)
783
    def test_fsdp_ignored_module_meta(self):
784
        torch.cuda.set_device(self.rank)
785

786
        class CPUGPUModule(nn.Module):
787
            def __init__(self):
788
                super().__init__()
789
                self.a = nn.Linear(1, 1)
790
                self.b = nn.Linear(1, 1)
791

792
        with torch.device("meta"):
793
            m = CPUGPUModule()
794
        m = FSDP(m, device_id=self.rank, ignored_modules=[m.a], use_orig_params=True)
795
        meta_device = torch.device("meta")
796
        self.assertEqual(meta_device, next(m.a.parameters()).device)
797

798
        # Test with param_init_fn
799
        with torch.device("meta"):
800
            m = CPUGPUModule()
801
        m = FSDP(
802
            m,
803
            device_id=torch.cuda.current_device(),
804
            ignored_modules=[m.a],
805
            use_orig_params=True,
806
            param_init_fn=lambda m: m.to_empty(
807
                device=torch.cuda.current_device(), recurse=False
808
            ),
809
        )
810
        self.assertEqual(meta_device, next(m.a.parameters()).device)
811

812
    @skip_if_lt_x_gpu(2)
813
    def test_fsdp_device_id_no_move_ignored_params_and_bufs(self):
814
        class CPUGPUModule(nn.Module):
815
            def __init__(self):
816
                super().__init__()
817
                self.a = nn.Linear(1, 1)
818
                self.b = nn.Linear(1, 1)
819
                self.a.register_buffer("buf", torch.ones(1))
820

821
        m = CPUGPUModule()
822
        m = FSDP(m, device_id=self.rank, ignored_modules=[m.a], use_orig_params=True)
823
        ignored_params = m.a.parameters()
824
        ignored_bufs = m.a.buffers()
825
        for t in chain(ignored_params, ignored_bufs):
826
            self.assertEqual(torch.device("cpu"), t.device)
827

828
    @skip_if_lt_x_gpu(2)
829
    def test_multigpu_module(self):
830
        """
831
        Module on multiple GPUs wrapped in FSDP should raise an error.
832
        """
833

834
        class MultiGPUModule(nn.Module):
835
            def __init__(self, rank):
836
                super().__init__()
837
                self.rank = rank
838
                self.a = nn.Linear(1, 1).cuda(self.rank)
839
                self.b = nn.Linear(1, 1).cuda((self.rank + 1) % dist.get_world_size())
840

841
        with self.assertRaisesRegex(
842
            RuntimeError, "FSDP only supports single device modules"
843
        ):
844
            FSDP(MultiGPUModule(self.rank))
845

846
    @skip_if_lt_x_gpu(2)
847
    def test_no_params(self):
848
        """
849
        Test that device_id and cpu init work if module has no params
850
        (they are effective noops, but ensure FSDP does not assume module
851
        has parameters during init)
852
        """
853
        # TODO: override FSDP MT Thread _run to set this instead of here for
854
        # every test.
855
        torch.cuda.set_device(self.rank)
856
        # Test CPU
857
        no_params = nn.ReLU()
858
        module = FSDP(no_params)
859
        # Test CUDA
860
        no_params = nn.ReLU().cuda()
861
        module = FSDP(no_params)
862
        # Test CPU + device_id
863
        no_params = nn.ReLU()
864
        module = FSDP(no_params, device_id=torch.cuda.current_device())
865
        # For modules with no params, wrong device_id will raise error about
866
        # inconsistency between compute_device and device_id, since compute_device
867
        # is computed as torch.cuda.current_device when there are no params.
868
        no_params = nn.ReLU().cuda()
869
        context = (
870
            (
871
                self.assertRaisesRegex(
872
                    ValueError, f"Inconsistent.*cuda:{self.rank} vs cuda:0"
873
                )
874
            )
875
            if self.rank != 0
876
            else nullcontext()
877
        )
878
        with context:
879
            FSDP(no_params, device_id=0)
880

881
    @skip_if_lt_x_gpu(2)
882
    def test_fsdp_same_model_across_ranks(self):
883
        """
884
        FSDP broadcasts model from rank 0 to ensure it starts off with the same
885
        values.
886
        """
887

888
        class MyModel(nn.Module):
889
            def __init__(self, rank):
890
                super().__init__()
891
                # Seed via rank to make model different across ranks
892
                torch.manual_seed(rank)
893
                torch.cuda.manual_seed(rank)
894
                self.lin = nn.Linear(10, 10, bias=False)
895
                self.register_buffer("buffer", torch.ones(1) * rank)
896

897
        m = MyModel(self.rank).cuda()
898
        _assert_module_states(
899
            m, process_group=self.process_group, assert_fn=self.assertNotEqual
900
        )
901
        # Passing sync_module_states into FSDP makes model the same during init.
902
        fsdp = FSDP(m, sync_module_states=True)
903
        with fsdp.summon_full_params(fsdp):
904
            _assert_module_states(
905
                fsdp, process_group=self.process_group, assert_fn=self.assertEqual
906
            )
907

908
        # sync_module_states also works with CPU module with device_id passed in
909
        m = MyModel(self.rank)
910
        _assert_module_states(
911
            m, process_group=self.process_group, assert_fn=self.assertNotEqual
912
        )
913
        # Passing sync_module_states into FSDP makes model the same during init.
914
        fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True)
915
        with fsdp.summon_full_params(fsdp):
916
            _assert_module_states(
917
                fsdp, process_group=self.process_group, assert_fn=self.assertEqual
918
            )
919

920
    @skip_if_lt_x_gpu(2)
921
    def test_homogeneous_attributes(self):
922
        """
923
        Tests that passing heterogeneous values for attributes designated as
924
        homogeneous raises an error.
925
        """
926
        # Manually construct this list but verify against the global list of
927
        # homogeneous attribute names
928
        all_attr_name_and_values = [
929
            ("_use_orig_params", False, True),
930
            ("limit_all_gathers", False, True),
931
            ("_use_full_prec_in_eval", False, True),
932
        ]
933
        self.assertEqual(
934
            [
935
                attr_name_and_values[0]
936
                for attr_name_and_values in all_attr_name_and_values
937
            ],
938
            HOMOGENEOUS_ATTR_NAMES,
939
        )
940

941
        self.run_subtests(
942
            {"attr_name_and_values": all_attr_name_and_values},
943
            self._test_homogeneous_attributes,
944
        )
945

946
    def _test_homogeneous_attributes(self, attr_name_and_values: Tuple[str, Any, Any]):
947
        model = NestedWrappedModule.init(
948
            self.process_group,
949
            FSDPInitMode.NO_FSDP,
950
            CUDAInitMode.CUDA_BEFORE,
951
            {},
952
        )
953
        attr_name = attr_name_and_values[0]
954

955
        if "_use_full_prec_in_eval" == attr_name:
956
            model.module[1] = FSDP(model.module[1])
957
            os.environ["FSDP_USE_FULL_PREC_IN_EVAL"] = "1"
958
            fsdp_model = FSDP(model)
959
        else:
960
            fsdp_kwargs_inner = {attr_name.lstrip("_"): attr_name_and_values[1]}
961
            fsdp_kwargs_outer = {attr_name.lstrip("_"): attr_name_and_values[2]}
962
            model.module[1] = FSDP(model.module[1], **fsdp_kwargs_inner)
963
            fsdp_model = FSDP(model, **fsdp_kwargs_outer)
964

965
        # Run a forward to trigger lazy initialization and the error
966
        with self.assertRaisesRegex(
967
            ValueError, f"Expects one homogeneous value for {attr_name}"
968
        ):
969
            inp = fsdp_model.module.get_input(torch.device("cuda"))
970
            fsdp_model(*inp)
971

972

973
class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
974
    @property
975
    def world_size(self) -> int:
976
        return 1
977

978
    @skip_if_lt_x_gpu(1)
979
    def test_world_size_1_sharding_strategy_warning(self):
980
        """
981
        Tests that FSDP issues a warning when it switches to using ``NO_SHARD``
982
        when the world size is 1.
983
        """
984
        warning_prefix = "FSDP is switching to use `NO_SHARD` instead of"
985
        # If the user already passes `NO_SHARD`, then there should not be a
986
        # warning
987
        with warnings.catch_warnings(record=True) as w:
988
            warnings.simplefilter("always")  # trigger all warnings
989
            FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.NO_SHARD)
990
            for warning in w:
991
                self.assertTrue(
992
                    warning.category != UserWarning
993
                    or not str(warning.message).startswith(warning_prefix)
994
                )
995

996
        # Check that a warning is issued
997
        warning_suffix = " since the world size is 1."
998
        # - Pass `FULL_SHARD` or `None`
999
        expected_regex_full_shard = (
1000
            warning_prefix + " " + str(ShardingStrategy.FULL_SHARD) + warning_suffix
1001
        )
1002
        with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
1003
            FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.FULL_SHARD)
1004
        with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
1005
            FSDP(nn.Linear(3, 3).cuda())
1006
        # - Pass `SHARD_GRAD_OP`
1007
        expected_regex_shard_grad_op = (
1008
            warning_prefix + " " + str(ShardingStrategy.SHARD_GRAD_OP) + warning_suffix
1009
        )
1010
        with self.assertWarnsRegex(UserWarning, expected_regex_shard_grad_op):
1011
            FSDP(
1012
                nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
1013
            )
1014

1015
    @skip_if_lt_x_gpu(1)
1016
    def test_training_device_mismatch_errors(self):
1017
        """
1018
        Tests that, when training starts, if FSDP parameters are not on the
1019
        expected device, then an informative error is raised. This applies for
1020
        both no parameter CPU offloading and parameter CPU offloading.
1021
        """
1022
        # Incorrectly not moving from CPU -> GPU
1023
        model = torch.nn.Linear(10, 10)
1024
        fsdp_model = FSDP(model)
1025
        inp = torch.randn((2, 10))
1026
        with self.assertRaisesRegex(
1027
            RuntimeError,
1028
            "An FSDP-managed module unexpectedly has parameters on cpu. Make "
1029
            "sure to move the module to cuda:0 before training.",
1030
        ):
1031
            fsdp_model(inp)
1032

1033
        # Incorrectly moving from CPU -> GPU
1034
        model = torch.nn.Linear(10, 10)
1035
        fsdp_model = FSDP(model, cpu_offload=CPUOffload(offload_params=True))
1036
        fsdp_model.to(torch.device("cuda"))
1037
        inp = torch.randn((2, 10))
1038
        with self.assertRaisesRegex(
1039
            RuntimeError,
1040
            "An FSDP-managed module with parameter CPU offloading enabled has "
1041
            "parameters on cuda:0. Make sure to not move the module from CPU "
1042
            "when offloading parameters.",
1043
        ):
1044
            fsdp_model(inp)
1045

1046
    @skip_if_lt_x_gpu(2)
1047
    def test_unsafe_setattr(self):
1048
        """
1049
        Tests that the environment variable for using unsafe setattr gates as
1050
        expected.
1051
        """
1052
        self.run_subtests(
1053
            {"use_orig_params": [False, True]},
1054
            self._test_unsafe_setattr,
1055
        )
1056

1057
    def _test_unsafe_setattr(self, use_orig_params: bool):
1058
        called_setattr_override = False
1059

1060
        class SetattrLinear(nn.Module):
1061
            def __init__(self, in_dim: int, out_dim: int, device: torch.device) -> None:
1062
                super().__init__()
1063
                self.weight = nn.Parameter(
1064
                    torch.randn((in_dim, out_dim), device=device)
1065
                )
1066

1067
            def forward(self, x: torch.Tensor) -> torch.Tensor:
1068
                return x @ self.weight
1069

1070
            def __setattr__(self, name: str, value: Any) -> None:
1071
                nonlocal called_setattr_override
1072
                called_setattr_override = True
1073
                return super().__setattr__(name, value)
1074

1075
        # Construct FSDP module without changing any environment variables and
1076
        # run forward, which triggers both unsharded and sharded view setting
1077
        module = SetattrLinear(5, 5, torch.device("cuda"))
1078
        fsdp_module = FSDP(module, use_orig_params=use_orig_params)
1079
        inp = torch.randn((8, 5), device=torch.device("cuda"))
1080
        called_setattr_override = False
1081
        fsdp_module(inp)
1082
        self.assertTrue(called_setattr_override)
1083

1084
        # Repeat with unsafe setattr explicitly enabled
1085
        os.environ[_FSDP_USE_UNSAFE_SETATTR] = "1"
1086
        module = SetattrLinear(5, 5, torch.device("cuda"))
1087
        fsdp_module = FSDP(module, use_orig_params=use_orig_params)
1088
        called_setattr_override = False
1089
        fsdp_module(inp)
1090
        self.assertFalse(called_setattr_override)
1091

1092
        # Repeat with unsafe setattr explicitly disabled
1093
        os.environ[_FSDP_USE_UNSAFE_SETATTR] = "0"
1094
        module = SetattrLinear(5, 5, torch.device("cuda"))
1095
        fsdp_module = FSDP(module, use_orig_params=use_orig_params)
1096
        called_setattr_override = False
1097
        fsdp_module(inp)
1098
        self.assertTrue(called_setattr_override)
1099

1100

1101
instantiate_parametrized_tests(TestFSDPMiscMultiThread)
1102
instantiate_parametrized_tests(TestFSDPMiscMultiProcess)
1103

1104
if __name__ == "__main__":
1105
    run_tests()
1106

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.