pytorch

Форк
0
/
distributed.py 
2350 строк · 103.6 Кб
1
import copy
2
import functools
3
import inspect
4
import itertools
5
import logging
6
import os
7
import sys
8
import warnings
9
import weakref
10
from collections import defaultdict, deque
11
from contextlib import contextmanager
12
from dataclasses import dataclass, fields, is_dataclass
13
from enum import auto, Enum
14
from typing import Any, Callable, List, Optional, Tuple, Type
15

16
import torch
17
import torch.distributed as dist
18
from torch.autograd import Function, Variable
19
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
20
from torch.utils._pytree import tree_flatten, tree_unflatten
21
from torch.utils.hooks import RemovableHandle
22

23
RPC_AVAILABLE = False
24
if dist.is_available():
25
    from torch.distributed.distributed_c10d import (
26
        _get_default_group,
27
        _rank_not_in_group,
28
        ReduceOp,
29
    )
30
    from torch.distributed.utils import (
31
        _alloc_storage,
32
        _cast_forward_inputs,
33
        _free_storage,
34
        _sync_module_states,
35
        _to_kwargs,
36
        _verify_param_shape_across_processes,
37
    )
38
if torch.distributed.rpc.is_available():
39
    RPC_AVAILABLE = True
40
    from torch.distributed.rpc import RRef
41

42
from torch._utils import _get_device_index
43

44
from ..modules import Module
45
from .scatter_gather import gather, scatter_kwargs  # noqa: F401
46

47
__all__ = ["DistributedDataParallel"]
48

49
logger = logging.getLogger(__name__)
50

51

52
@dataclass
53
class _MixedPrecision:
54
    """
55
    This configures DDP-native mixed precision training.
56

57
    Attributes:
58
        param_dtype (torch.dtype): This specifies the dtype for model
59
            parameters, inputs (when ``cast_forward_inputs`` is set to
60
            ``True``), and therefore the dtype for computation.
61
            However, outside the forward and backward passes, parameters are in
62
            full precision. Model checkpointing always happens in full
63
            precision.
64
        reduce_dtype (torch.dtype): This specifies the dtype for gradient
65
            reduction, which is permitted to differ from ``param_dtype``.
66
        buffer_dtype (torch.dtype): This specifies the dtype for buffers.
67

68
    .. note:: This API is experimental and subject to change.
69

70
    .. note:: Only floating point tensors are cast to their specified dtypes.
71

72
    .. note:: ``state_dict`` checkpoints parameters and buffers in full
73
        precision.
74

75
    .. note:: Each low precision dtype must be specified explicitly. For
76
        example, ``_MixedPrecision(reduce_dtype=torch.float16)`` only specifies
77
        the reduction dtype to be low precision, and DDP will not cast
78
        parameters or buffers.
79

80
    .. note:: If a ``reduce_dtype`` is not specified, then gradient reduction
81
        happens in ``param_dtype`` if specified or the original parameter dtype
82
        otherwise. For example, ``_MixedPrecision(param_dtype=torch.float16)``
83
        would result in communication occurring in fp16.
84
    """
85

86
    param_dtype: Optional[torch.dtype] = None
87
    reduce_dtype: Optional[torch.dtype] = None
88
    buffer_dtype: Optional[torch.dtype] = None
89
    # TODO (rohan-varma): keep_low_precision_grads: bool = False
90
    # TODO (rohan-varma): APIs to allow users to run batchnorm and layernorm
91
    # in full precision. For DDP, this can be implemented by not performing the
92
    # parameter cast for BN and LN units.
93

94

95
def _cast_buffers(mixed_precision_config, root_module):
96
    """Casts buffers to the given ``buffer_dtype``."""
97
    for buf in root_module.buffers():
98
        if hasattr(buf, "_ddp_ignored") and buf._ddp_ignored:
99
            continue
100

101
        buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype)
102

103

104
def _setup_mixed_precision_params(mixed_precision_config, root_module):
105
    """Create and free storage for the mixed precision parameters."""
106
    for param in root_module.parameters():
107
        # Do not setup mixed precision for DDP ignored parameters.
108
        if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
109
            continue
110

111
        if not hasattr(param, "_mp_param"):
112
            param._mp_param = torch.zeros_like(
113
                param,
114
                device=param.device,
115
                dtype=mixed_precision_config.param_dtype,
116
                requires_grad=param.requires_grad,
117
            )
118
            _free_storage(param._mp_param)
119
            # _fp_param will point to the full precision param so it can be switched
120
            # back to at the end of forward / backward.
121
            param._fp_param = param.data
122

123

124
def _tree_flatten_with_rref(output):
125
    output_is_rref = RPC_AVAILABLE and isinstance(output, RRef)
126
    if output_is_rref:
127
        output_tensor_list, treespec = tree_flatten(output.local_value())
128
    else:
129
        output_tensor_list, treespec = tree_flatten(output)
130
    # Need to return flattened tensors, spec to re-pack them, as well
131
    # as if the return type was actually an RRef to reconstruct.
132
    return output_tensor_list, treespec, output_is_rref
133

134

135
def _tree_unflatten_with_rref(output, treespec, output_is_rref):
136
    output = tree_unflatten(output, treespec)
137
    if output_is_rref:
138
        output = RRef(output)
139
    return output
140

141

142
def _find_tensors(obj):
143
    r"""Recursively find all tensors contained in the specified object."""
144
    if RPC_AVAILABLE and isinstance(obj, RRef):
145
        # If the current node is the owner of the RRef, unwrap it and try to
146
        # find Tensors.
147
        # TODO: Expand to remote RRefs.
148
        if obj.is_owner():
149
            return _find_tensors(obj.local_value())
150
    if isinstance(obj, torch.Tensor):
151
        return [obj]
152
    if isinstance(obj, (list, tuple)):
153
        return itertools.chain.from_iterable(map(_find_tensors, obj))
154
    if isinstance(obj, dict):
155
        return itertools.chain.from_iterable(map(_find_tensors, obj.values()))
156
    if is_dataclass(obj):
157
        return itertools.chain.from_iterable(
158
            map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))
159
        )
160

161
    return []
162

163

164
def _dump_DDP_relevant_env_vars():
165
    relevant_env_vars = [
166
        "RANK",
167
        "LOCAL_RANK",
168
        "WORLD_SIZE",
169
        "MASTER_PORT",
170
        "MASTER_ADDR",
171
        "CUDA_VISIBLE_DEVICES",
172
        "GLOO_SOCKET_IFNAME",
173
        "GLOO_DEVICE_TRANSPORT",
174
        "NCCL_SOCKET_IFNAME",
175
        "TORCH_NCCL_BLOCKING_WAIT",
176
        "NCCL_DEBUG",
177
        "NCCL_DEBUG_SUBSYS",
178
        "NCCL_IB_DISABLE",
179
        # More NCCL env vars:
180
        "NCCL_P2P_DISABLE",
181
        "NCCL_P2P_LEVEL",
182
        "NCCL_SHM_DISABLE",
183
        "NCCL_SOCKET_NTHREADS",
184
        "NCCL_NSOCKS_PERTHREAD",
185
        "NCCL_BUFFSIZE",
186
        "NCCL_NTHREADS",
187
        "NCCL_RINGS",
188
        "NCCL_MAX_NCHANNELS",
189
        "NCCL_MIN_NCHANNELS",
190
        "NCCL_CHECKS_DISABLE",
191
        "NCCL_CHECK_POINTERS",
192
        "NCCL_LAUNCH_MODE",
193
        "NCCL_IB_HCA",
194
        "NCCL_IB_TIMEOUT",
195
        "NCCL_IB_RETRY_CNT",
196
        "NCCL_IB_GID_INDEX",
197
        "NCCL_IB_SL",
198
        "NCCL_IB_TC",
199
        "NCCL_IB_AR_THRESHOLD",
200
        "NCCL_IB_CUDA_SUPPORT",
201
        "NCCL_NET_GDR_LEVEL",
202
        "NCCL_NET_GDR_READ",
203
        "NCCL_SINGLE_RING_THRESHOLD",
204
        "NCCL_LL_THRESHOLD",
205
        "NCCL_TREE_THRESHOLD",
206
        "NCCL_ALGO",
207
        "NCCL_PROTO",
208
        "NCCL_IGNORE_CPU_AFFINITY",
209
        "NCCL_DEBUG_FILE",
210
        "NCCL_COLLNET_ENABLE",
211
        "NCCL_TOPO_FILE",
212
        "NCCL_TOPO_DUMP_FILE",
213
        "TORCH_NCCL_ASYNC_ERROR_HANDLING",
214
    ]
215
    formatted_output = ""
216
    for var in relevant_env_vars:
217
        value = os.environ[var] if var in os.environ else "N/A"
218
        formatted_output += f"env:{var}={value}\n"
219
    print(formatted_output)
220

221

222
class _BufferCommHookLocation(Enum):
223
    PRE_FORWARD = auto()
224
    POST_FORWARD = auto()
225

226

227
@dataclass
228
class _BufferCommHook:
229
    buffer_comm_hook: Callable
230
    buffer_comm_hook_state: Any
231
    buffer_comm_hook_location: _BufferCommHookLocation
232

233

234
# Add a DDPSink to run various functions when backwards starts, such as
235
# queueing call back of out-most backward/graph task,
236
# this helps call back is fired after all gradients' calculation
237
# is completed.
238
class _DDPSink(Function):
239
    @staticmethod
240
    def forward(ctx, ddp_weakref, *inputs):
241
        # set_materialize_grads(False) will ensure that None gradients stay as
242
        # None and are not filled with zeros.
243
        ctx.set_materialize_grads(False)
244
        ctx.ddp_weakref = ddp_weakref
245
        ret = tuple(
246
            inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
247
        )
248
        return ret
249

250
    @staticmethod
251
    def backward(ctx, *grad_outputs):
252
        # Enqueue delay allreduce for static graph training on the first
253
        # iteration.
254
        ddp_weakref = ctx.ddp_weakref()
255
        reducer = ddp_weakref.reducer
256
        static_graph = ddp_weakref.static_graph
257
        delay_ar_enqueued = (
258
            static_graph and ddp_weakref._static_graph_delay_allreduce_enqueued
259
        )
260
        if static_graph and not delay_ar_enqueued:
261
            Variable._execution_engine.queue_callback(  # type: ignore[call-arg,misc]
262
                reducer._delay_all_reduce
263
            )
264
            ddp_weakref._static_graph_delay_allreduce_enqueued = True
265

266
        return (None, *grad_outputs)
267

268

269
class _DDPJoinHook(JoinHook):
270
    def __init__(self, ddp, divide_by_initial_world_size):
271
        """Set config variables for internal usage."""
272
        assert isinstance(ddp, DistributedDataParallel), (
273
            "DDP join hook requires passing in a DistributedDataParallel "
274
            "instance as the state"
275
        )
276
        assert ddp.logger is not None
277
        ddp.logger._set_uneven_input_join()
278
        self.ddp = ddp
279
        self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
280
        super().__init__()
281

282
    def main_hook(self):
283
        """Shadow the DDP collective communication operations in the forward and backward passes."""
284
        ddp = self.ddp
285
        # Buckets are rebuilt only once during a training period
286
        ddp.reducer._rebuild_buckets()
287

288
        # Schedule a broadcast if we are syncing module buffers in the
289
        # forward pass
290
        # TODO: make DDP uneven inputs context manager support buffer
291
        # comm hook (https://github.com/pytorch/pytorch/issues/65436)
292
        ddp._check_and_sync_module_buffers()
293

294
        # Check if need to sync in the backward pass
295
        should_sync_backwards = ddp._check_global_requires_backward_grad_sync(
296
            is_joined_rank=True
297
        )
298
        # Forward parameter sync is disabled in the next iteration if we
299
        # are skipping gradient sync this iteration, so set
300
        # `require_forward_param_sync` accordingly
301
        ddp.require_forward_param_sync = should_sync_backwards
302
        if not should_sync_backwards:
303
            return
304

305
        # Schedule one allreduce per gradient bucket to match the backward
306
        # pass allreduce
307
        ddp._match_all_reduce_for_bwd_pass()
308

309
        # Check if we need to allreduce locally unused parameters
310
        if ddp.find_unused_parameters:
311
            ddp._match_unused_params_allreduce()
312

313
        # Rebuilt parameters are pushed only once during a training period
314
        ddp.reducer._push_all_rebuilt_params()
315

316
    def post_hook(self, is_last_joiner: bool):
317
        """Sync the final model to ensure that the model is the same across all processes."""
318
        self.ddp._sync_final_model(is_last_joiner)
319

320

321
class DistributedDataParallel(Module, Joinable):
322
    r"""Implement distributed data parallelism based on ``torch.distributed`` at module level.
323

324
    This container provides data parallelism by synchronizing gradients
325
    across each model replica. The devices to synchronize across are
326
    specified by the input ``process_group``, which is the entire world
327
    by default. Note that ``DistributedDataParallel`` does not chunk or
328
    otherwise shard the input across participating GPUs; the user is
329
    responsible for defining how to do so, for example through the use
330
    of a :class:`DistributedSampler`.
331

332
    See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`.
333
    The same constraints on input as in :class:`torch.nn.DataParallel` apply.
334

335
    Creation of this class requires that ``torch.distributed`` to be already
336
    initialized, by calling :func:`torch.distributed.init_process_group`.
337

338
    ``DistributedDataParallel`` is proven to be significantly faster than
339
    :class:`torch.nn.DataParallel` for single-node multi-GPU data
340
    parallel training.
341

342
    To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn
343
    up ``N`` processes, ensuring that each process exclusively works on a single
344
    GPU from 0 to N-1. This can be done by either setting
345
    ``CUDA_VISIBLE_DEVICES`` for every process or by calling:
346

347
        >>> # xdoctest: +SKIP("undefined variables")
348
        >>> torch.cuda.set_device(i)
349

350
    where i is from 0 to N-1. In each process, you should refer the following
351
    to construct this module:
352

353
        >>> # xdoctest: +SKIP("undefined variables")
354
        >>> torch.distributed.init_process_group(
355
        >>>     backend='nccl', world_size=N, init_method='...'
356
        >>> )
357
        >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
358

359
    In order to spawn up multiple processes per node, you can use either
360
    ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``.
361

362
    .. note::
363
        Please refer to `PyTorch Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`__
364
        for a brief introduction to all features related to distributed training.
365

366
    .. note::
367
        ``DistributedDataParallel`` can be used in conjunction with
368
        :class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce
369
        per-rank optimizer states memory footprint. Please refer to
370
        `ZeroRedundancyOptimizer recipe <https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html>`__
371
        for more details.
372

373
    .. note:: ``nccl`` backend is currently the fastest and highly recommended
374
        backend when using GPUs. This applies to both single-node and
375
        multi-node distributed training.
376

377
    .. note:: This module also supports mixed-precision distributed training.
378
        This means that your model can have different types of parameters such
379
        as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these
380
        mixed types of parameters will just work fine.
381

382
    .. note:: If you use ``torch.save`` on one process to checkpoint the module,
383
        and ``torch.load`` on some other processes to recover it, make sure that
384
        ``map_location`` is configured properly for every process. Without
385
        ``map_location``, ``torch.load`` would recover the module to devices
386
        where the module was saved from.
387

388
    .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the
389
        gradient will be ``M`` times smaller when compared to the same model
390
        trained on a single node with ``batch=M*N`` if the loss is summed (NOT
391
        averaged as usual) across instances in a batch (because the gradients
392
        between different nodes are averaged). You should take this into
393
        consideration when you want to obtain a mathematically equivalent
394
        training process compared to the local training counterpart. But in most
395
        cases, you can just treat a DistributedDataParallel wrapped model, a
396
        DataParallel wrapped model and an ordinary model on a single GPU as the
397
        same (E.g. using the same learning rate for equivalent batch size).
398

399
    .. note::
400
        Parameters are never broadcast between processes. The module performs
401
        an all-reduce step on gradients and assumes that they will be modified
402
        by the optimizer in all processes in the same way. Buffers
403
        (e.g. BatchNorm stats) are broadcast from the module in process of rank
404
        0, to all other replicas in the system in every iteration.
405

406
    .. note::
407
        If you are using DistributedDataParallel in conjunction with the
408
        :ref:`distributed-rpc-framework`, you should always use
409
        :meth:`torch.distributed.autograd.backward` to compute gradients and
410
        :class:`torch.distributed.optim.DistributedOptimizer` for optimizing
411
        parameters.
412

413
        Example::
414

415
            >>> # xdoctest: +SKIP("undefined variables")
416
            >>> import torch.distributed.autograd as dist_autograd
417
            >>> from torch.nn.parallel import DistributedDataParallel as DDP
418
            >>> import torch
419
            >>> from torch import optim
420
            >>> from torch.distributed.optim import DistributedOptimizer
421
            >>> import torch.distributed.rpc as rpc
422
            >>> from torch.distributed.rpc import RRef
423
            >>>
424
            >>> t1 = torch.rand((3, 3), requires_grad=True)
425
            >>> t2 = torch.rand((3, 3), requires_grad=True)
426
            >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2))
427
            >>> ddp_model = DDP(my_model)
428
            >>>
429
            >>> # Setup optimizer
430
            >>> optimizer_params = [rref]
431
            >>> for param in ddp_model.parameters():
432
            >>>     optimizer_params.append(RRef(param))
433
            >>>
434
            >>> dist_optim = DistributedOptimizer(
435
            >>>     optim.SGD,
436
            >>>     optimizer_params,
437
            >>>     lr=0.05,
438
            >>> )
439
            >>>
440
            >>> with dist_autograd.context() as context_id:
441
            >>>     pred = ddp_model(rref.to_here())
442
            >>>     loss = loss_func(pred, target)
443
            >>>     dist_autograd.backward(context_id, [loss])
444
            >>>     dist_optim.step(context_id)
445

446
    .. note::
447
        DistributedDataParallel currently offers limited support for gradient
448
        checkpointing with :meth:`torch.utils.checkpoint`.
449
        If the checkpoint is done with use_reentrant=False (recommended), DDP
450
        will work as expected without any limitations.
451
        If, however, the checkpoint is done with use_reentrant=True (the default),
452
        DDP will work as expected when there are no unused parameters in the model
453
        and each layer is checkpointed at most once (make sure you are not passing
454
        `find_unused_parameters=True` to DDP). We currently do not support the
455
        case where a layer is checkpointed multiple times, or when there unused
456
        parameters in the checkpointed model.
457

458
    .. note::
459
        To let a non-DDP model load a state dict from a DDP model,
460
        :meth:`~torch.nn.modules.utils.consume_prefix_in_state_dict_if_present`
461
        needs to be applied to strip the prefix "module." in the DDP state dict before loading.
462

463
    .. warning::
464
        Constructor, forward method, and differentiation of the output (or a
465
        function of the output of this module) are distributed synchronization
466
        points. Take that into account in case different processes might be
467
        executing different code.
468

469
    .. warning::
470
        This module assumes all parameters are registered in the model by the
471
        time it is created. No parameters should be added nor removed later.
472
        Same applies to buffers.
473

474
    .. warning::
475
        This module assumes all parameters are registered in the model of each
476
        distributed processes are in the same order. The module itself will
477
        conduct gradient ``allreduce`` following the reverse order of the
478
        registered parameters of the model. In other words, it is users'
479
        responsibility to ensure that each distributed process has the exact
480
        same model and thus the exact same parameter registration order.
481

482
    .. warning::
483
        This module allows parameters with non-rowmajor-contiguous strides.
484
        For example, your model may contain some parameters whose
485
        :class:`torch.memory_format` is ``torch.contiguous_format``
486
        and others whose format is ``torch.channels_last``.  However,
487
        corresponding parameters in different processes must have the
488
        same strides.
489

490
    .. warning::
491
        This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
492
        only work if gradients are to be accumulated in ``.grad`` attributes of
493
        parameters).
494

495
    .. warning::
496
        If you plan on using this module with a ``nccl`` backend or a ``gloo``
497
        backend (that uses Infiniband), together with a DataLoader that uses
498
        multiple workers, please change the multiprocessing start method to
499
        ``forkserver`` (Python 3 only) or ``spawn``. Unfortunately
500
        Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will
501
        likely experience deadlocks if you don't change this setting.
502

503
    .. warning::
504
        You should never try to change your model's parameters after wrapping
505
        up your model with ``DistributedDataParallel``. Because, when
506
        wrapping up your model with ``DistributedDataParallel``, the constructor
507
        of ``DistributedDataParallel`` will register the additional gradient
508
        reduction functions on all the parameters of the model itself at the
509
        time of construction. If you change the model's parameters afterwards,
510
        gradient reduction functions no longer match the correct set of
511
        parameters.
512

513
    .. warning::
514
        Using ``DistributedDataParallel`` in conjunction with the
515
        :ref:`distributed-rpc-framework` is experimental and subject to change.
516

517
    Args:
518
        module (Module): module to be parallelized
519
        device_ids (list of int or torch.device): CUDA devices.
520
                   1) For single-device modules, ``device_ids`` can
521
                   contain exactly one device id, which represents the only
522
                   CUDA device where the input module corresponding to this process resides.
523
                   Alternatively, ``device_ids`` can also be ``None``.
524
                   2) For multi-device modules and CPU modules,
525
                   ``device_ids`` must be ``None``.
526

527
                   When ``device_ids`` is ``None`` for both cases,
528
                   both the input data for the forward pass and the actual module
529
                   must be placed on the correct device.
530
                   (default: ``None``)
531
        output_device (int or torch.device): Device location of output for
532
                      single-device CUDA modules. For multi-device modules and
533
                      CPU modules, it must be ``None``, and the module itself
534
                      dictates the output location. (default: ``device_ids[0]``
535
                      for single-device modules)
536
        broadcast_buffers (bool): Flag that enables syncing (broadcasting)
537
                          buffers of the module at beginning of the ``forward``
538
                          function. (default: ``True``)
539
        process_group: The process group to be used for distributed data
540
                       all-reduction. If ``None``, the default process group, which
541
                       is created by :func:`torch.distributed.init_process_group`,
542
                       will be used. (default: ``None``)
543
        bucket_cap_mb: ``DistributedDataParallel`` will bucket parameters into
544
                       multiple buckets so that gradient reduction of each
545
                       bucket can potentially overlap with backward computation.
546
                       :attr:`bucket_cap_mb` controls the bucket size in
547
                       MegaBytes (MB). (default: 25)
548
        find_unused_parameters (bool): Traverse the autograd graph from all
549
                               tensors contained in the return value of the
550
                               wrapped module's ``forward`` function. Parameters
551
                               that don't receive gradients as part of this
552
                               graph are preemptively marked as being ready to
553
                               be reduced. In addition, parameters that may have
554
                               been used in the wrapped module's ``forward``
555
                               function but were not part of loss computation and
556
                               thus would also not receive gradients are
557
                               preemptively marked as ready to be reduced.
558
                               (default: ``False``)
559
        check_reduction: This argument is deprecated.
560
        gradient_as_bucket_view (bool): When set to ``True``, gradients will be views
561
                      pointing to different offsets of ``allreduce`` communication
562
                      buckets. This can reduce peak memory usage, where the
563
                      saved memory size will be equal to the total gradients
564
                      size. Moreover, it avoids the overhead of copying between
565
                      gradients and ``allreduce`` communication buckets. When
566
                      gradients are views, ``detach_()`` cannot be called on the
567
                      gradients. If hitting such errors, please fix it by
568
                      referring to the :meth:`~torch.optim.Optimizer.zero_grad`
569
                      function in ``torch/optim/optimizer.py`` as a solution.
570
                      Note that gradients will be views after first iteration, so
571
                      the peak memory saving should be checked after first iteration.
572
        static_graph (bool): When set to ``True``, DDP knows the trained graph is
573
                     static. Static graph means 1) The set of used and unused
574
                     parameters will not change during the whole training loop; in
575
                     this case, it does not matter whether users set
576
                     ``find_unused_parameters = True`` or not. 2) How the graph is trained
577
                     will not change during the whole training loop (meaning there is
578
                     no control flow depending on iterations).
579
                     When static_graph is set to be ``True``, DDP will support cases that
580
                     can not be supported in the past:
581
                     1) Reentrant backwards.
582
                     2) Activation checkpointing multiple times.
583
                     3) Activation checkpointing when model has unused parameters.
584
                     4) There are model parameters that are outside of forward function.
585
                     5) Potentially improve performance when there are unused parameters,
586
                     as DDP will not search graph in each iteration to detect unused
587
                     parameters when static_graph is set to be ``True``.
588
                     To check whether you can set static_graph to be ``True``, one way is to
589
                     check ddp logging data at the end of your previous model training,
590
                     if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you
591
                     can set ``static_graph = True`` as well.
592

593
                     Example::
594
                         >>> # xdoctest: +SKIP("undefined variables")
595
                         >>> model_DDP = torch.nn.parallel.DistributedDataParallel(model)
596
                         >>> # Training loop
597
                         >>> ...
598
                         >>> ddp_logging_data = model_DDP._get_ddp_logging_data()
599
                         >>> static_graph = ddp_logging_data.get("can_set_static_graph")
600
        delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter): a list
601
                    of named parameters whose all reduce will be delayed when the gradient of
602
                    the parameter specified in ``param_to_hook_all_reduce`` is ready. Other
603
                    arguments of DDP do not apply to named params specified in this argument
604
                    as these named params will be ignored by DDP reducer.
605
        param_to_hook_all_reduce (torch.nn.Parameter): a parameter to hook delayed all reduce
606
                    of parameters specified in ``delay_all_reduce_named_params``.
607

608

609
    Attributes:
610
        module (Module): the module to be parallelized.
611

612
    Example::
613

614
        >>> # xdoctest: +SKIP("undefined variables")
615
        >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
616
        >>> net = torch.nn.parallel.DistributedDataParallel(model)
617
    """
618

619
    # used to track whether the given thread is inside ddp forward for torchdynamo purposes
620
    _active_ddp_module: Optional["DistributedDataParallel"] = None
621

622
    def __init__(
623
        self,
624
        module,
625
        device_ids=None,
626
        output_device=None,
627
        dim=0,
628
        broadcast_buffers=True,
629
        process_group=None,
630
        bucket_cap_mb=25,
631
        find_unused_parameters=False,
632
        check_reduction=False,
633
        gradient_as_bucket_view=False,
634
        static_graph=False,
635
        delay_all_reduce_named_params=None,
636
        param_to_hook_all_reduce=None,
637
        mixed_precision: Optional[_MixedPrecision] = None,
638
        device_mesh=None,
639
    ):
640
        super().__init__()
641
        Joinable.__init__(self)
642
        self.logger = None
643
        if bool(delay_all_reduce_named_params is not None) != bool(
644
            param_to_hook_all_reduce is not None
645
        ):
646
            self._log_and_throw(
647
                ValueError,
648
                "delay_all_reduce_named_params and param_to_hook_all_reduce "
649
                "need to be set at the same time.",
650
            )
651

652
        self._delay_all_reduce_params = []
653
        if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
654
            self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore)
655
        else:
656
            self.parameters_to_ignore = set()
657
        if delay_all_reduce_named_params is not None:
658
            for name, param in delay_all_reduce_named_params:
659
                self.parameters_to_ignore.add(name)
660
                self._delay_all_reduce_params.append(param)
661

662
        self._module_parameters = [
663
            p
664
            for n, p in module.named_parameters()
665
            if n not in self.parameters_to_ignore
666
        ]
667
        if not any(p.requires_grad for p in self._module_parameters):
668
            if len(self._delay_all_reduce_params):
669
                logger.info("Delay the AllReduce of all parameters.")
670
            else:
671
                self._log_and_throw(
672
                    RuntimeError,
673
                    "DistributedDataParallel is not needed when a module "
674
                    "doesn't have any parameter that requires a gradient.",
675
                )
676

677
        if device_ids is not None and len(device_ids) > 1:
678
            self._log_and_throw(
679
                ValueError,
680
                "device_ids can only be None or contain a single element.",
681
            )
682

683
        self.is_multi_device_module = (
684
            len({p.device for p in self._module_parameters}) > 1
685
        )
686
        distinct_device_types = {
687
            p.device.type for p in self._module_parameters if p.device is not None
688
        }
689
        if len(distinct_device_types) != 1:
690
            self._log_and_throw(
691
                ValueError,
692
                "DistributedDataParallel's input module must be on "
693
                f"the same type of devices, but input module parameters locate in {distinct_device_types}.",
694
            )
695

696
        self.device_type = next(iter(distinct_device_types))
697

698
        if (
699
            device_ids is None
700
            or len(device_ids) == 0  # For backward compatibility.
701
            or self.device_type == "cpu"
702
            or self.is_multi_device_module
703
        ):
704
            if device_ids or output_device:
705
                self._log_and_throw(
706
                    ValueError,
707
                    "DistributedDataParallel device_ids and output_device arguments "
708
                    "only work with single-device/multiple-device GPU modules or CPU modules, "
709
                    "but got device_ids {}, output_device {}, and module parameters {}.".format(
710
                        device_ids,
711
                        output_device,
712
                        {p.device for p in self._module_parameters},
713
                    ),
714
                )
715

716
            self.device_ids = None
717
            self.output_device = None
718
        else:
719
            self.device_ids = [_get_device_index(x, True) for x in device_ids]
720

721
            if output_device is None:
722
                output_device = device_ids[0]
723

724
            self.output_device = _get_device_index(output_device, True)
725

726
        if process_group and device_mesh is not None:
727
            raise RuntimeError(
728
                "Cannot specify both process_group and device_mesh arguments."
729
            )
730
        elif process_group is None and device_mesh is None:
731
            self.process_group = _get_default_group()
732
        elif device_mesh is None:
733
            self.process_group = process_group
734
        else:
735
            if device_mesh.ndim != 1:
736
                raise RuntimeError(
737
                    f"Only 1D device mesh is supported, but got {device_mesh}."
738
                )
739
            self.device_mesh = device_mesh
740
            self.process_group = device_mesh.get_group(mesh_dim=0)
741

742
        self.static_graph = False
743
        self.dim = dim
744
        self.module = module
745
        self.device = next(iter(self._module_parameters)).device
746
        self.broadcast_buffers = broadcast_buffers
747
        self.find_unused_parameters = find_unused_parameters
748
        self.require_backward_grad_sync = True
749
        self.require_forward_param_sync = True
750
        self.gradient_as_bucket_view = gradient_as_bucket_view
751
        self.mixed_precision = mixed_precision
752
        if self.mixed_precision is not None:
753
            logger.warning("Received mixed precision config %s", self.mixed_precision)
754

755
        if check_reduction:
756
            # This argument is no longer used since the reducer
757
            # will ensure reduction completes even if some parameters
758
            # do not receive gradients.
759
            warnings.warn(
760
                "The `check_reduction` argument in `DistributedDataParallel` "
761
                "module is deprecated. Please avoid using it."
762
            )
763

764
        # Check that a module does not have Uninitialized parameters
765
        for param in self._module_parameters:
766
            if isinstance(param, torch.nn.parameter.UninitializedParameter):
767
                self._log_and_throw(
768
                    RuntimeError,
769
                    "Modules with uninitialized parameters can't be used with `DistributedDataParallel`. "
770
                    "Run a dummy forward pass to correctly initialize the modules",
771
                )
772
        # used for intra-node param sync and inter-node sync as well
773
        self.broadcast_bucket_size = int(250 * 1024 * 1024)
774

775
        # reduction bucket size
776
        self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
777
        # Whether to perform input tensor CPU to GPU copies on a side-stream
778
        self.use_side_stream_for_tensor_copies = (
779
            os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1"
780
        )
781

782
        # Initialize gradient buffers and register all reduce hook
783
        self._delay_grad_buffer = None
784
        self._delay_grad_views: List[torch.Tensor] = []
785
        self._delay_all_reduce_all_params = False
786
        if len(self._delay_all_reduce_params) != 0:
787
            self._register_delay_all_reduce_hook(
788
                bucket_cap_mb=bucket_cap_mb,
789
                param_to_hook_all_reduce=param_to_hook_all_reduce,
790
                device_ids=device_ids,
791
            )
792
            if self._delay_all_reduce_all_params:
793
                return
794

795
        # Build parameters for reducer.
796
        parameters, expect_sparse_gradient = self._build_params_for_reducer()
797
        # Verify model equivalence.
798
        _verify_param_shape_across_processes(self.process_group, parameters)
799
        # Sync params and buffers. Ensures all DDP models start off at the same value.
800
        _sync_module_states(
801
            module=self.module,
802
            process_group=self.process_group,
803
            broadcast_bucket_size=self.broadcast_bucket_size,
804
            src=0,
805
            params_and_buffers_to_ignore=self.parameters_to_ignore,
806
            broadcast_buffers=self.broadcast_buffers,
807
        )
808
        # In debug mode, build a mapping of parameter index -> parameter.
809
        param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
810

811
        # Builds reducer.
812
        self._ddp_init_helper(
813
            parameters,
814
            expect_sparse_gradient,
815
            param_to_name_mapping,
816
            static_graph,
817
        )
818
        self._comm_hooks: List[Tuple[Callable, object]] = []
819

820
        if self.mixed_precision is not None:
821
            _setup_mixed_precision_params(self.mixed_precision, self.module)
822
            _cast_buffers(self.mixed_precision, self.module)
823
            # Stream used for async low precision copies.
824
            self._mp_stream = torch.cuda.Stream()
825
            self._submodule_to_event = defaultdict(deque)  # type: ignore[var-annotated]
826
            # Add forward pre-hook to root module to kick off copies to lower
827
            # precision.
828
            self.module.register_forward_pre_hook(
829
                self._root_copy_hook, prepend=False, with_kwargs=True
830
            )
831
            # Add forward pre hook to all submodules to wait for copy events
832
            # before running computation.
833
            for module in self.module.modules():
834
                module.register_forward_pre_hook(
835
                    self._module_wait_for_copy_hook,
836
                    prepend=False,
837
                    with_kwargs=True,
838
                )
839
            # Set up callbacks in backward to upcast and use full precision
840
            # params. TODO (rohan-varma): Make this compose with general
841
            # comm hooks and apply_optimizer_in_backward. Importing inline to
842
            # avoid circular import issue.
843
            from torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks import (
844
                _AllreduceUpcastHookState,
845
                _reducer_allreduce_and_upcast_hook,
846
            )
847

848
            upcast_hook_state = _AllreduceUpcastHookState(
849
                ddp_weakref=weakref.ref(self),
850
                upcast_stream=torch.cuda.Stream(),
851
            )
852
            self.register_comm_hook(
853
                upcast_hook_state,
854
                _reducer_allreduce_and_upcast_hook,
855
            )
856
            # Inform reducer of reduced precision param dtype for correctness
857
            # of type checks between gradient and bucket.
858
            self.reducer._set_mixed_precision_param_dtype(  # type: ignore[attr-defined]
859
                self.mixed_precision.param_dtype
860
            )
861

862
        self._has_rebuilt_buckets = False
863

864
        if static_graph:
865
            self._set_static_graph()
866

867
        self._lazy_init_ran = False
868

869
        # Register the AccumulateGrad post hooks if optimize_ddp is
870
        # True. The hooks will be deregistered if compiled_autograd is not
871
        # enabled.
872
        self._accum_grad_hooks: List[RemovableHandle] = []
873
        optimize_ddp = torch._dynamo.config._get_optimize_ddp_mode()
874
        self._use_python_reducer = optimize_ddp in (
875
            "python_reducer",
876
            "python_reducer_without_compiled_forward",
877
        )
878
        self._force_to_disable_cpp_reducer = (
879
            optimize_ddp == "python_reducer_without_compiled_forward"
880
        )
881
        if self._use_python_reducer:
882
            self._register_accum_grad_hook()
883

884
    def _register_accum_grad_hook(self):
885
        import torch.distributed._functional_collectives as fcol
886

887
        def compiled_accum_grad_hook(
888
            param,
889
            *,
890
            param_index: int,
891
        ):
892
            if not self.require_backward_grad_sync:
893
                return
894

895
            if param.grad is None:
896
                return
897

898
            if self._comm_hooks:
899
                for hook, state in self._comm_hooks:
900
                    hook(state, (param.grad, param))
901
            else:
902
                gradient = param.grad / self.process_group.size()
903
                gradient = fcol.all_reduce(gradient, "sum", self.process_group)
904
                param.grad.copy_(gradient)
905

906
        for index, param in enumerate(self._module_parameters):
907
            self._accum_grad_hooks.append(
908
                param.register_post_accumulate_grad_hook(
909
                    functools.partial(
910
                        compiled_accum_grad_hook,
911
                        param_index=index,
912
                    )
913
                )
914
            )
915

916
    def _delayed_all_reduce_hook(self, grad):
917
        world_size = dist.get_world_size(self.process_group)
918

919
        self._delay_grad_buffer.div_(world_size)  # type: ignore[union-attr]
920
        _ = dist.all_reduce(
921
            self._delay_grad_buffer, group=self.process_group, async_op=True
922
        )
923
        return grad
924

925
    def _register_delay_all_reduce_hook(
926
        self,
927
        bucket_cap_mb,
928
        param_to_hook_all_reduce,
929
        device_ids,
930
    ):
931
        # 1. Create gradient buffer
932
        device = torch.device("cpu") if device_ids is None else device_ids[0]
933
        self._delay_grad_buffer = torch.zeros(
934
            sum([p.numel() for p in self._delay_all_reduce_params]),
935
            device=device,
936
        )
937

938
        # 2. Broadcast the parameters
939
        detached_params = [p.detach() for p in self._delay_all_reduce_params]
940
        dist._broadcast_coalesced(self.process_group, detached_params, bucket_cap_mb, 0)
941

942
        # 3. Hook all reduce to the specified parameter
943
        param_to_hook_all_reduce.register_hook(self._delayed_all_reduce_hook)
944

945
        # 4. Build tensor views for gradients
946
        offset = 0
947
        for param in self._delay_all_reduce_params:
948
            grad_view = self._delay_grad_buffer[offset : (offset + param.numel())].view(
949
                param.shape
950
            )
951
            self._delay_grad_views.append(grad_view)
952
            offset = offset + param.numel()
953

954
        # 5. Check whether the all reduce of all params requiring grad is delayed.
955
        for module_name, module in self.module.named_modules():
956
            for param_name, param in module.named_parameters(recurse=False):
957
                if param.requires_grad:
958
                    full_name = f"{module_name}.{param_name}"
959
                    if full_name not in self.parameters_to_ignore:
960
                        # There is at least a param whose all reduce will not be delayed.
961
                        # In this case, we should not set self._delay_all_reduce_all_params
962
                        # to True.
963
                        return
964
        self._delay_all_reduce_all_params = True
965

966
    def _setup_in_backward_optimizers(self):
967
        # Check if user has used apply_optim_in_backward to overlap optimizer
968
        # step + DDP backward. Current constraints:
969
        # 1. Only allreduce is supported at the moment, no custom communication.
970
        # 2. For DDP-managed parameters that have their optimizer run in
971
        # backward, their gradients are set to ``None``. If your use case
972
        # requires DDP parameters grad not to be set to ``None`` after their
973
        # in-backward optimizer runs, please ping
974
        # https://github.com/pytorch/pytorch/issues/90052.
975
        # NOTE: we use self._module_parameters instead of .parameters() since
976
        # the former excludes ignored (non-DDP managed) parameters.
977
        if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters):
978
            torch._C._log_api_usage_once("ddp.optimizer_in_backward")
979
            # Remove hooks that apply_optim_in_backward had registered because
980
            # DDP customizes how optimizer is overlapped with backward due to
981
            # the allreduce.
982
            param_to_handle_map = (
983
                dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map
984
            )
985
            for p in self._module_parameters:
986
                for handle in param_to_handle_map.get(p, []):
987
                    handle.remove()
988

989
            # Need a weakref to DDP instance to run all_reduce (from reducer)
990
            # and get managed DDP parameters.
991
            ddp_weakref = weakref.ref(self)
992
            # Note: importing in function, otherwise this will cause a circular
993
            # import.
994
            from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
995
                _apply_optim_in_backward_hook,
996
            )
997

998
            self.register_comm_hook(
999
                ddp_weakref,
1000
                _apply_optim_in_backward_hook(
1001
                    gradient_is_bucket_view=self.gradient_as_bucket_view
1002
                ),
1003
            )
1004

1005
            self.reducer._set_optimizer_in_backward()  # type: ignore[attr-defined]
1006

1007
    def _fire_reducer_autograd_hook(self, idx, *unused):
1008
        """
1009
        Fire the reducer's autograd hook to allreduce params in a Reducer bucket.
1010

1011
        Note that this is only used during mixed precision training as the
1012
        Reducer's hooks installed during construction time would not be called
1013
        as we're working in the low precision parameter setting.
1014
        """
1015
        self.reducer._autograd_hook(idx)  # type: ignore[attr-defined]
1016

1017
    def _root_copy_hook(self, *args: Any, **kwargs: Any) -> None:
1018
        """
1019
        For DDP mixed precision, put low precision copies on separate stream and create events to wait for them.
1020

1021
        When training with DDP mixed precision, this root pre-forward hook kicks
1022
        off low precision copies on a separate stream and creates respective
1023
        events to wait for them.
1024
        """
1025
        # Clear out previous iteration submodule to event. This is because we
1026
        # may have populated some events for modules that didn't end up being
1027
        # used.
1028
        self._submodule_to_event = defaultdict(deque)  # type: ignore[var-annotated]
1029
        with torch.cuda.stream(self._mp_stream):
1030
            for submodule in self.module.modules():
1031
                for param in submodule.parameters(recurse=False):
1032
                    # Do not cast DDP ignored parameters.
1033
                    if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
1034
                        continue
1035
                    _alloc_storage(param._mp_param, param.size())
1036
                    # copy() implicitly casts to low precision
1037
                    with torch.no_grad():
1038
                        param._mp_param.copy_(param.data)
1039
                        # TODO: when zero_grad(set_to_none=False) or in grad
1040
                        # accumulation case, accumulated grads can be in fp32
1041
                        # which can cause errors when running DDP backwards due
1042
                        # to mismatched incoming and accumulated gradient types.
1043
                        # So we manually cast the accumulated grad down for now,
1044
                        # in the future we may shift to FSDP style gradient
1045
                        # accumulation management where the accumulated gradient
1046
                        # is saved and .grad field is set to None, bypassing
1047
                        # this issue.
1048
                        if param.grad is not None:
1049
                            param.grad.data = param.grad.to(
1050
                                self.mixed_precision.param_dtype  # type: ignore[union-attr]
1051
                            )
1052
                    param.data = param._mp_param
1053
                copy_event = torch.cuda.Event()
1054
                copy_event.record()
1055
                self._submodule_to_event[submodule].append(copy_event)
1056

1057
    def _module_wait_for_copy_hook(
1058
        self,
1059
        module,
1060
        *args: Any,
1061
        **kwargs: Any,
1062
    ) -> None:
1063
        """Before carrying out computation, wait on the appropriate event to ensure low precision copies have finished."""
1064
        try:
1065
            event = self._submodule_to_event[module].popleft()
1066
        except IndexError:
1067
            # copy event has already been waited on
1068
            return
1069

1070
        event.wait(stream=torch.cuda.current_stream())
1071
        for p in module.parameters(recurse=False):
1072
            # Don't register hooks if param does not require grad
1073
            if not p.requires_grad or (hasattr(p, "_ddp_ignored") and p._ddp_ignored):
1074
                continue
1075
            # We need to register autograd hook here instead of DDP's ctor
1076
            # since we're working with the low precision param. Register them
1077
            # via obtaining the gradient accumulator.
1078
            tmp = p.expand_as(p)
1079
            grad_acc = tmp.grad_fn.next_functions[0][0]
1080

1081
            hook = grad_acc.register_hook(
1082
                functools.partial(self._fire_reducer_autograd_hook, p._idx)
1083
            )
1084
            p._ddp_mp_hook_state = (grad_acc, hook)
1085

1086
    def _log_and_throw(self, err_type, err_msg):
1087
        if self.logger is not None:
1088
            self.logger.set_error_and_log(f"{str(err_type)}: {err_msg}")
1089
        raise err_type(err_msg)
1090

1091
    def _ddp_init_helper(
1092
        self,
1093
        parameters,
1094
        expect_sparse_gradient,
1095
        param_to_name_mapping,
1096
        static_graph,
1097
    ):
1098
        """
1099
        DDP init helper function to manage parameters, grad hooks, logging, and SyncBatchNorm.
1100

1101
        Initialization helper function that does the following:
1102
        (1) bucketing the parameters for reductions
1103
        (2) resetting the bucketing states
1104
        (3) registering the grad hooks
1105
        (4) Logging construction-time DDP logging data
1106
        (5) passing a handle of DDP to SyncBatchNorm Layer
1107
        """
1108
        # Notice, the parameters order is not in the order in which they are used,
1109
        # especially in models with control flow.
1110
        #
1111
        # Alongside parameters are not presented in the real execution order,
1112
        # if a certain model happens to also
1113
        #   1) have other collectives comm ops in its backward graph.
1114
        #   2) have unused parameter in subset ranks of the whole world.
1115
        # bucketing could insert ALL-REDUCE comm op too early on the rank with unused parameter,
1116
        # matching up with other collectives comm ops on other ranks unexpectedly.
1117
        #
1118
        # In order to handle this corner case, when the parameters are not in the real execution order,
1119
        # we don't do bucketing, thus only one ALL-REDUCE is inserted after all the gradients
1120
        # of the whole graph are computed.
1121
        #
1122
        # Notice, here we only disable bucketing for the first iteration.
1123
        # After the first iteration, it's OK to rebuild buckets,
1124
        # because "bucket rebuild" bucketizes parameters based on its real execution order in backward graph.
1125

1126
        # Can remove this branching once #73732 is landed.
1127
        if static_graph is True or self.find_unused_parameters is False:
1128
            bucket_size_limits = [sys.maxsize]
1129
        else:
1130
            bucket_size_limits = [
1131
                dist._DEFAULT_FIRST_BUCKET_BYTES,
1132
                self.bucket_bytes_cap,
1133
            ]
1134
        (
1135
            bucket_indices,
1136
            per_bucket_size_limits,
1137
        ) = dist._compute_bucket_assignment_by_size(
1138
            parameters,
1139
            bucket_size_limits,
1140
            expect_sparse_gradient,
1141
        )
1142

1143
        # Remember index for parameters if we are in mixed precision, as we
1144
        # need to pass in index to Reducer's autograd hook via python.
1145
        if self.mixed_precision is not None:
1146
            for i, p in enumerate(parameters):
1147
                p._idx = i
1148

1149
        # Note: reverse list of buckets because we want to approximate the
1150
        # order in which their gradients are produced, and assume they
1151
        # are used in the forward pass in the order they are defined.
1152
        self.reducer = dist.Reducer(
1153
            parameters,
1154
            list(reversed(bucket_indices)),
1155
            list(reversed(per_bucket_size_limits)),
1156
            self.process_group,
1157
            expect_sparse_gradient,
1158
            # The bucket size limit is specified in the constructor.
1159
            # Additionally, we allow for a single small bucket for parameters
1160
            # that are defined first, such that their gradients don't spill into
1161
            # a much larger bucket, adding unnecessary latency after gradient
1162
            # computation finishes. Experiments showed 1MB is a reasonable value.
1163
            self.bucket_bytes_cap,
1164
            self.find_unused_parameters,
1165
            self.gradient_as_bucket_view,
1166
            param_to_name_mapping,
1167
            # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first
1168
            # bucket.
1169
            dist._DEFAULT_FIRST_BUCKET_BYTES,
1170
        )
1171

1172
        self.logger = dist.Logger(self.reducer)
1173
        # Set as a weak reference to avoid reference cycle between
1174
        # logger and reducer.
1175
        self.reducer.set_logger(self.logger)
1176

1177
        has_sync_bn = False
1178
        for submodule in self.module.modules():
1179
            if isinstance(submodule, torch.nn.SyncBatchNorm):
1180
                has_sync_bn = True
1181
                break
1182

1183
        # Set logging data that can be got during construction time.
1184
        self.logger.set_construction_data_and_log(
1185
            self.module.__class__.__name__,
1186
            [] if self.device_ids is None else self.device_ids,
1187
            -1 if self.output_device is None else self.output_device,
1188
            self.broadcast_buffers,
1189
            has_sync_bn,
1190
            static_graph,
1191
        )
1192

1193
        # passing a handle to torch.nn.SyncBatchNorm layer
1194
        self._passing_sync_batchnorm_handle(self.module)
1195

1196
    def __getstate__(self):
1197
        self._check_default_group()
1198
        attrs = copy.copy(self.__dict__)
1199
        del attrs["process_group"]
1200
        del attrs["reducer"]
1201
        del attrs["logger"]
1202
        return attrs
1203

1204
    def __setstate__(self, state):
1205
        # If serializable, then the process group should be the default one
1206
        self.process_group = _get_default_group()
1207
        super().__setstate__(state)
1208
        self.__dict__.setdefault("require_forward_param_sync", True)
1209
        self.__dict__.setdefault("require_backward_grad_sync", True)
1210
        parameters, expect_sparse_gradient = self._build_params_for_reducer()
1211
        # In debug mode, build a mapping of parameter index -> parameter.
1212
        param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
1213
        # Builds reducer.
1214
        self._ddp_init_helper(
1215
            parameters,
1216
            expect_sparse_gradient,
1217
            param_to_name_mapping,
1218
            self.static_graph,
1219
        )
1220
        if self.static_graph:
1221
            self.reducer._set_static_graph()
1222
            assert self.logger is not None
1223
            self.logger._set_static_graph()
1224

1225
    def _build_params_for_reducer(self):
1226
        # Build tuple of (module, parameter) for all parameters that require grads.
1227
        modules_and_parameters = [
1228
            (module, parameter)
1229
            for module_name, module in self.module.named_modules()
1230
            for parameter in [
1231
                param
1232
                # Note that we access module.named_parameters instead of
1233
                # parameters(module). parameters(module) is only needed in the
1234
                # single-process multi device case, where it accesses replicated
1235
                # parameters through _former_parameters.
1236
                for param_name, param in module.named_parameters(recurse=False)
1237
                if param.requires_grad
1238
                and f"{module_name}.{param_name}" not in self.parameters_to_ignore
1239
            ]
1240
        ]
1241

1242
        # Deduplicate any parameters that might be shared across child modules.
1243
        memo = set()
1244
        modules_and_parameters = [
1245
            # "p not in memo" is the deduplication check.
1246
            # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
1247
            (m, p)
1248
            for m, p in modules_and_parameters
1249
            if p not in memo and not memo.add(p)  # type: ignore[func-returns-value]
1250
        ]
1251

1252
        # Build list of parameters.
1253
        parameters = [parameter for _, parameter in modules_and_parameters]
1254

1255
        # Checks if a module will produce a sparse gradient.
1256
        def produces_sparse_gradient(module):
1257
            if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
1258
                return module.sparse
1259
            return False
1260

1261
        # Build list of booleans indicating whether or not to expect sparse
1262
        # gradients for the corresponding parameters.
1263
        expect_sparse_gradient = [
1264
            produces_sparse_gradient(module) for module, _ in modules_and_parameters
1265
        ]
1266

1267
        self._assign_modules_buffers()
1268

1269
        return parameters, expect_sparse_gradient
1270

1271
    def _assign_modules_buffers(self):
1272
        """
1273
        Assign self.module.named_buffers to self.modules_buffers.
1274

1275
        Assigns module buffers to self.modules_buffers which are then used to
1276
        broadcast across ranks when broadcast_buffers=True. Note that this
1277
        must be called every time buffers need to be synced because buffers can
1278
        be reassigned by user module,
1279
        see https://github.com/pytorch/pytorch/issues/63916.
1280
        """
1281
        # Collect buffers for modules, filtering out buffers that should be ignored.
1282
        named_module_buffers = [
1283
            (buffer, buffer_name)
1284
            for buffer_name, buffer in self.module.named_buffers()
1285
            if buffer_name not in self.parameters_to_ignore
1286
        ]
1287
        self.modules_buffers = [
1288
            buffer for (buffer, buffer_name) in named_module_buffers
1289
        ]
1290
        # Dict[str, tensor] representing module buffers not ignored by DDP.
1291
        self.named_module_buffers = {
1292
            buffer_name: buffer for (buffer, buffer_name) in named_module_buffers
1293
        }
1294

1295
    def _build_debug_param_to_name_mapping(self, parameters):
1296
        param_to_param_index = {parameters[i]: i for i in range(len(parameters))}
1297
        param_set = set(parameters)
1298
        param_index_to_param_fqn = {}
1299
        for module_name, module in self.module.named_modules():
1300
            for param_name, param in module.named_parameters(recurse=False):
1301
                fqn = f"{module_name}.{param_name}"
1302
                # Bypass ignored parameters since those are not reduced by DDP
1303
                # to begin with.
1304
                if fqn not in self.parameters_to_ignore and param.requires_grad:
1305
                    if param not in param_set:
1306
                        self._log_and_throw(
1307
                            ValueError,
1308
                            f"Param with name {fqn} found in module parameters, but not DDP parameters."
1309
                            " This indicates a bug in DDP, please report an issue to PyTorch.",
1310
                        )
1311
                    param_index = param_to_param_index[param]
1312
                    param_index_to_param_fqn[param_index] = fqn
1313

1314
        # Ensure we covered all parameters
1315
        if len(param_set) != len(param_index_to_param_fqn):
1316
            self._log_and_throw(
1317
                ValueError,
1318
                (
1319
                    "Expected param to name mapping to cover all parameters, but"
1320
                    f" got conflicting lengths: {len(param_set)} vs "
1321
                    f"{len(param_index_to_param_fqn)}. This indicates a bug in DDP"
1322
                    ", please report an issue to PyTorch."
1323
                ),
1324
            )
1325

1326
        return param_index_to_param_fqn
1327

1328
    def _get_parameters(self, m, recurse=True):
1329
        """Return a generator of module parameters."""
1330

1331
        def model_parameters(m):
1332
            ps = (
1333
                m._former_parameters.values()
1334
                if hasattr(m, "_former_parameters")
1335
                else m.parameters(recurse=False)
1336
            )
1337
            yield from ps
1338

1339
        for mod in m.modules() if recurse else [m]:
1340
            yield from model_parameters(mod)
1341

1342
    def _check_default_group(self):
1343
        pickle_not_supported = False
1344
        try:
1345
            if self.process_group != _get_default_group():
1346
                pickle_not_supported = True
1347
        except RuntimeError:
1348
            pickle_not_supported = True
1349

1350
        if pickle_not_supported:
1351
            self._log_and_throw(
1352
                RuntimeError,
1353
                "DDP Pickling/Unpickling are only supported "
1354
                "when using DDP with the default process "
1355
                "group. That is, when you have called "
1356
                "init_process_group and have not passed "
1357
                "process_group argument to DDP constructor",
1358
            )
1359

1360
    @contextmanager
1361
    def no_sync(self):
1362
        r"""
1363
        Context manager to disable gradient synchronizations across DDP processes.
1364

1365
        Within this context, gradients will be accumulated on module
1366
        variables, which will later be synchronized in the first
1367
        forward-backward pass exiting the context.
1368

1369
        Example::
1370

1371
            >>> # xdoctest: +SKIP("undefined variables")
1372
            >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
1373
            >>> with ddp.no_sync():
1374
            >>>     for input in inputs:
1375
            >>>         ddp(input).backward()  # no synchronization, accumulate grads
1376
            >>> ddp(another_input).backward()  # synchronize grads
1377

1378
        .. warning::
1379
            The forward pass should be included inside the context manager, or
1380
            else gradients will still be synchronized.
1381
        """
1382
        old_require_backward_grad_sync = self.require_backward_grad_sync
1383
        self.require_backward_grad_sync = False
1384
        try:
1385
            yield
1386
        finally:
1387
            self.require_backward_grad_sync = old_require_backward_grad_sync
1388

1389
    @classmethod
1390
    def _get_active_ddp_module(cls):
1391
        """`TorchDynamo` requires DDP's status and module for cooperative optimization."""
1392
        return cls._active_ddp_module
1393

1394
    # note, this ctxmgr function is marked 'skip' in torchdynamo, so dynamo only kicks in
1395
    # for the 'module_to_run' underneath
1396
    # see torch._dynamo/eval_frame.py TorchPatcher.patch for more details
1397
    @contextmanager
1398
    @torch._disable_dynamo(recursive=False)
1399
    def _inside_ddp_forward(self):
1400
        DistributedDataParallel._active_ddp_module = self
1401
        try:
1402
            yield
1403
        finally:
1404
            DistributedDataParallel._active_ddp_module = None
1405

1406
    def _run_ddp_forward(self, *inputs, **kwargs):
1407
        if self._use_python_reducer:
1408
            return self.module(*inputs, **kwargs)  # type: ignore[index]
1409
        else:
1410
            with self._inside_ddp_forward():
1411
                return self.module(*inputs, **kwargs)  # type: ignore[index]
1412

1413
    def _clear_grad_buffer(self):
1414
        # Making param.grad points to the grad buffers before backward is based on the
1415
        # assumption that the grad accumulation is done in place in autograd engine,
1416
        # for some edge cases, if the grad accumulation in autograd engine is not in
1417
        # place, then the param.grad and grad buffers are detached.
1418
        if self._delay_grad_buffer is not None:
1419
            # We batch zero_grad for all params by resetting the whole grad
1420
            # buffer when the grad of all params is set to None.
1421
            all_param_grad_none = all(
1422
                param.grad is None for param in self._delay_all_reduce_params
1423
            )
1424

1425
            for index, param in enumerate(self._delay_all_reduce_params):
1426
                if param.grad is None:
1427
                    param.grad = self._delay_grad_views[index]
1428
                    if not all_param_grad_none:
1429
                        param.grad.zero_()
1430

1431
            if all_param_grad_none:
1432
                self._delay_grad_buffer.zero_()
1433

1434
    def _lazy_init(self):
1435
        # Initialization for DDP that occurs after construction, but lazily
1436
        # before the first forward pass.
1437
        self._setup_in_backward_optimizers()
1438
        self._lazy_init_ran = True
1439

1440
    def _should_disable_cpp_reducer(self) -> bool:
1441
        return self._use_python_reducer and (
1442
            torch._utils.is_compiling() or self._force_to_disable_cpp_reducer
1443
        )
1444

1445
    def _pre_forward(self, *inputs, **kwargs):
1446
        if self._should_disable_cpp_reducer():
1447
            return inputs, kwargs
1448

1449
        # Disable the python reducer if compiled_autograd is not enabled.
1450
        if self._accum_grad_hooks:
1451
            for index, h in enumerate(self._accum_grad_hooks):
1452
                h.remove()
1453
            self._accum_grad_hooks.clear()
1454

1455
        if not self._lazy_init_ran and not torch._utils.is_compiling():
1456
            self._lazy_init()
1457

1458
        if self._delay_all_reduce_all_params:
1459
            return inputs, kwargs
1460

1461
        if torch.is_grad_enabled() and self.require_backward_grad_sync:
1462
            assert self.logger is not None
1463
            self.logger.set_runtime_stats_and_log()
1464
            self.reducer.prepare_for_forward()
1465

1466
        # Notify the join context that this process has not joined, if
1467
        # needed
1468
        work = Join.notify_join_context(self)
1469
        if work:
1470
            self.reducer._set_forward_pass_work_handle(
1471
                work, self._divide_by_initial_world_size  # type: ignore[arg-type]
1472
            )
1473

1474
        # Calling _rebuild_buckets before forward computation,
1475
        # It may allocate new buckets before deallocating old buckets
1476
        # inside _rebuild_buckets. To save peak memory usage,
1477
        # call _rebuild_buckets before the peak memory usage increases
1478
        # during forward computation.
1479
        # This should be called only once during whole training period.
1480
        if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
1481
            logger.info("Reducer buckets have been rebuilt in this iteration.")
1482
            self._has_rebuilt_buckets = True
1483

1484
        # sync params according to location (before/after forward) user
1485
        # specified as part of hook, if hook was specified.
1486
        if self._check_sync_bufs_pre_fwd():
1487
            self._sync_buffers()
1488

1489
        if self._join_config.enable:
1490
            # Notify joined ranks whether they should sync in backwards pass or not.
1491
            self._check_global_requires_backward_grad_sync(is_joined_rank=False)
1492

1493
        if self.device_ids:
1494
            moved_inputs, moved_kwargs = _to_kwargs(
1495
                inputs,
1496
                kwargs,
1497
                torch.device(self.device_type, self.device_ids[0]),
1498
                self.use_side_stream_for_tensor_copies,
1499
            )
1500
            args, kwargs = moved_inputs[0], moved_kwargs[0]
1501
            # Cast inputs to reduced precision if needed.
1502
            if self.mixed_precision is not None:
1503
                args, kwargs = _cast_forward_inputs(
1504
                    self.mixed_precision.param_dtype,
1505
                    *args,
1506
                    **kwargs,
1507
                )
1508
            return args, kwargs
1509
        else:
1510
            # Cast inputs to reduced precision if needed.
1511
            # TODO (rohan-varma) test this codepath.
1512
            if self.mixed_precision is not None:
1513
                inputs, kwargs = _cast_forward_inputs(
1514
                    self.mixed_precision.param_dtype,
1515
                    *inputs,
1516
                    **kwargs,
1517
                )
1518
            return inputs, kwargs
1519

1520
    def _post_forward(self, output):
1521
        if self._should_disable_cpp_reducer():
1522
            return output
1523

1524
        if self._delay_all_reduce_all_params:
1525
            self._clear_grad_buffer()
1526
            return output
1527

1528
        # sync params according to location (before/after forward) user
1529
        # specified as part of hook, if hook was specified.
1530
        if self._check_sync_bufs_post_fwd():
1531
            self._sync_buffers()
1532

1533
        if torch.is_grad_enabled() and self.require_backward_grad_sync:
1534
            self.require_forward_param_sync = True
1535
            # We'll return the output object verbatim since it is a freeform
1536
            # object. We need to find any tensors in this object, though,
1537
            # because we need to figure out which parameters were used during
1538
            # this forward pass, to ensure we short circuit reduction for any
1539
            # unused parameters. Only if `find_unused_parameters` is set.
1540
            if self.find_unused_parameters and not self.static_graph:
1541
                # Do not need to populate this for static graph.
1542
                self.reducer.prepare_for_backward(list(_find_tensors(output)))
1543
            else:
1544
                self.reducer.prepare_for_backward([])
1545
        else:
1546
            self.require_forward_param_sync = False
1547

1548
        # TODO: DDPSink is currently enabled for unused parameter detection and
1549
        # static graph training for first iteration.
1550
        if (self.find_unused_parameters and not self.static_graph) or (
1551
            self.static_graph and not self._static_graph_delay_allreduce_enqueued
1552
        ):
1553
            (
1554
                output_tensor_list,
1555
                treespec,
1556
                output_is_rref,
1557
            ) = _tree_flatten_with_rref(output)
1558
            output_placeholders = [None for _ in range(len(output_tensor_list))]
1559
            # Do not touch tensors that have no grad_fn, which can cause issues
1560
            # such as https://github.com/pytorch/pytorch/issues/60733
1561
            for i, output in enumerate(output_tensor_list):
1562
                if torch.is_tensor(output) and output.grad_fn is None:
1563
                    output_placeholders[i] = output
1564

1565
            # When find_unused_parameters=True, makes tensors which require grad
1566
            # run through the DDPSink backward pass. When not all outputs are
1567
            # used in loss, this makes those corresponding tensors receive
1568
            # undefined gradient which the reducer then handles to ensure
1569
            # param.grad field is not touched and we don't error out.
1570
            passthrough_tensor_list = _DDPSink.apply(
1571
                weakref.ref(self),
1572
                *output_tensor_list,
1573
            )
1574
            for i in range(len(output_placeholders)):
1575
                if output_placeholders[i] is None:
1576
                    output_placeholders[i] = passthrough_tensor_list[i]
1577

1578
            # Reconstruct output data structure.
1579
            output = _tree_unflatten_with_rref(
1580
                output_placeholders, treespec, output_is_rref
1581
            )
1582

1583
        # At the end of the forward pass, reset the grad buffer and grad views
1584
        self._clear_grad_buffer()
1585
        return output
1586

1587
    def forward(self, *inputs, **kwargs):
1588
        with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
1589
            inputs, kwargs = self._pre_forward(*inputs, **kwargs)
1590
            output = (
1591
                self.module.forward(*inputs, **kwargs)
1592
                if self._delay_all_reduce_all_params
1593
                else self._run_ddp_forward(*inputs, **kwargs)
1594
            )
1595
            return self._post_forward(output)
1596

1597
    def scatter(self, inputs, kwargs, device_ids):
1598
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
1599

1600
    def to_kwargs(self, inputs, kwargs, device_id):
1601
        # Kept for BC
1602
        return _to_kwargs(
1603
            inputs,
1604
            kwargs,
1605
            torch.device(self.device_type, device_id),
1606
            self.use_side_stream_for_tensor_copies,
1607
        )
1608

1609
    def gather(self, outputs, output_device):
1610
        return gather(outputs, output_device, dim=self.dim)
1611

1612
    def train(self, mode=True):
1613
        super().train(mode)
1614
        return self
1615

1616
    # When running in join mode, schedules an allreduce to notify joined ranks
1617
    # of whether backwards pass synchronization will run this iteration or not.
1618
    def _check_global_requires_backward_grad_sync(self, is_joined_rank):
1619
        if not is_joined_rank and self.require_backward_grad_sync:
1620
            requires_sync_tensor = torch.ones(1, device=self.device)
1621
        else:
1622
            requires_sync_tensor = torch.zeros(1, device=self.device)
1623

1624
        work = dist.all_reduce(
1625
            requires_sync_tensor, group=self.process_group, async_op=True
1626
        )
1627

1628
        # (kwen2501) This if condition is a plain translation of previous
1629
        # behavior, i.e. in the `is_joined_rank=False` case, `work.wait()`
1630
        # is not called and it doesn't care about the result. I am guessing
1631
        # that it just wants to fire a matching all-reduce and does not want
1632
        # the main stream to wait.
1633
        if is_joined_rank:
1634
            work.wait()
1635
            should_sync_backwards = requires_sync_tensor.item() != 0
1636
            return should_sync_backwards
1637
        else:
1638
            return None  # Return value is not/should not be used.
1639

1640
    # When running in join mode, checks and performs sync of module buffers if
1641
    # the models have buffers that should be synchronized in the forward pass.
1642
    def _check_and_sync_module_buffers(self):
1643
        if self._check_sync_bufs_pre_fwd():
1644
            authoritative_rank = self._find_common_rank(self._distributed_rank, False)
1645
            self._sync_module_buffers(authoritative_rank)
1646

1647
    # When running in join model, agrees upon a common rank and broadcast model
1648
    # parameters to all other ranks.
1649
    def _sync_final_model(self, is_last_joiner):
1650
        # Agree upon the process that will be the authoritative model copy.
1651
        # The current rank is a candidate for being the authoritative copy if
1652
        # is_last_joiner=True. We break ties via picking the larger rank.
1653
        self._authoritative_rank = self._find_common_rank(
1654
            self._distributed_rank, is_last_joiner
1655
        )
1656
        _sync_module_states(
1657
            module=self.module,
1658
            process_group=self.process_group,
1659
            broadcast_bucket_size=self.broadcast_bucket_size,
1660
            src=self._authoritative_rank,
1661
            params_and_buffers_to_ignore=self.parameters_to_ignore,
1662
            broadcast_buffers=self.broadcast_buffers,
1663
        )
1664

1665
    # Schedule comm ops to match those scheduled in the reducer's backward
1666
    # pass.
1667
    def _match_all_reduce_for_bwd_pass(self):
1668
        comm_work = []
1669
        # Schedule comm in the same order as Reducer schedules them, i.e.
1670
        # the order of the buckets. Retrieving the bucket order from the reducer
1671
        # ensures that we keep the same order in join mode, such as when bucket
1672
        # order is rebuilt dynamically.
1673

1674
        # Returns grad_buckets in order, but real tensors are substituted with
1675
        # zero tensors of the same shape.
1676
        grad_buckets = self.reducer._get_zeros_like_grad_buckets()
1677
        for grad_bucket in grad_buckets:
1678
            # Joined processes contribute zero gradient. In the case that
1679
            # divide_by_initial_world_size=True, we divide grads by the static
1680
            # world size, if not, the dividing factor is reduced by the number
1681
            # of joined processes.
1682
            work = self.reducer._run_comm_hook(grad_bucket)
1683
            comm_work.append(work)
1684
        for work in comm_work:
1685
            work.wait()
1686

1687
    # Allreduces the used parameter mapping across ranks.
1688
    def _match_unused_params_allreduce(self):
1689
        locally_used_param_map = self.reducer._get_local_used_map()
1690
        self.process_group.allreduce(locally_used_param_map)
1691

1692
    def join(
1693
        self,
1694
        divide_by_initial_world_size: bool = True,
1695
        enable: bool = True,
1696
        throw_on_early_termination: bool = False,
1697
    ):
1698
        r"""
1699
        Context manager for training with uneven inputs across processes in DDP.
1700

1701
        This context manager will keep track of already-joined DDP processes,
1702
        and "shadow" the forward and backward passes by inserting collective
1703
        communication operations to match with the ones created by non-joined
1704
        DDP processes. This will ensure each collective call has a corresponding
1705
        call by already-joined DDP processes, preventing hangs or errors that
1706
        would otherwise happen when training with uneven inputs across
1707
        processes. Alternatively, if the flag ``throw_on_early_termination`` is
1708
        specified to be ``True``, all trainers will throw an error once one rank
1709
        runs out of inputs, allowing these errors to be caught and handled
1710
        according to application logic.
1711

1712
        Once all DDP processes have joined, the context manager will broadcast
1713
        the model corresponding to the last joined process to all processes to
1714
        ensure the model is the same across all processes
1715
        (which is guaranteed by DDP).
1716

1717
        To use this to enable training with uneven inputs across processes,
1718
        simply wrap this context manager around your training loop. No further
1719
        modifications to the model or data loading is required.
1720

1721
        .. warning::
1722
            If the model or training loop this context manager is wrapped around
1723
            has additional distributed collective operations, such as
1724
            ``SyncBatchNorm`` in the model's forward pass, then the flag
1725
            ``throw_on_early_termination`` must be enabled. This is because this
1726
            context manager is not aware of non-DDP collective communication.
1727
            This flag will cause all ranks to throw when any one rank
1728
            exhausts inputs, allowing these errors to be caught and recovered
1729
            from across all ranks.
1730

1731
        Args:
1732
            divide_by_initial_world_size (bool): If ``True``, will divide
1733
                gradients by the initial ``world_size`` DDP training was launched
1734
                with. If ``False``, will compute the effective world size
1735
                (number of ranks that have not depleted their inputs yet) and
1736
                divide gradients by that during allreduce. Set
1737
                ``divide_by_initial_world_size=True`` to ensure every input
1738
                sample including the uneven inputs have equal weight in terms of
1739
                how much they contribute to the global gradient. This is
1740
                achieved by always dividing the gradient by the initial
1741
                ``world_size`` even when we encounter uneven inputs. If you set
1742
                this to ``False``, we divide the gradient by the remaining
1743
                number of nodes. This ensures parity with training on a smaller
1744
                ``world_size`` although it also means the uneven inputs would
1745
                contribute more towards the global gradient. Typically, you
1746
                would want to set this to ``True`` for cases where the last few
1747
                inputs of your training job are uneven. In extreme cases, where
1748
                there is a large discrepancy in the number of inputs, setting
1749
                this to ``False`` might provide better results.
1750
            enable (bool): Whether to enable uneven input detection or not. Pass
1751
                in ``enable=False`` to disable in cases where you know that
1752
                inputs are even across participating processes. Default is
1753
                ``True``.
1754
            throw_on_early_termination (bool): Whether to throw an error
1755
                or continue training when at least one rank has exhausted
1756
                inputs. If ``True``, will throw upon the first rank reaching end
1757
                of data. If ``False``, will continue training with a smaller
1758
                effective world size until all ranks are joined. Note that if
1759
                this flag is specified, then the flag
1760
                ``divide_by_initial_world_size`` would be ignored. Default
1761
                is ``False``.
1762

1763

1764
        Example::
1765

1766
            >>> # xdoctest: +SKIP("Distributed")
1767
            >>> import torch
1768
            >>> import torch.distributed as dist
1769
            >>> import os
1770
            >>> import torch.multiprocessing as mp
1771
            >>> import torch.nn as nn
1772
            >>> # On each spawned worker
1773
            >>> def worker(rank):
1774
            >>>     dist.init_process_group("nccl", rank=rank, world_size=2)
1775
            >>>     torch.cuda.set_device(rank)
1776
            >>>     model = nn.Linear(1, 1, bias=False).to(rank)
1777
            >>>     model = torch.nn.parallel.DistributedDataParallel(
1778
            >>>         model, device_ids=[rank], output_device=rank
1779
            >>>     )
1780
            >>>     # Rank 1 gets one more input than rank 0.
1781
            >>>     inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
1782
            >>>     with model.join():
1783
            >>>         for _ in range(5):
1784
            >>>             for inp in inputs:
1785
            >>>                 loss = model(inp).sum()
1786
            >>>                 loss.backward()
1787
            >>>     # Without the join() API, the below synchronization will hang
1788
            >>>     # blocking for rank 1's allreduce to complete.
1789
            >>>     torch.cuda.synchronize(device=rank)
1790
        """
1791
        return Join(
1792
            [self],
1793
            enable,
1794
            throw_on_early_termination,
1795
            divide_by_initial_world_size=divide_by_initial_world_size,
1796
        )
1797

1798
    def join_hook(
1799
        self,
1800
        **kwargs,
1801
    ):
1802
        r"""
1803
        DDP join hook enables training on uneven inputs by mirroring communications in forward and backward passes.
1804

1805
        Arguments:
1806
            kwargs (dict): a :class:`dict` containing any keyword arguments
1807
                to modify the behavior of the join hook at run time; all
1808
                :class:`Joinable` instances sharing the same join context
1809
                manager are forwarded the same value for ``kwargs``.
1810

1811
        The hook supports the following keyword arguments:
1812
            divide_by_initial_world_size (bool, optional):
1813
                If ``True``, then gradients are divided by the initial world
1814
                size that DDP was launched with.
1815
                If ``False``, then gradients are divided by the effective world
1816
                size (i.e. the number of non-joined processes), meaning that
1817
                the uneven inputs contribute more toward the global gradient.
1818
                Typically, this should be set to ``True`` if the degree of
1819
                unevenness is small but can be set to ``False`` in extreme
1820
                cases for possibly better results.
1821
                Default is ``True``.
1822
        """
1823
        divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True)
1824
        return _DDPJoinHook(
1825
            self, divide_by_initial_world_size=divide_by_initial_world_size
1826
        )
1827

1828
    @property
1829
    def join_device(self):
1830
        return self.device
1831

1832
    @property
1833
    def join_process_group(self):
1834
        return self.process_group
1835

1836
    def _register_buffer_comm_hook(
1837
        self,
1838
        state,
1839
        hook: Callable,
1840
        comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
1841
    ):
1842
        r"""
1843
        Allow custom registration of hooks that define how buffer are synchronized across ranks.
1844

1845
        The hook takes in an optional state and is passed in a Dict[str, Tensor]
1846
        corresponding to buffer names and the buffers, and can run arbitrary reductions
1847
        on buffers as opposed to DDP's default broadcast from rank 0. This is useful for
1848
        example if a counter needs to be summed or averaged across ranks every iteration.
1849

1850
        Args:
1851
            state (Any): Optional state that is passed to the hook.
1852
            hook (Callable): Callable with the following signature:
1853
                         ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``
1854
            comm_hook_location (_BufferCommHookLocation): Enum value indicating
1855
                            where to run the hook.
1856
                            _BufferCommHookLocation.PRE_FORWARD means that the
1857
                            hook will run _before_ the forward pass, and
1858
                            _BufferCommHookLocation.POST_FORWARD means that the
1859
                            hook will run _after_ the forward pass.
1860

1861
            NOTE: To maximize performance, users can return a
1862
                List[torch.futures.Future] from their hook, and DDP will
1863
                install and await these hooks appropriately at the end of
1864
                the backward pass. This will ensure all buffers are
1865
                synchronized by the end of the backward pass. If this
1866
                setting is used, it is recommended to pass
1867
                comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
1868
                which will trigger the hook after the forward pass.
1869
                If _BufferCommHookLocation.PRE_FORWARD is used, users must
1870
                ensure appropriate synchronization when manipulating GPU
1871
                buffers in the forward pass.
1872
        """
1873
        assert callable(hook)
1874
        self.buffer_hook = _BufferCommHook(
1875
            buffer_comm_hook=hook,
1876
            buffer_comm_hook_state=state,
1877
            buffer_comm_hook_location=comm_hook_location,
1878
        )
1879

1880
    def register_comm_hook(self, state: object, hook: Callable):
1881
        r"""
1882
        Register communication hook for user-defined DDP aggregation of gradients across multiple workers.
1883

1884
        This hook would be very useful for researchers to try out new ideas. For
1885
        example, this hook can be used to implement several algorithms like GossipGrad
1886
        and gradient compression which involve different communication strategies for
1887
        parameter syncs while running Distributed DataParallel training.
1888

1889
        Args:
1890
            state (object): Passed to the hook to maintain any state information during the training process.
1891
                            Examples include error feedback in gradient compression,
1892
                            peers to communicate with next in GossipGrad, etc.
1893

1894
                            It is locally stored by each worker
1895
                            and shared by all the gradient tensors on the worker.
1896
            hook (Callable): Callable with the following signature:
1897
                             ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``:
1898

1899
                             This function is called once the bucket is ready. The
1900
                             hook can perform whatever processing is needed and return
1901
                             a Future indicating completion of any async work (ex: allreduce).
1902
                             If the hook doesn't perform any communication, it still
1903
                             must return a completed Future. The Future should hold the
1904
                             new value of grad bucket's tensors. Once a bucket is ready,
1905
                             c10d reducer would call this hook and use the tensors returned
1906
                             by the Future and copy grads to individual parameters.
1907
                             Note that the future's return type must be a single tensor.
1908

1909
                             We also provide an API called ``get_future`` to retrieve a
1910
                             Future associated with the completion of ``c10d.ProcessGroup.Work``.
1911
                             ``get_future`` is currently supported for NCCL and also supported for most
1912
                             operations on GLOO and MPI, except for peer to peer operations (send/recv).
1913

1914
        .. warning ::
1915
            Grad bucket's tensors will not be predivided by world_size. User is responsible
1916
            to divide by the world_size in case of operations like allreduce.
1917

1918
        .. warning ::
1919
            DDP communication hook can only be registered once and should be registered
1920
            before calling backward.
1921

1922
        .. warning ::
1923
            The Future object that hook returns should contain a single tensor
1924
            that has the same shape with the tensors inside grad bucket.
1925

1926
        .. warning ::
1927
            ``get_future`` API supports NCCL, and partially GLOO and MPI backends (no support
1928
            for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``.
1929

1930
        Example::
1931
            Below is an example of a noop hook that returns the same tensor.
1932

1933
            >>> # xdoctest: +SKIP('undefined name')
1934
            >>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
1935
            >>>     fut = torch.futures.Future()
1936
            >>>     fut.set_result(bucket.buffer())
1937
            >>>     return fut
1938
            >>> ddp.register_comm_hook(state=None, hook=noop)
1939

1940
        Example::
1941
            Below is an example of a Parallel SGD algorithm where gradients are encoded before
1942
            allreduce, and then decoded after allreduce.
1943

1944
            >>> # xdoctest: +SKIP('undefined name')
1945
            >>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
1946
            >>>     encoded_tensor = encode(bucket.buffer())  # encode gradients
1947
            >>>     fut = torch.distributed.all_reduce(encoded_tensor).get_future()
1948
            >>>     # Define the then callback to decode.
1949
            >>>     def decode(fut):
1950
            >>>         decoded_tensor = decode(fut.value()[0])  # decode gradients
1951
            >>>         return decoded_tensor
1952
            >>>     return fut.then(decode)
1953
            >>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
1954
        """
1955
        self._check_comm_hook(hook)
1956
        if hook.__name__ in ["bf16_compress_hook", "fp16_compress_hook"]:
1957
            # If we pass None, then the hook will try to get the world size
1958
            # by calling `dist.group.WORLD.size()`, which causes compilation
1959
            # errors. So we pre-decode the process group and pass it to the
1960
            # hook.
1961
            if state is None:
1962
                state = dist.group.WORLD
1963
        assert self.logger is not None
1964
        self.logger._set_comm_hook_name(hook.__qualname__)
1965
        self._comm_hooks.append((hook, state))
1966
        dist._register_comm_hook(self.reducer, state, hook)
1967

1968
    def _register_builtin_comm_hook(self, comm_hook_type):
1969
        r"""
1970
        Register a built-in communication hook that specifies how DDP aggregates gradients across multiple workers.
1971

1972
        The built-in hooks aim to provide efficient C++ implementations for certain hooks,
1973
        which might not be as efficient if implemented in Python using a Python communication hook.
1974

1975
        Args:
1976
            comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as ALLREDUCE, FP16_COMPRESS, etc.
1977

1978
        .. warning ::
1979
            DDP communication hook can only be registered once and should be registered
1980
            before calling backward.
1981

1982
        Example::
1983
            Below is an example of a FP16 compression where gradients are
1984
            compressed into 16-bit floating-point numbers before allreduce, and
1985
            then decompressed after allreduce.
1986

1987
            >>> # xdoctest: +SKIP('undefined name')
1988
            >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS)
1989

1990
        """
1991
        assert self.logger is not None
1992
        self.logger._set_comm_hook_name(str(comm_hook_type))
1993
        dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
1994

1995
    def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs):
1996
        r"""
1997
        Register an optimizer in DDP to optimize parameter immediately after its gradient reduction.
1998

1999
        Registers an optimizer with DDP such that the optimization for a
2000
        parameter will run immediately when that parameter's gradient is
2001
        finished with reduction, instead of waiting for all parameters'
2002
        gradients to finish reduction. This can result in a training speedup
2003
        depending on your workload since the optimizer can run while gradient
2004
        reduction for other parameters are still ongoing. In addition, this has
2005
        the potential to reduce peak memory consumption during training, as it
2006
        only needs to load the per-parameter optimizer states of a single
2007
        parameter at a time, instead of loading all per-parameter optimizer
2008
        states at once.
2009

2010
        Args:
2011
            optim (Type): a ``torch.optim.Optimizer`` class to be registered
2012
            as a fused optimizer.
2013
            *args (Sequence[Any]): Arguments to forward to `optim`.
2014
            optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters
2015
            to optimize, similar to `params` argument of traditional `torch.optim`
2016
            Optimizers. If this is omitted, all DDP model parameters will be
2017
            optimized.
2018
            **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim`.
2019

2020
        .. warning ::
2021
            _register_fused_optim should only be called once on a DDP instance,
2022
            and registering multiple fused optimizers for the same DDP model
2023
            is not currently supported. Please ping
2024
            https://github.com/pytorch/pytorch/issues/71595 if this is necessary
2025
            for your use case.
2026

2027
        .. warning ::
2028
            _register_fused_optim and register_comm_hook currently do not
2029
            compose together, meaning that custom DDP communication hooks are
2030
            not supported with overlapped optimizers. Please ping
2031
            https://github.com/pytorch/pytorch/issues/71595 if this is necessary
2032
            for your use case.
2033

2034
        .. warning ::
2035
            Gradient accumulation and DDP `no_sync` are currently not supported
2036
            with overlapped optimizer. Please ping
2037
            https://github.com/pytorch/pytorch/issues/71595 if this is necessary
2038
            for your use case.
2039

2040
        Example::
2041

2042
            >>> # xdoctest: +SKIP("No rendezvous handler")
2043
            >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
2044
            >>> net = torch.nn.parallel.DistributedDataParallel(model, pg)
2045
            >>> lr = 1e-2
2046
            >>> betas = (0.9, 0.99)
2047
            >>> eps = 1e-6
2048
            >>> net._register_fused_optim(torch.optim.Adam, lr, betas=betas, eps=eps)
2049
            >>> # Example with subset of parameters
2050
            >>> params_to_opt = [list(net.parameters())[0]]
2051
            >>> net._register_fused_optim(
2052
            ...   torch.optim.Adam, lr, optim_params=params_to_opt,  betas=betas, eps=eps
2053
            ... )
2054
        """
2055
        # Note: importing in function, otherwise this will cause a circular
2056
        # import as optimizer_overlap module needs to import DistributedDataParallel.
2057
        from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim
2058

2059
        overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs)
2060
        try:
2061
            overlapped_optim.register_ddp(self)
2062
        except NotImplementedError as e:
2063
            raise RuntimeError(
2064
                f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}."
2065
            ) from e
2066

2067
    def _distributed_broadcast_coalesced(
2068
        self, tensors, buffer_size, authoritative_rank=0
2069
    ):
2070
        dist._broadcast_coalesced(
2071
            self.process_group, tensors, buffer_size, authoritative_rank
2072
        )
2073

2074
    def _check_sync_bufs_post_fwd(self):
2075
        return (
2076
            self.will_sync_module_buffers()
2077
            and hasattr(self, "buffer_hook")
2078
            and self.buffer_hook.buffer_comm_hook_location
2079
            == _BufferCommHookLocation.POST_FORWARD
2080
        )
2081

2082
    def _check_sync_bufs_pre_fwd(self):
2083
        return self.will_sync_module_buffers() and (
2084
            not hasattr(self, "buffer_hook")
2085
            or self.buffer_hook.buffer_comm_hook_location
2086
            == _BufferCommHookLocation.PRE_FORWARD
2087
        )
2088

2089
    def will_sync_module_buffers(self):
2090
        return (
2091
            self.require_forward_param_sync
2092
            and self.broadcast_buffers
2093
            and len(self.modules_buffers) > 0
2094
        )
2095

2096
    def _find_common_rank(self, input_rank, rank_cond):
2097
        # -1 indicates that this rank is not under consideration to be the
2098
        # common_rank
2099
        rank_to_use = torch.tensor(
2100
            [input_rank if rank_cond else -1],
2101
            device=self.device,
2102
        )
2103
        dist.all_reduce(rank_to_use, op=ReduceOp.MAX, group=self.process_group)
2104
        if rank_to_use.item() == -1:
2105
            self._log_and_throw(
2106
                ValueError,
2107
                "BUG! Expected rank_cond to be true for at least one process."
2108
                " This indicates a bug in PyTorch, please report an issue.",
2109
            )
2110
        return rank_to_use.item()
2111

2112
    def _sync_buffers(self):
2113
        with torch.no_grad():
2114
            # module buffer sync
2115
            # Synchronize buffers across processes.
2116
            # If we are running DDP with the join manager, we have to agree
2117
            # upon a rank to sync module buffers from, since rank 0 may
2118
            # already have been joined and have stale module buffers.
2119
            if self._join_config.enable:
2120
                authoritative_rank = self._find_common_rank(
2121
                    self._distributed_rank, True
2122
                )
2123
            else:
2124
                # The process with rank 0 is considered the authoritative copy.
2125
                authoritative_rank = 0
2126
            # Update self.modules_buffers incase any buffers were
2127
            # reassigned.
2128
            self._assign_modules_buffers()
2129
            self._sync_module_buffers(authoritative_rank)
2130

2131
    def _sync_module_buffers(self, authoritative_rank):
2132
        if not hasattr(self, "buffer_hook"):
2133
            self._default_broadcast_coalesced(authoritative_rank=authoritative_rank)
2134
        else:
2135
            hook = self.buffer_hook.buffer_comm_hook
2136
            state = self.buffer_hook.buffer_comm_hook_state
2137
            futs = hook(state, self.named_module_buffers)
2138
            if futs is not None:
2139
                self.reducer._install_post_backward_futures(futs)
2140

2141
    def _default_broadcast_coalesced(
2142
        self, bufs=None, bucket_size=None, authoritative_rank=0
2143
    ):
2144
        """
2145
        Broadcasts buffers from rank 0 to rest of workers.
2146

2147
        If bufs, bucket_size are None, default values self.modules_buffers
2148
        and self.broadcast_bucket_size are used instead.
2149
        """
2150
        if bufs is None:
2151
            bufs = self.modules_buffers
2152
        if bucket_size is None:
2153
            bucket_size = self.broadcast_bucket_size
2154

2155
        self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank)
2156

2157
    def _passing_sync_batchnorm_handle(self, module):
2158
        for layer in module.modules():
2159
            if isinstance(layer, torch.nn.modules.SyncBatchNorm):
2160
                if self.device_type == "cpu":
2161
                    self._log_and_throw(
2162
                        ValueError,
2163
                        "SyncBatchNorm layers only work with GPU modules",
2164
                    )
2165

2166
    def _check_comm_hook(self, hook):
2167
        if not callable(hook):
2168
            self._log_and_throw(TypeError, "Communication hook must be callable.")
2169

2170
        sig = inspect.signature(hook)
2171
        if (
2172
            sig.parameters["bucket"].annotation != inspect._empty
2173
            and sig.parameters["bucket"].annotation != dist.GradBucket
2174
        ):
2175
            self._log_and_throw(
2176
                ValueError,
2177
                "Communication hook: bucket annotation should be dist.GradBucket.",
2178
            )
2179

2180
        if (
2181
            sig.return_annotation != inspect._empty
2182
            and sig.return_annotation != torch.futures.Future[torch.Tensor]
2183
        ):
2184
            self._log_and_throw(
2185
                ValueError,
2186
                "Communication hook: return annotation should be torch.futures.Future[torch.Tensor].",
2187
            )
2188

2189
        if hook.__name__ in [
2190
            "bf16_compress_hook",
2191
            "bf16_compress_wrapper_hook",
2192
        ] and (
2193
            (torch.version.cuda is None and torch.version.hip is None)
2194
            or (
2195
                torch.version.cuda is not None
2196
                and int(torch.version.cuda.split(".")[0]) < 11
2197
            )
2198
            or not dist.is_available()
2199
            or not dist.is_nccl_available()
2200
            or torch.cuda.nccl.version() < (2, 10)
2201
        ):
2202
            self._log_and_throw(
2203
                TypeError,
2204
                "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+.",
2205
            )
2206

2207
    @property
2208
    def _distributed_rank(self):
2209
        return dist.get_rank(self.process_group)
2210

2211
    @staticmethod
2212
    def _get_data_parallel_params(module, named_params=False):
2213
        """Return a generator of parameters managed by a given DDP unit."""
2214
        for param in (
2215
            module.parameters() if not named_params else module.named_parameters()
2216
        ):
2217
            if not hasattr(param, "_ddp_ignored"):
2218
                yield param
2219

2220
    @staticmethod
2221
    def _set_params_and_buffers_to_ignore_for_model(
2222
        module, params_and_buffers_to_ignore
2223
    ):
2224
        """
2225
        Set parameters and buffers to be ignored by DDP.
2226

2227
        Expected format for parameters is the fully qualified name: {module_name}.{param_name}, and
2228
        similarly, {module_name}.{buffer_name} for buffers. For example:
2229
        params_to_ignore = []
2230
        # NB: model here is vanilla PyTorch module, not yet wrapped with DDP.
2231
        for module_name, module in model.named_modules():
2232
            for param_name, param in module.named_parameters(recurse=False):
2233
                if should_ignore(param):
2234
                    # Create expected format
2235
                    fqn = f"{module_name}.{param_name}"
2236
                    params_to_ignore.append(fqn)
2237
        torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
2238
            model,
2239
            params_to_ignore
2240
        )
2241
        """
2242
        # This is a workaround to set parameters and buffers DDP should ignore
2243
        # during synchronization. It will be removed when the API is finalized
2244
        # as part of addressing https://github.com/pytorch/pytorch/issues/43690.
2245
        module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore
2246
        for name, param in module.named_parameters():
2247
            if name in params_and_buffers_to_ignore:
2248
                param._ddp_ignored = True
2249
        for name, buffer in module.named_buffers():
2250
            if name in params_and_buffers_to_ignore:
2251
                buffer._ddp_ignored = True
2252

2253
    def _get_ddp_logging_data(self):
2254
        r"""
2255
        Return a dictionary of logging data for debugging and analysis.
2256

2257
        This interface can be called after DistributedDataParallel() is
2258
        constructed. It returns a dictionary of logging data. It could help
2259
        for debugging and analysis. The logging data includes DistributedDataParallel
2260
        constructor input parameters, some internal states of DistributedDataParallel
2261
        and performance metrics. Simply print the dictionary and see what
2262
        these metrics are.
2263
        This is a prototype interface and subject to change in the future.
2264
        """
2265
        assert self.logger is not None
2266
        ddp_logging_data = self.logger._get_ddp_logging_data()
2267
        return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map}
2268

2269
    def _set_ddp_runtime_logging_sample_rate(self, sample_rate):
2270
        r"""
2271
        Set sample_rate of collecting runtime stats.
2272

2273
        This interface allows users to set sample_rate of collecting
2274
        runtime stats. The runtime stats will be recorded for the
2275
        first 10 iterations, after 10 iterations runtime stats will be
2276
        recorded once every "sample_rate" training iterations. In
2277
        default, runtime stats are recorded for the first 10 iterations,
2278
        after 10 iterations runtime stats are recorded once every
2279
        "kDDPRuntimeLoggingSampleRate=100" training iterations.
2280
        This is a prototype interface and subject to change in the future.
2281
        """
2282
        if sample_rate < 1:
2283
            self._log_and_throw(
2284
                ValueError,
2285
                "DDP runtime logging sample rate should be equal or greater than 1",
2286
            )
2287
        self.reducer._set_ddp_runtime_logging_sample_rate(sample_rate)
2288

2289
    def _set_static_graph(self):
2290
        """
2291
        Set static graph for DDP.
2292

2293
        It is recommended to set static graph in the DDP constructor, which will
2294
        call this private API internally.
2295
        """
2296
        # If self.static_graph has been set, no need to set it again
2297
        if self.static_graph:
2298
            warnings.warn(
2299
                "You've set static_graph to be True, no need to set it again."
2300
            )
2301
            return
2302
        self.static_graph = True
2303
        self._static_graph_delay_allreduce_enqueued = False
2304
        self.reducer._set_static_graph()
2305
        assert self.logger is not None
2306
        self.logger._set_static_graph()
2307
        if self.find_unused_parameters:
2308
            warnings.warn(
2309
                "You passed find_unused_parameters=true to DistributedDataParallel, "
2310
                "`_set_static_graph` will detect unused parameters automatically, so "
2311
                "you do not need to set find_unused_parameters=true, just be sure these "
2312
                "unused parameters will not change during training loop while calling "
2313
                "`_set_static_graph`."
2314
            )
2315

2316
    def _remove_autograd_hooks(self):
2317
        """Remove autograd hooks registered by the reducer on the model parameters."""
2318
        self.reducer._remove_autograd_hooks()
2319

2320
    def _check_reducer_finalized(self):
2321
        """
2322
        Check if the reducer has processed all buckets and finalized the backward appropriately.
2323

2324
        It is useful to call this method after calling .backward() in your training loop
2325
        in order to avoid subsequent hard to debug errors down the road due to the
2326
        reducer not finalizing backward.
2327
        """
2328
        self.reducer._check_reducer_finalized()
2329

2330
    def _set_sparse_metadata(self, global_unique_ids):
2331
        self.reducer._set_sparse_metadata(global_unique_ids)
2332

2333
    def _update_process_group(self, new_process_group):
2334
        """
2335
        Dynamically updates the process group for DDP so that we can shrink/expand DDP
2336
        world size without having to reinitialize DDP.
2337

2338
        NOTE: If you are using custom communications hooks via, register_comm_hook,
2339
        you need to update the process groups for those hooks separately.
2340
        """
2341
        # Force a rebuild of buckets for a new process group. This ensures all ranks
2342
        # are synchronized in terms of when they will rebuild buckets and also
2343
        # re-evaluates previous assumptions of buckets given the world size might have
2344
        # changed.
2345
        self._has_rebuilt_buckets = False
2346
        self.reducer._reset_state()
2347

2348
        if not _rank_not_in_group(new_process_group):
2349
            self.process_group = new_process_group
2350
            self.reducer._update_process_group(new_process_group)
2351

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.