pytorch

Форк
0
/
test_fsdp_state_dict.py 
1317 строк · 51.2 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import io
4
import itertools
5
import sys
6
from contextlib import nullcontext
7
from copy import deepcopy
8
from functools import partial
9
from typing import Any, Dict
10

11
import torch
12
import torch.nn as nn
13
from torch import distributed as dist
14
from torch.distributed._shard.sharded_tensor import (
15
    init_from_local_shards,
16
    Shard,
17
    ShardedTensor,
18
)
19
from torch.distributed._state_dict_utils import (
20
    _all_gather_sharded_tensor,
21
    _gather_state_dict,
22
)
23
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
24
    apply_activation_checkpointing,
25
    checkpoint_wrapper,
26
    CheckpointImpl,
27
)
28
from torch.distributed.fsdp import (
29
    CPUOffload,
30
    FullStateDictConfig,
31
    FullyShardedDataParallel as FSDP,
32
    LocalStateDictConfig,
33
    MixedPrecision,
34
    ShardedStateDictConfig,
35
    StateDictType,
36
)
37
from torch.distributed.fsdp._common_utils import FSDP_PREFIX
38
from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM
39
from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap
40
from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer
41
from torch.nn.parallel import DistributedDataParallel
42
from torch.optim import SGD
43
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
44
from torch.testing._internal.common_fsdp import (
45
    _assert_module_states,
46
    _broadcast_state_dict,
47
    _get_state_dict,
48
    _zero_model,
49
    CUDAInitMode,
50
    FSDPInitMode,
51
    FSDPTest,
52
    get_full_params,
53
    SkipModel,
54
    TransformerWithSharedParams,
55
)
56
from torch.testing._internal.common_utils import (
57
    instantiate_parametrized_tests,
58
    parametrize,
59
    run_tests,
60
    TEST_WITH_DEV_DBG_ASAN,
61
)
62

63
if not dist.is_available():
64
    print("Distributed not available, skipping tests", file=sys.stderr)
65
    sys.exit(0)
66

67
if TEST_WITH_DEV_DBG_ASAN:
68
    print(
69
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
70
        file=sys.stderr,
71
    )
72
    sys.exit(0)
73

74
INNER_SHAPE = [4, 4]
75
OUTER_SHAPE = [4, 5]
76
BUFFER_SHAPE = [5, 5]
77

78
NON_ROOT_FSDP_PREFIX = "non_fsdp_lin"
79

80
_UNFLATTENED_STATE_DICT_IMPLS = ["state_dict", "sharded_state_dict"]
81
_FLATTENED_STATE_DICT_IMPLS = ["local_state_dict"]
82
_SUPPORTED_STATE_DICT_IMPLS = (
83
    _UNFLATTENED_STATE_DICT_IMPLS + _FLATTENED_STATE_DICT_IMPLS
84
)
85

86
STATE_DICT_MAPPING = {
87
    "state_dict": StateDictType.FULL_STATE_DICT,
88
    "local_state_dict": StateDictType.LOCAL_STATE_DICT,
89
    "sharded_state_dict": StateDictType.SHARDED_STATE_DICT,
90
}
91

92

93
class Model(Module):
94
    def __init__(
95
        self,
96
        wrap_fsdp,
97
        register_buffers=False,
98
        ignore_inner=False,
99
        mixed_precision=False,
100
        process_group=None,
101
    ):
102
        super().__init__()
103
        self.inner = Linear(*INNER_SHAPE)
104
        if register_buffers:
105
            self.inner.register_buffer("buffer", torch.randn(BUFFER_SHAPE))
106
            self.inner.register_buffer(
107
                "non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False
108
            )
109
        if wrap_fsdp:
110
            self.inner = FSDP(
111
                self.inner,
112
                ignored_modules=([self.inner] if ignore_inner else []),
113
                mixed_precision=MixedPrecision(
114
                    param_dtype=torch.float16,
115
                    reduce_dtype=torch.float16,
116
                    buffer_dtype=torch.float16,
117
                )
118
                if mixed_precision
119
                else None,
120
                process_group=process_group,
121
            )
122
        self.outer = Linear(*OUTER_SHAPE)
123
        if register_buffers:
124
            self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE))
125
            self.outer.register_buffer(
126
                "non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False
127
            )
128

129
    def forward(self, x):
130
        # Forward twice.
131
        i = self.inner(x)
132
        j = self.inner(x)
133
        return self.outer(i + j)
134

135

136
class TestDummyModel(torch.nn.Module):
137
    def __init__(self):
138
        super().__init__()
139
        torch.manual_seed(0)
140
        self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
141
        self.net2 = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
142
        self.net3 = self.net2
143
        self.random_parameter = nn.Parameter(torch.Tensor(10))
144
        self.shared_parameter = self.random_parameter
145

146
    def forward(self, x):
147
        return self.net3(self.net2(self.net1(x)))
148

149
    def get_input(self):
150
        return torch.rand(8, 8, device="cuda")
151

152

153
class TestFSDPStateDict(FSDPTest):
154
    @property
155
    def world_size(self):
156
        return min(torch.cuda.device_count(), 2)
157

158
    def _broadcast_state_dict(self, model, state_dict):
159
        # TODO (rohan-varma): remove model
160
        return _broadcast_state_dict(self.rank, state_dict)
161

162
    def _state_compare(self, model, model_new, assert_fn, state_generator="parameters"):
163
        state_base = list(getattr(model, state_generator)())
164
        state_new = list(getattr(model_new, state_generator)())
165
        # Regardless of `assert_fn`, the number of parameters should be the same
166
        self.assertEqual(len(state_base), len(state_new))
167
        assert_fn(state_base, state_new)
168

169
    def _compare_models(
170
        self, model, model_new, assert_fn, check_fp16=False, check_buffers=True
171
    ):
172
        assert assert_fn in (self.assertEqual, self.assertNotEqual)
173
        with FSDP.summon_full_params(model):
174
            with FSDP.summon_full_params(model_new):
175
                self._state_compare(model, model_new, assert_fn)
176
                if check_buffers:
177
                    has_buffers = any(
178
                        len(list(m.buffers())) for m in (model, model_new)
179
                    )
180
                    if has_buffers:
181
                        self._state_compare(
182
                            model, model_new, assert_fn, state_generator="buffers"
183
                        )
184
                if check_fp16:
185
                    for tensor in model_new.parameters():
186
                        self.assertEqual(tensor.dtype, torch.float16)
187

188
    def _get_simple_nested_model(
189
        self, *fsdp_args, wrap=True, checkpoint_wrap=False, **fsdp_kwargs
190
    ):
191
        if wrap:
192
            lin1 = nn.Linear(10, 10, bias=False).cuda()
193
            lin2 = nn.Linear(10, 10, bias=False).cuda()
194
            if checkpoint_wrap:
195
                lin1 = checkpoint_wrapper(lin1)
196
                lin2 = checkpoint_wrapper(lin2)
197
            seq = nn.Sequential(FSDP(lin1, *fsdp_args, **fsdp_kwargs), lin2)
198
            if checkpoint_wrap:
199
                seq = checkpoint_wrapper(seq)
200
            model = FSDP(seq, *fsdp_args, **fsdp_kwargs)
201
        else:
202
            model = nn.Sequential(
203
                nn.Linear(10, 10, bias=False).cuda(),
204
                nn.Linear(10, 10, bias=False).cuda(),
205
            )
206
        return model
207

208
    def _get_simple_model(self, *fsdp_args, checkpoint_wrap=False, **fsdp_kwargs):
209
        lin = nn.Linear(10, 10, bias=False).cuda()
210
        if checkpoint_wrap:
211
            lin = checkpoint_wrapper(lin)
212
        model = FSDP(lin, *fsdp_args, **fsdp_kwargs)
213
        return model
214

215
    def _get_multibuffer_nested_model(
216
        self, *fsdp_args, wrap=True, checkpoint_wrap=False, **fsdp_kwargs
217
    ):
218
        full_p = torch.float32
219
        lin_mp = fsdp_kwargs.pop("mixed_precision", None)
220
        bn_mp = (
221
            MixedPrecision(param_dtype=full_p, reduce_dtype=full_p, buffer_dtype=full_p)
222
            if lin_mp
223
            else None
224
        )
225
        if wrap:
226
            lin1 = nn.Linear(10, 10, bias=False).cuda()
227
            bn1 = nn.BatchNorm1d(10).cuda()
228
            lin2 = nn.Linear(10, 10, bias=False).cuda()
229
            if checkpoint_wrap:
230
                lin1 = checkpoint_wrapper(lin1)
231
                bn1 = checkpoint_wrapper(bn1)
232
                lin2 = checkpoint_wrapper(lin2)
233
            seq = nn.Sequential(
234
                FSDP(lin1, *fsdp_args, mixed_precision=lin_mp, **fsdp_kwargs),
235
                FSDP(bn1, *fsdp_args, mixed_precision=bn_mp, **fsdp_kwargs),
236
                lin2,
237
            )
238
            if checkpoint_wrap:
239
                seq = checkpoint_wrapper(seq)
240
            model = FSDP(seq, *fsdp_args, **fsdp_kwargs)
241
        else:
242
            model = nn.Sequential(
243
                nn.Linear(10, 10, bias=False).cuda(),
244
                nn.BatchNorm1d(10).cuda(),
245
                nn.Linear(10, 10, bias=False).cuda(),
246
            )
247
        return model
248

249
    def _get_non_fsdp_root_module(self, *fsdp_args, wrap=True, **fsdp_kwargs):
250
        class FSDPContainer(nn.Module):
251
            def __init__(self, fsdp_1, fsdp_2):
252
                super().__init__()
253
                self.non_fsdp_lin = nn.Linear(10, 10, bias=False).cuda()
254
                self.fsdp_1 = fsdp_1
255
                self.fsdp_2 = fsdp_2
256

257
            def forward(self, x):
258
                x = self.non_fsdp_lin(x)
259
                x = self.fsdp_1(x)
260
                x = self.fsdp_2(x)
261
                return x
262

263
        return FSDPContainer(
264
            self._get_simple_nested_model(*fsdp_args, wrap=wrap, **fsdp_kwargs),
265
            self._get_simple_nested_model(*fsdp_args, wrap=wrap, **fsdp_kwargs),
266
        )
267

268
    def _get_state_dict_mgr(
269
        self,
270
        model: nn.Module,
271
        state_dict_type: str,
272
        state_dict_rank0_and_offload: bool,
273
    ):
274
        _state_dict_type = STATE_DICT_MAPPING[state_dict_type]
275
        if state_dict_type == "state_dict":
276
            config = FullStateDictConfig(
277
                rank0_only=state_dict_rank0_and_offload,
278
                offload_to_cpu=state_dict_rank0_and_offload,
279
            )
280
        elif state_dict_type == "local_state_dict":
281
            config = LocalStateDictConfig(
282
                offload_to_cpu=state_dict_rank0_and_offload,
283
            )
284
        elif state_dict_type == "sharded_state_dict":
285
            config = ShardedStateDictConfig(
286
                offload_to_cpu=state_dict_rank0_and_offload,
287
            )
288
        else:
289
            raise ValueError("Unsupported state_dict_type")
290
        return FSDP.state_dict_type(model, _state_dict_type, config)
291

292
    def _validate_state_dict_contents(
293
        self, model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=None
294
    ):
295
        if state_dict_rank0_and_offload:
296
            if self.rank == 0:
297
                self.assertNotEqual(fsdp_state_dict, {})
298
                for key, tensor in fsdp_state_dict.items():
299
                    if ignore_keys and key in ignore_keys:
300
                        continue
301
                    self.assertEqual(
302
                        tensor.device,
303
                        torch.device("cpu"),
304
                        f"{key} is unexpectedly on device {tensor.device}",
305
                    )
306
            else:
307
                # For non-FSDP roots, the non FSDP portion can still have parameters on rank 0,
308
                # so bypass the check for now.
309
                if isinstance(model, FSDP):
310
                    self.assertEqual(
311
                        fsdp_state_dict,
312
                        {},
313
                        f"Expected empty state_dict but got {fsdp_state_dict} on rank {dist.get_rank()}",
314
                    )
315

316
    @skip_if_lt_x_gpu(2)
317
    @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)
318
    @parametrize(
319
        "checkpoint_wrap",
320
        ["source", "dest", "both", "source_after_wrap", "both_after_wrap"],
321
    )
322
    @parametrize("rank0_only_and_offload", [False, True])
323
    def test_fsdp_state_dict_with_activation_checkpoint(
324
        self, state_dict_type, checkpoint_wrap, rank0_only_and_offload
325
    ):
326
        """Tests saving the state dict, zeroing a target model's parameters, and
327
        loading the state dict, where the source and target models may have a
328
        checkpoint wrapper."""
329

330
        def apply_ac_to_linears(model) -> None:
331
            non_reentrant_wrapper = partial(
332
                checkpoint_wrapper,
333
                offload_to_cpu=False,
334
                checkpoint_impl=CheckpointImpl.NO_REENTRANT,
335
            )
336
            apply_activation_checkpointing(
337
                model,
338
                checkpoint_wrapper_fn=non_reentrant_wrapper,
339
                check_fn=lambda submodule: isinstance(submodule, nn.Linear),
340
            )
341

342
        for model_call in [
343
            partial(self._get_simple_model),
344
            partial(self._get_simple_nested_model),
345
        ]:
346
            model = model_call(checkpoint_wrap=(checkpoint_wrap in ("source", "both")))
347
            if checkpoint_wrap in ("source_after_wrap", "both_after_wrap"):
348
                apply_ac_to_linears(model)
349
            with self._get_state_dict_mgr(
350
                model, state_dict_type, rank0_only_and_offload
351
            ):
352
                state_dict = _gather_state_dict(_get_state_dict(model, False, False))
353
                # Possibly wrap new model in activation checkpoint wrapper to test save/
354
                # load with this wrapper
355
                model_new = model_call(
356
                    checkpoint_wrap=(checkpoint_wrap in ("dest", "both"))
357
                )
358
                if checkpoint_wrap == "both_after_wrap":
359
                    apply_ac_to_linears(model_new)
360
                _zero_model(model_new)
361
                self._compare_models(model, model_new, self.assertNotEqual)
362
                if rank0_only_and_offload:
363
                    state_dict = self._broadcast_state_dict(model, state_dict)
364
                # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks
365
                model_new.load_state_dict(state_dict, strict=True)
366
                self._compare_models(model, model_new, self.assertEqual)
367

368
    @skip_if_lt_x_gpu(2)
369
    @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)
370
    @parametrize("rank0_only_and_offload", [False, True])
371
    def test_state_dict_with_manual_ac_wrapper(
372
        self,
373
        state_dict_type: str,
374
        rank0_only_and_offload: bool,
375
    ):
376
        """
377
        Tests saving and loading a state dict for a model manually wrapped with
378
        ``FSDP(CheckpointWrapper(module))``, where the ``CheckpointWrapper`` is
379
        wrapped before FSDP.
380

381
        TODO: Investigate why the test above does not cover everything in this
382
        test and de-duplicate afterwards.
383
        """
384
        if state_dict_type == "sharded_state_dict" and rank0_only_and_offload:
385
            return  # not supported
386
        model_ac = TransformerWithSharedParams.init(
387
            self.process_group,
388
            FSDPInitMode.NO_FSDP,
389
            CUDAInitMode.CUDA_BEFORE,
390
        )
391
        # Manually wrap FSDP without AC
392
        model_no_ac = deepcopy(model_ac)
393
        for i, layer in enumerate(model_no_ac.transformer.encoder.layers):
394
            model_no_ac.transformer.encoder.layers[i] = FSDP(layer)
395
        for i, layer in enumerate(model_no_ac.transformer.decoder.layers):
396
            model_no_ac.transformer.decoder.layers[i] = FSDP(layer)
397
        model_no_ac.transformer = FSDP(model_no_ac.transformer)
398

399
        # Manually wrap FSDP with AC as `FSDP(CheckpointWrapper(module))`
400
        for i, layer in enumerate(model_ac.transformer.encoder.layers):
401
            layer = checkpoint_wrapper(layer)
402
            model_ac.transformer.encoder.layers[i] = FSDP(layer)
403
        for i, layer in enumerate(model_ac.transformer.decoder.layers):
404
            layer = checkpoint_wrapper(layer)
405
            model_ac.transformer.decoder.layers[i] = FSDP(layer)
406
        model_ac.transformer = FSDP(model_ac.transformer)
407

408
        # Save, load, and compare the two models
409
        with self._get_state_dict_mgr(
410
            model_no_ac, state_dict_type, rank0_only_and_offload
411
        ):
412
            state_dict_no_ac = model_no_ac.state_dict()
413
        with self._get_state_dict_mgr(
414
            model_ac, state_dict_type, rank0_only_and_offload
415
        ):
416
            state_dict_ac = model_ac.state_dict()
417
        self.assertEqual(state_dict_ac.keys(), state_dict_no_ac.keys())
418
        if rank0_only_and_offload:
419
            state_dict_no_ac = self._broadcast_state_dict(model_no_ac, state_dict_no_ac)
420
            state_dict_ac = self._broadcast_state_dict(model_ac, state_dict_ac)
421
        with self._get_state_dict_mgr(
422
            model_no_ac, state_dict_type, rank0_only_and_offload
423
        ):
424
            model_no_ac.load_state_dict(state_dict_no_ac)
425
        with self._get_state_dict_mgr(
426
            model_ac, state_dict_type, rank0_only_and_offload
427
        ):
428
            model_ac.load_state_dict(state_dict_ac)
429
        self._compare_models(model_ac, model_no_ac, self.assertEqual)
430

431
    @skip_if_lt_x_gpu(2)
432
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
433
    def test_state_dict_with_shared_parameters(self, state_dict_type):
434
        auto_wrap_policy = ModuleWrapPolicy(
435
            {TransformerEncoderLayer, TransformerDecoderLayer}
436
        )
437
        model_creator = partial(
438
            TransformerWithSharedParams.init,
439
            self.process_group,
440
            FSDPInitMode.RECURSIVE,
441
            CUDAInitMode.CUDA_BEFORE,
442
            {"auto_wrap_policy": auto_wrap_policy},
443
        )
444

445
        fsdp_model = model_creator()
446
        with self._get_state_dict_mgr(fsdp_model, state_dict_type, False):
447
            state_dict = fsdp_model.state_dict()
448

449
        new_model = model_creator()
450
        _zero_model(new_model, zero_buffers=True)
451
        with self._get_state_dict_mgr(new_model, state_dict_type, False):
452
            new_model.load_state_dict(state_dict)
453

454
    @skip_if_lt_x_gpu(2)
455
    @parametrize("use_orig_params", [False, True])
456
    def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool):
457
        """Tests saving a model checkpoint only on rank 0 and loading it only
458
        on rank 0 with ``sync_module_states=True`` to emulate the workflow to
459
        avoid redundant CPU memory usage."""
460
        auto_wrap_policy = ModuleWrapPolicy(
461
            {TransformerEncoderLayer, TransformerDecoderLayer}
462
        )
463
        fsdp_kwargs = {
464
            "auto_wrap_policy": auto_wrap_policy,
465
            "use_orig_params": use_orig_params,
466
        }
467
        fsdp_model = TransformerWithSharedParams.init(
468
            self.process_group,
469
            FSDPInitMode.RECURSIVE,
470
            CUDAInitMode.CUDA_BEFORE,
471
            fsdp_kwargs,
472
        )
473
        # Force model parameters and buffers to be nonzero
474
        with FSDP.summon_full_params(fsdp_model):
475
            for tensor in itertools.chain(
476
                fsdp_model.parameters(), fsdp_model.buffers()
477
            ):
478
                if torch.count_nonzero(tensor) == 0:
479
                    with torch.no_grad():
480
                        tensor.add_(torch.ones_like(tensor))
481
        with self._get_state_dict_mgr(fsdp_model, "state_dict", True):
482
            state_dict = deepcopy(_get_state_dict(fsdp_model))
483
        # Initialize a non-wrapped model on all ranks
484
        new_model = TransformerWithSharedParams.init(
485
            self.process_group,
486
            FSDPInitMode.NO_FSDP,
487
            CUDAInitMode.CUDA_BEFORE,
488
        )
489
        _zero_model(new_model, zero_buffers=True)
490
        # Only load the checkpoint on rank 0
491
        if self.rank == 0:
492
            new_model.load_state_dict(state_dict, strict=True)
493
        _assert_module_states(
494
            new_model,
495
            process_group=self.process_group,
496
            assert_fn=self.assertNotEqual,
497
        )
498
        # Broadcast the module states from rank 0 with `sync_module_states=True`
499
        new_fsdp_model = FSDP(
500
            new_model,
501
            device_id=torch.cuda.current_device(),
502
            auto_wrap_policy=auto_wrap_policy,
503
            sync_module_states=True,
504
        )
505
        # Check FSDP models are equal across ranks
506
        with FSDP.summon_full_params(new_fsdp_model):
507
            _assert_module_states(
508
                new_fsdp_model,
509
                process_group=self.process_group,
510
                assert_fn=self.assertEqual,
511
            )
512
        # Check FSDP models correctly loaded the checkpoint
513
        with FSDP.summon_full_params(fsdp_model):
514
            with FSDP.summon_full_params(new_fsdp_model):
515
                params = list(fsdp_model.parameters())
516
                params_new = list(new_fsdp_model.parameters())
517
                self.assertEqual(params, params_new)
518

519
    @skip_if_lt_x_gpu(2)
520
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
521
    @parametrize(
522
        "cpu_offload",
523
        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
524
    )
525
    @parametrize("fp16", [True, False])
526
    @parametrize("state_dict_rank0_and_offload", [True, False])
527
    @parametrize("use_orig_params", [True, False])
528
    def test_basic_save_and_load_state_dict(
529
        self,
530
        state_dict_type: str,
531
        cpu_offload: bool,
532
        fp16: bool,
533
        state_dict_rank0_and_offload: bool,
534
        use_orig_params: bool,
535
    ):
536
        """
537
        Tests that we can save a state_dict and load it into a blank model
538
        with various configs such as fp16 and cpu offload and parameters
539
        match as expected.
540
        """
541
        if (state_dict_rank0_and_offload and state_dict_type != "state_dict") or (
542
            use_orig_params and state_dict_type not in _UNFLATTENED_STATE_DICT_IMPLS
543
        ):
544
            return  # not supported
545
        device = torch.device(self.rank)
546
        for model_call in [
547
            partial(
548
                self._get_non_fsdp_root_module,
549
                cpu_offload=cpu_offload,
550
                use_orig_params=use_orig_params,
551
            ),
552
            partial(
553
                self._get_simple_nested_model,
554
                cpu_offload=cpu_offload,
555
                use_orig_params=use_orig_params,
556
            ),
557
            partial(
558
                self._get_simple_model,
559
                cpu_offload=cpu_offload,
560
                use_orig_params=use_orig_params,
561
            ),
562
        ]:
563
            model = model_call()
564
            if fp16:
565
                model.half()
566
            # Run a forward/backward to compute gradients to test the case
567
            # where there are gradients populated
568
            inp = torch.randn((3, 10), device=device)
569
            if fp16:
570
                inp = inp.half()
571
            model(inp).sum().backward()
572

573
            ctx = self._get_state_dict_mgr(
574
                model, state_dict_type, state_dict_rank0_and_offload
575
            )
576
            with ctx:
577
                fsdp_state_dict = _get_state_dict(
578
                    model, cpu_offload.offload_params, fp16
579
                )
580

581
            ignore_keys = [
582
                k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k
583
            ]
584

585
            self._validate_state_dict_contents(
586
                model,
587
                fsdp_state_dict,
588
                state_dict_rank0_and_offload,
589
                ignore_keys=ignore_keys,
590
            )
591
            if fp16:
592
                # Verify fp16 is the type
593
                for tensor in fsdp_state_dict.values():
594
                    self.assertEqual(tensor.dtype, torch.float16)
595

596
            model_new = model_call()
597
            if not cpu_offload.offload_params:
598
                model_new = model_new.cuda()
599
            if fp16:
600
                model_new.half()
601
            # Run a forward/backward to compute gradients to test the case
602
            # where there are gradients populated
603
            inp = torch.randn((3, 10), device=device)
604
            if fp16:
605
                inp = inp.half()
606
            model_new(inp).sum().backward()
607

608
            # zero the model to ensure parameters are different.
609
            _zero_model(model_new, zero_buffers=True)
610
            self._compare_models(model, model_new, self.assertNotEqual)
611

612
            # Verify parameters are the same in the new model.
613
            if state_dict_rank0_and_offload:
614
                fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict)
615
            with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]):
616
                model_new.load_state_dict(fsdp_state_dict, strict=True)
617

618
            self._compare_models(model, model_new, self.assertEqual, check_fp16=fp16)
619

620
    @skip_if_lt_x_gpu(2)
621
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
622
    @parametrize(
623
        "cpu_offload",
624
        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
625
    )
626
    @parametrize("mixed_precision", [True, False])
627
    @parametrize("state_dict_rank0_and_offload", [True, False])
628
    @parametrize("use_orig_params", [True, False])
629
    def test_buffers_save_and_load_state_dict(
630
        self,
631
        state_dict_type: str,
632
        cpu_offload: bool,
633
        mixed_precision: bool,
634
        state_dict_rank0_and_offload: bool,
635
        use_orig_params: bool,
636
    ):
637
        """
638
        Tests that we can save a state_dict and load it for modules with persistent buffers, including
639
        in the context of non-default mixed precision, different ``state_dict_type`` s and CPU offloading.
640
        """
641
        if (state_dict_rank0_and_offload and state_dict_type != "state_dict") or (
642
            use_orig_params and state_dict_type not in _UNFLATTENED_STATE_DICT_IMPLS
643
        ):
644
            return  # not supported
645
        mixed_precision = (
646
            MixedPrecision(
647
                param_dtype=torch.float16,
648
                reduce_dtype=torch.float16,
649
                buffer_dtype=torch.float16,
650
            )
651
            if mixed_precision
652
            else None
653
        )
654
        model_call = partial(
655
            self._get_multibuffer_nested_model,
656
            cpu_offload=cpu_offload,
657
            use_orig_params=use_orig_params,
658
            mixed_precision=mixed_precision,
659
        )
660
        model = model_call()
661
        ctx = self._get_state_dict_mgr(
662
            model, state_dict_type, state_dict_rank0_and_offload
663
        )
664
        with ctx:
665
            fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, False)
666

667
        self._validate_state_dict_contents(
668
            model, fsdp_state_dict, state_dict_rank0_and_offload
669
        )
670

671
        model_new = model_call()
672
        if not cpu_offload.offload_params:
673
            model_new = model_new.cuda()
674

675
        # zero the model to ensure parameters are different.
676
        _zero_model(model_new, zero_buffers=True)
677
        self._compare_models(model, model_new, self.assertNotEqual)
678

679
        # Verify parameters are the same in the new model.
680
        if state_dict_rank0_and_offload:
681
            fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict)
682
        with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]):
683
            model_new.load_state_dict(fsdp_state_dict, strict=True)
684

685
        self._compare_models(model, model_new, self.assertEqual)
686

687
    @skip_if_lt_x_gpu(2)
688
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
689
    @parametrize("mixed_precision", [True, False])
690
    @parametrize("state_dict_rank0_and_offload", [True, False])
691
    def test_save_and_load_after_forward_state_dict(
692
        self, state_dict_type, mixed_precision, state_dict_rank0_and_offload
693
    ):
694
        """
695
        Test that saving after some training results in params being updated as
696
        expected.
697
        """
698
        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
699
            return
700
        torch.cuda.set_device(self.rank)
701
        mixed_precision = (
702
            MixedPrecision(
703
                param_dtype=torch.float16,
704
                reduce_dtype=torch.float16,
705
                buffer_dtype=torch.float16,
706
            )
707
            if mixed_precision
708
            else None
709
        )
710
        model = self._get_simple_nested_model(mixed_precision=mixed_precision)
711
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
712
        initial_params = get_full_params(model)
713
        for _ in range(6):
714
            inp = torch.randn(1, 10, device=torch.cuda.current_device())
715
            output = model(*inp)
716
            loss = output.sum()
717
            expected_dtype = torch.float32 if mixed_precision is None else torch.float16
718
            self.assertEqual(expected_dtype, loss.dtype)
719
            loss.backward()
720
            optim.step()
721

722
        trained_params = get_full_params(model)
723
        # Ensure some training occurred
724
        self.assertNotEqual(initial_params, trained_params)
725
        # Save a copy of the state_dict
726
        fsd_mgr = self._get_state_dict_mgr(
727
            model, state_dict_type, state_dict_rank0_and_offload
728
        )
729
        with fsd_mgr:
730
            state_dict = model.state_dict()
731
            if state_dict_type == "state_dict":
732
                state_dict = {k: v.clone() for k, v in state_dict.items()}
733
            else:
734
                for sharded_tensor in state_dict.values():
735
                    shard = sharded_tensor._local_shards[0]
736
                    shard.tensor = shard.tensor.clone().detach_()
737
        self._validate_state_dict_contents(
738
            model, state_dict, state_dict_rank0_and_offload
739
        )
740
        _zero_model(model)
741

742
        # Ensure checkpointed params have the full param dtype
743
        for tensor in state_dict.values():
744
            self.assertEqual(tensor.dtype, torch.float32)
745

746
        # Load state_dict into zeroed model
747
        if state_dict_rank0_and_offload:
748
            state_dict = self._broadcast_state_dict(model, state_dict)
749

750
        with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]):
751
            model.load_state_dict(state_dict, strict=True)
752
        loaded_params = get_full_params(model)
753
        self.assertEqual(loaded_params, trained_params)
754

755
    def _initialize_model(
756
        self,
757
        wrap_fsdp: bool,
758
        wrap_ddp: bool = True,
759
        register_buffers: bool = False,
760
    ):
761
        # keep everything deterministic for input data
762
        torch.manual_seed(0)
763

764
        model = Model(wrap_fsdp, register_buffers=register_buffers).cuda()
765
        if wrap_fsdp:
766
            model = FSDP(model)
767
        elif wrap_ddp:
768
            model = DistributedDataParallel(model, device_ids=[self.rank])
769
        return model
770

771
    @staticmethod
772
    def _state_dict(model: Module, state_dict_type: str):
773
        try:
774
            enum_val = STATE_DICT_MAPPING[state_dict_type]
775
        except KeyError as e:
776
            raise ValueError(f"No state_dict type for {state_dict_type}") from e
777

778
        with FSDP.state_dict_type(model, enum_val):
779
            return model.state_dict()
780

781
    @staticmethod
782
    def _load_state_dict(
783
        model: Module, state_dict_type: str, state_dict: Dict[str, Any]
784
    ):
785
        try:
786
            enum_val = STATE_DICT_MAPPING[state_dict_type]
787
        except KeyError as e:
788
            raise ValueError(f"No state_dict for {state_dict_type}") from e
789

790
        with FSDP.state_dict_type(model, enum_val):
791
            return model.load_state_dict(state_dict, strict=True)
792

793
    def _dist_train(
794
        self, wrap_fsdp: bool, state_dict_type: str = "", move_to_cpu: bool = False
795
    ):
796
        # TODO: Move this test to common_fsdp.
797
        model = self._initialize_model(wrap_fsdp)
798
        optim = SGD(model.parameters(), lr=0.1)
799

800
        in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))
801
        for _ in range(3):
802
            out = model(in_data)
803
            out.sum().backward()
804
            optim.step()
805
            optim.zero_grad()
806

807
        if wrap_fsdp:
808
            blank_model = FSDP(Model(True).cuda())
809
            _zero_model(blank_model)
810
            state_dict = self._state_dict(model, state_dict_type)
811
            if move_to_cpu:
812
                for key in list(state_dict.keys()):
813
                    tensor = state_dict[key]
814
                    if isinstance(tensor, torch.Tensor):
815
                        state_dict[key] = tensor.cpu()
816
                    else:
817
                        shards = tensor.local_shards()
818
                        if shards:
819
                            shards[0].tensor = shards[0].tensor.cpu()
820

821
            self._load_state_dict(blank_model, state_dict_type, state_dict)
822
            return get_full_params(blank_model)
823
        else:
824
            return list(model.parameters())
825

826
    @skip_if_lt_x_gpu(2)
827
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
828
    def test_state_dict_save_load_flow(self, state_dict_type):
829
        self.run_subtests(
830
            {"move_to_cpu": [True, False]},
831
            self._test_state_dict_save_load_flow,
832
            state_dict_type=state_dict_type,
833
        )
834

835
    def _test_state_dict_save_load_flow(self, state_dict_type, move_to_cpu):
836
        fsdp_params = self._dist_train(
837
            wrap_fsdp=True,
838
            state_dict_type=state_dict_type,
839
            move_to_cpu=move_to_cpu,
840
        )
841
        ddp_params = self._dist_train(wrap_fsdp=False)
842
        self.assertEqual(ddp_params, fsdp_params)
843

844
    @skip_if_lt_x_gpu(2)
845
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
846
    def test_fsdp_state_dict_keys(self, state_dict_type):
847
        state_dict = self._state_dict(self._initialize_model(True), state_dict_type)
848
        if state_dict_type == "local_state_dict":
849
            self.assertEqual({FLAT_PARAM, f"inner.{FLAT_PARAM}"}, state_dict.keys())
850
        elif state_dict_type in ("state_dict", "sharded_state_dict"):
851
            # Keys should match local model.
852
            local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False)
853
            local_keys = local_model.state_dict().keys()
854
            self.assertEqual(state_dict.keys(), local_keys)
855
        else:
856
            raise NotImplementedError(f"No test for {state_dict_type}!")
857

858
    @skip_if_lt_x_gpu(2)
859
    @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)
860
    @parametrize("state_dict_rank0_and_offload", [True, False])
861
    @parametrize("fsdp_root", [True, False])
862
    def test_state_dict_load_into_local_module(
863
        self,
864
        state_dict_type,
865
        state_dict_rank0_and_offload,
866
        fsdp_root,
867
    ):
868
        """
869
        Tests that FSDP's state_dict can be loaded into a local model.
870
        """
871
        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
872
            return
873
        if not fsdp_root:
874
            model = self._get_non_fsdp_root_module()
875
        else:
876
            model = self._initialize_model(wrap_fsdp=True, register_buffers=True)
877
        optim = SGD(model.parameters(), lr=0.1)
878
        if not fsdp_root:
879
            in_data = torch.randn(
880
                1, 10, requires_grad=True, device=torch.device("cuda")
881
            )
882
        else:
883
            in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))
884
        for _ in range(3):
885
            out = model(in_data)
886
            out.sum().backward()
887
            optim.step()
888
            optim.zero_grad()
889

890
        with FSDP.summon_full_params(model):
891
            fsdp_params = deepcopy(list(model.parameters()))
892

893
        # get FSDP state_dict. Note that by default we return full_state_dict.
894
        sd_mgr = self._get_state_dict_mgr(
895
            model, state_dict_type, state_dict_rank0_and_offload
896
        )
897
        with sd_mgr:
898
            fsdp_state_dict = model.state_dict()
899

900
        ignore_keys = [k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k]
901
        self._validate_state_dict_contents(
902
            model,
903
            fsdp_state_dict,
904
            state_dict_rank0_and_offload,
905
            ignore_keys=ignore_keys,
906
        )
907
        # Create zeroed local model
908
        if not fsdp_root:
909
            blank_local_model = self._get_non_fsdp_root_module(wrap=False)
910
        else:
911
            blank_local_model = self._initialize_model(
912
                wrap_fsdp=False, wrap_ddp=False, register_buffers=True
913
            )
914

915
        # Nothing should be FSDP
916
        for mod in blank_local_model.modules():
917
            self.assertFalse(isinstance(mod, FSDP))
918

919
        for param in blank_local_model.parameters():
920
            with torch.no_grad():
921
                param.zero_()
922

923
        fsdp_state_dict = _gather_state_dict(fsdp_state_dict)
924

925
        # Load fsdp's full state dict into the local and verify params are as
926
        # expected.
927
        if state_dict_rank0_and_offload:
928
            fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict)
929

930
        blank_local_model.load_state_dict(fsdp_state_dict, strict=True)
931
        local_params = list(blank_local_model.parameters())
932
        for fsdp_param, local_param in zip(fsdp_params, local_params):
933
            self.assertEqual(fsdp_param, local_param)
934

935
    @skip_if_lt_x_gpu(2)
936
    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
937
    @parametrize("double_nest", [True])
938
    def test_state_dict_skip_module(self, state_dict_type, double_nest):
939
        torch.cuda.set_device(self.rank)
940

941
        def _create_module(wrap_fsdp=True):
942
            LINEAR_SKIP = "linear_skip"
943
            ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else nullcontext()
944
            with ctx:
945
                module = SkipModel(double_nest=double_nest)
946
                # Full name of linear_skip param tensors in SkipModel, as would be
947
                # stored in checkpoint.
948
                linear_skip_tensor_names = [
949
                    k
950
                    for k in dict(module.named_parameters()).keys()
951
                    if LINEAR_SKIP in k
952
                ]
953
                # skip SkipModule
954
                linear_skip = getattr(module, LINEAR_SKIP)
955
                delattr(module, LINEAR_SKIP)
956
                # Wrap FSDP
957
                fsdp = wrap(module)
958
                # reattach
959
                setattr(module, LINEAR_SKIP, linear_skip)
960
                return fsdp, linear_skip_tensor_names
961

962
        fsdp, linear_skip_tensor_names = _create_module()
963
        # Run a forward pass
964
        inp = torch.randn((1, 10), device=torch.cuda.current_device())
965
        loss = fsdp(inp)
966
        loss.sum().backward()
967

968
        with FSDP.state_dict_type(fsdp, STATE_DICT_MAPPING[state_dict_type]):
969
            state_dict = fsdp.state_dict()
970
        if self.rank == 0 and state_dict_type != "local_state_dict":
971
            sd_keys = list(state_dict.keys())
972
            expected = list(SkipModel(double_nest=False).state_dict().keys())
973
            self.assertEqual(sorted(sd_keys), sorted(expected))
974
            # TODO: parameters in linear_skip_tensor_names should not be handled
975
            # by FSDP.state_dict(). Have a check once this is implemented in
976
            # FSDP.state_dict().
977

978
        # Check that it can be loaded into FSDP.
979
        new_fsdp, _ = _create_module()
980
        _zero_model(new_fsdp)
981
        for p1, p2 in zip(fsdp.parameters(), new_fsdp.parameters()):
982
            self.assertNotEqual(p1, p2)
983
        with FSDP.state_dict_type(new_fsdp, STATE_DICT_MAPPING[state_dict_type]):
984
            if state_dict_type != "local_state_dict":
985
                # FlatParameter has not supported deepcopy yet.
986
                state_dict = deepcopy(state_dict)
987
            new_fsdp.load_state_dict(state_dict, strict=True)
988
        for p1, p2 in zip(fsdp.parameters(), new_fsdp.parameters()):
989
            self.assertEqual(p1, p2)
990

991
        # Test that the checkpoint can be loaded into a local model.
992
        local, _ = _create_module(wrap_fsdp=False)
993
        for param in local.parameters():
994
            with torch.no_grad():
995
                param.zero_()
996

997
        with fsdp.summon_full_params(fsdp):
998
            for p1, p2 in zip(fsdp.parameters(), local.parameters()):
999
                self.assertNotEqual(p1, p2)
1000

1001
        if state_dict_type == "local_state_dict":
1002
            return
1003
        state_dict = _gather_state_dict(state_dict)
1004
        with fsdp.summon_full_params(fsdp):
1005
            if self.rank == 0:
1006
                local.load_state_dict(state_dict, strict=True)
1007
                for p1, p2 in zip(fsdp.parameters(), local.parameters()):
1008
                    self.assertEqual(p1, p2)
1009

1010
    @skip_if_lt_x_gpu(2)
1011
    def test_wrong_state_dict_config(self):
1012
        model = FSDP(Model(wrap_fsdp=True).cuda())
1013
        with self.assertRaisesRegex(RuntimeError, "Expected state_dict_config of type"):
1014
            with model.state_dict_type(
1015
                model, StateDictType.FULL_STATE_DICT, LocalStateDictConfig()
1016
            ):
1017
                pass
1018

1019
    @skip_if_lt_x_gpu(2)
1020
    @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)
1021
    @parametrize("prefix", [True, False])
1022
    @parametrize("ignore_inner", [True, False])
1023
    @parametrize("mixed_precision", [True, False])
1024
    def test_state_dict_with_ignored_modules(
1025
        self, state_dict_type, prefix, ignore_inner, mixed_precision
1026
    ):
1027
        # Initialize an FSDP-wrapped model with an ignored module that includes
1028
        # both parameters and a buffer
1029
        model = Model(
1030
            wrap_fsdp=True,
1031
            register_buffers=True,
1032
            ignore_inner=ignore_inner,
1033
            mixed_precision=mixed_precision,
1034
        ).cuda()
1035
        ignored_modules = [model.outer]
1036
        ignored_tensor_to_tensor_name = {
1037
            model.outer.bias: "outer.bias",
1038
            model.outer.weight: "outer.weight",
1039
        }
1040
        if ignore_inner:
1041
            ignored_tensor_to_tensor_name = {
1042
                **ignored_tensor_to_tensor_name,
1043
                model.inner.bias: "inner.bias",
1044
                model.inner.weight: "inner.weight",
1045
            }
1046
        # Note that when model.inner is not ignored this test also ensures
1047
        # non-ignored buffers are not cloned.
1048
        buffer_to_buffer_name = {
1049
            model.inner.buffer: "inner.buffer",
1050
            model.outer.buffer: "outer.buffer",
1051
        }
1052
        # expect fp16 model.inner.buffer with mixed_precisions
1053
        # expect fp32 sd.inner.buffer after restoring to original precision
1054
        # so skip AssertEqual
1055
        if mixed_precision and not ignore_inner:
1056
            buffer_to_buffer_name.pop(model.inner.buffer)
1057

1058
        fsdp_model = FSDP(
1059
            model,
1060
            ignored_modules=ignored_modules,
1061
            mixed_precision=MixedPrecision(
1062
                param_dtype=torch.float16,
1063
                reduce_dtype=torch.float16,
1064
                buffer_dtype=torch.float16,
1065
            )
1066
            if mixed_precision
1067
            else None,
1068
        )
1069
        prefix_str = "foo." if prefix else ""
1070
        with FSDP.state_dict_type(fsdp_model, STATE_DICT_MAPPING[state_dict_type]):
1071
            sd1 = _gather_state_dict(fsdp_model.state_dict(prefix=prefix_str))
1072
        with FSDP.summon_full_params(fsdp_model):
1073
            fsdp_params = deepcopy(list(fsdp_model.parameters()))
1074
        # Check that the ignored parameters and all buffers are not cloned
1075
        for tensor, tensor_name in {
1076
            **ignored_tensor_to_tensor_name,
1077
            **buffer_to_buffer_name,
1078
        }.items():
1079
            prefixed_tensor_name = f"{prefix_str}{tensor_name}"
1080
            self.assertTrue(prefixed_tensor_name in sd1)
1081
            self.assertEqual(
1082
                tensor.data_ptr(),
1083
                sd1[prefixed_tensor_name].data_ptr(),
1084
                f"{prefixed_tensor_name}",
1085
            )
1086
        # should not apply mixed_precision to ignored buffers
1087
        for buffer_name in buffer_to_buffer_name.values():
1088
            prefixed_buffer_name = f"{prefix_str}{buffer_name}"
1089
            self.assertTrue(prefixed_buffer_name in sd1)
1090
            self.assertEqual(sd1[prefixed_buffer_name].dtype, torch.float32)
1091
        # Check that the state dict can be loaded into a non-wrapped version of
1092
        # the model
1093
        nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda()
1094
        for param in nonwrapped_model.parameters():
1095
            with torch.no_grad():
1096
                param.zero_()
1097

1098
        to_load = {k[len(prefix_str) :]: v for k, v in sd1.items()}
1099
        nonwrapped_model.load_state_dict(to_load, strict=True)
1100
        local_params = list(nonwrapped_model.parameters())
1101
        for fsdp_param, local_param in zip(fsdp_params, local_params):
1102
            self.assertEqual(fsdp_param, local_param)
1103
        # Check that if we save a state dict again, the ignored parameters and
1104
        # buffer still have the same data pointer
1105
        with FSDP.state_dict_type(fsdp_model, STATE_DICT_MAPPING[state_dict_type]):
1106
            sd2 = fsdp_model.state_dict(prefix=prefix_str)
1107
        for tensor, tensor_name in {
1108
            **ignored_tensor_to_tensor_name,
1109
            **buffer_to_buffer_name,
1110
        }.items():
1111
            prefixed_tensor_name = f"{prefix_str}{tensor_name}"
1112
            self.assertTrue(prefixed_tensor_name in sd2)
1113
            self.assertEqual(tensor.data_ptr(), sd2[prefixed_tensor_name].data_ptr())
1114
            self.assertEqual(
1115
                sd1[prefixed_tensor_name].data_ptr(),
1116
                sd2[prefixed_tensor_name].data_ptr(),
1117
            )
1118

1119
    @skip_if_lt_x_gpu(2)
1120
    def test_state_dict_type(self):
1121
        module = SkipModel(double_nest=True)
1122
        with enable_wrap(wrapper_cls=FSDP):
1123
            fsdp = wrap(module)
1124
        with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT):
1125
            pass
1126
        for module in FSDP.fsdp_modules(fsdp):
1127
            self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT)
1128

1129
    @skip_if_lt_x_gpu(2)
1130
    def test_local_state_dict_with_empty_ranks(self):
1131
        class Model(Module):
1132
            def __init__(self):
1133
                super().__init__()
1134
                self.my_tensor = torch.full((1,), 3.1415926)
1135
                self.my_parameter = nn.Parameter(self.my_tensor)
1136

1137
            def forward(self, x):
1138
                return self.my_parameter
1139

1140
        model = FSDP(Model().cuda())
1141
        with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
1142
            out = model(None)
1143
            out.backward()
1144

1145
            state_dict = deepcopy(model.state_dict())
1146
            with torch.no_grad():
1147
                with FSDP.summon_full_params(model):
1148
                    self.assertEqual(model.my_parameter.item(), 3.1415926)
1149
                    model.my_parameter.copy_(torch.full((1,), 1.75).cuda())
1150
                    self.assertEqual(model.my_parameter.item(), 1.75)
1151
            model.load_state_dict(state_dict)
1152
            with FSDP.summon_full_params(model):
1153
                self.assertEqual(model.my_parameter.item(), 3.1415926)
1154

1155
    @skip_if_lt_x_gpu(2)
1156
    def test_torch_save_load(self):
1157
        model = Model(wrap_fsdp=True).cuda()
1158
        with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
1159
            state_dict = model.state_dict()
1160
            checkpoint = io.BytesIO()
1161
            torch.save(state_dict, checkpoint)
1162
            checkpoint.seek(0)
1163
            state_dict_saved = torch.load(checkpoint)
1164
            for k, v in state_dict_saved.items():
1165
                if isinstance(v, ShardedTensor):
1166
                    self.assertEqual(
1167
                        v._local_shards[0].tensor, state_dict[k]._local_shards[0].tensor
1168
                    )
1169
                else:
1170
                    self.assertEqual(v, state_dict[k])
1171

1172
    @skip_if_lt_x_gpu(2)
1173
    def test_shared_module_and_shared_parameter(self):
1174
        model = FSDP(TestDummyModel().cuda())
1175
        with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
1176
            state_dict = model.state_dict()
1177
            self.assertEqual(
1178
                state_dict["random_parameter"], state_dict["shared_parameter"]
1179
            )
1180
            self.assertEqual(state_dict["net2.0.bias"], state_dict["net3.0.bias"])
1181
            self.assertEqual(state_dict["net2.0.weight"], state_dict["net3.0.weight"])
1182

1183
    @skip_if_lt_x_gpu(2)
1184
    def test_full_state_dict_missing_unexpected_keys_cleaned(self):
1185
        model = self._get_simple_nested_model()
1186
        sd = model.state_dict()
1187
        # Create a missing key
1188
        sd.pop(next(iter(sd.keys())))
1189
        # Create an unexpected key
1190
        sd["unexpected"] = torch.ones(1)
1191
        missing, unexpected = model.load_state_dict(sd, strict=False)
1192
        assert len(missing) == 1
1193
        assert len(unexpected) == 1
1194
        self.assertTrue(FSDP_PREFIX not in missing[0])
1195
        self.assertTrue(FSDP_PREFIX not in unexpected[0])
1196

1197
    @skip_if_lt_x_gpu(2)
1198
    def test_sharded_load_multi_backend_pg(self):
1199
        auto_wrap_policy = ModuleWrapPolicy(
1200
            {TransformerEncoderLayer, TransformerDecoderLayer}
1201
        )
1202
        fsdp_kwargs = {
1203
            "auto_wrap_policy": auto_wrap_policy,
1204
            "use_orig_params": True,
1205
        }
1206
        for load_cpu in [True, False]:
1207
            with self.subTest(load_cpu=load_cpu):
1208
                pg = dist.new_group(backend="cpu:gloo,cuda:nccl")
1209
                fsdp_model = TransformerWithSharedParams.init(
1210
                    pg,
1211
                    FSDPInitMode.RECURSIVE,
1212
                    CUDAInitMode.CUDA_BEFORE,
1213
                    fsdp_kwargs,
1214
                )
1215
                FSDP.set_state_dict_type(fsdp_model, StateDictType.SHARDED_STATE_DICT)
1216
                sharded = fsdp_model.state_dict()
1217
                param_copy = [t.clone().detach_() for t in fsdp_model.parameters()]
1218
                with torch.no_grad():
1219
                    for p in fsdp_model.parameters():
1220
                        p.zero_()
1221

1222
                if load_cpu:
1223
                    # Offload to CPU to simulate CPU state_dict load
1224
                    for k, v in sharded.items():
1225
                        sharded[k] = v.cpu()
1226

1227
                fsdp_model.load_state_dict(sharded)
1228
                for p1, p2 in zip(param_copy, fsdp_model.parameters()):
1229
                    self.assertEqual(p1, p2, f"not equal: {p1.sum()} vs {p2.sum()}")
1230

1231
    @skip_if_lt_x_gpu(2)
1232
    def test_world_size_one(self):
1233
        my_pg = None
1234
        for i in range(self.world_size):
1235
            pg = dist.new_group(ranks=[i])
1236
            if i == self.rank:
1237
                my_pg = pg
1238

1239
        model = TransformerWithSharedParams.init(
1240
            my_pg,
1241
            FSDPInitMode.RECURSIVE,
1242
            CUDAInitMode.CUDA_BEFORE,
1243
        )
1244
        with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
1245
            state_dict = model.state_dict()
1246
            model.load_state_dict(state_dict)
1247

1248
        dist.barrier()
1249

1250

1251
class TestFSDPStateDict4GPUs(FSDPTest):
1252
    @property
1253
    def world_size(self):
1254
        return torch.cuda.device_count()
1255

1256
    @skip_if_lt_x_gpu(4)
1257
    def test_local_state_dict_reshard(self):
1258
        """
1259
        This test demonstrates the ability to do resharding when using
1260
        local_state_dict. Although we do not recommend users to use
1261
        local_state_dict, there are still some corner cases that
1262
        using local_state_dict is a better solution.
1263
        """
1264
        model = FSDP(Model(wrap_fsdp=True)).cuda()
1265
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
1266

1267
        batch = torch.randn(4, 4, device=torch.cuda.current_device())
1268
        output = model(batch)
1269
        loss = output.sum()
1270
        loss.backward()
1271
        optim.step()
1272
        with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
1273
            state_dict = model.state_dict()
1274

1275
        rank = dist.get_rank()
1276
        new_pg = dist.new_group(ranks=[0, 1])
1277
        resharded_state_dict = {}
1278
        # Mimic resharding from 4 GPUs to 2 GPUs
1279
        for key, value in state_dict.items():
1280
            if isinstance(value, ShardedTensor):
1281
                full_flat_param = _all_gather_sharded_tensor(value)
1282
                if rank < 2:
1283
                    full_numel = full_flat_param.size()
1284
                    chunks = full_flat_param.chunk(2)
1285
                    flat_param = chunks[rank]
1286
                    shard_offset = 0 if rank == 0 else chunks[0].numel()
1287
                    local_shards = [
1288
                        Shard.from_tensor_and_offsets(flat_param, [shard_offset], rank)
1289
                    ]
1290
                    sharded_tensor = init_from_local_shards(
1291
                        local_shards, full_numel, process_group=new_pg
1292
                    )
1293
                    resharded_state_dict[key] = sharded_tensor
1294
            else:
1295
                if rank < 2:
1296
                    resharded_state_dict[key] = value
1297

1298
        if rank < 2:
1299
            model2 = FSDP(
1300
                Model(wrap_fsdp=True, process_group=new_pg), process_group=new_pg
1301
            ).cuda()
1302
            with FSDP.state_dict_type(model2, StateDictType.LOCAL_STATE_DICT):
1303
                model2.load_state_dict(resharded_state_dict)
1304

1305
        with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
1306
            full_state_dict1 = model.state_dict()
1307

1308
        if rank < 2:
1309
            with FSDP.state_dict_type(model2, StateDictType.FULL_STATE_DICT):
1310
                full_state_dict2 = model2.state_dict()
1311
            self.assertEqual(full_state_dict1, full_state_dict2)
1312

1313

1314
instantiate_parametrized_tests(TestFSDPStateDict)
1315

1316
if __name__ == "__main__":
1317
    run_tests()
1318

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.