pytorch

Форк
0
/
_runtime_utils.py 
1630 строк · 64.4 Кб
1
import functools
2
import logging
3
from enum import auto, Enum
4
from typing import Any, Callable, Dict, List, no_type_check, Optional, Set, Tuple
5

6
import torch
7
import torch.distributed as dist
8
import torch.distributed.fsdp._traversal_utils as traversal_utils
9
import torch.nn as nn
10
import torch.nn.functional as F
11
from torch.autograd import Variable
12
from torch.autograd.graph import register_multi_grad_hook
13
from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS
14
from torch.distributed.fsdp._common_utils import (
15
    _assert_in_training_states,
16
    _FSDPState,
17
    _get_module_fsdp_state,
18
    _is_composable,
19
    _log_post_backward_hook,
20
    _no_dispatch_record_stream,
21
    clean_tensor_name,
22
    TrainingState,
23
)
24
from torch.distributed.fsdp._flat_param import (
25
    FlatParameter,
26
    FlatParamHandle,
27
    HandleShardingStrategy,
28
    HandleTrainingState,
29
    RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES,
30
)
31
from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
32
from torch.distributed.fsdp.api import BackwardPrefetch
33
from torch.distributed.utils import (
34
    _apply_to_tensors,
35
    _cast_forward_inputs,
36
    _p_assert,
37
    _to_kwargs,
38
)
39
from torch.utils import _pytree as pytree
40

41
log = logging.getLogger(__name__)
42

43
# Do not include "process_group" to enable hybrid shard and MoE cases
44
HOMOGENEOUS_ATTR_NAMES = (
45
    "_use_orig_params",
46
    "limit_all_gathers",
47
    "_use_full_prec_in_eval",
48
)
49

50

51
class _PrefetchMode(Enum):
52
    BACKWARD = auto()
53
    FORWARD = auto()
54

55

56
def _get_fsdp_root_states_with_modules(
57
    module: nn.Module,
58
) -> Tuple[List[_FSDPState], List[nn.Module]]:
59
    """
60
    Returns a tuple containing:
61
    1. A list of the root ``_FSDPState`` instances in the module tree rooted at
62
    ``module`` without any duplicates and following the ``module.modules()``
63
    traversal order (which is assumed to be depth-first).
64
    2. A corresponding list of the root modules owning the states in the first
65
    list.
66

67
    This is similar to :func:`_get_fsdp_states_with_modules` except that we
68
    must call :func:`_is_fsdp_root` to force a lazy initialization to determine
69
    the FSDP root in case lazy initialization has not yet happened.
70
    """
71
    fsdp_root_states: List[_FSDPState] = []
72
    fsdp_root_modules: List[nn.Module] = []
73
    visited_fsdp_states: Set[_FSDPState] = set()
74
    # NOTE: This function assumes that `module.modules()` proceeds top-down.
75
    for submodule in module.modules():
76
        optional_state = _get_module_fsdp_state(submodule)
77
        if (
78
            optional_state is not None
79
            and optional_state not in visited_fsdp_states
80
            and _is_fsdp_root(optional_state, submodule)
81
        ):
82
            visited_fsdp_states.add(optional_state)
83
            fsdp_root_states.append(optional_state)
84
            fsdp_root_modules.append(submodule)
85
    return fsdp_root_states, fsdp_root_modules
86

87

88
def _get_fsdp_root_states(module: nn.Module) -> List[_FSDPState]:
89
    """See :func:`_get_fsdp_root_states_with_modules`."""
90
    fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module)
91
    return fsdp_root_states
92

93

94
def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool:
95
    """
96
    Returns if ``state`` corresponds to that of an FSDP root.
97

98
    For the wrapper code path, ``state`` and ``module`` should be the same. For
99
    the non-wrapper code path, ``state`` should be ``module`` 's state.
100
    """
101
    # Force a lazy initialization to determine the FSDP root
102
    _lazy_init(state, module)
103
    assert state._is_root is not None  # mypy
104
    return state._is_root
105

106

107
@no_type_check
108
def _lazy_init(
109
    state: _FSDPState,
110
    root_module: nn.Module,
111
) -> _FSDPState:
112
    """
113
    Performs initialization lazily, typically right before the first forward
114
    pass. The laziness is needed to ensure that the parameter device/dtype and
115
    the FSDP hierarchy have finalized. This method's actual logic only runs on
116
    the root FSDP instance, which performs initialization for all non-root FSDP
117
    instances to avoid partial initialization.
118

119
    For the non-composable code path, ``state`` and ``root_module`` should be
120
    the same, namely the FSDP instance itself.
121
    """
122
    if state._is_root is not None:
123
        return  # no-op: already lazily initialized
124
    if not state._device_handle.is_available():
125
        # Allow the FSDP constructor to run even without CUDA but check this
126
        # once we start real execution
127
        raise RuntimeError("FSDP does not support CPU only execution")
128
    # The following logic is only run on the root FSDP instance since it will
129
    # set `_is_root=False` for the non-root instances
130
    state._is_root = True
131
    _assert_in_training_states(state, [TrainingState.IDLE])
132
    _check_flat_params_on_expected_device(state, root_module)
133
    state._all_fsdp_states = traversal_utils._get_fsdp_states(root_module)
134
    _init_streams(state)
135
    buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation(state, root_module)
136
    _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, state.compute_device)
137
    state._exec_order_data.init(state, root_module, state.process_group)
138
    _share_state_and_init_handle_attrs(state, root_module)
139
    return state
140

141

142
def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module):
143
    """
144
    Checks that all ``FlatParameter``s in ``module`` 's tree managed by
145
    ``state`` are on the expected device for *lazy initialization*.
146
    """
147
    cpu_device = torch.device("cpu")
148
    for handle in traversal_utils._get_fsdp_handles(module):
149
        if (
150
            not handle._offload_params
151
            and handle.flat_param.device != state.compute_device
152
        ):
153
            raise RuntimeError(
154
                "An FSDP-managed module unexpectedly has parameters on "
155
                f"{handle.flat_param.device}. Make sure to move the module to "
156
                f"{state.compute_device} before training."
157
            )
158
        elif handle._offload_params and handle.flat_param.device != cpu_device:
159
            raise RuntimeError(
160
                "An FSDP-managed module with parameter CPU offloading enabled "
161
                f"has parameters on {handle.flat_param.device}. Make sure to "
162
                f"not move the module from CPU when offloading parameters."
163
            )
164

165

166
@no_type_check
167
def _share_state_and_init_handle_attrs(
168
    root_state: _FSDPState,
169
    root_module: nn.Module,
170
) -> None:
171
    """
172
    Shares data structure state from the ``root_state`` to all FSDP states in
173
    ``root_module`` 's module tree, and initializes handle attributes. These
174
    are done together to require a single loop over the states.
175
    """
176
    handle = root_state._handle
177
    if handle:
178
        handle.init_flat_param_attributes()
179
    attr_name_to_values: Dict[str, Set[Any]] = {}
180
    for attr_name in HOMOGENEOUS_ATTR_NAMES:
181
        attr_name_to_values[attr_name] = set()
182
    root_state._all_handles = root_state._exec_order_data.all_handles  # share reference
183
    # Update _has_optim_in_backward for each handle.
184
    for handle in root_state._all_handles:
185
        flat_param = handle.flat_param
186
        if hasattr(flat_param, "_in_backward_optimizers"):
187
            raise RuntimeError(
188
                "FSDP optimizer in backward only supported with use_orig_params=True!"
189
            )
190
        handle._has_optim_in_backward = flat_param._params is not None and any(
191
            hasattr(param, "_in_backward_optimizers") for param in flat_param._params
192
        )
193
        if handle._has_optim_in_backward:
194
            torch._C._log_api_usage_once("fsdp.optimizer_in_backward")
195
    for fsdp_state in root_state._all_fsdp_states:
196
        for attr_name in HOMOGENEOUS_ATTR_NAMES:
197
            _p_assert(
198
                hasattr(fsdp_state, attr_name),
199
                f"FSDP state missing attribute {attr_name}",
200
            )
201
            attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name))
202
        if fsdp_state is root_state:
203
            continue
204
        # Relax the assert for non-root FSDP instances in case the nested
205
        # initialized module is wrapped again in FSDP later (e.g. after
206
        # training to run inference)
207
        _p_assert(
208
            fsdp_state._is_root is None or not fsdp_state._is_root,
209
            "Non-root FSDP instance's `_is_root` should not have been "
210
            "set yet or should have been set to `False`",
211
        )
212
        fsdp_state._is_root = False
213
        fsdp_state._unshard_stream = root_state._unshard_stream
214
        fsdp_state._post_backward_stream = root_state._post_backward_stream
215
        fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream
216
        fsdp_state._all_reduce_stream = root_state._all_reduce_stream
217
        fsdp_state._default_stream = root_state._default_stream
218
        fsdp_state._exec_order_data = root_state._exec_order_data
219
        fsdp_state._free_event_queue = root_state._free_event_queue
220
        if fsdp_state._fsdp_extension is not None:
221
            fsdp_state._fsdp_extension.compute_stream = root_state._default_stream
222
        handle = fsdp_state._handle
223
        if handle:
224
            handle.init_flat_param_attributes()
225
    for attr_name, attr_values in attr_name_to_values.items():
226
        if len(attr_values) != 1:
227
            raise ValueError(
228
                f"Expects one homogeneous value for {attr_name} but got {attr_values}"
229
            )
230

231

232
@no_type_check
233
def _init_streams(
234
    state: _FSDPState,
235
) -> None:
236
    """
237
    Initializes CUDA streams for overlapping communication, computation, and
238
    data transfers. The streams should be shared across FSDP instances.
239
    """
240
    assert state._is_root
241
    assert state._device_handle.is_available()
242
    uses_hybrid_sharding = any(
243
        fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES
244
        for fsdp_state in state._all_fsdp_states
245
    )
246
    # Prioritize all-gathers/reduce-scatters over async all-reduce for HSDP and
247
    # preserve the default priority of 0 otherwise
248
    high_priority = -1 if state.limit_all_gathers and uses_hybrid_sharding else 0
249
    # Default stream for computation
250
    state._default_stream = state._device_handle.current_stream()
251
    if state._fsdp_extension is not None:
252
        # set the compute stream to the FSDP extension
253
        state._fsdp_extension.compute_stream = state._default_stream
254

255
    # Stream for unshard logic, including allocating the all-gather destination
256
    # tensors and the all-gathers themselves
257
    state._unshard_stream = state._device_handle.Stream(priority=high_priority)
258
    # Stream for overlapping gradient reduction with the backward pass gradient
259
    # computation
260
    state._post_backward_stream = state._device_handle.Stream(priority=high_priority)
261
    # Stream for pre-unshard logic, namely allocations and writes for CPU
262
    # offloading (H2D copy) and mixed precision (low precision cast)
263
    state._pre_unshard_stream = state._device_handle.Stream(priority=high_priority)
264
    # Stream to run HSDP's all-reduce as async (if using HSDP)
265
    state._all_reduce_stream = (
266
        state._device_handle.Stream() if uses_hybrid_sharding else state._default_stream
267
    )
268

269

270
@no_type_check
271
def _unshard(
272
    state: _FSDPState,
273
    handle: FlatParamHandle,
274
    unshard_stream: torch.Stream,
275
    pre_unshard_stream: torch.Stream,
276
) -> None:
277
    """
278
    Unshards the handles in ``handles``. If the handles are in
279
    :meth:`summon_full_params` and are using mixed precision, then they are
280
    forced to full precision.
281

282
    Postcondition: handle's ``FlatParameter`` 's data is the padded
283
    unsharded flat parameter on the compute device.
284
    """
285
    if not handle:
286
        return
287
    with state._device_handle.stream(pre_unshard_stream):
288
        ran_pre_unshard = handle.pre_unshard()
289
    if ran_pre_unshard:
290
        unshard_stream.wait_stream(pre_unshard_stream)
291
    if state.limit_all_gathers:
292
        event = state._free_event_queue.dequeue_if_needed()
293
        if event:
294
            with torch.profiler.record_function(
295
                "FullyShardedDataParallel.rate_limiter"
296
            ):
297
                event.synchronize()
298
    with state._device_handle.stream(unshard_stream):
299
        handle.unshard()
300
        handle.post_unshard()
301

302

303
@no_type_check
304
def _reshard(
305
    state: _FSDPState,
306
    handle: FlatParamHandle,
307
    free_unsharded_flat_param: bool,
308
):
309
    """
310
    Reshards the handle. ``free_unsharded_flat_param`` indicates whether to
311
    free the handle's padded unsharded flat parameter.
312
    """
313
    handle.reshard(free_unsharded_flat_param)
314
    if state.limit_all_gathers and free_unsharded_flat_param:
315
        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
316
            # We don't run a even queue for freeing under torch compile atm
317
            # But maybe we need to? TODO(voz): Look into this
318
            free_event = state._device_handle.Event()
319
            free_event.record()
320
            state._free_event_queue.enqueue(free_event)
321
    handle.post_reshard()
322
    # Flat parameter freed or not, we always have to "unshard" the parameter
323
    # upon next access to get its shape correct.
324
    handle._prefetched = False
325

326

327
def _unshard_grads(
328
    handle: Optional[FlatParamHandle],
329
) -> None:
330
    if handle:
331
        handle.unshard_grad()
332

333

334
def _reshard_grads(
335
    handle: Optional[FlatParamHandle],
336
) -> None:
337
    if handle:
338
        handle.reshard_grad()
339

340

341
@no_type_check
342
def _pre_forward(
343
    state: _FSDPState,
344
    handle: Optional[FlatParamHandle],
345
    unshard_fn: Callable,
346
    module: nn.Module,
347
    args: Tuple[Any, ...],
348
    kwargs: Dict[str, Any],
349
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
350
    """
351
    Runs the pre-forward logic. This includes an opportunity to unshard
352
    currently sharded parameters such as those for the current forward and
353
    registering post-backward hooks for these current parameters. This function
354
    also converts forward ``args`` and ``kwargs`` to the given precision.
355

356
    Args:
357
        handles (List[FlatParamHandle]): Handles giving the parameters used in
358
            the current forward.
359
        unshard_fn (Optional[Callable]): A callable to unshard any currently
360
            sharded parameters or ``None`` to not do any unsharding.
361
        module (nn.Module): Module whose forward this method runs right before;
362
            expected by the hook signature.
363
        args (Tuple[Any, ...]): Module forward ``args``.
364
        kwargs (Dict[str, Any]): Module forward ``kwargs``.
365
    """
366
    with torch.profiler.record_function("FullyShardedDataParallel._pre_forward"):
367
        # For `fully_shard` + `checkpoint`, skip pre-forward logic in the
368
        # recomputed forward
369
        if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:
370
            # For both checkpoint implementations, we do not need to re-cast
371
            # inputs here since they will be checkpointed in the low precision
372
            # either by AC or normally by autograd as long as the AC region is
373
            # nested within FSDP
374
            return args, kwargs
375
        state.training_state = TrainingState.FORWARD_BACKWARD
376
        state._exec_order_data.record_pre_forward(handle, module.training)
377
        if handle:
378
            handle._training_state = HandleTrainingState.FORWARD
379
        if unshard_fn is not None:
380
            unshard_fn(state, handle)
381
        # Register post-backward hooks to reshard the parameters and reduce-scatter
382
        # their gradients. They must be re-registered every forward pass in case
383
        # the `grad_fn` is mutated.
384
        _register_post_backward_hook(state, handle)
385
        # We have to reallocate the _cpu_grad if optimizer overlap
386
        # set the grad to None in the backward pass.
387
        if handle and handle._offload_params and handle.flat_param._cpu_grad is None:
388
            handle.flat_param._cpu_grad = torch.zeros_like(
389
                handle.flat_param._local_shard, device=torch.device("cpu")
390
            ).pin_memory()
391

392
        should_cast_forward_inputs = (
393
            state._handle and not state._handle._force_full_precision
394
        )
395

396
        if should_cast_forward_inputs and state.mixed_precision.cast_forward_inputs:
397
            # Recursively convert args and kwargs to specified precision.
398
            input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
399
            args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
400
        _register_post_backward_reshard_only_hook(state, handle, args, kwargs)
401
        return args, kwargs
402

403

404
@no_type_check
405
def _pre_forward_unshard(
406
    state: _FSDPState,
407
    handle: Optional[FlatParamHandle],
408
) -> None:
409
    """Unshards parameters in the pre-forward."""
410
    if not handle:
411
        return
412
    # If the handles have been prefetched, then there is no need to call
413
    # `_unshard()` again
414
    if not handle._prefetched:
415
        _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
416
    handle._needs_pre_forward_unshard = False
417
    # Don't wait during trace
418
    if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
419
        state._device_handle.current_stream().wait_stream(state._unshard_stream)
420
    with torch.profiler.record_function(
421
        "FullyShardedDataParallel._pre_forward_prefetch"
422
    ):
423
        _prefetch_handle(state, handle, _PrefetchMode.FORWARD)
424

425

426
@no_type_check
427
def _post_forward(
428
    state: _FSDPState,
429
    handle: Optional[FlatParamHandle],
430
    reshard_fn: Callable,
431
    module: nn.Module,
432
    input: Any,
433
    output: Any,
434
) -> Any:
435
    """
436
    Runs the post-forward logic. This includes an opportunity to reshard
437
    currently unsharded parameters such as those used in the current forward
438
    and registering pre-backward hooks on the forward outputs.
439

440
    Args:
441
        handles (List[FlatParamHandle]): Handles giving the parameters used in
442
            the current forward.
443
        reshard_fn (Optional[Callable]): A callable to reshard any currently
444
            unsharded parameters (e.g. from the current forward) or ``None`` to
445
            not do any resharding.
446
        module (nn.Module): Module whose forward just ran, which should be a
447
            fully sharded module (see [Note: Fully Sharded Module]); expected
448
            by the hook signature.
449
        input (Any): Unused; expected by the hook signature.
450
        output (Any): Forward pass output; pre-backward hooks are registered on
451
            the tensors that require gradients in this output.
452

453
    Postcondition: Each ``FlatParameter`` 's data points to the sharded flat
454
    parameter.
455
    """
456
    with torch.profiler.record_function("FullyShardedDataParallel._post_forward"):
457
        # For `fully_shard` + `checkpoint`, skip post-forward logic in the
458
        # recomputed forward
459
        if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:
460
            return output
461

462
        state._exec_order_data.record_post_forward(handle)
463
        if reshard_fn is not None:
464
            reshard_fn(state, handle)
465
        # Register pre-backward hooks to unshard the flat parameters for the
466
        # gradient computation (if needed)
467
        output = _register_pre_backward_hooks(state, module, output, handle)
468
        state.training_state = TrainingState.IDLE
469
        if handle:
470
            handle._training_state = HandleTrainingState.IDLE
471
        return output
472

473

474
@no_type_check
475
def _post_forward_reshard(
476
    state: _FSDPState,
477
    handle: FlatParamHandle,
478
) -> None:
479
    """Reshards parameters in the post-forward."""
480
    if not handle:
481
        return
482
    # Do not free the root's parameters in the post-forward for `FULL_SHARD`
483
    # with the intention that they are immediately used for backward
484
    # computation (though this may not be true)
485
    free_unsharded_flat_param = (
486
        not state._is_root
487
        and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
488
    )
489
    _reshard(state, handle, free_unsharded_flat_param)
490

491

492
@no_type_check
493
def _root_pre_forward(
494
    state: _FSDPState,
495
    module: nn.Module,
496
    args,
497
    kwargs,
498
) -> None:
499
    """
500
    Runs pre-forward logic specific to the root FSDP instance, which should run
501
    before any individual module's pre-forward. This starts with an attempt at
502
    lazy initialization (which only runs non-vacuously once). Otherwise, if
503
    this is called on a non-root FSDP instance, then it returns directly.
504

505
    Args:
506
        module (nn.Module): Module for which this logic tries to run. It may or
507
            may not be the root. If not, then this method does not do anything.
508
    """
509
    with torch.profiler.record_function("FullyShardedDataParallel._root_pre_forward"):
510
        _lazy_init(state, module)
511
        _p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
512
        if not state._is_root:
513
            # Always cast forward inputs in the root of this local FSDP unit for mixed
514
            # precision, as this is where mixed precision could be configed.
515
            # This is more useful for auto wrapping that is recommended in composable path.
516
            # For manual wrapping, cast forward inputs on each local FSDP unit root will
517
            # increase some overhead, so not turned on for model wrapper path right now where
518
            # manual wrapping is more broadly used.
519
            if _is_composable(state):
520
                return _root_cast_forward_input(state, module, args, kwargs)
521
            return args, kwargs
522

523
        # We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers
524
        # are in full precision and if we should cast them back to lower precision, which happens when
525
        # exiting eval() mode.
526
        handle = state._handle
527
        if handle:
528
            should_cast_buffers_to_full_prec = handle._force_full_precision
529
        else:
530
            should_cast_buffers_to_full_prec = True
531

532
        if should_cast_buffers_to_full_prec:
533
            _cast_buffers_to_dtype_and_device(
534
                buffers=dict(module.named_buffers()).values(),
535
                buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()),
536
                device=state.compute_device,
537
            )
538
            # This flag is only set when we cast buffers to full precision, to avoid the
539
            # CPU overhead that can stem from retrieving all buffers and their types in the
540
            # following else branch.
541
            state._needs_buffer_dtype_restore_check = True
542
        elif getattr(state, "_needs_buffer_dtype_restore_check", False):
543
            # Check if buffers are in full precision and we need to cast them
544
            # back down.
545
            (
546
                buffers,
547
                buffer_dtypes_for_computation,
548
            ) = _get_buffers_and_dtypes_for_computation(state, module)
549
            if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0:
550
                if any(
551
                    buffer.dtype != buffer_dtype_for_computation
552
                    for buffer, buffer_dtype_for_computation in zip(
553
                        buffers, buffer_dtypes_for_computation
554
                    )
555
                ):
556
                    # Assume we have to cast everything if there is one mismatch
557
                    _cast_buffers_to_dtype_and_device(
558
                        buffers, buffer_dtypes_for_computation, state.compute_device
559
                    )
560
            # We don't have to check this again until we cast buffers to full precision again.
561
            state._needs_buffer_dtype_restore_check = False
562

563
        if state.forward_prefetch:
564
            handles = []
565
            for fsdp_state in state._all_fsdp_states:
566
                if fsdp_state._handle:
567
                    handles.append(fsdp_state._handle)
568
            for handle in handles:
569
                handle._needs_pre_forward_unshard = True
570
                handle._prefetched = False
571
        _wait_for_computation_stream(
572
            state._device_handle.current_stream(),
573
            state._unshard_stream,
574
            state._pre_unshard_stream,
575
        )
576
        _reset_flat_param_grad_info_if_needed(state._all_handles)
577

578
        # Prepares the forward inputs by moving them to ``compute_device``
579
        # TODO: Do not use the side stream for tensor copies for now; investigate
580
        # the perf with/without it.
581
        with torch.profiler.record_function("FullyShardedDataParallel._to_kwargs"):
582
            args_tuple, kwargs_tuple = _to_kwargs(
583
                args, kwargs, state.compute_device, False
584
            )
585
        args = args_tuple[0]
586
        kwargs = kwargs_tuple[0]
587

588
        return _root_cast_forward_input(state, module, args, kwargs)
589

590

591
@no_type_check
592
def _root_cast_forward_input(
593
    state: _FSDPState, module: torch.nn.Module, args, kwargs
594
) -> Tuple[Any, Any]:
595
    if state._handle:
596
        force_full_precision = not state._handle._force_full_precision
597
    else:
598
        force_full_precision = True
599

600
    should_cast_forward_inputs = (
601
        (module.training or not state._use_full_prec_in_eval) and force_full_precision
602
    ) and state.mixed_precision.cast_root_forward_inputs
603

604
    if should_cast_forward_inputs:
605
        input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
606
        args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
607

608
    return args, kwargs
609

610

611
@no_type_check
612
def _pre_backward_hook(
613
    state: _FSDPState,
614
    module: nn.Module,
615
    handle: FlatParamHandle,
616
    grad,
617
    *unused: Any,
618
) -> Any:
619
    """
620
    Prepares ``_handle`` 's ``FlatParameter`` s for gradient computation.
621

622
    Args:
623
        module (nn.Module): Fully sharded module (see [Note: Fully Sharded
624
            Module]).
625
    """
626
    # Only run the pre-backward hook once per group of handles involved in the
627
    # same module forward computation
628
    if (
629
        handle
630
        and hasattr(handle, "_ran_pre_backward_hook")
631
        and handle._ran_pre_backward_hook
632
    ):
633
        log.debug("%s %s", id(state), "Not Running pre backward! Already Ran!")
634
        return grad
635

636
    with torch.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"):
637
        # Queue the post-backward callback once for the root FSDP instance to
638
        # attach it to the outermost backward graph task so that it is called
639
        # after all backward calls complete
640
        if state._is_root and not state._post_backward_callback_queued:
641
            _register_post_backward_final_callback(state, module)
642
            _reset_flat_param_grad_info_if_needed(state._all_handles)
643
        elif handle:
644
            allowed_states = [TrainingState.IDLE]
645
            if _is_composable(state):
646
                allowed_states.append(TrainingState.FORWARD_BACKWARD)
647
            _assert_in_training_states(state, allowed_states)
648
        state.training_state = TrainingState.FORWARD_BACKWARD
649
        # Queueing the post-backward callback is the only logic that is not
650
        # per-handle in the pre-backward hook, so we can return early here if
651
        # there are no handles.
652
        if not handle:
653
            return grad
654
        handle._training_state = HandleTrainingState.BACKWARD_PRE
655

656
        if handle._needs_pre_backward_unshard:
657
            # If the handles have been prefetched, then there is no need to
658
            # call `_unshard()` again
659
            if not handle._prefetched:
660
                _unshard(
661
                    state,
662
                    handle,
663
                    state._unshard_stream,
664
                    state._pre_unshard_stream,
665
                )
666
            # Don't wait during trace
667
            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
668
                state._device_handle.current_stream().wait_stream(state._unshard_stream)
669

670
        # Set this to `False` to ensure that a mistargeted prefetch does not
671
        # actually unshard these handles
672
        handle._needs_pre_backward_unshard = False
673
        with torch.profiler.record_function(
674
            "FullyShardedDataParallel._pre_backward_prefetch"
675
        ):
676
            _prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
677
        handle.prepare_gradient_for_backward()
678
        handle._ran_pre_backward_hook = True
679
        return grad
680

681

682
@no_type_check
683
@torch.no_grad()
684
def _post_backward_hook(
685
    state: _FSDPState,
686
    handle: FlatParamHandle,
687
    flat_param,
688
    *unused: Any,
689
):
690
    """
691
    Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
692

693
    Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the
694
    unsharded gradient for the local batch.
695

696
    Postcondition:
697
    - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced
698
    unsharded gradient.
699
    - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
700
    gradient (accumulating with any existing gradient).
701
    """
702
    _log_post_backward_hook(state, handle, log)
703
    flat_param = handle.flat_param
704
    flat_param._post_backward_called = True
705
    with torch.autograd.profiler.record_function(
706
        "FullyShardedDataParallel._post_backward_hook"
707
    ):
708
        _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
709
        # For multiple applications of reentrant AC across submodules sharing
710
        # the same `FlatParameter`, the post-backward hook may run multiple
711
        # times in one backward, in which case we permit the state to already
712
        # be in `BACKWARD_POST`.
713
        _p_assert(
714
            handle._training_state
715
            in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
716
            f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
717
        )
718
        handle._training_state = HandleTrainingState.BACKWARD_POST
719

720
        if flat_param.grad is None:
721
            return
722
        if flat_param.grad.requires_grad:
723
            raise RuntimeError("FSDP does not support gradients of gradients")
724

725
        _post_backward_reshard(state, handle)
726
        if not state._sync_gradients:
727
            if handle._use_orig_params:
728
                handle._use_unsharded_grad_views()
729
            return
730

731
        # Wait for all ops in the current stream (e.g. gradient computation) to
732
        # finish before reduce-scattering the gradient
733
        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
734
            state._post_backward_stream.wait_stream(
735
                state._device_handle.current_stream()
736
            )
737

738
        with state._device_handle.stream(state._post_backward_stream):
739
            autograd_computed_grad = flat_param.grad.data
740
            if (
741
                not _low_precision_hook_enabled(state)
742
                and flat_param.grad.dtype != handle._reduce_dtype
743
                # If we are forcing full precision but communicating grads
744
                # (i.e. model.eval() + full precision in eval was configured), don't downcast gradient.
745
                and not handle._force_full_precision
746
            ):
747
                flat_param.grad.data = flat_param.grad.to(handle._reduce_dtype)
748
            if handle.uses_sharded_strategy:
749
                _reduce_grad(state, handle)
750
            else:
751
                _reduce_grad_no_shard(state, handle)
752
            # Since the unsharded gradient is produced in the computation
753
            # stream and consumed in the post-backward stream, inform the
754
            # caching allocator (before it goes out of scope)
755
            _no_dispatch_record_stream(
756
                autograd_computed_grad, state._post_backward_stream
757
            )
758

759

760
def _post_backward_reshard_only_hook(
761
    state: _FSDPState,
762
    handle: FlatParamHandle,
763
    *unused: Any,
764
) -> None:
765
    with torch.profiler.record_function(
766
        "FullyShardedDataParallel._post_backward_hook_reshard_only"
767
    ):
768
        # `_pre_backward_hook` may not get executed
769
        # if forward output does not require grad
770
        # overwrite IDLE state for post-backward prefetching
771
        state.training_state = TrainingState.FORWARD_BACKWARD
772
        handle._training_state = HandleTrainingState.BACKWARD_POST
773
        _post_backward_reshard(state, handle)
774

775

776
def _post_backward_reshard(
777
    state: _FSDPState,
778
    handle: FlatParamHandle,
779
    *unused: Any,
780
) -> None:
781
    free_unsharded_flat_param = _should_free_in_backward(state, handle)
782
    _reshard(state, handle, free_unsharded_flat_param)
783

784
    # TODO: Post-backward prefetching does not support the multiple handles
785
    # per module case since the post-backward hook runs per handle, not per
786
    # group of handles.
787
    with torch.profiler.record_function(
788
        "FullyShardedDataParallel._post_backward_prefetch"
789
    ):
790
        _prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
791

792

793
@no_type_check
794
def _should_free_in_backward(
795
    state: _FSDPState,
796
    handle: FlatParamHandle,
797
) -> bool:
798
    """
799
    Returns whether FSDP should free the unsharded flat parameter in the
800
    post-backward or not.
801
    """
802
    if not handle.uses_sharded_strategy:
803
        return False
804
    # If not syncing gradients, then we do not free for strategies that do not
805
    # reshard after forward as a *heuristic* to tradeoff higher memory for
806
    # higher throughput.
807
    return (
808
        state._sync_gradients
809
        or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
810
    )
811

812

813
@no_type_check
814
def _reduce_grad(state: _FSDPState, handle: FlatParamHandle) -> None:
815
    """
816
    For sharded strategies, this runs gradient reduction, sharded gradient
817
    accumulation if needed, and the post-reduction callback.
818
    """
819
    flat_param = handle.flat_param
820
    uses_hybrid_sharded_strategy = handle._sharding_strategy in (
821
        HandleShardingStrategy.HYBRID_SHARD,
822
        HandleShardingStrategy._HYBRID_SHARD_ZERO2,
823
    )
824
    # We clear `.grad` to permit multiple backwards. This avoids a race where
825
    # the second backward pass computation precedes ahead of the first backward
826
    # pass reduction, which is possible since the reduction is issued in a
827
    # separate stream and is async and would result in reducing the wrong
828
    # gradient.
829
    unsharded_grad = flat_param.grad.data
830
    flat_param.grad = None
831
    padded_unsharded_grad, new_sharded_grad = _get_reduce_scatter_tensors(
832
        state, unsharded_grad
833
    )
834
    if state._comm_hook is None:  # default path
835
        _div_if_needed(padded_unsharded_grad, state._gradient_predivide_factor)
836
        pg = (
837
            handle._fake_process_group
838
            if handle._use_fake_reduce
839
            else state.process_group
840
        )
841
        dist.reduce_scatter_tensor(
842
            new_sharded_grad,
843
            padded_unsharded_grad,
844
            group=pg,
845
        )
846
        if uses_hybrid_sharded_strategy:
847
            # Don't wait during trace
848
            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
849
                state._all_reduce_stream.wait_stream(state._post_backward_stream)
850
            with state._device_handle.stream(state._all_reduce_stream):
851
                # Since the new sharded gradient is produced in the post-
852
                # backward stream and consumed in the all-reduce stream,
853
                # inform the caching allocator
854
                _no_dispatch_record_stream(new_sharded_grad, state._all_reduce_stream)
855
                dist.all_reduce(new_sharded_grad, group=state._inter_node_pg)
856
                _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor)
857
                grad_to_offload = _accumulate_sharded_grad(
858
                    state, handle, new_sharded_grad
859
                )
860
                _post_reduce_grad_callback(state, handle, grad_to_offload)
861
                return
862
        _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor)
863
    else:
864
        state._comm_hook(
865
            state._comm_hook_state, padded_unsharded_grad, new_sharded_grad
866
        )
867
        # NOTE: HSDP variants do not support communication hook.
868
    grad_to_offload = _accumulate_sharded_grad(state, handle, new_sharded_grad)
869
    _post_reduce_grad_callback(state, handle, grad_to_offload)
870

871

872
@no_type_check
873
def _get_reduce_scatter_tensors(
874
    state: _FSDPState, unsharded_grad: torch.Tensor
875
) -> Tuple[torch.Tensor, torch.Tensor]:
876
    """
877
    Returns the input and output tensors to reduce-scatter, respectively.
878
    """
879
    chunks = list(unsharded_grad.chunk(state.world_size))
880
    numel_to_pad = state.world_size * chunks[0].numel() - unsharded_grad.numel()
881
    padded_unsharded_grad = (
882
        F.pad(unsharded_grad, [0, numel_to_pad]) if numel_to_pad > 0 else unsharded_grad
883
    )
884
    new_sharded_grad = torch.empty_like(chunks[0])  # padded
885
    return padded_unsharded_grad, new_sharded_grad
886

887

888
@no_type_check
889
def _accumulate_sharded_grad(
890
    state: _FSDPState,
891
    handle: FlatParamHandle,
892
    sharded_grad: torch.Tensor,
893
) -> torch.Tensor:
894
    """
895
    Accumulates the reduce-scattered sharded gradient with any existing sharded
896
    gradient if needed, returning the gradient to offload (if CPU offloading is
897
    enabled).
898
    """
899
    flat_param = handle.flat_param
900
    _cast_grad_to_param_dtype(state, sharded_grad, flat_param)
901
    # Save the sharded gradient in `_saved_grad_shard` to support gradient
902
    # accumulation -- for multiple backwards, the gradient reductions may
903
    # happen in arbitrary order
904
    accumulate_grad = hasattr(flat_param, "_saved_grad_shard")
905
    if accumulate_grad:
906
        _check_grad_to_accumulate(sharded_grad, flat_param._saved_grad_shard)
907
        flat_param._saved_grad_shard += sharded_grad
908
    else:
909
        flat_param._saved_grad_shard = sharded_grad
910
    grad_to_offload = flat_param._saved_grad_shard
911
    return grad_to_offload
912

913

914
@no_type_check
915
def _reduce_grad_no_shard(state: _FSDPState, handle: FlatParamHandle) -> None:
916
    """
917
    For no-shard, this runs gradient reduction (which directly covers any
918
    gradient accumulation implicitly) and the post-reduction callback.
919
    """
920
    flat_param = handle.flat_param
921
    if state._comm_hook is None:  # default path
922
        _div_if_needed(flat_param.grad, state._gradient_predivide_factor)
923
        dist.all_reduce(flat_param.grad, group=state.process_group)
924
        _div_if_needed(flat_param.grad, state._gradient_postdivide_factor)
925
    else:
926
        state._comm_hook(state._comm_hook_state, flat_param.grad)
927
    # For `NO_SHARD`, we can keep the low precision gradients by simply
928
    # omitting the cast altogether
929
    if not handle._keep_low_precision_grads:
930
        _cast_grad_to_param_dtype(state, flat_param.grad, flat_param)
931
    grad_to_offload = flat_param.grad.data
932
    _post_reduce_grad_callback(state, handle, grad_to_offload)
933

934

935
@no_type_check
936
def _post_reduce_grad_callback(
937
    state: _FSDPState,
938
    handle: FlatParamHandle,
939
    # Additional arguments needed for the callback logic
940
    grad_to_offload: torch.Tensor,
941
):
942
    """
943
    This callback captures any logic to run after the gradient reduction
944
    finishes. Currently, this offloads the gradient to CPU if CPU offloading is
945
    enabled and uses sharded gradient views if ``use_orig_params=True``.
946
    """
947
    _offload_grad(state, handle, grad_to_offload)
948
    _post_backward_use_sharded_grad_views(handle)
949

950

951
@no_type_check
952
def _offload_grad(
953
    state: _FSDPState,
954
    handle: FlatParamHandle,
955
    grad_to_offload: torch.Tensor,
956
):
957
    if not handle._offload_params:
958
        return
959
    # Offload the gradient to CPU to ensure parameters and gradients are on the
960
    # same device as required by the optimizer
961
    # TODO: Investigate why `NO_SHARD` breaks correctness when using
962
    # `non_blocking=True` here.
963
    # TODO (rohan-varma): When CPU offload and optimizer overlap,
964
    # non_blocking=True won't work since the copy may have not finished before
965
    # the optimizer step executes on CPU. If we want to use non-blocking=True
966
    # here, we'll have to synchronize before using result on CPU.
967
    non_blocking = handle.uses_sharded_strategy and not handle._has_optim_in_backward
968
    handle.flat_param._cpu_grad.copy_(
969
        grad_to_offload.detach(), non_blocking=non_blocking
970
    )  # synchronized in the post-backward callback
971
    # Since the gradient being offloaded may have been produced in the
972
    # computation stream and is being consumed here in the post-backward
973
    # stream, inform the caching allocator
974
    _no_dispatch_record_stream(grad_to_offload.data, state._post_backward_stream)
975

976

977
@no_type_check
978
def _post_backward_use_sharded_grad_views(handle: FlatParamHandle):
979
    if not handle._use_orig_params:
980
        return
981
    # Since the handle's `FlatParameter` completed its gradient computation, we
982
    # should reset the gradient noneness mask
983
    handle._reset_is_grad_none()
984
    # Delay using sharded gradient views until after the reduce-scatter instead
985
    # of immediately after resharding
986
    handle._use_sharded_grad_views()
987
    if handle._has_optim_in_backward:
988
        handle.prepare_gradient_for_optim()
989
        for orig_param in handle.flat_param._params:
990
            # Check for `None` gradient to filter parameters not in the rank
991
            if orig_param.grad is not None and hasattr(
992
                orig_param, "_in_backward_optimizers"
993
            ):
994
                # TODO (rohan-varma): For CPU offload, this unfortunately
995
                # operates on CPU because the parameters and gradients have
996
                # already been offloaded. We should run this on GPU after
997
                # refactoring.
998
                for optim in orig_param._in_backward_optimizers:
999
                    optim.step()
1000

1001
                optim.zero_grad(set_to_none=True)
1002
        handle._reset_flat_param_grad_info_if_needed()
1003
        if handle._offload_params:
1004
            handle.flat_param._cpu_grad = None
1005

1006

1007
def _div_if_needed(tensor: torch.Tensor, div_factor: float) -> None:
1008
    if div_factor > 1:
1009
        tensor.div_(div_factor)
1010

1011

1012
@no_type_check
1013
def _cast_grad_to_param_dtype(
1014
    state: _FSDPState,
1015
    sharded_grad: torch.Tensor,
1016
    param: FlatParameter,
1017
):
1018
    """
1019
    Casts ``sharded_grad`` back to the full parameter dtype so that the
1020
    optimizer step runs with that dtype. This performs an actual cast if
1021
    1. parameters were in reduced precision during the forward since then
1022
    gradients would be in that reduced precision, or
1023
    2. parameters were not in reduced precision but gradients were in
1024
    reduced precision for communication.
1025
    However, if a low precision communication hook is registered, then this
1026
    dtype cast happens in the hook instead.
1027
    """
1028
    _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
1029
    if not _low_precision_hook_enabled(state) and sharded_grad.dtype != param.dtype:
1030
        low_prec_grad_data = sharded_grad.data
1031
        sharded_grad.data = sharded_grad.data.to(dtype=param.dtype)
1032
        # Since for `NO_SHARD`, the gradient is produced in the computation
1033
        # stream and consumed here in the post-backward stream, inform the
1034
        # caching allocator; for the sharded strategies, the gradient is
1035
        # produced in the post-backward stream, so this `record_stream()`
1036
        # should be a no-op
1037
        _no_dispatch_record_stream(
1038
            low_prec_grad_data, state._device_handle.current_stream()
1039
        )
1040

1041

1042
def _check_grad_to_accumulate(
1043
    new_sharded_grad: torch.Tensor,
1044
    accumulated_grad: torch.Tensor,
1045
) -> None:
1046
    _p_assert(
1047
        accumulated_grad.shape == new_sharded_grad.shape,
1048
        "Shape mismatch when accumulating gradients: "
1049
        f"existing gradient shape={accumulated_grad.shape} "
1050
        f"new gradient shape={new_sharded_grad.shape}",
1051
    )
1052
    _p_assert(
1053
        accumulated_grad.device == new_sharded_grad.device,
1054
        "Device mismatch when accumulating gradients: "
1055
        f"existing gradient device={accumulated_grad.device} "
1056
        f"new gradient device={new_sharded_grad.device}",
1057
    )
1058

1059

1060
@no_type_check
1061
def _low_precision_hook_enabled(state: _FSDPState) -> bool:
1062
    return state._comm_hook in LOW_PRECISION_HOOKS
1063

1064

1065
@no_type_check
1066
@torch.no_grad()
1067
def _post_backward_final_callback(
1068
    state: _FSDPState,
1069
    module: nn.Module,
1070
):
1071
    """
1072
    This waits for the post-backward to finish and performs some final cleanup.
1073
    This runs at the end of the entire backward pass and should only be called
1074
    on the root FSDP instance.
1075
    """
1076
    _p_assert(
1077
        state._is_root,
1078
        "The post-backward callback should only be called on the root FSDP instance",
1079
    )
1080
    root_state = state
1081

1082
    if root_state._sync_gradients:
1083
        current_stream = state._device_handle.current_stream()
1084
        # TODO (rohan-varma): this also waits for the overlapped optimizer step to finish
1085
        # since it currently runs in the post-backward stream. That can be
1086
        # pushed to the next forward if run in a different stream
1087
        current_stream.wait_stream(root_state._post_backward_stream)
1088
        if root_state._all_reduce_stream is not current_stream:  # uses HSDP
1089
            current_stream.wait_stream(root_state._all_reduce_stream)
1090
        if root_state.cpu_offload.offload_params:
1091
            # Wait for non-blocking GPU -> CPU sharded gradient copies from the
1092
            # post-backward hooks to finish explicitly since CPU gradients do
1093
            # not automatically synchronize with the GPU
1094
            state._device_handle.current_stream().synchronize()
1095
    root_state._exec_order_data.next_iter()
1096

1097
    for fsdp_state in state._all_fsdp_states:
1098
        _catch_all_reshard(fsdp_state)
1099
        _finalize_params(fsdp_state)
1100
        fsdp_state.training_state = TrainingState.IDLE
1101
        handle = fsdp_state._handle
1102
        if handle:
1103
            handle._ran_pre_backward_hook = False
1104
            handle._needs_pre_backward_unshard = False
1105
            handle._post_forward_index = None
1106
            handle._training_state = HandleTrainingState.IDLE
1107
            handle._prefetched = False
1108
    # Reset for cases like one forward and multiple backwards
1109
    root_state._post_backward_callback_queued = False
1110

1111

1112
@no_type_check
1113
def _catch_all_reshard(
1114
    state: _FSDPState,
1115
) -> None:
1116
    """
1117
    Reshards the parameters that may not have been resharded in the
1118
    post-backward hook. This can happen when a module's output is used in the
1119
    forward pass, meaning that its pre-backward hook runs (unsharding the
1120
    parameter), but the post-backward hook does not run because the output was
1121
    not jused in the loss computation corresponding to this backward pass.
1122
    """
1123
    # Wrap with a try-except to provide a more informative traceback if an
1124
    # error is raised
1125
    try:
1126
        if state._handle:
1127
            # TODO: This already-resharded check is brittle:
1128
            # https://github.com/pytorch/pytorch/issues/83956
1129
            already_resharded = (
1130
                state._handle.flat_param.data_ptr()
1131
                == state._handle.flat_param._local_shard.data_ptr()
1132
                # If FSDP skipped using sharded views, then the flat parameter
1133
                # still points to the sharded data, so we need to reshard to
1134
                # use sharded views
1135
                and not state._handle._skipped_use_sharded_views
1136
            )
1137
            if already_resharded:
1138
                return
1139
            free_unsharded_flat_param = _should_free_in_backward(state, state._handle)
1140
            _reshard(state, state._handle, free_unsharded_flat_param)
1141
    except Exception as e:
1142
        _p_assert(
1143
            False,
1144
            f"Got exception in the catch-all reshard for {state}: {str(e)}",
1145
            raise_assertion_error=False,
1146
        )
1147
        raise e
1148

1149

1150
@no_type_check
1151
def _finalize_params(
1152
    state: _FSDPState,
1153
) -> None:
1154
    """Finalizes the parameters before the next iteration."""
1155
    handle = state._handle
1156
    if not handle:
1157
        return
1158
    flat_param = handle.flat_param
1159
    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
1160
        if hasattr(flat_param, "_post_backward_hook_handle"):
1161
            pbhs_handle = flat_param._post_backward_hook_handle
1162
            pbhs_handle.remove()
1163
            del flat_param._post_backward_hook_handle
1164
    else:
1165
        if hasattr(flat_param, "_post_backward_hook_state"):
1166
            post_backward_hook_state_len = len(flat_param._post_backward_hook_state)
1167
            expected_post_backward_hook_state_len = int(flat_param.requires_grad) + 1
1168
            _p_assert(
1169
                post_backward_hook_state_len == expected_post_backward_hook_state_len,
1170
                f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
1171
            )
1172
            flat_param._post_backward_hook_state[-1].remove()
1173
            delattr(flat_param, "_post_backward_hook_state")
1174
    if flat_param.requires_grad:
1175
        if not state._sync_gradients:
1176
            # Preserve the gradient accumulation state if not synchronizing
1177
            # gradients: `.grad` remains the unsharded gradient  from prior
1178
            # `no_sync()` iterations, and `_saved_grad_shard` remains the
1179
            # sharded gradient from the last synchronized iteration
1180
            return
1181
        if not handle._has_optim_in_backward:
1182
            handle.prepare_gradient_for_optim()
1183
        _p_assert(
1184
            hasattr(flat_param, "_post_backward_called"),
1185
            "Expects `_post_backward_called` to be set on the `FlatParameter`",
1186
        )
1187
        flat_param._post_backward_called = False
1188

1189

1190
@no_type_check
1191
def _prefetch_handle(
1192
    state: _FSDPState,
1193
    current_handle: Optional[FlatParamHandle],
1194
    prefetch_mode: _PrefetchMode,
1195
) -> None:
1196
    """
1197
    Prefetches the next handles if needed (without synchronization). An empty
1198
    handles key cannot prefetch.
1199
    """
1200
    if not current_handle:
1201
        return
1202
    handle = _get_handle_to_prefetch(state, current_handle)
1203
    if not handle:
1204
        return
1205
    # Temporarily emulate the training state while calling `_unshard` to
1206
    # ensure the correct `as_params` for `_use_unsharded_views()`
1207
    prev_training_state = handle._training_state
1208
    if prefetch_mode == _PrefetchMode.BACKWARD:
1209
        handle._training_state = HandleTrainingState.BACKWARD_PRE
1210
    elif prefetch_mode == _PrefetchMode.FORWARD:
1211
        handle._training_state = HandleTrainingState.FORWARD
1212
    else:
1213
        raise ValueError(f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}")
1214
    # Prefetch the next set of handles without synchronizing to allow
1215
    # the sync to happen as late as possible to maximize overlap
1216
    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
1217
    handle._training_state = prev_training_state
1218
    handle._prefetched = True
1219

1220

1221
@no_type_check
1222
def _get_handle_to_prefetch(
1223
    state: _FSDPState,
1224
    current_handle: FlatParamHandle,
1225
) -> FlatParamHandle:
1226
    """
1227
    Returns a :class:`list` of the handles keys to prefetch for the next
1228
    module(s), where ``current_handle`` represents the current module.
1229

1230
    "Prefetching" refers to running the unshard logic early (without
1231
    synchronization), and the "next" modules depend on the recorded execution
1232
    order and the current training state.
1233
    """
1234
    training_state = _get_training_state(current_handle)
1235
    valid_training_states = (
1236
        HandleTrainingState.BACKWARD_PRE,
1237
        HandleTrainingState.BACKWARD_POST,
1238
        HandleTrainingState.FORWARD,
1239
    )
1240
    _p_assert(
1241
        training_state in valid_training_states,
1242
        f"Prefetching is only supported in {valid_training_states} but "
1243
        f"currently in {training_state}",
1244
    )
1245
    eod = state._exec_order_data
1246
    target_handle: Optional[FlatParamHandle] = None
1247
    if (
1248
        training_state == HandleTrainingState.BACKWARD_PRE
1249
        and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
1250
    ) or (
1251
        training_state == HandleTrainingState.BACKWARD_POST
1252
        and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
1253
    ):
1254
        target_handle_candidate = eod.get_handle_to_backward_prefetch(current_handle)
1255
        if (
1256
            target_handle_candidate
1257
            and target_handle_candidate._needs_pre_backward_unshard
1258
            and not target_handle_candidate._prefetched
1259
        ):
1260
            target_handle = target_handle_candidate
1261
        else:
1262
            target_handle = None
1263
    elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch:
1264
        target_handle_candidate = eod.get_handle_to_forward_prefetch(current_handle)
1265
        if (
1266
            target_handle_candidate
1267
            and target_handle_candidate._needs_pre_forward_unshard
1268
            and not target_handle_candidate._prefetched
1269
        ):
1270
            target_handle = target_handle_candidate
1271
        else:
1272
            target_handle = None
1273

1274
    return target_handle
1275

1276

1277
def _get_training_state(
1278
    handle: FlatParamHandle,
1279
) -> HandleTrainingState:
1280
    """Returns the training state of the handles in ``handle``."""
1281
    _p_assert(handle, "Expects a non-empty handle")
1282
    return handle._training_state
1283

1284

1285
@no_type_check
1286
def _register_pre_forward_hook(
1287
    state: _FSDPState,
1288
    module: nn.Module,
1289
) -> None:
1290
    """
1291
    Registers a pre-forward hook on ``module``.
1292
    """
1293
    for forward_handle in state._pre_forward_handles:
1294
        forward_handle.remove()
1295
    state._pre_forward_handles.clear()
1296
    module_param_handle = state._fully_sharded_module_to_handle.get(module, None)
1297
    hook = functools.partial(
1298
        _pre_forward, state, module_param_handle, _pre_forward_unshard
1299
    )
1300
    state._pre_forward_handles.append(
1301
        module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)
1302
    )
1303

1304

1305
@no_type_check
1306
def _register_post_forward_hook(
1307
    state: _FSDPState,
1308
    module: nn.Module,
1309
) -> None:
1310
    """
1311
    Registers a post-forward hook on ``module``. Even if the module has no
1312
    handles, we should register the hook since it will register the module's
1313
    pre-backward hook.
1314
    """
1315
    for forward_handle in state._post_forward_handles:
1316
        forward_handle.remove()
1317
    state._post_forward_handles.clear()
1318
    module_param_handle = state._fully_sharded_module_to_handle.get(module, None)
1319
    hook = functools.partial(
1320
        _post_forward,
1321
        state,
1322
        module_param_handle,
1323
        _post_forward_reshard,
1324
    )
1325
    state._post_forward_handles.append(module.register_forward_hook(hook))
1326

1327

1328
@no_type_check
1329
def _register_root_pre_forward_hook(
1330
    state: _FSDPState,
1331
    module: nn.Module,
1332
):
1333
    """
1334
    Registers root pre-forward hook on ``module``, which should be the local
1335
    FSDP root.
1336

1337
    NOTE: For the current composable FSDP design, we have each application of
1338
    ``fully_shard()`` to a module to indicate that that module is the local
1339
    FSDP root. We may remove this assumption in the future, in which case we
1340
    will need to register this root pre-forward hook on any candidate module
1341
    that may be the local FSDP root.
1342
    """
1343
    for forward_handle in state._root_pre_forward_handles:
1344
        forward_handle.remove()
1345
    state._root_pre_forward_handles.clear()
1346
    hook = functools.partial(_root_pre_forward, state)
1347
    state._root_pre_forward_handles.append(
1348
        module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)
1349
    )
1350

1351

1352
@no_type_check
1353
def _register_pre_backward_hooks(
1354
    state: _FSDPState,
1355
    module: nn.Module,
1356
    outputs: Any,
1357
    handle: FlatParamHandle,
1358
) -> None:
1359
    """
1360
    Registers pre-backward hooks on the tensors that require gradients in the
1361
    forward pass outputs ``outputs``, which were computed using the
1362
    ``FlatParameter`` s of ``handles``.
1363

1364
    Args:
1365
        module (nn.Module): Fully sharded module (see [Note: Fully Sharded
1366
            Module]).
1367

1368
    Returns:
1369
        Forward pass outputs with pre-backward hooks registered to tensors that
1370
        require gradients.
1371
    """
1372
    # If there is no gradient computation, then there is no need for
1373
    # pre-backward logic
1374
    if not torch.is_grad_enabled():
1375
        return outputs
1376
    if state._is_root:
1377
        state._post_backward_callback_queued = False  # only defined on the root
1378

1379
    if handle:
1380
        handle._needs_pre_backward_unshard = False
1381
        # Since these handles' `FlatParameter`s participated in a forward, we
1382
        # conservatively assume that they will be used in the backward
1383
        handle._ran_pre_backward_hook = False
1384

1385
    def _register_hook(t: torch.Tensor) -> torch.Tensor:
1386
        if t.requires_grad:
1387
            t.register_hook(
1388
                functools.partial(_pre_backward_hook, state, module, handle)
1389
            )
1390
            if handle:
1391
                handle._needs_pre_backward_unshard = True
1392
        return t
1393

1394
    return _apply_to_tensors(_register_hook, outputs)
1395

1396

1397
def _register_post_backward_hook(
1398
    state: _FSDPState,
1399
    handle: Optional[FlatParamHandle],
1400
) -> None:
1401
    """
1402
    Registers post-backward hooks on the ``FlatParameter`` s'
1403
    ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients.
1404

1405
    The ``AccumulateGrad`` object represents the last function that finalizes
1406
    the ``FlatParameter`` 's gradient, so it only runs after its entire
1407
    gradient computation has finished.
1408

1409
    We register the post-backward hook only once in the *first* forward that a
1410
    ``FlatParameter`` participates in. This relies on the ``AccumulateGrad``
1411
    object being preserved through multiple forwards.
1412

1413
    NOTE: We follow this heuristic to prefer the *first* forward to target the
1414
    parameter mixed precision case, where there are *separate*
1415
    ``AccumulateGrad`` objects across the different forwards. (Without
1416
    parameter mixed precision, the ``AccumulateGrad`` objects are the same.) If
1417
    we instead prefer the *last* forward, then the hook runs early.
1418
    """
1419
    # If there is no gradient computation, then there is no need for
1420
    # post-backward logic
1421
    if not torch.is_grad_enabled():
1422
        return
1423
    if not handle:
1424
        return
1425
    flat_param = handle.flat_param
1426

1427
    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
1428
        already_registered = hasattr(flat_param, "_post_backward_hook_handle")
1429
        if already_registered or not flat_param.requires_grad:
1430
            return
1431
        hook = functools.partial(_post_backward_hook, state, handle)
1432
        hook_handle = flat_param.register_post_accumulate_grad_hook(hook)
1433
        flat_param._post_backward_hook_handle = hook_handle  # type: ignore[attr-defined]
1434
    else:
1435
        already_registered = hasattr(flat_param, "_post_backward_hook_state")
1436
        if already_registered or not flat_param.requires_grad:
1437
            return
1438
        # Get the `AccumulateGrad` object
1439
        temp_flat_param = flat_param.expand_as(flat_param)
1440
        _p_assert(
1441
            temp_flat_param.grad_fn is not None,
1442
            "The `grad_fn` is needed to access the `AccumulateGrad` and "
1443
            "register the post-backward hook",
1444
        )
1445
        acc_grad = temp_flat_param.grad_fn.next_functions[0][0]  # type: ignore[union-attr]
1446
        assert acc_grad is not None
1447
        hook_handle = acc_grad.register_hook(
1448
            functools.partial(_post_backward_hook, state, handle)
1449
        )
1450
        flat_param._post_backward_hook_state = (acc_grad, hook_handle)  # type: ignore[attr-defined]
1451

1452

1453
def _register_post_backward_reshard_only_hook(
1454
    state: _FSDPState,
1455
    handle: Optional[FlatParamHandle],
1456
    args: Tuple[Any, ...],
1457
    kwargs: Dict[str, Any],
1458
) -> None:
1459
    """
1460
    Registers post-backward hooks to reshard flat parameters that do not
1461
    require gradient. We register these using multi-post-grad hooks on the
1462
    input activations to ensure that all gradients that may depend on the
1463
    parameters have been computed before resharding.
1464
    """
1465
    # If there is no gradient computation, then there is no need for
1466
    # post-backward logic
1467
    if not torch.is_grad_enabled():
1468
        return
1469
    # Construct `inp_tensors` lazily to avoid CPU overhead in typical case
1470
    # where each flat parameter requires gradient
1471
    inp_tensors: Optional[List[torch.Tensor]] = None
1472
    if not handle:
1473
        return
1474
    flat_param = handle.flat_param
1475

1476
    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
1477
        already_registered = hasattr(flat_param, "_post_backward_hook_handle")
1478
    else:
1479
        already_registered = hasattr(flat_param, "_post_backward_hook_state")
1480

1481
    if already_registered or flat_param.requires_grad:
1482
        return
1483
    if inp_tensors is None:
1484
        args_flat = pytree.arg_tree_leaves(*args, **kwargs)
1485
        inp_tensors = [
1486
            obj for obj in args_flat if torch.is_tensor(obj) and obj.requires_grad
1487
        ]
1488
    assert inp_tensors is not None  # mypy
1489
    hook_handle = register_multi_grad_hook(
1490
        inp_tensors, functools.partial(_post_backward_reshard_only_hook, state, handle)
1491
    )
1492
    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
1493
        flat_param._post_backward_hook_handle = hook_handle  # type: ignore[attr-defined, assignment]
1494
    else:
1495
        flat_param._post_backward_hook_state = (hook_handle,)  # type: ignore[attr-defined, assignment]
1496

1497

1498
@no_type_check
1499
def _register_post_backward_final_callback(
1500
    state: _FSDPState, module: nn.Module
1501
) -> None:
1502
    """
1503
    Registers the post-backward final callback that runs at the end of the
1504
    backward pass. This should be called from the root FSDP instance at the
1505
    beginning of the pre-backward.
1506
    """
1507
    _p_assert(
1508
        state._is_root,
1509
        "Only the root FSDP instance should register the post-backward callback",
1510
    )
1511
    if state._post_backward_callback_queued:
1512
        return
1513
    _assert_in_training_states(state, [TrainingState.IDLE])
1514
    # Trace does not need this callback
1515
    if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
1516
        state._post_backward_callback_queued = True
1517
        Variable._execution_engine.queue_callback(
1518
            functools.partial(_post_backward_final_callback, state, module)
1519
        )
1520

1521

1522
def _wait_for_computation_stream(
1523
    computation_stream: torch.Stream,
1524
    unshard_stream: torch.Stream,
1525
    pre_unshard_stream: torch.Stream,
1526
):
1527
    """
1528
    Has the unshard and pre-unshard streams wait for the computation stream.
1529
    For example, this should be called in the FSDP root's pre-forward to
1530
    respect optimizer step computation.
1531
    """
1532
    # Tracing does not need to wait
1533
    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
1534
        return
1535
    unshard_stream.wait_stream(computation_stream)  # type: ignore[attr-defined]
1536
    # Having the pre-all-gather stream wait for the current stream even if we
1537
    # do not leverage the pre-all-gather stream is tolerable since this only
1538
    # runs once per iteration
1539
    pre_unshard_stream.wait_stream(computation_stream)  # type: ignore[attr-defined]
1540

1541

1542
def _reset_flat_param_grad_info_if_needed(
1543
    handles: List[FlatParamHandle],
1544
):
1545
    """
1546
    Clears the original parameters' gradients if needed. This method's CPU
1547
    overhead is minimal, so we may call it throughout FSDP methods, which serve
1548
    as callsites to free the gradient memory earlier.
1549
    """
1550
    if not isinstance(handles, list):
1551
        handles = [handles]
1552
    for handle in handles:
1553
        if handle._use_orig_params:
1554
            handle._reset_flat_param_grad_info_if_needed()
1555

1556

1557
@no_type_check
1558
def _get_buffers_and_dtypes_for_computation(
1559
    state: _FSDPState,
1560
    root_module: nn.Module,
1561
) -> Tuple[List[torch.Tensor], List[Optional[torch.dtype]]]:
1562
    """
1563
    Returns all buffers in the module tree rooted at ``root_module`` and a
1564
    corresponding list of the buffer dtypes for computation. Each buffer dtype
1565
    is either ``None`` if buffer mixed precision is not enabled or the buffer
1566
    low precision dtype otherwise.
1567
    """
1568
    _p_assert(state._is_root, "Expects the root to cast buffers")
1569
    buffers: List[torch.Tensor] = []
1570
    buffer_dtypes: List[Optional[torch.dtype]] = []
1571
    visited_buffers: Set[torch.Tensor] = set()
1572
    # Traverse the FSDP states bottom-up so that we prefer the owning FSDP
1573
    # instance's mixed precision setting for each buffer
1574
    fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules(
1575
        root_module
1576
    )
1577
    for fsdp_state, fsdp_module in zip(reversed(fsdp_states), reversed(fsdp_modules)):
1578
        for buffer_name, buffer in fsdp_module.named_buffers():
1579
            if buffer in visited_buffers:
1580
                continue
1581
            visited_buffers.add(buffer)
1582
            if clean_tensor_name(buffer_name) in fsdp_state._ignored_buffer_names:
1583
                continue
1584
            buffers.append(buffer)
1585
            buffer_dtypes.append(fsdp_state.mixed_precision.buffer_dtype)
1586
    assert len(buffers) == len(buffer_dtypes), f"{len(buffers)} {len(buffer_dtypes)}"
1587
    return buffers, buffer_dtypes
1588

1589

1590
@no_type_check
1591
def _get_orig_buffer_dtypes(
1592
    state: _FSDPState,
1593
    buffer_names: List[str],
1594
) -> List[torch.dtype]:
1595
    """
1596
    Returns the original buffer types of the given buffer names.
1597
    """
1598
    buffer_dtypes: List[torch.dtype] = []
1599
    for buffer_name in buffer_names:
1600
        _p_assert(
1601
            buffer_name in state._buffer_name_to_orig_dtype,
1602
            f"{buffer_name} is missing from pre-computed dict on rank "
1603
            f"{state.rank}, which only has keys "
1604
            f"{state._buffer_name_to_orig_dtype.keys()}",
1605
        )
1606
        buffer_dtypes.append(state._buffer_name_to_orig_dtype[buffer_name])
1607
    return buffer_dtypes
1608

1609

1610
def _cast_buffers_to_dtype_and_device(
1611
    buffers: List[torch.Tensor],
1612
    buffer_dtypes: List[Optional[torch.dtype]],
1613
    device: torch.device,
1614
) -> None:
1615
    """
1616
    Casts ``buffers`` to the dtypes given by ``buffer_dtypes`` and moves them
1617
    to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
1618
    corresponding buffer is only moved to ``device``.
1619
    """
1620
    _p_assert(
1621
        buffer_dtypes is None or len(buffers) == len(buffer_dtypes),
1622
        f"Expects `buffers` and `buffer_dtypes` to have the same length if "
1623
        f"`buffer_dtypes` is specified but got {len(buffers)} and "
1624
        f"{len(buffer_dtypes)}",
1625
    )
1626
    for buffer, buffer_dtype in zip(buffers, buffer_dtypes):
1627
        if not torch.is_floating_point(buffer) or buffer_dtype is None:
1628
            buffer.data = buffer.to(device=device)
1629
        else:
1630
            buffer.data = buffer.to(device=device, dtype=buffer_dtype)
1631

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.