pytorch

Форк
0
/
test_fsdp_optim_state.py 
2014 строк · 78.1 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import bisect
4
import sys
5
from copy import deepcopy
6
from enum import auto, Enum
7
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
8

9
import torch
10
import torch.nn as nn
11
from torch import distributed as dist
12
from torch.distributed._shard.sharded_tensor import ShardedTensor
13
from torch.distributed._state_dict_utils import _gather_state_dict
14
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
15
    _CHECKPOINT_WRAPPED_MODULE,
16
    apply_activation_checkpointing,
17
)
18
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
19
from torch.distributed.fsdp.api import ShardingStrategy
20
from torch.distributed.fsdp.fully_sharded_data_parallel import (
21
    FullOptimStateDictConfig,
22
    FullStateDictConfig,
23
    OptimStateKeyType,
24
    ShardedOptimStateDictConfig,
25
    ShardedStateDictConfig,
26
    StateDictSettings,
27
    StateDictType,
28
)
29
from torch.distributed.optim import _NamedOptimizer
30
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
31
from torch.testing._internal.common_fsdp import (
32
    CUDAInitMode,
33
    FSDPInitMode,
34
    FSDPTest,
35
    TransformerWithSharedParams,
36
)
37
from torch.testing._internal.common_utils import (
38
    instantiate_parametrized_tests,
39
    parametrize,
40
    run_tests,
41
    TEST_WITH_DEV_DBG_ASAN,
42
)
43

44
STATE_DICT_TYPES = [StateDictType.FULL_STATE_DICT, StateDictType.SHARDED_STATE_DICT]
45

46
if not dist.is_available():
47
    print("Distributed not available, skipping tests", file=sys.stderr)
48
    sys.exit(0)
49

50
if TEST_WITH_DEV_DBG_ASAN:
51
    print(
52
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
53
        file=sys.stderr,
54
    )
55
    sys.exit(0)
56

57

58
class _OSDCommMethod(Enum):
59
    """Method for communicating the optimizer state dict for internal tests."""
60

61
    BROADCAST_OBJECT_LIST = auto()
62
    SCATTER_FULL_OSD = auto()
63
    FLATTEN_SHARDED_OSD = auto()
64
    OPTIM_STATE_DICT = auto()
65

66

67
class _ModelClass(Enum):
68
    """Different model type to test."""
69

70
    NESTED = auto()
71
    TRANSFORMER = auto()
72

73

74
class Bias(torch.nn.Module):
75
    """This module applies a 1D additive bias with dimension ``dim``."""
76

77
    def __init__(self, dim: int) -> None:
78
        super().__init__()
79
        assert dim > 0
80
        torch.manual_seed(0)
81
        self.bias = torch.nn.Parameter(torch.randn((dim,)))
82

83
    def forward(self, x):
84
        return x + self.bias
85

86

87
class BlockA(torch.nn.Module):
88
    """
89
    Used to define interesting nested structure for FSDP wrapping.
90
    BlockA
91
        Bias0
92
            bias
93
        weight
94
        Bias1
95
            bias
96
    """
97

98
    def __init__(self, in_dim: int, out_dim: int) -> None:
99
        super().__init__()
100
        assert all(v > 0 for v in (in_dim, out_dim))
101
        torch.manual_seed(0)
102
        self.bias_module0 = Bias(out_dim)
103
        self.weight = torch.nn.Parameter(torch.randn((in_dim, out_dim)))
104
        self.bias_module1 = Bias(out_dim)
105
        self.relu = torch.nn.ReLU()
106

107
    def forward(self, x):
108
        x = x @ self.weight
109
        x = self.bias_module0(x)
110
        x = self.relu(x)  # ensure biases have different gradients
111
        x = self.bias_module1(x)
112
        return x
113

114

115
class BlockB(torch.nn.Module):
116
    """
117
    Used to define interesting nested structure for FSDP wrapping.
118
    BlockB
119
        weight
120
        Bias
121
            bias
122
        Bias
123
            bias
124
    """
125

126
    def __init__(self, in_dim: int, out_dim: int) -> None:
127
        super().__init__()
128
        assert all(v > 0 for v in (in_dim, out_dim))
129
        torch.manual_seed(0)
130
        self.weight = torch.nn.Parameter(torch.randn((in_dim, out_dim)))
131
        self.bias_module0 = Bias(out_dim)
132
        self.bias_module1 = Bias(out_dim)
133
        self.relu = torch.nn.ReLU()
134

135
    def forward(self, x):
136
        x = x @ self.weight
137
        x = self.bias_module0(x)
138
        x = self.relu(x)  # ensure biases have different gradients
139
        x = self.bias_module1(x)
140
        return x
141

142

143
class NestedModel(torch.nn.Module):
144
    def __init__(self) -> None:
145
        super().__init__()
146
        self.block0 = BlockB(5, 3)
147
        self.block1 = BlockB(3, 7)
148
        self.bias = torch.nn.Parameter(torch.randn((5,)))
149
        self.block2 = torch.nn.Sequential(
150
            BlockA(7, 9),
151
            BlockA(9, 9),
152
            BlockB(9, 5),
153
        )
154
        self.relu = torch.nn.ReLU()
155

156
    def forward(self, x) -> torch.Tensor:
157
        x = self.relu(self.block0(x))
158
        x = self.relu(self.block1(x))
159
        x = self.relu(self.block2(x))
160
        x = x + self.bias
161
        return x
162

163
    def get_input(self, device):
164
        BATCH_SIZE = 8
165
        return (torch.randn((BATCH_SIZE, 5)).to(device),)
166

167
    def get_loss(self, inp, output):
168
        return output.sum()
169

170
    def run_backward(self, loss):
171
        loss.backward()
172

173
    @staticmethod
174
    def wrap(
175
        model: torch.nn.Module,
176
        group: Optional[dist.ProcessGroup] = None,
177
        ignore_modules: bool = False,
178
        fsdp_kwargs: Optional[Dict[str, Any]] = None,
179
    ) -> torch.nn.Module:
180
        if fsdp_kwargs is None:
181
            fsdp_kwargs = {}
182
        # Flatten Bias0; then flatten weight and Bias1 together into `block1`
183
        model.block1.bias_module0 = FSDP(
184
            model.block1.bias_module0,
185
            process_group=group,
186
            **fsdp_kwargs,
187
        )
188
        model.block1 = FSDP(model.block1, process_group=group, **fsdp_kwargs)
189
        # Flatten Bias0; flatten Bias1; then flatten weight into `block2[1]`
190
        model.block2[1].bias_module0 = FSDP(
191
            model.block2[1].bias_module0,
192
            process_group=group,
193
            **fsdp_kwargs,
194
        )
195
        model.block2[1].bias_module1 = FSDP(
196
            model.block2[1].bias_module1,
197
            process_group=group,
198
            **fsdp_kwargs,
199
        )
200
        model.block2[1] = FSDP(model.block2[1], process_group=group, **fsdp_kwargs)
201
        # Flatten weight, Bias, bias into `block2[2]`
202
        ignored_modules = [model.block2[2].bias_module0] if ignore_modules else None
203
        model.block2[2] = FSDP(
204
            model.block2[2],
205
            process_group=group,
206
            ignored_modules=ignored_modules,
207
            **fsdp_kwargs,
208
        )
209
        return model
210

211
    @staticmethod
212
    def wrap_alt(
213
        model: torch.nn.Module,
214
        group: Optional[dist.ProcessGroup] = None,
215
        fsdp_kwargs: Optional[Dict[str, Any]] = None,
216
    ) -> torch.nn.Module:
217
        if fsdp_kwargs is None:
218
            fsdp_kwargs = {}
219
        model.block0.bias_module0 = FSDP(
220
            model.block0.bias_module0,
221
            process_group=group,
222
            **fsdp_kwargs,
223
        )
224
        model.block0 = FSDP(model.block0, process_group=group, **fsdp_kwargs)
225
        return model
226

227
    @staticmethod
228
    def wrap_with_unmanaged_params(
229
        model,
230
        add_to_fsdp_module: bool,
231
        group=None,
232
    ) -> Tuple[torch.nn.Module, List[torch.nn.Parameter]]:
233
        """Registers unmanaged parameters before wrapping with :meth:`wrap`."""
234
        device = next(model.parameters()).device
235
        unmanaged_param = torch.nn.Parameter(torch.randn(5, 5, device=device))
236
        # Either register the parameter to a module to be wrapped with FSDP
237
        # (`model.block2[2]`) or a module not to be wrapped with FSDP (`model`)
238
        register_module = model.block2[2] if add_to_fsdp_module else model
239
        register_module.register_parameter(
240
            "unmanaged_param",
241
            unmanaged_param,
242
        )
243
        # For simplicity, we only add a single unmanaged parameter, but should
244
        # be easy to generalize if needed
245
        return NestedModel.wrap(model, group), [unmanaged_param]
246

247
    @staticmethod
248
    def add_unmanaged_param_entry(osd, unmanaged_param, step) -> None:
249
        """Adds an entry for the unmanaged parameter ``unmanaged_param``
250
        assuming Adam optimizer and a single parameter group."""
251
        # The unmanaged parameters should be passed to this method in
252
        # `model.parameters()` order since their parameter IDs will be assigned
253
        # in order of the skipped IDs
254
        # Assign a parameter ID to the unmanaged parameter
255
        unmanaged_param_id = -1
256
        param_ids = osd["param_groups"][0]["params"]
257
        for i in range(1, len(param_ids)):
258
            diff = param_ids[i] - param_ids[i - 1]
259
            if diff != 1:
260
                assert diff > 1, f"Invalid IDs: {param_ids[i - 1]} {param_ids[i]}"
261
                unmanaged_param_id = param_ids[i - 1] + 1
262
                break
263
        if unmanaged_param_id == -1:
264
            unmanaged_param_id = len(param_ids)  # last ID skipped
265
        assert unmanaged_param_id >= 0, "One parameter ID should be skipped"
266
        # Add a state entry for the unmanaged parameter
267
        state_device = next(iter(next(iter(osd["state"].values())).values())).device
268
        osd["state"][unmanaged_param_id] = {
269
            "step": torch.tensor(float(step), device=state_device),
270
            "exp_avg": torch.randn(unmanaged_param.shape, device=state_device),
271
            "exp_avg_sq": torch.randn(unmanaged_param.shape, device=state_device),
272
        }
273
        # Insert the ID into the parameter group in order
274
        bisect.insort(osd["param_groups"][0]["params"], unmanaged_param_id)
275

276
    # NOTE: We exclude `self.bias` from either parameter group to test the
277
    # case where the optimizer input does not include all model parameters
278
    def param_group0(self) -> List[torch.nn.Parameter]:
279
        # Use `block1`'s parameters for the first parameter group to deviate
280
        # from the `model.parameters()` order
281
        return list(self.block1.parameters())
282

283
    def param_group1(self) -> List[torch.nn.Parameter]:
284
        # Deviate from the `model.parameters()` order further by rearranging
285
        # `block2`'s parameters to be before `block0`'s parameters
286
        return list(self.block2.parameters()) + list(self.block0.parameters())
287

288

289
# Simple and boring model to test interface and some corner cases that do not
290
# require complicated wrapping strategy.
291
class TestDummyModel(torch.nn.Module):
292
    def __init__(self, no_grad: bool = False):
293
        super().__init__()
294
        torch.manual_seed(0)
295
        self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
296
        self.net1[0].weight.requires_grad = not no_grad
297
        self.net1[0].bias.requires_grad = not no_grad
298
        self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
299
        self.net3 = nn.Linear(32, 64)
300
        self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
301

302
    def forward(self, x):
303
        return self.net4(self.net3(self.net2(self.net1(x))))
304

305
    def get_input(self):
306
        return torch.rand(8, 8, device="cuda")
307

308

309
class TestFSDPOptimState(FSDPTest):
310
    def __init__(self, *args, **kwargs):
311
        super().__init__(*args, **kwargs)
312
        self._model_class = {
313
            _ModelClass.NESTED: self._init_nested_model,
314
            _ModelClass.TRANSFORMER: self._init_transformer_model,
315
        }
316

317
    def _init_nested_model(
318
        self,
319
        wrap: bool,
320
        wrap_alt: bool = False,  # ignored if `wrap=False`
321
        device: torch.device = torch.device("cuda"),
322
        group=None,
323
        optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
324
        use_multiple_param_groups: bool = False,
325
        use_diff_optim_inputs: bool = False,
326
        fsdp_kwargs: Optional[Dict[str, Any]] = None,
327
    ):
328
        model = NestedModel().to(device)
329
        if wrap:
330
            model = (
331
                NestedModel.wrap_alt(model, group, fsdp_kwargs)
332
                if wrap_alt
333
                else NestedModel.wrap(model, group, fsdp_kwargs=fsdp_kwargs)
334
            )
335
        if not use_multiple_param_groups:
336
            optim_input = list(model.parameters())
337
        else:
338
            optim_input = [
339
                {"params": model.param_group0()},
340
                {"params": model.param_group1(), "weight_decay": 0.9},
341
            ]
342
        # Use a reversed parameter order for the optimizer input on odd ranks
343
        if use_diff_optim_inputs and self.rank % 2 == 1:
344
            if isinstance(optim_input[0], dict):
345
                for param_group in optim_input:
346
                    param_group["params"] = list(reversed(param_group["params"]))
347
            else:
348
                optim_input = list(reversed(optim_input))
349
        optim = optim_class(optim_input, lr=0.01)
350
        return model, optim, optim_input
351

352
    def _init_transformer_model(
353
        self,
354
        wrap: bool,
355
        device: torch.device = torch.device("cuda"),
356
        group=None,
357
        optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
358
        use_multiple_param_groups: bool = False,
359
        use_diff_optim_inputs: bool = False,
360
    ):
361
        if use_multiple_param_groups or use_diff_optim_inputs:
362
            # Keep these as arguments for parity with `_init_nested_model()`;
363
            # these settings are not implemented since the transformer is
364
            # wrapped with FSDP at the top-level, which means that there is
365
            # only a single flat parameter, making these booleans vacuous
366
            raise NotImplementedError()
367
        if group is None:
368
            group = dist.distributed_c10d._get_default_group()
369
        model = TransformerWithSharedParams.init(
370
            group,
371
            FSDPInitMode.RECURSIVE if wrap else FSDPInitMode.NO_FSDP,
372
            CUDAInitMode.CUDA_BEFORE,
373
            deterministic=True,
374
        )
375
        optim = optim_class(model.parameters(), lr=0.01)
376
        return model, optim, None
377

378
    def _step_model(
379
        self,
380
        model: torch.nn.Module,
381
        optim: torch.optim.Optimizer,
382
        device: torch.device = torch.device("cuda"),
383
        num_iters: int = 1,
384
    ) -> List[float]:
385
        """Performs a forward pass, backward pass, and optimizer step
386
        ``num_iters``-many times, and returns the per-iteration losses."""
387
        torch.manual_seed(0)  # set seed for determinism
388
        losses = []
389
        module = getattr(model, "module", model)
390
        for _ in range(num_iters):
391
            optim.zero_grad()
392
            inp = module.get_input(device)
393
            output = model(*inp)
394
            loss = module.get_loss(inp, output).to(device)
395
            losses.append(loss.item())
396
            module.run_backward(loss)
397
            optim.step()
398
        return losses
399

400
    def _broadcast_full_osd(self, full_osd: Dict[str, Any], group=None):
401
        """Broadcasts the full optimizer state dict in place of using
402
        ``torch.save()`` and ``torch.load()`` so that all ranks can have it."""
403
        obj_list = [full_osd]
404
        dist.broadcast_object_list(
405
            obj_list,
406
            src=0,
407
            group=group,
408
        )
409
        full_osd = obj_list[0]
410
        return full_osd
411

412
    def _are_equal_states(
413
        self,
414
        state1: Dict[str, Any],
415
        state2: Dict[str, Any],
416
    ) -> bool:
417
        """Checks if ``state1`` and ``state2`` contain the same mappings."""
418
        if set(state1.keys()) != set(state2.keys()):
419
            return False
420
        for state_name, value1 in state1.items():
421
            value2 = state2[state_name]
422
            if type(value1) != type(value2):
423
                return False
424
            if torch.is_tensor(value1):  # tensor state
425
                assert torch.is_tensor(value2)
426
                # Check the values on CPU to be device-agnostic
427
                value1 = value1.cpu()
428
                value2 = value2.cpu()
429
                if value1.shape != value2.shape or not torch.all(
430
                    torch.isclose(value1, value2)
431
                ):
432
                    return False
433
            else:  # non-tensor state
434
                if value1 != value2:
435
                    return False
436
        return True
437

438
    def _check_same_state(
439
        self,
440
        fsdp_osd,
441
        ref_osd,
442
        check_same_param_keys: bool,
443
    ):
444
        """Checks that ``full_osd`` and ``ref_osd`` have the same "state" part.
445
        If ``check_same_param_keys=True``, then checks that the parameter keys
446
        match (e.g. when both should be parameter names), and does not check
447
        the parameter keys otherwise."""
448
        assert "state" in ref_osd
449
        self.assertTrue("state" in fsdp_osd)
450
        ref_osd_state = ref_osd["state"]
451
        fsdp_osd_state = {
452
            k: _gather_state_dict(v) for k, v in fsdp_osd["state"].items()
453
        }
454

455
        if check_same_param_keys:
456
            # Check parameter keys are the same first for earlier erroring
457
            ref_osd_param_ids = set(ref_osd_state.keys())
458
            fsdp_osd_param_ids = set(fsdp_osd_state.keys())
459
            self.assertTrue(
460
                ref_osd_param_ids == fsdp_osd_param_ids,
461
                f"Rank {self.rank}: {(ref_osd_param_ids, fsdp_osd_param_ids)}",
462
            )
463
            # Check state values are the same
464
            for param_id, param_state in fsdp_osd_state.items():
465
                for state_name, value in param_state.items():
466
                    ref_value = ref_osd_state[param_id][state_name]
467
                    self.assertEqual(value, ref_value)
468
            return
469
        # Otherwise, only require the parameter keys to be isomorphic (e.g.
470
        # between IDs and names)
471
        ref_osd_states = list(ref_osd_state.values())
472
        fsdp_osd_states = list(fsdp_osd_state.values())
473
        self.assertEqual(len(ref_osd_states), len(fsdp_osd_states))
474
        # Use brute-force quadratic-time comparison since it is hard to
475
        # hash a tensor by value instead of by object
476
        for fsdp_osd_state in fsdp_osd_states:
477
            # Check for at least one match (may be > 1 in toy edge cases, e.g.
478
            # multiple biases); nonetheless, each having >= 1 match and the two
479
            # lists having equal length imply that the list contents are equal
480
            self.assertTrue(
481
                any(
482
                    self._are_equal_states(fsdp_osd_state, ref_osd_state)
483
                    for ref_osd_state in ref_osd_states
484
                )
485
            )
486

487
    def _check_same_param_groups(
488
        self,
489
        full_osd,
490
        ref_osd,
491
        check_same_param_keys: bool,
492
    ):
493
        """Checks that ``full_osd`` and ``ref_osd`` have the same
494
        "param_groups" part. If ``check_same_param_keys=True`, then checks that
495
        the parameter keys match (e.g. when both should be parameter names),
496
        and does not check the parameter keys otherwise."""
497
        assert "param_groups" in ref_osd
498
        self.assertTrue("param_groups" in full_osd)
499
        ref_osd_param_groups = ref_osd["param_groups"]
500
        full_osd_param_groups = full_osd["param_groups"]
501
        self.assertTrue(len(full_osd_param_groups), len(ref_osd_param_groups))
502
        for full_osd_pg, ref_osd_pg in zip(
503
            full_osd_param_groups,
504
            ref_osd_param_groups,
505
        ):
506
            self.assertEqual(
507
                set(full_osd_pg.keys()),
508
                set(ref_osd_pg.keys()),
509
            )
510
            for name, full_osd_value in full_osd_pg.items():
511
                if name == "params" and not check_same_param_keys:
512
                    continue
513
                self.assertEqual(full_osd_value, ref_osd_pg[name])
514

515
    @skip_if_lt_x_gpu(2)
516
    @parametrize("state_dict_type", STATE_DICT_TYPES)
517
    @parametrize("use_multiple_param_groups", [False, True])
518
    @parametrize("rank0_only", [False, True])
519
    @parametrize("use_diff_optim_inputs", [False, True])
520
    def test_optim_state_dict_nested(
521
        self,
522
        state_dict_type: StateDictType,
523
        use_multiple_param_groups: bool,
524
        rank0_only: bool,
525
        use_diff_optim_inputs: bool,
526
    ) -> None:
527
        """
528
        Tests :meth:`full_optim_state_dict` and meth:`sharded_optim_state_dict`
529
        by comparing the returned dict for an FSDP-wrapped model with that of
530
        an equivalent non-wrapped model.
531

532
        The test checks the equivalence excluding the parameter keys since the
533
        FSDP and normal optimizer state dicts key by names and IDs,
534
        respectively. This means that the test can pass even if parameter keys
535
        are incorrectly mapped to values. Their correct mapping is tested in
536
        other tests that exercise the save/load workflow.
537
        """
538
        self.run_subtests(
539
            {"use_optim_input": [False, True]},
540
            self._test_optim_state_dict_nested,
541
            state_dict_type=state_dict_type,
542
            use_multiple_param_groups=use_multiple_param_groups,
543
            rank0_only=rank0_only,
544
            use_diff_optim_inputs=use_diff_optim_inputs,
545
        )
546

547
    def _test_optim_state_dict_nested(
548
        self,
549
        state_dict_type: StateDictType,
550
        use_multiple_param_groups: bool,
551
        rank0_only: bool,
552
        use_diff_optim_inputs: bool,
553
        use_optim_input: bool,
554
    ) -> None:
555
        if rank0_only and state_dict_type == StateDictType.SHARDED_STATE_DICT:
556
            return  # not supported
557
        NUM_ITERS = 3
558
        model1, optim1, optim_input = self._init_nested_model(
559
            wrap=True,
560
            use_multiple_param_groups=use_multiple_param_groups,
561
            use_diff_optim_inputs=use_diff_optim_inputs,
562
        )
563
        losses1 = self._step_model(model1, optim1, num_iters=NUM_ITERS)
564
        if state_dict_type == StateDictType.FULL_STATE_DICT:
565
            if use_optim_input:
566
                fsdp_osd = FSDP.full_optim_state_dict(
567
                    model1,
568
                    optim1,
569
                    optim_input,
570
                    rank0_only=rank0_only,
571
                )
572
            else:
573
                fsdp_osd = FSDP.full_optim_state_dict(
574
                    model1,
575
                    optim1,
576
                    rank0_only=rank0_only,
577
                )
578
        else:
579
            fsdp_osd = FSDP.sharded_optim_state_dict(model1, optim1)
580
        # Non-target ranks get an empty state dict
581
        if rank0_only and self.rank != 0:
582
            self.assertEqual(len(fsdp_osd), 0)
583
            return
584
        model2, optim2, _ = self._init_nested_model(
585
            wrap=False,
586
            use_multiple_param_groups=use_multiple_param_groups,
587
            use_diff_optim_inputs=use_diff_optim_inputs,
588
        )
589
        losses2 = self._step_model(model2, optim2, num_iters=NUM_ITERS)
590
        ref_osd = optim2.state_dict()
591
        # Check the losses to eliminate model drift as a source of error
592
        for i, (l1, l2) in enumerate(zip(losses1, losses2)):
593
            assert l1 == l2, f"Losses differ on iter {i}: {l1:.5f} {l2:.5f}"
594
        # Do not check the parameter keys since the full/sharded optimizer state
595
        # dict uses parameter names, while the non-wrapped equivalent uses
596
        # parameter IDs
597
        check_same_param_keys = False
598
        self._check_same_param_groups(
599
            fsdp_osd,
600
            ref_osd,
601
            check_same_param_keys=check_same_param_keys,
602
        )
603
        self._check_same_state(
604
            fsdp_osd,
605
            ref_osd,
606
            check_same_param_keys=check_same_param_keys,
607
        )
608

609
    @skip_if_lt_x_gpu(2)
610
    def test_full_optim_state_dict_keys(self):
611
        """Tests that the parameter keys returned by
612
        :meth:`full_optim_state_dict` match those of :meth:`state_dict` with
613
        full ``state_dict_type`` for a non-FSDP-root model with nested FSDP
614
        instances and ignored modules."""
615
        device = torch.device("cuda")
616
        model = NestedModel().to(device)
617
        wrapped_model = NestedModel.wrap(model, ignore_modules=True)
618
        # Add checkpointing to ensure optim_state_dict and state_dict strip out
619
        # checkpointing prefixes.
620
        apply_activation_checkpointing(
621
            model, check_fn=lambda module: isinstance(module, torch.nn.Sequential)
622
        )
623
        optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
624
        self._step_model(model, optim, device)
625
        optim_state_dict = FSDP.full_optim_state_dict(
626
            wrapped_model, optim, rank0_only=False
627
        )
628
        with FSDP.state_dict_type(wrapped_model, StateDictType.FULL_STATE_DICT):
629
            state_dict = wrapped_model.state_dict()
630
        self.assertEqual(optim_state_dict["state"].keys(), state_dict.keys())
631
        # Check that checkpointing prefix was indeed stripped.
632
        for key in optim_state_dict["state"]:
633
            self.assertNotIn(_CHECKPOINT_WRAPPED_MODULE, key)
634

635
    @skip_if_lt_x_gpu(2)
636
    def test_full_optim_state_dict_nested_invalid(self):
637
        """Tests that :meth:`full_optim_state_dict` raises an error when
638
        nonzero ranks are missing the optimizer state for parameters on rank
639
        0."""
640
        device = torch.device("cuda")
641
        model = NestedModel.wrap(NestedModel().to(device), None)
642
        optim_input = list(model.parameters())
643
        if self.rank != 0:
644
            # Exclude a parameter so that nonzero ranks are missing state
645
            optim_input = optim_input[:-1]
646
        optim = torch.optim.Adam(optim_input, lr=1e-3)
647
        self._step_model(model, optim, num_iters=3)
648
        error_regex = (
649
            "FSDP currently requires each rank to have at least the "
650
            "optimizer states needed by rank 0's optimizer but some ranks "
651
            "are missing some of those states"
652
        )
653
        with self.assertRaisesRegex(RuntimeError, error_regex):
654
            FSDP.full_optim_state_dict(model, optim)
655

656
    @skip_if_lt_x_gpu(2)
657
    @parametrize("use_multiple_param_groups", [False, True])
658
    @parametrize("wrap_alt", [False, True])
659
    @parametrize("use_diff_optim_inputs", [False, True])
660
    def test_shard_full_optim_state_dict_nested(
661
        self,
662
        use_multiple_param_groups: bool,
663
        wrap_alt: bool,
664
        use_diff_optim_inputs: bool,
665
    ):
666
        """Tests :meth:`shard_full_optim_state_dict` for a non-FSDP-root model
667
        with nested FSDP instances."""
668
        self.run_subtests(
669
            {"use_optim_input": [False, True]},
670
            self._test_load_optim_state,
671
            model_class=_ModelClass.NESTED,
672
            use_multiple_param_groups=use_multiple_param_groups,
673
            halve_world_size=False,
674
            osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST,
675
            use_diff_optim_inputs=use_diff_optim_inputs,
676
            wrap_alt=wrap_alt,
677
            num_iters=3,
678
        )
679

680
        self._test_load_optim_state_with_optim_state_dict(
681
            _ModelClass.NESTED,
682
            state_dict_settings=StateDictSettings(
683
                StateDictType.FULL_STATE_DICT,
684
                FullStateDictConfig(),
685
                FullOptimStateDictConfig(),
686
            ),
687
            use_multiple_param_groups=False,
688
            halve_world_size=False,
689
            use_diff_optim_inputs=use_diff_optim_inputs,
690
            wrap_alt=wrap_alt,
691
            num_iters=3,
692
        )
693

694
    @skip_if_lt_x_gpu(2)
695
    def test_shard_full_optim_state_dict_nested_halve_world_size(self):
696
        """Tests :meth:`shard_full_optim_state_dict` for a non-FSDP-root model
697
        with nested FSDP instances when loading into a new process group with
698
        halved world size."""
699
        # To save CI costs, we test with the "harder" settings:
700
        use_multiple_param_groups = True
701
        use_diff_optim_inputs = True
702
        wrap_alt = True
703
        self.run_subtests(
704
            {"use_optim_input": [False, True]},
705
            self._test_load_optim_state,
706
            model_class=_ModelClass.NESTED,
707
            use_multiple_param_groups=use_multiple_param_groups,
708
            halve_world_size=True,
709
            osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST,
710
            use_diff_optim_inputs=use_diff_optim_inputs,
711
            wrap_alt=wrap_alt,
712
            num_iters=3,
713
        )
714

715
        self._test_load_optim_state_with_optim_state_dict(
716
            _ModelClass.NESTED,
717
            state_dict_settings=StateDictSettings(
718
                StateDictType.FULL_STATE_DICT,
719
                FullStateDictConfig(),
720
                FullOptimStateDictConfig(),
721
            ),
722
            use_multiple_param_groups=use_multiple_param_groups,
723
            halve_world_size=True,
724
            use_diff_optim_inputs=use_diff_optim_inputs,
725
            wrap_alt=wrap_alt,
726
            num_iters=3,
727
        )
728

729
    @skip_if_lt_x_gpu(2)
730
    def test_shard_full_optim_state_dict_transformer(self) -> None:
731
        """Tests :meth:`shard_full_optim_state_dict` for an FSDP-root
732
        transformer model with shared parameters."""
733
        self.run_subtests(
734
            {"use_optim_input": [False, True]},
735
            self._test_load_optim_state,
736
            model_class=_ModelClass.TRANSFORMER,
737
            use_multiple_param_groups=False,
738
            halve_world_size=True,
739
            osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST,
740
            use_diff_optim_inputs=False,
741
            num_iters=3,
742
        )
743

744
        self._test_load_optim_state_with_optim_state_dict(
745
            _ModelClass.TRANSFORMER,
746
            state_dict_settings=StateDictSettings(
747
                StateDictType.FULL_STATE_DICT,
748
                FullStateDictConfig(),
749
                FullOptimStateDictConfig(),
750
            ),
751
            use_multiple_param_groups=False,
752
            halve_world_size=True,
753
            use_diff_optim_inputs=False,
754
            num_iters=3,
755
        )
756

757
    @skip_if_lt_x_gpu(2)
758
    @parametrize("use_multiple_param_groups", [False, True])
759
    @parametrize("wrap_alt", [False, True])
760
    @parametrize("use_diff_optim_inputs", [False, True])
761
    def test_scatter_full_optim_state_dict_nested(
762
        self,
763
        use_multiple_param_groups: bool,
764
        wrap_alt: bool,
765
        use_diff_optim_inputs: bool,
766
    ):
767
        """Tests :meth:`scatter_full_optim_state_dict` for a non-FSDP-root
768
        model with nested FSDP instances."""
769
        self.run_subtests(
770
            {"use_optim_input": [False, True]},
771
            self._test_load_optim_state,
772
            model_class=_ModelClass.NESTED,
773
            use_multiple_param_groups=use_multiple_param_groups,
774
            halve_world_size=False,
775
            osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD,
776
            use_diff_optim_inputs=use_diff_optim_inputs,
777
            wrap_alt=wrap_alt,
778
            num_iters=3,
779
        )
780

781
        self._test_load_optim_state_with_optim_state_dict(
782
            _ModelClass.NESTED,
783
            state_dict_settings=StateDictSettings(
784
                StateDictType.FULL_STATE_DICT,
785
                FullStateDictConfig(),
786
                FullOptimStateDictConfig(rank0_only=True),
787
            ),
788
            use_multiple_param_groups=use_multiple_param_groups,
789
            halve_world_size=False,
790
            use_diff_optim_inputs=use_diff_optim_inputs,
791
            wrap_alt=wrap_alt,
792
            num_iters=3,
793
        )
794

795
    @skip_if_lt_x_gpu(2)
796
    def test_scatter_full_optim_state_dict_nested_halve_world_size(self):
797
        """Tests :meth:`scatter_full_optim_state_dict` for a non-FSDP-root
798
        model with nested FSDP instances when loading into a new process group
799
        with halved world size."""
800
        # To save CI costs, we test with the "harder" settings:
801
        use_multiple_param_groups = True
802
        use_diff_optim_inputs = True
803
        wrap_alt = True
804
        self.run_subtests(
805
            {"use_optim_input": [False, True]},
806
            self._test_load_optim_state,
807
            model_class=_ModelClass.NESTED,
808
            use_multiple_param_groups=use_multiple_param_groups,
809
            halve_world_size=True,
810
            osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD,
811
            use_diff_optim_inputs=use_diff_optim_inputs,
812
            wrap_alt=wrap_alt,
813
            num_iters=3,
814
        )
815

816
        self._test_load_optim_state_with_optim_state_dict(
817
            _ModelClass.NESTED,
818
            state_dict_settings=StateDictSettings(
819
                StateDictType.FULL_STATE_DICT,
820
                FullStateDictConfig(),
821
                FullOptimStateDictConfig(rank0_only=True),
822
            ),
823
            use_multiple_param_groups=use_multiple_param_groups,
824
            halve_world_size=True,
825
            use_diff_optim_inputs=use_diff_optim_inputs,
826
            wrap_alt=wrap_alt,
827
            num_iters=3,
828
        )
829

830
    @skip_if_lt_x_gpu(2)
831
    def test_scatter_full_optim_state_dict_transformer(self) -> None:
832
        """Tests :meth:`scatter_full_optim_state_dict` for an FSDP-root
833
        transformer model with shared parameters."""
834
        self.run_subtests(
835
            {"use_optim_input": [False, True]},
836
            self._test_load_optim_state,
837
            model_class=_ModelClass.TRANSFORMER,
838
            use_multiple_param_groups=False,
839
            halve_world_size=True,
840
            osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD,
841
            use_diff_optim_inputs=False,
842
            num_iters=3,
843
        )
844

845
        self._test_load_optim_state_with_optim_state_dict(
846
            _ModelClass.TRANSFORMER,
847
            state_dict_settings=StateDictSettings(
848
                StateDictType.FULL_STATE_DICT,
849
                FullStateDictConfig(),
850
                FullOptimStateDictConfig(rank0_only=True),
851
            ),
852
            use_multiple_param_groups=False,
853
            halve_world_size=True,
854
            use_diff_optim_inputs=False,
855
            num_iters=3,
856
        )
857

858
    @skip_if_lt_x_gpu(2)
859
    def test_flatten_sharded_optim_state_dict_nested(self) -> None:
860
        """Tests :meth:`flatten_sharded_optim_state_dict` for an FSDP-root
861
        nested model."""
862
        self._test_load_optim_state(
863
            _ModelClass.NESTED,
864
            use_multiple_param_groups=False,
865
            halve_world_size=False,
866
            osd_comm_method=_OSDCommMethod.FLATTEN_SHARDED_OSD,
867
            use_diff_optim_inputs=False,
868
            use_optim_input=False,
869
            wrap_alt=True,
870
            num_iters=3,
871
        )
872

873
        self._test_load_optim_state_with_optim_state_dict(
874
            _ModelClass.NESTED,
875
            state_dict_settings=StateDictSettings(
876
                StateDictType.SHARDED_STATE_DICT,
877
                ShardedStateDictConfig(),
878
                ShardedOptimStateDictConfig(),
879
            ),
880
            use_multiple_param_groups=False,
881
            halve_world_size=False,
882
            use_diff_optim_inputs=False,
883
            wrap_alt=True,
884
            num_iters=3,
885
        )
886

887
    @skip_if_lt_x_gpu(2)
888
    def test_flatten_sharded_optim_state_dict_transformer(self) -> None:
889
        """Tests :meth:`flatten_sharded_optim_state_dict` for an FSDP-root
890
        transformer model."""
891
        self._test_load_optim_state(
892
            _ModelClass.TRANSFORMER,
893
            use_multiple_param_groups=False,
894
            halve_world_size=False,
895
            osd_comm_method=_OSDCommMethod.FLATTEN_SHARDED_OSD,
896
            use_diff_optim_inputs=False,
897
            use_optim_input=False,
898
            num_iters=3,
899
        )
900

901
        self._test_load_optim_state_with_optim_state_dict(
902
            _ModelClass.TRANSFORMER,
903
            state_dict_settings=StateDictSettings(
904
                StateDictType.SHARDED_STATE_DICT,
905
                ShardedStateDictConfig(),
906
                ShardedOptimStateDictConfig(),
907
            ),
908
            use_multiple_param_groups=False,
909
            halve_world_size=False,
910
            use_diff_optim_inputs=False,
911
            num_iters=3,
912
        )
913

914
    @skip_if_lt_x_gpu(2)
915
    def test_use_orig_params(self) -> None:
916
        """Tests :meth:`optim_state_dict` for an FSDP-root nested model."""
917
        self.run_subtests(
918
            {
919
                "halve_world_size": [True, False],
920
                "wrap_alt": [True, False],
921
            },
922
            self._test_load_optim_state_with_optim_state_dict,
923
            model_class=_ModelClass.NESTED,
924
            state_dict_settings=StateDictSettings(
925
                StateDictType.FULL_STATE_DICT,
926
                FullStateDictConfig(),
927
                FullOptimStateDictConfig(),
928
            ),
929
            use_multiple_param_groups=False,
930
            use_diff_optim_inputs=False,
931
            num_iters=3,
932
            fsdp_kwargs={"use_orig_params": True},
933
        )
934

935
        self.run_subtests(
936
            {
937
                "halve_world_size": [True, False],
938
                "wrap_alt": [True, False],
939
            },
940
            self._test_load_optim_state_with_optim_state_dict,
941
            model_class=_ModelClass.NESTED,
942
            state_dict_settings=StateDictSettings(
943
                StateDictType.FULL_STATE_DICT,
944
                FullStateDictConfig(),
945
                FullOptimStateDictConfig(rank0_only=True),
946
            ),
947
            use_multiple_param_groups=False,
948
            use_diff_optim_inputs=False,
949
            num_iters=3,
950
            fsdp_kwargs={"use_orig_params": True},
951
        )
952

953
        self.run_subtests(
954
            {
955
                "wrap_alt": [True, False],
956
            },
957
            self._test_load_optim_state_with_optim_state_dict,
958
            model_class=_ModelClass.NESTED,
959
            state_dict_settings=StateDictSettings(
960
                StateDictType.SHARDED_STATE_DICT,
961
                ShardedStateDictConfig(),
962
                ShardedOptimStateDictConfig(),
963
            ),
964
            use_multiple_param_groups=False,
965
            # We cannot test halve_world_size with SHARDED_STATE_DICT.
966
            halve_world_size=False,
967
            use_diff_optim_inputs=False,
968
            num_iters=3,
969
            fsdp_kwargs={"use_orig_params": True},
970
        )
971

972
    def _test_load_optim_state(
973
        self,
974
        model_class: _ModelClass,
975
        use_multiple_param_groups: bool,
976
        halve_world_size: bool,
977
        osd_comm_method: _OSDCommMethod,
978
        use_diff_optim_inputs: bool,
979
        use_optim_input: bool,
980
        num_iters: int,
981
        **new_model_kwargs,
982
    ):
983
        """
984
        (1) Runs a model with full world size for K iterations to generate a
985
        full/sharded optimizer state dict;
986
        (2) initializes a model with halved world size and possibly different
987
        FSDP wrapping scheme (based on ``new_model_kwargs``);
988
        (3) loads the full/sharded optimizer state dict from (1) according to the
989
        halved-world-size model;
990
        (4) runs the halved-world-size model for K iterations; and
991
        (5) checks that the sharded optimizer state dict from (3) matches the
992
        halved-world-size model's local optimizer state dict, meaning that the
993
        former could have equivalently been loaded into the local optimizer.
994
        """
995
        initializer = self._model_class[model_class]
996
        if osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT:
997
            osd_method = FSDP.optim_state_dict
998
        elif osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD:
999
            osd_method = FSDP.sharded_optim_state_dict
1000
        else:
1001
            osd_method = FSDP.full_optim_state_dict
1002

1003
        # First, run a wrapped model with full world size for a few iterations
1004
        model1, optim1, optim_input1 = initializer(
1005
            wrap=True,
1006
            use_multiple_param_groups=use_multiple_param_groups,
1007
        )
1008
        self._step_model(model1, optim1, num_iters=num_iters)
1009
        fsdp_osd1 = (
1010
            osd_method(model1, optim1, optim_input1)
1011
            if use_optim_input
1012
            else osd_method(model1, optim1)
1013
        )
1014
        if halve_world_size:
1015
            # Create a new process group with halved world size
1016
            new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0]
1017
            new_group = dist.new_group(ranks=new_group_ranks)
1018
            if self.rank not in new_group_ranks:
1019
                return
1020
        else:
1021
            # Continue using the same group and hence world size
1022
            new_group = dist.distributed_c10d._get_default_group()
1023
        # Second, run a wrapped model with (possibly) halved world size and
1024
        # (possibly) differing `optim_input` across ranks
1025
        model2, optim2, optim_input2 = initializer(
1026
            wrap=True,
1027
            group=new_group,
1028
            use_multiple_param_groups=use_multiple_param_groups,
1029
            use_diff_optim_inputs=use_diff_optim_inputs,
1030
            **new_model_kwargs,  # specify `wrap_alt` to change wrapping
1031
        )
1032
        self._step_model(model2, optim2, num_iters=num_iters)
1033
        fsdp_osd2 = (
1034
            osd_method(model2, optim2, optim_input2, group=new_group)
1035
            if use_optim_input
1036
            else osd_method(model2, optim2, group=new_group)
1037
        )
1038
        # Compute two sharded optim state dicts: (1) for the first model
1039
        # according to the second model and (2) for the second model according
1040
        # to the second model
1041
        if osd_comm_method == _OSDCommMethod.BROADCAST_OBJECT_LIST:
1042
            fsdp_osd1 = self._broadcast_full_osd(fsdp_osd1, group=new_group)
1043
            sharded_osd1 = (
1044
                FSDP.shard_full_optim_state_dict(
1045
                    fsdp_osd1, model2, optim_input=optim_input2
1046
                )
1047
                if use_optim_input
1048
                else FSDP.shard_full_optim_state_dict(fsdp_osd1, model2, optim=optim2)
1049
            )
1050
            fsdp_osd2 = self._broadcast_full_osd(fsdp_osd2, group=new_group)
1051
            sharded_osd2 = (
1052
                FSDP.shard_full_optim_state_dict(
1053
                    fsdp_osd2, model2, optim_input=optim_input2
1054
                )
1055
                if use_optim_input
1056
                else FSDP.shard_full_optim_state_dict(fsdp_osd2, model2, optim=optim2)
1057
            )
1058
        elif osd_comm_method == _OSDCommMethod.SCATTER_FULL_OSD:
1059
            sharded_osd1 = (
1060
                FSDP.scatter_full_optim_state_dict(
1061
                    fsdp_osd1 if self.rank == 0 else None,
1062
                    model2,
1063
                    optim_input=optim_input2,
1064
                    group=new_group,
1065
                )
1066
                if use_optim_input
1067
                else FSDP.scatter_full_optim_state_dict(
1068
                    fsdp_osd1 if self.rank == 0 else None,
1069
                    model2,
1070
                    optim=optim2,
1071
                    group=new_group,
1072
                )
1073
            )
1074
            sharded_osd2 = (
1075
                FSDP.scatter_full_optim_state_dict(
1076
                    fsdp_osd2 if self.rank == 0 else None,
1077
                    model2,
1078
                    optim_input=optim_input2,
1079
                    group=new_group,
1080
                )
1081
                if use_optim_input
1082
                else FSDP.scatter_full_optim_state_dict(
1083
                    fsdp_osd2 if self.rank == 0 else None,
1084
                    model2,
1085
                    optim=optim2,
1086
                    group=new_group,
1087
                )
1088
            )
1089
        elif osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD:
1090
            sharded_osd1 = FSDP.flatten_sharded_optim_state_dict(
1091
                fsdp_osd1,
1092
                model2,
1093
                optim=optim2,
1094
            )
1095
            sharded_osd2 = FSDP.flatten_sharded_optim_state_dict(
1096
                fsdp_osd2,
1097
                model2,
1098
                optim=optim2,
1099
            )
1100
        elif osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT:
1101
            sharded_osd1 = FSDP.optim_state_dict_to_load(model2, optim2, fsdp_osd1)
1102
            sharded_osd2 = FSDP.optim_state_dict_to_load(model2, optim2, fsdp_osd2)
1103

1104
        # As a sanity check, check that sharding the second model's full/sharded
1105
        # optimizer state dict according to itself is equivalent to its local
1106
        # optimizer's state dict
1107
        local_osd2 = optim2.state_dict()
1108
        check_same_param_keys = True  # should all have matching parameter IDs
1109
        self._check_same_param_groups(
1110
            sharded_osd2,
1111
            local_osd2,
1112
            check_same_param_keys=check_same_param_keys,
1113
        )
1114
        self._check_same_state(
1115
            sharded_osd2,
1116
            local_osd2,
1117
            check_same_param_keys=check_same_param_keys,
1118
        )
1119
        # Check that sharding the first model's full/sharded optimizer state dict
1120
        # according to the second model is equivalent to the second model's
1121
        # local optimizer state dict
1122
        self._check_same_param_groups(
1123
            sharded_osd1,
1124
            local_osd2,
1125
            check_same_param_keys=check_same_param_keys,
1126
        )
1127
        self._check_same_state(
1128
            sharded_osd1,
1129
            local_osd2,
1130
            check_same_param_keys=check_same_param_keys,
1131
        )
1132
        # As a sanity check, check that we can load and run a few iterations
1133
        optim2.load_state_dict(sharded_osd2)
1134
        self._step_model(model2, optim2, num_iters=num_iters)
1135

1136
    @skip_if_lt_x_gpu(2)
1137
    @parametrize("state_dict_type", STATE_DICT_TYPES)
1138
    @parametrize("add_to_fsdp_module", [False, True])
1139
    def test_shard_full_optim_state_dict_unmanaged_params(
1140
        self,
1141
        state_dict_type: StateDictType,
1142
        add_to_fsdp_module: bool,
1143
    ):
1144
        """
1145
        Tests :meth:`shard_full_optim_state_dict` when there are unmanaged
1146
        parameters.
1147
          - If ``add_to_fsdp_module=True``, then the unmanaged parameters are
1148
          added to a module to be wrapped with FSDP, in which case there should
1149
          be an error since we require that all unflattened parameter
1150
          comprising a flat parameter have the same scalar state (e.g. Adam
1151
          "step") but the added parameter is missing its entry.
1152
          - If ``add_to_fsdp_module=False``, then the unmanaged parameters are
1153
          added to a module not to be wrapped with FSDP, in which case there
1154
          should be no error (emulating model parallel use cases where some
1155
          parameters may be managed externally to FSDP).
1156
        We do not separately test unmanaged parameters for
1157
        :meth:`scatter_full_optim_state_dict` and `flatten_sharded_optim_state_dict`
1158
        to save CI cost since it call into the same subroutine
1159
        :meth:`_flatten_optim_state_dict`.
1160
        """
1161
        if state_dict_type == StateDictType.SHARDED_STATE_DICT:
1162
            use_optim_input = [False]
1163
        else:
1164
            use_optim_input = [False, True]
1165
        self.run_subtests(
1166
            {"use_optim_input": use_optim_input},
1167
            self._test_shard_full_optim_state_dict_unmanaged_params,
1168
            state_dict_type=state_dict_type,
1169
            add_to_fsdp_module=add_to_fsdp_module,
1170
        )
1171

1172
    def _test_shard_full_optim_state_dict_unmanaged_params(
1173
        self,
1174
        state_dict_type: StateDictType,
1175
        add_to_fsdp_module: bool,
1176
        use_optim_input: bool,
1177
    ):
1178
        NUM_ITERS = 1
1179
        # Create a normal wrapped model
1180
        model, optim, optim_input = self._init_nested_model(wrap=True)
1181
        self._step_model(model, optim, num_iters=NUM_ITERS)
1182

1183
        if state_dict_type == StateDictType.FULL_STATE_DICT:
1184
            fsdp_osd = (
1185
                FSDP.full_optim_state_dict(model, optim, optim_input, rank0_only=False)
1186
                if use_optim_input
1187
                else FSDP.full_optim_state_dict(model, optim, rank0_only=False)
1188
            )  # save on all ranks to avoid having to broadcast from rank 0
1189
        else:
1190
            fsdp_osd = FSDP.sharded_optim_state_dict(model, optim)
1191
        # Create a new model with the same structure but additional unmanaged
1192
        # parameters, representing the model for which we want to load
1193
        device = torch.device("cuda")
1194
        model = NestedModel().to(device)
1195
        model, unmanaged_params = NestedModel.wrap_with_unmanaged_params(
1196
            model,
1197
            add_to_fsdp_module,
1198
        )
1199
        optim_input = list(model.parameters())
1200
        optim = torch.optim.Adam(optim_input, lr=1e-3)
1201
        if add_to_fsdp_module:
1202
            # If we add the unmanaged parameters to a module wrapped with FSDP,
1203
            # then the flat parameter will be comprised of some unflattened
1204
            # parameters with zero-dimensional tensor state (i.e. Adam "step")
1205
            # and others without (i.e. the unmanaged parameters), which
1206
            # triggers an error that we have to ensure correctness
1207
            error_prefix = (
1208
                "^(All unflattened parameters comprising a "
1209
                "single flat parameter must have scalar state with the "
1210
                "same value and dtype)"
1211
            )
1212
            with self.assertRaisesRegex(ValueError, error_prefix):
1213
                if state_dict_type == StateDictType.FULL_STATE_DICT:
1214
                    (
1215
                        FSDP.shard_full_optim_state_dict(
1216
                            fsdp_osd, model, optim_input=optim_input
1217
                        )
1218
                        if use_optim_input
1219
                        else FSDP.shard_full_optim_state_dict(
1220
                            fsdp_osd, model, optim=optim
1221
                        )
1222
                    )
1223
                else:
1224
                    FSDP.flatten_sharded_optim_state_dict(fsdp_osd, model, optim=optim)
1225
        else:
1226
            # If we add the unmanaged parameters to a module not wrapped with
1227
            # FSDP, then we simply ignore them without erroring to enable
1228
            # model parallelism use cases, where some parameters are managed
1229
            # externally to FSDP
1230
            if state_dict_type == StateDictType.FULL_STATE_DICT:
1231
                flattened_osd = (
1232
                    FSDP.shard_full_optim_state_dict(
1233
                        fsdp_osd, model, optim_input=optim_input
1234
                    )
1235
                    if use_optim_input
1236
                    else FSDP.shard_full_optim_state_dict(fsdp_osd, model, optim=optim)
1237
                )
1238
            else:
1239
                flattened_osd = FSDP.flatten_sharded_optim_state_dict(
1240
                    fsdp_osd, model, optim=optim
1241
                )
1242
            # Add entries for the unmanaged parameters to be able to load
1243
            for unmanaged_param in unmanaged_params:
1244
                NestedModel.add_unmanaged_param_entry(
1245
                    flattened_osd,
1246
                    unmanaged_param,
1247
                    NUM_ITERS,
1248
                )
1249
            # Check that we can load the optimizer state dict
1250
            optim.load_state_dict(flattened_osd)
1251

1252
    @skip_if_lt_x_gpu(2)
1253
    @parametrize("state_dict_type", STATE_DICT_TYPES)
1254
    @parametrize("use_multiple_param_groups", [False, True])
1255
    def test_rekey_optim_state_dict_to_ids(
1256
        self,
1257
        state_dict_type: StateDictType,
1258
        use_multiple_param_groups: bool,
1259
    ):
1260
        """Tests :meth:`rekey_optim_state_dict` with the new keys being
1261
        parameter IDs by checking that a wrapped model (i.e. with FSDP modules)
1262
        can rekey its optimizer state dict to match that of an equivalent
1263
        non-wrapped model (i.e. without FSDP modules)."""
1264
        if state_dict_type == StateDictType.SHARDED_STATE_DICT:
1265
            use_optim_input = [False]
1266
        else:
1267
            use_optim_input = [False, True]
1268
        self.run_subtests(
1269
            {"use_optim_input": use_optim_input},
1270
            self._test_rekey_optim_state_dict_to_ids,
1271
            state_dict_type=state_dict_type,
1272
            use_multiple_param_groups=use_multiple_param_groups,
1273
        )
1274

1275
    @skip_if_lt_x_gpu(2)
1276
    def _test_rekey_optim_state_dict_to_ids(
1277
        self,
1278
        state_dict_type: StateDictType,
1279
        use_multiple_param_groups: bool,
1280
        use_optim_input: bool,
1281
    ):
1282
        NUM_ITERS = 3
1283
        # Run a wrapped model for a few iterations
1284
        model1, optim1, optim_input1 = self._init_nested_model(
1285
            wrap=True,
1286
            use_multiple_param_groups=use_multiple_param_groups,
1287
        )
1288
        self._step_model(model1, optim1, num_iters=NUM_ITERS)
1289
        if state_dict_type == StateDictType.FULL_STATE_DICT:
1290
            fsdp_osd = (
1291
                FSDP.full_optim_state_dict(model1, optim1, optim_input1)
1292
                if use_optim_input
1293
                else FSDP.full_optim_state_dict(model1, optim1)
1294
            )
1295
            # Broadcast instead of `torch.save()`/`torch.load()` so that all ranks
1296
            # have the full state dict
1297
            fsdp_osd = self._broadcast_full_osd(fsdp_osd)
1298
        else:
1299
            fsdp_osd = FSDP.sharded_optim_state_dict(model1, optim1)
1300
        # Run a non-wrapped model for a few iterations
1301
        model2, optim2, optim_input2 = self._init_nested_model(
1302
            wrap=False,
1303
            use_multiple_param_groups=use_multiple_param_groups,
1304
        )
1305
        self._step_model(model2, optim2, num_iters=NUM_ITERS)
1306
        # Re-key the wrapped model's optimizer state dict using parameter IDs
1307
        # according to the non-wrapped model
1308
        rekeyed_osd = (
1309
            FSDP.rekey_optim_state_dict(
1310
                fsdp_osd,
1311
                OptimStateKeyType.PARAM_ID,
1312
                model2,
1313
                optim_input=optim_input2,
1314
            )
1315
            if use_optim_input
1316
            else FSDP.rekey_optim_state_dict(
1317
                fsdp_osd,
1318
                OptimStateKeyType.PARAM_ID,
1319
                model2,
1320
                optim=optim2,
1321
            )
1322
        )
1323
        # Check that the re-keyed dict and actual dict are the same
1324
        osd = optim2.state_dict()
1325
        check_same_param_keys = True
1326
        self._check_same_param_groups(
1327
            rekeyed_osd,
1328
            osd,
1329
            check_same_param_keys=check_same_param_keys,
1330
        )
1331
        self._check_same_state(
1332
            rekeyed_osd,
1333
            osd,
1334
            check_same_param_keys=check_same_param_keys,
1335
        )
1336
        # As a sanity check, check that we can load and run a few iterations
1337
        if state_dict_type != StateDictType.SHARDED_STATE_DICT:
1338
            optim2.load_state_dict(rekeyed_osd)
1339
            self._step_model(model2, optim2, num_iters=NUM_ITERS)
1340

1341
    @skip_if_lt_x_gpu(2)
1342
    def test_rekey_optim_state_dict_to_names(self):
1343
        """Tests :meth:`rekey_optim_state_dict` with the new keys being
1344
        parameter names by checking that a non-wrapped model (i.e. without FSDP
1345
        modules) can rekey its optimizer state dict to match the expected
1346
        output of :meth:`full_optim_state_dict`, hence be sharded using
1347
        :meth:`shard_full_optim_state_dict`, and finally match the per-rank
1348
        optimizer state dict of a wrapped model (i.e. with FSDP modules)."""
1349
        self.run_subtests(
1350
            {"use_optim_input": [False, True]},
1351
            self._test_rekey_optim_state_dict_to_names,
1352
            use_multiple_param_groups=False,
1353
        )
1354

1355
    def _test_rekey_optim_state_dict_to_names(
1356
        self,
1357
        use_multiple_param_groups: bool,
1358
        use_optim_input: bool,
1359
    ):
1360
        NUM_ITERS = 3
1361
        # Run a wrapped model for a few iterations
1362
        model1, optim1, optim_input1 = self._init_nested_model(
1363
            wrap=True,
1364
            use_multiple_param_groups=use_multiple_param_groups,
1365
        )
1366
        self._step_model(model1, optim1, num_iters=NUM_ITERS)
1367
        # Run a non-wrapped model for a few iterations
1368
        model2, optim2, optim_input2 = self._init_nested_model(
1369
            wrap=False,
1370
            use_multiple_param_groups=use_multiple_param_groups,
1371
        )
1372
        self._step_model(model2, optim2, num_iters=NUM_ITERS)
1373
        # Re-key the non-wrapped model's optimizer state dict using parameter
1374
        # names (still according to itself)
1375
        osd2 = optim2.state_dict()
1376
        rekeyed_osd = (
1377
            FSDP.rekey_optim_state_dict(
1378
                osd2,
1379
                OptimStateKeyType.PARAM_NAME,
1380
                model2,
1381
                optim_input=optim_input2,
1382
            )
1383
            if use_optim_input
1384
            else FSDP.rekey_optim_state_dict(
1385
                osd2,
1386
                OptimStateKeyType.PARAM_NAME,
1387
                model2,
1388
                optim=optim2,
1389
            )
1390
        )
1391
        # Shard the non-wrapped model's re-keyed optimizer state dict, which
1392
        # maps back to (flattened) parameter IDs
1393
        sharded_osd = (
1394
            FSDP.shard_full_optim_state_dict(
1395
                rekeyed_osd,
1396
                model1,
1397
                optim_input=optim_input1,
1398
            )
1399
            if use_optim_input
1400
            else FSDP.shard_full_optim_state_dict(
1401
                rekeyed_osd,
1402
                model1,
1403
                optim=optim1,
1404
            )
1405
        )
1406
        # Check that this sharded optimizer state dict matches the wrapped
1407
        # model's per-rank optimizer state dict
1408
        osd1 = optim1.state_dict()
1409
        check_same_param_keys = True
1410
        self._check_same_param_groups(
1411
            sharded_osd,
1412
            osd1,
1413
            check_same_param_keys=check_same_param_keys,
1414
        )
1415
        self._check_same_state(
1416
            sharded_osd,
1417
            osd1,
1418
            check_same_param_keys=check_same_param_keys,
1419
        )
1420
        # As a sanity check, check that we can load and run a few iterations
1421
        optim1.load_state_dict(sharded_osd)
1422
        self._step_model(model1, optim1, num_iters=NUM_ITERS)
1423

1424
    @skip_if_lt_x_gpu(2)
1425
    def test_optim_input_warning(self):
1426
        """Tests that passing the ``optim_input`` argument into optimizer state
1427
        checkpointing APIs issues a warning."""
1428

1429
        def should_check_method(method_name: str):
1430
            # Check every method since they all accept `optim_input`
1431
            return method_name not in (
1432
                "sharded_optim_state_dict",
1433
                "flatten_sharded_optim_state_dict",
1434
            )
1435

1436
        def get_warning_context():
1437
            warning_regex = "`optim_input` argument is deprecated"
1438
            return self.assertWarnsRegex(
1439
                expected_warning=UserWarning, expected_regex=warning_regex
1440
            )
1441

1442
        self._run_on_all_optim_state_apis(
1443
            should_check_method, get_warning_context, fsdp_kwargs=None
1444
        )
1445

1446
    def _run_on_all_optim_state_apis(
1447
        self,
1448
        should_check_method_fn: Callable[[str], bool],
1449
        context_fn: Callable,
1450
        fsdp_kwargs: Optional[Dict[str, Any]],
1451
    ):
1452
        """
1453
        Runs through all optimizer state checkpointing APIs with a context
1454
        manager instantiated by ``context_fn``. Certain APIs can be skipped
1455
        via ``should_check_method_fn``, which gets passed the string name of
1456
        the method.
1457
        """
1458
        wrapped_model, wrapped_optim, wrapped_optim_input = self._init_nested_model(
1459
            wrap=True,
1460
            use_multiple_param_groups=False,
1461
            fsdp_kwargs=fsdp_kwargs,
1462
        )
1463
        self._step_model(wrapped_model, wrapped_optim, num_iters=2)
1464

1465
        # Sharded optim state dict
1466
        if should_check_method_fn("sharded_optim_state_dict"):
1467
            with context_fn():
1468
                fsdp_osd = FSDP.sharded_optim_state_dict(
1469
                    wrapped_model,
1470
                    wrapped_optim,
1471
                )
1472
        if "fsdp_osd" not in locals():
1473
            fsdp_osd = {}  # may not be defined due to previous method erroring
1474
        if should_check_method_fn("flatten_sharded_optim_state_dict"):
1475
            with context_fn():
1476
                FSDP.flatten_sharded_optim_state_dict(
1477
                    fsdp_osd,
1478
                    wrapped_model,
1479
                    wrapped_optim,
1480
                )
1481
        # Full optim state dict
1482
        if should_check_method_fn("full_optim_state_dict"):
1483
            with context_fn():
1484
                fsdp_osd = FSDP.full_optim_state_dict(
1485
                    wrapped_model,
1486
                    wrapped_optim,
1487
                    optim_input=wrapped_optim_input,
1488
                    rank0_only=False,
1489
                )
1490
        if should_check_method_fn("shard_full_optim_state_dict"):
1491
            with context_fn():
1492
                FSDP.shard_full_optim_state_dict(
1493
                    fsdp_osd,
1494
                    wrapped_model,
1495
                    optim_input=wrapped_optim_input,
1496
                )
1497
        if should_check_method_fn("scatter_full_optim_state_dict"):
1498
            with context_fn():
1499
                FSDP.scatter_full_optim_state_dict(
1500
                    fsdp_osd,
1501
                    wrapped_model,
1502
                    optim_input=wrapped_optim_input,
1503
                )
1504
        # Rekey optim state dict
1505
        (
1506
            nonwrapped_model,
1507
            nonwrapped_optim,
1508
            nonwrapped_optim_input,
1509
        ) = self._init_nested_model(wrap=False, use_multiple_param_groups=False)
1510
        if should_check_method_fn("rekey_optim_state_dict"):
1511
            with context_fn():
1512
                rekeyed_osd = FSDP.rekey_optim_state_dict(
1513
                    fsdp_osd,  # from `full_optim_state_dict()`
1514
                    OptimStateKeyType.PARAM_ID,
1515
                    nonwrapped_model,
1516
                    optim_input=nonwrapped_optim_input,
1517
                )
1518
        self._step_model(nonwrapped_model, nonwrapped_optim, num_iters=2)
1519
        osd = nonwrapped_optim.state_dict()
1520
        if should_check_method_fn("rekey_optim_state_dict"):
1521
            with context_fn():
1522
                FSDP.rekey_optim_state_dict(
1523
                    osd,
1524
                    OptimStateKeyType.PARAM_NAME,
1525
                    nonwrapped_model,
1526
                    optim_input=nonwrapped_optim_input,
1527
                )
1528

1529
    @skip_if_lt_x_gpu(2)
1530
    @parametrize("state_dict_type", STATE_DICT_TYPES)
1531
    def test_save_load_without_0th_param_state(self, state_dict_type: StateDictType):
1532
        """
1533
        Tests saving and loading an optim state dict for Adam optimizer (i.e.
1534
        any optimizer with a "step" key in its state) when the first parameter
1535
        does not have optimizer state (e.g. unused or frozen).
1536
        """
1537

1538
        class Model(nn.Module):
1539
            def __init__(self) -> None:
1540
                super().__init__()
1541
                self.lin1 = nn.Linear(5, 5)
1542
                self.lin2 = nn.Linear(5, 5)
1543
                self.relu = nn.ReLU()
1544

1545
            def forward(self, x: torch.Tensor) -> torch.Tensor:
1546
                # Do not use `lin1`, which is the parameter passed to the
1547
                # optimizer and the one checked for "step" state to see if it
1548
                # is tensor or float
1549
                return self.relu(self.lin2(x))
1550

1551
        model = Model().cuda()
1552
        model.lin1 = FSDP(model.lin1)
1553
        model.lin2 = FSDP(model.lin2)
1554
        fsdp_model = FSDP(model)
1555
        optim = torch.optim.Adam(
1556
            fsdp_model.parameters(), lr=1e-2
1557
        )  # or any optimizer with "step"
1558

1559
        # Run an iteration to construct optimizer state
1560
        device = torch.device("cuda")
1561
        inp = torch.randn((2, 5), device=device)
1562
        loss = fsdp_model(inp).sum()
1563
        loss.backward()
1564
        optim.step()
1565

1566
        # Check that save and load does not error
1567
        if state_dict_type == StateDictType.FULL_STATE_DICT:
1568
            fsdp_osd = FSDP.full_optim_state_dict(fsdp_model, optim, rank0_only=False)
1569
            flattened_osd = FSDP.shard_full_optim_state_dict(fsdp_osd, fsdp_model)
1570
        elif state_dict_type == StateDictType.SHARDED_STATE_DICT:
1571
            fsdp_osd = FSDP.sharded_optim_state_dict(fsdp_model, optim)
1572
            flattened_osd = FSDP.flatten_sharded_optim_state_dict(
1573
                fsdp_osd, fsdp_model, optim
1574
            )
1575
        optim.load_state_dict(flattened_osd)
1576
        # `__setstate__()` will check the 0th parameter to see if "step" is
1577
        # represented as a tensor or float, so it is imperative that its state
1578
        # is non-empty.
1579

1580
        # Run an iteration as a sanity check
1581
        inp = torch.randn((2, 5), device=device)
1582
        loss = fsdp_model(inp).sum()
1583
        loss.backward()
1584
        optim.step()
1585

1586
    @skip_if_lt_x_gpu(2)
1587
    def test_compatible_with_trec(self):
1588
        class DenseModel(torch.nn.Module):
1589
            def __init__(self):
1590
                super().__init__()
1591
                self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
1592
                self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
1593
                self.net3 = nn.Linear(32, 64)
1594
                self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
1595

1596
            def forward(self, x):
1597
                return self.net4(self.net3(self.net2(self.net1(x))))
1598

1599
        class FakeMPModel(torch.nn.Module):
1600
            def __init__(self):
1601
                super().__init__()
1602
                torch.manual_seed(0)
1603
                self.dense = FSDP(DenseModel().cuda(), use_orig_params=True)
1604
                if dist.get_rank() == 0:
1605
                    self.sparse0 = nn.Sequential(nn.Linear(8, 8), nn.ReLU())
1606
                else:
1607
                    self.sparse1 = nn.Sequential(nn.Linear(8, 8), nn.ReLU())
1608

1609
            def forward(self, x):
1610
                if dist.get_rank() == 0:
1611
                    sparse = self.sparse0(x)
1612
                else:
1613
                    sparse = self.sparse1(x)
1614
                dist.all_reduce(sparse)
1615
                return self.dense(sparse)
1616

1617
        models = [FakeMPModel().cuda(), FakeMPModel().cuda()]
1618
        optims = [
1619
            torch.optim.Adam(models[0].parameters(), lr=1e-2),
1620
            _NamedOptimizer(
1621
                models[1].named_parameters(),
1622
                torch.optim.Adam,
1623
                [{"params": models[1].parameters()}],
1624
                models[1],
1625
                lr=1e-2,
1626
            ),
1627
        ]
1628
        state_dicts = []
1629

1630
        # Train one batch and see if optim_state_dict are the same.
1631
        batch = torch.rand(5, 8, device=torch.device("cuda"))
1632
        for model, optim in zip(models, optims):
1633
            # Eagerly initialize the states
1634
            for param in model.parameters():
1635
                if param.requires_grad:
1636
                    t = torch.zeros_like(param)
1637
                    param.grad = torch.autograd.Variable(t)
1638
            optim.step()
1639
            loss = model(batch).sum()
1640
            loss.backward()
1641
            optim.step()
1642
            state_dicts.append(deepcopy(FSDP.optim_state_dict(model, optim)))
1643

1644
        self._check_same_param_groups(
1645
            state_dicts[0], state_dicts[1], check_same_param_keys=False
1646
        )
1647
        self._check_same_state(
1648
            state_dicts[0], state_dicts[1], check_same_param_keys=True
1649
        )
1650

1651
        # Make optim1 has a different state.
1652
        for i in range(5):
1653
            batch = torch.rand(5, 8).cuda()
1654
            loss = models[1](batch).sum()
1655
            loss.backward()
1656
            optims[1].step()
1657

1658
        # Load the state back to see if load_optim_state_dict works.
1659
        state_dict_to_load = FSDP.optim_state_dict_to_load(
1660
            models[1], optims[1], state_dicts[1], is_named_optimizer=True
1661
        )
1662
        optims[1].load_state_dict(state_dict_to_load)
1663
        state_dicts[1] = FSDP.optim_state_dict(models[1], optims[1])
1664

1665
        self._check_same_param_groups(
1666
            state_dicts[0], state_dicts[1], check_same_param_keys=False
1667
        )
1668
        self._check_same_state(
1669
            state_dicts[0], state_dicts[1], check_same_param_keys=True
1670
        )
1671

1672
    @skip_if_lt_x_gpu(2)
1673
    def test_optim_state_without_param_groups(self):
1674
        class SimpleModel(torch.nn.Module):
1675
            def __init__(self):
1676
                super().__init__()
1677
                torch.manual_seed(0)
1678
                self.net1 = nn.Sequential(nn.Linear(2, 4), nn.ReLU())
1679

1680
            def forward(self, x):
1681
                return self.net1(x)
1682

1683
        model = FSDP(SimpleModel().cuda())
1684
        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
1685

1686
        # Train one step to save original optimizer state dict and original optimizer param groups.
1687
        batch = torch.rand(3, 2, device=torch.device("cuda"))
1688
        for param in model.parameters():
1689
            if param.requires_grad:
1690
                t = torch.zeros_like(param)
1691
                param.grad = torch.autograd.Variable(t)
1692
        optim.step()
1693
        loss = model(batch).sum()
1694
        loss.backward()
1695

1696
        original_osd = deepcopy(optim.state_dict())
1697
        original_osd_no_param_groups = deepcopy(original_osd)
1698
        # manually remove param_groups from optimizer state dict
1699
        original_param_groups = deepcopy(
1700
            original_osd_no_param_groups.pop("param_groups")
1701
        )
1702
        # passing the osd without param_groups to FSDP
1703
        original_fsdp_optim_state_dict = deepcopy(
1704
            FSDP.optim_state_dict(
1705
                model, optim, optim_state_dict=original_osd_no_param_groups
1706
            )
1707
        )
1708
        # check the state_dict sharded by FSDP does not contain param_groups.
1709
        self.assertEqual(None, original_fsdp_optim_state_dict.get("param_groups"))
1710

1711
        # train another step to make optim a different state.
1712
        for param in model.parameters():
1713
            if param.requires_grad:
1714
                t = torch.zeros_like(param)
1715
                param.grad = torch.autograd.Variable(t)
1716
        optim.step()
1717
        loss = model(batch).sum()
1718
        loss.backward()
1719

1720
        state_dict_to_load = FSDP.optim_state_dict_to_load(
1721
            model, optim, original_fsdp_optim_state_dict
1722
        )
1723
        # manually add param_groups to state_dict_to_load before loading the optimizer state
1724
        state_dict_to_load["param_groups"] = original_param_groups
1725
        optim.load_state_dict(state_dict_to_load)
1726
        self.assertEqual(original_osd, optim.state_dict())
1727

1728
        fsdp_optim_state = FSDP.optim_state_dict(model, optim)
1729
        self._check_same_state(
1730
            original_fsdp_optim_state_dict, fsdp_optim_state, check_same_param_keys=True
1731
        )
1732
        self.assertEqual(original_param_groups, optim.state_dict()["param_groups"])
1733

1734
    @skip_if_lt_x_gpu(2)
1735
    def test_with_empty_optimizer_state(self):
1736
        model = FSDP(TestDummyModel().cuda())
1737
        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
1738
        state_dict = optim.state_dict()
1739
        gathered_state_dict = FSDP.optim_state_dict(model, optim)
1740
        self.assertEqual(gathered_state_dict["state"], state_dict["state"])
1741

1742
    def _test_load_optim_state_with_optim_state_dict(
1743
        self,
1744
        model_class: _ModelClass,
1745
        state_dict_settings: StateDictSettings,
1746
        use_multiple_param_groups: bool,
1747
        halve_world_size: bool,
1748
        use_diff_optim_inputs: bool,
1749
        num_iters: int,
1750
        **new_model_kwargs,
1751
    ):
1752
        """
1753
        (1) Runs a model with full world size for K iterations to generate a
1754
        full/sharded optimizer state dict;
1755
        (2) initializes a model with halved world size and possibly different
1756
        FSDP wrapping scheme (based on ``new_model_kwargs``);
1757
        (3) loads the full/sharded optimizer state dict from (1) according to the
1758
        halved-world-size model;
1759
        (4) runs the halved-world-size model for K iterations; and
1760
        (5) checks that the sharded optimizer state dict from (3) matches the
1761
        halved-world-size model's local optimizer state dict, meaning that the
1762
        former could have equivalently been loaded into the local optimizer.
1763
        """
1764
        initializer = self._model_class[model_class]
1765

1766
        # First, run a wrapped model with full world size for a few iterations
1767
        model1, optim1, optim_input1 = initializer(
1768
            wrap=True,
1769
            use_multiple_param_groups=use_multiple_param_groups,
1770
        )
1771
        FSDP.set_state_dict_type(
1772
            model1,
1773
            state_dict_settings.state_dict_type,
1774
            state_dict_settings.state_dict_config,
1775
            state_dict_settings.optim_state_dict_config,
1776
        )
1777
        self._step_model(model1, optim1, num_iters=num_iters)
1778
        fsdp_osd1 = FSDP.optim_state_dict(model1, optim1)
1779
        if halve_world_size:
1780
            # Create a new process group with halved world size
1781
            new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0]
1782
            new_group = dist.new_group(ranks=new_group_ranks)
1783
            if self.rank not in new_group_ranks:
1784
                return
1785
        else:
1786
            # Continue using the same group and hence world size
1787
            new_group = dist.distributed_c10d._get_default_group()
1788
        # Second, run a wrapped model with (possibly) halved world size and
1789
        # (possibly) differing `optim_input` across ranks
1790
        model2, optim2, optim_input2 = initializer(
1791
            wrap=True,
1792
            group=new_group,
1793
            use_multiple_param_groups=use_multiple_param_groups,
1794
            use_diff_optim_inputs=use_diff_optim_inputs,
1795
            **new_model_kwargs,  # specify `wrap_alt` to change wrapping
1796
        )
1797
        FSDP.set_state_dict_type(
1798
            model2,
1799
            state_dict_settings.state_dict_type,
1800
            state_dict_settings.state_dict_config,
1801
            state_dict_settings.optim_state_dict_config,
1802
        )
1803
        self._step_model(model2, optim2, num_iters=num_iters)
1804
        fsdp_osd2 = FSDP.optim_state_dict(model2, optim2, group=new_group)
1805
        # Compute two sharded optim state dicts: (1) for the first model
1806
        # according to the second model and (2) for the second model according
1807
        # to the second model
1808
        sharded_osd2 = FSDP.optim_state_dict_to_load(
1809
            model2, optim2, fsdp_osd2, group=new_group
1810
        )
1811

1812
        # As a sanity check, check that sharding the second model's full/sharded
1813
        # optimizer state dict according to itself is equivalent to its local
1814
        # optimizer's state dict
1815
        local_osd2 = optim2.state_dict()
1816
        self._check_same_param_groups(
1817
            sharded_osd2,
1818
            local_osd2,
1819
            check_same_param_keys=True,
1820
        )
1821
        self._check_same_state(
1822
            sharded_osd2,
1823
            local_osd2,
1824
            check_same_param_keys=True,
1825
        )
1826
        # Check that sharding the first model's full/sharded optimizer state dict
1827
        # according to the second model is equivalent to the second model's
1828
        # local optimizer state dict
1829
        sharded_osd1 = FSDP.optim_state_dict_to_load(
1830
            model2, optim2, fsdp_osd1, group=new_group
1831
        )
1832
        self._check_same_param_groups(
1833
            sharded_osd1,
1834
            local_osd2,
1835
            check_same_param_keys=True,
1836
        )
1837
        self._check_same_state(
1838
            sharded_osd1,
1839
            local_osd2,
1840
            check_same_param_keys=True,
1841
        )
1842
        # As a sanity check, check that we can load and run a few iterations
1843
        optim2.load_state_dict(sharded_osd2)
1844
        self._step_model(model2, optim2, num_iters=num_iters)
1845

1846
    @skip_if_lt_x_gpu(2)
1847
    def test_interface_arguments(self):
1848
        model = FSDP(TestDummyModel().cuda())
1849
        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
1850

1851
        def step():
1852
            loss = model(model.get_input())
1853
            loss.backward(loss)
1854
            optim.step()
1855

1856
        step()
1857
        original_osd = deepcopy(optim.state_dict())
1858
        osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
1859
        self._check_same_state(
1860
            FSDP.optim_state_dict(model, optim), osd, check_same_param_keys=True
1861
        )
1862
        step()
1863
        osd_to_load = FSDP.optim_state_dict_to_load(
1864
            model, optim, osd, load_directly=True
1865
        )
1866
        self._check_same_state(
1867
            optim.state_dict(), original_osd, check_same_param_keys=True
1868
        )
1869

1870
        # Test the default setting.
1871
        osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
1872
        for state in osd["state"].values():
1873
            for s in state.values():
1874
                self.assertFalse(isinstance(s, ShardedTensor))
1875
                self.assertFalse(s.is_cuda)
1876

1877
        # Test sharded state_dict without offload_to_cpu
1878
        with FSDP.state_dict_type(
1879
            model,
1880
            StateDictType.SHARDED_STATE_DICT,
1881
            ShardedStateDictConfig(),
1882
            ShardedOptimStateDictConfig(offload_to_cpu=False),
1883
        ):
1884
            osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
1885
            for state in osd["state"].values():
1886
                for s in state.values():
1887
                    if s.dim() == 0:
1888
                        continue
1889
                    self.assertTrue(isinstance(s, ShardedTensor))
1890
                    if s._local_shards[0]:
1891
                        self.assertTrue(s._local_shards[0].tensor.is_cuda)
1892

1893
        # Test full state_dict with rank0_only
1894
        with FSDP.state_dict_type(
1895
            model,
1896
            StateDictType.FULL_STATE_DICT,
1897
            FullStateDictConfig(),
1898
            FullOptimStateDictConfig(
1899
                offload_to_cpu=True,
1900
                rank0_only=True,
1901
            ),
1902
        ):
1903
            osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
1904
            if dist.get_rank() > 0:
1905
                self.assertEqual(osd, {})
1906
            else:
1907
                for state in osd["state"].values():
1908
                    for s in state.values():
1909
                        if s.dim() == 0:
1910
                            continue
1911
                        self.assertFalse(s.is_cuda)
1912
                        self.assertFalse(isinstance(s, ShardedTensor))
1913

1914
    @skip_if_lt_x_gpu(2)
1915
    def test_state_dict_with_none_tensor_state(self):
1916
        def _run_test(use_orig_params, optimizer_has_tensor_state):
1917
            model = FSDP(TestDummyModel().cuda(), use_orig_params=use_orig_params)
1918
            optimizer_cls = (
1919
                torch.optim.Adam if optimizer_has_tensor_state else torch.optim.SGD
1920
            )
1921
            optim = optimizer_cls(model.parameters(), lr=1e-2)
1922

1923
            def step():
1924
                loss = model(model.get_input())
1925
                loss.backward(loss)
1926
                optim.step()
1927

1928
            step()
1929
            original_osd = deepcopy(optim.state_dict())
1930
            for state in original_osd["state"].values():
1931
                # Add customized value
1932
                state["value1"] = 2.74
1933
                state["value2"] = None
1934

1935
            osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
1936
            osd_to_load = FSDP.optim_state_dict_to_load(model, optim, osd)
1937
            for state in osd_to_load["state"].values():
1938
                self.assertEqual(state["value1"], 2.74)
1939
                self.assertEqual(state["value2"], None)
1940

1941
        self.run_subtests(
1942
            {
1943
                "use_orig_params": [False, True],
1944
                "optimizer_has_tensor_state": [False, True],
1945
            },
1946
            _run_test,
1947
        )
1948

1949
    @skip_if_lt_x_gpu(2)
1950
    def test_with_no_shard(self):
1951
        def _run_test(use_orig_params: bool) -> None:
1952
            model = FSDP(
1953
                TestDummyModel().cuda(),
1954
                sharding_strategy=ShardingStrategy.NO_SHARD,
1955
                use_orig_params=use_orig_params,
1956
            )
1957
            optim = torch.optim.Adam(model.parameters(), lr=1e-2)
1958

1959
            def step():
1960
                loss = model(model.get_input())
1961
                loss.backward(loss)
1962
                optim.step()
1963

1964
            step()
1965

1966
            original_osd = deepcopy(optim.state_dict())
1967

1968
            osd = FSDP.optim_state_dict(model, optim)
1969
            osd_to_load = FSDP.optim_state_dict_to_load(model, optim, osd)
1970
            optim.load_state_dict(osd_to_load)
1971

1972
            new_osd = optim.state_dict()
1973

1974
            self.assertEqual(original_osd, new_osd)
1975

1976
        self.run_subtests({"use_orig_params": [False, True]}, _run_test)
1977

1978
    @skip_if_lt_x_gpu(2)
1979
    def test_no_grad(self):
1980
        model = TestDummyModel(no_grad=True).cuda()
1981
        fsdp_model = FSDP(deepcopy(model), use_orig_params=True)
1982
        fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)
1983

1984
        for i in range(5):
1985
            if i % 2 == 1:
1986
                fsdp_model.net1[0].weight.requires_grad = True
1987
                fsdp_model.net1[0].bias.requires_grad = True
1988
            else:
1989
                fsdp_model.net1[0].weight.requires_grad = False
1990
                fsdp_model.net1[0].bias.requires_grad = False
1991
            batch = fsdp_model.get_input()
1992
            loss = fsdp_model(batch).sum()
1993
            loss.backward()
1994
            fsdp_optim.step()
1995
            orig_state_dict = deepcopy(fsdp_optim.state_dict())
1996
            optim_state_dict = FSDP.optim_state_dict(fsdp_model, fsdp_optim)
1997
            FSDP.optim_state_dict_to_load(
1998
                fsdp_model,
1999
                fsdp_optim,
2000
                FSDP.optim_state_dict(fsdp_model, fsdp_optim),
2001
                load_directly=True,
2002
            )
2003

2004
            self._check_same_state(
2005
                fsdp_optim.state_dict(),
2006
                orig_state_dict,
2007
                check_same_param_keys=True,
2008
            )
2009

2010

2011
instantiate_parametrized_tests(TestFSDPOptimState)
2012

2013
if __name__ == "__main__":
2014
    run_tests()
2015

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.