10
from collections import defaultdict, deque
11
from contextlib import contextmanager
12
from dataclasses import dataclass, fields, is_dataclass
13
from enum import auto, Enum
14
from typing import Any, Callable, List, Optional, Tuple, Type
17
import torch.distributed as dist
18
from torch.autograd import Function, Variable
19
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
20
from torch.utils._pytree import tree_flatten, tree_unflatten
21
from torch.utils.hooks import RemovableHandle
24
if dist.is_available():
25
from torch.distributed.distributed_c10d import (
30
from torch.distributed.utils import (
36
_verify_param_shape_across_processes,
38
if torch.distributed.rpc.is_available():
40
from torch.distributed.rpc import RRef
42
from torch._utils import _get_device_index
44
from ..modules import Module
45
from .scatter_gather import gather, scatter_kwargs
47
__all__ = ["DistributedDataParallel"]
49
logger = logging.getLogger(__name__)
55
This configures DDP-native mixed precision training.
58
param_dtype (torch.dtype): This specifies the dtype for model
59
parameters, inputs (when ``cast_forward_inputs`` is set to
60
``True``), and therefore the dtype for computation.
61
However, outside the forward and backward passes, parameters are in
62
full precision. Model checkpointing always happens in full
64
reduce_dtype (torch.dtype): This specifies the dtype for gradient
65
reduction, which is permitted to differ from ``param_dtype``.
66
buffer_dtype (torch.dtype): This specifies the dtype for buffers.
68
.. note:: This API is experimental and subject to change.
70
.. note:: Only floating point tensors are cast to their specified dtypes.
72
.. note:: ``state_dict`` checkpoints parameters and buffers in full
75
.. note:: Each low precision dtype must be specified explicitly. For
76
example, ``_MixedPrecision(reduce_dtype=torch.float16)`` only specifies
77
the reduction dtype to be low precision, and DDP will not cast
78
parameters or buffers.
80
.. note:: If a ``reduce_dtype`` is not specified, then gradient reduction
81
happens in ``param_dtype`` if specified or the original parameter dtype
82
otherwise. For example, ``_MixedPrecision(param_dtype=torch.float16)``
83
would result in communication occurring in fp16.
86
param_dtype: Optional[torch.dtype] = None
87
reduce_dtype: Optional[torch.dtype] = None
88
buffer_dtype: Optional[torch.dtype] = None
95
def _cast_buffers(mixed_precision_config, root_module):
96
"""Casts buffers to the given ``buffer_dtype``."""
97
for buf in root_module.buffers():
98
if hasattr(buf, "_ddp_ignored") and buf._ddp_ignored:
101
buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype)
104
def _setup_mixed_precision_params(mixed_precision_config, root_module):
105
"""Create and free storage for the mixed precision parameters."""
106
for param in root_module.parameters():
108
if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
111
if not hasattr(param, "_mp_param"):
112
param._mp_param = torch.zeros_like(
115
dtype=mixed_precision_config.param_dtype,
116
requires_grad=param.requires_grad,
118
_free_storage(param._mp_param)
121
param._fp_param = param.data
124
def _tree_flatten_with_rref(output):
125
output_is_rref = RPC_AVAILABLE and isinstance(output, RRef)
127
output_tensor_list, treespec = tree_flatten(output.local_value())
129
output_tensor_list, treespec = tree_flatten(output)
132
return output_tensor_list, treespec, output_is_rref
135
def _tree_unflatten_with_rref(output, treespec, output_is_rref):
136
output = tree_unflatten(output, treespec)
138
output = RRef(output)
142
def _find_tensors(obj):
143
r"""Recursively find all tensors contained in the specified object."""
144
if RPC_AVAILABLE and isinstance(obj, RRef):
149
return _find_tensors(obj.local_value())
150
if isinstance(obj, torch.Tensor):
152
if isinstance(obj, (list, tuple)):
153
return itertools.chain.from_iterable(map(_find_tensors, obj))
154
if isinstance(obj, dict):
155
return itertools.chain.from_iterable(map(_find_tensors, obj.values()))
156
if is_dataclass(obj):
157
return itertools.chain.from_iterable(
158
map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))
164
def _dump_DDP_relevant_env_vars():
165
relevant_env_vars = [
171
"CUDA_VISIBLE_DEVICES",
172
"GLOO_SOCKET_IFNAME",
173
"GLOO_DEVICE_TRANSPORT",
174
"NCCL_SOCKET_IFNAME",
175
"TORCH_NCCL_BLOCKING_WAIT",
183
"NCCL_SOCKET_NTHREADS",
184
"NCCL_NSOCKS_PERTHREAD",
188
"NCCL_MAX_NCHANNELS",
189
"NCCL_MIN_NCHANNELS",
190
"NCCL_CHECKS_DISABLE",
191
"NCCL_CHECK_POINTERS",
199
"NCCL_IB_AR_THRESHOLD",
200
"NCCL_IB_CUDA_SUPPORT",
201
"NCCL_NET_GDR_LEVEL",
203
"NCCL_SINGLE_RING_THRESHOLD",
205
"NCCL_TREE_THRESHOLD",
208
"NCCL_IGNORE_CPU_AFFINITY",
210
"NCCL_COLLNET_ENABLE",
212
"NCCL_TOPO_DUMP_FILE",
213
"TORCH_NCCL_ASYNC_ERROR_HANDLING",
215
formatted_output = ""
216
for var in relevant_env_vars:
217
value = os.environ[var] if var in os.environ else "N/A"
218
formatted_output += f"env:{var}={value}\n"
219
print(formatted_output)
222
class _BufferCommHookLocation(Enum):
224
POST_FORWARD = auto()
228
class _BufferCommHook:
229
buffer_comm_hook: Callable
230
buffer_comm_hook_state: Any
231
buffer_comm_hook_location: _BufferCommHookLocation
238
class _DDPSink(Function):
240
def forward(ctx, ddp_weakref, *inputs):
243
ctx.set_materialize_grads(False)
244
ctx.ddp_weakref = ddp_weakref
246
inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
251
def backward(ctx, *grad_outputs):
254
ddp_weakref = ctx.ddp_weakref()
255
reducer = ddp_weakref.reducer
256
static_graph = ddp_weakref.static_graph
257
delay_ar_enqueued = (
258
static_graph and ddp_weakref._static_graph_delay_allreduce_enqueued
260
if static_graph and not delay_ar_enqueued:
261
Variable._execution_engine.queue_callback(
262
reducer._delay_all_reduce
264
ddp_weakref._static_graph_delay_allreduce_enqueued = True
266
return (None, *grad_outputs)
269
class _DDPJoinHook(JoinHook):
270
def __init__(self, ddp, divide_by_initial_world_size):
271
"""Set config variables for internal usage."""
272
assert isinstance(ddp, DistributedDataParallel), (
273
"DDP join hook requires passing in a DistributedDataParallel "
274
"instance as the state"
276
assert ddp.logger is not None
277
ddp.logger._set_uneven_input_join()
279
self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
283
"""Shadow the DDP collective communication operations in the forward and backward passes."""
286
ddp.reducer._rebuild_buckets()
292
ddp._check_and_sync_module_buffers()
295
should_sync_backwards = ddp._check_global_requires_backward_grad_sync(
301
ddp.require_forward_param_sync = should_sync_backwards
302
if not should_sync_backwards:
307
ddp._match_all_reduce_for_bwd_pass()
310
if ddp.find_unused_parameters:
311
ddp._match_unused_params_allreduce()
314
ddp.reducer._push_all_rebuilt_params()
316
def post_hook(self, is_last_joiner: bool):
317
"""Sync the final model to ensure that the model is the same across all processes."""
318
self.ddp._sync_final_model(is_last_joiner)
321
class DistributedDataParallel(Module, Joinable):
322
r"""Implement distributed data parallelism based on ``torch.distributed`` at module level.
324
This container provides data parallelism by synchronizing gradients
325
across each model replica. The devices to synchronize across are
326
specified by the input ``process_group``, which is the entire world
327
by default. Note that ``DistributedDataParallel`` does not chunk or
328
otherwise shard the input across participating GPUs; the user is
329
responsible for defining how to do so, for example through the use
330
of a :class:`DistributedSampler`.
332
See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`.
333
The same constraints on input as in :class:`torch.nn.DataParallel` apply.
335
Creation of this class requires that ``torch.distributed`` to be already
336
initialized, by calling :func:`torch.distributed.init_process_group`.
338
``DistributedDataParallel`` is proven to be significantly faster than
339
:class:`torch.nn.DataParallel` for single-node multi-GPU data
342
To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn
343
up ``N`` processes, ensuring that each process exclusively works on a single
344
GPU from 0 to N-1. This can be done by either setting
345
``CUDA_VISIBLE_DEVICES`` for every process or by calling:
347
>>> # xdoctest: +SKIP("undefined variables")
348
>>> torch.cuda.set_device(i)
350
where i is from 0 to N-1. In each process, you should refer the following
351
to construct this module:
353
>>> # xdoctest: +SKIP("undefined variables")
354
>>> torch.distributed.init_process_group(
355
>>> backend='nccl', world_size=N, init_method='...'
357
>>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
359
In order to spawn up multiple processes per node, you can use either
360
``torch.distributed.launch`` or ``torch.multiprocessing.spawn``.
363
Please refer to `PyTorch Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`__
364
for a brief introduction to all features related to distributed training.
367
``DistributedDataParallel`` can be used in conjunction with
368
:class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce
369
per-rank optimizer states memory footprint. Please refer to
370
`ZeroRedundancyOptimizer recipe <https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html>`__
373
.. note:: ``nccl`` backend is currently the fastest and highly recommended
374
backend when using GPUs. This applies to both single-node and
375
multi-node distributed training.
377
.. note:: This module also supports mixed-precision distributed training.
378
This means that your model can have different types of parameters such
379
as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these
380
mixed types of parameters will just work fine.
382
.. note:: If you use ``torch.save`` on one process to checkpoint the module,
383
and ``torch.load`` on some other processes to recover it, make sure that
384
``map_location`` is configured properly for every process. Without
385
``map_location``, ``torch.load`` would recover the module to devices
386
where the module was saved from.
388
.. note:: When a model is trained on ``M`` nodes with ``batch=N``, the
389
gradient will be ``M`` times smaller when compared to the same model
390
trained on a single node with ``batch=M*N`` if the loss is summed (NOT
391
averaged as usual) across instances in a batch (because the gradients
392
between different nodes are averaged). You should take this into
393
consideration when you want to obtain a mathematically equivalent
394
training process compared to the local training counterpart. But in most
395
cases, you can just treat a DistributedDataParallel wrapped model, a
396
DataParallel wrapped model and an ordinary model on a single GPU as the
397
same (E.g. using the same learning rate for equivalent batch size).
400
Parameters are never broadcast between processes. The module performs
401
an all-reduce step on gradients and assumes that they will be modified
402
by the optimizer in all processes in the same way. Buffers
403
(e.g. BatchNorm stats) are broadcast from the module in process of rank
404
0, to all other replicas in the system in every iteration.
407
If you are using DistributedDataParallel in conjunction with the
408
:ref:`distributed-rpc-framework`, you should always use
409
:meth:`torch.distributed.autograd.backward` to compute gradients and
410
:class:`torch.distributed.optim.DistributedOptimizer` for optimizing
415
>>> # xdoctest: +SKIP("undefined variables")
416
>>> import torch.distributed.autograd as dist_autograd
417
>>> from torch.nn.parallel import DistributedDataParallel as DDP
419
>>> from torch import optim
420
>>> from torch.distributed.optim import DistributedOptimizer
421
>>> import torch.distributed.rpc as rpc
422
>>> from torch.distributed.rpc import RRef
424
>>> t1 = torch.rand((3, 3), requires_grad=True)
425
>>> t2 = torch.rand((3, 3), requires_grad=True)
426
>>> rref = rpc.remote("worker1", torch.add, args=(t1, t2))
427
>>> ddp_model = DDP(my_model)
429
>>> # Setup optimizer
430
>>> optimizer_params = [rref]
431
>>> for param in ddp_model.parameters():
432
>>> optimizer_params.append(RRef(param))
434
>>> dist_optim = DistributedOptimizer(
436
>>> optimizer_params,
440
>>> with dist_autograd.context() as context_id:
441
>>> pred = ddp_model(rref.to_here())
442
>>> loss = loss_func(pred, target)
443
>>> dist_autograd.backward(context_id, [loss])
444
>>> dist_optim.step(context_id)
447
DistributedDataParallel currently offers limited support for gradient
448
checkpointing with :meth:`torch.utils.checkpoint`.
449
If the checkpoint is done with use_reentrant=False (recommended), DDP
450
will work as expected without any limitations.
451
If, however, the checkpoint is done with use_reentrant=True (the default),
452
DDP will work as expected when there are no unused parameters in the model
453
and each layer is checkpointed at most once (make sure you are not passing
454
`find_unused_parameters=True` to DDP). We currently do not support the
455
case where a layer is checkpointed multiple times, or when there unused
456
parameters in the checkpointed model.
459
To let a non-DDP model load a state dict from a DDP model,
460
:meth:`~torch.nn.modules.utils.consume_prefix_in_state_dict_if_present`
461
needs to be applied to strip the prefix "module." in the DDP state dict before loading.
464
Constructor, forward method, and differentiation of the output (or a
465
function of the output of this module) are distributed synchronization
466
points. Take that into account in case different processes might be
467
executing different code.
470
This module assumes all parameters are registered in the model by the
471
time it is created. No parameters should be added nor removed later.
472
Same applies to buffers.
475
This module assumes all parameters are registered in the model of each
476
distributed processes are in the same order. The module itself will
477
conduct gradient ``allreduce`` following the reverse order of the
478
registered parameters of the model. In other words, it is users'
479
responsibility to ensure that each distributed process has the exact
480
same model and thus the exact same parameter registration order.
483
This module allows parameters with non-rowmajor-contiguous strides.
484
For example, your model may contain some parameters whose
485
:class:`torch.memory_format` is ``torch.contiguous_format``
486
and others whose format is ``torch.channels_last``. However,
487
corresponding parameters in different processes must have the
491
This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
492
only work if gradients are to be accumulated in ``.grad`` attributes of
496
If you plan on using this module with a ``nccl`` backend or a ``gloo``
497
backend (that uses Infiniband), together with a DataLoader that uses
498
multiple workers, please change the multiprocessing start method to
499
``forkserver`` (Python 3 only) or ``spawn``. Unfortunately
500
Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will
501
likely experience deadlocks if you don't change this setting.
504
You should never try to change your model's parameters after wrapping
505
up your model with ``DistributedDataParallel``. Because, when
506
wrapping up your model with ``DistributedDataParallel``, the constructor
507
of ``DistributedDataParallel`` will register the additional gradient
508
reduction functions on all the parameters of the model itself at the
509
time of construction. If you change the model's parameters afterwards,
510
gradient reduction functions no longer match the correct set of
514
Using ``DistributedDataParallel`` in conjunction with the
515
:ref:`distributed-rpc-framework` is experimental and subject to change.
518
module (Module): module to be parallelized
519
device_ids (list of int or torch.device): CUDA devices.
520
1) For single-device modules, ``device_ids`` can
521
contain exactly one device id, which represents the only
522
CUDA device where the input module corresponding to this process resides.
523
Alternatively, ``device_ids`` can also be ``None``.
524
2) For multi-device modules and CPU modules,
525
``device_ids`` must be ``None``.
527
When ``device_ids`` is ``None`` for both cases,
528
both the input data for the forward pass and the actual module
529
must be placed on the correct device.
531
output_device (int or torch.device): Device location of output for
532
single-device CUDA modules. For multi-device modules and
533
CPU modules, it must be ``None``, and the module itself
534
dictates the output location. (default: ``device_ids[0]``
535
for single-device modules)
536
broadcast_buffers (bool): Flag that enables syncing (broadcasting)
537
buffers of the module at beginning of the ``forward``
538
function. (default: ``True``)
539
process_group: The process group to be used for distributed data
540
all-reduction. If ``None``, the default process group, which
541
is created by :func:`torch.distributed.init_process_group`,
542
will be used. (default: ``None``)
543
bucket_cap_mb: ``DistributedDataParallel`` will bucket parameters into
544
multiple buckets so that gradient reduction of each
545
bucket can potentially overlap with backward computation.
546
:attr:`bucket_cap_mb` controls the bucket size in
547
MegaBytes (MB). (default: 25)
548
find_unused_parameters (bool): Traverse the autograd graph from all
549
tensors contained in the return value of the
550
wrapped module's ``forward`` function. Parameters
551
that don't receive gradients as part of this
552
graph are preemptively marked as being ready to
553
be reduced. In addition, parameters that may have
554
been used in the wrapped module's ``forward``
555
function but were not part of loss computation and
556
thus would also not receive gradients are
557
preemptively marked as ready to be reduced.
559
check_reduction: This argument is deprecated.
560
gradient_as_bucket_view (bool): When set to ``True``, gradients will be views
561
pointing to different offsets of ``allreduce`` communication
562
buckets. This can reduce peak memory usage, where the
563
saved memory size will be equal to the total gradients
564
size. Moreover, it avoids the overhead of copying between
565
gradients and ``allreduce`` communication buckets. When
566
gradients are views, ``detach_()`` cannot be called on the
567
gradients. If hitting such errors, please fix it by
568
referring to the :meth:`~torch.optim.Optimizer.zero_grad`
569
function in ``torch/optim/optimizer.py`` as a solution.
570
Note that gradients will be views after first iteration, so
571
the peak memory saving should be checked after first iteration.
572
static_graph (bool): When set to ``True``, DDP knows the trained graph is
573
static. Static graph means 1) The set of used and unused
574
parameters will not change during the whole training loop; in
575
this case, it does not matter whether users set
576
``find_unused_parameters = True`` or not. 2) How the graph is trained
577
will not change during the whole training loop (meaning there is
578
no control flow depending on iterations).
579
When static_graph is set to be ``True``, DDP will support cases that
580
can not be supported in the past:
581
1) Reentrant backwards.
582
2) Activation checkpointing multiple times.
583
3) Activation checkpointing when model has unused parameters.
584
4) There are model parameters that are outside of forward function.
585
5) Potentially improve performance when there are unused parameters,
586
as DDP will not search graph in each iteration to detect unused
587
parameters when static_graph is set to be ``True``.
588
To check whether you can set static_graph to be ``True``, one way is to
589
check ddp logging data at the end of your previous model training,
590
if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you
591
can set ``static_graph = True`` as well.
594
>>> # xdoctest: +SKIP("undefined variables")
595
>>> model_DDP = torch.nn.parallel.DistributedDataParallel(model)
598
>>> ddp_logging_data = model_DDP._get_ddp_logging_data()
599
>>> static_graph = ddp_logging_data.get("can_set_static_graph")
600
delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter): a list
601
of named parameters whose all reduce will be delayed when the gradient of
602
the parameter specified in ``param_to_hook_all_reduce`` is ready. Other
603
arguments of DDP do not apply to named params specified in this argument
604
as these named params will be ignored by DDP reducer.
605
param_to_hook_all_reduce (torch.nn.Parameter): a parameter to hook delayed all reduce
606
of parameters specified in ``delay_all_reduce_named_params``.
610
module (Module): the module to be parallelized.
614
>>> # xdoctest: +SKIP("undefined variables")
615
>>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
616
>>> net = torch.nn.parallel.DistributedDataParallel(model)
620
_active_ddp_module: Optional["DistributedDataParallel"] = None
628
broadcast_buffers=True,
631
find_unused_parameters=False,
632
check_reduction=False,
633
gradient_as_bucket_view=False,
635
delay_all_reduce_named_params=None,
636
param_to_hook_all_reduce=None,
637
mixed_precision: Optional[_MixedPrecision] = None,
641
Joinable.__init__(self)
643
if bool(delay_all_reduce_named_params is not None) != bool(
644
param_to_hook_all_reduce is not None
648
"delay_all_reduce_named_params and param_to_hook_all_reduce "
649
"need to be set at the same time.",
652
self._delay_all_reduce_params = []
653
if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
654
self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore)
656
self.parameters_to_ignore = set()
657
if delay_all_reduce_named_params is not None:
658
for name, param in delay_all_reduce_named_params:
659
self.parameters_to_ignore.add(name)
660
self._delay_all_reduce_params.append(param)
662
self._module_parameters = [
664
for n, p in module.named_parameters()
665
if n not in self.parameters_to_ignore
667
if not any(p.requires_grad for p in self._module_parameters):
668
if len(self._delay_all_reduce_params):
669
logger.info("Delay the AllReduce of all parameters.")
673
"DistributedDataParallel is not needed when a module "
674
"doesn't have any parameter that requires a gradient.",
677
if device_ids is not None and len(device_ids) > 1:
680
"device_ids can only be None or contain a single element.",
683
self.is_multi_device_module = (
684
len({p.device for p in self._module_parameters}) > 1
686
distinct_device_types = {
687
p.device.type for p in self._module_parameters if p.device is not None
689
if len(distinct_device_types) != 1:
692
"DistributedDataParallel's input module must be on "
693
f"the same type of devices, but input module parameters locate in {distinct_device_types}.",
696
self.device_type = next(iter(distinct_device_types))
700
or len(device_ids) == 0
701
or self.device_type == "cpu"
702
or self.is_multi_device_module
704
if device_ids or output_device:
707
"DistributedDataParallel device_ids and output_device arguments "
708
"only work with single-device/multiple-device GPU modules or CPU modules, "
709
"but got device_ids {}, output_device {}, and module parameters {}.".format(
712
{p.device for p in self._module_parameters},
716
self.device_ids = None
717
self.output_device = None
719
self.device_ids = [_get_device_index(x, True) for x in device_ids]
721
if output_device is None:
722
output_device = device_ids[0]
724
self.output_device = _get_device_index(output_device, True)
726
if process_group and device_mesh is not None:
728
"Cannot specify both process_group and device_mesh arguments."
730
elif process_group is None and device_mesh is None:
731
self.process_group = _get_default_group()
732
elif device_mesh is None:
733
self.process_group = process_group
735
if device_mesh.ndim != 1:
737
f"Only 1D device mesh is supported, but got {device_mesh}."
739
self.device_mesh = device_mesh
740
self.process_group = device_mesh.get_group(mesh_dim=0)
742
self.static_graph = False
745
self.device = next(iter(self._module_parameters)).device
746
self.broadcast_buffers = broadcast_buffers
747
self.find_unused_parameters = find_unused_parameters
748
self.require_backward_grad_sync = True
749
self.require_forward_param_sync = True
750
self.gradient_as_bucket_view = gradient_as_bucket_view
751
self.mixed_precision = mixed_precision
752
if self.mixed_precision is not None:
753
logger.warning("Received mixed precision config %s", self.mixed_precision)
760
"The `check_reduction` argument in `DistributedDataParallel` "
761
"module is deprecated. Please avoid using it."
765
for param in self._module_parameters:
766
if isinstance(param, torch.nn.parameter.UninitializedParameter):
769
"Modules with uninitialized parameters can't be used with `DistributedDataParallel`. "
770
"Run a dummy forward pass to correctly initialize the modules",
773
self.broadcast_bucket_size = int(250 * 1024 * 1024)
776
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
778
self.use_side_stream_for_tensor_copies = (
779
os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1"
783
self._delay_grad_buffer = None
784
self._delay_grad_views: List[torch.Tensor] = []
785
self._delay_all_reduce_all_params = False
786
if len(self._delay_all_reduce_params) != 0:
787
self._register_delay_all_reduce_hook(
788
bucket_cap_mb=bucket_cap_mb,
789
param_to_hook_all_reduce=param_to_hook_all_reduce,
790
device_ids=device_ids,
792
if self._delay_all_reduce_all_params:
796
parameters, expect_sparse_gradient = self._build_params_for_reducer()
798
_verify_param_shape_across_processes(self.process_group, parameters)
802
process_group=self.process_group,
803
broadcast_bucket_size=self.broadcast_bucket_size,
805
params_and_buffers_to_ignore=self.parameters_to_ignore,
806
broadcast_buffers=self.broadcast_buffers,
809
param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
812
self._ddp_init_helper(
814
expect_sparse_gradient,
815
param_to_name_mapping,
818
self._comm_hooks: List[Tuple[Callable, object]] = []
820
if self.mixed_precision is not None:
821
_setup_mixed_precision_params(self.mixed_precision, self.module)
822
_cast_buffers(self.mixed_precision, self.module)
824
self._mp_stream = torch.cuda.Stream()
825
self._submodule_to_event = defaultdict(deque)
828
self.module.register_forward_pre_hook(
829
self._root_copy_hook, prepend=False, with_kwargs=True
833
for module in self.module.modules():
834
module.register_forward_pre_hook(
835
self._module_wait_for_copy_hook,
843
from torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks import (
844
_AllreduceUpcastHookState,
845
_reducer_allreduce_and_upcast_hook,
848
upcast_hook_state = _AllreduceUpcastHookState(
849
ddp_weakref=weakref.ref(self),
850
upcast_stream=torch.cuda.Stream(),
852
self.register_comm_hook(
854
_reducer_allreduce_and_upcast_hook,
858
self.reducer._set_mixed_precision_param_dtype(
859
self.mixed_precision.param_dtype
862
self._has_rebuilt_buckets = False
865
self._set_static_graph()
867
self._lazy_init_ran = False
872
self._accum_grad_hooks: List[RemovableHandle] = []
873
optimize_ddp = torch._dynamo.config._get_optimize_ddp_mode()
874
self._use_python_reducer = optimize_ddp in (
876
"python_reducer_without_compiled_forward",
878
self._force_to_disable_cpp_reducer = (
879
optimize_ddp == "python_reducer_without_compiled_forward"
881
if self._use_python_reducer:
882
self._register_accum_grad_hook()
884
def _register_accum_grad_hook(self):
885
import torch.distributed._functional_collectives as fcol
887
def compiled_accum_grad_hook(
892
if not self.require_backward_grad_sync:
895
if param.grad is None:
899
for hook, state in self._comm_hooks:
900
hook(state, (param.grad, param))
902
gradient = param.grad / self.process_group.size()
903
gradient = fcol.all_reduce(gradient, "sum", self.process_group)
904
param.grad.copy_(gradient)
906
for index, param in enumerate(self._module_parameters):
907
self._accum_grad_hooks.append(
908
param.register_post_accumulate_grad_hook(
910
compiled_accum_grad_hook,
916
def _delayed_all_reduce_hook(self, grad):
917
world_size = dist.get_world_size(self.process_group)
919
self._delay_grad_buffer.div_(world_size)
921
self._delay_grad_buffer, group=self.process_group, async_op=True
925
def _register_delay_all_reduce_hook(
928
param_to_hook_all_reduce,
932
device = torch.device("cpu") if device_ids is None else device_ids[0]
933
self._delay_grad_buffer = torch.zeros(
934
sum([p.numel() for p in self._delay_all_reduce_params]),
939
detached_params = [p.detach() for p in self._delay_all_reduce_params]
940
dist._broadcast_coalesced(self.process_group, detached_params, bucket_cap_mb, 0)
943
param_to_hook_all_reduce.register_hook(self._delayed_all_reduce_hook)
947
for param in self._delay_all_reduce_params:
948
grad_view = self._delay_grad_buffer[offset : (offset + param.numel())].view(
951
self._delay_grad_views.append(grad_view)
952
offset = offset + param.numel()
955
for module_name, module in self.module.named_modules():
956
for param_name, param in module.named_parameters(recurse=False):
957
if param.requires_grad:
958
full_name = f"{module_name}.{param_name}"
959
if full_name not in self.parameters_to_ignore:
964
self._delay_all_reduce_all_params = True
966
def _setup_in_backward_optimizers(self):
977
if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters):
978
torch._C._log_api_usage_once("ddp.optimizer_in_backward")
982
param_to_handle_map = (
983
dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map
985
for p in self._module_parameters:
986
for handle in param_to_handle_map.get(p, []):
991
ddp_weakref = weakref.ref(self)
994
from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
995
_apply_optim_in_backward_hook,
998
self.register_comm_hook(
1000
_apply_optim_in_backward_hook(
1001
gradient_is_bucket_view=self.gradient_as_bucket_view
1005
self.reducer._set_optimizer_in_backward()
1007
def _fire_reducer_autograd_hook(self, idx, *unused):
1009
Fire the reducer's autograd hook to allreduce params in a Reducer bucket.
1011
Note that this is only used during mixed precision training as the
1012
Reducer's hooks installed during construction time would not be called
1013
as we're working in the low precision parameter setting.
1015
self.reducer._autograd_hook(idx)
1017
def _root_copy_hook(self, *args: Any, **kwargs: Any) -> None:
1019
For DDP mixed precision, put low precision copies on separate stream and create events to wait for them.
1021
When training with DDP mixed precision, this root pre-forward hook kicks
1022
off low precision copies on a separate stream and creates respective
1023
events to wait for them.
1028
self._submodule_to_event = defaultdict(deque)
1029
with torch.cuda.stream(self._mp_stream):
1030
for submodule in self.module.modules():
1031
for param in submodule.parameters(recurse=False):
1033
if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
1035
_alloc_storage(param._mp_param, param.size())
1037
with torch.no_grad():
1038
param._mp_param.copy_(param.data)
1048
if param.grad is not None:
1049
param.grad.data = param.grad.to(
1050
self.mixed_precision.param_dtype
1052
param.data = param._mp_param
1053
copy_event = torch.cuda.Event()
1055
self._submodule_to_event[submodule].append(copy_event)
1057
def _module_wait_for_copy_hook(
1063
"""Before carrying out computation, wait on the appropriate event to ensure low precision copies have finished."""
1065
event = self._submodule_to_event[module].popleft()
1070
event.wait(stream=torch.cuda.current_stream())
1071
for p in module.parameters(recurse=False):
1073
if not p.requires_grad or (hasattr(p, "_ddp_ignored") and p._ddp_ignored):
1078
tmp = p.expand_as(p)
1079
grad_acc = tmp.grad_fn.next_functions[0][0]
1081
hook = grad_acc.register_hook(
1082
functools.partial(self._fire_reducer_autograd_hook, p._idx)
1084
p._ddp_mp_hook_state = (grad_acc, hook)
1086
def _log_and_throw(self, err_type, err_msg):
1087
if self.logger is not None:
1088
self.logger.set_error_and_log(f"{str(err_type)}: {err_msg}")
1089
raise err_type(err_msg)
1091
def _ddp_init_helper(
1094
expect_sparse_gradient,
1095
param_to_name_mapping,
1099
DDP init helper function to manage parameters, grad hooks, logging, and SyncBatchNorm.
1101
Initialization helper function that does the following:
1102
(1) bucketing the parameters for reductions
1103
(2) resetting the bucketing states
1104
(3) registering the grad hooks
1105
(4) Logging construction-time DDP logging data
1106
(5) passing a handle of DDP to SyncBatchNorm Layer
1127
if static_graph is True or self.find_unused_parameters is False:
1128
bucket_size_limits = [sys.maxsize]
1130
bucket_size_limits = [
1131
dist._DEFAULT_FIRST_BUCKET_BYTES,
1132
self.bucket_bytes_cap,
1136
per_bucket_size_limits,
1137
) = dist._compute_bucket_assignment_by_size(
1140
expect_sparse_gradient,
1145
if self.mixed_precision is not None:
1146
for i, p in enumerate(parameters):
1152
self.reducer = dist.Reducer(
1154
list(reversed(bucket_indices)),
1155
list(reversed(per_bucket_size_limits)),
1157
expect_sparse_gradient,
1163
self.bucket_bytes_cap,
1164
self.find_unused_parameters,
1165
self.gradient_as_bucket_view,
1166
param_to_name_mapping,
1169
dist._DEFAULT_FIRST_BUCKET_BYTES,
1172
self.logger = dist.Logger(self.reducer)
1175
self.reducer.set_logger(self.logger)
1178
for submodule in self.module.modules():
1179
if isinstance(submodule, torch.nn.SyncBatchNorm):
1184
self.logger.set_construction_data_and_log(
1185
self.module.__class__.__name__,
1186
[] if self.device_ids is None else self.device_ids,
1187
-1 if self.output_device is None else self.output_device,
1188
self.broadcast_buffers,
1194
self._passing_sync_batchnorm_handle(self.module)
1196
def __getstate__(self):
1197
self._check_default_group()
1198
attrs = copy.copy(self.__dict__)
1199
del attrs["process_group"]
1200
del attrs["reducer"]
1204
def __setstate__(self, state):
1206
self.process_group = _get_default_group()
1207
super().__setstate__(state)
1208
self.__dict__.setdefault("require_forward_param_sync", True)
1209
self.__dict__.setdefault("require_backward_grad_sync", True)
1210
parameters, expect_sparse_gradient = self._build_params_for_reducer()
1212
param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
1214
self._ddp_init_helper(
1216
expect_sparse_gradient,
1217
param_to_name_mapping,
1220
if self.static_graph:
1221
self.reducer._set_static_graph()
1222
assert self.logger is not None
1223
self.logger._set_static_graph()
1225
def _build_params_for_reducer(self):
1227
modules_and_parameters = [
1229
for module_name, module in self.module.named_modules()
1236
for param_name, param in module.named_parameters(recurse=False)
1237
if param.requires_grad
1238
and f"{module_name}.{param_name}" not in self.parameters_to_ignore
1244
modules_and_parameters = [
1248
for m, p in modules_and_parameters
1249
if p not in memo and not memo.add(p)
1253
parameters = [parameter for _, parameter in modules_and_parameters]
1256
def produces_sparse_gradient(module):
1257
if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
1258
return module.sparse
1263
expect_sparse_gradient = [
1264
produces_sparse_gradient(module) for module, _ in modules_and_parameters
1267
self._assign_modules_buffers()
1269
return parameters, expect_sparse_gradient
1271
def _assign_modules_buffers(self):
1273
Assign self.module.named_buffers to self.modules_buffers.
1275
Assigns module buffers to self.modules_buffers which are then used to
1276
broadcast across ranks when broadcast_buffers=True. Note that this
1277
must be called every time buffers need to be synced because buffers can
1278
be reassigned by user module,
1279
see https://github.com/pytorch/pytorch/issues/63916.
1282
named_module_buffers = [
1283
(buffer, buffer_name)
1284
for buffer_name, buffer in self.module.named_buffers()
1285
if buffer_name not in self.parameters_to_ignore
1287
self.modules_buffers = [
1288
buffer for (buffer, buffer_name) in named_module_buffers
1291
self.named_module_buffers = {
1292
buffer_name: buffer for (buffer, buffer_name) in named_module_buffers
1295
def _build_debug_param_to_name_mapping(self, parameters):
1296
param_to_param_index = {parameters[i]: i for i in range(len(parameters))}
1297
param_set = set(parameters)
1298
param_index_to_param_fqn = {}
1299
for module_name, module in self.module.named_modules():
1300
for param_name, param in module.named_parameters(recurse=False):
1301
fqn = f"{module_name}.{param_name}"
1304
if fqn not in self.parameters_to_ignore and param.requires_grad:
1305
if param not in param_set:
1306
self._log_and_throw(
1308
f"Param with name {fqn} found in module parameters, but not DDP parameters."
1309
" This indicates a bug in DDP, please report an issue to PyTorch.",
1311
param_index = param_to_param_index[param]
1312
param_index_to_param_fqn[param_index] = fqn
1315
if len(param_set) != len(param_index_to_param_fqn):
1316
self._log_and_throw(
1319
"Expected param to name mapping to cover all parameters, but"
1320
f" got conflicting lengths: {len(param_set)} vs "
1321
f"{len(param_index_to_param_fqn)}. This indicates a bug in DDP"
1322
", please report an issue to PyTorch."
1326
return param_index_to_param_fqn
1328
def _get_parameters(self, m, recurse=True):
1329
"""Return a generator of module parameters."""
1331
def model_parameters(m):
1333
m._former_parameters.values()
1334
if hasattr(m, "_former_parameters")
1335
else m.parameters(recurse=False)
1339
for mod in m.modules() if recurse else [m]:
1340
yield from model_parameters(mod)
1342
def _check_default_group(self):
1343
pickle_not_supported = False
1345
if self.process_group != _get_default_group():
1346
pickle_not_supported = True
1347
except RuntimeError:
1348
pickle_not_supported = True
1350
if pickle_not_supported:
1351
self._log_and_throw(
1353
"DDP Pickling/Unpickling are only supported "
1354
"when using DDP with the default process "
1355
"group. That is, when you have called "
1356
"init_process_group and have not passed "
1357
"process_group argument to DDP constructor",
1363
Context manager to disable gradient synchronizations across DDP processes.
1365
Within this context, gradients will be accumulated on module
1366
variables, which will later be synchronized in the first
1367
forward-backward pass exiting the context.
1371
>>> # xdoctest: +SKIP("undefined variables")
1372
>>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
1373
>>> with ddp.no_sync():
1374
>>> for input in inputs:
1375
>>> ddp(input).backward() # no synchronization, accumulate grads
1376
>>> ddp(another_input).backward() # synchronize grads
1379
The forward pass should be included inside the context manager, or
1380
else gradients will still be synchronized.
1382
old_require_backward_grad_sync = self.require_backward_grad_sync
1383
self.require_backward_grad_sync = False
1387
self.require_backward_grad_sync = old_require_backward_grad_sync
1390
def _get_active_ddp_module(cls):
1391
"""`TorchDynamo` requires DDP's status and module for cooperative optimization."""
1392
return cls._active_ddp_module
1398
@torch._disable_dynamo(recursive=False)
1399
def _inside_ddp_forward(self):
1400
DistributedDataParallel._active_ddp_module = self
1404
DistributedDataParallel._active_ddp_module = None
1406
def _run_ddp_forward(self, *inputs, **kwargs):
1407
if self._use_python_reducer:
1408
return self.module(*inputs, **kwargs)
1410
with self._inside_ddp_forward():
1411
return self.module(*inputs, **kwargs)
1413
def _clear_grad_buffer(self):
1418
if self._delay_grad_buffer is not None:
1421
all_param_grad_none = all(
1422
param.grad is None for param in self._delay_all_reduce_params
1425
for index, param in enumerate(self._delay_all_reduce_params):
1426
if param.grad is None:
1427
param.grad = self._delay_grad_views[index]
1428
if not all_param_grad_none:
1431
if all_param_grad_none:
1432
self._delay_grad_buffer.zero_()
1434
def _lazy_init(self):
1437
self._setup_in_backward_optimizers()
1438
self._lazy_init_ran = True
1440
def _should_disable_cpp_reducer(self) -> bool:
1441
return self._use_python_reducer and (
1442
torch._utils.is_compiling() or self._force_to_disable_cpp_reducer
1445
def _pre_forward(self, *inputs, **kwargs):
1446
if self._should_disable_cpp_reducer():
1447
return inputs, kwargs
1450
if self._accum_grad_hooks:
1451
for index, h in enumerate(self._accum_grad_hooks):
1453
self._accum_grad_hooks.clear()
1455
if not self._lazy_init_ran and not torch._utils.is_compiling():
1458
if self._delay_all_reduce_all_params:
1459
return inputs, kwargs
1461
if torch.is_grad_enabled() and self.require_backward_grad_sync:
1462
assert self.logger is not None
1463
self.logger.set_runtime_stats_and_log()
1464
self.reducer.prepare_for_forward()
1468
work = Join.notify_join_context(self)
1470
self.reducer._set_forward_pass_work_handle(
1471
work, self._divide_by_initial_world_size
1480
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
1481
logger.info("Reducer buckets have been rebuilt in this iteration.")
1482
self._has_rebuilt_buckets = True
1486
if self._check_sync_bufs_pre_fwd():
1487
self._sync_buffers()
1489
if self._join_config.enable:
1491
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
1494
moved_inputs, moved_kwargs = _to_kwargs(
1497
torch.device(self.device_type, self.device_ids[0]),
1498
self.use_side_stream_for_tensor_copies,
1500
args, kwargs = moved_inputs[0], moved_kwargs[0]
1502
if self.mixed_precision is not None:
1503
args, kwargs = _cast_forward_inputs(
1504
self.mixed_precision.param_dtype,
1512
if self.mixed_precision is not None:
1513
inputs, kwargs = _cast_forward_inputs(
1514
self.mixed_precision.param_dtype,
1518
return inputs, kwargs
1520
def _post_forward(self, output):
1521
if self._should_disable_cpp_reducer():
1524
if self._delay_all_reduce_all_params:
1525
self._clear_grad_buffer()
1530
if self._check_sync_bufs_post_fwd():
1531
self._sync_buffers()
1533
if torch.is_grad_enabled() and self.require_backward_grad_sync:
1534
self.require_forward_param_sync = True
1540
if self.find_unused_parameters and not self.static_graph:
1542
self.reducer.prepare_for_backward(list(_find_tensors(output)))
1544
self.reducer.prepare_for_backward([])
1546
self.require_forward_param_sync = False
1550
if (self.find_unused_parameters and not self.static_graph) or (
1551
self.static_graph and not self._static_graph_delay_allreduce_enqueued
1557
) = _tree_flatten_with_rref(output)
1558
output_placeholders = [None for _ in range(len(output_tensor_list))]
1561
for i, output in enumerate(output_tensor_list):
1562
if torch.is_tensor(output) and output.grad_fn is None:
1563
output_placeholders[i] = output
1570
passthrough_tensor_list = _DDPSink.apply(
1572
*output_tensor_list,
1574
for i in range(len(output_placeholders)):
1575
if output_placeholders[i] is None:
1576
output_placeholders[i] = passthrough_tensor_list[i]
1579
output = _tree_unflatten_with_rref(
1580
output_placeholders, treespec, output_is_rref
1584
self._clear_grad_buffer()
1587
def forward(self, *inputs, **kwargs):
1588
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
1589
inputs, kwargs = self._pre_forward(*inputs, **kwargs)
1591
self.module.forward(*inputs, **kwargs)
1592
if self._delay_all_reduce_all_params
1593
else self._run_ddp_forward(*inputs, **kwargs)
1595
return self._post_forward(output)
1597
def scatter(self, inputs, kwargs, device_ids):
1598
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
1600
def to_kwargs(self, inputs, kwargs, device_id):
1605
torch.device(self.device_type, device_id),
1606
self.use_side_stream_for_tensor_copies,
1609
def gather(self, outputs, output_device):
1610
return gather(outputs, output_device, dim=self.dim)
1612
def train(self, mode=True):
1618
def _check_global_requires_backward_grad_sync(self, is_joined_rank):
1619
if not is_joined_rank and self.require_backward_grad_sync:
1620
requires_sync_tensor = torch.ones(1, device=self.device)
1622
requires_sync_tensor = torch.zeros(1, device=self.device)
1624
work = dist.all_reduce(
1625
requires_sync_tensor, group=self.process_group, async_op=True
1635
should_sync_backwards = requires_sync_tensor.item() != 0
1636
return should_sync_backwards
1642
def _check_and_sync_module_buffers(self):
1643
if self._check_sync_bufs_pre_fwd():
1644
authoritative_rank = self._find_common_rank(self._distributed_rank, False)
1645
self._sync_module_buffers(authoritative_rank)
1649
def _sync_final_model(self, is_last_joiner):
1653
self._authoritative_rank = self._find_common_rank(
1654
self._distributed_rank, is_last_joiner
1656
_sync_module_states(
1658
process_group=self.process_group,
1659
broadcast_bucket_size=self.broadcast_bucket_size,
1660
src=self._authoritative_rank,
1661
params_and_buffers_to_ignore=self.parameters_to_ignore,
1662
broadcast_buffers=self.broadcast_buffers,
1667
def _match_all_reduce_for_bwd_pass(self):
1676
grad_buckets = self.reducer._get_zeros_like_grad_buckets()
1677
for grad_bucket in grad_buckets:
1682
work = self.reducer._run_comm_hook(grad_bucket)
1683
comm_work.append(work)
1684
for work in comm_work:
1688
def _match_unused_params_allreduce(self):
1689
locally_used_param_map = self.reducer._get_local_used_map()
1690
self.process_group.allreduce(locally_used_param_map)
1694
divide_by_initial_world_size: bool = True,
1695
enable: bool = True,
1696
throw_on_early_termination: bool = False,
1699
Context manager for training with uneven inputs across processes in DDP.
1701
This context manager will keep track of already-joined DDP processes,
1702
and "shadow" the forward and backward passes by inserting collective
1703
communication operations to match with the ones created by non-joined
1704
DDP processes. This will ensure each collective call has a corresponding
1705
call by already-joined DDP processes, preventing hangs or errors that
1706
would otherwise happen when training with uneven inputs across
1707
processes. Alternatively, if the flag ``throw_on_early_termination`` is
1708
specified to be ``True``, all trainers will throw an error once one rank
1709
runs out of inputs, allowing these errors to be caught and handled
1710
according to application logic.
1712
Once all DDP processes have joined, the context manager will broadcast
1713
the model corresponding to the last joined process to all processes to
1714
ensure the model is the same across all processes
1715
(which is guaranteed by DDP).
1717
To use this to enable training with uneven inputs across processes,
1718
simply wrap this context manager around your training loop. No further
1719
modifications to the model or data loading is required.
1722
If the model or training loop this context manager is wrapped around
1723
has additional distributed collective operations, such as
1724
``SyncBatchNorm`` in the model's forward pass, then the flag
1725
``throw_on_early_termination`` must be enabled. This is because this
1726
context manager is not aware of non-DDP collective communication.
1727
This flag will cause all ranks to throw when any one rank
1728
exhausts inputs, allowing these errors to be caught and recovered
1729
from across all ranks.
1732
divide_by_initial_world_size (bool): If ``True``, will divide
1733
gradients by the initial ``world_size`` DDP training was launched
1734
with. If ``False``, will compute the effective world size
1735
(number of ranks that have not depleted their inputs yet) and
1736
divide gradients by that during allreduce. Set
1737
``divide_by_initial_world_size=True`` to ensure every input
1738
sample including the uneven inputs have equal weight in terms of
1739
how much they contribute to the global gradient. This is
1740
achieved by always dividing the gradient by the initial
1741
``world_size`` even when we encounter uneven inputs. If you set
1742
this to ``False``, we divide the gradient by the remaining
1743
number of nodes. This ensures parity with training on a smaller
1744
``world_size`` although it also means the uneven inputs would
1745
contribute more towards the global gradient. Typically, you
1746
would want to set this to ``True`` for cases where the last few
1747
inputs of your training job are uneven. In extreme cases, where
1748
there is a large discrepancy in the number of inputs, setting
1749
this to ``False`` might provide better results.
1750
enable (bool): Whether to enable uneven input detection or not. Pass
1751
in ``enable=False`` to disable in cases where you know that
1752
inputs are even across participating processes. Default is
1754
throw_on_early_termination (bool): Whether to throw an error
1755
or continue training when at least one rank has exhausted
1756
inputs. If ``True``, will throw upon the first rank reaching end
1757
of data. If ``False``, will continue training with a smaller
1758
effective world size until all ranks are joined. Note that if
1759
this flag is specified, then the flag
1760
``divide_by_initial_world_size`` would be ignored. Default
1766
>>> # xdoctest: +SKIP("Distributed")
1768
>>> import torch.distributed as dist
1770
>>> import torch.multiprocessing as mp
1771
>>> import torch.nn as nn
1772
>>> # On each spawned worker
1773
>>> def worker(rank):
1774
>>> dist.init_process_group("nccl", rank=rank, world_size=2)
1775
>>> torch.cuda.set_device(rank)
1776
>>> model = nn.Linear(1, 1, bias=False).to(rank)
1777
>>> model = torch.nn.parallel.DistributedDataParallel(
1778
>>> model, device_ids=[rank], output_device=rank
1780
>>> # Rank 1 gets one more input than rank 0.
1781
>>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
1782
>>> with model.join():
1783
>>> for _ in range(5):
1784
>>> for inp in inputs:
1785
>>> loss = model(inp).sum()
1787
>>> # Without the join() API, the below synchronization will hang
1788
>>> # blocking for rank 1's allreduce to complete.
1789
>>> torch.cuda.synchronize(device=rank)
1794
throw_on_early_termination,
1795
divide_by_initial_world_size=divide_by_initial_world_size,
1803
DDP join hook enables training on uneven inputs by mirroring communications in forward and backward passes.
1806
kwargs (dict): a :class:`dict` containing any keyword arguments
1807
to modify the behavior of the join hook at run time; all
1808
:class:`Joinable` instances sharing the same join context
1809
manager are forwarded the same value for ``kwargs``.
1811
The hook supports the following keyword arguments:
1812
divide_by_initial_world_size (bool, optional):
1813
If ``True``, then gradients are divided by the initial world
1814
size that DDP was launched with.
1815
If ``False``, then gradients are divided by the effective world
1816
size (i.e. the number of non-joined processes), meaning that
1817
the uneven inputs contribute more toward the global gradient.
1818
Typically, this should be set to ``True`` if the degree of
1819
unevenness is small but can be set to ``False`` in extreme
1820
cases for possibly better results.
1821
Default is ``True``.
1823
divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True)
1824
return _DDPJoinHook(
1825
self, divide_by_initial_world_size=divide_by_initial_world_size
1829
def join_device(self):
1833
def join_process_group(self):
1834
return self.process_group
1836
def _register_buffer_comm_hook(
1840
comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
1843
Allow custom registration of hooks that define how buffer are synchronized across ranks.
1845
The hook takes in an optional state and is passed in a Dict[str, Tensor]
1846
corresponding to buffer names and the buffers, and can run arbitrary reductions
1847
on buffers as opposed to DDP's default broadcast from rank 0. This is useful for
1848
example if a counter needs to be summed or averaged across ranks every iteration.
1851
state (Any): Optional state that is passed to the hook.
1852
hook (Callable): Callable with the following signature:
1853
``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``
1854
comm_hook_location (_BufferCommHookLocation): Enum value indicating
1855
where to run the hook.
1856
_BufferCommHookLocation.PRE_FORWARD means that the
1857
hook will run _before_ the forward pass, and
1858
_BufferCommHookLocation.POST_FORWARD means that the
1859
hook will run _after_ the forward pass.
1861
NOTE: To maximize performance, users can return a
1862
List[torch.futures.Future] from their hook, and DDP will
1863
install and await these hooks appropriately at the end of
1864
the backward pass. This will ensure all buffers are
1865
synchronized by the end of the backward pass. If this
1866
setting is used, it is recommended to pass
1867
comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
1868
which will trigger the hook after the forward pass.
1869
If _BufferCommHookLocation.PRE_FORWARD is used, users must
1870
ensure appropriate synchronization when manipulating GPU
1871
buffers in the forward pass.
1873
assert callable(hook)
1874
self.buffer_hook = _BufferCommHook(
1875
buffer_comm_hook=hook,
1876
buffer_comm_hook_state=state,
1877
buffer_comm_hook_location=comm_hook_location,
1880
def register_comm_hook(self, state: object, hook: Callable):
1882
Register communication hook for user-defined DDP aggregation of gradients across multiple workers.
1884
This hook would be very useful for researchers to try out new ideas. For
1885
example, this hook can be used to implement several algorithms like GossipGrad
1886
and gradient compression which involve different communication strategies for
1887
parameter syncs while running Distributed DataParallel training.
1890
state (object): Passed to the hook to maintain any state information during the training process.
1891
Examples include error feedback in gradient compression,
1892
peers to communicate with next in GossipGrad, etc.
1894
It is locally stored by each worker
1895
and shared by all the gradient tensors on the worker.
1896
hook (Callable): Callable with the following signature:
1897
``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``:
1899
This function is called once the bucket is ready. The
1900
hook can perform whatever processing is needed and return
1901
a Future indicating completion of any async work (ex: allreduce).
1902
If the hook doesn't perform any communication, it still
1903
must return a completed Future. The Future should hold the
1904
new value of grad bucket's tensors. Once a bucket is ready,
1905
c10d reducer would call this hook and use the tensors returned
1906
by the Future and copy grads to individual parameters.
1907
Note that the future's return type must be a single tensor.
1909
We also provide an API called ``get_future`` to retrieve a
1910
Future associated with the completion of ``c10d.ProcessGroup.Work``.
1911
``get_future`` is currently supported for NCCL and also supported for most
1912
operations on GLOO and MPI, except for peer to peer operations (send/recv).
1915
Grad bucket's tensors will not be predivided by world_size. User is responsible
1916
to divide by the world_size in case of operations like allreduce.
1919
DDP communication hook can only be registered once and should be registered
1920
before calling backward.
1923
The Future object that hook returns should contain a single tensor
1924
that has the same shape with the tensors inside grad bucket.
1927
``get_future`` API supports NCCL, and partially GLOO and MPI backends (no support
1928
for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``.
1931
Below is an example of a noop hook that returns the same tensor.
1933
>>> # xdoctest: +SKIP('undefined name')
1934
>>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
1935
>>> fut = torch.futures.Future()
1936
>>> fut.set_result(bucket.buffer())
1938
>>> ddp.register_comm_hook(state=None, hook=noop)
1941
Below is an example of a Parallel SGD algorithm where gradients are encoded before
1942
allreduce, and then decoded after allreduce.
1944
>>> # xdoctest: +SKIP('undefined name')
1945
>>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
1946
>>> encoded_tensor = encode(bucket.buffer()) # encode gradients
1947
>>> fut = torch.distributed.all_reduce(encoded_tensor).get_future()
1948
>>> # Define the then callback to decode.
1949
>>> def decode(fut):
1950
>>> decoded_tensor = decode(fut.value()[0]) # decode gradients
1951
>>> return decoded_tensor
1952
>>> return fut.then(decode)
1953
>>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
1955
self._check_comm_hook(hook)
1956
if hook.__name__ in ["bf16_compress_hook", "fp16_compress_hook"]:
1962
state = dist.group.WORLD
1963
assert self.logger is not None
1964
self.logger._set_comm_hook_name(hook.__qualname__)
1965
self._comm_hooks.append((hook, state))
1966
dist._register_comm_hook(self.reducer, state, hook)
1968
def _register_builtin_comm_hook(self, comm_hook_type):
1970
Register a built-in communication hook that specifies how DDP aggregates gradients across multiple workers.
1972
The built-in hooks aim to provide efficient C++ implementations for certain hooks,
1973
which might not be as efficient if implemented in Python using a Python communication hook.
1976
comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as ALLREDUCE, FP16_COMPRESS, etc.
1979
DDP communication hook can only be registered once and should be registered
1980
before calling backward.
1983
Below is an example of a FP16 compression where gradients are
1984
compressed into 16-bit floating-point numbers before allreduce, and
1985
then decompressed after allreduce.
1987
>>> # xdoctest: +SKIP('undefined name')
1988
>>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS)
1991
assert self.logger is not None
1992
self.logger._set_comm_hook_name(str(comm_hook_type))
1993
dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
1995
def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs):
1997
Register an optimizer in DDP to optimize parameter immediately after its gradient reduction.
1999
Registers an optimizer with DDP such that the optimization for a
2000
parameter will run immediately when that parameter's gradient is
2001
finished with reduction, instead of waiting for all parameters'
2002
gradients to finish reduction. This can result in a training speedup
2003
depending on your workload since the optimizer can run while gradient
2004
reduction for other parameters are still ongoing. In addition, this has
2005
the potential to reduce peak memory consumption during training, as it
2006
only needs to load the per-parameter optimizer states of a single
2007
parameter at a time, instead of loading all per-parameter optimizer
2011
optim (Type): a ``torch.optim.Optimizer`` class to be registered
2012
as a fused optimizer.
2013
*args (Sequence[Any]): Arguments to forward to `optim`.
2014
optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters
2015
to optimize, similar to `params` argument of traditional `torch.optim`
2016
Optimizers. If this is omitted, all DDP model parameters will be
2018
**kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim`.
2021
_register_fused_optim should only be called once on a DDP instance,
2022
and registering multiple fused optimizers for the same DDP model
2023
is not currently supported. Please ping
2024
https://github.com/pytorch/pytorch/issues/71595 if this is necessary
2028
_register_fused_optim and register_comm_hook currently do not
2029
compose together, meaning that custom DDP communication hooks are
2030
not supported with overlapped optimizers. Please ping
2031
https://github.com/pytorch/pytorch/issues/71595 if this is necessary
2035
Gradient accumulation and DDP `no_sync` are currently not supported
2036
with overlapped optimizer. Please ping
2037
https://github.com/pytorch/pytorch/issues/71595 if this is necessary
2042
>>> # xdoctest: +SKIP("No rendezvous handler")
2043
>>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
2044
>>> net = torch.nn.parallel.DistributedDataParallel(model, pg)
2046
>>> betas = (0.9, 0.99)
2048
>>> net._register_fused_optim(torch.optim.Adam, lr, betas=betas, eps=eps)
2049
>>> # Example with subset of parameters
2050
>>> params_to_opt = [list(net.parameters())[0]]
2051
>>> net._register_fused_optim(
2052
... torch.optim.Adam, lr, optim_params=params_to_opt, betas=betas, eps=eps
2057
from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim
2059
overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs)
2061
overlapped_optim.register_ddp(self)
2062
except NotImplementedError as e:
2064
f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}."
2067
def _distributed_broadcast_coalesced(
2068
self, tensors, buffer_size, authoritative_rank=0
2070
dist._broadcast_coalesced(
2071
self.process_group, tensors, buffer_size, authoritative_rank
2074
def _check_sync_bufs_post_fwd(self):
2076
self.will_sync_module_buffers()
2077
and hasattr(self, "buffer_hook")
2078
and self.buffer_hook.buffer_comm_hook_location
2079
== _BufferCommHookLocation.POST_FORWARD
2082
def _check_sync_bufs_pre_fwd(self):
2083
return self.will_sync_module_buffers() and (
2084
not hasattr(self, "buffer_hook")
2085
or self.buffer_hook.buffer_comm_hook_location
2086
== _BufferCommHookLocation.PRE_FORWARD
2089
def will_sync_module_buffers(self):
2091
self.require_forward_param_sync
2092
and self.broadcast_buffers
2093
and len(self.modules_buffers) > 0
2096
def _find_common_rank(self, input_rank, rank_cond):
2099
rank_to_use = torch.tensor(
2100
[input_rank if rank_cond else -1],
2103
dist.all_reduce(rank_to_use, op=ReduceOp.MAX, group=self.process_group)
2104
if rank_to_use.item() == -1:
2105
self._log_and_throw(
2107
"BUG! Expected rank_cond to be true for at least one process."
2108
" This indicates a bug in PyTorch, please report an issue.",
2110
return rank_to_use.item()
2112
def _sync_buffers(self):
2113
with torch.no_grad():
2119
if self._join_config.enable:
2120
authoritative_rank = self._find_common_rank(
2121
self._distributed_rank, True
2125
authoritative_rank = 0
2128
self._assign_modules_buffers()
2129
self._sync_module_buffers(authoritative_rank)
2131
def _sync_module_buffers(self, authoritative_rank):
2132
if not hasattr(self, "buffer_hook"):
2133
self._default_broadcast_coalesced(authoritative_rank=authoritative_rank)
2135
hook = self.buffer_hook.buffer_comm_hook
2136
state = self.buffer_hook.buffer_comm_hook_state
2137
futs = hook(state, self.named_module_buffers)
2138
if futs is not None:
2139
self.reducer._install_post_backward_futures(futs)
2141
def _default_broadcast_coalesced(
2142
self, bufs=None, bucket_size=None, authoritative_rank=0
2145
Broadcasts buffers from rank 0 to rest of workers.
2147
If bufs, bucket_size are None, default values self.modules_buffers
2148
and self.broadcast_bucket_size are used instead.
2151
bufs = self.modules_buffers
2152
if bucket_size is None:
2153
bucket_size = self.broadcast_bucket_size
2155
self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank)
2157
def _passing_sync_batchnorm_handle(self, module):
2158
for layer in module.modules():
2159
if isinstance(layer, torch.nn.modules.SyncBatchNorm):
2160
if self.device_type == "cpu":
2161
self._log_and_throw(
2163
"SyncBatchNorm layers only work with GPU modules",
2166
def _check_comm_hook(self, hook):
2167
if not callable(hook):
2168
self._log_and_throw(TypeError, "Communication hook must be callable.")
2170
sig = inspect.signature(hook)
2172
sig.parameters["bucket"].annotation != inspect._empty
2173
and sig.parameters["bucket"].annotation != dist.GradBucket
2175
self._log_and_throw(
2177
"Communication hook: bucket annotation should be dist.GradBucket.",
2181
sig.return_annotation != inspect._empty
2182
and sig.return_annotation != torch.futures.Future[torch.Tensor]
2184
self._log_and_throw(
2186
"Communication hook: return annotation should be torch.futures.Future[torch.Tensor].",
2189
if hook.__name__ in [
2190
"bf16_compress_hook",
2191
"bf16_compress_wrapper_hook",
2193
(torch.version.cuda is None and torch.version.hip is None)
2195
torch.version.cuda is not None
2196
and int(torch.version.cuda.split(".")[0]) < 11
2198
or not dist.is_available()
2199
or not dist.is_nccl_available()
2200
or torch.cuda.nccl.version() < (2, 10)
2202
self._log_and_throw(
2204
"BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+.",
2208
def _distributed_rank(self):
2209
return dist.get_rank(self.process_group)
2212
def _get_data_parallel_params(module, named_params=False):
2213
"""Return a generator of parameters managed by a given DDP unit."""
2215
module.parameters() if not named_params else module.named_parameters()
2217
if not hasattr(param, "_ddp_ignored"):
2221
def _set_params_and_buffers_to_ignore_for_model(
2222
module, params_and_buffers_to_ignore
2225
Set parameters and buffers to be ignored by DDP.
2227
Expected format for parameters is the fully qualified name: {module_name}.{param_name}, and
2228
similarly, {module_name}.{buffer_name} for buffers. For example:
2229
params_to_ignore = []
2230
# NB: model here is vanilla PyTorch module, not yet wrapped with DDP.
2231
for module_name, module in model.named_modules():
2232
for param_name, param in module.named_parameters(recurse=False):
2233
if should_ignore(param):
2234
# Create expected format
2235
fqn = f"{module_name}.{param_name}"
2236
params_to_ignore.append(fqn)
2237
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
2245
module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore
2246
for name, param in module.named_parameters():
2247
if name in params_and_buffers_to_ignore:
2248
param._ddp_ignored = True
2249
for name, buffer in module.named_buffers():
2250
if name in params_and_buffers_to_ignore:
2251
buffer._ddp_ignored = True
2253
def _get_ddp_logging_data(self):
2255
Return a dictionary of logging data for debugging and analysis.
2257
This interface can be called after DistributedDataParallel() is
2258
constructed. It returns a dictionary of logging data. It could help
2259
for debugging and analysis. The logging data includes DistributedDataParallel
2260
constructor input parameters, some internal states of DistributedDataParallel
2261
and performance metrics. Simply print the dictionary and see what
2263
This is a prototype interface and subject to change in the future.
2265
assert self.logger is not None
2266
ddp_logging_data = self.logger._get_ddp_logging_data()
2267
return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map}
2269
def _set_ddp_runtime_logging_sample_rate(self, sample_rate):
2271
Set sample_rate of collecting runtime stats.
2273
This interface allows users to set sample_rate of collecting
2274
runtime stats. The runtime stats will be recorded for the
2275
first 10 iterations, after 10 iterations runtime stats will be
2276
recorded once every "sample_rate" training iterations. In
2277
default, runtime stats are recorded for the first 10 iterations,
2278
after 10 iterations runtime stats are recorded once every
2279
"kDDPRuntimeLoggingSampleRate=100" training iterations.
2280
This is a prototype interface and subject to change in the future.
2283
self._log_and_throw(
2285
"DDP runtime logging sample rate should be equal or greater than 1",
2287
self.reducer._set_ddp_runtime_logging_sample_rate(sample_rate)
2289
def _set_static_graph(self):
2291
Set static graph for DDP.
2293
It is recommended to set static graph in the DDP constructor, which will
2294
call this private API internally.
2297
if self.static_graph:
2299
"You've set static_graph to be True, no need to set it again."
2302
self.static_graph = True
2303
self._static_graph_delay_allreduce_enqueued = False
2304
self.reducer._set_static_graph()
2305
assert self.logger is not None
2306
self.logger._set_static_graph()
2307
if self.find_unused_parameters:
2309
"You passed find_unused_parameters=true to DistributedDataParallel, "
2310
"`_set_static_graph` will detect unused parameters automatically, so "
2311
"you do not need to set find_unused_parameters=true, just be sure these "
2312
"unused parameters will not change during training loop while calling "
2313
"`_set_static_graph`."
2316
def _remove_autograd_hooks(self):
2317
"""Remove autograd hooks registered by the reducer on the model parameters."""
2318
self.reducer._remove_autograd_hooks()
2320
def _check_reducer_finalized(self):
2322
Check if the reducer has processed all buckets and finalized the backward appropriately.
2324
It is useful to call this method after calling .backward() in your training loop
2325
in order to avoid subsequent hard to debug errors down the road due to the
2326
reducer not finalizing backward.
2328
self.reducer._check_reducer_finalized()
2330
def _set_sparse_metadata(self, global_unique_ids):
2331
self.reducer._set_sparse_metadata(global_unique_ids)
2333
def _update_process_group(self, new_process_group):
2335
Dynamically updates the process group for DDP so that we can shrink/expand DDP
2336
world size without having to reinitialize DDP.
2338
NOTE: If you are using custom communications hooks via, register_comm_hook,
2339
you need to update the process groups for those hooks separately.
2345
self._has_rebuilt_buckets = False
2346
self.reducer._reset_state()
2348
if not _rank_not_in_group(new_process_group):
2349
self.process_group = new_process_group
2350
self.reducer._update_process_group(new_process_group)