pytorch
1630 строк · 64.4 Кб
1import functools2import logging3from enum import auto, Enum4from typing import Any, Callable, Dict, List, no_type_check, Optional, Set, Tuple5
6import torch7import torch.distributed as dist8import torch.distributed.fsdp._traversal_utils as traversal_utils9import torch.nn as nn10import torch.nn.functional as F11from torch.autograd import Variable12from torch.autograd.graph import register_multi_grad_hook13from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS14from torch.distributed.fsdp._common_utils import (15_assert_in_training_states,16_FSDPState,17_get_module_fsdp_state,18_is_composable,19_log_post_backward_hook,20_no_dispatch_record_stream,21clean_tensor_name,22TrainingState,23)
24from torch.distributed.fsdp._flat_param import (25FlatParameter,26FlatParamHandle,27HandleShardingStrategy,28HandleTrainingState,29RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES,30)
31from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES32from torch.distributed.fsdp.api import BackwardPrefetch33from torch.distributed.utils import (34_apply_to_tensors,35_cast_forward_inputs,36_p_assert,37_to_kwargs,38)
39from torch.utils import _pytree as pytree40
41log = logging.getLogger(__name__)42
43# Do not include "process_group" to enable hybrid shard and MoE cases
44HOMOGENEOUS_ATTR_NAMES = (45"_use_orig_params",46"limit_all_gathers",47"_use_full_prec_in_eval",48)
49
50
51class _PrefetchMode(Enum):52BACKWARD = auto()53FORWARD = auto()54
55
56def _get_fsdp_root_states_with_modules(57module: nn.Module,58) -> Tuple[List[_FSDPState], List[nn.Module]]:59"""60Returns a tuple containing:
611. A list of the root ``_FSDPState`` instances in the module tree rooted at
62``module`` without any duplicates and following the ``module.modules()``
63traversal order (which is assumed to be depth-first).
642. A corresponding list of the root modules owning the states in the first
65list.
66
67This is similar to :func:`_get_fsdp_states_with_modules` except that we
68must call :func:`_is_fsdp_root` to force a lazy initialization to determine
69the FSDP root in case lazy initialization has not yet happened.
70"""
71fsdp_root_states: List[_FSDPState] = []72fsdp_root_modules: List[nn.Module] = []73visited_fsdp_states: Set[_FSDPState] = set()74# NOTE: This function assumes that `module.modules()` proceeds top-down.75for submodule in module.modules():76optional_state = _get_module_fsdp_state(submodule)77if (78optional_state is not None79and optional_state not in visited_fsdp_states80and _is_fsdp_root(optional_state, submodule)81):82visited_fsdp_states.add(optional_state)83fsdp_root_states.append(optional_state)84fsdp_root_modules.append(submodule)85return fsdp_root_states, fsdp_root_modules86
87
88def _get_fsdp_root_states(module: nn.Module) -> List[_FSDPState]:89"""See :func:`_get_fsdp_root_states_with_modules`."""90fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module)91return fsdp_root_states92
93
94def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool:95"""96Returns if ``state`` corresponds to that of an FSDP root.
97
98For the wrapper code path, ``state`` and ``module`` should be the same. For
99the non-wrapper code path, ``state`` should be ``module`` 's state.
100"""
101# Force a lazy initialization to determine the FSDP root102_lazy_init(state, module)103assert state._is_root is not None # mypy104return state._is_root105
106
107@no_type_check
108def _lazy_init(109state: _FSDPState,110root_module: nn.Module,111) -> _FSDPState:112"""113Performs initialization lazily, typically right before the first forward
114pass. The laziness is needed to ensure that the parameter device/dtype and
115the FSDP hierarchy have finalized. This method's actual logic only runs on
116the root FSDP instance, which performs initialization for all non-root FSDP
117instances to avoid partial initialization.
118
119For the non-composable code path, ``state`` and ``root_module`` should be
120the same, namely the FSDP instance itself.
121"""
122if state._is_root is not None:123return # no-op: already lazily initialized124if not state._device_handle.is_available():125# Allow the FSDP constructor to run even without CUDA but check this126# once we start real execution127raise RuntimeError("FSDP does not support CPU only execution")128# The following logic is only run on the root FSDP instance since it will129# set `_is_root=False` for the non-root instances130state._is_root = True131_assert_in_training_states(state, [TrainingState.IDLE])132_check_flat_params_on_expected_device(state, root_module)133state._all_fsdp_states = traversal_utils._get_fsdp_states(root_module)134_init_streams(state)135buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation(state, root_module)136_cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, state.compute_device)137state._exec_order_data.init(state, root_module, state.process_group)138_share_state_and_init_handle_attrs(state, root_module)139return state140
141
142def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module):143"""144Checks that all ``FlatParameter``s in ``module`` 's tree managed by
145``state`` are on the expected device for *lazy initialization*.
146"""
147cpu_device = torch.device("cpu")148for handle in traversal_utils._get_fsdp_handles(module):149if (150not handle._offload_params151and handle.flat_param.device != state.compute_device152):153raise RuntimeError(154"An FSDP-managed module unexpectedly has parameters on "155f"{handle.flat_param.device}. Make sure to move the module to "156f"{state.compute_device} before training."157)158elif handle._offload_params and handle.flat_param.device != cpu_device:159raise RuntimeError(160"An FSDP-managed module with parameter CPU offloading enabled "161f"has parameters on {handle.flat_param.device}. Make sure to "162f"not move the module from CPU when offloading parameters."163)164
165
166@no_type_check
167def _share_state_and_init_handle_attrs(168root_state: _FSDPState,169root_module: nn.Module,170) -> None:171"""172Shares data structure state from the ``root_state`` to all FSDP states in
173``root_module`` 's module tree, and initializes handle attributes. These
174are done together to require a single loop over the states.
175"""
176handle = root_state._handle177if handle:178handle.init_flat_param_attributes()179attr_name_to_values: Dict[str, Set[Any]] = {}180for attr_name in HOMOGENEOUS_ATTR_NAMES:181attr_name_to_values[attr_name] = set()182root_state._all_handles = root_state._exec_order_data.all_handles # share reference183# Update _has_optim_in_backward for each handle.184for handle in root_state._all_handles:185flat_param = handle.flat_param186if hasattr(flat_param, "_in_backward_optimizers"):187raise RuntimeError(188"FSDP optimizer in backward only supported with use_orig_params=True!"189)190handle._has_optim_in_backward = flat_param._params is not None and any(191hasattr(param, "_in_backward_optimizers") for param in flat_param._params192)193if handle._has_optim_in_backward:194torch._C._log_api_usage_once("fsdp.optimizer_in_backward")195for fsdp_state in root_state._all_fsdp_states:196for attr_name in HOMOGENEOUS_ATTR_NAMES:197_p_assert(198hasattr(fsdp_state, attr_name),199f"FSDP state missing attribute {attr_name}",200)201attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name))202if fsdp_state is root_state:203continue204# Relax the assert for non-root FSDP instances in case the nested205# initialized module is wrapped again in FSDP later (e.g. after206# training to run inference)207_p_assert(208fsdp_state._is_root is None or not fsdp_state._is_root,209"Non-root FSDP instance's `_is_root` should not have been "210"set yet or should have been set to `False`",211)212fsdp_state._is_root = False213fsdp_state._unshard_stream = root_state._unshard_stream214fsdp_state._post_backward_stream = root_state._post_backward_stream215fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream216fsdp_state._all_reduce_stream = root_state._all_reduce_stream217fsdp_state._default_stream = root_state._default_stream218fsdp_state._exec_order_data = root_state._exec_order_data219fsdp_state._free_event_queue = root_state._free_event_queue220if fsdp_state._fsdp_extension is not None:221fsdp_state._fsdp_extension.compute_stream = root_state._default_stream222handle = fsdp_state._handle223if handle:224handle.init_flat_param_attributes()225for attr_name, attr_values in attr_name_to_values.items():226if len(attr_values) != 1:227raise ValueError(228f"Expects one homogeneous value for {attr_name} but got {attr_values}"229)230
231
232@no_type_check
233def _init_streams(234state: _FSDPState,235) -> None:236"""237Initializes CUDA streams for overlapping communication, computation, and
238data transfers. The streams should be shared across FSDP instances.
239"""
240assert state._is_root241assert state._device_handle.is_available()242uses_hybrid_sharding = any(243fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES244for fsdp_state in state._all_fsdp_states245)246# Prioritize all-gathers/reduce-scatters over async all-reduce for HSDP and247# preserve the default priority of 0 otherwise248high_priority = -1 if state.limit_all_gathers and uses_hybrid_sharding else 0249# Default stream for computation250state._default_stream = state._device_handle.current_stream()251if state._fsdp_extension is not None:252# set the compute stream to the FSDP extension253state._fsdp_extension.compute_stream = state._default_stream254
255# Stream for unshard logic, including allocating the all-gather destination256# tensors and the all-gathers themselves257state._unshard_stream = state._device_handle.Stream(priority=high_priority)258# Stream for overlapping gradient reduction with the backward pass gradient259# computation260state._post_backward_stream = state._device_handle.Stream(priority=high_priority)261# Stream for pre-unshard logic, namely allocations and writes for CPU262# offloading (H2D copy) and mixed precision (low precision cast)263state._pre_unshard_stream = state._device_handle.Stream(priority=high_priority)264# Stream to run HSDP's all-reduce as async (if using HSDP)265state._all_reduce_stream = (266state._device_handle.Stream() if uses_hybrid_sharding else state._default_stream267)268
269
270@no_type_check
271def _unshard(272state: _FSDPState,273handle: FlatParamHandle,274unshard_stream: torch.Stream,275pre_unshard_stream: torch.Stream,276) -> None:277"""278Unshards the handles in ``handles``. If the handles are in
279:meth:`summon_full_params` and are using mixed precision, then they are
280forced to full precision.
281
282Postcondition: handle's ``FlatParameter`` 's data is the padded
283unsharded flat parameter on the compute device.
284"""
285if not handle:286return287with state._device_handle.stream(pre_unshard_stream):288ran_pre_unshard = handle.pre_unshard()289if ran_pre_unshard:290unshard_stream.wait_stream(pre_unshard_stream)291if state.limit_all_gathers:292event = state._free_event_queue.dequeue_if_needed()293if event:294with torch.profiler.record_function(295"FullyShardedDataParallel.rate_limiter"296):297event.synchronize()298with state._device_handle.stream(unshard_stream):299handle.unshard()300handle.post_unshard()301
302
303@no_type_check
304def _reshard(305state: _FSDPState,306handle: FlatParamHandle,307free_unsharded_flat_param: bool,308):309"""310Reshards the handle. ``free_unsharded_flat_param`` indicates whether to
311free the handle's padded unsharded flat parameter.
312"""
313handle.reshard(free_unsharded_flat_param)314if state.limit_all_gathers and free_unsharded_flat_param:315if not torch.distributed._functional_collectives.is_torchdynamo_compiling():316# We don't run a even queue for freeing under torch compile atm317# But maybe we need to? TODO(voz): Look into this318free_event = state._device_handle.Event()319free_event.record()320state._free_event_queue.enqueue(free_event)321handle.post_reshard()322# Flat parameter freed or not, we always have to "unshard" the parameter323# upon next access to get its shape correct.324handle._prefetched = False325
326
327def _unshard_grads(328handle: Optional[FlatParamHandle],329) -> None:330if handle:331handle.unshard_grad()332
333
334def _reshard_grads(335handle: Optional[FlatParamHandle],336) -> None:337if handle:338handle.reshard_grad()339
340
341@no_type_check
342def _pre_forward(343state: _FSDPState,344handle: Optional[FlatParamHandle],345unshard_fn: Callable,346module: nn.Module,347args: Tuple[Any, ...],348kwargs: Dict[str, Any],349) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:350"""351Runs the pre-forward logic. This includes an opportunity to unshard
352currently sharded parameters such as those for the current forward and
353registering post-backward hooks for these current parameters. This function
354also converts forward ``args`` and ``kwargs`` to the given precision.
355
356Args:
357handles (List[FlatParamHandle]): Handles giving the parameters used in
358the current forward.
359unshard_fn (Optional[Callable]): A callable to unshard any currently
360sharded parameters or ``None`` to not do any unsharding.
361module (nn.Module): Module whose forward this method runs right before;
362expected by the hook signature.
363args (Tuple[Any, ...]): Module forward ``args``.
364kwargs (Dict[str, Any]): Module forward ``kwargs``.
365"""
366with torch.profiler.record_function("FullyShardedDataParallel._pre_forward"):367# For `fully_shard` + `checkpoint`, skip pre-forward logic in the368# recomputed forward369if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:370# For both checkpoint implementations, we do not need to re-cast371# inputs here since they will be checkpointed in the low precision372# either by AC or normally by autograd as long as the AC region is373# nested within FSDP374return args, kwargs375state.training_state = TrainingState.FORWARD_BACKWARD376state._exec_order_data.record_pre_forward(handle, module.training)377if handle:378handle._training_state = HandleTrainingState.FORWARD379if unshard_fn is not None:380unshard_fn(state, handle)381# Register post-backward hooks to reshard the parameters and reduce-scatter382# their gradients. They must be re-registered every forward pass in case383# the `grad_fn` is mutated.384_register_post_backward_hook(state, handle)385# We have to reallocate the _cpu_grad if optimizer overlap386# set the grad to None in the backward pass.387if handle and handle._offload_params and handle.flat_param._cpu_grad is None:388handle.flat_param._cpu_grad = torch.zeros_like(389handle.flat_param._local_shard, device=torch.device("cpu")390).pin_memory()391
392should_cast_forward_inputs = (393state._handle and not state._handle._force_full_precision394)395
396if should_cast_forward_inputs and state.mixed_precision.cast_forward_inputs:397# Recursively convert args and kwargs to specified precision.398input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype399args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)400_register_post_backward_reshard_only_hook(state, handle, args, kwargs)401return args, kwargs402
403
404@no_type_check
405def _pre_forward_unshard(406state: _FSDPState,407handle: Optional[FlatParamHandle],408) -> None:409"""Unshards parameters in the pre-forward."""410if not handle:411return412# If the handles have been prefetched, then there is no need to call413# `_unshard()` again414if not handle._prefetched:415_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)416handle._needs_pre_forward_unshard = False417# Don't wait during trace418if not torch.distributed._functional_collectives.is_torchdynamo_compiling():419state._device_handle.current_stream().wait_stream(state._unshard_stream)420with torch.profiler.record_function(421"FullyShardedDataParallel._pre_forward_prefetch"422):423_prefetch_handle(state, handle, _PrefetchMode.FORWARD)424
425
426@no_type_check
427def _post_forward(428state: _FSDPState,429handle: Optional[FlatParamHandle],430reshard_fn: Callable,431module: nn.Module,432input: Any,433output: Any,434) -> Any:435"""436Runs the post-forward logic. This includes an opportunity to reshard
437currently unsharded parameters such as those used in the current forward
438and registering pre-backward hooks on the forward outputs.
439
440Args:
441handles (List[FlatParamHandle]): Handles giving the parameters used in
442the current forward.
443reshard_fn (Optional[Callable]): A callable to reshard any currently
444unsharded parameters (e.g. from the current forward) or ``None`` to
445not do any resharding.
446module (nn.Module): Module whose forward just ran, which should be a
447fully sharded module (see [Note: Fully Sharded Module]); expected
448by the hook signature.
449input (Any): Unused; expected by the hook signature.
450output (Any): Forward pass output; pre-backward hooks are registered on
451the tensors that require gradients in this output.
452
453Postcondition: Each ``FlatParameter`` 's data points to the sharded flat
454parameter.
455"""
456with torch.profiler.record_function("FullyShardedDataParallel._post_forward"):457# For `fully_shard` + `checkpoint`, skip post-forward logic in the458# recomputed forward459if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:460return output461
462state._exec_order_data.record_post_forward(handle)463if reshard_fn is not None:464reshard_fn(state, handle)465# Register pre-backward hooks to unshard the flat parameters for the466# gradient computation (if needed)467output = _register_pre_backward_hooks(state, module, output, handle)468state.training_state = TrainingState.IDLE469if handle:470handle._training_state = HandleTrainingState.IDLE471return output472
473
474@no_type_check
475def _post_forward_reshard(476state: _FSDPState,477handle: FlatParamHandle,478) -> None:479"""Reshards parameters in the post-forward."""480if not handle:481return482# Do not free the root's parameters in the post-forward for `FULL_SHARD`483# with the intention that they are immediately used for backward484# computation (though this may not be true)485free_unsharded_flat_param = (486not state._is_root487and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES488)489_reshard(state, handle, free_unsharded_flat_param)490
491
492@no_type_check
493def _root_pre_forward(494state: _FSDPState,495module: nn.Module,496args,497kwargs,498) -> None:499"""500Runs pre-forward logic specific to the root FSDP instance, which should run
501before any individual module's pre-forward. This starts with an attempt at
502lazy initialization (which only runs non-vacuously once). Otherwise, if
503this is called on a non-root FSDP instance, then it returns directly.
504
505Args:
506module (nn.Module): Module for which this logic tries to run. It may or
507may not be the root. If not, then this method does not do anything.
508"""
509with torch.profiler.record_function("FullyShardedDataParallel._root_pre_forward"):510_lazy_init(state, module)511_p_assert(state._is_root is not None, "Expects a root FSDP to have been set")512if not state._is_root:513# Always cast forward inputs in the root of this local FSDP unit for mixed514# precision, as this is where mixed precision could be configed.515# This is more useful for auto wrapping that is recommended in composable path.516# For manual wrapping, cast forward inputs on each local FSDP unit root will517# increase some overhead, so not turned on for model wrapper path right now where518# manual wrapping is more broadly used.519if _is_composable(state):520return _root_cast_forward_input(state, module, args, kwargs)521return args, kwargs522
523# We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers524# are in full precision and if we should cast them back to lower precision, which happens when525# exiting eval() mode.526handle = state._handle527if handle:528should_cast_buffers_to_full_prec = handle._force_full_precision529else:530should_cast_buffers_to_full_prec = True531
532if should_cast_buffers_to_full_prec:533_cast_buffers_to_dtype_and_device(534buffers=dict(module.named_buffers()).values(),535buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()),536device=state.compute_device,537)538# This flag is only set when we cast buffers to full precision, to avoid the539# CPU overhead that can stem from retrieving all buffers and their types in the540# following else branch.541state._needs_buffer_dtype_restore_check = True542elif getattr(state, "_needs_buffer_dtype_restore_check", False):543# Check if buffers are in full precision and we need to cast them544# back down.545(546buffers,547buffer_dtypes_for_computation,548) = _get_buffers_and_dtypes_for_computation(state, module)549if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0:550if any(551buffer.dtype != buffer_dtype_for_computation552for buffer, buffer_dtype_for_computation in zip(553buffers, buffer_dtypes_for_computation554)555):556# Assume we have to cast everything if there is one mismatch557_cast_buffers_to_dtype_and_device(558buffers, buffer_dtypes_for_computation, state.compute_device559)560# We don't have to check this again until we cast buffers to full precision again.561state._needs_buffer_dtype_restore_check = False562
563if state.forward_prefetch:564handles = []565for fsdp_state in state._all_fsdp_states:566if fsdp_state._handle:567handles.append(fsdp_state._handle)568for handle in handles:569handle._needs_pre_forward_unshard = True570handle._prefetched = False571_wait_for_computation_stream(572state._device_handle.current_stream(),573state._unshard_stream,574state._pre_unshard_stream,575)576_reset_flat_param_grad_info_if_needed(state._all_handles)577
578# Prepares the forward inputs by moving them to ``compute_device``579# TODO: Do not use the side stream for tensor copies for now; investigate580# the perf with/without it.581with torch.profiler.record_function("FullyShardedDataParallel._to_kwargs"):582args_tuple, kwargs_tuple = _to_kwargs(583args, kwargs, state.compute_device, False584)585args = args_tuple[0]586kwargs = kwargs_tuple[0]587
588return _root_cast_forward_input(state, module, args, kwargs)589
590
591@no_type_check
592def _root_cast_forward_input(593state: _FSDPState, module: torch.nn.Module, args, kwargs594) -> Tuple[Any, Any]:595if state._handle:596force_full_precision = not state._handle._force_full_precision597else:598force_full_precision = True599
600should_cast_forward_inputs = (601(module.training or not state._use_full_prec_in_eval) and force_full_precision602) and state.mixed_precision.cast_root_forward_inputs603
604if should_cast_forward_inputs:605input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype606args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)607
608return args, kwargs609
610
611@no_type_check
612def _pre_backward_hook(613state: _FSDPState,614module: nn.Module,615handle: FlatParamHandle,616grad,617*unused: Any,618) -> Any:619"""620Prepares ``_handle`` 's ``FlatParameter`` s for gradient computation.
621
622Args:
623module (nn.Module): Fully sharded module (see [Note: Fully Sharded
624Module]).
625"""
626# Only run the pre-backward hook once per group of handles involved in the627# same module forward computation628if (629handle
630and hasattr(handle, "_ran_pre_backward_hook")631and handle._ran_pre_backward_hook632):633log.debug("%s %s", id(state), "Not Running pre backward! Already Ran!")634return grad635
636with torch.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"):637# Queue the post-backward callback once for the root FSDP instance to638# attach it to the outermost backward graph task so that it is called639# after all backward calls complete640if state._is_root and not state._post_backward_callback_queued:641_register_post_backward_final_callback(state, module)642_reset_flat_param_grad_info_if_needed(state._all_handles)643elif handle:644allowed_states = [TrainingState.IDLE]645if _is_composable(state):646allowed_states.append(TrainingState.FORWARD_BACKWARD)647_assert_in_training_states(state, allowed_states)648state.training_state = TrainingState.FORWARD_BACKWARD649# Queueing the post-backward callback is the only logic that is not650# per-handle in the pre-backward hook, so we can return early here if651# there are no handles.652if not handle:653return grad654handle._training_state = HandleTrainingState.BACKWARD_PRE655
656if handle._needs_pre_backward_unshard:657# If the handles have been prefetched, then there is no need to658# call `_unshard()` again659if not handle._prefetched:660_unshard(661state,662handle,663state._unshard_stream,664state._pre_unshard_stream,665)666# Don't wait during trace667if not torch.distributed._functional_collectives.is_torchdynamo_compiling():668state._device_handle.current_stream().wait_stream(state._unshard_stream)669
670# Set this to `False` to ensure that a mistargeted prefetch does not671# actually unshard these handles672handle._needs_pre_backward_unshard = False673with torch.profiler.record_function(674"FullyShardedDataParallel._pre_backward_prefetch"675):676_prefetch_handle(state, handle, _PrefetchMode.BACKWARD)677handle.prepare_gradient_for_backward()678handle._ran_pre_backward_hook = True679return grad680
681
682@no_type_check
683@torch.no_grad()684def _post_backward_hook(685state: _FSDPState,686handle: FlatParamHandle,687flat_param,688*unused: Any,689):690"""691Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
692
693Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the
694unsharded gradient for the local batch.
695
696Postcondition:
697- If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced
698unsharded gradient.
699- Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
700gradient (accumulating with any existing gradient).
701"""
702_log_post_backward_hook(state, handle, log)703flat_param = handle.flat_param704flat_param._post_backward_called = True705with torch.autograd.profiler.record_function(706"FullyShardedDataParallel._post_backward_hook"707):708_assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])709# For multiple applications of reentrant AC across submodules sharing710# the same `FlatParameter`, the post-backward hook may run multiple711# times in one backward, in which case we permit the state to already712# be in `BACKWARD_POST`.713_p_assert(714handle._training_state715in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),716f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",717)718handle._training_state = HandleTrainingState.BACKWARD_POST719
720if flat_param.grad is None:721return722if flat_param.grad.requires_grad:723raise RuntimeError("FSDP does not support gradients of gradients")724
725_post_backward_reshard(state, handle)726if not state._sync_gradients:727if handle._use_orig_params:728handle._use_unsharded_grad_views()729return730
731# Wait for all ops in the current stream (e.g. gradient computation) to732# finish before reduce-scattering the gradient733if not torch.distributed._functional_collectives.is_torchdynamo_compiling():734state._post_backward_stream.wait_stream(735state._device_handle.current_stream()736)737
738with state._device_handle.stream(state._post_backward_stream):739autograd_computed_grad = flat_param.grad.data740if (741not _low_precision_hook_enabled(state)742and flat_param.grad.dtype != handle._reduce_dtype743# If we are forcing full precision but communicating grads744# (i.e. model.eval() + full precision in eval was configured), don't downcast gradient.745and not handle._force_full_precision746):747flat_param.grad.data = flat_param.grad.to(handle._reduce_dtype)748if handle.uses_sharded_strategy:749_reduce_grad(state, handle)750else:751_reduce_grad_no_shard(state, handle)752# Since the unsharded gradient is produced in the computation753# stream and consumed in the post-backward stream, inform the754# caching allocator (before it goes out of scope)755_no_dispatch_record_stream(756autograd_computed_grad, state._post_backward_stream757)758
759
760def _post_backward_reshard_only_hook(761state: _FSDPState,762handle: FlatParamHandle,763*unused: Any,764) -> None:765with torch.profiler.record_function(766"FullyShardedDataParallel._post_backward_hook_reshard_only"767):768# `_pre_backward_hook` may not get executed769# if forward output does not require grad770# overwrite IDLE state for post-backward prefetching771state.training_state = TrainingState.FORWARD_BACKWARD772handle._training_state = HandleTrainingState.BACKWARD_POST773_post_backward_reshard(state, handle)774
775
776def _post_backward_reshard(777state: _FSDPState,778handle: FlatParamHandle,779*unused: Any,780) -> None:781free_unsharded_flat_param = _should_free_in_backward(state, handle)782_reshard(state, handle, free_unsharded_flat_param)783
784# TODO: Post-backward prefetching does not support the multiple handles785# per module case since the post-backward hook runs per handle, not per786# group of handles.787with torch.profiler.record_function(788"FullyShardedDataParallel._post_backward_prefetch"789):790_prefetch_handle(state, handle, _PrefetchMode.BACKWARD)791
792
793@no_type_check
794def _should_free_in_backward(795state: _FSDPState,796handle: FlatParamHandle,797) -> bool:798"""799Returns whether FSDP should free the unsharded flat parameter in the
800post-backward or not.
801"""
802if not handle.uses_sharded_strategy:803return False804# If not syncing gradients, then we do not free for strategies that do not805# reshard after forward as a *heuristic* to tradeoff higher memory for806# higher throughput.807return (808state._sync_gradients809or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES810)811
812
813@no_type_check
814def _reduce_grad(state: _FSDPState, handle: FlatParamHandle) -> None:815"""816For sharded strategies, this runs gradient reduction, sharded gradient
817accumulation if needed, and the post-reduction callback.
818"""
819flat_param = handle.flat_param820uses_hybrid_sharded_strategy = handle._sharding_strategy in (821HandleShardingStrategy.HYBRID_SHARD,822HandleShardingStrategy._HYBRID_SHARD_ZERO2,823)824# We clear `.grad` to permit multiple backwards. This avoids a race where825# the second backward pass computation precedes ahead of the first backward826# pass reduction, which is possible since the reduction is issued in a827# separate stream and is async and would result in reducing the wrong828# gradient.829unsharded_grad = flat_param.grad.data830flat_param.grad = None831padded_unsharded_grad, new_sharded_grad = _get_reduce_scatter_tensors(832state, unsharded_grad833)834if state._comm_hook is None: # default path835_div_if_needed(padded_unsharded_grad, state._gradient_predivide_factor)836pg = (837handle._fake_process_group838if handle._use_fake_reduce839else state.process_group840)841dist.reduce_scatter_tensor(842new_sharded_grad,843padded_unsharded_grad,844group=pg,845)846if uses_hybrid_sharded_strategy:847# Don't wait during trace848if not torch.distributed._functional_collectives.is_torchdynamo_compiling():849state._all_reduce_stream.wait_stream(state._post_backward_stream)850with state._device_handle.stream(state._all_reduce_stream):851# Since the new sharded gradient is produced in the post-852# backward stream and consumed in the all-reduce stream,853# inform the caching allocator854_no_dispatch_record_stream(new_sharded_grad, state._all_reduce_stream)855dist.all_reduce(new_sharded_grad, group=state._inter_node_pg)856_div_if_needed(new_sharded_grad, state._gradient_postdivide_factor)857grad_to_offload = _accumulate_sharded_grad(858state, handle, new_sharded_grad859)860_post_reduce_grad_callback(state, handle, grad_to_offload)861return862_div_if_needed(new_sharded_grad, state._gradient_postdivide_factor)863else:864state._comm_hook(865state._comm_hook_state, padded_unsharded_grad, new_sharded_grad866)867# NOTE: HSDP variants do not support communication hook.868grad_to_offload = _accumulate_sharded_grad(state, handle, new_sharded_grad)869_post_reduce_grad_callback(state, handle, grad_to_offload)870
871
872@no_type_check
873def _get_reduce_scatter_tensors(874state: _FSDPState, unsharded_grad: torch.Tensor875) -> Tuple[torch.Tensor, torch.Tensor]:876"""877Returns the input and output tensors to reduce-scatter, respectively.
878"""
879chunks = list(unsharded_grad.chunk(state.world_size))880numel_to_pad = state.world_size * chunks[0].numel() - unsharded_grad.numel()881padded_unsharded_grad = (882F.pad(unsharded_grad, [0, numel_to_pad]) if numel_to_pad > 0 else unsharded_grad883)884new_sharded_grad = torch.empty_like(chunks[0]) # padded885return padded_unsharded_grad, new_sharded_grad886
887
888@no_type_check
889def _accumulate_sharded_grad(890state: _FSDPState,891handle: FlatParamHandle,892sharded_grad: torch.Tensor,893) -> torch.Tensor:894"""895Accumulates the reduce-scattered sharded gradient with any existing sharded
896gradient if needed, returning the gradient to offload (if CPU offloading is
897enabled).
898"""
899flat_param = handle.flat_param900_cast_grad_to_param_dtype(state, sharded_grad, flat_param)901# Save the sharded gradient in `_saved_grad_shard` to support gradient902# accumulation -- for multiple backwards, the gradient reductions may903# happen in arbitrary order904accumulate_grad = hasattr(flat_param, "_saved_grad_shard")905if accumulate_grad:906_check_grad_to_accumulate(sharded_grad, flat_param._saved_grad_shard)907flat_param._saved_grad_shard += sharded_grad908else:909flat_param._saved_grad_shard = sharded_grad910grad_to_offload = flat_param._saved_grad_shard911return grad_to_offload912
913
914@no_type_check
915def _reduce_grad_no_shard(state: _FSDPState, handle: FlatParamHandle) -> None:916"""917For no-shard, this runs gradient reduction (which directly covers any
918gradient accumulation implicitly) and the post-reduction callback.
919"""
920flat_param = handle.flat_param921if state._comm_hook is None: # default path922_div_if_needed(flat_param.grad, state._gradient_predivide_factor)923dist.all_reduce(flat_param.grad, group=state.process_group)924_div_if_needed(flat_param.grad, state._gradient_postdivide_factor)925else:926state._comm_hook(state._comm_hook_state, flat_param.grad)927# For `NO_SHARD`, we can keep the low precision gradients by simply928# omitting the cast altogether929if not handle._keep_low_precision_grads:930_cast_grad_to_param_dtype(state, flat_param.grad, flat_param)931grad_to_offload = flat_param.grad.data932_post_reduce_grad_callback(state, handle, grad_to_offload)933
934
935@no_type_check
936def _post_reduce_grad_callback(937state: _FSDPState,938handle: FlatParamHandle,939# Additional arguments needed for the callback logic940grad_to_offload: torch.Tensor,941):942"""943This callback captures any logic to run after the gradient reduction
944finishes. Currently, this offloads the gradient to CPU if CPU offloading is
945enabled and uses sharded gradient views if ``use_orig_params=True``.
946"""
947_offload_grad(state, handle, grad_to_offload)948_post_backward_use_sharded_grad_views(handle)949
950
951@no_type_check
952def _offload_grad(953state: _FSDPState,954handle: FlatParamHandle,955grad_to_offload: torch.Tensor,956):957if not handle._offload_params:958return959# Offload the gradient to CPU to ensure parameters and gradients are on the960# same device as required by the optimizer961# TODO: Investigate why `NO_SHARD` breaks correctness when using962# `non_blocking=True` here.963# TODO (rohan-varma): When CPU offload and optimizer overlap,964# non_blocking=True won't work since the copy may have not finished before965# the optimizer step executes on CPU. If we want to use non-blocking=True966# here, we'll have to synchronize before using result on CPU.967non_blocking = handle.uses_sharded_strategy and not handle._has_optim_in_backward968handle.flat_param._cpu_grad.copy_(969grad_to_offload.detach(), non_blocking=non_blocking970) # synchronized in the post-backward callback971# Since the gradient being offloaded may have been produced in the972# computation stream and is being consumed here in the post-backward973# stream, inform the caching allocator974_no_dispatch_record_stream(grad_to_offload.data, state._post_backward_stream)975
976
977@no_type_check
978def _post_backward_use_sharded_grad_views(handle: FlatParamHandle):979if not handle._use_orig_params:980return981# Since the handle's `FlatParameter` completed its gradient computation, we982# should reset the gradient noneness mask983handle._reset_is_grad_none()984# Delay using sharded gradient views until after the reduce-scatter instead985# of immediately after resharding986handle._use_sharded_grad_views()987if handle._has_optim_in_backward:988handle.prepare_gradient_for_optim()989for orig_param in handle.flat_param._params:990# Check for `None` gradient to filter parameters not in the rank991if orig_param.grad is not None and hasattr(992orig_param, "_in_backward_optimizers"993):994# TODO (rohan-varma): For CPU offload, this unfortunately995# operates on CPU because the parameters and gradients have996# already been offloaded. We should run this on GPU after997# refactoring.998for optim in orig_param._in_backward_optimizers:999optim.step()1000
1001optim.zero_grad(set_to_none=True)1002handle._reset_flat_param_grad_info_if_needed()1003if handle._offload_params:1004handle.flat_param._cpu_grad = None1005
1006
1007def _div_if_needed(tensor: torch.Tensor, div_factor: float) -> None:1008if div_factor > 1:1009tensor.div_(div_factor)1010
1011
1012@no_type_check
1013def _cast_grad_to_param_dtype(1014state: _FSDPState,1015sharded_grad: torch.Tensor,1016param: FlatParameter,1017):1018"""1019Casts ``sharded_grad`` back to the full parameter dtype so that the
1020optimizer step runs with that dtype. This performs an actual cast if
10211. parameters were in reduced precision during the forward since then
1022gradients would be in that reduced precision, or
10232. parameters were not in reduced precision but gradients were in
1024reduced precision for communication.
1025However, if a low precision communication hook is registered, then this
1026dtype cast happens in the hook instead.
1027"""
1028_assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])1029if not _low_precision_hook_enabled(state) and sharded_grad.dtype != param.dtype:1030low_prec_grad_data = sharded_grad.data1031sharded_grad.data = sharded_grad.data.to(dtype=param.dtype)1032# Since for `NO_SHARD`, the gradient is produced in the computation1033# stream and consumed here in the post-backward stream, inform the1034# caching allocator; for the sharded strategies, the gradient is1035# produced in the post-backward stream, so this `record_stream()`1036# should be a no-op1037_no_dispatch_record_stream(1038low_prec_grad_data, state._device_handle.current_stream()1039)1040
1041
1042def _check_grad_to_accumulate(1043new_sharded_grad: torch.Tensor,1044accumulated_grad: torch.Tensor,1045) -> None:1046_p_assert(1047accumulated_grad.shape == new_sharded_grad.shape,1048"Shape mismatch when accumulating gradients: "1049f"existing gradient shape={accumulated_grad.shape} "1050f"new gradient shape={new_sharded_grad.shape}",1051)1052_p_assert(1053accumulated_grad.device == new_sharded_grad.device,1054"Device mismatch when accumulating gradients: "1055f"existing gradient device={accumulated_grad.device} "1056f"new gradient device={new_sharded_grad.device}",1057)1058
1059
1060@no_type_check
1061def _low_precision_hook_enabled(state: _FSDPState) -> bool:1062return state._comm_hook in LOW_PRECISION_HOOKS1063
1064
1065@no_type_check
1066@torch.no_grad()1067def _post_backward_final_callback(1068state: _FSDPState,1069module: nn.Module,1070):1071"""1072This waits for the post-backward to finish and performs some final cleanup.
1073This runs at the end of the entire backward pass and should only be called
1074on the root FSDP instance.
1075"""
1076_p_assert(1077state._is_root,1078"The post-backward callback should only be called on the root FSDP instance",1079)1080root_state = state1081
1082if root_state._sync_gradients:1083current_stream = state._device_handle.current_stream()1084# TODO (rohan-varma): this also waits for the overlapped optimizer step to finish1085# since it currently runs in the post-backward stream. That can be1086# pushed to the next forward if run in a different stream1087current_stream.wait_stream(root_state._post_backward_stream)1088if root_state._all_reduce_stream is not current_stream: # uses HSDP1089current_stream.wait_stream(root_state._all_reduce_stream)1090if root_state.cpu_offload.offload_params:1091# Wait for non-blocking GPU -> CPU sharded gradient copies from the1092# post-backward hooks to finish explicitly since CPU gradients do1093# not automatically synchronize with the GPU1094state._device_handle.current_stream().synchronize()1095root_state._exec_order_data.next_iter()1096
1097for fsdp_state in state._all_fsdp_states:1098_catch_all_reshard(fsdp_state)1099_finalize_params(fsdp_state)1100fsdp_state.training_state = TrainingState.IDLE1101handle = fsdp_state._handle1102if handle:1103handle._ran_pre_backward_hook = False1104handle._needs_pre_backward_unshard = False1105handle._post_forward_index = None1106handle._training_state = HandleTrainingState.IDLE1107handle._prefetched = False1108# Reset for cases like one forward and multiple backwards1109root_state._post_backward_callback_queued = False1110
1111
1112@no_type_check
1113def _catch_all_reshard(1114state: _FSDPState,1115) -> None:1116"""1117Reshards the parameters that may not have been resharded in the
1118post-backward hook. This can happen when a module's output is used in the
1119forward pass, meaning that its pre-backward hook runs (unsharding the
1120parameter), but the post-backward hook does not run because the output was
1121not jused in the loss computation corresponding to this backward pass.
1122"""
1123# Wrap with a try-except to provide a more informative traceback if an1124# error is raised1125try:1126if state._handle:1127# TODO: This already-resharded check is brittle:1128# https://github.com/pytorch/pytorch/issues/839561129already_resharded = (1130state._handle.flat_param.data_ptr()1131== state._handle.flat_param._local_shard.data_ptr()1132# If FSDP skipped using sharded views, then the flat parameter1133# still points to the sharded data, so we need to reshard to1134# use sharded views1135and not state._handle._skipped_use_sharded_views1136)1137if already_resharded:1138return1139free_unsharded_flat_param = _should_free_in_backward(state, state._handle)1140_reshard(state, state._handle, free_unsharded_flat_param)1141except Exception as e:1142_p_assert(1143False,1144f"Got exception in the catch-all reshard for {state}: {str(e)}",1145raise_assertion_error=False,1146)1147raise e1148
1149
1150@no_type_check
1151def _finalize_params(1152state: _FSDPState,1153) -> None:1154"""Finalizes the parameters before the next iteration."""1155handle = state._handle1156if not handle:1157return1158flat_param = handle.flat_param1159if torch.distributed._functional_collectives.is_torchdynamo_compiling():1160if hasattr(flat_param, "_post_backward_hook_handle"):1161pbhs_handle = flat_param._post_backward_hook_handle1162pbhs_handle.remove()1163del flat_param._post_backward_hook_handle1164else:1165if hasattr(flat_param, "_post_backward_hook_state"):1166post_backward_hook_state_len = len(flat_param._post_backward_hook_state)1167expected_post_backward_hook_state_len = int(flat_param.requires_grad) + 11168_p_assert(1169post_backward_hook_state_len == expected_post_backward_hook_state_len,1170f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",1171)1172flat_param._post_backward_hook_state[-1].remove()1173delattr(flat_param, "_post_backward_hook_state")1174if flat_param.requires_grad:1175if not state._sync_gradients:1176# Preserve the gradient accumulation state if not synchronizing1177# gradients: `.grad` remains the unsharded gradient from prior1178# `no_sync()` iterations, and `_saved_grad_shard` remains the1179# sharded gradient from the last synchronized iteration1180return1181if not handle._has_optim_in_backward:1182handle.prepare_gradient_for_optim()1183_p_assert(1184hasattr(flat_param, "_post_backward_called"),1185"Expects `_post_backward_called` to be set on the `FlatParameter`",1186)1187flat_param._post_backward_called = False1188
1189
1190@no_type_check
1191def _prefetch_handle(1192state: _FSDPState,1193current_handle: Optional[FlatParamHandle],1194prefetch_mode: _PrefetchMode,1195) -> None:1196"""1197Prefetches the next handles if needed (without synchronization). An empty
1198handles key cannot prefetch.
1199"""
1200if not current_handle:1201return1202handle = _get_handle_to_prefetch(state, current_handle)1203if not handle:1204return1205# Temporarily emulate the training state while calling `_unshard` to1206# ensure the correct `as_params` for `_use_unsharded_views()`1207prev_training_state = handle._training_state1208if prefetch_mode == _PrefetchMode.BACKWARD:1209handle._training_state = HandleTrainingState.BACKWARD_PRE1210elif prefetch_mode == _PrefetchMode.FORWARD:1211handle._training_state = HandleTrainingState.FORWARD1212else:1213raise ValueError(f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}")1214# Prefetch the next set of handles without synchronizing to allow1215# the sync to happen as late as possible to maximize overlap1216_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)1217handle._training_state = prev_training_state1218handle._prefetched = True1219
1220
1221@no_type_check
1222def _get_handle_to_prefetch(1223state: _FSDPState,1224current_handle: FlatParamHandle,1225) -> FlatParamHandle:1226"""1227Returns a :class:`list` of the handles keys to prefetch for the next
1228module(s), where ``current_handle`` represents the current module.
1229
1230"Prefetching" refers to running the unshard logic early (without
1231synchronization), and the "next" modules depend on the recorded execution
1232order and the current training state.
1233"""
1234training_state = _get_training_state(current_handle)1235valid_training_states = (1236HandleTrainingState.BACKWARD_PRE,1237HandleTrainingState.BACKWARD_POST,1238HandleTrainingState.FORWARD,1239)1240_p_assert(1241training_state in valid_training_states,1242f"Prefetching is only supported in {valid_training_states} but "1243f"currently in {training_state}",1244)1245eod = state._exec_order_data1246target_handle: Optional[FlatParamHandle] = None1247if (1248training_state == HandleTrainingState.BACKWARD_PRE1249and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE1250) or (1251training_state == HandleTrainingState.BACKWARD_POST1252and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST1253):1254target_handle_candidate = eod.get_handle_to_backward_prefetch(current_handle)1255if (1256target_handle_candidate
1257and target_handle_candidate._needs_pre_backward_unshard1258and not target_handle_candidate._prefetched1259):1260target_handle = target_handle_candidate1261else:1262target_handle = None1263elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch:1264target_handle_candidate = eod.get_handle_to_forward_prefetch(current_handle)1265if (1266target_handle_candidate
1267and target_handle_candidate._needs_pre_forward_unshard1268and not target_handle_candidate._prefetched1269):1270target_handle = target_handle_candidate1271else:1272target_handle = None1273
1274return target_handle1275
1276
1277def _get_training_state(1278handle: FlatParamHandle,1279) -> HandleTrainingState:1280"""Returns the training state of the handles in ``handle``."""1281_p_assert(handle, "Expects a non-empty handle")1282return handle._training_state1283
1284
1285@no_type_check
1286def _register_pre_forward_hook(1287state: _FSDPState,1288module: nn.Module,1289) -> None:1290"""1291Registers a pre-forward hook on ``module``.
1292"""
1293for forward_handle in state._pre_forward_handles:1294forward_handle.remove()1295state._pre_forward_handles.clear()1296module_param_handle = state._fully_sharded_module_to_handle.get(module, None)1297hook = functools.partial(1298_pre_forward, state, module_param_handle, _pre_forward_unshard1299)1300state._pre_forward_handles.append(1301module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)1302)1303
1304
1305@no_type_check
1306def _register_post_forward_hook(1307state: _FSDPState,1308module: nn.Module,1309) -> None:1310"""1311Registers a post-forward hook on ``module``. Even if the module has no
1312handles, we should register the hook since it will register the module's
1313pre-backward hook.
1314"""
1315for forward_handle in state._post_forward_handles:1316forward_handle.remove()1317state._post_forward_handles.clear()1318module_param_handle = state._fully_sharded_module_to_handle.get(module, None)1319hook = functools.partial(1320_post_forward,1321state,1322module_param_handle,1323_post_forward_reshard,1324)1325state._post_forward_handles.append(module.register_forward_hook(hook))1326
1327
1328@no_type_check
1329def _register_root_pre_forward_hook(1330state: _FSDPState,1331module: nn.Module,1332):1333"""1334Registers root pre-forward hook on ``module``, which should be the local
1335FSDP root.
1336
1337NOTE: For the current composable FSDP design, we have each application of
1338``fully_shard()`` to a module to indicate that that module is the local
1339FSDP root. We may remove this assumption in the future, in which case we
1340will need to register this root pre-forward hook on any candidate module
1341that may be the local FSDP root.
1342"""
1343for forward_handle in state._root_pre_forward_handles:1344forward_handle.remove()1345state._root_pre_forward_handles.clear()1346hook = functools.partial(_root_pre_forward, state)1347state._root_pre_forward_handles.append(1348module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)1349)1350
1351
1352@no_type_check
1353def _register_pre_backward_hooks(1354state: _FSDPState,1355module: nn.Module,1356outputs: Any,1357handle: FlatParamHandle,1358) -> None:1359"""1360Registers pre-backward hooks on the tensors that require gradients in the
1361forward pass outputs ``outputs``, which were computed using the
1362``FlatParameter`` s of ``handles``.
1363
1364Args:
1365module (nn.Module): Fully sharded module (see [Note: Fully Sharded
1366Module]).
1367
1368Returns:
1369Forward pass outputs with pre-backward hooks registered to tensors that
1370require gradients.
1371"""
1372# If there is no gradient computation, then there is no need for1373# pre-backward logic1374if not torch.is_grad_enabled():1375return outputs1376if state._is_root:1377state._post_backward_callback_queued = False # only defined on the root1378
1379if handle:1380handle._needs_pre_backward_unshard = False1381# Since these handles' `FlatParameter`s participated in a forward, we1382# conservatively assume that they will be used in the backward1383handle._ran_pre_backward_hook = False1384
1385def _register_hook(t: torch.Tensor) -> torch.Tensor:1386if t.requires_grad:1387t.register_hook(1388functools.partial(_pre_backward_hook, state, module, handle)1389)1390if handle:1391handle._needs_pre_backward_unshard = True1392return t1393
1394return _apply_to_tensors(_register_hook, outputs)1395
1396
1397def _register_post_backward_hook(1398state: _FSDPState,1399handle: Optional[FlatParamHandle],1400) -> None:1401"""1402Registers post-backward hooks on the ``FlatParameter`` s'
1403``AccumulateGrad`` objects to reshard and to reduce-scatter gradients.
1404
1405The ``AccumulateGrad`` object represents the last function that finalizes
1406the ``FlatParameter`` 's gradient, so it only runs after its entire
1407gradient computation has finished.
1408
1409We register the post-backward hook only once in the *first* forward that a
1410``FlatParameter`` participates in. This relies on the ``AccumulateGrad``
1411object being preserved through multiple forwards.
1412
1413NOTE: We follow this heuristic to prefer the *first* forward to target the
1414parameter mixed precision case, where there are *separate*
1415``AccumulateGrad`` objects across the different forwards. (Without
1416parameter mixed precision, the ``AccumulateGrad`` objects are the same.) If
1417we instead prefer the *last* forward, then the hook runs early.
1418"""
1419# If there is no gradient computation, then there is no need for1420# post-backward logic1421if not torch.is_grad_enabled():1422return1423if not handle:1424return1425flat_param = handle.flat_param1426
1427if torch.distributed._functional_collectives.is_torchdynamo_compiling():1428already_registered = hasattr(flat_param, "_post_backward_hook_handle")1429if already_registered or not flat_param.requires_grad:1430return1431hook = functools.partial(_post_backward_hook, state, handle)1432hook_handle = flat_param.register_post_accumulate_grad_hook(hook)1433flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined]1434else:1435already_registered = hasattr(flat_param, "_post_backward_hook_state")1436if already_registered or not flat_param.requires_grad:1437return1438# Get the `AccumulateGrad` object1439temp_flat_param = flat_param.expand_as(flat_param)1440_p_assert(1441temp_flat_param.grad_fn is not None,1442"The `grad_fn` is needed to access the `AccumulateGrad` and "1443"register the post-backward hook",1444)1445acc_grad = temp_flat_param.grad_fn.next_functions[0][0] # type: ignore[union-attr]1446assert acc_grad is not None1447hook_handle = acc_grad.register_hook(1448functools.partial(_post_backward_hook, state, handle)1449)1450flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined]1451
1452
1453def _register_post_backward_reshard_only_hook(1454state: _FSDPState,1455handle: Optional[FlatParamHandle],1456args: Tuple[Any, ...],1457kwargs: Dict[str, Any],1458) -> None:1459"""1460Registers post-backward hooks to reshard flat parameters that do not
1461require gradient. We register these using multi-post-grad hooks on the
1462input activations to ensure that all gradients that may depend on the
1463parameters have been computed before resharding.
1464"""
1465# If there is no gradient computation, then there is no need for1466# post-backward logic1467if not torch.is_grad_enabled():1468return1469# Construct `inp_tensors` lazily to avoid CPU overhead in typical case1470# where each flat parameter requires gradient1471inp_tensors: Optional[List[torch.Tensor]] = None1472if not handle:1473return1474flat_param = handle.flat_param1475
1476if torch.distributed._functional_collectives.is_torchdynamo_compiling():1477already_registered = hasattr(flat_param, "_post_backward_hook_handle")1478else:1479already_registered = hasattr(flat_param, "_post_backward_hook_state")1480
1481if already_registered or flat_param.requires_grad:1482return1483if inp_tensors is None:1484args_flat = pytree.arg_tree_leaves(*args, **kwargs)1485inp_tensors = [1486obj for obj in args_flat if torch.is_tensor(obj) and obj.requires_grad1487]1488assert inp_tensors is not None # mypy1489hook_handle = register_multi_grad_hook(1490inp_tensors, functools.partial(_post_backward_reshard_only_hook, state, handle)1491)1492if torch.distributed._functional_collectives.is_torchdynamo_compiling():1493flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined, assignment]1494else:1495flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined, assignment]1496
1497
1498@no_type_check
1499def _register_post_backward_final_callback(1500state: _FSDPState, module: nn.Module1501) -> None:1502"""1503Registers the post-backward final callback that runs at the end of the
1504backward pass. This should be called from the root FSDP instance at the
1505beginning of the pre-backward.
1506"""
1507_p_assert(1508state._is_root,1509"Only the root FSDP instance should register the post-backward callback",1510)1511if state._post_backward_callback_queued:1512return1513_assert_in_training_states(state, [TrainingState.IDLE])1514# Trace does not need this callback1515if not torch.distributed._functional_collectives.is_torchdynamo_compiling():1516state._post_backward_callback_queued = True1517Variable._execution_engine.queue_callback(1518functools.partial(_post_backward_final_callback, state, module)1519)1520
1521
1522def _wait_for_computation_stream(1523computation_stream: torch.Stream,1524unshard_stream: torch.Stream,1525pre_unshard_stream: torch.Stream,1526):1527"""1528Has the unshard and pre-unshard streams wait for the computation stream.
1529For example, this should be called in the FSDP root's pre-forward to
1530respect optimizer step computation.
1531"""
1532# Tracing does not need to wait1533if torch.distributed._functional_collectives.is_torchdynamo_compiling():1534return1535unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]1536# Having the pre-all-gather stream wait for the current stream even if we1537# do not leverage the pre-all-gather stream is tolerable since this only1538# runs once per iteration1539pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]1540
1541
1542def _reset_flat_param_grad_info_if_needed(1543handles: List[FlatParamHandle],1544):1545"""1546Clears the original parameters' gradients if needed. This method's CPU
1547overhead is minimal, so we may call it throughout FSDP methods, which serve
1548as callsites to free the gradient memory earlier.
1549"""
1550if not isinstance(handles, list):1551handles = [handles]1552for handle in handles:1553if handle._use_orig_params:1554handle._reset_flat_param_grad_info_if_needed()1555
1556
1557@no_type_check
1558def _get_buffers_and_dtypes_for_computation(1559state: _FSDPState,1560root_module: nn.Module,1561) -> Tuple[List[torch.Tensor], List[Optional[torch.dtype]]]:1562"""1563Returns all buffers in the module tree rooted at ``root_module`` and a
1564corresponding list of the buffer dtypes for computation. Each buffer dtype
1565is either ``None`` if buffer mixed precision is not enabled or the buffer
1566low precision dtype otherwise.
1567"""
1568_p_assert(state._is_root, "Expects the root to cast buffers")1569buffers: List[torch.Tensor] = []1570buffer_dtypes: List[Optional[torch.dtype]] = []1571visited_buffers: Set[torch.Tensor] = set()1572# Traverse the FSDP states bottom-up so that we prefer the owning FSDP1573# instance's mixed precision setting for each buffer1574fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules(1575root_module
1576)1577for fsdp_state, fsdp_module in zip(reversed(fsdp_states), reversed(fsdp_modules)):1578for buffer_name, buffer in fsdp_module.named_buffers():1579if buffer in visited_buffers:1580continue1581visited_buffers.add(buffer)1582if clean_tensor_name(buffer_name) in fsdp_state._ignored_buffer_names:1583continue1584buffers.append(buffer)1585buffer_dtypes.append(fsdp_state.mixed_precision.buffer_dtype)1586assert len(buffers) == len(buffer_dtypes), f"{len(buffers)} {len(buffer_dtypes)}"1587return buffers, buffer_dtypes1588
1589
1590@no_type_check
1591def _get_orig_buffer_dtypes(1592state: _FSDPState,1593buffer_names: List[str],1594) -> List[torch.dtype]:1595"""1596Returns the original buffer types of the given buffer names.
1597"""
1598buffer_dtypes: List[torch.dtype] = []1599for buffer_name in buffer_names:1600_p_assert(1601buffer_name in state._buffer_name_to_orig_dtype,1602f"{buffer_name} is missing from pre-computed dict on rank "1603f"{state.rank}, which only has keys "1604f"{state._buffer_name_to_orig_dtype.keys()}",1605)1606buffer_dtypes.append(state._buffer_name_to_orig_dtype[buffer_name])1607return buffer_dtypes1608
1609
1610def _cast_buffers_to_dtype_and_device(1611buffers: List[torch.Tensor],1612buffer_dtypes: List[Optional[torch.dtype]],1613device: torch.device,1614) -> None:1615"""1616Casts ``buffers`` to the dtypes given by ``buffer_dtypes`` and moves them
1617to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
1618corresponding buffer is only moved to ``device``.
1619"""
1620_p_assert(1621buffer_dtypes is None or len(buffers) == len(buffer_dtypes),1622f"Expects `buffers` and `buffer_dtypes` to have the same length if "1623f"`buffer_dtypes` is specified but got {len(buffers)} and "1624f"{len(buffer_dtypes)}",1625)1626for buffer, buffer_dtype in zip(buffers, buffer_dtypes):1627if not torch.is_floating_point(buffer) or buffer_dtype is None:1628buffer.data = buffer.to(device=device)1629else:1630buffer.data = buffer.to(device=device, dtype=buffer_dtype)1631