pytorch
2723 строки · 117.6 Кб
1import contextlib
2import functools
3import logging
4import os
5import warnings
6from enum import auto, Enum
7from itertools import accumulate, chain
8from typing import (
9Any,
10Callable,
11cast,
12Dict,
13Generator,
14Iterator,
15List,
16NamedTuple,
17no_type_check,
18Optional,
19Sequence,
20Set,
21Tuple,
22Union,
23)
24
25import torch
26import torch.distributed as dist
27import torch.nn as nn
28import torch.nn.functional as F
29from torch import Tensor
30from torch.distributed.fsdp._common_utils import (
31_FSDPDeviceHandle,
32_named_parameters_with_duplicates,
33_no_dispatch_record_stream,
34_set_fsdp_flattened,
35HandleTrainingState,
36)
37from torch.distributed.utils import (
38_alloc_storage,
39_data_ptr_allocated,
40_free_storage,
41_p_assert,
42)
43from torch.nn.parameter import _ParameterMeta # type: ignore[attr-defined]
44from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
45
46from ._fsdp_extensions import (
47_ext_post_unflatten_transform,
48_ext_pre_flatten_transform,
49FSDPExtensions,
50)
51
52__all__ = [
53"FlatParameter",
54"FlatParamHandle",
55"FlatParamShardMetadata",
56"ParamInfo",
57"SharedParamInfo",
58"HandleShardingStrategy",
59]
60
61log = logging.getLogger(__name__)
62
63
64"""
65[Note: Fully Sharded Module]
66We define the "fully sharded module" to be the original ``nn.Module`` that owns
67a ``FlatParamHandle``. It is the *single* module logically responsible for the
68*single* unshard/reshard pair for the handle's ``FlatParameter`` for a given
69forward or backward pass. The fully sharded module should be passed to the
70``FlatParamHandle`` constructor.
71
72For the wrapper code path:
73- The ``FullyShardedDataParallel`` module wrapping the fully sharded module
74runs the unshard/reshard on behalf of the fully sharded module by overriding
75``nn.Module.forward``.
76- The fully sharded module is exactly the module passed to the
77``FullyShardedDataParallel`` constructor's ``module`` argument.
78
79For the non-wrapper code path:
80- Hooks registered on the fully sharded module run the unshard/reshard.
81- The fully sharded module may either be the direct argument to ``fully_shard``
82or a submodule chosen by the provided wrapping policy.
83"""
84
85# Environment variable toggling whether to use unsafe `setattr()` for view
86# setting in `_use_sharded_views()` and `_use_unsharded_views()`
87# We should use 'safe' by default since it respects method overrides, but for
88# special cases such as for high CPU overhead or for intentionally bypassing
89# checks in the overrides, we may use 'unsafe'.
90_FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR"
91
92# Environment variable toggling whether to check for parameter/gradient
93# writeback in case their storages change after FSDP initialization
94# We should check by default since it prevents silent correctness errors, but
95# since such changes are atypical, we may want to skip the check to save CPU
96# overhead, especially since the check happens in the pre-forward and
97# pre-backward each iteration.
98_FSDP_SKIP_WRITEBACK_CHECK = "FSDP_SKIP_WRITEBACK_CHECK"
99
100# Env var toggling whether when model is in .eval() mode, should we run in fp32
101# or the reduced precision.
102_FSDP_USE_FULL_PREC_IN_EVAL = "FSDP_USE_FULL_PREC_IN_EVAL"
103
104# Some value to set padding in tensors to for debuggability
105_FLAT_PARAM_PADDING_VALUE = 42
106
107# Environment variables for disabling the all-gather and reduce-scatter
108# communication ops for ablation studies. Note that without these communication
109# ops the training won't converge, and you probably need to disable correctness
110# checks in your model.
111_FSDP_USE_FAKE_ALL_GATHER = "FSDP_USE_FAKE_ALL_GATHER"
112_FSDP_USE_FAKE_REDUCE = "FSDP_USE_FAKE_REDUCE"
113
114
115# TODO: Define this for now to avoid circular imports. See if we can remove.
116class HandleShardingStrategy(Enum):
117FULL_SHARD = auto()
118SHARD_GRAD_OP = auto()
119NO_SHARD = auto()
120HYBRID_SHARD = auto()
121_HYBRID_SHARD_ZERO2 = auto()
122
123
124RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
125HandleShardingStrategy.FULL_SHARD,
126HandleShardingStrategy.HYBRID_SHARD,
127)
128NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
129HandleShardingStrategy.SHARD_GRAD_OP,
130HandleShardingStrategy._HYBRID_SHARD_ZERO2,
131)
132
133
134class ParamInfo(NamedTuple):
135"""Information for an original parameter."""
136
137param_name: str # unprefixed
138module: nn.Module
139module_name: str
140
141
142class SharedParamInfo(NamedTuple):
143"""
144Additional information for a shared parameter.
145
146For each shared parameter, we designate one module and its parameter
147variable to be the primary owner, determined as the first one encountered
148in the parameter walk. These are prefixed with "prim". The primary module
149and parameter do not have their own :class:`SharedParamInfo` instance.
150"""
151
152param_name: str # unprefixed
153module: nn.Module
154module_name: str
155prim_param_name: str # unprefixed
156prim_module: nn.Module
157prim_module_name: str
158
159
160class _ShardParamInfo(NamedTuple):
161"""Shard-related information for an original parameter."""
162
163in_shard: bool
164# Use to index into the sharded flat parameter, e.g.
165# `flat_param[offset_in_shard : offset_in_shard + numel_in_shard]`
166offset_in_shard: Optional[int]
167numel_in_shard: Optional[int]
168# Use to get part of the parameter in the local shard from a flattened
169# version of the unsharded parameter, e.g.
170# `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]`
171intra_param_start_idx: Optional[int]
172intra_param_end_idx: Optional[int] # inclusive
173
174
175class FlatParamShardMetadata(NamedTuple):
176"""
177This holds metadata specific to this rank's shard of the flat parameter.
178
179Attributes:
180param_names (Tuple[str, ...]): Prefixed parameter names of this rank's
181shard of the parameters; see :class:`FlatParameter`.
182param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's
183shard of the parameters; see :class:`FlatParameter`.
184param_numels (Tuple[int, ...]): Parameter numels of this rank's shard
185of the parameters; see :class:`FlatParameter`.
186param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in
187units of numels) giving this rank's part of each flattened
188original parameter.
189"""
190
191param_names: Tuple[str, ...]
192param_shapes: Tuple[torch.Size, ...]
193param_numels: Tuple[int, ...]
194param_offsets: Tuple[Tuple[int, int], ...]
195
196
197class _FlatParameterMeta(_ParameterMeta):
198# Make `isinstance(t, FlatParameter)` return True for custom tensor
199# instances that have the _is_flat_param flag for BC
200def __instancecheck__(self, instance):
201# NB: do NOT test the super implementation
202return isinstance(instance, torch.Tensor) and getattr(
203instance, "_is_flat_param", False
204)
205
206
207class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
208"""
209This is the flat parameter used by :class:`FullyShardedDataParallel`.
210
211It is comprised of one or more original parameters, which are flattened and
212concatenated to construct the flat parameter.
213
214Under the current design, this parameter logically represents both the
215unsharded and sharded flat parameter, and its data changes storages
216dynamically.
217- In the :class:`FullyShardedDataParallel` constructor, the parameter
218is initialized as unsharded and then sharded in-place.
219- At runtime, the parameter is lazily (re)-initialized. The sharded
220parameter data is saved in ``self._local_shard``, and a new ``Tensor``
221``self._full_param_padded`` is created, which is the all-gather
222destination and owns the unsharded parameter storage thereafter. (See
223:meth:`FlatParamHandle.init_flat_param_attributes`.)
224- Throughout runtime, the parameter data changes storages as needed,
225e.g. to the sharded flat parameter, low precision sharded flat
226parameter, or the unsharded flat parameter.
227
228NOTE: Since ``use_orig_params=True`` supports intra-``FlatParameter``
229padding, we have two versions of the per-parameter numels, one that
230includes the padding (``_numels_with_padding``) and one that does not
231(``_numels``). The former may have length longer than the other data
232structures, while the latter has the same length as the number of actual
233original parameters like the other per-parameter data structures.
234
235NOTE: This is not a real class; instead, you will always get a Parameter
236back out if you try to create one of these. This is similar to the trick
237we implemented for Parameter to get it to work with subclasses; this
238is primarily so that FlatParameter supports combination with FakeTensor.
239
240Attributes:
241_unpadded_unsharded_size (torch.Size): Unsharded flat parameter's size
242without right-hand-side padding for divisibility by the world size.
243For ``use_orig_params=True``, this includes alignment padding.
244_padded_unsharded_size (torch.Size): Unsharded flat parameter's size
245with right-hand-side padding for divisibility by the world size.
246For ``use_orig_params=True``, this includes alignment padding. This
247is only set for sharded strategies since they require padding for
248the all-gather.
249_sharded_size (torch.Size): Sharded flat parameter's size with padding.
250This is also set for ``NO_SHARD``, in which case it is the same as
251the unsharded sizes. (We omit "padded" because there is no
252analogous unpadded one.)
253
254_num_params (int): Number of original parameters flattened into this
255flat parameter. This is the length of the per-parameter data
256structures.
257_param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info
258entry; see :class:`ParamInfo` for details.
259_shapes (Tuple[torch.Size, ...]): Each parameter's original shape.
260_fqns (Tuple[str, ...]): Each parameter's fully-qualified name (FQN)
261prefixed from the ``_fully_sharded_module``. The names are
262guaranteed to be unique in the subtree rooted at that module.
263_param_extensions (Tuple[Optional[Any], ...]): Each parameter's
264extension (i.e. some per-parameter state) used to customize
265pre-flatten and post-unflatten behavior or ``None``. This is
266experimental, and users should not depend on its existence in the
267future.
268_numels_with_padding (Tuple[int, ...]): Each parameter's numel
269including entries for the padding. This is used to construct views
270into the flat parameter via ``torch.split()``. This may have length
271longer than ``_num_params``.
272_numels (Tuple[int, ...]): Each parameter's numel excluding entries for
273padding. This has length equal to ``_num_params``.
274_shard_param_infos (Tuple[_ShardParamInfo, ...]): Each parameter's
275shard parameter info; see :class:`_ShardParamInfo` for details.
276_shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter
277info entries; see :class:`SharedParamInfo` for details.
278_modules (Set[nn.Module]): Modules that contain some original parameter
279that is flattened into the flat parameter.
280
281_shard_numel_padded (int): Numel padded for this rank's sharded flat
282parameter.
283_local_shard (Tensor): Sharded flat parameter with padding if using a
284sharded strategy. If using ``NO_SHARD``, then this is the unpadded
285unsharded flat parameter, and there is no notion of a sharded flat
286parameter or padded unsharded flat parameter.
287_full_param_padded (Tensor): Unsharded flat parameter with padding.
288This is not defined for ``NO_SHARD``. When using mixed precision
289for parameters, this has the low precision.
290_full_prec_full_param_padded (Tensor): Full precision unsharded flat
291parameter with padding. This is used for unsharding outside of
292computation when using mixed precision for parameters. This is
293never defined for ``NO_SHARD``.
294_post_backward_hook_handle (RemovableHandle):
295Flat parameter's post-backward hook handle. (Compile only)
296_post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]):
297Flat parameter's :class:`AccumulateGrad` object and post-backward
298hook handle. (Eager only)
299_mp_shard (Tensor): Low precision sharded flat parameter with padding.
300This is only defined when parameter mixed precision is enabled. For
301``NO_SHARD``, this is used for computation.
302_cpu_grad (Tensor): Sharded gradient with padding stored on CPU.
303This is only defined when offloading parameters is enabled.
304_saved_grad_shard (Tensor): Sharded gradient with padding from previous
305iterations for gradient accumulation without :meth:`no_sync`.
306
307_params (Optional[List[nn.Parameter]]): If ``use_orig_params=True``,
308then each original parameter variable; otherwise, ``None``. This
309does not include any padding tensors.
310_shared_params (Optional[List[nn.Parameter]]): The original shared
311parameter variables if ``use_orig_params=True`` and ``None``
312otherwise.
313_tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor``
314views created in the forward and tracked by autograd when
315``use_orig_params=True`` and is ``None`` otherwise. This is to
316preserve those ``Tensor`` variables for the backward to ensure that
317the ``FlatParameter`` 's ``AccumulateGrad`` object does not change
318in which case the post-backward hook does not run. This is relevant
319for cases like reentrant activation checkpointing.
320_is_grad_none_mask (Optional[List[bool]]): If ``use_orig_params=True``,
321a mask over the original parameters' gradients indicating if it is
322logically ``None`` or not; otherwise, ``None``. This does not
323include entries for padding. This mask is needed because only some
324of the parameters may have ``None`` gradient, in which case the
325flat gradient must be non-``None`` and must use zeros to
326approximate those original ``None`` gradients. This mask informs
327FSDP to set the original parameter gradients to ``None`` (instead
328of zeros) as needed.
329"""
330
331_unpadded_unsharded_size: torch.Size
332_padded_unsharded_size: torch.Size
333_sharded_size: torch.Size
334_num_params: int
335_param_infos: Tuple[ParamInfo, ...]
336_shapes: Tuple[torch.Size, ...]
337_fqns: Tuple[str, ...]
338_param_extensions: Tuple[Optional[Any], ...]
339_numels_with_padding: Tuple[int, ...]
340_numels: Tuple[int, ...]
341_shard_param_infos: Tuple[_ShardParamInfo, ...]
342_shared_param_infos: Tuple[SharedParamInfo, ...]
343_modules: Set[nn.Module]
344_shard_numel_padded: int
345_local_shard: Tensor
346_full_param_padded: Tensor
347_full_prec_full_param_padded: Tensor
348# Eager only
349_post_backward_hook_state: Tuple[Any, Any]
350# Compile only
351_post_backward_hook_handle: Any
352_mp_shard: Tensor
353_cpu_grad: Tensor
354_saved_grad_shard: Tensor
355_params: Optional[List[nn.Parameter]]
356_shared_params: Optional[List[nn.Parameter]]
357_tensors: Optional[List[Optional[Tensor]]]
358_is_grad_none_mask: Optional[List[bool]]
359
360_is_padding_mask: List[bool]
361
362def __new__(cls, data=None, requires_grad=True):
363assert cls is FlatParameter, "subclasses FlatParameter not supported"
364r = nn.Parameter.__new__(nn.Parameter, data, requires_grad) # type: ignore[call-arg]
365r._is_flat_param = True # type: ignore[attr-defined]
366return r
367
368# NB: This is not a regular method, because FlatParameters are not actually
369# instances of this class (see __new__ above). So you must indirectly
370# call this directly through the classmethod.
371@classmethod
372def _init_metadata(
373cls,
374self,
375param_infos: List[ParamInfo],
376numels: List[int],
377shapes: List[torch.Size],
378fqns: List[str],
379shared_param_infos: List[SharedParamInfo],
380param_extensions: List[Optional[Any]],
381params: Optional[List[nn.Parameter]],
382shared_params: Optional[List[nn.Parameter]],
383is_padding_mask: List[bool],
384) -> None:
385"""
386Initialize attributes holding metadata about the original parameters comprising the flat parameter.
387
388We expose this method separate from the constructor to keep the
389constructor only responsible for the flat parameter's tensor data. This
390method should only be called once per model, while the constructor may
391be called multiple times, e.g. when reloading from a checkpoint, in
392which case only the tensor data needs to be passed to the constructor.
393Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the
394metadata is correctly assumed to be unchanged.
395
396Args:
397See the Attributes in the class docstring.
398"""
399assert len(param_infos) == len(shapes)
400assert len(param_infos) == len(fqns)
401assert len(param_infos) == len(param_extensions)
402self._num_params = len(param_infos)
403self._param_infos = param_infos
404self._shapes = shapes
405self._fqns = fqns
406self._param_extensions = param_extensions
407self._is_padding_mask = is_padding_mask
408
409numels_without_padding: List[int] = []
410for numel, is_padding in zip(numels, is_padding_mask):
411if not is_padding:
412numels_without_padding.append(numel)
413self._numels = tuple(numels_without_padding)
414self._numels_with_padding = tuple(numels)
415assert len(self._numels) == self._num_params
416
417self._shared_param_infos = tuple(shared_param_infos)
418self._modules = {pi.module for pi in self._param_infos}.union(
419{spi.module for spi in self._shared_param_infos}
420)
421assert (params is None) == (shared_params is None)
422if params is not None:
423assert shared_params is not None and len(shared_params) == len(
424shared_param_infos
425)
426self._params = []
427for param, is_padding in zip(params, is_padding_mask):
428if not is_padding:
429self._params.append(param)
430self._shared_params = shared_params
431# Mark the original parameters to avoid flattening them into
432# another `FlatParameter` during recursive construction
433for param in chain(self._params, self._shared_params):
434_set_fsdp_flattened(param)
435self._is_grad_none_mask = [False for _ in range(self._num_params)]
436self._tensors = [None for _ in range(self._num_params)]
437else:
438self._params = None
439self._shared_params = None
440self._is_grad_none_mask = None
441self._tensors = None
442self._unpadded_unsharded_size = self.size()
443_set_fsdp_flattened(self)
444# Tracks whether the `FlatParameter`'s post-backward hook has been
445# called to modify the behavior of the post-backward callback
446self._post_backward_called = False
447
448
449class FlatParamHandle:
450"""
451A handle that manages a flat parameter (:class:`FlatParameter`).
452
453This includes sharding and view management.
454
455Args:
456params (Sequence[nn.Parameter]): The parameters to flatten into the
457flat parameter.
458fully_sharded_module (nn.Module): See [Note: Fully Sharded Module].
459device (torch.device): The compute and communication device, which
460should be a non-CPU device. We refer to it as the compute device.
461sharding_strategy (ShardingStrategy): Sharding strategy to apply to
462this handle's ``FlatParameter``.
463offload_params (bool): Whether to offload the handle's
464``FlatParameter`` to CPU.
465mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision
466setting passed to the FSDP constructor.
467mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed
468precision setting passed to the FSDP constructor.
469keep_low_precision_grads (bool): Whether to keep gradients in low
470precision.
471use_orig_params (bool): If ``True``, then FSDP preserves the original
472parameter variables and returns them from ``named_parameters()``
473(e.g. to support different optimizer hyperparameters within one
474:class:`FlatParameter`). If ``False``, then FSDP reconstructs the
475parameters every iteration and returns the :class:`FlatParameter` s
476from ``named_parameters()``.
477"""
478
479##################
480# INITIALIZATION #
481##################
482def __init__(
483self,
484params: Sequence[Union[nn.Parameter, Tensor]],
485fully_sharded_module: nn.Module,
486device: torch.device,
487sharding_strategy: HandleShardingStrategy,
488offload_params: bool,
489mp_param_dtype: Optional[torch.dtype],
490mp_reduce_dtype: Optional[torch.dtype],
491keep_low_precision_grads: bool,
492process_group: dist.ProcessGroup,
493use_orig_params: bool,
494*,
495fsdp_extension: Optional[FSDPExtensions] = None,
496):
497super().__init__()
498params = list(params)
499if len(params) == 0:
500raise ValueError(
501f"Cannot construct a {self.__class__.__name__} with an empty parameter list"
502)
503self._init_setattr_fns()
504self._skip_writeback_check = (
505os.environ.get(_FSDP_SKIP_WRITEBACK_CHECK, "") == "1"
506)
507self._use_full_prec_in_eval = (
508os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
509)
510self._use_fake_all_gather = os.environ.get(_FSDP_USE_FAKE_ALL_GATHER, "") == "1"
511self._use_fake_reduce = os.environ.get(_FSDP_USE_FAKE_REDUCE, "") == "1"
512if self._skip_writeback_check:
513_warn_skip_writeback_check(
514log,
515f"Since {_FSDP_SKIP_WRITEBACK_CHECK}=1, FSDP will not check "
516"for parameter or gradient writeback. Changing parameter or "
517"gradient storages may lead to silent correctness errors.",
518)
519if self._use_fake_all_gather:
520_warn_use_fake_all_gather(
521log,
522f"Since {_FSDP_USE_FAKE_ALL_GATHER}=1, FSDP will not execute "
523"all-gather ops. Your training will be incorrect, but "
524"can reveal how much time spent on all-gather ops.",
525)
526if self._use_fake_reduce:
527_warn_use_fake_reduce(
528log,
529f"Since {_FSDP_USE_FAKE_REDUCE}=1, FSDP will not execute "
530"reduce-scatter ops. Your training will be incorrect, but "
531"can reveal how much time spent on reduce-scatter ops.",
532)
533# Only align addresses for `use_orig_params=True` (for now)
534align_addresses = use_orig_params
535self._init_get_unflat_views_fn(align_addresses)
536self.device = device
537self._device_handle = _FSDPDeviceHandle.from_device(self.device)
538self.process_group = process_group
539if self._use_fake_all_gather or self._use_fake_reduce:
540self._fake_process_group = FakeProcessGroup(
541rank=process_group.rank(), world_size=process_group.size()
542)
543self.rank = process_group.rank()
544self.world_size = process_group.size()
545self._sharding_strategy = sharding_strategy
546self._offload_params = offload_params
547self._use_orig_params = use_orig_params
548self._keep_low_precision_grads = keep_low_precision_grads
549self._training_state = HandleTrainingState.IDLE
550self._debug_level = dist.get_debug_level()
551self._fully_sharded_module = fully_sharded_module
552# For strategies that do not free after forward, we skip using sharded
553# views after forward since the unsharded data exists. We still switch
554# `self.flat_param` to point to the sharded flat parameter since what
555# it points to parameterizes behavior. We use the following attribute
556# to track which tensor data the parameters are unsharded views into.
557self._unsharded_flat_param_for_skipped_views: Optional[Tensor] = None
558# The index in the state's `all_handles`, which must be the
559# same across ranks for the execution order validation to work
560self._handle_index: Optional[int] = None
561# Index in handles_to_pre_forward_order
562self._pre_forward_order_index: Optional[int] = None
563# Index in `handles_post_forward_order`
564self._post_forward_index: Optional[int] = None
565# Used for guarding against mistargeted forward prefetches
566self._needs_pre_forward_unshard = False
567# Used for guarding against mistargeted backward prefetches
568self._needs_pre_backward_unshard = False
569# Was the handle prefetched? Set on successful _prefetch_handle and unshard
570self._prefetched = False
571# Optimistically assume a valid input `params` and set dtype attributes
572# before `_init_flat_param()`, which performs the actual validation
573self._orig_param_dtype = params[0].dtype
574self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
575assert self._fwd_bwd_param_dtype is not None # mypy
576self._aligned_numel = (
577_get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype)
578if align_addresses
579else 0
580)
581self._fsdp_extension = fsdp_extension
582self._init_flat_param_and_metadata(
583params, fully_sharded_module, self._aligned_numel, use_orig_params # type: ignore[arg-type]
584)
585self._use_unsharded_views(as_params=False)
586
587def _init_setattr_fns(self):
588use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1"
589self._setattr_tensor: Callable[[nn.Module, str, Tensor], None]
590self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None]
591if use_unsafe_setattr:
592self._setattr_tensor = _unsafe_setattr_tensor
593self._setattr_param = _unsafe_setattr_param
594else:
595self._setattr_tensor = _safe_setattr_tensor_or_param
596self._setattr_param = _safe_setattr_tensor_or_param
597
598def _init_get_unflat_views_fn(self, align_addresses: bool):
599self._get_unflat_views = (
600self._get_unflat_views_aligned
601if align_addresses
602else self._get_unflat_views_unaligned
603)
604
605def _init_flat_param_and_metadata(
606self,
607params: List[Union[Tensor, nn.Parameter]],
608module: nn.Module,
609aligned_numel: int,
610use_orig_params: bool,
611) -> None:
612"""
613Initialize the ``FlatParameter`` and its metadata.
614
615NOTE: This should only be called once at construction time, after which
616the ``FlatParameter`` metadata is assumed to be static.
617
618NOTE: The elements of ``params`` should only be ``Tensor`` s when
619composing with ``DTensor`` -based tensor parallelism, in which case the
620elements may be ``DTensor`` local shards.
621"""
622if len(params) == 0:
623raise ValueError("Expects non-empty `params`")
624if aligned_numel < 0:
625raise ValueError(
626f"Expects non-negative `aligned_numel` but got {aligned_numel}"
627)
628(
629dtype,
630flat_param_requires_grad,
631device,
632) = self._validate_tensors_to_flatten(params)
633params_set = set(params)
634# For alignment padding, only `numels` gets strictly non-`None`
635# elements, and all other lists get `None` elements for padding.
636param_infos: List[ParamInfo] = []
637numels: List[int] = []
638shapes: List[torch.Size] = []
639fqns: List[str] = []
640shared_param_infos: List[SharedParamInfo] = []
641shared_param_memo: Dict[
642Union[Tensor, nn.Parameter], Tuple[nn.Module, str, str]
643] = {}
644params_to_flatten: List[Union[Tensor, nn.Parameter]] = []
645shared_params: List[Union[Tensor, nn.Parameter]] = []
646param_extensions: List[Any] = []
647is_padding_mask: List[bool] = []
648total_numel = total_numel_without_padding = 0
649for submodule_name, submodule in module.named_modules(remove_duplicate=False):
650for param_name, param in _named_parameters_with_duplicates(
651submodule, recurse=False
652):
653if param not in params_set:
654continue
655if param in shared_param_memo: # shared reference
656prim_module, prim_module_name, prim_param_name = shared_param_memo[
657param
658]
659shared_params.append(param)
660shared_param_infos.append(
661SharedParamInfo(
662param_name,
663submodule,
664submodule_name,
665prim_param_name,
666prim_module,
667prim_module_name,
668)
669)
670else:
671if aligned_numel > 0:
672numel_to_pad = aligned_numel - (total_numel % aligned_numel)
673if numel_to_pad > 0 and numel_to_pad < aligned_numel:
674padding_tensor = _construct_padding_tensor(
675numel_to_pad, dtype, False, device
676)
677params_to_flatten.append(padding_tensor)
678is_padding_mask.append(True)
679numels.append(numel_to_pad)
680total_numel += numel_to_pad
681transform_t, extension = _ext_pre_flatten_transform(
682param,
683self._fsdp_extension,
684)
685param = cast(nn.Parameter, transform_t)
686param_extensions.append(extension)
687shared_param_memo[param] = (submodule, submodule_name, param_name)
688params_to_flatten.append(param)
689is_padding_mask.append(False)
690param_infos.append(ParamInfo(param_name, submodule, submodule_name))
691numels.append(param.numel())
692shapes.append(param.shape)
693fqn = (
694submodule_name + "." + param_name
695if submodule_name
696else param_name
697)
698fqns.append(fqn)
699total_numel += param.numel()
700total_numel_without_padding += param.numel()
701if len(params_to_flatten) == 0:
702raise ValueError(
703f"`params` were not found in `module`'s tree"
704f"params: {params}\nmodule: {module}"
705)
706if (
707self.rank == 0
708and aligned_numel > 0
709and total_numel != total_numel_without_padding
710):
711log.info(
712"FSDP FlatParameter address alignment created "
713"%s numel of padding (%s vs. %s)",
714total_numel - total_numel_without_padding,
715total_numel,
716total_numel_without_padding,
717)
718if aligned_numel > 0:
719# Pad to be divisible by world size to avoid a copy for the
720# post-backward reduce-scatter
721numel_to_pad = self.world_size - (total_numel % self.world_size)
722if numel_to_pad > 0 and numel_to_pad < self.world_size:
723if self.rank == 0:
724log.info(
725"FSDP FlatParameter world size divisibility created "
726"%s numel of padding",
727numel_to_pad,
728)
729padding_tensor = _construct_padding_tensor(
730numel_to_pad, dtype, False, device
731)
732params_to_flatten.append(padding_tensor)
733is_padding_mask.append(True)
734numels.append(numel_to_pad)
735total_numel += numel_to_pad
736# Pass `aligned_numel=0` since we already included padding tensors
737self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param(
738params_to_flatten,
739aligned_numel=0,
740requires_grad=flat_param_requires_grad,
741)
742FlatParameter._init_metadata(
743self.flat_param,
744param_infos,
745numels,
746shapes,
747fqns,
748shared_param_infos,
749param_extensions,
750_convert_to_params(params_to_flatten) if use_orig_params else None,
751_convert_to_params(shared_params) if use_orig_params else None,
752is_padding_mask,
753)
754
755def _validate_tensors_to_flatten(
756self, tensors: List[Union[Tensor, nn.Parameter]]
757) -> Tuple:
758"""Validate the tensors to flatten and returns any necessary metadata."""
759dtype: Optional[torch.dtype] = None
760# Return as the logical OR over each tensor's value
761flat_param_requires_grad: Optional[bool] = None
762device: Optional[torch.device] = None
763# For `use_orig_params=True`, permit non-uniform `requires_grad`
764for tensor in tensors:
765if isinstance(tensor, FlatParameter):
766raise ValueError("Cannot flatten a `FlatParameter`")
767if dtype is None and not tensor.is_floating_point():
768raise ValueError("Cannot flatten integer dtype tensors")
769if dtype is not None and tensor.dtype != dtype:
770raise ValueError(
771f"Must flatten tensors with uniform dtype but got {dtype} "
772f"and {tensor.dtype}"
773)
774if (
775not self._use_orig_params
776and flat_param_requires_grad is not None
777and tensor.requires_grad != flat_param_requires_grad
778):
779raise ValueError(
780"Must flatten tensors with uniform `requires_grad` when "
781"`use_orig_params=False`"
782)
783if device is not None and tensor.device != device:
784raise ValueError(
785"Must flatten tensors on the same device but got both "
786f"{device} and {tensor.device}"
787)
788dtype = tensor.dtype
789flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad
790device = tensor.device
791assert flat_param_requires_grad is not None, "Requires non-empty `tensors` list"
792return dtype, flat_param_requires_grad, device
793
794def flatten_tensors(
795self,
796tensors: List[Tensor],
797aligned_numel: int,
798) -> Tensor:
799"""
800Flatten ``tensors`` into a single flat tensor.
801
802The flattening optionally includes
803padding if ``aligned_numel`` is greater than 0, where ``aligned_numel``
804gives the numel required to have address alignment.
805
806NOTE: The padding alignment algorithm must be kept in sync with
807:meth:`_init_flat_param_metadata`. We separate the two methods because
808the initialization happens once, whereas this method may be called
809multiple times throughout training (e.g. for checkpointing).
810"""
811if len(tensors) == 0:
812raise ValueError("Expects non-empty `tensors`")
813if aligned_numel < 0:
814raise ValueError(
815f"Expects non-negative `aligned_numel` but got {aligned_numel}"
816)
817dtype, _, device = self._validate_tensors_to_flatten(tensors)
818flat_tensors: List[Tensor] = []
819if aligned_numel > 0:
820total_numel = 0
821for tensor in tensors:
822numel_to_pad = aligned_numel - (total_numel % aligned_numel)
823if numel_to_pad > 0 and numel_to_pad < aligned_numel:
824padding_tensor = _construct_padding_tensor(
825numel_to_pad, dtype, False, device
826)
827flat_tensors.append(padding_tensor)
828total_numel += numel_to_pad
829flat_tensors.append(torch.flatten(_detach_if_needed(tensor)))
830total_numel += tensor.numel()
831numel_to_pad = self.world_size - (total_numel % self.world_size)
832if numel_to_pad > 0 and numel_to_pad < self.world_size:
833padding_tensor = _construct_padding_tensor(
834numel_to_pad, dtype, False, device
835)
836flat_tensors.append(padding_tensor)
837total_numel += numel_to_pad
838else:
839flat_tensors = [
840torch.flatten(_detach_if_needed(tensor)) for tensor in tensors
841]
842return torch.cat(flat_tensors, dim=0)
843
844def flatten_tensors_into_flat_param(
845self,
846tensors: List[Tensor],
847aligned_numel: int,
848requires_grad: bool,
849) -> FlatParameter:
850flat_param_data = self.flatten_tensors(tensors, aligned_numel)
851return FlatParameter(flat_param_data, requires_grad=requires_grad)
852
853def _init_param_reduce_dtypes(
854self,
855mp_param_dtype: Optional[torch.dtype],
856mp_reduce_dtype: Optional[torch.dtype],
857) -> None:
858"""
859Initialize param and reduce dtypes.
860
861Precondition: ``self.flat_param`` is set. This ensures that this
862handle's parameters have a single dtype.
863
864Postcondition: This sets ``self._fwd_bwd_param_dtype`` and
865``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype``
866is ``None``, then we assume the original parameter dtype. One special
867case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype``
868is ``None``, in which case we assume the gradient reduction dtype
869matches the forward/backward parameter dtype.
870"""
871# Save whether these dtypes were specified so that we permit the
872# parameter dtype to change up until the lazy initialization
873self._low_prec_param_dtype_specified = mp_param_dtype is not None
874self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
875if (
876self._low_prec_param_dtype_specified
877and not self._low_prec_reduce_dtype_specified
878):
879# Special case: infer gradient reduction mixed precision
880self._fwd_bwd_param_dtype = mp_param_dtype
881self._reduce_dtype = self._fwd_bwd_param_dtype
882else:
883self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype
884self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype
885assert self._fwd_bwd_param_dtype is not None
886assert self._reduce_dtype is not None
887
888###################################
889# SHARD INITIALIZATION & METADATA #
890###################################
891@torch.no_grad()
892def shard(self):
893"""
894Shard the handle's ``FlatParameter``.
895
896This allocates new memory for
897the sharded flat parameter and frees the unsharded flat parameter's
898storage.
899
900Postcondition: ``self.flat_param`` is the sharded flat parameter. Shard
901metadata attributes are set for all sharding strategies.
902"""
903flat_param = self.flat_param
904if not self.uses_sharded_strategy:
905self._init_shard_metadata(0, 0, flat_param.numel() - 1)
906else:
907_p_assert(
908flat_param.storage_offset() == 0,
909"The `FlatParameter` is not the sole occupant of its storage",
910)
911sharded_flat_param, numel_padded = FlatParamHandle._get_shard(
912flat_param, self.rank, self.world_size
913)
914if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
915allocated = flat_param._typed_storage()._size() > 0
916if allocated:
917flat_param._typed_storage()._resize_(0)
918flat_param.set_(sharded_flat_param) # type: ignore[call-overload]
919start_idx = sharded_flat_param.numel() * self.rank
920end_idx = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive
921self._init_shard_metadata(numel_padded, start_idx, end_idx)
922if self._use_orig_params:
923self._use_sharded_views()
924
925def _init_shard_metadata(
926self,
927numel_padded: int,
928unsharded_start_idx: int,
929unsharded_end_idx: int,
930) -> None:
931"""
932Initialize shard-related metadata for this rank's shard of the flat parameter.
933
934This includes ``_sharded_size``, ``_shard_param_infos``, and ``_shard_numel_padded``.
935
936Args:
937numel_padded (int): Numel padded for this rank's sharded flat
938parameter.
939unsharded_start_idx (int): Start index in the unsharded flat
940parameter assigned to this rank.
941unsharded_end_idx (int): End index (inclusive) in the unsharded
942flat parameter assigned to this rank.
943
944Precondition: ``self.flat_param`` 's data is the sharded flat
945parameter.
946"""
947flat_param = self.flat_param
948flat_param._sharded_size = flat_param.size() # type: ignore[attr-defined]
949sharded_flat_param_numel = flat_param.numel() # includes `numel_padded`
950_p_assert(
951unsharded_start_idx >= 0 and unsharded_start_idx <= unsharded_end_idx,
952f"unsharded_start_idx: {unsharded_start_idx} unsharded_end_idx: {unsharded_end_idx}",
953)
954_p_assert(
955numel_padded <= sharded_flat_param_numel,
956f"numel_padded: {numel_padded} "
957f"sharded_flat_param_numel: {sharded_flat_param_numel}",
958)
959shard_param_infos = self._get_shard_metadata(
960unsharded_start_idx, unsharded_end_idx
961)
962assert (
963len(shard_param_infos) == flat_param._num_params
964), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
965flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
966flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
967
968def _get_shard_metadata(
969self,
970unsharded_start_idx: int,
971unsharded_end_idx: int,
972) -> Tuple[_ShardParamInfo, ...]:
973"""
974Compute the shard metadata based on ``unsharded_start_idx`` and ``unsharded_end_idx`` (inclusive).
975
976``unsharded_start_idx`` and ``unsharded_end_idx`` give the interval of the
977unsharded flat parameter specifying the shard.
978"""
979flat_param_offsets = self._get_flat_param_offsets()
980assert len(flat_param_offsets) == len(
981self.flat_param._numels_with_padding
982), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
983shard_param_infos: List[_ShardParamInfo] = []
984sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
985# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
986# into the unsharded flat parameter (inclusive) of the given parameter
987for i, (
988(unsharded_param_start_idx, unsharded_param_end_idx),
989is_padding,
990) in enumerate(zip(flat_param_offsets, self.flat_param._is_padding_mask)):
991if is_padding:
992continue
993in_sharded_flat_param = (
994unsharded_start_idx <= unsharded_param_end_idx
995and unsharded_end_idx >= unsharded_param_start_idx
996)
997if not in_sharded_flat_param:
998shard_param_info = _ShardParamInfo(False, None, None, None, None)
999else:
1000if unsharded_start_idx <= unsharded_param_start_idx:
1001# This branch can only happen once since the rank's
1002# unsharded start index can only intersect one parameter
1003intra_param_start_idx = 0
1004offset_in_shard = unsharded_param_start_idx - unsharded_start_idx
1005else:
1006intra_param_start_idx = (
1007unsharded_start_idx - unsharded_param_start_idx
1008)
1009offset_in_shard = 0
1010assert (
1011offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel
1012), (
1013f"Invalid `offset_in_shard` of {offset_in_shard} for "
1014f"sharded flat parameter with {sharded_flat_param_numel} numel"
1015)
1016intra_param_end_idx = (
1017min(unsharded_param_end_idx, unsharded_end_idx)
1018- unsharded_param_start_idx
1019)
1020numel_in_shard = intra_param_end_idx - intra_param_start_idx + 1
1021shard_param_info = _ShardParamInfo(
1022True,
1023offset_in_shard,
1024numel_in_shard,
1025intra_param_start_idx,
1026intra_param_end_idx,
1027)
1028shard_param_infos.append(shard_param_info)
1029return tuple(shard_param_infos)
1030
1031@staticmethod
1032def _get_unpadded_shard(
1033tensor: Tensor,
1034rank: int,
1035world_size: int,
1036) -> Tuple[Tensor, int]:
1037"""
1038Return the unpadded shard of ``tensor`` for the given ``rank`` and ``world_size``.
1039
1040The returned value is a tuple of the shard of ``tensor`` without any
1041padding and the numel to pad for that shard.
1042
1043If ``tensor`` is already flattened or may be viewed in the flattened
1044shape (which is true in the expected usage), then this method does not
1045allocate any new tensor memory.
1046"""
1047chunks = torch.flatten(tensor).chunk(world_size)
1048if len(chunks) < (rank + 1):
1049# This rank gets an empty chunk fully padded with zeros since there
1050# are not enough chunks across ranks
1051chunk = chunks[0].new_empty(0)
1052else:
1053chunk = chunks[rank]
1054numel_to_pad = chunks[0].numel() - chunk.numel()
1055assert (
1056numel_to_pad >= 0
1057), "Chunk's size should be at most the first chunk's size"
1058return chunk, numel_to_pad
1059
1060@staticmethod
1061def _get_shard(
1062tensor: Tensor,
1063rank: int,
1064world_size: int,
1065) -> Tuple[Tensor, int]:
1066"""
1067Return the shard of ``tensor`` with padding for the given ``rank`` and ``world_size`` and the numel padded for that shard.
1068
1069This method allocates new memory (via :meth:`clone`) since the
1070unsharded ``tensor`` may be deallocated after this method returns.
1071"""
1072chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(
1073tensor, rank, world_size
1074)
1075shard = chunk.clone()
1076if numel_to_pad > 0:
1077shard = F.pad(shard, [0, numel_to_pad])
1078return shard, numel_to_pad
1079
1080@staticmethod
1081def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size:
1082"""
1083Return the shape of ``tensor`` after sharding including padding.
1084
1085This requires ``tensor`` to have 1D shape and ensures that the returned
1086shape is 1D.
1087"""
1088assert len(tensor.shape) == 1, f"{tensor.shape}"
1089unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard(
1090tensor, rank, world_size
1091)
1092unpadded_sharded_size = unpadded_sharded_tensor.size()
1093assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
1094return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
1095
1096def _get_flat_param_offsets(self) -> List[Tuple[int, int]]:
1097"""
1098Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding).
1099
1100NOTE: The returned list includes elements for alignment padding.
1101"""
1102cumulative_sum = list(accumulate(self.flat_param._numels_with_padding))
1103starts = [0] + cumulative_sum[:-1]
1104ends = [end - 1 for end in cumulative_sum] # inclusive
1105param_offsets = list(zip(starts, ends))
1106return param_offsets
1107
1108@no_type_check
1109def shard_metadata(
1110self,
1111) -> FlatParamShardMetadata:
1112"""
1113Return the shard-related metadata specific to this rank's shard of the flat parameter.
1114
1115NOTE: The returned tuple does not include elements for alignment
1116padding but does account for the padding.
1117"""
1118fqns_list = []
1119shapes_list = []
1120numels_list = []
1121shard_param_offsets = []
1122for fqn, shape, numel, shard_param_info in zip(
1123self.flat_param._fqns,
1124self.flat_param._shapes,
1125self.flat_param._numels,
1126self.flat_param._shard_param_infos,
1127):
1128if not shard_param_info.in_shard:
1129continue
1130fqns_list.append(fqn)
1131shapes_list.append(shape)
1132numels_list.append(numel)
1133shard_param_offsets.append(
1134(
1135shard_param_info.intra_param_start_idx,
1136shard_param_info.intra_param_end_idx,
1137)
1138)
1139return FlatParamShardMetadata(
1140tuple(fqns_list),
1141tuple(shapes_list),
1142tuple(numels_list),
1143shard_param_offsets,
1144)
1145
1146@no_type_check
1147@torch.no_grad()
1148def init_flat_param_attributes(self) -> None:
1149"""
1150This initializes some attributes on the handle's ``FlatParameter``.
1151This should be called during lazy initialization since it requires the
1152parameter to be on the compute device if not offloading to CPU and we
1153want to give users the chance to move the parameter appropriately after
1154the FSDP constructor.
1155
1156For each tensor attribute on the ``FlatParameter``, see the unshard and
1157reshard methods in this class for the allocation and free pattern.
1158"""
1159flat_param = self.flat_param
1160if flat_param.dtype != self._orig_param_dtype:
1161# Entering this branch means that the user changed the parameter
1162# dtype after FSDP initialization, in which case we may need to
1163# refresh some saved dtype attributes (dtypes specified as a part
1164# of mixed precision take precedence).
1165if not self._low_prec_param_dtype_specified:
1166self._fwd_bwd_param_dtype = flat_param.dtype
1167# For `reduce_dtype`, require `param_dtype` was not specified since
1168# then we infer the `reduce_dtype` from the specified `param_dtype`
1169if (
1170not self._low_prec_reduce_dtype_specified
1171and not self._low_prec_param_dtype_specified
1172):
1173self._reduce_dtype = flat_param.dtype
1174self._orig_param_dtype = flat_param.dtype
1175cpu_device = torch.device("cpu")
1176if self._offload_params:
1177_p_assert(
1178flat_param.device == cpu_device,
1179f"Expects the `FlatParameter` to be on CPU when parameter CPU "
1180f"offloading is enabled, not {flat_param.device}",
1181)
1182else:
1183self._check_on_compute_device(self.flat_param)
1184flat_param._local_shard = flat_param.data
1185if self._offload_params:
1186# Pin the memory for faster H2D transfer
1187flat_param._local_shard = flat_param._local_shard.pin_memory()
1188# Pre-allocate the sharded gradient on CPU to enable non-blocking
1189# D2H transfer during the backward pass
1190flat_param._cpu_grad = torch.zeros_like(
1191flat_param._local_shard, device=cpu_device
1192).pin_memory()
1193if self._uses_param_mixed_precision:
1194# For parameter mixed precision, we maintain a low precision
1195# sharded tensor on the compute device to be all-gathered (for
1196# sharded strategies) or directly used (for `NO_SHARD`) for
1197# computation.
1198flat_param._mp_shard = torch.empty_like(
1199flat_param._local_shard,
1200device=self.device,
1201dtype=self._fwd_bwd_param_dtype,
1202)
1203_free_storage(flat_param._mp_shard)
1204if self.uses_sharded_strategy:
1205# We maintain a padded unsharded tensor that serves as the
1206# all-gather destination and owns the original parameter storages.
1207unsharded_param_dtype = (
1208self._fwd_bwd_param_dtype
1209if self._uses_param_mixed_precision
1210else flat_param.dtype
1211) # use low precision if parameter mixed precision is enabled
1212padded_unsharded_numel = flat_param.numel() * self.world_size
1213flat_param._full_param_padded = torch.empty(
1214padded_unsharded_numel,
1215device=self.device,
1216dtype=unsharded_param_dtype,
1217)
1218flat_param._padded_unsharded_size = flat_param._full_param_padded.size()
1219_free_storage(flat_param._full_param_padded)
1220
1221if self._uses_param_mixed_precision:
1222# For parameter mixed precision, we maintain a full precision
1223# padded unsharded tensor for when we force full precision.
1224flat_param._full_prec_full_param_padded = torch.empty(
1225padded_unsharded_numel,
1226device=self.device,
1227dtype=flat_param.dtype, # full precision
1228)
1229_free_storage(flat_param._full_prec_full_param_padded)
1230
1231###################
1232# UNSHARD/RESHARD #
1233###################
1234def pre_unshard(self) -> bool:
1235"""
1236Return ``False`` if this is a no-op and ``True`` otherwise.
1237
1238Postcondition: ``self.flat_param`` 's data is on the device for
1239communication and is what should be all-gathered. This means that it
1240matches the dtype of the expected unsharded parameter.
1241"""
1242if (
1243self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
1244and self._skipped_use_sharded_views
1245):
1246# Since this path imposes special semantics for the unsharded flat
1247# parameter (e.g. forcing full precision), use sharded views to
1248# reuse the existing logic for that special handling
1249self._use_sharded_views()
1250ret = False
1251if self._use_orig_params and not self._skip_writeback_check:
1252ret = self._writeback_orig_params()
1253if (
1254self.uses_sharded_strategy
1255and not self._offload_params
1256and not self.needs_unshard()
1257):
1258pass # no-op
1259elif self._uses_param_mixed_precision and not self._force_full_precision:
1260self._use_low_precision_shard()
1261ret = True
1262elif self._offload_params and self.flat_param.device != self.device:
1263# NOTE: This creates a new tensor distinct from any attributes.
1264self.flat_param_to(self.device, non_blocking=True)
1265ret = True
1266self._check_on_compute_device(self.flat_param)
1267return ret
1268
1269def _use_low_precision_shard(self):
1270"""Allocate on the compute device and switch to using the low precision sharded flat parameter."""
1271self._check_low_precision_shard()
1272flat_param = self.flat_param
1273_alloc_storage(
1274flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined]
1275)
1276# `copy_()` implicitly casts to the low precision
1277flat_param._mp_shard.copy_( # type: ignore[attr-defined]
1278flat_param._local_shard.to( # type: ignore[attr-defined]
1279self.device, non_blocking=True
1280)
1281)
1282# Invariant: `_mp_shard` is always on the compute device.
1283flat_param.data = flat_param._mp_shard # type: ignore[attr-defined]
1284
1285def unshard(self):
1286"""
1287Run the unshard logic.
1288
1289This includes all-gathering the flat parameter
1290and switching to using the unsharded flat parameter. If the handle does
1291not need unsharding, then this only switches to using the unsharded
1292flat parameter. For ``NO_SHARD``, this is a no-op.
1293
1294If FSDP is in :meth:`summon_full_params` and the handle uses parameter
1295mixed precision, then the parameter is forced to full precision.
1296"""
1297if not self.needs_unshard():
1298# Even when not needing an unshard, we should switch to using
1299# the unsharded flat parameter
1300unsharded_flat_param = (
1301self._get_padded_unsharded_flat_param()
1302if self.uses_sharded_strategy
1303else self.flat_param
1304)
1305self._use_unsharded_flat_param(unsharded_flat_param)
1306return
1307unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
1308padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
1309self._use_unsharded_flat_param(padded_unsharded_flat_param)
1310
1311def needs_unshard(self) -> bool:
1312"""Return if the handle's flat parameter needs to be unsharded."""
1313if not self.uses_sharded_strategy:
1314return False
1315unsharded_flat_param = self._get_padded_unsharded_flat_param()
1316already_unsharded = _same_storage_size(
1317unsharded_flat_param, unsharded_flat_param.numel()
1318)
1319return not already_unsharded
1320
1321def _alloc_padded_unsharded_flat_param(self):
1322"""
1323Allocate the *padded* unsharded flat parameter.
1324
1325The unpadded unsharded
1326flat parameter is always a view into the padded one. This padded
1327parameter is saved to a different attribute on the ``FlatParameter``
1328depending on if we force full precision.
1329"""
1330self._check_sharded_strategy()
1331flat_param = self.flat_param
1332unsharded_flat_param = self._get_padded_unsharded_flat_param()
1333self._check_storage_freed(unsharded_flat_param)
1334_alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined]
1335return unsharded_flat_param
1336
1337def _get_padded_unsharded_flat_param(self) -> torch.Tensor:
1338"""
1339Return a reference to the padded unsharded flat parameter depending on the calling context.
1340
1341This should only be called if using a sharded strategy.
1342"""
1343self._check_sharded_strategy()
1344flat_param = self.flat_param
1345if self._force_full_precision and self._uses_param_mixed_precision:
1346# When parameter mixed precision is enabled, we use a different
1347# tensor as the all-gather destination to preserve the invariant
1348# that `_full_param_padded` is in the low precision
1349unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
1350_p_assert(
1351unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
1352f"Expects full precision but got {self._fwd_bwd_param_dtype}",
1353)
1354# For no-reshard-after-forward strategies, `_full_param_padded` may
1355# still be allocated from a previous forward. As we are forcing
1356# full precision here, the full-precision unsharded copy may be
1357# modified, invalidating the existing low-precision unsharded copy,
1358# so we should free it here to ensure a new all-gather for the next
1359# forward/backward computation to persist the modifications.
1360if flat_param._full_param_padded.untyped_storage().size() > 0:
1361_free_storage(flat_param._full_param_padded)
1362else:
1363unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined]
1364return unsharded_flat_param
1365
1366def _all_gather_flat_param(
1367self,
1368padded_unsharded_flat_param: Tensor,
1369) -> Tensor:
1370"""
1371All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
1372
1373Then switch to use the all-gathered tensor.
1374"""
1375_p_assert(
1376hasattr(self, "process_group") and hasattr(self, "world_size"),
1377"Expects a process group and world size to have been set via `shard()`",
1378)
1379sharded_flat_param = self.flat_param.data
1380expected_numel = sharded_flat_param.numel() * self.world_size
1381_p_assert(
1382padded_unsharded_flat_param.numel() == expected_numel,
1383f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
1384)
1385
1386pg = (
1387self._fake_process_group
1388if self._use_fake_all_gather
1389else self.process_group
1390)
1391
1392# HACK this should be handled by C10D
1393if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
1394tensor_list = list(
1395torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg))
1396)
1397work = dist.all_gather(tensor_list, sharded_flat_param, group=pg)
1398else:
1399dist.all_gather_into_tensor(
1400padded_unsharded_flat_param,
1401sharded_flat_param,
1402pg,
1403)
1404
1405if self._offload_params:
1406# In case of offloading, `flat_param.data` (i.e. sharded param) is
1407# created on the pre-unshard stream. We need to hand it over to the
1408# unshard stream for all-gather
1409_no_dispatch_record_stream(
1410sharded_flat_param,
1411self._device_handle.current_stream(), # unshard_stream
1412)
1413return padded_unsharded_flat_param
1414
1415def _use_unsharded_flat_param(
1416self,
1417padded_unsharded_flat_param: torch.Tensor,
1418) -> None:
1419"""
1420Switch to use the *unpadded* unsharded flat parameter.
1421
1422This is a view into the *padded* unsharded flat parameter.
1423"""
1424unsharded_size = self.flat_param._unpadded_unsharded_size
1425flat_param_part = padded_unsharded_flat_param[: unsharded_size.numel()]
1426# slicing [:] is not visible to autograd because of .data
1427self.flat_param.data = flat_param_part
1428in_forward = self._training_state == HandleTrainingState.FORWARD
1429in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
1430if self._use_orig_params:
1431if self._skipped_use_sharded_views and in_pre_backward:
1432# This call corresponds to the complementary pre-backward
1433# `_use_unsharded_views()` to the skipped pre-forward
1434# `_use_sharded_views()`, so we should skip this one too.
1435return
1436# We use `Tensor` views in the forward so that they are tracked by
1437# autograd. We use them in the pre-backward as well to support
1438# reentrant activation checkpointing, which needs the views to be
1439# tracked by autograd in the backward pass's recomputed forward.
1440self._use_unsharded_views(
1441as_params=(not in_forward and not in_pre_backward)
1442)
1443elif in_forward:
1444self._use_unsharded_views(as_params=False)
1445
1446def post_unshard(self):
1447"""
1448Run the post-unshard logic.
1449
1450This includes freeing the low precision shard if needed.
1451"""
1452if self._uses_param_mixed_precision and self.uses_sharded_strategy:
1453self._free_low_precision_sharded_param()
1454self._check_on_compute_device(self.flat_param)
1455
1456def _free_low_precision_sharded_param(self):
1457"""Frees the low precision sharded flat parameter."""
1458self._check_low_precision_shard()
1459# `_mp_shard` is allocated in the pre-unshard stream, consumed in the
1460# unshard stream for sharded strategies, and consumed in both the
1461# unshard and default streams for `NO_SHARD`. For sharded strategies,
1462# the current stream here is the unshard stream, and for `NO_SHARD`,
1463# it is the default stream. For `NO_SHARD`, only recording for the
1464# default stream suffices since the default stream waits for the
1465# unshard stream.
1466_no_dispatch_record_stream(
1467self.flat_param._mp_shard, self._device_handle.current_stream() # type: ignore[attr-defined]
1468)
1469_free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
1470
1471@torch.no_grad()
1472def unshard_grad(self):
1473"""
1474Unshard the handle's ``FlatParameter``'s gradient.
1475
1476If all ranks have
1477``None`` gradient, then all original parameters will as well. This
1478method performs an all-reduce and an all-gather. The additional
1479all-reduce is tolerable since this method is not meant to be used on
1480the computation critical path.
1481
1482Postcondition: ``_saved_grad_shard`` is defined and contains the value
1483to set ``flat_param.grad`` after gradients are resharded.
1484"""
1485if not self.uses_sharded_strategy:
1486self._use_unsharded_grad_views()
1487return
1488flat_param = self.flat_param
1489self._check_unsharded(flat_param)
1490
1491# Check if all ranks have a `None` gradient
1492num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device)
1493num_grad_none[0] = flat_param.grad is None
1494dist.all_reduce(num_grad_none, group=self.process_group)
1495if num_grad_none[0] == self.world_size:
1496flat_param._saved_grad_shard = None # type: ignore[assignment]
1497self._use_unsharded_grad_views()
1498return
1499
1500if flat_param.grad is None:
1501# In the case that only some ranks have `None` gradient, we use
1502# zeros to approximate as a best effort attempt
1503if self._debug_level == dist.DebugLevel.INFO:
1504warnings.warn(
1505f"[Rank {self.rank}] Only some but not all ranks have a "
1506"`None` `FlatParameter` gradient, so FSDP is using zeros to "
1507"approximate those ranks' sharded gradients being `None`"
1508)
1509flat_param._saved_grad_shard = None # type: ignore[assignment]
1510sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device) # type: ignore[attr-defined]
1511else:
1512self._check_sharded(flat_param.grad)
1513flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined]
1514sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
1515padded_unsharded_grad = torch.empty(
1516flat_param._padded_unsharded_size, # type: ignore[attr-defined]
1517device=self.device,
1518dtype=sharded_grad.dtype,
1519)
1520dist.all_gather_into_tensor(
1521padded_unsharded_grad, sharded_grad, self.process_group
1522)
1523unsharded_size = self.flat_param._unpadded_unsharded_size
1524flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view(
1525unsharded_size
1526)
1527self._use_unsharded_grad_views()
1528
1529def reshard_grad(self):
1530if self._use_orig_params:
1531self._use_sharded_grad_views()
1532if not self.uses_sharded_strategy:
1533return
1534self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined]
1535delattr(self.flat_param, "_saved_grad_shard")
1536
1537def prepare_gradient_for_backward(self):
1538"""
1539Prepare the gradient for the backward computation.
1540
1541This is done by saving and clearing any existing sharded gradient
1542in ``.grad`` to enable computing a new unsharded gradient.
1543"""
1544_p_assert(
1545self._training_state
1546in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
1547"Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
1548)
1549flat_param = self.flat_param
1550if flat_param.grad is not None and (
1551flat_param.grad.size() != flat_param._unpadded_unsharded_size
1552or flat_param.grad.device != flat_param.device # grad on CPU
1553):
1554self._check_on_compute_device(self.flat_param)
1555grad_offloaded = flat_param.grad.device != self.device
1556_p_assert(
1557not grad_offloaded or self._offload_params,
1558f"Expects the sharded gradient to be on {self.device} "
1559f"but got {flat_param.grad.device}",
1560)
1561prev_iter_synced_gradients = (
1562flat_param.grad.size()
1563== flat_param._local_shard.size() # type: ignore[attr-defined]
1564)
1565if prev_iter_synced_gradients:
1566# TODO (awgu): Gradient accumulation outside `no_sync()`
1567# does not work with CPU offloading. The issue should be
1568# that, in the post-backward hook, we cannot do an addition
1569# between a CPU tensor (the existing sharded gradient) and
1570# a GPU tensor (the new sharded gradient).
1571if not grad_offloaded:
1572flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined]
1573sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
1574else:
1575_p_assert(
1576hasattr(flat_param, "_cpu_grad"),
1577"`_cpu_grad` should be defined if the gradient is on CPU",
1578)
1579sharded_grad = flat_param._cpu_grad # type: ignore[attr-defined]
1580# If user specified to keep the gradient in low precision, then
1581# the gradient may still be of the low precision dtype if the
1582# user did not set the gradient to `None` after the previous
1583# backward, in which case FSDP should cast back to the full
1584# precision dtype so that FSDP can accumulate in that dtype in
1585# the post-backward hook and assign to `.grad` in that dtype in
1586# the post-backward callback.
1587local_shard_dtype = flat_param._local_shard.dtype # type: ignore[attr-defined]
1588if (
1589self._keep_low_precision_grads
1590and sharded_grad.dtype != local_shard_dtype
1591):
1592sharded_grad.data = sharded_grad.to(local_shard_dtype)
1593else:
1594padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined]
1595_p_assert(
1596flat_param.grad.size() == padded_unsharded_size,
1597"Expects `.grad` to be the unsharded gradient in "
1598f"`no_sync()` with size {padded_unsharded_size} "
1599f"but got size {flat_param.grad.size()}",
1600)
1601flat_param.grad = None
1602
1603def prepare_gradient_for_optim(self):
1604"""Prepare the gradient for optimizer computation by moving the sharded gradient to the ``.grad`` attribute."""
1605
1606def cast_grad_to_param_dtype_if_needed(flat_param):
1607# TODO (rohan-varma): test for full precision with keep_low_precision_grads
1608if not self._force_full_precision and self._keep_low_precision_grads:
1609_p_assert(flat_param.grad is not None, "Unexpected None grad!")
1610if flat_param.grad.dtype != self._fwd_bwd_param_dtype:
1611flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype)
1612if self._use_orig_params:
1613self._use_sharded_grad_views()
1614
1615flat_param = self.flat_param
1616# TODO (awgu): We should replace these conditional checks to encode
1617# the logical intention more directly.
1618if hasattr(flat_param, "_cpu_grad"):
1619# NOTE: This branch includes `NO_SHARD`.
1620self._check_sharded(flat_param)
1621self._check_on_cpu(flat_param)
1622flat_param.grad = flat_param._cpu_grad # type: ignore[attr-defined]
1623cast_grad_to_param_dtype_if_needed(flat_param)
1624elif hasattr(flat_param, "_saved_grad_shard"):
1625self._check_sharded(flat_param)
1626self._check_on_compute_device(flat_param)
1627if flat_param._saved_grad_shard is not None:
1628self._check_on_compute_device(flat_param._saved_grad_shard) # type: ignore[attr-defined]
1629# If no sharded gradient was computed this iteration, then there is
1630# no need to forward `_saved_grad_shard` to `grad`
1631if flat_param._post_backward_called: # type: ignore[attr-defined]
1632flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
1633if flat_param.grad is not None:
1634cast_grad_to_param_dtype_if_needed(flat_param)
1635else:
1636_p_assert(
1637not self.uses_sharded_strategy
1638or not flat_param._post_backward_called, # type: ignore[attr-defined]
1639"All sharded parameters that received a gradient in the "
1640"post-backward should use `_saved_grad_shard`",
1641)
1642# Delete `_saved_grad_shard` since its existence indicates a previous
1643# gradient to accumulate with in the post-backward hook
1644if hasattr(flat_param, "_saved_grad_shard"):
1645delattr(flat_param, "_saved_grad_shard")
1646
1647@contextlib.contextmanager
1648def to_cpu(self):
1649"""
1650Move the unpadded unsharded flat parameter to CPU while in the context and moves it back to the previous device upon exit.
1651
1652For now, this assumes the ``FlatParameter`` is the unpadded unsharded flat parameter
1653since (1) there is no reason to include the padding in the copy and (2)
1654there is no use case for the sharded flat parameter.
1655
1656Precondition: ``self.flat_param`` 's data is the unpadded unsharded
1657flat parameter on the compute device, and the handle uses a sharded
1658strategy.
1659Postcondition: Same as the precondition.
1660"""
1661self._check_sharded_strategy()
1662_p_assert(
1663self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
1664f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
1665)
1666self._check_on_compute_device(self.flat_param)
1667# Check that the unpadded unsharded flat parameter is a view into the
1668# padded unsharded flat parameter as expected
1669# NOTE: This check is not strictly needed for correctness but is a
1670# useful sanity check since the tensor should only be used internally.
1671_p_assert(
1672_same_storage(self.flat_param, self._get_padded_unsharded_flat_param()),
1673"Expects the unpadded parameter to be a view into the padded parameter",
1674)
1675self.flat_param_to(torch.device("cpu"))
1676self._free_unsharded_flat_param()
1677try:
1678yield
1679finally:
1680_p_assert(
1681self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
1682f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
1683)
1684padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
1685# Copy from CPU to the compute device
1686padded_unsharded_flat_param[: self.flat_param.numel()].copy_(
1687self.flat_param
1688)
1689self._use_unsharded_flat_param(padded_unsharded_flat_param)
1690
1691def reshard(self, free_unsharded_flat_param: bool):
1692"""
1693Run the reshard logic.
1694
1695This includes freeing the unsharded flat
1696parameter if ``free_unsharded_flat_param`` and switching to using the
1697sharded flat parameter. Note that this also implicitly offloads
1698the sharded flat parameter (if CPU offload is enabled) by pointing
1699it to the ``_local_shard`` attribute which resides on CPU.
1700"""
1701# Switch to the sharded `FlatParameter` before freeing to prevent
1702# "use-after-free"-type bugs with external profiling tools, where for
1703# `use_orig_params=True`, the `param` does not point to valid memory
1704# when setting `param.data = ...` in `_use_sharded_views()`.
1705self._use_sharded_flat_param()
1706if free_unsharded_flat_param:
1707self._free_unsharded_flat_param()
1708
1709def post_reshard(self):
1710"""
1711Run the post-reshard logic.
1712
1713This includes freeing any memory that
1714can now be freed given that the ``FlatParameter`` points to the full
1715precision sharded flat parameter.
1716
1717Precondition: ``self.flat_param`` 's data points to the full precision
1718sharded flat parameter.
1719"""
1720# For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since it
1721# is also the low precision *unsharded* flat parameter. Hence, we delay
1722# the free until the reshard.
1723if (
1724self._uses_param_mixed_precision
1725and not self.uses_sharded_strategy
1726and not self._force_full_precision # did not use the low precision shard
1727):
1728self._free_low_precision_sharded_param()
1729
1730def _free_unsharded_flat_param(self):
1731"""
1732Free the padded unsharded flat parameter. We allow this
1733function to be called even when storage is not allocated
1734
1735The tensor to free depends
1736on the calling context since the unshard may have forced full
1737precision, in which case a different tensor is used.
1738"""
1739self._check_sharded_strategy()
1740unsharded_flat_param = self._get_padded_unsharded_flat_param()
1741self._check_on_compute_device(unsharded_flat_param)
1742# Do not free the memory until all ops in the current stream finish
1743_no_dispatch_record_stream(
1744unsharded_flat_param, self._device_handle.current_stream()
1745)
1746_free_storage(unsharded_flat_param)
1747
1748def _use_sharded_flat_param(self) -> None:
1749"""Switches to using the sharded flat parameter."""
1750flat_param = self.flat_param
1751if self._use_orig_params:
1752in_forward = self._training_state == HandleTrainingState.FORWARD
1753skip_use_sharded_views = (
1754torch.is_grad_enabled()
1755and in_forward
1756and self._sharding_strategy
1757in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
1758)
1759# Only incur the extra `.data` call if needed
1760if skip_use_sharded_views:
1761unsharded_flat_param = flat_param.data
1762if self._offload_params:
1763device = flat_param._local_shard.device # type: ignore[attr-defined]
1764_p_assert(
1765device == torch.device("cpu"),
1766f"Expects the local shard to be on CPU but got {device}",
1767)
1768flat_param.data = flat_param._local_shard # type: ignore[attr-defined]
1769if self._use_orig_params:
1770if skip_use_sharded_views: # type: ignore[possibly-undefined]
1771self._unsharded_flat_param_for_skipped_views = unsharded_flat_param # type: ignore[possibly-undefined]
1772else:
1773self._use_sharded_views()
1774# For the post-forward reshard, we may try to use sharded gradient
1775# views (or unsharded gradient views if a gradient was accumulated
1776# in `no_sync()`), but for the post-backward reshard, we delay the
1777# call to after the reduce-scatter.
1778if (
1779in_forward # type: ignore[possibly-undefined]
1780# Skip using gradient views if skipped using sharded views
1781# since exposing unsharded parameters with sharded gradients
1782# may be confusing to the user
1783and not self._skipped_use_sharded_views
1784):
1785# TODO: Change `_unpadded_unsharded_size` if we change the
1786# gradient to be computed directly with padding.
1787accumulated_grad_in_no_sync = (
1788flat_param.grad is not None
1789and self.uses_sharded_strategy
1790and flat_param.grad.shape == flat_param._unpadded_unsharded_size
1791)
1792if accumulated_grad_in_no_sync:
1793self._use_unsharded_grad_views()
1794else:
1795self._use_sharded_grad_views()
1796
1797#########
1798# VIEWS #
1799#########
1800@no_type_check
1801def _get_unflat_views_unaligned(
1802self,
1803tensor: Optional[torch.Tensor] = None,
1804) -> Iterator[Tensor]:
1805"""
1806Return unflattened ``Tensor`` views into ``tensor``.
1807
1808If `tensor`` is ``None``, ``flat_param`` is used. The unflattening is based
1809on ``flat_param`` 's metadata.
1810
1811Examples for ``tensor`` include ``flat_param.grad`` or unsharded
1812tensor optimizer state.
1813"""
1814flat_param = self.flat_param
1815if tensor is None:
1816tensor = flat_param
1817views = (
1818_ext_post_unflatten_transform(
1819subtensor.view(shape),
1820param_extension,
1821self._fsdp_extension,
1822)
1823for (subtensor, shape, param_extension) in zip(
1824torch.split(tensor, flat_param._numels, dim=0),
1825flat_param._shapes,
1826flat_param._param_extensions,
1827)
1828)
1829return views
1830
1831@no_type_check
1832def _get_unflat_views_aligned(
1833self,
1834tensor: Optional[Tensor] = None,
1835) -> List[Tensor]:
1836"""
1837Return unflattened ``Tensor`` views into ``tensor`` with handling for padding.
1838
1839This method has the same contract as :meth:`_get_unflat_views_unaligned`
1840except it checks for ``None`` placeholders representing padding for
1841alignment, which may incur slightly more CPU overhead.
1842"""
1843flat_param = self.flat_param
1844if tensor is None:
1845tensor = flat_param
1846splits: List[Tensor] = torch.split(
1847tensor, flat_param._numels_with_padding, dim=0
1848)
1849idx = 0
1850views: List[Tensor] = []
1851for split, is_padding in zip(splits, flat_param._is_padding_mask):
1852if is_padding:
1853continue
1854views.append(
1855_ext_post_unflatten_transform(
1856split.view(flat_param._shapes[idx]),
1857flat_param._param_extensions[idx],
1858self._fsdp_extension,
1859)
1860)
1861idx += 1
1862return views
1863
1864@no_type_check
1865@torch.enable_grad()
1866def _use_unsharded_views(self, as_params: bool) -> None:
1867"""
1868Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it.
1869
1870Args:
1871as_params (bool): If ``True``, then registers the original
1872parameters as ``nn.Parameter`` s; if ``False``, then registers
1873the original parameters only as ``Tensor`` s. ``False`` should
1874be used during forward/backward computation and when hiding the
1875original parameters from :meth:`nn.Module.named_parameters`.
1876
1877Note:
1878when prefetching for next forward, current forward may be
1879annotated with `@torch.no_grad()`
1880`@torch.enable_grad()` ensures non-empty `view.grad_fn`
1881otherwise `_post_backward_hook` will not get called
1882"""
1883flat_param = self.flat_param
1884self._check_unsharded(flat_param)
1885views = self._get_unflat_views()
1886from torch.distributed._tensor import DTensor
1887
1888for i, (view, (param_name, module, _)) in enumerate(
1889zip(views, flat_param._param_infos)
1890):
1891if self._use_orig_params and as_params:
1892if type(view) is DTensor:
1893# A `DTensor` `view` is not compatible with assigning
1894# `param.data = view`, so we cannot preserve the parameter
1895# variable.
1896self._setattr_param(
1897module,
1898param_name,
1899nn.Parameter(view, requires_grad=flat_param.requires_grad),
1900)
1901continue
1902param = self.flat_param._params[i]
1903self._setattr_param(module, param_name, param)
1904param.data = view
1905elif as_params:
1906self._setattr_param(
1907module,
1908param_name,
1909nn.Parameter(view, requires_grad=flat_param.requires_grad),
1910)
1911else: # `as_params=False`
1912param_var: Tensor = view
1913if self._use_orig_params:
1914if self._training_state == HandleTrainingState.FORWARD:
1915# Save the `Tensor` for the pre-backward
1916self.flat_param._tensors[i] = view # save for pre-backward
1917elif self._training_state == HandleTrainingState.BACKWARD_PRE:
1918# Use the saved `Tensor` variable from the forward to
1919# preserve the autograd graph so that the post-backward
1920# hook fires (e.g. for reentrant AC)
1921tensor = self.flat_param._tensors[i]
1922tensor.data = view
1923param_var = tensor
1924self._setattr_tensor(module, param_name, param_var)
1925if (
1926self._use_orig_params
1927and self._training_state == HandleTrainingState.FORWARD
1928):
1929module._parameters[param_name] = param_var
1930for i, (
1931param_name,
1932module,
1933_,
1934prim_param_name,
1935prim_module,
1936_,
1937) in enumerate(self.flat_param._shared_param_infos):
1938prim_param: Union[Tensor, nn.Parameter] = getattr(
1939prim_module, prim_param_name
1940)
1941_p_assert(
1942not as_params or isinstance(prim_param, nn.Parameter),
1943f"as_params={as_params} type(prim_param)={type(prim_param)}",
1944)
1945if self._use_orig_params and as_params:
1946shared_param = self.flat_param._shared_params[i]
1947self._setattr_param(module, param_name, shared_param)
1948shared_param.data = prim_param
1949elif as_params:
1950self._setattr_param(module, param_name, prim_param)
1951else:
1952self._setattr_tensor(module, param_name, prim_param)
1953if (
1954self._use_orig_params
1955and self._training_state == HandleTrainingState.FORWARD
1956):
1957module._parameters[param_name] = prim_param
1958
1959@no_type_check
1960def _use_unsharded_grad_views(self) -> None:
1961"""
1962Unflatten the unsharded flat parameter's gradient.
1963
1964The original parameter variables' gradients are set to be views into
1965the unsharded flat parameter's gradient.
1966"""
1967# Expects the gradient to be in `flat_param.grad`
1968if self.flat_param.grad is None:
1969for param in chain(self.flat_param._params, self.flat_param._shared_params):
1970param.grad = None
1971return
1972self._check_unsharded(self.flat_param.grad)
1973views = self._get_unflat_views(self.flat_param.grad)
1974for i, (view, (param_name, module, _)) in enumerate(
1975zip(views, self.flat_param._param_infos)
1976):
1977_p_assert(
1978hasattr(module, param_name),
1979f"{self.flat_param._fqns[i]} is missing",
1980)
1981param = getattr(module, param_name)
1982if (
1983param.shape != view.shape
1984or param.dtype != view.dtype
1985or param.device != view.device
1986):
1987# NOTE: This is a hack using `.data` to side step the check
1988# that parameter/gradient sizes/dtypes/devices match. From
1989# calling `reshard()`, `param` has the sharded size, has the
1990# full precision dtype, and if CPU offloading is enabled, is on
1991# CPU. Thus, one or more of the following cases can hold when
1992# in `no_sync()`, where `view` is the original parameter's
1993# gradient:
1994# 1. `view` can have the unsharded size.
1995# 2. `view` can have the parameter low precision dtype.
1996# 3. `view` can be on GPU.
1997if param.grad is None:
1998param.grad = torch.empty_like(param)
1999param.grad.data = view
2000else:
2001param.grad = view
2002for i, (
2003param_name,
2004module,
2005module_name,
2006prim_param_name,
2007prim_module,
2008_,
2009) in enumerate(self.flat_param._shared_param_infos):
2010_p_assert(
2011hasattr(module, param_name),
2012f"{module_name + '.' + param_name if module_name else param_name} is missing",
2013) # did not save FQN info in `_shared_param_infos`
2014param = getattr(module, param_name)
2015prim_param = getattr(prim_module, prim_param_name)
2016if (
2017param.shape != prim_param.grad.shape
2018or param.dtype != prim_param.grad.dtype
2019or param.device != prim_param.grad.device
2020):
2021# NOTE: This is the same hack to use `.data` to side step the
2022# size check.
2023if param.grad is None:
2024param.grad = torch.empty_like(param)
2025param.grad.data = prim_param.grad
2026else:
2027param.grad = prim_param.grad
2028
2029@contextlib.contextmanager
2030def unflatten_as_params(self) -> Generator:
2031"""
2032Unflatten the original parameters.
2033
2034The function assumes that the flat parameter is unsharded. When in the context,
2035unflattens the original parameters as ``nn.Parameter`` views into the
2036flat parameter, and after the context, restores the original parameters
2037as ``Tensor`` views into the flat parameter.
2038"""
2039self._use_unsharded_views(as_params=True)
2040try:
2041yield
2042finally:
2043self._use_unsharded_views(as_params=False)
2044
2045@no_type_check
2046@torch.no_grad()
2047def _use_sharded_views(self) -> None:
2048"""
2049Set the original parameter variables' data to be flattened views into the sharded flat parameter.
2050
2051The views are kept as flattened to simplify the case where a parameter
2052is sharded across ranks. Parameters whose data is not present in the
2053sharded flat parameter have their data set to a size-0 empty tensor. We
2054do not delete them to ensure to preserve expected behaviors like model
2055printability. Parameters whose data is present must preserve their
2056variables to be passable to an optimizer.
2057"""
2058self._unsharded_flat_param_for_skipped_views = None
2059if not self.uses_sharded_strategy:
2060# For `NO_SHARD`, use the *unflattened* unsharded views since we
2061# have the unsharded parameter
2062self._use_unsharded_views(as_params=True)
2063return
2064flat_param = self.flat_param
2065self._check_sharded(flat_param)
2066# Construct once and reuse for all parameters not in the local shard
2067size_0_empty_tensor = torch.empty(
20680,
2069dtype=self.flat_param.dtype, # in case `flat_param` changed dtype
2070device=self.flat_param.device,
2071requires_grad=False,
2072)
2073for param, shard_param_info, (param_name, module, _) in zip(
2074flat_param._params, flat_param._shard_param_infos, flat_param._param_infos
2075):
2076self._setattr_param(module, param_name, param)
2077if not shard_param_info.in_shard:
2078# Allow the original data to be freed via garbage collection
2079param.data = size_0_empty_tensor
2080else:
2081offset = shard_param_info.offset_in_shard
2082numel_in_shard = shard_param_info.numel_in_shard
2083param.data = flat_param[offset : offset + numel_in_shard]
2084assert self.flat_param._shared_params is not None
2085for i, (
2086param,
2087(param_name, module, _, prim_param_name, prim_module, _),
2088) in enumerate(
2089zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
2090):
2091self._setattr_param(module, param_name, param)
2092prim_param = getattr(prim_module, prim_param_name)
2093param.data = prim_param # could be both empty and non-empty
2094if self._training_state == HandleTrainingState.BACKWARD_POST:
2095# Clear the saved `Tensor`s since they are unneeded now
2096for i in range(len(self.flat_param._tensors)):
2097self.flat_param._tensors[i] = None
2098
2099@no_type_check
2100@torch.no_grad()
2101def _use_sharded_grad_views(self) -> None:
2102"""
2103Set the original parameter variables' gradients to be flattened views into the sharded flat parameter's gradient.
2104
2105This is a no-op if there is no gradient.
2106
2107Parameters whose data is not present in the sharded flat parameter and
2108parameters with ``requires_grad=False`` have their gradients set to
2109``None``. Since the gradient variables do not need to be preserved,
2110this method does not manipulate existing ``Tensor`` data directly and
2111creates new ``Tensor`` variables instead.
2112"""
2113flat_param = self.flat_param
2114self._check_sharded(flat_param)
2115grad = self.sharded_grad
2116if grad is None:
2117for param in chain(flat_param._params, flat_param._shared_params):
2118param.grad = None
2119return
2120self._check_sharded(grad)
2121for param, shard_param_info, is_grad_none in zip(
2122flat_param._params,
2123flat_param._shard_param_infos,
2124flat_param._is_grad_none_mask,
2125):
2126if not shard_param_info.in_shard:
2127param.grad = None
2128else:
2129numel_in_shard = shard_param_info.numel_in_shard
2130if param.requires_grad and not is_grad_none:
2131offset = shard_param_info.offset_in_shard
2132if self._keep_low_precision_grads or param.dtype != grad.dtype:
2133# NOTE: This is a hack using `.data` to side step the
2134# check that parameter/gradient dtypes match. Here,
2135# `param` has full precision; `grad` has low precision.
2136if param.grad is None:
2137# `.grad` must have the same shape as `param`
2138param.grad = torch.empty_like(param)
2139param.grad.data = grad[
2140offset : offset + numel_in_shard
2141].reshape(param.shape)
2142else:
2143param.grad = grad[offset : offset + numel_in_shard].reshape(
2144param.shape
2145)
2146else:
2147param.grad = None
2148assert flat_param._shared_params is not None
2149for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate(
2150zip(flat_param._shared_params, flat_param._shared_param_infos)
2151):
2152in_sharded_flat_param = hasattr(prim_module, prim_param_name)
2153if in_sharded_flat_param and param.requires_grad:
2154prim_param = getattr(prim_module, prim_param_name)
2155param.grad = prim_param.grad # share the same reference
2156else:
2157param.grad = None
2158
2159@no_type_check
2160@torch.no_grad()
2161def _writeback_orig_params(self) -> bool:
2162"""
2163Write back any parameters that changed storage to the handle's ``FlatParameter``.
2164
2165Iterates over the original parameters and writes back any parameters
2166that changed storages (due to a non-inplace operator) to the handle's
2167``FlatParameter``. This method preserves the ``FlatParameter` 's
2168device even if an original parameter's device changes.
2169
2170Raises:
2171RuntimeError: If an original parameter or gradient changes storages
2172but no longer has the expected flattened shape.
2173Returns: ``True`` if some writeback happened, and ``False`` otherwise.
2174"""
2175if (
2176self.uses_sharded_strategy
2177and not self.is_sharded(self.flat_param)
2178and not self._skipped_use_sharded_views
2179):
2180# For `NO_SHARD`, we may still need to writeback
2181return False
2182flat_param = self.flat_param
2183wroteback = False
2184if self._skipped_use_sharded_views and self.uses_sharded_strategy:
2185# NOTE: We must use the unsharded flat parameter from which the
2186# unsharded views were computed, not the one from the current
2187# calling context (`_get_padded_unsharded_flat_param()`) since that
2188# may be different (e.g. the model changed from train to eval).
2189flat_param_tensor = self._unsharded_flat_param_for_skipped_views
2190_p_assert(
2191_data_ptr_allocated(flat_param_tensor),
2192"If skipped using sharded views, the unsharded flat parameter "
2193"should be allocated",
2194)
2195else:
2196flat_param_tensor = flat_param
2197# NOTE: Since this method is called in the pre-unshard, which is only
2198# called during computation in the pre-forward or pre-backward, the
2199# sharded gradient should be guaranteed to be in `.grad`, not in
2200# `._saved_grad_shard`.
2201flat_param_grad = (
2202flat_param.grad
2203if self.uses_sharded_strategy or not self._offload_params
2204else flat_param._cpu_grad
2205)
2206for i, (
2207param,
2208(in_shard, offset_in_shard, numel_in_shard, _, _),
2209(param_name, module, _),
2210) in enumerate(
2211zip(
2212flat_param._params,
2213flat_param._shard_param_infos,
2214flat_param._param_infos,
2215)
2216):
2217if not in_shard:
2218continue
2219if not hasattr(module, param_name):
2220# Do not writeback if original parameters are deregistered
2221# (e.g. during model checkpointing)
2222continue
2223
2224# Check for parameter writeback
2225if self._skipped_use_sharded_views:
2226param = flat_param._tensors[i]
2227_p_assert(
2228param is not None,
2229f"Expects to have saved tensor for {flat_param._fqns[i]}",
2230)
2231param_changed = getattr(module, param_name) is not param
2232needs_param_writeback = (
2233param_changed # changed parameter variable itself
2234or not _same_storage(param, flat_param_tensor)
2235)
2236if self._skipped_use_sharded_views and (
2237param_changed or needs_param_writeback
2238):
2239raise AssertionError(
2240"FSDP does not support changing the parameters between "
2241f"forward and backward for {self._sharding_strategy}"
2242)
2243if param_changed:
2244# NOTE: The gradient is not preserved after a parameter change.
2245param = getattr(module, param_name)
2246flat_param._params[i] = param
2247if needs_param_writeback:
2248expected_shape = torch.Size([numel_in_shard])
2249self._writeback_tensor(
2250param, flat_param, i, expected_shape, offset_in_shard, True
2251)
2252wroteback = True
2253
2254# Check for gradient writeback
2255if self._skipped_use_sharded_views:
2256# Skip the writeback check because we do not expose gradients
2257# when we skipped using sharded views
2258continue
2259if param.grad is None and flat_param.grad is not None:
2260expected_shape = torch.Size([numel_in_shard])
2261self._writeback_tensor(
2262None, flat_param.grad, i, expected_shape, offset_in_shard, False
2263)
2264elif param.grad is not None:
2265# For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
2266# memory and owns the gradient storage, so it will never
2267# require gradient writeback.
2268if not self.uses_sharded_strategy and self._offload_params:
2269# Explicitly continue to handle the case of `no_sync()`,
2270# where `param.grad` is a view into the GPU gradient
2271# referenced by `flat_param.grad`, while `flat_param_grad`
2272# is `flat_param._cpu_grad`, which is on CPU
2273continue
2274
2275needs_grad_writeback = flat_param_grad is None or not _same_storage(
2276param.grad, flat_param_grad
2277)
2278if needs_grad_writeback:
2279if flat_param_grad is None:
2280flat_param_grad = torch.zeros_like(flat_param)
2281expected_shape = torch.Size([numel_in_shard])
2282self._writeback_tensor(
2283param.grad,
2284flat_param_grad,
2285i,
2286expected_shape,
2287offset_in_shard,
2288False,
2289)
2290flat_param.grad = flat_param_grad
2291flat_param_grad = flat_param.grad
2292
2293# TODO: If we want to handle shared parameters, we need to re-generate
2294# the shared parameter data structures in case sharedness changed.
2295for i, (
2296param_name,
2297module,
2298_,
2299prim_param_name,
2300prim_module,
2301_,
2302) in enumerate(flat_param._shared_param_infos):
2303if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
2304raise NotImplementedError(
2305"Changing shared parameters is not supported yet"
2306)
2307return wroteback
2308
2309def _writeback_tensor(
2310self,
2311src_tensor: Optional[Tensor],
2312dst_tensor: Tensor,
2313tensor_index: int,
2314expected_shape: torch.Size,
2315offset: int,
2316is_param: bool, # else gradient
2317) -> None:
2318"""
2319Write back ``src_tensor`` to ``dst_tensor`` at offset ``offset``, where ``src_tensor`` should have shape ``expected_shape``.
2320
2321``is_param`` indicates if the tensor is the parameter (if ``True``) or gradient (if
2322``False``). If ``src_tensor`` is ``None``, then the effect is zeroing
2323instead of copying. ``tensor_index`` gives the index of ``src_tensor``
2324in the metadata structures.
2325
2326Raises:
2327RuntimeError: If the ``src_tensor`` does not have the expected
2328shape.
2329"""
2330_p_assert(
2331len(expected_shape) == 1,
2332f"Expects a 1D expected shape but got {expected_shape}",
2333)
2334if self._debug_level == dist.DebugLevel.INFO:
2335rank = self.rank if hasattr(self, "rank") else dist.get_rank()
2336src_shape = src_tensor.shape if src_tensor is not None else None
2337src_device = src_tensor.device if src_tensor is not None else None
2338warnings.warn(
2339f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs "
2340f"writeback in {self._training_state}\n"
2341f"expected shape={expected_shape} shape={src_shape} "
2342f"expected device={dst_tensor.device} device={src_device}"
2343)
2344if src_tensor is not None and src_tensor.shape != expected_shape:
2345# NOTE: Gradient shape mismatch is not possible in practice since
2346# the gradient shape is enforced to match that of the parameter and
2347# we already check for parameter shape mismatch.
2348raise RuntimeError(
2349f"Cannot writeback when the {'parameter' if is_param else 'gradient'} "
2350f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}"
2351)
2352if src_tensor is not None:
2353dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor)
2354else:
2355dst_tensor[offset : offset + expected_shape.numel()].zero_()
2356assert self.flat_param._is_grad_none_mask is not None
2357self.flat_param._is_grad_none_mask[tensor_index] = True
2358
2359def _reset_flat_param_grad_info_if_needed(self):
2360"""
2361Reset ``flat_param.grad`` if needed.
2362
2363When ``use_orig_params=True``:
2364(1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the
2365original parameters' ``.grad`` are ``None``, and
2366(2) sets ``flat_param.requires_grad=False`` if *none* of the original
2367parameters require gradient.
2368For (1), this is targeting ``optim.zero_grad(set_to_none=True)``, in
2369which case we want to free the gradients as soon after the
2370``zero_grad()`` call as possible.
2371"""
2372if not self._use_orig_params:
2373return
2374flat_param = self.flat_param
2375assert flat_param._params is not None # mypy
2376all_grad_none = True
2377requires_grad = False
2378for param in flat_param._params:
2379all_grad_none &= param.grad is None
2380requires_grad |= param.requires_grad
2381if all_grad_none:
2382flat_param.grad = None
2383# As long as one parameter requires gradient, then the flat parameter
2384# must require gradient
2385flat_param.requires_grad = requires_grad
2386
2387def _deregister_orig_params(self):
2388for param_info in self.flat_param._param_infos:
2389param_name, module, _ = param_info
2390if hasattr(module, param_name):
2391delattr(module, param_name)
2392for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos:
2393if hasattr(module, param_name):
2394delattr(module, param_name)
2395
2396###########
2397# HELPERS #
2398###########
2399def flat_param_to(self, *args, **kwargs):
2400"""Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
2401self.flat_param.data = self.flat_param.to(*args, **kwargs)
2402if self._use_orig_params:
2403# Refresh the views because their storage may have changed
2404if self.is_sharded(self.flat_param):
2405self._use_sharded_views()
2406else:
2407self._use_unsharded_views(as_params=True)
2408
2409def _get_modules(self) -> Set[nn.Module]:
2410"""Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter."""
2411return {pi.module for pi in self.flat_param._param_infos}.union(
2412{spi.module for spi in self.flat_param._shared_param_infos}
2413)
2414
2415def is_sharded(self, tensor: Tensor) -> bool:
2416"""
2417Return whether ``tensor`` is *currently* sharded.
2418
2419For ``NO_SHARD``, we choose to have this always return ``False`` for clarity.
2420"""
2421if (
2422not hasattr(self.flat_param, "_sharded_size")
2423or not self.uses_sharded_strategy
2424):
2425# `_sharded_size` is defined iff `handle.shard()` has been called
2426return False
2427sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
2428return tensor.size() == sharded_size
2429
2430def param_module_names(self) -> Iterator[Tuple[str, str]]:
2431shared_param_infos = [
2432ParamInfo(param_name, module, module_name)
2433for (
2434param_name,
2435module,
2436module_name,
2437_,
2438_,
2439_,
2440) in self.flat_param._shared_param_infos
2441]
2442for param_info in chain(self.flat_param._param_infos, shared_param_infos):
2443param_name, _, module_name = param_info # type: ignore[misc]
2444yield (param_name, module_name)
2445
2446def shared_param_module_names(self) -> Iterator[Tuple[str, str]]:
2447for param_name, _, module_name in [
2448ParamInfo(param_name, module, module_name)
2449for (
2450param_name,
2451module,
2452module_name,
2453_,
2454_,
2455_,
2456) in self.flat_param._shared_param_infos
2457]:
2458yield (param_name, module_name)
2459
2460@property
2461def _fqns_in_shard(self) -> List[str]:
2462"""Return the FQNs of the parameters present in this rank's shard."""
2463fqns_in_shard: List[str] = []
2464for fqn, shard_param_info in zip(
2465self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined]
2466):
2467if shard_param_info.in_shard:
2468fqns_in_shard.append(fqn)
2469return fqns_in_shard
2470
2471@property
2472def sharded_grad(self) -> Optional[Tensor]:
2473"""Return the handle's sharded gradient."""
2474flat_param = self.flat_param
2475# Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad`
2476# - CPU offloading: `_cpu_grad`
2477# - No CPU offloading + sharded strategies: `_saved_grad_shard`
2478# - No CPU offloading + `NO_SHARD`: `grad`
2479grad: Optional[Tensor]
2480if hasattr(flat_param, "_cpu_grad"):
2481grad = flat_param._cpu_grad # type: ignore[attr-defined]
2482elif hasattr(flat_param, "_saved_grad_shard"):
2483# In the post-backward hook, the sharded gradient is still in
2484# `_saved_grad_shard`.
2485grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
2486else:
2487# If in IDLE or in FORWARD states, then there may be an
2488# (accumulated) gradient. If accessed in IDLE, then this should
2489# be due to re-registering the original parameters (e.g. in state
2490# dict load).
2491_p_assert(
2492flat_param.grad is None
2493or not self.uses_sharded_strategy
2494or self._training_state
2495in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE),
2496"Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` "
2497"unless in IDLE or FORWARD",
2498)
2499grad = flat_param.grad
2500return grad
2501
2502def _reset_is_grad_none(self) -> None:
2503"""
2504Reset ``_is_grad_none_mask`` as needed.
2505
2506This method should only be
2507called in the post-backward after gradient computation, in which case
2508if a parameter requires gradient, then it will surely receive a
2509gradient and we may reset its mask entry to ``False``.
2510"""
2511if not self._use_orig_params:
2512return
2513_p_assert(
2514self._training_state == HandleTrainingState.BACKWARD_POST,
2515"Expects to only be called in the post-backward after gradient computation",
2516)
2517flat_param = self.flat_param
2518assert flat_param._params is not None # mypy
2519for i, param in enumerate(flat_param._params): # type: ignore[arg-type]
2520# As long as the parameter requires gradient, it should receive a
2521# meaningful gradient (even if the gradient happens to be zeros)
2522if param.requires_grad:
2523assert flat_param._is_grad_none_mask is not None # mypy
2524flat_param._is_grad_none_mask[i] = False
2525
2526#######################
2527# CHECKS & INVARIANTS #
2528#######################
2529def _check_sharded_strategy(self):
2530_p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
2531
2532def _check_on_compute_device(self, tensor: Tensor):
2533_p_assert(
2534tensor.device == self.device,
2535f"Expects tensor to be on the compute device {self.device}, was on {tensor.device}",
2536)
2537
2538def _check_on_cpu(self, tensor: Tensor):
2539_p_assert(
2540tensor.device == torch.device("cpu"),
2541f"Expects tensor to be on CPU but got {tensor.device}",
2542)
2543
2544@staticmethod
2545def _check_storage_freed(tensor: Tensor):
2546# Compile does not resize during trace
2547if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
2548_p_assert(
2549_same_storage_size(tensor, 0),
2550"Expects storage to be freed but got storage with size > 0",
2551)
2552
2553@staticmethod
2554def _check_storage_allocated(tensor: Tensor):
2555_p_assert(_storage_size_allocated(tensor), "Expects storage to be allocated")
2556
2557def _check_low_precision_shard(self):
2558_p_assert(
2559self._uses_param_mixed_precision,
2560"Not using low precision for parameters",
2561)
2562_p_assert(
2563getattr(self.flat_param, "_mp_shard", None) is not None,
2564"Expects `_mp_shard` to exist",
2565)
2566device = self.flat_param._mp_shard.device # type: ignore[attr-defined]
2567_p_assert(
2568device == self.device,
2569f"Expects the low precision shard to be on {self.device} but got {device}",
2570)
2571
2572def _check_unsharded(self, tensor: Tensor):
2573msg_prefix = "Expects tensor to be unsharded "
2574_p_assert(tensor is not None, msg_prefix + "but got `None`")
2575unsharded_size = self.flat_param._unpadded_unsharded_size
2576_p_assert(
2577tensor.size() == unsharded_size,
2578msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
2579)
2580
2581def _check_sharded(self, tensor: Tensor):
2582msg_prefix = "Expects tensor to be sharded "
2583_p_assert(tensor is not None, msg_prefix + "but got `None`")
2584sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
2585_p_assert(
2586tensor.size() == sharded_size,
2587msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
2588)
2589
2590##############
2591# PROPERTIES #
2592##############
2593@property
2594def uses_sharded_strategy(self) -> bool:
2595return self._sharding_strategy != HandleShardingStrategy.NO_SHARD
2596
2597@property
2598def _uses_param_mixed_precision(self) -> bool:
2599return self._fwd_bwd_param_dtype != self._orig_param_dtype
2600
2601@property
2602def _uses_reduce_mixed_precision(self) -> bool:
2603return self._reduce_dtype != self._orig_param_dtype
2604
2605@property
2606def _force_full_precision(self) -> bool:
2607return (
2608self._uses_param_mixed_precision or self._uses_reduce_mixed_precision
2609) and (
2610self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
2611or
2612# Also disable mixed precision in model eval mode, if configured
2613(not self._fully_sharded_module.training and self._use_full_prec_in_eval)
2614)
2615
2616@property
2617def _skipped_use_sharded_views(self) -> bool:
2618"""
2619This property is used for sharding strategies that do not free after forward with ``use_orig_params=True``.
2620
2621This returns if this handle is
2622currently in a state where it has skipped using sharded views, in which
2623case it can restore view invariants via ``_use_sharded_views()``.
2624"""
2625return self._unsharded_flat_param_for_skipped_views is not None
2626
2627
2628# NOTE: These are hacks to bypass `nn.Module.__setattr__` checks.
2629def _unsafe_setattr_param(
2630module: nn.Module, param_name: str, param: nn.Parameter
2631) -> None:
2632module._parameters[param_name] = param
2633# This bypasses any overrides in case `module` is an instance of an
2634# `nn.Module` subclass
2635super(nn.Module, module).__setattr__(param_name, param)
2636
2637
2638def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None:
2639module._parameters.pop(param_name, None)
2640# This bypasses any overrides in case `module` is an instance of an
2641# `nn.Module` subclass
2642super(nn.Module, module).__setattr__(param_name, tensor)
2643
2644
2645def _safe_setattr_tensor_or_param(
2646module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter]
2647):
2648# Call `delattr()` and `setattr()` to go through `nn.Module` checks
2649if hasattr(module, param_name):
2650delattr(module, param_name)
2651setattr(module, param_name, tensor_or_param)
2652
2653
2654def _convert_to_params(
2655tensors: List[Union[torch.Tensor, nn.Parameter]]
2656) -> List[nn.Parameter]:
2657return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]
2658
2659
2660def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor:
2661return (
2662param_or_tensor.detach()
2663if isinstance(param_or_tensor, nn.Parameter)
2664else param_or_tensor
2665)
2666
2667
2668def _get_aligned_numel(unsharded_dtype: torch.dtype):
2669# NOTE: This alignment constraint comes from TorchInductor.
2670ALIGNMENT = 16 # bytes
2671unsharded_dtype_size = _get_dtype_size(unsharded_dtype)
2672aligned_numel = ALIGNMENT // unsharded_dtype_size
2673return aligned_numel
2674
2675
2676@functools.lru_cache(8)
2677def _get_dtype_size(dtype):
2678return torch.empty((), dtype=dtype).element_size()
2679
2680
2681def _construct_padding_tensor(
2682padding_numel: int, dtype: torch.dtype, requires_grad: bool, device: torch.device
2683):
2684# NOTE: Set the padding value as a magic number for debuggability. The
2685# value itself should never be used in any user-facing computation.
2686return (
2687torch.ones(
2688(padding_numel,), dtype=dtype, requires_grad=requires_grad, device=device
2689)
2690* _FLAT_PARAM_PADDING_VALUE
2691)
2692
2693
2694# Use `lru_cache(1)` to only log the warning once (assuming the fixed warning
2695# messasge is passed in)
2696@functools.lru_cache(1)
2697def _warn_skip_writeback_check(log: logging.Logger, warning: str):
2698log.warning(warning)
2699
2700
2701# Use `lru_cache(1)` to only log the warning once
2702@functools.lru_cache(1)
2703def _warn_use_fake_all_gather(log: logging.Logger, warning: str):
2704log.warning(warning)
2705
2706
2707# Use `lru_cache(1)` to only log the warning once
2708@functools.lru_cache(1)
2709def _warn_use_fake_reduce(log: logging.Logger, warning: str):
2710log.warning(warning)
2711
2712
2713def _same_storage(a, b):
2714return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
2715
2716
2717def _same_storage_size(a: torch.Tensor, b: int):
2718return a.untyped_storage().size() // a.element_size() == b
2719
2720
2721def _storage_size_allocated(tensor: Tensor):
2722storage_size: int = tensor.untyped_storage().size()
2723return storage_size > 0
2724