pytorch

Форк
0
/
_flat_param.py 
2723 строки · 117.6 Кб
1
import contextlib
2
import functools
3
import logging
4
import os
5
import warnings
6
from enum import auto, Enum
7
from itertools import accumulate, chain
8
from typing import (
9
    Any,
10
    Callable,
11
    cast,
12
    Dict,
13
    Generator,
14
    Iterator,
15
    List,
16
    NamedTuple,
17
    no_type_check,
18
    Optional,
19
    Sequence,
20
    Set,
21
    Tuple,
22
    Union,
23
)
24

25
import torch
26
import torch.distributed as dist
27
import torch.nn as nn
28
import torch.nn.functional as F
29
from torch import Tensor
30
from torch.distributed.fsdp._common_utils import (
31
    _FSDPDeviceHandle,
32
    _named_parameters_with_duplicates,
33
    _no_dispatch_record_stream,
34
    _set_fsdp_flattened,
35
    HandleTrainingState,
36
)
37
from torch.distributed.utils import (
38
    _alloc_storage,
39
    _data_ptr_allocated,
40
    _free_storage,
41
    _p_assert,
42
)
43
from torch.nn.parameter import _ParameterMeta  # type: ignore[attr-defined]
44
from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
45

46
from ._fsdp_extensions import (
47
    _ext_post_unflatten_transform,
48
    _ext_pre_flatten_transform,
49
    FSDPExtensions,
50
)
51

52
__all__ = [
53
    "FlatParameter",
54
    "FlatParamHandle",
55
    "FlatParamShardMetadata",
56
    "ParamInfo",
57
    "SharedParamInfo",
58
    "HandleShardingStrategy",
59
]
60

61
log = logging.getLogger(__name__)
62

63

64
"""
65
[Note: Fully Sharded Module]
66
We define the "fully sharded module" to be the original ``nn.Module`` that owns
67
a ``FlatParamHandle``. It is the *single* module logically responsible for the
68
*single* unshard/reshard pair for the handle's ``FlatParameter`` for a given
69
forward or backward pass. The fully sharded module should be passed to the
70
``FlatParamHandle`` constructor.
71

72
For the wrapper code path:
73
- The ``FullyShardedDataParallel`` module wrapping the fully sharded module
74
runs the unshard/reshard on behalf of the fully sharded module by overriding
75
``nn.Module.forward``.
76
- The fully sharded module is exactly the module passed to the
77
``FullyShardedDataParallel`` constructor's ``module`` argument.
78

79
For the non-wrapper code path:
80
- Hooks registered on the fully sharded module run the unshard/reshard.
81
- The fully sharded module may either be the direct argument to ``fully_shard``
82
or a submodule chosen by the provided wrapping policy.
83
"""
84

85
# Environment variable toggling whether to use unsafe `setattr()` for view
86
# setting in `_use_sharded_views()` and `_use_unsharded_views()`
87
# We should use 'safe' by default since it respects method overrides, but for
88
# special cases such as for high CPU overhead or for intentionally bypassing
89
# checks in the overrides, we may use 'unsafe'.
90
_FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR"
91

92
# Environment variable toggling whether to check for parameter/gradient
93
# writeback in case their storages change after FSDP initialization
94
# We should check by default since it prevents silent correctness errors, but
95
# since such changes are atypical, we may want to skip the check to save CPU
96
# overhead, especially since the check happens in the pre-forward and
97
# pre-backward each iteration.
98
_FSDP_SKIP_WRITEBACK_CHECK = "FSDP_SKIP_WRITEBACK_CHECK"
99

100
# Env var toggling whether when model is in .eval() mode, should we run in fp32
101
# or the reduced precision.
102
_FSDP_USE_FULL_PREC_IN_EVAL = "FSDP_USE_FULL_PREC_IN_EVAL"
103

104
# Some value to set padding in tensors to for debuggability
105
_FLAT_PARAM_PADDING_VALUE = 42
106

107
# Environment variables for disabling the all-gather and reduce-scatter
108
# communication ops for ablation studies. Note that without these communication
109
# ops the training won't converge, and you probably need to disable correctness
110
# checks in your model.
111
_FSDP_USE_FAKE_ALL_GATHER = "FSDP_USE_FAKE_ALL_GATHER"
112
_FSDP_USE_FAKE_REDUCE = "FSDP_USE_FAKE_REDUCE"
113

114

115
# TODO: Define this for now to avoid circular imports. See if we can remove.
116
class HandleShardingStrategy(Enum):
117
    FULL_SHARD = auto()
118
    SHARD_GRAD_OP = auto()
119
    NO_SHARD = auto()
120
    HYBRID_SHARD = auto()
121
    _HYBRID_SHARD_ZERO2 = auto()
122

123

124
RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
125
    HandleShardingStrategy.FULL_SHARD,
126
    HandleShardingStrategy.HYBRID_SHARD,
127
)
128
NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
129
    HandleShardingStrategy.SHARD_GRAD_OP,
130
    HandleShardingStrategy._HYBRID_SHARD_ZERO2,
131
)
132

133

134
class ParamInfo(NamedTuple):
135
    """Information for an original parameter."""
136

137
    param_name: str  # unprefixed
138
    module: nn.Module
139
    module_name: str
140

141

142
class SharedParamInfo(NamedTuple):
143
    """
144
    Additional information for a shared parameter.
145

146
    For each shared parameter, we designate one module and its parameter
147
    variable to be the primary owner, determined as the first one encountered
148
    in the parameter walk. These are prefixed with "prim". The primary module
149
    and parameter do not have their own :class:`SharedParamInfo` instance.
150
    """
151

152
    param_name: str  # unprefixed
153
    module: nn.Module
154
    module_name: str
155
    prim_param_name: str  # unprefixed
156
    prim_module: nn.Module
157
    prim_module_name: str
158

159

160
class _ShardParamInfo(NamedTuple):
161
    """Shard-related information for an original parameter."""
162

163
    in_shard: bool
164
    # Use to index into the sharded flat parameter, e.g.
165
    # `flat_param[offset_in_shard : offset_in_shard + numel_in_shard]`
166
    offset_in_shard: Optional[int]
167
    numel_in_shard: Optional[int]
168
    # Use to get part of the parameter in the local shard from a flattened
169
    # version of the unsharded parameter, e.g.
170
    # `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]`
171
    intra_param_start_idx: Optional[int]
172
    intra_param_end_idx: Optional[int]  # inclusive
173

174

175
class FlatParamShardMetadata(NamedTuple):
176
    """
177
    This holds metadata specific to this rank's shard of the flat parameter.
178

179
    Attributes:
180
        param_names (Tuple[str, ...]): Prefixed parameter names of this rank's
181
            shard of the parameters; see :class:`FlatParameter`.
182
        param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's
183
            shard of the parameters; see :class:`FlatParameter`.
184
        param_numels (Tuple[int, ...]): Parameter numels of this rank's shard
185
            of the parameters; see :class:`FlatParameter`.
186
        param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in
187
            units of numels) giving this rank's part of each flattened
188
            original parameter.
189
    """
190

191
    param_names: Tuple[str, ...]
192
    param_shapes: Tuple[torch.Size, ...]
193
    param_numels: Tuple[int, ...]
194
    param_offsets: Tuple[Tuple[int, int], ...]
195

196

197
class _FlatParameterMeta(_ParameterMeta):
198
    # Make `isinstance(t, FlatParameter)` return True for custom tensor
199
    # instances that have the _is_flat_param flag for BC
200
    def __instancecheck__(self, instance):
201
        # NB: do NOT test the super implementation
202
        return isinstance(instance, torch.Tensor) and getattr(
203
            instance, "_is_flat_param", False
204
        )
205

206

207
class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
208
    """
209
    This is the flat parameter used by :class:`FullyShardedDataParallel`.
210

211
    It is comprised of one or more original parameters, which are flattened and
212
    concatenated to construct the flat parameter.
213

214
    Under the current design, this parameter logically represents both the
215
    unsharded and sharded flat parameter, and its data changes storages
216
    dynamically.
217
        - In the :class:`FullyShardedDataParallel` constructor, the parameter
218
        is initialized as unsharded and then sharded in-place.
219
        - At runtime, the parameter is lazily (re)-initialized. The sharded
220
        parameter data is saved in ``self._local_shard``, and a new ``Tensor``
221
        ``self._full_param_padded`` is created, which is the all-gather
222
        destination and owns the unsharded parameter storage thereafter. (See
223
        :meth:`FlatParamHandle.init_flat_param_attributes`.)
224
        - Throughout runtime, the parameter data changes storages as needed,
225
        e.g. to the sharded flat parameter, low precision sharded flat
226
        parameter, or the unsharded flat parameter.
227

228
    NOTE: Since ``use_orig_params=True`` supports intra-``FlatParameter``
229
    padding, we have two versions of the per-parameter numels, one that
230
    includes the padding (``_numels_with_padding``) and one that does not
231
    (``_numels``). The former may have length longer than the other data
232
    structures, while the latter has the same length as the number of actual
233
    original parameters like the other per-parameter data structures.
234

235
    NOTE: This is not a real class; instead, you will always get a Parameter
236
    back out if you try to create one of these.  This is similar to the trick
237
    we implemented for Parameter to get it to work with subclasses; this
238
    is primarily so that FlatParameter supports combination with FakeTensor.
239

240
    Attributes:
241
        _unpadded_unsharded_size (torch.Size): Unsharded flat parameter's size
242
            without right-hand-side padding for divisibility by the world size.
243
            For ``use_orig_params=True``, this includes alignment padding.
244
        _padded_unsharded_size (torch.Size): Unsharded flat parameter's size
245
            with right-hand-side padding for divisibility by the world size.
246
            For ``use_orig_params=True``, this includes alignment padding. This
247
            is only set for sharded strategies since they require padding for
248
            the all-gather.
249
        _sharded_size (torch.Size): Sharded flat parameter's size with padding.
250
            This is also set for ``NO_SHARD``, in which case it is the same as
251
            the unsharded sizes. (We omit "padded" because there is no
252
            analogous unpadded one.)
253

254
        _num_params (int): Number of original parameters flattened into this
255
            flat parameter. This is the length of the per-parameter data
256
            structures.
257
        _param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info
258
            entry; see :class:`ParamInfo` for details.
259
        _shapes (Tuple[torch.Size, ...]): Each parameter's original shape.
260
        _fqns (Tuple[str, ...]): Each parameter's fully-qualified name (FQN)
261
            prefixed from the ``_fully_sharded_module``. The names are
262
            guaranteed to be unique in the subtree rooted at that module.
263
        _param_extensions (Tuple[Optional[Any], ...]): Each parameter's
264
            extension (i.e. some per-parameter state) used to customize
265
            pre-flatten and post-unflatten behavior or ``None``. This is
266
            experimental, and users should not depend on its existence in the
267
            future.
268
        _numels_with_padding (Tuple[int, ...]): Each parameter's numel
269
            including entries for the padding. This is used to construct views
270
            into the flat parameter via ``torch.split()``. This may have length
271
            longer than ``_num_params``.
272
        _numels (Tuple[int, ...]): Each parameter's numel excluding entries for
273
            padding. This has length equal to ``_num_params``.
274
        _shard_param_infos (Tuple[_ShardParamInfo, ...]): Each parameter's
275
            shard parameter info; see :class:`_ShardParamInfo` for details.
276
        _shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter
277
            info entries; see :class:`SharedParamInfo` for details.
278
        _modules (Set[nn.Module]): Modules that contain some original parameter
279
            that is flattened into the flat parameter.
280

281
        _shard_numel_padded (int): Numel padded for this rank's sharded flat
282
            parameter.
283
        _local_shard (Tensor): Sharded flat parameter with padding if using a
284
            sharded strategy. If using ``NO_SHARD``, then this is the unpadded
285
            unsharded flat parameter, and there is no notion of a sharded flat
286
            parameter or padded unsharded flat parameter.
287
        _full_param_padded (Tensor): Unsharded flat parameter with padding.
288
            This is not defined for ``NO_SHARD``. When using mixed precision
289
            for parameters, this has the low precision.
290
        _full_prec_full_param_padded (Tensor): Full precision unsharded flat
291
            parameter with padding. This is used for unsharding outside of
292
            computation when using mixed precision for parameters. This is
293
            never defined for ``NO_SHARD``.
294
        _post_backward_hook_handle (RemovableHandle):
295
            Flat parameter's post-backward hook handle. (Compile only)
296
        _post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]):
297
            Flat parameter's :class:`AccumulateGrad` object and post-backward
298
            hook handle. (Eager only)
299
        _mp_shard (Tensor): Low precision sharded flat parameter with padding.
300
            This is only defined when parameter mixed precision is enabled. For
301
            ``NO_SHARD``, this is used for computation.
302
        _cpu_grad (Tensor): Sharded gradient with padding stored on CPU.
303
            This is only defined when offloading parameters is enabled.
304
        _saved_grad_shard (Tensor): Sharded gradient with padding from previous
305
            iterations for gradient accumulation without :meth:`no_sync`.
306

307
        _params (Optional[List[nn.Parameter]]): If ``use_orig_params=True``,
308
            then each original parameter variable; otherwise, ``None``. This
309
            does not include any padding tensors.
310
        _shared_params (Optional[List[nn.Parameter]]): The original shared
311
            parameter variables if ``use_orig_params=True`` and ``None``
312
            otherwise.
313
        _tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor``
314
            views created in the forward and tracked by autograd when
315
            ``use_orig_params=True`` and is ``None`` otherwise. This is to
316
            preserve those ``Tensor`` variables for the backward to ensure that
317
            the ``FlatParameter`` 's ``AccumulateGrad`` object does not change
318
            in which case the post-backward hook does not run. This is relevant
319
            for cases like reentrant activation checkpointing.
320
        _is_grad_none_mask (Optional[List[bool]]): If ``use_orig_params=True``,
321
            a mask over the original parameters' gradients indicating if it is
322
            logically ``None`` or not; otherwise, ``None``. This does not
323
            include entries for padding. This mask is needed because only some
324
            of the parameters may have ``None`` gradient, in which case the
325
            flat gradient must be non-``None`` and must use zeros to
326
            approximate those original ``None`` gradients. This mask informs
327
            FSDP to set the original parameter gradients to ``None`` (instead
328
            of zeros) as needed.
329
    """
330

331
    _unpadded_unsharded_size: torch.Size
332
    _padded_unsharded_size: torch.Size
333
    _sharded_size: torch.Size
334
    _num_params: int
335
    _param_infos: Tuple[ParamInfo, ...]
336
    _shapes: Tuple[torch.Size, ...]
337
    _fqns: Tuple[str, ...]
338
    _param_extensions: Tuple[Optional[Any], ...]
339
    _numels_with_padding: Tuple[int, ...]
340
    _numels: Tuple[int, ...]
341
    _shard_param_infos: Tuple[_ShardParamInfo, ...]
342
    _shared_param_infos: Tuple[SharedParamInfo, ...]
343
    _modules: Set[nn.Module]
344
    _shard_numel_padded: int
345
    _local_shard: Tensor
346
    _full_param_padded: Tensor
347
    _full_prec_full_param_padded: Tensor
348
    # Eager only
349
    _post_backward_hook_state: Tuple[Any, Any]
350
    # Compile only
351
    _post_backward_hook_handle: Any
352
    _mp_shard: Tensor
353
    _cpu_grad: Tensor
354
    _saved_grad_shard: Tensor
355
    _params: Optional[List[nn.Parameter]]
356
    _shared_params: Optional[List[nn.Parameter]]
357
    _tensors: Optional[List[Optional[Tensor]]]
358
    _is_grad_none_mask: Optional[List[bool]]
359

360
    _is_padding_mask: List[bool]
361

362
    def __new__(cls, data=None, requires_grad=True):
363
        assert cls is FlatParameter, "subclasses FlatParameter not supported"
364
        r = nn.Parameter.__new__(nn.Parameter, data, requires_grad)  # type: ignore[call-arg]
365
        r._is_flat_param = True  # type: ignore[attr-defined]
366
        return r
367

368
    # NB: This is not a regular method, because FlatParameters are not actually
369
    # instances of this class (see __new__ above).  So you must indirectly
370
    # call this directly through the classmethod.
371
    @classmethod
372
    def _init_metadata(
373
        cls,
374
        self,
375
        param_infos: List[ParamInfo],
376
        numels: List[int],
377
        shapes: List[torch.Size],
378
        fqns: List[str],
379
        shared_param_infos: List[SharedParamInfo],
380
        param_extensions: List[Optional[Any]],
381
        params: Optional[List[nn.Parameter]],
382
        shared_params: Optional[List[nn.Parameter]],
383
        is_padding_mask: List[bool],
384
    ) -> None:
385
        """
386
        Initialize attributes holding metadata about the original parameters comprising the flat parameter.
387

388
        We expose this method separate from the constructor to keep the
389
        constructor only responsible for the flat parameter's tensor data. This
390
        method should only be called once per model, while the constructor may
391
        be called multiple times, e.g. when reloading from a checkpoint, in
392
        which case only the tensor data needs to be passed to the constructor.
393
        Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the
394
        metadata is correctly assumed to be unchanged.
395

396
        Args:
397
            See the Attributes in the class docstring.
398
        """
399
        assert len(param_infos) == len(shapes)
400
        assert len(param_infos) == len(fqns)
401
        assert len(param_infos) == len(param_extensions)
402
        self._num_params = len(param_infos)
403
        self._param_infos = param_infos
404
        self._shapes = shapes
405
        self._fqns = fqns
406
        self._param_extensions = param_extensions
407
        self._is_padding_mask = is_padding_mask
408

409
        numels_without_padding: List[int] = []
410
        for numel, is_padding in zip(numels, is_padding_mask):
411
            if not is_padding:
412
                numels_without_padding.append(numel)
413
        self._numels = tuple(numels_without_padding)
414
        self._numels_with_padding = tuple(numels)
415
        assert len(self._numels) == self._num_params
416

417
        self._shared_param_infos = tuple(shared_param_infos)
418
        self._modules = {pi.module for pi in self._param_infos}.union(
419
            {spi.module for spi in self._shared_param_infos}
420
        )
421
        assert (params is None) == (shared_params is None)
422
        if params is not None:
423
            assert shared_params is not None and len(shared_params) == len(
424
                shared_param_infos
425
            )
426
            self._params = []
427
            for param, is_padding in zip(params, is_padding_mask):
428
                if not is_padding:
429
                    self._params.append(param)
430
            self._shared_params = shared_params
431
            # Mark the original parameters to avoid flattening them into
432
            # another `FlatParameter` during recursive construction
433
            for param in chain(self._params, self._shared_params):
434
                _set_fsdp_flattened(param)
435
            self._is_grad_none_mask = [False for _ in range(self._num_params)]
436
            self._tensors = [None for _ in range(self._num_params)]
437
        else:
438
            self._params = None
439
            self._shared_params = None
440
            self._is_grad_none_mask = None
441
            self._tensors = None
442
        self._unpadded_unsharded_size = self.size()
443
        _set_fsdp_flattened(self)
444
        # Tracks whether the `FlatParameter`'s post-backward hook has been
445
        # called to modify the behavior of the post-backward callback
446
        self._post_backward_called = False
447

448

449
class FlatParamHandle:
450
    """
451
    A handle that manages a flat parameter (:class:`FlatParameter`).
452

453
    This includes sharding and view management.
454

455
    Args:
456
        params (Sequence[nn.Parameter]): The parameters to flatten into the
457
            flat parameter.
458
        fully_sharded_module (nn.Module): See [Note: Fully Sharded Module].
459
        device (torch.device): The compute and communication device, which
460
            should be a non-CPU device. We refer to it as the compute device.
461
        sharding_strategy (ShardingStrategy): Sharding strategy to apply to
462
            this handle's ``FlatParameter``.
463
        offload_params (bool): Whether to offload the handle's
464
            ``FlatParameter`` to CPU.
465
        mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision
466
            setting passed to the FSDP constructor.
467
        mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed
468
            precision setting passed to the FSDP constructor.
469
        keep_low_precision_grads (bool): Whether to keep gradients in low
470
            precision.
471
        use_orig_params (bool): If ``True``, then FSDP preserves the original
472
            parameter variables and returns them from ``named_parameters()``
473
            (e.g. to support different optimizer hyperparameters within one
474
            :class:`FlatParameter`). If ``False``, then FSDP reconstructs the
475
            parameters every iteration and returns the :class:`FlatParameter` s
476
            from ``named_parameters()``.
477
    """
478

479
    ##################
480
    # INITIALIZATION #
481
    ##################
482
    def __init__(
483
        self,
484
        params: Sequence[Union[nn.Parameter, Tensor]],
485
        fully_sharded_module: nn.Module,
486
        device: torch.device,
487
        sharding_strategy: HandleShardingStrategy,
488
        offload_params: bool,
489
        mp_param_dtype: Optional[torch.dtype],
490
        mp_reduce_dtype: Optional[torch.dtype],
491
        keep_low_precision_grads: bool,
492
        process_group: dist.ProcessGroup,
493
        use_orig_params: bool,
494
        *,
495
        fsdp_extension: Optional[FSDPExtensions] = None,
496
    ):
497
        super().__init__()
498
        params = list(params)
499
        if len(params) == 0:
500
            raise ValueError(
501
                f"Cannot construct a {self.__class__.__name__} with an empty parameter list"
502
            )
503
        self._init_setattr_fns()
504
        self._skip_writeback_check = (
505
            os.environ.get(_FSDP_SKIP_WRITEBACK_CHECK, "") == "1"
506
        )
507
        self._use_full_prec_in_eval = (
508
            os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
509
        )
510
        self._use_fake_all_gather = os.environ.get(_FSDP_USE_FAKE_ALL_GATHER, "") == "1"
511
        self._use_fake_reduce = os.environ.get(_FSDP_USE_FAKE_REDUCE, "") == "1"
512
        if self._skip_writeback_check:
513
            _warn_skip_writeback_check(
514
                log,
515
                f"Since {_FSDP_SKIP_WRITEBACK_CHECK}=1, FSDP will not check "
516
                "for parameter or gradient writeback. Changing parameter or "
517
                "gradient storages may lead to silent correctness errors.",
518
            )
519
        if self._use_fake_all_gather:
520
            _warn_use_fake_all_gather(
521
                log,
522
                f"Since {_FSDP_USE_FAKE_ALL_GATHER}=1, FSDP will not execute "
523
                "all-gather ops. Your training will be incorrect, but "
524
                "can reveal how much time spent on all-gather ops.",
525
            )
526
        if self._use_fake_reduce:
527
            _warn_use_fake_reduce(
528
                log,
529
                f"Since {_FSDP_USE_FAKE_REDUCE}=1, FSDP will not execute "
530
                "reduce-scatter ops. Your training will be incorrect, but "
531
                "can reveal how much time spent on reduce-scatter ops.",
532
            )
533
        # Only align addresses for `use_orig_params=True` (for now)
534
        align_addresses = use_orig_params
535
        self._init_get_unflat_views_fn(align_addresses)
536
        self.device = device
537
        self._device_handle = _FSDPDeviceHandle.from_device(self.device)
538
        self.process_group = process_group
539
        if self._use_fake_all_gather or self._use_fake_reduce:
540
            self._fake_process_group = FakeProcessGroup(
541
                rank=process_group.rank(), world_size=process_group.size()
542
            )
543
        self.rank = process_group.rank()
544
        self.world_size = process_group.size()
545
        self._sharding_strategy = sharding_strategy
546
        self._offload_params = offload_params
547
        self._use_orig_params = use_orig_params
548
        self._keep_low_precision_grads = keep_low_precision_grads
549
        self._training_state = HandleTrainingState.IDLE
550
        self._debug_level = dist.get_debug_level()
551
        self._fully_sharded_module = fully_sharded_module
552
        # For strategies that do not free after forward, we skip using sharded
553
        # views after forward since the unsharded data exists. We still switch
554
        # `self.flat_param` to point to the sharded flat parameter since what
555
        # it points to parameterizes behavior. We use the following attribute
556
        # to track which tensor data the parameters are unsharded views into.
557
        self._unsharded_flat_param_for_skipped_views: Optional[Tensor] = None
558
        # The index in the state's `all_handles`, which must be the
559
        # same across ranks for the execution order validation to work
560
        self._handle_index: Optional[int] = None
561
        # Index in handles_to_pre_forward_order
562
        self._pre_forward_order_index: Optional[int] = None
563
        # Index in `handles_post_forward_order`
564
        self._post_forward_index: Optional[int] = None
565
        # Used for guarding against mistargeted forward prefetches
566
        self._needs_pre_forward_unshard = False
567
        # Used for guarding against mistargeted backward prefetches
568
        self._needs_pre_backward_unshard = False
569
        # Was the handle prefetched? Set on successful _prefetch_handle and unshard
570
        self._prefetched = False
571
        # Optimistically assume a valid input `params` and set dtype attributes
572
        # before `_init_flat_param()`, which performs the actual validation
573
        self._orig_param_dtype = params[0].dtype
574
        self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
575
        assert self._fwd_bwd_param_dtype is not None  # mypy
576
        self._aligned_numel = (
577
            _get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype)
578
            if align_addresses
579
            else 0
580
        )
581
        self._fsdp_extension = fsdp_extension
582
        self._init_flat_param_and_metadata(
583
            params, fully_sharded_module, self._aligned_numel, use_orig_params  # type: ignore[arg-type]
584
        )
585
        self._use_unsharded_views(as_params=False)
586

587
    def _init_setattr_fns(self):
588
        use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1"
589
        self._setattr_tensor: Callable[[nn.Module, str, Tensor], None]
590
        self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None]
591
        if use_unsafe_setattr:
592
            self._setattr_tensor = _unsafe_setattr_tensor
593
            self._setattr_param = _unsafe_setattr_param
594
        else:
595
            self._setattr_tensor = _safe_setattr_tensor_or_param
596
            self._setattr_param = _safe_setattr_tensor_or_param
597

598
    def _init_get_unflat_views_fn(self, align_addresses: bool):
599
        self._get_unflat_views = (
600
            self._get_unflat_views_aligned
601
            if align_addresses
602
            else self._get_unflat_views_unaligned
603
        )
604

605
    def _init_flat_param_and_metadata(
606
        self,
607
        params: List[Union[Tensor, nn.Parameter]],
608
        module: nn.Module,
609
        aligned_numel: int,
610
        use_orig_params: bool,
611
    ) -> None:
612
        """
613
        Initialize the ``FlatParameter`` and its metadata.
614

615
        NOTE: This should only be called once at construction time, after which
616
        the ``FlatParameter`` metadata is assumed to be static.
617

618
        NOTE: The elements of ``params`` should only be ``Tensor`` s when
619
        composing with ``DTensor`` -based tensor parallelism, in which case the
620
        elements may be ``DTensor`` local shards.
621
        """
622
        if len(params) == 0:
623
            raise ValueError("Expects non-empty `params`")
624
        if aligned_numel < 0:
625
            raise ValueError(
626
                f"Expects non-negative `aligned_numel` but got {aligned_numel}"
627
            )
628
        (
629
            dtype,
630
            flat_param_requires_grad,
631
            device,
632
        ) = self._validate_tensors_to_flatten(params)
633
        params_set = set(params)
634
        # For alignment padding, only `numels` gets strictly non-`None`
635
        # elements, and all other lists get `None` elements for padding.
636
        param_infos: List[ParamInfo] = []
637
        numels: List[int] = []
638
        shapes: List[torch.Size] = []
639
        fqns: List[str] = []
640
        shared_param_infos: List[SharedParamInfo] = []
641
        shared_param_memo: Dict[
642
            Union[Tensor, nn.Parameter], Tuple[nn.Module, str, str]
643
        ] = {}
644
        params_to_flatten: List[Union[Tensor, nn.Parameter]] = []
645
        shared_params: List[Union[Tensor, nn.Parameter]] = []
646
        param_extensions: List[Any] = []
647
        is_padding_mask: List[bool] = []
648
        total_numel = total_numel_without_padding = 0
649
        for submodule_name, submodule in module.named_modules(remove_duplicate=False):
650
            for param_name, param in _named_parameters_with_duplicates(
651
                submodule, recurse=False
652
            ):
653
                if param not in params_set:
654
                    continue
655
                if param in shared_param_memo:  # shared reference
656
                    prim_module, prim_module_name, prim_param_name = shared_param_memo[
657
                        param
658
                    ]
659
                    shared_params.append(param)
660
                    shared_param_infos.append(
661
                        SharedParamInfo(
662
                            param_name,
663
                            submodule,
664
                            submodule_name,
665
                            prim_param_name,
666
                            prim_module,
667
                            prim_module_name,
668
                        )
669
                    )
670
                else:
671
                    if aligned_numel > 0:
672
                        numel_to_pad = aligned_numel - (total_numel % aligned_numel)
673
                        if numel_to_pad > 0 and numel_to_pad < aligned_numel:
674
                            padding_tensor = _construct_padding_tensor(
675
                                numel_to_pad, dtype, False, device
676
                            )
677
                            params_to_flatten.append(padding_tensor)
678
                            is_padding_mask.append(True)
679
                            numels.append(numel_to_pad)
680
                            total_numel += numel_to_pad
681
                    transform_t, extension = _ext_pre_flatten_transform(
682
                        param,
683
                        self._fsdp_extension,
684
                    )
685
                    param = cast(nn.Parameter, transform_t)
686
                    param_extensions.append(extension)
687
                    shared_param_memo[param] = (submodule, submodule_name, param_name)
688
                    params_to_flatten.append(param)
689
                    is_padding_mask.append(False)
690
                    param_infos.append(ParamInfo(param_name, submodule, submodule_name))
691
                    numels.append(param.numel())
692
                    shapes.append(param.shape)
693
                    fqn = (
694
                        submodule_name + "." + param_name
695
                        if submodule_name
696
                        else param_name
697
                    )
698
                    fqns.append(fqn)
699
                    total_numel += param.numel()
700
                    total_numel_without_padding += param.numel()
701
        if len(params_to_flatten) == 0:
702
            raise ValueError(
703
                f"`params` were not found in `module`'s tree"
704
                f"params: {params}\nmodule: {module}"
705
            )
706
        if (
707
            self.rank == 0
708
            and aligned_numel > 0
709
            and total_numel != total_numel_without_padding
710
        ):
711
            log.info(
712
                "FSDP FlatParameter address alignment created "
713
                "%s numel of padding (%s vs. %s)",
714
                total_numel - total_numel_without_padding,
715
                total_numel,
716
                total_numel_without_padding,
717
            )
718
        if aligned_numel > 0:
719
            # Pad to be divisible by world size to avoid a copy for the
720
            # post-backward reduce-scatter
721
            numel_to_pad = self.world_size - (total_numel % self.world_size)
722
            if numel_to_pad > 0 and numel_to_pad < self.world_size:
723
                if self.rank == 0:
724
                    log.info(
725
                        "FSDP FlatParameter world size divisibility created "
726
                        "%s numel of padding",
727
                        numel_to_pad,
728
                    )
729
                padding_tensor = _construct_padding_tensor(
730
                    numel_to_pad, dtype, False, device
731
                )
732
                params_to_flatten.append(padding_tensor)
733
                is_padding_mask.append(True)
734
                numels.append(numel_to_pad)
735
                total_numel += numel_to_pad
736
        # Pass `aligned_numel=0` since we already included padding tensors
737
        self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param(
738
            params_to_flatten,
739
            aligned_numel=0,
740
            requires_grad=flat_param_requires_grad,
741
        )
742
        FlatParameter._init_metadata(
743
            self.flat_param,
744
            param_infos,
745
            numels,
746
            shapes,
747
            fqns,
748
            shared_param_infos,
749
            param_extensions,
750
            _convert_to_params(params_to_flatten) if use_orig_params else None,
751
            _convert_to_params(shared_params) if use_orig_params else None,
752
            is_padding_mask,
753
        )
754

755
    def _validate_tensors_to_flatten(
756
        self, tensors: List[Union[Tensor, nn.Parameter]]
757
    ) -> Tuple:
758
        """Validate the tensors to flatten and returns any necessary metadata."""
759
        dtype: Optional[torch.dtype] = None
760
        # Return as the logical OR over each tensor's value
761
        flat_param_requires_grad: Optional[bool] = None
762
        device: Optional[torch.device] = None
763
        # For `use_orig_params=True`, permit non-uniform `requires_grad`
764
        for tensor in tensors:
765
            if isinstance(tensor, FlatParameter):
766
                raise ValueError("Cannot flatten a `FlatParameter`")
767
            if dtype is None and not tensor.is_floating_point():
768
                raise ValueError("Cannot flatten integer dtype tensors")
769
            if dtype is not None and tensor.dtype != dtype:
770
                raise ValueError(
771
                    f"Must flatten tensors with uniform dtype but got {dtype} "
772
                    f"and {tensor.dtype}"
773
                )
774
            if (
775
                not self._use_orig_params
776
                and flat_param_requires_grad is not None
777
                and tensor.requires_grad != flat_param_requires_grad
778
            ):
779
                raise ValueError(
780
                    "Must flatten tensors with uniform `requires_grad` when "
781
                    "`use_orig_params=False`"
782
                )
783
            if device is not None and tensor.device != device:
784
                raise ValueError(
785
                    "Must flatten tensors on the same device but got both "
786
                    f"{device} and {tensor.device}"
787
                )
788
            dtype = tensor.dtype
789
            flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad
790
            device = tensor.device
791
        assert flat_param_requires_grad is not None, "Requires non-empty `tensors` list"
792
        return dtype, flat_param_requires_grad, device
793

794
    def flatten_tensors(
795
        self,
796
        tensors: List[Tensor],
797
        aligned_numel: int,
798
    ) -> Tensor:
799
        """
800
        Flatten ``tensors`` into a single flat tensor.
801

802
        The flattening optionally includes
803
        padding if ``aligned_numel`` is greater than 0, where ``aligned_numel``
804
        gives the numel required to have address alignment.
805

806
        NOTE: The padding alignment algorithm must be kept in sync with
807
        :meth:`_init_flat_param_metadata`. We separate the two methods because
808
        the initialization happens once, whereas this method may be called
809
        multiple times throughout training (e.g. for checkpointing).
810
        """
811
        if len(tensors) == 0:
812
            raise ValueError("Expects non-empty `tensors`")
813
        if aligned_numel < 0:
814
            raise ValueError(
815
                f"Expects non-negative `aligned_numel` but got {aligned_numel}"
816
            )
817
        dtype, _, device = self._validate_tensors_to_flatten(tensors)
818
        flat_tensors: List[Tensor] = []
819
        if aligned_numel > 0:
820
            total_numel = 0
821
            for tensor in tensors:
822
                numel_to_pad = aligned_numel - (total_numel % aligned_numel)
823
                if numel_to_pad > 0 and numel_to_pad < aligned_numel:
824
                    padding_tensor = _construct_padding_tensor(
825
                        numel_to_pad, dtype, False, device
826
                    )
827
                    flat_tensors.append(padding_tensor)
828
                    total_numel += numel_to_pad
829
                flat_tensors.append(torch.flatten(_detach_if_needed(tensor)))
830
                total_numel += tensor.numel()
831
            numel_to_pad = self.world_size - (total_numel % self.world_size)
832
            if numel_to_pad > 0 and numel_to_pad < self.world_size:
833
                padding_tensor = _construct_padding_tensor(
834
                    numel_to_pad, dtype, False, device
835
                )
836
                flat_tensors.append(padding_tensor)
837
                total_numel += numel_to_pad
838
        else:
839
            flat_tensors = [
840
                torch.flatten(_detach_if_needed(tensor)) for tensor in tensors
841
            ]
842
        return torch.cat(flat_tensors, dim=0)
843

844
    def flatten_tensors_into_flat_param(
845
        self,
846
        tensors: List[Tensor],
847
        aligned_numel: int,
848
        requires_grad: bool,
849
    ) -> FlatParameter:
850
        flat_param_data = self.flatten_tensors(tensors, aligned_numel)
851
        return FlatParameter(flat_param_data, requires_grad=requires_grad)
852

853
    def _init_param_reduce_dtypes(
854
        self,
855
        mp_param_dtype: Optional[torch.dtype],
856
        mp_reduce_dtype: Optional[torch.dtype],
857
    ) -> None:
858
        """
859
        Initialize param and reduce dtypes.
860

861
        Precondition: ``self.flat_param`` is set. This ensures that this
862
        handle's parameters have a single dtype.
863

864
        Postcondition: This sets ``self._fwd_bwd_param_dtype`` and
865
        ``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype``
866
        is ``None``, then we assume the original parameter dtype. One special
867
        case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype``
868
        is ``None``, in which case we assume the gradient reduction dtype
869
        matches the forward/backward parameter dtype.
870
        """
871
        # Save whether these dtypes were specified so that we permit the
872
        # parameter dtype to change up until the lazy initialization
873
        self._low_prec_param_dtype_specified = mp_param_dtype is not None
874
        self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
875
        if (
876
            self._low_prec_param_dtype_specified
877
            and not self._low_prec_reduce_dtype_specified
878
        ):
879
            # Special case: infer gradient reduction mixed precision
880
            self._fwd_bwd_param_dtype = mp_param_dtype
881
            self._reduce_dtype = self._fwd_bwd_param_dtype
882
        else:
883
            self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype
884
            self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype
885
        assert self._fwd_bwd_param_dtype is not None
886
        assert self._reduce_dtype is not None
887

888
    ###################################
889
    # SHARD INITIALIZATION & METADATA #
890
    ###################################
891
    @torch.no_grad()
892
    def shard(self):
893
        """
894
        Shard the handle's ``FlatParameter``.
895

896
        This allocates new memory for
897
        the sharded flat parameter and frees the unsharded flat parameter's
898
        storage.
899

900
        Postcondition: ``self.flat_param`` is the sharded flat parameter. Shard
901
        metadata attributes are set for all sharding strategies.
902
        """
903
        flat_param = self.flat_param
904
        if not self.uses_sharded_strategy:
905
            self._init_shard_metadata(0, 0, flat_param.numel() - 1)
906
        else:
907
            _p_assert(
908
                flat_param.storage_offset() == 0,
909
                "The `FlatParameter` is not the sole occupant of its storage",
910
            )
911
            sharded_flat_param, numel_padded = FlatParamHandle._get_shard(
912
                flat_param, self.rank, self.world_size
913
            )
914
            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
915
                allocated = flat_param._typed_storage()._size() > 0
916
                if allocated:
917
                    flat_param._typed_storage()._resize_(0)
918
            flat_param.set_(sharded_flat_param)  # type: ignore[call-overload]
919
            start_idx = sharded_flat_param.numel() * self.rank
920
            end_idx = sharded_flat_param.numel() * (self.rank + 1) - 1  # inclusive
921
            self._init_shard_metadata(numel_padded, start_idx, end_idx)
922
        if self._use_orig_params:
923
            self._use_sharded_views()
924

925
    def _init_shard_metadata(
926
        self,
927
        numel_padded: int,
928
        unsharded_start_idx: int,
929
        unsharded_end_idx: int,
930
    ) -> None:
931
        """
932
        Initialize shard-related metadata for this rank's shard of the flat parameter.
933

934
        This includes ``_sharded_size``, ``_shard_param_infos``, and ``_shard_numel_padded``.
935

936
        Args:
937
            numel_padded (int): Numel padded for this rank's sharded flat
938
                parameter.
939
            unsharded_start_idx (int): Start index in the unsharded flat
940
            parameter assigned to this rank.
941
            unsharded_end_idx (int): End index (inclusive) in the unsharded
942
                flat parameter assigned to this rank.
943

944
        Precondition: ``self.flat_param`` 's data is the sharded flat
945
        parameter.
946
        """
947
        flat_param = self.flat_param
948
        flat_param._sharded_size = flat_param.size()  # type: ignore[attr-defined]
949
        sharded_flat_param_numel = flat_param.numel()  # includes `numel_padded`
950
        _p_assert(
951
            unsharded_start_idx >= 0 and unsharded_start_idx <= unsharded_end_idx,
952
            f"unsharded_start_idx: {unsharded_start_idx} unsharded_end_idx: {unsharded_end_idx}",
953
        )
954
        _p_assert(
955
            numel_padded <= sharded_flat_param_numel,
956
            f"numel_padded: {numel_padded} "
957
            f"sharded_flat_param_numel: {sharded_flat_param_numel}",
958
        )
959
        shard_param_infos = self._get_shard_metadata(
960
            unsharded_start_idx, unsharded_end_idx
961
        )
962
        assert (
963
            len(shard_param_infos) == flat_param._num_params
964
        ), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
965
        flat_param._shard_param_infos = shard_param_infos  # type: ignore[attr-defined]
966
        flat_param._shard_numel_padded = numel_padded  # type: ignore[attr-defined]
967

968
    def _get_shard_metadata(
969
        self,
970
        unsharded_start_idx: int,
971
        unsharded_end_idx: int,
972
    ) -> Tuple[_ShardParamInfo, ...]:
973
        """
974
        Compute the shard metadata based on ``unsharded_start_idx`` and ``unsharded_end_idx`` (inclusive).
975

976
        ``unsharded_start_idx`` and ``unsharded_end_idx`` give the interval of the
977
        unsharded flat parameter specifying the shard.
978
        """
979
        flat_param_offsets = self._get_flat_param_offsets()
980
        assert len(flat_param_offsets) == len(
981
            self.flat_param._numels_with_padding
982
        ), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
983
        shard_param_infos: List[_ShardParamInfo] = []
984
        sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
985
        # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
986
        # into the unsharded flat parameter (inclusive) of the given parameter
987
        for i, (
988
            (unsharded_param_start_idx, unsharded_param_end_idx),
989
            is_padding,
990
        ) in enumerate(zip(flat_param_offsets, self.flat_param._is_padding_mask)):
991
            if is_padding:
992
                continue
993
            in_sharded_flat_param = (
994
                unsharded_start_idx <= unsharded_param_end_idx
995
                and unsharded_end_idx >= unsharded_param_start_idx
996
            )
997
            if not in_sharded_flat_param:
998
                shard_param_info = _ShardParamInfo(False, None, None, None, None)
999
            else:
1000
                if unsharded_start_idx <= unsharded_param_start_idx:
1001
                    # This branch can only happen once since the rank's
1002
                    # unsharded start index can only intersect one parameter
1003
                    intra_param_start_idx = 0
1004
                    offset_in_shard = unsharded_param_start_idx - unsharded_start_idx
1005
                else:
1006
                    intra_param_start_idx = (
1007
                        unsharded_start_idx - unsharded_param_start_idx
1008
                    )
1009
                    offset_in_shard = 0
1010
                assert (
1011
                    offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel
1012
                ), (
1013
                    f"Invalid `offset_in_shard` of {offset_in_shard} for "
1014
                    f"sharded flat parameter with {sharded_flat_param_numel} numel"
1015
                )
1016
                intra_param_end_idx = (
1017
                    min(unsharded_param_end_idx, unsharded_end_idx)
1018
                    - unsharded_param_start_idx
1019
                )
1020
                numel_in_shard = intra_param_end_idx - intra_param_start_idx + 1
1021
                shard_param_info = _ShardParamInfo(
1022
                    True,
1023
                    offset_in_shard,
1024
                    numel_in_shard,
1025
                    intra_param_start_idx,
1026
                    intra_param_end_idx,
1027
                )
1028
            shard_param_infos.append(shard_param_info)
1029
        return tuple(shard_param_infos)
1030

1031
    @staticmethod
1032
    def _get_unpadded_shard(
1033
        tensor: Tensor,
1034
        rank: int,
1035
        world_size: int,
1036
    ) -> Tuple[Tensor, int]:
1037
        """
1038
        Return the unpadded shard of ``tensor`` for the given ``rank`` and ``world_size``.
1039

1040
        The returned value is a tuple of the shard of ``tensor`` without any
1041
        padding and the numel to pad for that shard.
1042

1043
        If ``tensor`` is already flattened or may be viewed in the flattened
1044
        shape (which is true in the expected usage), then this method does not
1045
        allocate any new tensor memory.
1046
        """
1047
        chunks = torch.flatten(tensor).chunk(world_size)
1048
        if len(chunks) < (rank + 1):
1049
            # This rank gets an empty chunk fully padded with zeros since there
1050
            # are not enough chunks across ranks
1051
            chunk = chunks[0].new_empty(0)
1052
        else:
1053
            chunk = chunks[rank]
1054
        numel_to_pad = chunks[0].numel() - chunk.numel()
1055
        assert (
1056
            numel_to_pad >= 0
1057
        ), "Chunk's size should be at most the first chunk's size"
1058
        return chunk, numel_to_pad
1059

1060
    @staticmethod
1061
    def _get_shard(
1062
        tensor: Tensor,
1063
        rank: int,
1064
        world_size: int,
1065
    ) -> Tuple[Tensor, int]:
1066
        """
1067
        Return the shard of ``tensor`` with padding for the given ``rank`` and ``world_size`` and the numel padded for that shard.
1068

1069
        This method allocates new memory (via :meth:`clone`) since the
1070
        unsharded ``tensor`` may be deallocated after this method returns.
1071
        """
1072
        chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(
1073
            tensor, rank, world_size
1074
        )
1075
        shard = chunk.clone()
1076
        if numel_to_pad > 0:
1077
            shard = F.pad(shard, [0, numel_to_pad])
1078
        return shard, numel_to_pad
1079

1080
    @staticmethod
1081
    def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size:
1082
        """
1083
        Return the shape of ``tensor`` after sharding including padding.
1084

1085
        This requires ``tensor`` to have 1D shape and ensures that the returned
1086
        shape is 1D.
1087
        """
1088
        assert len(tensor.shape) == 1, f"{tensor.shape}"
1089
        unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard(
1090
            tensor, rank, world_size
1091
        )
1092
        unpadded_sharded_size = unpadded_sharded_tensor.size()
1093
        assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
1094
        return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
1095

1096
    def _get_flat_param_offsets(self) -> List[Tuple[int, int]]:
1097
        """
1098
        Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding).
1099

1100
        NOTE: The returned list includes elements for alignment padding.
1101
        """
1102
        cumulative_sum = list(accumulate(self.flat_param._numels_with_padding))
1103
        starts = [0] + cumulative_sum[:-1]
1104
        ends = [end - 1 for end in cumulative_sum]  # inclusive
1105
        param_offsets = list(zip(starts, ends))
1106
        return param_offsets
1107

1108
    @no_type_check
1109
    def shard_metadata(
1110
        self,
1111
    ) -> FlatParamShardMetadata:
1112
        """
1113
        Return the shard-related metadata specific to this rank's shard of the flat parameter.
1114

1115
        NOTE: The returned tuple does not include elements for alignment
1116
        padding but does account for the padding.
1117
        """
1118
        fqns_list = []
1119
        shapes_list = []
1120
        numels_list = []
1121
        shard_param_offsets = []
1122
        for fqn, shape, numel, shard_param_info in zip(
1123
            self.flat_param._fqns,
1124
            self.flat_param._shapes,
1125
            self.flat_param._numels,
1126
            self.flat_param._shard_param_infos,
1127
        ):
1128
            if not shard_param_info.in_shard:
1129
                continue
1130
            fqns_list.append(fqn)
1131
            shapes_list.append(shape)
1132
            numels_list.append(numel)
1133
            shard_param_offsets.append(
1134
                (
1135
                    shard_param_info.intra_param_start_idx,
1136
                    shard_param_info.intra_param_end_idx,
1137
                )
1138
            )
1139
        return FlatParamShardMetadata(
1140
            tuple(fqns_list),
1141
            tuple(shapes_list),
1142
            tuple(numels_list),
1143
            shard_param_offsets,
1144
        )
1145

1146
    @no_type_check
1147
    @torch.no_grad()
1148
    def init_flat_param_attributes(self) -> None:
1149
        """
1150
        This initializes some attributes on the handle's ``FlatParameter``.
1151
        This should be called during lazy initialization since it requires the
1152
        parameter to be on the compute device if not offloading to CPU and we
1153
        want to give users the chance to move the parameter appropriately after
1154
        the FSDP constructor.
1155

1156
        For each tensor attribute on the ``FlatParameter``, see the unshard and
1157
        reshard methods in this class for the allocation and free pattern.
1158
        """
1159
        flat_param = self.flat_param
1160
        if flat_param.dtype != self._orig_param_dtype:
1161
            # Entering this branch means that the user changed the parameter
1162
            # dtype after FSDP initialization, in which case we may need to
1163
            # refresh some saved dtype attributes (dtypes specified as a part
1164
            # of mixed precision take precedence).
1165
            if not self._low_prec_param_dtype_specified:
1166
                self._fwd_bwd_param_dtype = flat_param.dtype
1167
            # For `reduce_dtype`, require `param_dtype` was not specified since
1168
            # then we infer the `reduce_dtype` from the specified `param_dtype`
1169
            if (
1170
                not self._low_prec_reduce_dtype_specified
1171
                and not self._low_prec_param_dtype_specified
1172
            ):
1173
                self._reduce_dtype = flat_param.dtype
1174
            self._orig_param_dtype = flat_param.dtype
1175
        cpu_device = torch.device("cpu")
1176
        if self._offload_params:
1177
            _p_assert(
1178
                flat_param.device == cpu_device,
1179
                f"Expects the `FlatParameter` to be on CPU when parameter CPU "
1180
                f"offloading is enabled, not {flat_param.device}",
1181
            )
1182
        else:
1183
            self._check_on_compute_device(self.flat_param)
1184
        flat_param._local_shard = flat_param.data
1185
        if self._offload_params:
1186
            # Pin the memory for faster H2D transfer
1187
            flat_param._local_shard = flat_param._local_shard.pin_memory()
1188
            # Pre-allocate the sharded gradient on CPU to enable non-blocking
1189
            # D2H transfer during the backward pass
1190
            flat_param._cpu_grad = torch.zeros_like(
1191
                flat_param._local_shard, device=cpu_device
1192
            ).pin_memory()
1193
        if self._uses_param_mixed_precision:
1194
            # For parameter mixed precision, we maintain a low precision
1195
            # sharded tensor on the compute device to be all-gathered (for
1196
            # sharded strategies) or directly used (for `NO_SHARD`) for
1197
            # computation.
1198
            flat_param._mp_shard = torch.empty_like(
1199
                flat_param._local_shard,
1200
                device=self.device,
1201
                dtype=self._fwd_bwd_param_dtype,
1202
            )
1203
            _free_storage(flat_param._mp_shard)
1204
        if self.uses_sharded_strategy:
1205
            # We maintain a padded unsharded tensor that serves as the
1206
            # all-gather destination and owns the original parameter storages.
1207
            unsharded_param_dtype = (
1208
                self._fwd_bwd_param_dtype
1209
                if self._uses_param_mixed_precision
1210
                else flat_param.dtype
1211
            )  # use low precision if parameter mixed precision is enabled
1212
            padded_unsharded_numel = flat_param.numel() * self.world_size
1213
            flat_param._full_param_padded = torch.empty(
1214
                padded_unsharded_numel,
1215
                device=self.device,
1216
                dtype=unsharded_param_dtype,
1217
            )
1218
            flat_param._padded_unsharded_size = flat_param._full_param_padded.size()
1219
            _free_storage(flat_param._full_param_padded)
1220

1221
            if self._uses_param_mixed_precision:
1222
                # For parameter mixed precision, we maintain a full precision
1223
                # padded unsharded tensor for when we force full precision.
1224
                flat_param._full_prec_full_param_padded = torch.empty(
1225
                    padded_unsharded_numel,
1226
                    device=self.device,
1227
                    dtype=flat_param.dtype,  # full precision
1228
                )
1229
                _free_storage(flat_param._full_prec_full_param_padded)
1230

1231
    ###################
1232
    # UNSHARD/RESHARD #
1233
    ###################
1234
    def pre_unshard(self) -> bool:
1235
        """
1236
        Return ``False`` if this is a no-op and ``True`` otherwise.
1237

1238
        Postcondition: ``self.flat_param`` 's data is on the device for
1239
        communication and is what should be all-gathered. This means that it
1240
        matches the dtype of the expected unsharded parameter.
1241
        """
1242
        if (
1243
            self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
1244
            and self._skipped_use_sharded_views
1245
        ):
1246
            # Since this path imposes special semantics for the unsharded flat
1247
            # parameter (e.g. forcing full precision), use sharded views to
1248
            # reuse the existing logic for that special handling
1249
            self._use_sharded_views()
1250
        ret = False
1251
        if self._use_orig_params and not self._skip_writeback_check:
1252
            ret = self._writeback_orig_params()
1253
        if (
1254
            self.uses_sharded_strategy
1255
            and not self._offload_params
1256
            and not self.needs_unshard()
1257
        ):
1258
            pass  # no-op
1259
        elif self._uses_param_mixed_precision and not self._force_full_precision:
1260
            self._use_low_precision_shard()
1261
            ret = True
1262
        elif self._offload_params and self.flat_param.device != self.device:
1263
            # NOTE: This creates a new tensor distinct from any attributes.
1264
            self.flat_param_to(self.device, non_blocking=True)
1265
            ret = True
1266
        self._check_on_compute_device(self.flat_param)
1267
        return ret
1268

1269
    def _use_low_precision_shard(self):
1270
        """Allocate on the compute device and switch to using the low precision sharded flat parameter."""
1271
        self._check_low_precision_shard()
1272
        flat_param = self.flat_param
1273
        _alloc_storage(
1274
            flat_param._mp_shard, flat_param._local_shard.size()  # type: ignore[attr-defined]
1275
        )
1276
        # `copy_()` implicitly casts to the low precision
1277
        flat_param._mp_shard.copy_(  # type: ignore[attr-defined]
1278
            flat_param._local_shard.to(  # type: ignore[attr-defined]
1279
                self.device, non_blocking=True
1280
            )
1281
        )
1282
        # Invariant: `_mp_shard` is always on the compute device.
1283
        flat_param.data = flat_param._mp_shard  # type: ignore[attr-defined]
1284

1285
    def unshard(self):
1286
        """
1287
        Run the unshard logic.
1288

1289
        This includes all-gathering the flat parameter
1290
        and switching to using the unsharded flat parameter. If the handle does
1291
        not need unsharding, then this only switches to using the unsharded
1292
        flat parameter. For ``NO_SHARD``, this is a no-op.
1293

1294
        If FSDP is in :meth:`summon_full_params` and the handle uses parameter
1295
        mixed precision, then the parameter is forced to full precision.
1296
        """
1297
        if not self.needs_unshard():
1298
            # Even when not needing an unshard, we should switch to using
1299
            # the unsharded flat parameter
1300
            unsharded_flat_param = (
1301
                self._get_padded_unsharded_flat_param()
1302
                if self.uses_sharded_strategy
1303
                else self.flat_param
1304
            )
1305
            self._use_unsharded_flat_param(unsharded_flat_param)
1306
            return
1307
        unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
1308
        padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
1309
        self._use_unsharded_flat_param(padded_unsharded_flat_param)
1310

1311
    def needs_unshard(self) -> bool:
1312
        """Return if the handle's flat parameter needs to be unsharded."""
1313
        if not self.uses_sharded_strategy:
1314
            return False
1315
        unsharded_flat_param = self._get_padded_unsharded_flat_param()
1316
        already_unsharded = _same_storage_size(
1317
            unsharded_flat_param, unsharded_flat_param.numel()
1318
        )
1319
        return not already_unsharded
1320

1321
    def _alloc_padded_unsharded_flat_param(self):
1322
        """
1323
        Allocate the *padded* unsharded flat parameter.
1324

1325
        The unpadded unsharded
1326
        flat parameter is always a view into the padded one. This padded
1327
        parameter is saved to a different attribute on the ``FlatParameter``
1328
        depending on if we force full precision.
1329
        """
1330
        self._check_sharded_strategy()
1331
        flat_param = self.flat_param
1332
        unsharded_flat_param = self._get_padded_unsharded_flat_param()
1333
        self._check_storage_freed(unsharded_flat_param)
1334
        _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size)  # type: ignore[attr-defined]
1335
        return unsharded_flat_param
1336

1337
    def _get_padded_unsharded_flat_param(self) -> torch.Tensor:
1338
        """
1339
        Return a reference to the padded unsharded flat parameter depending on the calling context.
1340

1341
        This should only be called if using a sharded strategy.
1342
        """
1343
        self._check_sharded_strategy()
1344
        flat_param = self.flat_param
1345
        if self._force_full_precision and self._uses_param_mixed_precision:
1346
            # When parameter mixed precision is enabled, we use a different
1347
            # tensor as the all-gather destination to preserve the invariant
1348
            # that  `_full_param_padded` is in the low precision
1349
            unsharded_flat_param = flat_param._full_prec_full_param_padded  # type: ignore[attr-defined]
1350
            _p_assert(
1351
                unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
1352
                f"Expects full precision but got {self._fwd_bwd_param_dtype}",
1353
            )
1354
            # For no-reshard-after-forward strategies, `_full_param_padded` may
1355
            # still be allocated from a previous forward. As we are forcing
1356
            # full precision here, the full-precision unsharded copy may be
1357
            # modified, invalidating the existing low-precision unsharded copy,
1358
            # so we should free it here to ensure a new all-gather for the next
1359
            # forward/backward computation to persist the modifications.
1360
            if flat_param._full_param_padded.untyped_storage().size() > 0:
1361
                _free_storage(flat_param._full_param_padded)
1362
        else:
1363
            unsharded_flat_param = flat_param._full_param_padded  # type: ignore[attr-defined]
1364
        return unsharded_flat_param
1365

1366
    def _all_gather_flat_param(
1367
        self,
1368
        padded_unsharded_flat_param: Tensor,
1369
    ) -> Tensor:
1370
        """
1371
        All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
1372

1373
        Then switch to use the all-gathered tensor.
1374
        """
1375
        _p_assert(
1376
            hasattr(self, "process_group") and hasattr(self, "world_size"),
1377
            "Expects a process group and world size to have been set via `shard()`",
1378
        )
1379
        sharded_flat_param = self.flat_param.data
1380
        expected_numel = sharded_flat_param.numel() * self.world_size
1381
        _p_assert(
1382
            padded_unsharded_flat_param.numel() == expected_numel,
1383
            f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
1384
        )
1385

1386
        pg = (
1387
            self._fake_process_group
1388
            if self._use_fake_all_gather
1389
            else self.process_group
1390
        )
1391

1392
        # HACK this should be handled by C10D
1393
        if sharded_flat_param.is_cpu:  # type: ignore[attr-defined]
1394
            tensor_list = list(
1395
                torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg))
1396
            )
1397
            work = dist.all_gather(tensor_list, sharded_flat_param, group=pg)
1398
        else:
1399
            dist.all_gather_into_tensor(
1400
                padded_unsharded_flat_param,
1401
                sharded_flat_param,
1402
                pg,
1403
            )
1404

1405
        if self._offload_params:
1406
            # In case of offloading, `flat_param.data` (i.e. sharded param) is
1407
            # created on the pre-unshard stream. We need to hand it over to the
1408
            # unshard stream for all-gather
1409
            _no_dispatch_record_stream(
1410
                sharded_flat_param,
1411
                self._device_handle.current_stream(),  # unshard_stream
1412
            )
1413
        return padded_unsharded_flat_param
1414

1415
    def _use_unsharded_flat_param(
1416
        self,
1417
        padded_unsharded_flat_param: torch.Tensor,
1418
    ) -> None:
1419
        """
1420
        Switch to use the *unpadded* unsharded flat parameter.
1421

1422
        This is a view into the *padded* unsharded flat parameter.
1423
        """
1424
        unsharded_size = self.flat_param._unpadded_unsharded_size
1425
        flat_param_part = padded_unsharded_flat_param[: unsharded_size.numel()]
1426
        # slicing [:] is not visible to autograd because of .data
1427
        self.flat_param.data = flat_param_part
1428
        in_forward = self._training_state == HandleTrainingState.FORWARD
1429
        in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
1430
        if self._use_orig_params:
1431
            if self._skipped_use_sharded_views and in_pre_backward:
1432
                # This call corresponds to the complementary pre-backward
1433
                # `_use_unsharded_views()` to the skipped pre-forward
1434
                # `_use_sharded_views()`, so we should skip this one too.
1435
                return
1436
            # We use `Tensor` views in the forward so that they are tracked by
1437
            # autograd. We use them in the pre-backward as well to support
1438
            # reentrant activation checkpointing, which needs the views to be
1439
            # tracked by autograd in the backward pass's recomputed forward.
1440
            self._use_unsharded_views(
1441
                as_params=(not in_forward and not in_pre_backward)
1442
            )
1443
        elif in_forward:
1444
            self._use_unsharded_views(as_params=False)
1445

1446
    def post_unshard(self):
1447
        """
1448
        Run the post-unshard logic.
1449

1450
        This includes freeing the low precision shard if needed.
1451
        """
1452
        if self._uses_param_mixed_precision and self.uses_sharded_strategy:
1453
            self._free_low_precision_sharded_param()
1454
        self._check_on_compute_device(self.flat_param)
1455

1456
    def _free_low_precision_sharded_param(self):
1457
        """Frees the low precision sharded flat parameter."""
1458
        self._check_low_precision_shard()
1459
        # `_mp_shard` is allocated in the pre-unshard stream, consumed in the
1460
        # unshard stream for sharded strategies, and consumed in both the
1461
        # unshard and default streams for `NO_SHARD`. For sharded strategies,
1462
        # the current stream here is the unshard stream, and for `NO_SHARD`,
1463
        # it is the default stream. For `NO_SHARD`, only recording for the
1464
        # default stream suffices since the default stream waits for the
1465
        # unshard stream.
1466
        _no_dispatch_record_stream(
1467
            self.flat_param._mp_shard, self._device_handle.current_stream()  # type: ignore[attr-defined]
1468
        )
1469
        _free_storage(self.flat_param._mp_shard)  # type: ignore[attr-defined]
1470

1471
    @torch.no_grad()
1472
    def unshard_grad(self):
1473
        """
1474
        Unshard the handle's ``FlatParameter``'s gradient.
1475

1476
        If all ranks have
1477
        ``None`` gradient, then all original parameters will as well. This
1478
        method performs an all-reduce and an all-gather. The additional
1479
        all-reduce is tolerable since this method is not meant to be used on
1480
        the computation critical path.
1481

1482
        Postcondition: ``_saved_grad_shard`` is defined and contains the value
1483
        to set ``flat_param.grad`` after gradients are resharded.
1484
        """
1485
        if not self.uses_sharded_strategy:
1486
            self._use_unsharded_grad_views()
1487
            return
1488
        flat_param = self.flat_param
1489
        self._check_unsharded(flat_param)
1490

1491
        # Check if all ranks have a `None` gradient
1492
        num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device)
1493
        num_grad_none[0] = flat_param.grad is None
1494
        dist.all_reduce(num_grad_none, group=self.process_group)
1495
        if num_grad_none[0] == self.world_size:
1496
            flat_param._saved_grad_shard = None  # type: ignore[assignment]
1497
            self._use_unsharded_grad_views()
1498
            return
1499

1500
        if flat_param.grad is None:
1501
            # In the case that only some ranks have `None` gradient, we use
1502
            # zeros to approximate as a best effort attempt
1503
            if self._debug_level == dist.DebugLevel.INFO:
1504
                warnings.warn(
1505
                    f"[Rank {self.rank}] Only some but not all ranks have a "
1506
                    "`None` `FlatParameter` gradient, so FSDP is using zeros to "
1507
                    "approximate those ranks' sharded gradients being `None`"
1508
                )
1509
            flat_param._saved_grad_shard = None  # type: ignore[assignment]
1510
            sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device)  # type: ignore[attr-defined]
1511
        else:
1512
            self._check_sharded(flat_param.grad)
1513
            flat_param._saved_grad_shard = flat_param.grad  # type: ignore[attr-defined]
1514
            sharded_grad = flat_param._saved_grad_shard  # type: ignore[attr-defined]
1515
        padded_unsharded_grad = torch.empty(
1516
            flat_param._padded_unsharded_size,  # type: ignore[attr-defined]
1517
            device=self.device,
1518
            dtype=sharded_grad.dtype,
1519
        )
1520
        dist.all_gather_into_tensor(
1521
            padded_unsharded_grad, sharded_grad, self.process_group
1522
        )
1523
        unsharded_size = self.flat_param._unpadded_unsharded_size
1524
        flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view(
1525
            unsharded_size
1526
        )
1527
        self._use_unsharded_grad_views()
1528

1529
    def reshard_grad(self):
1530
        if self._use_orig_params:
1531
            self._use_sharded_grad_views()
1532
        if not self.uses_sharded_strategy:
1533
            return
1534
        self.flat_param.grad = self.flat_param._saved_grad_shard  # type: ignore[attr-defined]
1535
        delattr(self.flat_param, "_saved_grad_shard")
1536

1537
    def prepare_gradient_for_backward(self):
1538
        """
1539
        Prepare the gradient for the backward computation.
1540

1541
        This is done by saving and clearing any existing sharded gradient
1542
        in ``.grad`` to enable computing a new unsharded gradient.
1543
        """
1544
        _p_assert(
1545
            self._training_state
1546
            in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
1547
            "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
1548
        )
1549
        flat_param = self.flat_param
1550
        if flat_param.grad is not None and (
1551
            flat_param.grad.size() != flat_param._unpadded_unsharded_size
1552
            or flat_param.grad.device != flat_param.device  # grad on CPU
1553
        ):
1554
            self._check_on_compute_device(self.flat_param)
1555
            grad_offloaded = flat_param.grad.device != self.device
1556
            _p_assert(
1557
                not grad_offloaded or self._offload_params,
1558
                f"Expects the sharded gradient to be on {self.device} "
1559
                f"but got {flat_param.grad.device}",
1560
            )
1561
            prev_iter_synced_gradients = (
1562
                flat_param.grad.size()
1563
                == flat_param._local_shard.size()  # type: ignore[attr-defined]
1564
            )
1565
            if prev_iter_synced_gradients:
1566
                # TODO (awgu): Gradient accumulation outside `no_sync()`
1567
                # does not work with CPU offloading. The issue should be
1568
                # that, in the post-backward hook, we cannot do an addition
1569
                # between a CPU tensor (the existing sharded gradient) and
1570
                # a GPU tensor (the new sharded gradient).
1571
                if not grad_offloaded:
1572
                    flat_param._saved_grad_shard = flat_param.grad.data  # type: ignore[attr-defined]
1573
                    sharded_grad = flat_param._saved_grad_shard  # type: ignore[attr-defined]
1574
                else:
1575
                    _p_assert(
1576
                        hasattr(flat_param, "_cpu_grad"),
1577
                        "`_cpu_grad` should be defined if the gradient is on CPU",
1578
                    )
1579
                    sharded_grad = flat_param._cpu_grad  # type: ignore[attr-defined]
1580
                # If user specified to keep the gradient in low precision, then
1581
                # the gradient may still be of the low precision dtype if the
1582
                # user did not set the gradient to `None` after the previous
1583
                # backward, in which case FSDP should cast back to the full
1584
                # precision dtype so that FSDP can accumulate in that dtype in
1585
                # the post-backward hook and assign to `.grad` in that dtype in
1586
                # the post-backward callback.
1587
                local_shard_dtype = flat_param._local_shard.dtype  # type: ignore[attr-defined]
1588
                if (
1589
                    self._keep_low_precision_grads
1590
                    and sharded_grad.dtype != local_shard_dtype
1591
                ):
1592
                    sharded_grad.data = sharded_grad.to(local_shard_dtype)
1593
            else:
1594
                padded_unsharded_size = flat_param._padded_unsharded_size  # type: ignore[attr-defined]
1595
                _p_assert(
1596
                    flat_param.grad.size() == padded_unsharded_size,
1597
                    "Expects `.grad` to be the unsharded gradient in "
1598
                    f"`no_sync()` with size {padded_unsharded_size} "
1599
                    f"but got size {flat_param.grad.size()}",
1600
                )
1601
            flat_param.grad = None
1602

1603
    def prepare_gradient_for_optim(self):
1604
        """Prepare the gradient for optimizer computation by moving the sharded gradient to the ``.grad`` attribute."""
1605

1606
        def cast_grad_to_param_dtype_if_needed(flat_param):
1607
            # TODO (rohan-varma): test for full precision with keep_low_precision_grads
1608
            if not self._force_full_precision and self._keep_low_precision_grads:
1609
                _p_assert(flat_param.grad is not None, "Unexpected None grad!")
1610
                if flat_param.grad.dtype != self._fwd_bwd_param_dtype:
1611
                    flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype)
1612
                    if self._use_orig_params:
1613
                        self._use_sharded_grad_views()
1614

1615
        flat_param = self.flat_param
1616
        # TODO (awgu): We should replace these conditional checks to encode
1617
        # the logical intention more directly.
1618
        if hasattr(flat_param, "_cpu_grad"):
1619
            # NOTE: This branch includes `NO_SHARD`.
1620
            self._check_sharded(flat_param)
1621
            self._check_on_cpu(flat_param)
1622
            flat_param.grad = flat_param._cpu_grad  # type: ignore[attr-defined]
1623
            cast_grad_to_param_dtype_if_needed(flat_param)
1624
        elif hasattr(flat_param, "_saved_grad_shard"):
1625
            self._check_sharded(flat_param)
1626
            self._check_on_compute_device(flat_param)
1627
            if flat_param._saved_grad_shard is not None:
1628
                self._check_on_compute_device(flat_param._saved_grad_shard)  # type: ignore[attr-defined]
1629
            # If no sharded gradient was computed this iteration, then there is
1630
            # no need to forward `_saved_grad_shard` to `grad`
1631
            if flat_param._post_backward_called:  # type: ignore[attr-defined]
1632
                flat_param.grad = flat_param._saved_grad_shard  # type: ignore[attr-defined]
1633
                if flat_param.grad is not None:
1634
                    cast_grad_to_param_dtype_if_needed(flat_param)
1635
        else:
1636
            _p_assert(
1637
                not self.uses_sharded_strategy
1638
                or not flat_param._post_backward_called,  # type: ignore[attr-defined]
1639
                "All sharded parameters that received a gradient in the "
1640
                "post-backward should use `_saved_grad_shard`",
1641
            )
1642
        # Delete `_saved_grad_shard` since its existence indicates a previous
1643
        # gradient to accumulate with in the post-backward hook
1644
        if hasattr(flat_param, "_saved_grad_shard"):
1645
            delattr(flat_param, "_saved_grad_shard")
1646

1647
    @contextlib.contextmanager
1648
    def to_cpu(self):
1649
        """
1650
        Move the unpadded unsharded flat parameter to CPU while in the context and moves it back to the previous device upon exit.
1651

1652
        For now, this assumes the ``FlatParameter`` is the unpadded unsharded flat parameter
1653
        since (1) there is no reason to include the padding in the copy and (2)
1654
        there is no use case for the sharded flat parameter.
1655

1656
        Precondition: ``self.flat_param`` 's data is the unpadded unsharded
1657
        flat parameter on the compute device, and the handle uses a sharded
1658
        strategy.
1659
        Postcondition: Same as the precondition.
1660
        """
1661
        self._check_sharded_strategy()
1662
        _p_assert(
1663
            self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
1664
            f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
1665
        )
1666
        self._check_on_compute_device(self.flat_param)
1667
        # Check that the unpadded unsharded flat parameter is a view into the
1668
        # padded unsharded flat parameter as expected
1669
        # NOTE: This check is not strictly needed for correctness but is a
1670
        # useful sanity check since the tensor should only be used internally.
1671
        _p_assert(
1672
            _same_storage(self.flat_param, self._get_padded_unsharded_flat_param()),
1673
            "Expects the unpadded parameter to be a view into the padded parameter",
1674
        )
1675
        self.flat_param_to(torch.device("cpu"))
1676
        self._free_unsharded_flat_param()
1677
        try:
1678
            yield
1679
        finally:
1680
            _p_assert(
1681
                self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
1682
                f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
1683
            )
1684
            padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
1685
            # Copy from CPU to the compute device
1686
            padded_unsharded_flat_param[: self.flat_param.numel()].copy_(
1687
                self.flat_param
1688
            )
1689
            self._use_unsharded_flat_param(padded_unsharded_flat_param)
1690

1691
    def reshard(self, free_unsharded_flat_param: bool):
1692
        """
1693
        Run the reshard logic.
1694

1695
        This includes freeing the unsharded flat
1696
        parameter if ``free_unsharded_flat_param`` and switching to using the
1697
        sharded flat parameter. Note that this also implicitly offloads
1698
        the sharded flat parameter (if CPU offload is enabled) by pointing
1699
        it to the ``_local_shard`` attribute which resides on CPU.
1700
        """
1701
        # Switch to the sharded `FlatParameter` before freeing to prevent
1702
        # "use-after-free"-type bugs with external profiling tools, where for
1703
        # `use_orig_params=True`, the `param` does not point to valid memory
1704
        # when setting `param.data = ...` in `_use_sharded_views()`.
1705
        self._use_sharded_flat_param()
1706
        if free_unsharded_flat_param:
1707
            self._free_unsharded_flat_param()
1708

1709
    def post_reshard(self):
1710
        """
1711
        Run the post-reshard logic.
1712

1713
        This includes freeing any memory that
1714
        can now be freed given that the ``FlatParameter`` points to the full
1715
        precision sharded flat parameter.
1716

1717
        Precondition: ``self.flat_param`` 's data points to the full precision
1718
        sharded flat parameter.
1719
        """
1720
        # For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since it
1721
        # is also the low precision *unsharded* flat parameter. Hence, we delay
1722
        # the free until the reshard.
1723
        if (
1724
            self._uses_param_mixed_precision
1725
            and not self.uses_sharded_strategy
1726
            and not self._force_full_precision  # did not use the low precision shard
1727
        ):
1728
            self._free_low_precision_sharded_param()
1729

1730
    def _free_unsharded_flat_param(self):
1731
        """
1732
        Free the padded unsharded flat parameter. We allow this
1733
        function to be called even when storage is not allocated
1734

1735
        The tensor to free depends
1736
        on the calling context since the unshard may have forced full
1737
        precision, in which case a different tensor is used.
1738
        """
1739
        self._check_sharded_strategy()
1740
        unsharded_flat_param = self._get_padded_unsharded_flat_param()
1741
        self._check_on_compute_device(unsharded_flat_param)
1742
        # Do not free the memory until all ops in the current stream finish
1743
        _no_dispatch_record_stream(
1744
            unsharded_flat_param, self._device_handle.current_stream()
1745
        )
1746
        _free_storage(unsharded_flat_param)
1747

1748
    def _use_sharded_flat_param(self) -> None:
1749
        """Switches to using the sharded flat parameter."""
1750
        flat_param = self.flat_param
1751
        if self._use_orig_params:
1752
            in_forward = self._training_state == HandleTrainingState.FORWARD
1753
            skip_use_sharded_views = (
1754
                torch.is_grad_enabled()
1755
                and in_forward
1756
                and self._sharding_strategy
1757
                in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
1758
            )
1759
            # Only incur the extra `.data` call if needed
1760
            if skip_use_sharded_views:
1761
                unsharded_flat_param = flat_param.data
1762
        if self._offload_params:
1763
            device = flat_param._local_shard.device  # type: ignore[attr-defined]
1764
            _p_assert(
1765
                device == torch.device("cpu"),
1766
                f"Expects the local shard to be on CPU but got {device}",
1767
            )
1768
        flat_param.data = flat_param._local_shard  # type: ignore[attr-defined]
1769
        if self._use_orig_params:
1770
            if skip_use_sharded_views:  # type: ignore[possibly-undefined]
1771
                self._unsharded_flat_param_for_skipped_views = unsharded_flat_param  # type: ignore[possibly-undefined]
1772
            else:
1773
                self._use_sharded_views()
1774
            # For the post-forward reshard, we may try to use sharded gradient
1775
            # views (or unsharded gradient views if a gradient was accumulated
1776
            # in `no_sync()`), but for the post-backward reshard, we delay the
1777
            # call to after the reduce-scatter.
1778
            if (
1779
                in_forward  # type: ignore[possibly-undefined]
1780
                # Skip using gradient views if skipped using sharded views
1781
                # since exposing unsharded parameters with sharded gradients
1782
                # may be confusing to the user
1783
                and not self._skipped_use_sharded_views
1784
            ):
1785
                # TODO: Change `_unpadded_unsharded_size` if we change the
1786
                # gradient to be computed directly with padding.
1787
                accumulated_grad_in_no_sync = (
1788
                    flat_param.grad is not None
1789
                    and self.uses_sharded_strategy
1790
                    and flat_param.grad.shape == flat_param._unpadded_unsharded_size
1791
                )
1792
                if accumulated_grad_in_no_sync:
1793
                    self._use_unsharded_grad_views()
1794
                else:
1795
                    self._use_sharded_grad_views()
1796

1797
    #########
1798
    # VIEWS #
1799
    #########
1800
    @no_type_check
1801
    def _get_unflat_views_unaligned(
1802
        self,
1803
        tensor: Optional[torch.Tensor] = None,
1804
    ) -> Iterator[Tensor]:
1805
        """
1806
        Return unflattened ``Tensor`` views into ``tensor``.
1807

1808
        If `tensor`` is ``None``,  ``flat_param`` is used. The unflattening is based
1809
        on ``flat_param`` 's metadata.
1810

1811
        Examples for ``tensor`` include ``flat_param.grad`` or unsharded
1812
        tensor optimizer state.
1813
        """
1814
        flat_param = self.flat_param
1815
        if tensor is None:
1816
            tensor = flat_param
1817
        views = (
1818
            _ext_post_unflatten_transform(
1819
                subtensor.view(shape),
1820
                param_extension,
1821
                self._fsdp_extension,
1822
            )
1823
            for (subtensor, shape, param_extension) in zip(
1824
                torch.split(tensor, flat_param._numels, dim=0),
1825
                flat_param._shapes,
1826
                flat_param._param_extensions,
1827
            )
1828
        )
1829
        return views
1830

1831
    @no_type_check
1832
    def _get_unflat_views_aligned(
1833
        self,
1834
        tensor: Optional[Tensor] = None,
1835
    ) -> List[Tensor]:
1836
        """
1837
        Return unflattened ``Tensor`` views into ``tensor`` with handling for padding.
1838

1839
        This method has the same contract as :meth:`_get_unflat_views_unaligned`
1840
        except it checks for ``None`` placeholders representing padding for
1841
        alignment, which may incur slightly more CPU overhead.
1842
        """
1843
        flat_param = self.flat_param
1844
        if tensor is None:
1845
            tensor = flat_param
1846
        splits: List[Tensor] = torch.split(
1847
            tensor, flat_param._numels_with_padding, dim=0
1848
        )
1849
        idx = 0
1850
        views: List[Tensor] = []
1851
        for split, is_padding in zip(splits, flat_param._is_padding_mask):
1852
            if is_padding:
1853
                continue
1854
            views.append(
1855
                _ext_post_unflatten_transform(
1856
                    split.view(flat_param._shapes[idx]),
1857
                    flat_param._param_extensions[idx],
1858
                    self._fsdp_extension,
1859
                )
1860
            )
1861
            idx += 1
1862
        return views
1863

1864
    @no_type_check
1865
    @torch.enable_grad()
1866
    def _use_unsharded_views(self, as_params: bool) -> None:
1867
        """
1868
        Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it.
1869

1870
        Args:
1871
            as_params (bool): If ``True``, then registers the original
1872
                parameters as ``nn.Parameter`` s; if ``False``, then registers
1873
                the original parameters only as ``Tensor`` s. ``False`` should
1874
                be used during forward/backward computation and when hiding the
1875
                original parameters from :meth:`nn.Module.named_parameters`.
1876

1877
        Note:
1878
            when prefetching for next forward, current forward may be
1879
            annotated with `@torch.no_grad()`
1880
            `@torch.enable_grad()` ensures non-empty `view.grad_fn`
1881
            otherwise `_post_backward_hook` will not get called
1882
        """
1883
        flat_param = self.flat_param
1884
        self._check_unsharded(flat_param)
1885
        views = self._get_unflat_views()
1886
        from torch.distributed._tensor import DTensor
1887

1888
        for i, (view, (param_name, module, _)) in enumerate(
1889
            zip(views, flat_param._param_infos)
1890
        ):
1891
            if self._use_orig_params and as_params:
1892
                if type(view) is DTensor:
1893
                    # A `DTensor` `view` is not compatible with assigning
1894
                    # `param.data = view`, so we cannot preserve the parameter
1895
                    # variable.
1896
                    self._setattr_param(
1897
                        module,
1898
                        param_name,
1899
                        nn.Parameter(view, requires_grad=flat_param.requires_grad),
1900
                    )
1901
                    continue
1902
                param = self.flat_param._params[i]
1903
                self._setattr_param(module, param_name, param)
1904
                param.data = view
1905
            elif as_params:
1906
                self._setattr_param(
1907
                    module,
1908
                    param_name,
1909
                    nn.Parameter(view, requires_grad=flat_param.requires_grad),
1910
                )
1911
            else:  # `as_params=False`
1912
                param_var: Tensor = view
1913
                if self._use_orig_params:
1914
                    if self._training_state == HandleTrainingState.FORWARD:
1915
                        # Save the `Tensor` for the pre-backward
1916
                        self.flat_param._tensors[i] = view  # save for pre-backward
1917
                    elif self._training_state == HandleTrainingState.BACKWARD_PRE:
1918
                        # Use the saved `Tensor` variable from the forward to
1919
                        # preserve the autograd graph so that the post-backward
1920
                        # hook fires (e.g. for reentrant AC)
1921
                        tensor = self.flat_param._tensors[i]
1922
                        tensor.data = view
1923
                        param_var = tensor
1924
                self._setattr_tensor(module, param_name, param_var)
1925
                if (
1926
                    self._use_orig_params
1927
                    and self._training_state == HandleTrainingState.FORWARD
1928
                ):
1929
                    module._parameters[param_name] = param_var
1930
        for i, (
1931
            param_name,
1932
            module,
1933
            _,
1934
            prim_param_name,
1935
            prim_module,
1936
            _,
1937
        ) in enumerate(self.flat_param._shared_param_infos):
1938
            prim_param: Union[Tensor, nn.Parameter] = getattr(
1939
                prim_module, prim_param_name
1940
            )
1941
            _p_assert(
1942
                not as_params or isinstance(prim_param, nn.Parameter),
1943
                f"as_params={as_params} type(prim_param)={type(prim_param)}",
1944
            )
1945
            if self._use_orig_params and as_params:
1946
                shared_param = self.flat_param._shared_params[i]
1947
                self._setattr_param(module, param_name, shared_param)
1948
                shared_param.data = prim_param
1949
            elif as_params:
1950
                self._setattr_param(module, param_name, prim_param)
1951
            else:
1952
                self._setattr_tensor(module, param_name, prim_param)
1953
                if (
1954
                    self._use_orig_params
1955
                    and self._training_state == HandleTrainingState.FORWARD
1956
                ):
1957
                    module._parameters[param_name] = prim_param
1958

1959
    @no_type_check
1960
    def _use_unsharded_grad_views(self) -> None:
1961
        """
1962
        Unflatten the unsharded flat parameter's gradient.
1963

1964
        The original parameter variables' gradients are set to be views into
1965
        the unsharded flat parameter's gradient.
1966
        """
1967
        # Expects the gradient to be in `flat_param.grad`
1968
        if self.flat_param.grad is None:
1969
            for param in chain(self.flat_param._params, self.flat_param._shared_params):
1970
                param.grad = None
1971
            return
1972
        self._check_unsharded(self.flat_param.grad)
1973
        views = self._get_unflat_views(self.flat_param.grad)
1974
        for i, (view, (param_name, module, _)) in enumerate(
1975
            zip(views, self.flat_param._param_infos)
1976
        ):
1977
            _p_assert(
1978
                hasattr(module, param_name),
1979
                f"{self.flat_param._fqns[i]} is missing",
1980
            )
1981
            param = getattr(module, param_name)
1982
            if (
1983
                param.shape != view.shape
1984
                or param.dtype != view.dtype
1985
                or param.device != view.device
1986
            ):
1987
                # NOTE: This is a hack using `.data` to side step the check
1988
                # that parameter/gradient sizes/dtypes/devices match. From
1989
                # calling `reshard()`, `param` has the sharded size, has the
1990
                # full precision dtype, and if CPU offloading is enabled, is on
1991
                # CPU. Thus, one or more of the following cases can hold when
1992
                # in `no_sync()`, where `view` is the original parameter's
1993
                # gradient:
1994
                # 1. `view` can have the unsharded size.
1995
                # 2. `view` can have the parameter low precision dtype.
1996
                # 3. `view` can be on GPU.
1997
                if param.grad is None:
1998
                    param.grad = torch.empty_like(param)
1999
                param.grad.data = view
2000
            else:
2001
                param.grad = view
2002
        for i, (
2003
            param_name,
2004
            module,
2005
            module_name,
2006
            prim_param_name,
2007
            prim_module,
2008
            _,
2009
        ) in enumerate(self.flat_param._shared_param_infos):
2010
            _p_assert(
2011
                hasattr(module, param_name),
2012
                f"{module_name + '.' + param_name if module_name else param_name} is missing",
2013
            )  # did not save FQN info in `_shared_param_infos`
2014
            param = getattr(module, param_name)
2015
            prim_param = getattr(prim_module, prim_param_name)
2016
            if (
2017
                param.shape != prim_param.grad.shape
2018
                or param.dtype != prim_param.grad.dtype
2019
                or param.device != prim_param.grad.device
2020
            ):
2021
                # NOTE: This is the same hack to use `.data` to side step the
2022
                # size check.
2023
                if param.grad is None:
2024
                    param.grad = torch.empty_like(param)
2025
                param.grad.data = prim_param.grad
2026
            else:
2027
                param.grad = prim_param.grad
2028

2029
    @contextlib.contextmanager
2030
    def unflatten_as_params(self) -> Generator:
2031
        """
2032
        Unflatten the original parameters.
2033

2034
        The function assumes that the flat parameter is unsharded. When in the context,
2035
        unflattens the original parameters as ``nn.Parameter`` views into the
2036
        flat parameter, and after the context, restores the original parameters
2037
        as ``Tensor`` views into the flat parameter.
2038
        """
2039
        self._use_unsharded_views(as_params=True)
2040
        try:
2041
            yield
2042
        finally:
2043
            self._use_unsharded_views(as_params=False)
2044

2045
    @no_type_check
2046
    @torch.no_grad()
2047
    def _use_sharded_views(self) -> None:
2048
        """
2049
        Set the original parameter variables' data to be flattened views into the sharded flat parameter.
2050

2051
        The views are kept as flattened to simplify the case where a parameter
2052
        is sharded across ranks. Parameters whose data is not present in the
2053
        sharded flat parameter have their data set to a size-0 empty tensor. We
2054
        do not delete them to ensure to preserve expected behaviors like model
2055
        printability. Parameters whose data is present must preserve their
2056
        variables to be passable to an optimizer.
2057
        """
2058
        self._unsharded_flat_param_for_skipped_views = None
2059
        if not self.uses_sharded_strategy:
2060
            # For `NO_SHARD`, use the *unflattened* unsharded views since we
2061
            # have the unsharded parameter
2062
            self._use_unsharded_views(as_params=True)
2063
            return
2064
        flat_param = self.flat_param
2065
        self._check_sharded(flat_param)
2066
        # Construct once and reuse for all parameters not in the local shard
2067
        size_0_empty_tensor = torch.empty(
2068
            0,
2069
            dtype=self.flat_param.dtype,  # in case `flat_param` changed dtype
2070
            device=self.flat_param.device,
2071
            requires_grad=False,
2072
        )
2073
        for param, shard_param_info, (param_name, module, _) in zip(
2074
            flat_param._params, flat_param._shard_param_infos, flat_param._param_infos
2075
        ):
2076
            self._setattr_param(module, param_name, param)
2077
            if not shard_param_info.in_shard:
2078
                # Allow the original data to be freed via garbage collection
2079
                param.data = size_0_empty_tensor
2080
            else:
2081
                offset = shard_param_info.offset_in_shard
2082
                numel_in_shard = shard_param_info.numel_in_shard
2083
                param.data = flat_param[offset : offset + numel_in_shard]
2084
        assert self.flat_param._shared_params is not None
2085
        for i, (
2086
            param,
2087
            (param_name, module, _, prim_param_name, prim_module, _),
2088
        ) in enumerate(
2089
            zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
2090
        ):
2091
            self._setattr_param(module, param_name, param)
2092
            prim_param = getattr(prim_module, prim_param_name)
2093
            param.data = prim_param  # could be both empty and non-empty
2094
        if self._training_state == HandleTrainingState.BACKWARD_POST:
2095
            # Clear the saved `Tensor`s since they are unneeded now
2096
            for i in range(len(self.flat_param._tensors)):
2097
                self.flat_param._tensors[i] = None
2098

2099
    @no_type_check
2100
    @torch.no_grad()
2101
    def _use_sharded_grad_views(self) -> None:
2102
        """
2103
        Set the original parameter variables' gradients to be flattened views into the sharded flat parameter's gradient.
2104

2105
        This is a no-op if there is no gradient.
2106

2107
        Parameters whose data is not present in the sharded flat parameter and
2108
        parameters with ``requires_grad=False`` have their gradients set to
2109
        ``None``. Since the gradient variables do not need to be preserved,
2110
        this method does not manipulate existing ``Tensor`` data directly and
2111
        creates new ``Tensor`` variables instead.
2112
        """
2113
        flat_param = self.flat_param
2114
        self._check_sharded(flat_param)
2115
        grad = self.sharded_grad
2116
        if grad is None:
2117
            for param in chain(flat_param._params, flat_param._shared_params):
2118
                param.grad = None
2119
            return
2120
        self._check_sharded(grad)
2121
        for param, shard_param_info, is_grad_none in zip(
2122
            flat_param._params,
2123
            flat_param._shard_param_infos,
2124
            flat_param._is_grad_none_mask,
2125
        ):
2126
            if not shard_param_info.in_shard:
2127
                param.grad = None
2128
            else:
2129
                numel_in_shard = shard_param_info.numel_in_shard
2130
                if param.requires_grad and not is_grad_none:
2131
                    offset = shard_param_info.offset_in_shard
2132
                    if self._keep_low_precision_grads or param.dtype != grad.dtype:
2133
                        # NOTE: This is a hack using `.data` to side step the
2134
                        # check that parameter/gradient dtypes match. Here,
2135
                        # `param` has full precision; `grad` has low precision.
2136
                        if param.grad is None:
2137
                            # `.grad` must have the same shape as `param`
2138
                            param.grad = torch.empty_like(param)
2139
                        param.grad.data = grad[
2140
                            offset : offset + numel_in_shard
2141
                        ].reshape(param.shape)
2142
                    else:
2143
                        param.grad = grad[offset : offset + numel_in_shard].reshape(
2144
                            param.shape
2145
                        )
2146
                else:
2147
                    param.grad = None
2148
        assert flat_param._shared_params is not None
2149
        for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate(
2150
            zip(flat_param._shared_params, flat_param._shared_param_infos)
2151
        ):
2152
            in_sharded_flat_param = hasattr(prim_module, prim_param_name)
2153
            if in_sharded_flat_param and param.requires_grad:
2154
                prim_param = getattr(prim_module, prim_param_name)
2155
                param.grad = prim_param.grad  # share the same reference
2156
            else:
2157
                param.grad = None
2158

2159
    @no_type_check
2160
    @torch.no_grad()
2161
    def _writeback_orig_params(self) -> bool:
2162
        """
2163
        Write back any parameters that changed storage to the handle's ``FlatParameter``.
2164

2165
        Iterates over the original parameters and writes back any parameters
2166
        that changed storages (due to a non-inplace operator) to the handle's
2167
        ``FlatParameter``. This method preserves the ``FlatParameter` 's
2168
        device even if an original parameter's device changes.
2169

2170
        Raises:
2171
            RuntimeError: If an original parameter or gradient changes storages
2172
            but no longer has the expected flattened shape.
2173
        Returns: ``True`` if some writeback happened, and ``False`` otherwise.
2174
        """
2175
        if (
2176
            self.uses_sharded_strategy
2177
            and not self.is_sharded(self.flat_param)
2178
            and not self._skipped_use_sharded_views
2179
        ):
2180
            # For `NO_SHARD`, we may still need to writeback
2181
            return False
2182
        flat_param = self.flat_param
2183
        wroteback = False
2184
        if self._skipped_use_sharded_views and self.uses_sharded_strategy:
2185
            # NOTE: We must use the unsharded flat parameter from which the
2186
            # unsharded views were computed, not the one from the current
2187
            # calling context (`_get_padded_unsharded_flat_param()`) since that
2188
            # may be different (e.g. the model changed from train to eval).
2189
            flat_param_tensor = self._unsharded_flat_param_for_skipped_views
2190
            _p_assert(
2191
                _data_ptr_allocated(flat_param_tensor),
2192
                "If skipped using sharded views, the unsharded flat parameter "
2193
                "should be allocated",
2194
            )
2195
        else:
2196
            flat_param_tensor = flat_param
2197
        # NOTE: Since this method is called in the pre-unshard, which is only
2198
        # called during computation in the pre-forward or pre-backward, the
2199
        # sharded gradient should be guaranteed to be in `.grad`, not in
2200
        # `._saved_grad_shard`.
2201
        flat_param_grad = (
2202
            flat_param.grad
2203
            if self.uses_sharded_strategy or not self._offload_params
2204
            else flat_param._cpu_grad
2205
        )
2206
        for i, (
2207
            param,
2208
            (in_shard, offset_in_shard, numel_in_shard, _, _),
2209
            (param_name, module, _),
2210
        ) in enumerate(
2211
            zip(
2212
                flat_param._params,
2213
                flat_param._shard_param_infos,
2214
                flat_param._param_infos,
2215
            )
2216
        ):
2217
            if not in_shard:
2218
                continue
2219
            if not hasattr(module, param_name):
2220
                # Do not writeback if original parameters are deregistered
2221
                # (e.g. during model checkpointing)
2222
                continue
2223

2224
            # Check for parameter writeback
2225
            if self._skipped_use_sharded_views:
2226
                param = flat_param._tensors[i]
2227
                _p_assert(
2228
                    param is not None,
2229
                    f"Expects to have saved tensor for {flat_param._fqns[i]}",
2230
                )
2231
            param_changed = getattr(module, param_name) is not param
2232
            needs_param_writeback = (
2233
                param_changed  # changed parameter variable itself
2234
                or not _same_storage(param, flat_param_tensor)
2235
            )
2236
            if self._skipped_use_sharded_views and (
2237
                param_changed or needs_param_writeback
2238
            ):
2239
                raise AssertionError(
2240
                    "FSDP does not support changing the parameters between "
2241
                    f"forward and backward for {self._sharding_strategy}"
2242
                )
2243
            if param_changed:
2244
                # NOTE: The gradient is not preserved after a parameter change.
2245
                param = getattr(module, param_name)
2246
                flat_param._params[i] = param
2247
            if needs_param_writeback:
2248
                expected_shape = torch.Size([numel_in_shard])
2249
                self._writeback_tensor(
2250
                    param, flat_param, i, expected_shape, offset_in_shard, True
2251
                )
2252
                wroteback = True
2253

2254
            # Check for gradient writeback
2255
            if self._skipped_use_sharded_views:
2256
                # Skip the writeback check because we do not expose gradients
2257
                # when we skipped using sharded views
2258
                continue
2259
            if param.grad is None and flat_param.grad is not None:
2260
                expected_shape = torch.Size([numel_in_shard])
2261
                self._writeback_tensor(
2262
                    None, flat_param.grad, i, expected_shape, offset_in_shard, False
2263
                )
2264
            elif param.grad is not None:
2265
                # For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
2266
                # memory and owns the gradient storage, so it will never
2267
                # require gradient writeback.
2268
                if not self.uses_sharded_strategy and self._offload_params:
2269
                    # Explicitly continue to handle the case of `no_sync()`,
2270
                    # where `param.grad` is a view into the GPU gradient
2271
                    # referenced by `flat_param.grad`, while `flat_param_grad`
2272
                    # is `flat_param._cpu_grad`, which is on CPU
2273
                    continue
2274

2275
                needs_grad_writeback = flat_param_grad is None or not _same_storage(
2276
                    param.grad, flat_param_grad
2277
                )
2278
                if needs_grad_writeback:
2279
                    if flat_param_grad is None:
2280
                        flat_param_grad = torch.zeros_like(flat_param)
2281
                    expected_shape = torch.Size([numel_in_shard])
2282
                    self._writeback_tensor(
2283
                        param.grad,
2284
                        flat_param_grad,
2285
                        i,
2286
                        expected_shape,
2287
                        offset_in_shard,
2288
                        False,
2289
                    )
2290
                    flat_param.grad = flat_param_grad
2291
                    flat_param_grad = flat_param.grad
2292

2293
        # TODO: If we want to handle shared parameters, we need to re-generate
2294
        # the shared parameter data structures in case sharedness changed.
2295
        for i, (
2296
            param_name,
2297
            module,
2298
            _,
2299
            prim_param_name,
2300
            prim_module,
2301
            _,
2302
        ) in enumerate(flat_param._shared_param_infos):
2303
            if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
2304
                raise NotImplementedError(
2305
                    "Changing shared parameters is not supported yet"
2306
                )
2307
        return wroteback
2308

2309
    def _writeback_tensor(
2310
        self,
2311
        src_tensor: Optional[Tensor],
2312
        dst_tensor: Tensor,
2313
        tensor_index: int,
2314
        expected_shape: torch.Size,
2315
        offset: int,
2316
        is_param: bool,  # else gradient
2317
    ) -> None:
2318
        """
2319
        Write back ``src_tensor`` to ``dst_tensor`` at offset ``offset``, where ``src_tensor`` should have shape ``expected_shape``.
2320

2321
        ``is_param`` indicates if the tensor is the parameter (if ``True``) or gradient (if
2322
        ``False``). If ``src_tensor`` is ``None``, then the effect is zeroing
2323
        instead of copying. ``tensor_index`` gives the index of ``src_tensor``
2324
        in the metadata structures.
2325

2326
        Raises:
2327
            RuntimeError: If the ``src_tensor`` does not have the expected
2328
            shape.
2329
        """
2330
        _p_assert(
2331
            len(expected_shape) == 1,
2332
            f"Expects a 1D expected shape but got {expected_shape}",
2333
        )
2334
        if self._debug_level == dist.DebugLevel.INFO:
2335
            rank = self.rank if hasattr(self, "rank") else dist.get_rank()
2336
            src_shape = src_tensor.shape if src_tensor is not None else None
2337
            src_device = src_tensor.device if src_tensor is not None else None
2338
            warnings.warn(
2339
                f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs "
2340
                f"writeback in {self._training_state}\n"
2341
                f"expected shape={expected_shape} shape={src_shape} "
2342
                f"expected device={dst_tensor.device} device={src_device}"
2343
            )
2344
        if src_tensor is not None and src_tensor.shape != expected_shape:
2345
            # NOTE: Gradient shape mismatch is not possible in practice since
2346
            # the gradient shape is enforced to match that of the parameter and
2347
            # we already check for parameter shape mismatch.
2348
            raise RuntimeError(
2349
                f"Cannot writeback when the {'parameter' if is_param else 'gradient'} "
2350
                f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}"
2351
            )
2352
        if src_tensor is not None:
2353
            dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor)
2354
        else:
2355
            dst_tensor[offset : offset + expected_shape.numel()].zero_()
2356
            assert self.flat_param._is_grad_none_mask is not None
2357
            self.flat_param._is_grad_none_mask[tensor_index] = True
2358

2359
    def _reset_flat_param_grad_info_if_needed(self):
2360
        """
2361
        Reset ``flat_param.grad`` if needed.
2362

2363
        When ``use_orig_params=True``:
2364
        (1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the
2365
        original parameters' ``.grad`` are ``None``, and
2366
        (2) sets ``flat_param.requires_grad=False`` if *none* of the original
2367
        parameters require gradient.
2368
        For (1), this is targeting ``optim.zero_grad(set_to_none=True)``, in
2369
        which case we want to free the gradients as soon after the
2370
        ``zero_grad()`` call as possible.
2371
        """
2372
        if not self._use_orig_params:
2373
            return
2374
        flat_param = self.flat_param
2375
        assert flat_param._params is not None  # mypy
2376
        all_grad_none = True
2377
        requires_grad = False
2378
        for param in flat_param._params:
2379
            all_grad_none &= param.grad is None
2380
            requires_grad |= param.requires_grad
2381
        if all_grad_none:
2382
            flat_param.grad = None
2383
        # As long as one parameter requires gradient, then the flat parameter
2384
        # must require gradient
2385
        flat_param.requires_grad = requires_grad
2386

2387
    def _deregister_orig_params(self):
2388
        for param_info in self.flat_param._param_infos:
2389
            param_name, module, _ = param_info
2390
            if hasattr(module, param_name):
2391
                delattr(module, param_name)
2392
        for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos:
2393
            if hasattr(module, param_name):
2394
                delattr(module, param_name)
2395

2396
    ###########
2397
    # HELPERS #
2398
    ###########
2399
    def flat_param_to(self, *args, **kwargs):
2400
        """Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
2401
        self.flat_param.data = self.flat_param.to(*args, **kwargs)
2402
        if self._use_orig_params:
2403
            # Refresh the views because their storage may have changed
2404
            if self.is_sharded(self.flat_param):
2405
                self._use_sharded_views()
2406
            else:
2407
                self._use_unsharded_views(as_params=True)
2408

2409
    def _get_modules(self) -> Set[nn.Module]:
2410
        """Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter."""
2411
        return {pi.module for pi in self.flat_param._param_infos}.union(
2412
            {spi.module for spi in self.flat_param._shared_param_infos}
2413
        )
2414

2415
    def is_sharded(self, tensor: Tensor) -> bool:
2416
        """
2417
        Return whether ``tensor`` is *currently* sharded.
2418

2419
        For ``NO_SHARD``, we choose to have this always return ``False`` for clarity.
2420
        """
2421
        if (
2422
            not hasattr(self.flat_param, "_sharded_size")
2423
            or not self.uses_sharded_strategy
2424
        ):
2425
            # `_sharded_size` is defined iff `handle.shard()` has been called
2426
            return False
2427
        sharded_size = self.flat_param._sharded_size  # type: ignore[attr-defined]
2428
        return tensor.size() == sharded_size
2429

2430
    def param_module_names(self) -> Iterator[Tuple[str, str]]:
2431
        shared_param_infos = [
2432
            ParamInfo(param_name, module, module_name)
2433
            for (
2434
                param_name,
2435
                module,
2436
                module_name,
2437
                _,
2438
                _,
2439
                _,
2440
            ) in self.flat_param._shared_param_infos
2441
        ]
2442
        for param_info in chain(self.flat_param._param_infos, shared_param_infos):
2443
            param_name, _, module_name = param_info  # type: ignore[misc]
2444
            yield (param_name, module_name)
2445

2446
    def shared_param_module_names(self) -> Iterator[Tuple[str, str]]:
2447
        for param_name, _, module_name in [
2448
            ParamInfo(param_name, module, module_name)
2449
            for (
2450
                param_name,
2451
                module,
2452
                module_name,
2453
                _,
2454
                _,
2455
                _,
2456
            ) in self.flat_param._shared_param_infos
2457
        ]:
2458
            yield (param_name, module_name)
2459

2460
    @property
2461
    def _fqns_in_shard(self) -> List[str]:
2462
        """Return the FQNs of the parameters present in this rank's shard."""
2463
        fqns_in_shard: List[str] = []
2464
        for fqn, shard_param_info in zip(
2465
            self.flat_param._fqns, self.flat_param._shard_param_infos  # type: ignore[attr-defined]
2466
        ):
2467
            if shard_param_info.in_shard:
2468
                fqns_in_shard.append(fqn)
2469
        return fqns_in_shard
2470

2471
    @property
2472
    def sharded_grad(self) -> Optional[Tensor]:
2473
        """Return the handle's sharded gradient."""
2474
        flat_param = self.flat_param
2475
        # Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad`
2476
        # - CPU offloading: `_cpu_grad`
2477
        # - No CPU offloading + sharded strategies: `_saved_grad_shard`
2478
        # - No CPU offloading + `NO_SHARD`: `grad`
2479
        grad: Optional[Tensor]
2480
        if hasattr(flat_param, "_cpu_grad"):
2481
            grad = flat_param._cpu_grad  # type: ignore[attr-defined]
2482
        elif hasattr(flat_param, "_saved_grad_shard"):
2483
            # In the post-backward hook, the sharded gradient is still in
2484
            # `_saved_grad_shard`.
2485
            grad = flat_param._saved_grad_shard  # type: ignore[attr-defined]
2486
        else:
2487
            # If in IDLE or in FORWARD states, then there may be an
2488
            # (accumulated) gradient. If accessed in IDLE, then this should
2489
            # be due to re-registering the original parameters (e.g. in state
2490
            # dict load).
2491
            _p_assert(
2492
                flat_param.grad is None
2493
                or not self.uses_sharded_strategy
2494
                or self._training_state
2495
                in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE),
2496
                "Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` "
2497
                "unless in IDLE or FORWARD",
2498
            )
2499
            grad = flat_param.grad
2500
        return grad
2501

2502
    def _reset_is_grad_none(self) -> None:
2503
        """
2504
        Reset ``_is_grad_none_mask`` as needed.
2505

2506
        This method should only be
2507
        called in the post-backward after gradient computation, in which case
2508
        if a parameter requires gradient, then it will surely receive a
2509
        gradient and we may reset its mask entry to ``False``.
2510
        """
2511
        if not self._use_orig_params:
2512
            return
2513
        _p_assert(
2514
            self._training_state == HandleTrainingState.BACKWARD_POST,
2515
            "Expects to only be called in the post-backward after gradient computation",
2516
        )
2517
        flat_param = self.flat_param
2518
        assert flat_param._params is not None  # mypy
2519
        for i, param in enumerate(flat_param._params):  # type: ignore[arg-type]
2520
            # As long as the parameter requires gradient, it should receive a
2521
            # meaningful gradient (even if the gradient happens to be zeros)
2522
            if param.requires_grad:
2523
                assert flat_param._is_grad_none_mask is not None  # mypy
2524
                flat_param._is_grad_none_mask[i] = False
2525

2526
    #######################
2527
    # CHECKS & INVARIANTS #
2528
    #######################
2529
    def _check_sharded_strategy(self):
2530
        _p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
2531

2532
    def _check_on_compute_device(self, tensor: Tensor):
2533
        _p_assert(
2534
            tensor.device == self.device,
2535
            f"Expects tensor to be on the compute device {self.device}, was on {tensor.device}",
2536
        )
2537

2538
    def _check_on_cpu(self, tensor: Tensor):
2539
        _p_assert(
2540
            tensor.device == torch.device("cpu"),
2541
            f"Expects tensor to be on CPU but got {tensor.device}",
2542
        )
2543

2544
    @staticmethod
2545
    def _check_storage_freed(tensor: Tensor):
2546
        # Compile does not resize during trace
2547
        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
2548
            _p_assert(
2549
                _same_storage_size(tensor, 0),
2550
                "Expects storage to be freed but got storage with size > 0",
2551
            )
2552

2553
    @staticmethod
2554
    def _check_storage_allocated(tensor: Tensor):
2555
        _p_assert(_storage_size_allocated(tensor), "Expects storage to be allocated")
2556

2557
    def _check_low_precision_shard(self):
2558
        _p_assert(
2559
            self._uses_param_mixed_precision,
2560
            "Not using low precision for parameters",
2561
        )
2562
        _p_assert(
2563
            getattr(self.flat_param, "_mp_shard", None) is not None,
2564
            "Expects `_mp_shard` to exist",
2565
        )
2566
        device = self.flat_param._mp_shard.device  # type: ignore[attr-defined]
2567
        _p_assert(
2568
            device == self.device,
2569
            f"Expects the low precision shard to be on {self.device} but got {device}",
2570
        )
2571

2572
    def _check_unsharded(self, tensor: Tensor):
2573
        msg_prefix = "Expects tensor to be unsharded "
2574
        _p_assert(tensor is not None, msg_prefix + "but got `None`")
2575
        unsharded_size = self.flat_param._unpadded_unsharded_size
2576
        _p_assert(
2577
            tensor.size() == unsharded_size,
2578
            msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
2579
        )
2580

2581
    def _check_sharded(self, tensor: Tensor):
2582
        msg_prefix = "Expects tensor to be sharded "
2583
        _p_assert(tensor is not None, msg_prefix + "but got `None`")
2584
        sharded_size = self.flat_param._sharded_size  # type: ignore[attr-defined]
2585
        _p_assert(
2586
            tensor.size() == sharded_size,
2587
            msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
2588
        )
2589

2590
    ##############
2591
    # PROPERTIES #
2592
    ##############
2593
    @property
2594
    def uses_sharded_strategy(self) -> bool:
2595
        return self._sharding_strategy != HandleShardingStrategy.NO_SHARD
2596

2597
    @property
2598
    def _uses_param_mixed_precision(self) -> bool:
2599
        return self._fwd_bwd_param_dtype != self._orig_param_dtype
2600

2601
    @property
2602
    def _uses_reduce_mixed_precision(self) -> bool:
2603
        return self._reduce_dtype != self._orig_param_dtype
2604

2605
    @property
2606
    def _force_full_precision(self) -> bool:
2607
        return (
2608
            self._uses_param_mixed_precision or self._uses_reduce_mixed_precision
2609
        ) and (
2610
            self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
2611
            or
2612
            # Also disable mixed precision in model eval mode, if configured
2613
            (not self._fully_sharded_module.training and self._use_full_prec_in_eval)
2614
        )
2615

2616
    @property
2617
    def _skipped_use_sharded_views(self) -> bool:
2618
        """
2619
        This property is used for sharding strategies that do not free after forward with ``use_orig_params=True``.
2620

2621
        This returns if this handle is
2622
        currently in a state where it has skipped using sharded views, in which
2623
        case it can restore view invariants via ``_use_sharded_views()``.
2624
        """
2625
        return self._unsharded_flat_param_for_skipped_views is not None
2626

2627

2628
# NOTE: These are hacks to bypass `nn.Module.__setattr__` checks.
2629
def _unsafe_setattr_param(
2630
    module: nn.Module, param_name: str, param: nn.Parameter
2631
) -> None:
2632
    module._parameters[param_name] = param
2633
    # This bypasses any overrides in case `module` is an instance of an
2634
    # `nn.Module` subclass
2635
    super(nn.Module, module).__setattr__(param_name, param)
2636

2637

2638
def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None:
2639
    module._parameters.pop(param_name, None)
2640
    # This bypasses any overrides in case `module` is an instance of an
2641
    # `nn.Module` subclass
2642
    super(nn.Module, module).__setattr__(param_name, tensor)
2643

2644

2645
def _safe_setattr_tensor_or_param(
2646
    module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter]
2647
):
2648
    # Call `delattr()` and `setattr()` to go through `nn.Module` checks
2649
    if hasattr(module, param_name):
2650
        delattr(module, param_name)
2651
    setattr(module, param_name, tensor_or_param)
2652

2653

2654
def _convert_to_params(
2655
    tensors: List[Union[torch.Tensor, nn.Parameter]]
2656
) -> List[nn.Parameter]:
2657
    return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]
2658

2659

2660
def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor:
2661
    return (
2662
        param_or_tensor.detach()
2663
        if isinstance(param_or_tensor, nn.Parameter)
2664
        else param_or_tensor
2665
    )
2666

2667

2668
def _get_aligned_numel(unsharded_dtype: torch.dtype):
2669
    # NOTE: This alignment constraint comes from TorchInductor.
2670
    ALIGNMENT = 16  # bytes
2671
    unsharded_dtype_size = _get_dtype_size(unsharded_dtype)
2672
    aligned_numel = ALIGNMENT // unsharded_dtype_size
2673
    return aligned_numel
2674

2675

2676
@functools.lru_cache(8)
2677
def _get_dtype_size(dtype):
2678
    return torch.empty((), dtype=dtype).element_size()
2679

2680

2681
def _construct_padding_tensor(
2682
    padding_numel: int, dtype: torch.dtype, requires_grad: bool, device: torch.device
2683
):
2684
    # NOTE: Set the padding value as a magic number for debuggability. The
2685
    # value itself should never be used in any user-facing computation.
2686
    return (
2687
        torch.ones(
2688
            (padding_numel,), dtype=dtype, requires_grad=requires_grad, device=device
2689
        )
2690
        * _FLAT_PARAM_PADDING_VALUE
2691
    )
2692

2693

2694
# Use `lru_cache(1)` to only log the warning once (assuming the fixed warning
2695
# messasge is passed in)
2696
@functools.lru_cache(1)
2697
def _warn_skip_writeback_check(log: logging.Logger, warning: str):
2698
    log.warning(warning)
2699

2700

2701
# Use `lru_cache(1)` to only log the warning once
2702
@functools.lru_cache(1)
2703
def _warn_use_fake_all_gather(log: logging.Logger, warning: str):
2704
    log.warning(warning)
2705

2706

2707
# Use `lru_cache(1)` to only log the warning once
2708
@functools.lru_cache(1)
2709
def _warn_use_fake_reduce(log: logging.Logger, warning: str):
2710
    log.warning(warning)
2711

2712

2713
def _same_storage(a, b):
2714
    return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
2715

2716

2717
def _same_storage_size(a: torch.Tensor, b: int):
2718
    return a.untyped_storage().size() // a.element_size() == b
2719

2720

2721
def _storage_size_allocated(tensor: Tensor):
2722
    storage_size: int = tensor.untyped_storage().size()
2723
    return storage_size > 0
2724

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.