pytorch

Форк
0
/
test_fsdp_use_orig_params.py 
1402 строки · 54.6 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import copy
4
import functools
5
import itertools
6
import os
7
import sys
8
import unittest
9
from typing import Any, Dict, List, Optional, Tuple, Type
10

11
import torch
12
import torch.nn as nn
13
from torch import distributed as dist
14
from torch.distributed.fsdp import (
15
    BackwardPrefetch,
16
    CPUOffload,
17
    FullyShardedDataParallel as FSDP,
18
    MixedPrecision,
19
    ShardingStrategy,
20
    StateDictType,
21
)
22
from torch.distributed.fsdp._common_utils import clean_tensor_name
23
from torch.distributed.fsdp._flat_param import (
24
    _FSDP_SKIP_WRITEBACK_CHECK,
25
    _FSDP_USE_FULL_PREC_IN_EVAL,
26
)
27
from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
28
from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy
29
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
30
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
31
from torch.testing._internal.common_cuda import TEST_CUDA
32
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
33
from torch.testing._internal.common_fsdp import (
34
    CUDAInitMode,
35
    FSDPInitMode,
36
    FSDPTest,
37
    TransformerWithSharedParams,
38
)
39
from torch.testing._internal.common_utils import (
40
    instantiate_parametrized_tests,
41
    parametrize,
42
    run_tests,
43
    TEST_WITH_DEV_DBG_ASAN,
44
    TestCase,
45
)
46

47
if not dist.is_available():
48
    print("Distributed not available, skipping tests", file=sys.stderr)
49
    sys.exit(0)
50

51
if TEST_WITH_DEV_DBG_ASAN:
52
    print(
53
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
54
        file=sys.stderr,
55
    )
56
    sys.exit(0)
57

58

59
class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
60
    """Tests multiple parameter groups."""
61

62
    @property
63
    def world_size(self) -> int:
64
        return 2
65

66
    def _get_param_groups(self, model: nn.Module) -> List[Dict[str, Any]]:
67
        """
68
        Constructs separate parameter groups for weights, biases, and other
69
        parameters.
70
        """
71
        param_groups = [
72
            {"params": [], "weight_decay": 0.1, "lr": 1e-2},
73
            {"params": [], "weight_decay": 0.01, "lr": 1e-3},
74
            {"params": []},
75
        ]
76
        for param_name, param in model.named_parameters():
77
            if "weight" in param_name:
78
                param_groups[0]["params"].append(param)
79
            elif "bias" in param_name:
80
                param_groups[1]["params"].append(param)
81
            else:
82
                param_groups[2]["params"].append(param)
83
        return param_groups
84

85
    def _get_optim(
86
        self,
87
        model: nn.Module,
88
        optim_class: Type[torch.optim.Optimizer],
89
        multi_tensor: bool,
90
    ) -> torch.optim.Optimizer:
91
        """
92
        Constructs an Adam optimizer with three parameter groups, one for
93
        weights, one for biases, and one for everything else, each with
94
        different weight decay and learning rates.
95
        """
96
        param_groups = self._get_param_groups(model)
97
        return optim_class(param_groups, lr=5e-3, foreach=multi_tensor)
98

99
    def _get_ddp_transformer(self, find_unused_params: bool) -> DDP:
100
        """Returns a transformer with shared parameters wrapped with DDP."""
101
        model = TransformerWithSharedParams.init(
102
            self.process_group,
103
            FSDPInitMode.NO_FSDP,
104
            CUDAInitMode.CUDA_BEFORE,
105
            deterministic=True,
106
        )
107
        ddp_model = DDP(
108
            model,
109
            device_ids=[self.rank],
110
            find_unused_parameters=find_unused_params,
111
        )
112
        return ddp_model
113

114
    def _get_fsdp_transformer_and_optim(
115
        self,
116
        cuda_init_mode: CUDAInitMode,
117
        init_optim_before_wrap: bool,
118
        optim_class: Type[torch.optim.Optimizer],
119
        multi_tensor: bool,
120
        sharding_strategy: ShardingStrategy,
121
        backward_prefetch: Optional[BackwardPrefetch],
122
        cpu_offload: CPUOffload,
123
    ) -> Tuple[FSDP, torch.optim.Optimizer]:
124
        """
125
        Returns a transformer with shared parameters wrapped with FSDP and a
126
        corresponding optimizer.
127
        """
128
        # Each transformer layer has multiple linear layers, so this policy, in
129
        # combination with the parameter group construction, ensures different
130
        # hyperparameter settings within one `FlatParameter`
131
        fsdp_kwargs = {
132
            "auto_wrap_policy": ModuleWrapPolicy(
133
                {
134
                    TransformerEncoderLayer,
135
                    TransformerDecoderLayer,
136
                }
137
            ),
138
            "use_orig_params": True,
139
            "sharding_strategy": sharding_strategy,
140
            "backward_prefetch": backward_prefetch,
141
            "cpu_offload": cpu_offload,
142
        }
143
        model = TransformerWithSharedParams.init(
144
            self.process_group,
145
            FSDPInitMode.NO_FSDP,
146
            cuda_init_mode,
147
            deterministic=True,
148
        )
149
        if init_optim_before_wrap:
150
            fsdp_optim = self._get_optim(model, optim_class, multi_tensor)
151
            fsdp_model = FSDP(model, self.process_group, **fsdp_kwargs)
152
        else:
153
            fsdp_model = FSDP(model, self.process_group, **fsdp_kwargs)
154
            fsdp_optim = self._get_optim(fsdp_model, optim_class, multi_tensor)
155
        if (
156
            cuda_init_mode == CUDAInitMode.CUDA_AFTER
157
            and not fsdp_model.cpu_offload.offload_params
158
        ):
159
            fsdp_model = fsdp_model.cuda()
160
        return fsdp_model, fsdp_optim
161

162
    def _check_train_parity(
163
        self,
164
        ddp_model: DDP,
165
        ddp_optim: torch.optim.Optimizer,
166
        fsdp_model: FSDP,
167
        fsdp_optim: torch.optim.Optimizer,
168
        set_to_none: bool,
169
        num_iters: int = 10,
170
    ):
171
        """Checks training parity between DDP and FSDP."""
172
        device = torch.device("cuda")
173
        for i in range(num_iters):
174
            iter_losses = []
175
            for model, optim in ((ddp_model, ddp_optim), (fsdp_model, fsdp_optim)):
176
                module = model.module
177
                # Test two different `zero_grad()` timings
178
                if i % 2 == 0:
179
                    optim.zero_grad(set_to_none=set_to_none)  # pre-forward
180
                inp = module.get_input(device)
181
                output = model(*inp)
182
                loss = module.get_loss(inp, output).to(device)
183
                iter_losses.append(loss)
184
                if i % 2 == 1:
185
                    optim.zero_grad(set_to_none=set_to_none)  # pre-backward
186
                module.run_backward(loss)
187
                # Perform the DDP optimizer step on CPU to match FSDP if needed
188
                if model is ddp_model and fsdp_model.cpu_offload.offload_params:
189
                    model.to(torch.device("cpu"))
190
                optim.step()
191
                if model is ddp_model and fsdp_model.cpu_offload.offload_params:
192
                    model.to(device)
193
            torch.testing.assert_close(iter_losses[0], iter_losses[1])
194
            iter_losses.clear()
195
        self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)
196

197
    def _check_ddp_fsdp_param_parity(self, ddp_model: DDP, fsdp_model: FSDP):
198
        with FSDP.summon_full_params(fsdp_model):
199
            for (n1, p1), (n2, p2) in zip(
200
                ddp_model.module.named_parameters(), fsdp_model.named_parameters()
201
            ):
202
                # Allow for FSDP prefixes
203
                self.assertEqual(n1, clean_tensor_name(n2))
204
                torch.testing.assert_close(p1, p2)
205

206
    def _get_sharding_strategy_from_str(
207
        self, sharding_strategy_str: str
208
    ) -> ShardingStrategy:
209
        if sharding_strategy_str == "no_shard":
210
            sharding_strategy = ShardingStrategy.NO_SHARD
211
        elif sharding_strategy_str == "shard_grad_op":
212
            sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
213
        elif sharding_strategy_str == "full_shard":
214
            sharding_strategy = ShardingStrategy.FULL_SHARD
215
        else:
216
            raise ValueError(f"Invalid string: {sharding_strategy_str}")
217
        return sharding_strategy
218

219
    @skip_if_lt_x_gpu(2)
220
    def test_fsdp_compile(self):
221
        self.run_subtests(
222
            {
223
                "sharding_strategy": [
224
                    ShardingStrategy.FULL_SHARD,
225
                    ShardingStrategy.SHARD_GRAD_OP,
226
                    ShardingStrategy.NO_SHARD,
227
                ],
228
                "skip_fsdp_guards": [True, False],
229
            },
230
            self._test_fsdp_compile,
231
        )
232

233
    def _test_fsdp_compile(
234
        self, sharding_strategy: ShardingStrategy, skip_fsdp_guards: bool
235
    ):
236
        torch._dynamo.config.skip_fsdp_guards = skip_fsdp_guards
237
        fsdp_kwargs = {
238
            "auto_wrap_policy": ModuleWrapPolicy(
239
                {
240
                    TransformerEncoderLayer,
241
                    TransformerDecoderLayer,
242
                }
243
            ),
244
            "use_orig_params": True,
245
            "sharding_strategy": sharding_strategy,
246
            "backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
247
            "cpu_offload": CPUOffload(False),
248
        }
249
        base_model = TransformerWithSharedParams.init(
250
            self.process_group,
251
            FSDPInitMode.NO_FSDP,
252
            CUDAInitMode.CUDA_BEFORE,
253
            deterministic=True,
254
        )
255
        ref_model = FSDP(copy.deepcopy(base_model), self.process_group, **fsdp_kwargs)
256
        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
257
        model = FSDP(copy.deepcopy(base_model), self.process_group, **fsdp_kwargs)
258
        model = torch.compile(model)
259
        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
260
        for i in range(10):
261
            losses = []
262
            inp = ref_model.get_input(torch.device("cuda"))
263
            for _model, _optim in ((ref_model, ref_optim), (model, optim)):
264
                _optim.zero_grad()
265
                loss = _model(*inp).sum()
266
                losses.append(loss)
267
                loss.backward()
268
                _optim.step()
269
            self.assertEqual(losses[0], losses[1])
270

271
    @skip_if_lt_x_gpu(2)
272
    @parametrize(
273
        "sharding_strategy_str",
274
        ["no_shard", "shard_grad_op", "full_shard"],
275
    )
276
    def test_diff_hyperparams(self, sharding_strategy_str: str):
277
        """
278
        Tests FSDP parity with DDP when using multiple parameter groups with
279
        different hyperparameter settings.
280
        """
281
        sharding_strategy = self._get_sharding_strategy_from_str(sharding_strategy_str)
282
        self.run_subtests(
283
            {
284
                "cuda_init_mode": [
285
                    CUDAInitMode.CUDA_BEFORE,
286
                    CUDAInitMode.CUDA_AFTER,
287
                ],
288
                "init_optim_before_wrap": [False, True],
289
                "optim_class": [torch.optim.AdamW],
290
                "multi_tensor": [False, True],
291
                "set_to_none": [False, True],
292
                "backward_prefetch": [
293
                    None,
294
                    BackwardPrefetch.BACKWARD_PRE,
295
                    BackwardPrefetch.BACKWARD_POST,
296
                ],
297
                "skip_writeback_check": [False, True],
298
            },
299
            self._test_diff_hyperparams,
300
            cpu_offload=CPUOffload(offload_params=False),
301
            sharding_strategy=sharding_strategy,
302
        )
303

304
    @skip_if_lt_x_gpu(2)
305
    @parametrize(
306
        "sharding_strategy_str",
307
        ["no_shard", "shard_grad_op", "full_shard"],
308
    )
309
    def test_diff_hyperparams_cpu_offload(self, sharding_strategy_str: str):
310
        """
311
        Tests FSDP parity with DDP when using multiple parameter groups with
312
        different hyperparameter settings with CPU offloading enabled. This is
313
        separate from :meth:`test_diff_hyperparams` because CPU offloading has
314
        some issues with subtesting for some specific subtesting configs (e.g.,
315
        with ``offload_params=False`` followed by ``True`` but not vice versa).
316
        """
317
        sharding_strategy = self._get_sharding_strategy_from_str(sharding_strategy_str)
318
        for skip_writeback_check in (False, True):
319
            self._test_diff_hyperparams(
320
                cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
321
                init_optim_before_wrap=False,
322
                optim_class=torch.optim.Adam,
323
                multi_tensor=False,
324
                set_to_none=False,
325
                backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
326
                cpu_offload=CPUOffload(offload_params=True),
327
                sharding_strategy=sharding_strategy,
328
                skip_writeback_check=skip_writeback_check,
329
            )
330

331
    def _test_diff_hyperparams(
332
        self,
333
        cuda_init_mode: CUDAInitMode,
334
        init_optim_before_wrap: bool,
335
        optim_class: Type[torch.optim.Optimizer],
336
        multi_tensor: bool,
337
        set_to_none: bool,
338
        backward_prefetch: Optional[BackwardPrefetch],
339
        cpu_offload: CPUOffload,
340
        sharding_strategy: ShardingStrategy,
341
        skip_writeback_check: bool,
342
    ):
343
        """
344
        Args:
345
            init_optim_before_wrap (bool): If ``True``, initializes the
346
                FSDP optimizer before wrapping the model with FSDP; otherwise,
347
                initializes the FSDP optimizer after wrapping the model with
348
                FSDP. We permit both forms of initialization to give users
349
                flexibility.
350
        """
351
        if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params:
352
            return  # not supported
353
        if skip_writeback_check:
354
            os.environ[_FSDP_SKIP_WRITEBACK_CHECK] = "1"
355
        ddp_model = self._get_ddp_transformer(find_unused_params=False)
356
        ddp_optim = self._get_optim(ddp_model, optim_class, multi_tensor)
357
        fsdp_model, fsdp_optim = self._get_fsdp_transformer_and_optim(
358
            cuda_init_mode=cuda_init_mode,
359
            init_optim_before_wrap=init_optim_before_wrap,
360
            optim_class=optim_class,
361
            multi_tensor=multi_tensor,
362
            sharding_strategy=sharding_strategy,
363
            backward_prefetch=backward_prefetch,
364
            cpu_offload=cpu_offload,
365
        )
366
        self._check_train_parity(
367
            ddp_model, ddp_optim, fsdp_model, fsdp_optim, set_to_none
368
        )
369

370
    @skip_if_lt_x_gpu(2)
371
    def test_diff_trainability(self):
372
        """
373
        Tests FSDP parity with DDP when using multiple parameter groups and
374
        freezing the parameters in one parameter group.
375
        """
376
        self.run_subtests(
377
            {
378
                "multi_tensor": [False, True],
379
                "sharding_strategy": [
380
                    ShardingStrategy.FULL_SHARD,
381
                    ShardingStrategy.SHARD_GRAD_OP,
382
                    ShardingStrategy.NO_SHARD,
383
                ],
384
            },
385
            self._test_diff_trainability,
386
        )
387

388
    def _test_diff_trainability(
389
        self,
390
        multi_tensor: bool,
391
        sharding_strategy: ShardingStrategy,
392
    ):
393
        optim_class = torch.optim.Adam
394
        ddp_model = self._get_ddp_transformer(find_unused_params=True)
395
        ddp_optim = self._get_optim(ddp_model, optim_class, multi_tensor)
396
        fsdp_model, fsdp_optim = self._get_fsdp_transformer_and_optim(
397
            cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
398
            init_optim_before_wrap=False,
399
            optim_class=optim_class,
400
            multi_tensor=multi_tensor,
401
            sharding_strategy=sharding_strategy,
402
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
403
            cpu_offload=None,
404
        )
405
        # Freeze all biases (which happen to be in the same parameter group)
406
        for param_name, param in ddp_model.named_parameters():
407
            if "bias" in param_name:
408
                param.requires_grad_(False)
409
        for param_name, param in fsdp_model.named_parameters():
410
            if "bias" in param_name:
411
                param.requires_grad_(False)
412
        self._check_train_parity(ddp_model, ddp_optim, fsdp_model, fsdp_optim, False)
413

414
    @skip_if_lt_x_gpu(2)
415
    def test_multiple_optimizers(self):
416
        """
417
        Tests using two optimizers where only one sets gradients to ``None``.
418
        """
419
        self.run_subtests(
420
            {
421
                "sharding_strategy": [
422
                    ShardingStrategy.FULL_SHARD,
423
                    ShardingStrategy.SHARD_GRAD_OP,
424
                ]
425
            },
426
            self._test_multiple_optimizers,
427
        )
428

429
    def _test_multiple_optimizers(self, sharding_strategy: ShardingStrategy):
430
        ddp_model = self._get_ddp_transformer(find_unused_params=True)
431
        ddp_param_groups = self._get_param_groups(ddp_model)
432
        assert len(ddp_param_groups) == 3, f"{len(ddp_param_groups)}"
433
        (
434
            fsdp_model,
435
            _,
436
        ) = self._get_fsdp_transformer_and_optim(  # ignore returned optimizer
437
            cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
438
            init_optim_before_wrap=False,
439
            optim_class=torch.optim.Adam,  # ignored
440
            multi_tensor=False,  # ignored
441
            sharding_strategy=sharding_strategy,
442
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
443
            cpu_offload=None,
444
        )
445
        fsdp_param_groups = self._get_param_groups(fsdp_model)
446
        assert len(fsdp_param_groups) == 3, f"{len(fsdp_param_groups)}"
447
        ddp_optims = []
448
        fsdp_optims = []
449
        # For the transformer model, every parameter is either a weight or a
450
        # bias, so we only use the first two parameter groups. Moreover, we use
451
        # Adam and AdamW in particular since they both use bias correction
452
        # dependent on the step, which is incremented even if a parameter has a
453
        # zero gradient but not if the gradient is `None`. This is to test that
454
        # we are differentiating between a zero and `None` gradient correctly.
455
        optim_ctors = [
456
            functools.partial(torch.optim.Adam, lr=5e-3),
457
            functools.partial(torch.optim.AdamW, lr=1e-2),
458
        ]
459

460
        for optim_ctor, ddp_param_group, fsdp_param_group in zip(
461
            optim_ctors,
462
            ddp_param_groups[:2],
463
            fsdp_param_groups[:2],
464
        ):
465
            ddp_optims.append(optim_ctor(ddp_param_group["params"]))
466
            fsdp_optims.append(optim_ctor(fsdp_param_group["params"]))
467
        device = torch.device("cuda")
468

469
        # Check that there exists a `FlatParameter` that has both a weight and
470
        # a bias in this rank's shard
471
        has_both = False
472
        for fsdp_module in FSDP.fsdp_modules(fsdp_model):
473
            handle = fsdp_module._handle
474
            if not handle:
475
                continue
476
            flat_param = handle.flat_param
477
            assert flat_param._params is not None
478
            has_weight = False
479
            has_bias = False
480
            for param, fqn in zip(flat_param._params, flat_param._fqns):
481
                if "weight" in fqn and param.numel() > 0:
482
                    has_weight = True
483
                elif "bias" in fqn and param.numel() > 0:
484
                    has_bias = True
485
            has_both |= has_weight and has_bias
486
        assert has_both, (
487
            f"Rank {self.rank} does not have a `FlatParameter` with both a "
488
            "weight and a bias in its shard, meaning that this test is vacuous"
489
        )
490

491
        # Run one iteration to generate gradients
492
        def run_iter():
493
            iter_losses = []
494
            for model, optims in ((ddp_model, ddp_optims), (fsdp_model, fsdp_optims)):
495
                module = model.module
496
                inp = module.get_input(device)
497
                output = model(*inp)
498
                loss = module.get_loss(inp, output).to(device)
499
                iter_losses.append(loss)
500
                module.run_backward(loss)
501
                for optim in optims:
502
                    optim.step()
503
            torch.testing.assert_close(iter_losses[0], iter_losses[1])
504
            iter_losses.clear()
505
            self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)
506

507
        run_iter()
508

509
        # Only set the weights' gradients to None
510
        ddp_optims[0].zero_grad(set_to_none=True)
511
        fsdp_optims[0].zero_grad(set_to_none=True)
512
        inp = ddp_model.module.get_input(device)
513
        ddp_output = ddp_model(*inp)
514
        fsdp_output = fsdp_model(*inp)
515

516
        # Check that FSDP correctly exposes gradients even after forward
517
        # (namely, `None` for weights and non-`None` for biases)
518
        if sharding_strategy in NO_RESHARD_AFTER_FORWARD_STRATEGIES:
519
            # Skip the check since we do not expose the gradients after forward
520
            # for these strategies
521
            return
522
        for (ddp_n, ddp_p), (fsdp_n, fsdp_p) in zip(
523
            ddp_model.module.named_parameters(),
524
            fsdp_model.named_parameters(),
525
        ):
526
            self.assertEqual(ddp_n, clean_tensor_name(fsdp_n))
527
            if fsdp_p.numel() == 0:
528
                # Not in this rank's shard
529
                self.assertTrue(fsdp_p.grad is None)
530
                continue
531
            if ddp_p.grad is None:
532
                self.assertTrue(fsdp_p.grad is None)
533
            else:
534
                self.assertEqual(ddp_p.flatten(), fsdp_p.flatten())
535
                self.assertEqual(ddp_p.grad.flatten(), fsdp_p.grad.flatten())
536
        self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)
537

538
        # Finish the iteration (backward pass and optimizer step)
539
        ddp_loss = ddp_model.module.get_loss(inp, ddp_output).to(device)
540
        fsdp_loss = fsdp_model.module.get_loss(inp, fsdp_output).to(device)
541
        ddp_model.module.run_backward(ddp_loss)
542
        fsdp_model.module.run_backward(fsdp_loss)
543
        for optim in itertools.chain(ddp_optims, fsdp_optims):
544
            optim.step()
545
        self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)
546

547
        # Run one more iteration to confirm bias corrections are correct
548
        run_iter()
549
        self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)
550

551

552
class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
553
    """Tests the unshard/reshard flow."""
554

555
    @property
556
    def world_size(self) -> int:
557
        return 2
558

559
    def _get_fsdp_models_and_optims(
560
        self,
561
        sharding_strategy: ShardingStrategy,
562
        cpu_offload: CPUOffload,
563
    ) -> Tuple[FSDP, torch.optim.Optimizer, FSDP, torch.optim.Optimizer]:
564
        """
565
        Returns a pair of (FSDP model, optimizer) for ``use_orig_params=False``
566
        and ``True``, respectively.
567
        """
568
        LR = 1e-2
569
        fsdp_kwargs = {
570
            "sharding_strategy": sharding_strategy,
571
            "cpu_offload": cpu_offload,
572
            "use_orig_params": False,
573
        }
574
        fsdp_model = TransformerWithSharedParams.init(
575
            self.process_group,
576
            FSDPInitMode.RECURSIVE,
577
            CUDAInitMode.CUDA_BEFORE,
578
            fsdp_kwargs=fsdp_kwargs,
579
            deterministic=True,
580
        )
581
        optim = torch.optim.Adam(fsdp_model.parameters(), foreach=False, lr=LR)
582
        fsdp_kwargs["use_orig_params"] = True
583
        fsdp_model_orig_params = TransformerWithSharedParams.init(
584
            self.process_group,
585
            FSDPInitMode.RECURSIVE,
586
            CUDAInitMode.CUDA_BEFORE,
587
            fsdp_kwargs=fsdp_kwargs,
588
            deterministic=True,
589
        )
590
        optim_orig_params = torch.optim.Adam(
591
            fsdp_model_orig_params.parameters(), foreach=False, lr=LR
592
        )
593
        return fsdp_model, optim, fsdp_model_orig_params, optim_orig_params
594

595
    def _check_fsdp_parameter_parity(self, fsdp1: FSDP, fsdp2: FSDP) -> None:
596
        """Checks that two FSDP instances have the same model parameters."""
597
        with FSDP.summon_full_params(fsdp1), FSDP.summon_full_params(fsdp2):
598
            for (n1, p1), (n2, p2) in zip(
599
                fsdp1.named_parameters(),
600
                fsdp2.named_parameters(),
601
            ):
602
                self.assertEqual(n1, n2)
603
                torch.testing.assert_close(p1, p2)
604

605
    def _get_fsdp_parity_subtest_config(self):
606
        return {
607
            "sharding_strategy": [
608
                ShardingStrategy.NO_SHARD,
609
                ShardingStrategy.SHARD_GRAD_OP,
610
                ShardingStrategy.FULL_SHARD,
611
            ],
612
        }
613

614
    @skip_if_lt_x_gpu(2)
615
    @parametrize("offload_params", [False, True])
616
    def test_multiple_forward(self, offload_params: bool):
617
        """
618
        Tests that ``use_orig_params=True`` has parity with ``False`` when
619
        running multiple forward passes before a backward pass.
620
        """
621
        cpu_offload = CPUOffload(offload_params=offload_params)
622
        self.run_subtests(
623
            self._get_fsdp_parity_subtest_config(),
624
            self._test_multiple_forward,
625
            cpu_offload=cpu_offload,
626
        )
627

628
    @skip_if_lt_x_gpu(2)
629
    def _test_multiple_forward(
630
        self,
631
        sharding_strategy: ShardingStrategy,
632
        cpu_offload: CPUOffload,
633
    ):
634
        (
635
            fsdp_model,
636
            optim,
637
            fsdp_model_orig_params,
638
            optim_orig_params,
639
        ) = self._get_fsdp_models_and_optims(sharding_strategy, cpu_offload)
640
        device = torch.device("cuda")
641
        for _ in range(3):
642
            inp1 = fsdp_model.get_input(device)
643
            _inp2 = fsdp_model.get_input(device)
644
            inp2 = tuple(
645
                t + torch.ones_like(t) for t in _inp2
646
            )  # make different from `inp1`
647
            # For these loss lists: elem 0 is baseline; elem 1 is test
648
            losses1 = []
649
            losses2 = []
650
            losses = []
651
            for _model, _optim in (fsdp_model, optim), (
652
                fsdp_model_orig_params,
653
                optim_orig_params,
654
            ):
655
                _optim.zero_grad()
656
                loss1 = _model(*inp1)
657
                losses1.append(loss1)
658
                loss2 = _model(*inp2)
659
                losses2.append(loss2)
660
                loss = (loss1 + loss2).sum()
661
                losses.append(loss)
662
                _model.run_backward(loss)
663
                _optim.step()
664
            self.assertEqual(losses1[0], losses1[1])
665
            self.assertEqual(losses2[0], losses2[1])
666
            self.assertEqual(losses[0], losses[1])
667
        self._check_fsdp_parameter_parity(fsdp_model, fsdp_model_orig_params)
668

669
    @skip_if_lt_x_gpu(2)
670
    @parametrize("offload_params", [False, True])
671
    def test_summon_between_two_forwards(self, offload_params: bool):
672
        """
673
        Tests that ``use_orig_params=True`` has parity with ``False`` when
674
        running a forward pass, :meth:`summon_full_params()`, and another
675
        forward pass before a backward pass.
676
        """
677
        cpu_offload = CPUOffload(offload_params=offload_params)
678
        self.run_subtests(
679
            self._get_fsdp_parity_subtest_config(),
680
            self._test_summon_between_two_forwards,
681
            cpu_offload=cpu_offload,
682
        )
683

684
    def _test_summon_between_two_forwards(
685
        self,
686
        sharding_strategy: ShardingStrategy,
687
        cpu_offload: CPUOffload,
688
    ):
689
        (
690
            fsdp_model,
691
            optim,
692
            fsdp_model_orig_params,
693
            optim_orig_params,
694
        ) = self._get_fsdp_models_and_optims(sharding_strategy, cpu_offload)
695
        device = torch.device("cuda")
696
        for _ in range(3):
697
            optim.zero_grad()
698
            optim_orig_params.zero_grad()
699

700
            inp1 = fsdp_model.get_input(device)
701
            loss1 = fsdp_model(*inp1)
702
            loss_orig_params1 = fsdp_model_orig_params(*inp1)
703
            self.assertEqual(loss1, loss_orig_params1)
704

705
            # Calls into `summon_full_params()`
706
            self._check_fsdp_parameter_parity(fsdp_model, fsdp_model_orig_params)
707

708
            inp2 = fsdp_model.get_input(device)
709
            loss2 = fsdp_model(*inp2)
710
            loss_orig_params2 = fsdp_model_orig_params(*inp2)
711
            self.assertEqual(loss2, loss_orig_params2)
712

713
            loss = (loss1 + loss2).sum()
714
            loss_orig_params = (loss_orig_params1 + loss_orig_params2).sum()
715
            fsdp_model.run_backward(loss)
716
            fsdp_model_orig_params.run_backward(loss_orig_params)
717
            optim.step()
718
            optim_orig_params.step()
719
        self._check_fsdp_parameter_parity(fsdp_model, fsdp_model_orig_params)
720

721

722
class TestFSDPUseOrigParamsParamAccess(FSDPTest):
723
    """Tests original parameter access."""
724

725
    @property
726
    def world_size(self):
727
        # Force a world size of 2 since the tests hard code to the FSDP
728
        # sharding strategy to check sharded parameter parity
729
        return 2
730

731
    @skip_if_lt_x_gpu(2)
732
    def test_access_params_after_forward(self):
733
        """
734
        Tests that accessing the original parameters after the forward but
735
        before the backward. Notably, this is not supported when
736
        ``use_orig_params=False``. However, for ``True``, FSDP exposes the
737
        (flattened) sharded original parameters, making it possible.
738
        """
739
        self.run_subtests(
740
            {
741
                "sharding_strategy": [
742
                    ShardingStrategy.NO_SHARD,
743
                    ShardingStrategy.FULL_SHARD,
744
                    ShardingStrategy.SHARD_GRAD_OP,
745
                ],
746
            },
747
            self._test_access_params_after_forward,
748
        )
749

750
    def _test_access_params_after_forward(
751
        self,
752
        sharding_strategy: ShardingStrategy,
753
    ):
754
        # NOTE: This test needs to be changed if the FSDP sharding algorithm
755
        # changes. It is still valuable until such a change to sanity check the
756
        # `use_orig_params=True` implementation.
757
        class Model(nn.Module):
758
            def __init__(self):
759
                super().__init__()
760
                torch.manual_seed(42)
761
                # 5 * 5 = 25 numel -> pad to 26 -> 13 on each rank
762
                self.lin1 = nn.Linear(5, 5, bias=False)
763
                # 5 * 7 + (1) + 7 = 43 numel -> pad to 44 -> 22 on each rank,
764
                # where the (1) is from intra-`FlatParameter` alignment padding
765
                # 22 of weight on rank 0; 13 of weight, 1 alignment padding,
766
                # and 7 of bias on rank 1
767
                self.lin2 = nn.Linear(5, 7)
768

769
            def forward(self, x: torch.Tensor) -> torch.Tensor:
770
                z = self.lin1(x)
771
                z = nn.functional.relu(z)
772
                z = self.lin2(z)
773
                return z
774

775
            def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]:
776
                return (torch.randn((2, 5)).to(device),)
777

778
            def get_loss(self, inp, out):
779
                return out.sum()
780

781
        def check_parameter_parity(
782
            ddp_model: DDP, fsdp_model: FSDP, between_fwd_and_bwd: bool
783
        ):
784
            assert self.rank in (
785
                0,
786
                1,
787
            ), f"Expects world size of 2 but got {self.world_size}"
788
            for (n1, p1), (n2, p2) in zip(
789
                ddp_model.module.named_parameters(),
790
                fsdp_model.named_parameters(),
791
            ):
792
                self.assertEqual(n1, clean_tensor_name(n2))
793
                if sharding_strategy == ShardingStrategy.NO_SHARD:
794
                    # For `NO_SHARD`, do nothing since the original parameters
795
                    # are unflattened
796
                    pass
797
                elif (
798
                    between_fwd_and_bwd
799
                    and sharding_strategy in NO_RESHARD_AFTER_FORWARD_STRATEGIES
800
                ):
801
                    # For no reshard after forward strategies, do nothing since
802
                    # FSDP did not use sharded views after forward
803
                    pass
804
                # Otherwise, case on the parameter (see the model definition)
805
                elif n1 == "lin1.weight":
806
                    if self.rank == 0:
807
                        p1 = p1.flatten()[:13]
808
                    elif self.rank == 1:
809
                        p1 = p1.flatten()[13:]
810
                elif n1 == "lin2.weight":
811
                    if self.rank == 0:
812
                        p1 = p1.flatten()[:22]
813
                    elif self.rank == 1:
814
                        p1 = p1.flatten()[22:]
815
                elif n1 == "lin2.bias":
816
                    if self.rank == 0:
817
                        p1 = torch.empty(0, device=p1.device)
818
                    elif self.rank == 1:
819
                        p1 = p1.flatten()
820
                torch.testing.assert_close(p1, p2)
821

822
        ddp_model = DDP(Model().cuda(), device_ids=[self.rank])
823
        fsdp_model = FSDP(
824
            Model().cuda(),
825
            sharding_strategy=sharding_strategy,
826
            auto_wrap_policy=always_wrap_policy,
827
            use_orig_params=True,
828
        )
829
        LR = 1e-2
830
        ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
831
        fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)
832
        device = torch.device("cuda")
833

834
        inp = fsdp_model.get_input(device)
835
        ddp_out = ddp_model(*inp)
836
        fsdp_out = fsdp_model(*inp)
837
        check_parameter_parity(ddp_model, fsdp_model, True)
838

839
        ddp_loss = ddp_model.module.get_loss(inp, ddp_out)
840
        fsdp_loss = fsdp_model.get_loss(inp, fsdp_out)
841
        ddp_loss.backward()
842
        fsdp_loss.backward()
843
        ddp_optim.step()
844
        fsdp_optim.step()
845
        check_parameter_parity(ddp_model, fsdp_model, False)
846

847
        inp = fsdp_model.get_input(device)
848
        ddp_out = ddp_model(*inp)
849
        fsdp_out = fsdp_model(*inp)
850
        check_parameter_parity(ddp_model, fsdp_model, True)
851

852

853
class TestFSDPUseOrigParamsWriteback(FSDPTest):
854
    """Tests parameter and gradient writeback."""
855

856
    class Model(nn.Module):
857
        def __init__(self, device: torch.device):
858
            super().__init__()
859
            torch.manual_seed(42)
860
            self.lin1 = nn.Linear(5, 5, bias=True, device=device)
861
            self.lin2 = nn.Linear(5, 7, bias=True, device=device)
862

863
        def forward(self, x: torch.Tensor) -> torch.Tensor:
864
            z = self.lin1(x)
865
            z = nn.functional.relu(z)
866
            z = self.lin2(z)
867
            return z
868

869
        def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]:
870
            return (torch.randn((2, 5)).to(device),)
871

872
        def get_loss(self, inp, out):
873
            return out.sum()
874

875
    @property
876
    def world_size(self):
877
        # Force a world size of 2 since the tests hard code to the FSDP
878
        # sharding strategy
879
        return 2
880

881
    def _check_param_parity(self, ddp_model: DDP, fsdp_model: FSDP):
882
        with FSDP.summon_full_params(fsdp_model):
883
            for (n1, p1), (n2, p2) in zip(
884
                ddp_model.module.named_parameters(),
885
                fsdp_model.named_parameters(),
886
            ):
887
                self.assertEqual(n1, n2)
888
                torch.testing.assert_close(p1, p2)
889

890
    @skip_if_lt_x_gpu(2)
891
    def test_param_writeback(self):
892
        """Tests that changes to the original parameters are written back."""
893
        self.run_subtests(
894
            {
895
                "change_first_weight": [True, False],  # first vs. second `weight`
896
                "change_data": [True, False],  # change `.data` vs. variable itself
897
            },
898
            self._test_param_writeback,
899
        )
900

901
    def _test_param_writeback(self, change_first_weight: bool, change_data: bool):
902
        def transform_param(param: nn.Parameter) -> nn.Parameter:
903
            return nn.Parameter(torch.ones_like(param) * 2)
904

905
        # Check that the writeback propagates
906
        ddp_model = DDP(
907
            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
908
            device_ids=[self.rank],
909
        )
910
        fsdp_model = FSDP(
911
            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
912
            use_orig_params=True,
913
        )
914
        ddp = ddp_model.module  # for brevity
915
        fsdp = fsdp_model.module
916
        if change_first_weight:
917
            if change_data:
918
                ddp.lin1.weight.data = transform_param(ddp.lin1.weight)
919
                fsdp.lin1.weight.data = transform_param(fsdp.lin1.weight)
920
            else:
921
                ddp.lin1.weight = transform_param(ddp.lin1.weight)
922
                fsdp.lin1.weight = transform_param(fsdp.lin1.weight)
923
        else:
924
            if change_data:
925
                ddp.lin2.weight.data = transform_param(ddp.lin2.weight)
926
                fsdp.lin2.weight.data = transform_param(fsdp.lin2.weight)
927
            else:
928
                ddp.lin2.weight = transform_param(ddp.lin2.weight)
929
                fsdp.lin2.weight = transform_param(fsdp.lin2.weight)
930
        self._check_param_parity(ddp_model, fsdp_model)  # triggers a writeback
931

932
    @skip_if_lt_x_gpu(2)
933
    def test_grad_writeback(self):
934
        """
935
        Tests that changes to the original parameters' gradients are written
936
        back.
937
        """
938
        self.run_subtests(
939
            {
940
                "change_first_weight_grad": [False, True],
941
                "change_data": [False, True],  # change `.data` vs. variable itself
942
                "set_to_none": [False, True],
943
            },
944
            self._test_grad_writeback,
945
        )
946

947
    def _test_grad_writeback(
948
        self,
949
        change_first_weight_grad: bool,
950
        change_data: bool,
951
        set_to_none: bool,
952
    ):
953
        if change_data and set_to_none:
954
            return  # not well-defined
955

956
        def transform_grad(param: nn.Parameter) -> nn.Parameter:
957
            return None if set_to_none else torch.ones_like(param) * 2
958

959
        ddp_model = DDP(
960
            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
961
            device_ids=[self.rank],
962
        )
963
        fsdp_model = FSDP(
964
            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
965
            use_orig_params=True,
966
        )
967
        LR = 1e-2
968
        # TODO: If we add `summon_full_params(with_grads=True)`, then replace
969
        # the following. For now, we use the optimizer step as a surrogate for
970
        # checking that gradients were written back.
971
        ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
972
        fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)
973

974
        # Generate an initial gradient
975
        inp = fsdp_model.get_input(torch.device("cuda"))
976
        ddp_out = ddp_model(*inp)
977
        fsdp_out = fsdp_model(*inp)
978
        ddp_out.sum().backward()
979
        fsdp_out.sum().backward()
980

981
        # Change the gradient through the original parameters
982
        ddp = ddp_model.module  # for brevity
983
        fsdp = fsdp_model.module
984
        if change_first_weight_grad:
985
            if change_data:
986
                ddp.lin1.weight.grad.data = transform_grad(ddp.lin1.weight)
987
                if fsdp.lin1.weight.grad is not None:
988
                    fsdp.lin1.weight.grad.data = transform_grad(fsdp.lin1.weight)
989
            else:
990
                ddp.lin1.weight.grad = transform_grad(ddp.lin1.weight)
991
                fsdp.lin1.weight.grad = transform_grad(fsdp.lin1.weight)
992
        else:
993
            if change_data:
994
                ddp.lin2.weight.grad.data = transform_grad(ddp.lin2.weight)
995
                if fsdp.lin2.weight.grad is not None:
996
                    fsdp.lin2.weight.grad.data = transform_grad(fsdp.lin2.weight)
997
            else:
998
                ddp.lin2.weight.grad = transform_grad(ddp.lin2.weight)
999
                fsdp.lin2.weight.grad = transform_grad(fsdp.lin2.weight)
1000
        ddp_optim.step()
1001
        fsdp_optim.step()
1002
        self._check_param_parity(ddp_model, fsdp_model)  # triggers a writeback
1003

1004
        # Intentionally do not zero the gradient to check writeback
1005
        inp = fsdp_model.get_input(torch.device("cuda"))
1006
        ddp_out = ddp_model(*inp)
1007
        fsdp_out = fsdp_model(*inp)
1008
        ddp_out.sum().backward()
1009
        fsdp_out.sum().backward()
1010
        ddp_optim.step()
1011
        fsdp_optim.step()
1012
        self._check_param_parity(ddp_model, fsdp_model)  # triggers a writeback
1013

1014
    @skip_if_lt_x_gpu(2)
1015
    def test_writeback_shape_mismatch(self):
1016
        fsdp_model = FSDP(
1017
            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
1018
            use_orig_params=True,
1019
        )
1020
        # Check that writing back with mismatched shape errors
1021
        fsdp = fsdp_model.module  # for brevity
1022
        assert self.rank in (0, 1), f"Expects world size of 2 but got {self.world_size}"
1023
        with self.assertRaisesRegex(RuntimeError, "Cannot writeback"):
1024
            # Change the gradient to a new one with 1 added to each dimension
1025
            # to force a shape mismatch when writing back
1026
            if self.rank == 0:
1027
                # Change `lin1.weight.grad` since it exists on rank 0
1028
                lin1_weight_shape = list(fsdp.lin1.weight.shape)
1029
                for dim_index in range(len(lin1_weight_shape)):
1030
                    lin1_weight_shape[dim_index] += 1
1031
                fsdp.lin1.weight = nn.Parameter(
1032
                    torch.randn(
1033
                        torch.Size(lin1_weight_shape), device=fsdp.lin1.weight.device
1034
                    )
1035
                )
1036
                fsdp.lin1.weight.grad = torch.randn(
1037
                    torch.Size(lin1_weight_shape), device=fsdp.lin1.weight.device
1038
                )
1039
            elif self.rank == 1:
1040
                # Change `lin2.weight.grad` since it exists (partially) on rank 1
1041
                lin2_weight_shape = list(fsdp.lin2.weight.shape)
1042
                for dim_index in range(len(lin2_weight_shape)):
1043
                    lin2_weight_shape[dim_index] += 1
1044
                fsdp.lin2.weight = nn.Parameter(
1045
                    torch.randn(
1046
                        torch.Size(lin2_weight_shape), device=fsdp.lin2.weight.device
1047
                    )
1048
                )
1049
                fsdp.lin2.weight.grad = torch.randn(
1050
                    torch.Size(lin2_weight_shape), device=fsdp.lin2.weight.device
1051
                )
1052
            with FSDP.summon_full_params(fsdp_model):  # triggers a writeback
1053
                ...
1054

1055
    @skip_if_lt_x_gpu(2)
1056
    def test_writeback_between_fwd_and_bwd_for_no_reshard_raises(self):
1057
        fsdp_kwargs = {
1058
            "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
1059
            "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
1060
            "use_orig_params": True,
1061
        }
1062
        fsdp_wrapper = functools.partial(FSDP, **fsdp_kwargs)
1063

1064
        # Test changing the parameter storage to no longer be a view into the
1065
        # flat parameter
1066
        fsdp_model = fsdp_wrapper(
1067
            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))
1068
        )
1069
        inp = fsdp_model.get_input(torch.device("cuda"))
1070
        loss = fsdp_model(*inp).sum()
1071
        fsdp_model.lin1.weight.data = fsdp_model.lin1.weight.clone()
1072
        assert_msg = (
1073
            "FSDP does not support changing the parameters between forward and backward"
1074
        )
1075
        with self.assertRaisesRegex(AssertionError, assert_msg):
1076
            loss.backward()
1077

1078
        # Test changing the parameter variable itself
1079
        fsdp_model = fsdp_wrapper(
1080
            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))
1081
        )
1082
        inp = fsdp_model.get_input(torch.device("cuda"))
1083
        loss = fsdp_model(*inp).sum()
1084
        fsdp_model.lin1._fsdp_wrapped_module.weight = nn.Parameter(
1085
            fsdp_model.lin1.weight.clone()
1086
        )
1087
        with self.assertRaisesRegex(AssertionError, assert_msg):
1088
            loss.backward()
1089

1090
    @skip_if_lt_x_gpu(2)
1091
    def test_no_reshard_and_mixed_precision(self):
1092
        """
1093
        Tests that writeback does not falsely get triggered for a few
1094
        configurations (exercising the sharded view skipping logic):
1095
        - Train forward -> full-precision unshard -> train forward
1096
        - Train forward -> eval forward
1097
        - Train forward/backward -> eval forward -> model checkpoint
1098
        """
1099
        self.run_subtests(
1100
            {"use_full_prec_in_eval": [False, True]},
1101
            self._test_no_reshard_and_mixed_precision,
1102
        )
1103

1104
    def _test_no_reshard_and_mixed_precision(self, use_full_prec_in_eval: bool):
1105
        if use_full_prec_in_eval:
1106
            os.environ[_FSDP_USE_FULL_PREC_IN_EVAL] = "1"
1107
        fsdp_kwargs = {
1108
            "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
1109
            "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
1110
            "mixed_precision": MixedPrecision(param_dtype=torch.float16),
1111
            "use_orig_params": True,
1112
        }
1113

1114
        # Train forward -> full-precision unshard -> train forward
1115
        fsdp_model = FSDP(
1116
            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")), **fsdp_kwargs
1117
        )
1118
        inp = fsdp_model.get_input(torch.device("cuda"))
1119
        fsdp_model(*inp)
1120
        with FSDP.summon_full_params(fsdp_model):
1121
            ...
1122
        fsdp_model(*inp).sum()
1123

1124
        # Train forward -> eval forward
1125
        fsdp_model.train()
1126
        fsdp_model(*inp)
1127
        fsdp_model.eval()
1128
        fsdp_model(*inp)
1129

1130
        # Train forward/backward -> eval forward -> model checkpoint
1131
        fsdp_model.train()
1132
        fsdp_model(*inp).sum().backward()
1133
        fsdp_model.eval()
1134
        fsdp_model(*inp)
1135
        with FSDP.state_dict_type(fsdp_model, StateDictType.SHARDED_STATE_DICT):
1136
            sd = fsdp_model.state_dict()
1137
            fsdp_model.load_state_dict(sd)
1138
        fsdp_model(*inp).sum().backward()
1139

1140

1141
class TestFSDPUseOrigParamsFQNs(FSDPTest):
1142
    @skip_if_lt_x_gpu(2)
1143
    def test_named_parameters_in_forward(self):
1144
        """
1145
        Tests that calling ``named_parameters()`` during forward returns FQNs
1146
        and ``Tensor`` s corresponding to the original parameters.
1147
        """
1148
        param_shapes = [None, None]
1149
        assert_equal_fn = self.assertEqual
1150

1151
        class Model(nn.Module):
1152
            def __init__(self) -> None:
1153
                super().__init__()
1154
                self.lin = nn.Linear(5, 5)
1155

1156
            def forward(self, x: torch.Tensor) -> torch.Tensor:
1157
                nonlocal param_shapes
1158
                # Allow for FSDP prefixes
1159
                param_names = [
1160
                    clean_tensor_name(tup[0]) for tup in self.named_parameters()
1161
                ]
1162
                params = [tup[1] for tup in self.named_parameters()]
1163
                assert (
1164
                    param_shapes[0] is not None and param_shapes[1] is not None
1165
                ), "`param_sizes` should be set"
1166
                assert_equal_fn(
1167
                    param_names,
1168
                    [
1169
                        "lin.weight",
1170
                        "lin.bias",
1171
                    ],
1172
                )
1173
                assert_equal_fn(params[0].shape, param_shapes[0])
1174
                assert_equal_fn(params[1].shape, param_shapes[1])
1175
                return self.lin(x)
1176

1177
        model = Model().cuda()
1178
        # Save the *unsharded* original parameter shapes and check the shapes
1179
        # match in the forward pass
1180
        param_shapes[0] = model.lin.weight.shape
1181
        param_shapes[1] = model.lin.bias.shape
1182
        fsdp_model = FSDP(model, use_orig_params=True)
1183
        inp = torch.randn((2, 5), device=torch.device("cuda"))
1184
        fsdp_model(inp)
1185

1186

1187
class TestFSDPUseOrigParamsNoSync(FSDPTest):
1188
    @property
1189
    def world_size(self) -> int:
1190
        return 2
1191

1192
    @skip_if_lt_x_gpu(2)
1193
    def test_no_sync_correctness(self):
1194
        """
1195
        Tests a basic ``no_sync()`` setup by comparing ``use_orig_params=True``
1196
        against ``use_orig_params=False``.
1197
        """
1198
        self.run_subtests(
1199
            {
1200
                "sharding_strategy": [
1201
                    ShardingStrategy.FULL_SHARD,
1202
                    ShardingStrategy.SHARD_GRAD_OP,
1203
                    ShardingStrategy.NO_SHARD,
1204
                ],
1205
            },
1206
            self._test_no_sync_correctness,
1207
        )
1208

1209
    def _test_no_sync_correctness(self, sharding_strategy: ShardingStrategy):
1210
        model = nn.Linear(7, 1, bias=False, device="cuda")
1211
        fsdp_kwargs = {
1212
            "sharding_strategy": sharding_strategy,
1213
        }
1214
        model_use_flat_params = FSDP(
1215
            copy.deepcopy(model), use_orig_params=False, **fsdp_kwargs
1216
        )
1217
        model_use_orig_params = FSDP(model, use_orig_params=True, **fsdp_kwargs)
1218
        optim_use_flat_params = torch.optim.AdamW(
1219
            model_use_flat_params.parameters(), foreach=True
1220
        )
1221
        optim_use_orig_params = torch.optim.AdamW(
1222
            model_use_orig_params.parameters(), foreach=True
1223
        )
1224

1225
        def _check_param_grad_parity(
1226
            _baseline_model: nn.Module,
1227
            _test_model: nn.Module,
1228
        ):
1229
            """
1230
            This assumes that the model is ``nn.Linear(7, 1, bias=False)``
1231
            (i.e. with a single 1D weight parameter) to be able to directly
1232
            compare the baseline and test models. On rank 1, the baseline
1233
            includes 1 element of padding.
1234
            """
1235
            self.assertEqual(len(list(_baseline_model.parameters())), 1)
1236
            self.assertEqual(len(list(_test_model.parameters())), 1)
1237
            for flat_param, orig_param in zip(
1238
                _baseline_model.parameters(), _test_model.parameters()
1239
            ):
1240
                # Baseline is permitted to have padding
1241
                self.assertGreaterEqual(flat_param.numel(), orig_param.numel())
1242
                unpadded_param_numel = orig_param.numel()
1243
                # For `NO_SHARD`, `use_orig_params=True` presents unflattened
1244
                # parameters, while `False` presents flattened ones
1245
                torch.testing.assert_close(
1246
                    flat_param[:unpadded_param_numel], orig_param.flatten()
1247
                )
1248
                # Gradient numel is different if right after `no_sync()` since
1249
                # the gradient is unsharded, while the parameter is sharded
1250
                unpadded_grad_numel = orig_param.grad.numel()
1251
                # For `use_orig_params=False`, the unsharded gradient is
1252
                # flattened, while for `True`, it is unflattened
1253
                torch.testing.assert_close(
1254
                    flat_param.grad[:unpadded_grad_numel].reshape(
1255
                        orig_param.grad.shape
1256
                    ),
1257
                    orig_param.grad,
1258
                )
1259

1260
        inp = torch.randn((2, 7), device="cuda")
1261
        grad = torch.randn((2, 1), device="cuda")
1262

1263
        # Compute some reference gradients using one forward/backward
1264
        out_use_flat_params = model_use_flat_params(inp)
1265
        out_use_orig_params = model_use_orig_params(inp)
1266
        torch.testing.assert_close(out_use_flat_params, out_use_orig_params)
1267
        out_use_flat_params.backward(grad)
1268
        out_use_orig_params.backward(grad)
1269
        _check_param_grad_parity(model_use_flat_params, model_use_orig_params)
1270
        ref_grads_use_flat_params = [
1271
            param.grad.detach().clone() for param in model_use_flat_params.parameters()
1272
        ]
1273
        ref_grads_use_orig_params = [
1274
            param.grad.detach().clone()
1275
            for param in model_use_orig_params.parameters()
1276
            if param.grad is not None
1277
        ]
1278

1279
        # Run a forward/backward in `no_sync()`
1280
        optim_use_flat_params.zero_grad(set_to_none=True)
1281
        optim_use_orig_params.zero_grad(set_to_none=True)
1282
        for model in (model_use_flat_params, model_use_orig_params):
1283
            with model.no_sync():
1284
                out = model(inp)
1285
                out.backward(grad)
1286
        _check_param_grad_parity(model_use_flat_params, model_use_orig_params)
1287

1288
        # Run a forward/backward outside `no_sync()`
1289
        for model in (model_use_flat_params, model_use_orig_params):
1290
            out = model(inp)
1291
            out.backward(grad)
1292
        _check_param_grad_parity(model_use_flat_params, model_use_orig_params)
1293

1294
        # Check that, since we accumulated gradients across 2 iterations, that
1295
        # the new gradients are 2x the reference gradients
1296
        grads_use_flat_params = [
1297
            param.grad.detach().clone() for param in model_use_flat_params.parameters()
1298
        ]
1299
        grads_use_orig_params = [
1300
            param.grad.detach().clone()
1301
            for param in model_use_orig_params.parameters()
1302
            if param.grad is not None
1303
        ]
1304
        for grad, ref_grad in zip(grads_use_flat_params, ref_grads_use_flat_params):
1305
            torch.testing.assert_close(grad, 2 * ref_grad)
1306
        for grad, ref_grad in zip(grads_use_orig_params, ref_grads_use_orig_params):
1307
            torch.testing.assert_close(grad, 2 * ref_grad)
1308

1309
    @skip_if_lt_x_gpu(2)
1310
    def test_no_sync_mixed_precision(self):
1311
        """
1312
        Tests that dtypes are as expected when using ``no_sync()`` with
1313
        ``use_orig_params=True`` and parameter mixed precision.
1314
        """
1315
        self.run_subtests(
1316
            {
1317
                "sharding_strategy": [
1318
                    ShardingStrategy.FULL_SHARD,
1319
                    ShardingStrategy.SHARD_GRAD_OP,
1320
                    ShardingStrategy.NO_SHARD,
1321
                ]
1322
            },
1323
            self._test_no_sync_mixed_precision,
1324
        )
1325

1326
    def _test_no_sync_mixed_precision(self, sharding_strategy: ShardingStrategy):
1327
        model = nn.Linear(3, 3, device="cuda")
1328
        mixed_precision = MixedPrecision(
1329
            param_dtype=torch.float16,
1330
            reduce_dtype=torch.float32,
1331
        )
1332
        fsdp_kwargs = {
1333
            "sharding_strategy": sharding_strategy,
1334
            "mixed_precision": mixed_precision,
1335
            "use_orig_params": True,
1336
        }
1337
        fsdp_model = FSDP(model, **fsdp_kwargs)
1338
        inp = torch.randn((2, 3), device="cuda")
1339
        with fsdp_model.no_sync():
1340
            # For each of these `no_sync()` backward passes, check that the
1341
            # gradients are in the low precision parameter dtype (FP16)
1342
            fsdp_model(inp).sum().backward()
1343
            for param in fsdp_model.parameters():
1344
                if param.grad is not None:
1345
                    self.assertEqual(param.grad.dtype, torch.float16)
1346
            fsdp_model(inp).sum().backward()
1347
            for param in fsdp_model.parameters():
1348
                if param.grad is not None:
1349
                    self.assertEqual(param.grad.dtype, torch.float16)
1350
        # For the backward pass outside `no_sync()`, check that the gradients
1351
        # are cast to the full precision in preparation for the optimizer step
1352
        fsdp_model(inp).sum().backward()
1353
        for param in fsdp_model.parameters():
1354
            if param.grad is not None:
1355
                self.assertEqual(param.grad.dtype, torch.float32)
1356

1357

1358
class TestFSDPUseOrigParamsInit(FSDPTest):
1359
    @skip_if_lt_x_gpu(2)
1360
    def test_non_uniform_requires_grad(self):
1361
        model = nn.Sequential(
1362
            nn.Linear(3, 3, device="cuda"),
1363
            nn.Linear(3, 3, device="cuda"),
1364
        )
1365
        # Freeze biases only and flatten both weights and biases into the same
1366
        # `FlatParameter` to exercise non-uniform `requires_grad`
1367
        model[0].bias.requires_grad = False
1368
        model[1].bias.requires_grad = False
1369
        fsdp_model = FSDP(model, use_orig_params=True)
1370
        self.assertTrue(fsdp_model[0].weight.requires_grad)
1371
        self.assertFalse(fsdp_model[0].bias.requires_grad)
1372
        self.assertTrue(fsdp_model[1].weight.requires_grad)
1373
        self.assertFalse(fsdp_model[1].bias.requires_grad)
1374

1375

1376
# Define this to be large enough to trigger stack corruption
1377
NUM_SIZE0_TENSORS = 1000
1378

1379

1380
class TestMultiTensorApply(TestCase):
1381
    def test_multi_tensor_apply_size0_tensors_cpu(self):
1382
        size0_tensors = [torch.empty(0, device="cpu") for _ in range(NUM_SIZE0_TENSORS)]
1383
        # Check that this does not segfault
1384
        torch._foreach_mul_(size0_tensors, 0.1)
1385

1386
    @unittest.skipIf(not TEST_CUDA, "no cuda")
1387
    def test_multi_tensor_apply_size0_tensors_cuda(self):
1388
        size0_tensors = [
1389
            torch.empty(0, device="cuda") for _ in range(NUM_SIZE0_TENSORS)
1390
        ]
1391
        # Check that this does not segfault
1392
        torch._foreach_mul_(size0_tensors, 0.1)
1393

1394

1395
instantiate_parametrized_tests(TestFSDPUseOrigParamsMultipleParamGroups)
1396
instantiate_parametrized_tests(TestFSDPUseOrigParamsUnshardReshard)
1397
instantiate_parametrized_tests(TestFSDPUseOrigParamsParamAccess)
1398
instantiate_parametrized_tests(TestFSDPUseOrigParamsFQNs)
1399
instantiate_parametrized_tests(TestFSDPUseOrigParamsNoSync)
1400

1401
if __name__ == "__main__":
1402
    run_tests()
1403

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.