pytorch
1064 строки · 39.0 Кб
1import sys2import warnings3from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union4
5import torch6import torch.distributed as dist7import torch.distributed.distributed_c10d as c10d8from torch._custom_ops import impl_abstract9from torch.distributed.device_mesh import DeviceMesh10from torch.fx.experimental.proxy_tensor import get_innermost_proxy_mode11
12from . import _functional_collectives_impl as fun_col_impl13from ._functional_collectives_impl import ( # noqa: F40114_register_tensor_wrapper,15native_funcol_enabled,16)
17
18try:19from torch.utils._cxx_pytree import tree_map_only20except ImportError:21from torch.utils._pytree import tree_map_only # type: ignore[no-redef]22
23
24if torch._running_with_deploy():25
26def is_torchdynamo_compiling():27"""Can't import torchdynamo in torchdeploy builds currently."""28return False29
30else:31try:32from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling33except Exception:34warnings.warn(35"Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"36)37
38def is_torchdynamo_compiling():39return False40
41
42"""
43New traceable, functional collectives.
44RFC: https://github.com/pytorch/pytorch/issues/93173
45
46compiler: trace these ops with plain-old-data schemas, then choose how to lower them.
47eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses,
48automatically calling .wait() on underlying/hidden async 'work' obj only when fed to
49a downstream op.
50
51Issues:
52* Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files
53* Proper support for eager requires inplace ops. We should explore having it as an option for the API.
54"""
55
56"""
57Functional collectives are asynchronous only and we perform implicit stream synchronization
58on behalf of the user.
59
60We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness
61first usage of the tensor and insert cross stream sync at the right place.
62
63The above are the easy bits, the hard one is how we match the Work object returned by
64c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective
65op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the
66dispatcher which might call other implementations that are allowed to change the returned
67tensor - even return a tensor with a different shape (see ``torch.vmap``).
68
69This means the caller of our ops receives a Tensor that is not guaranteed to be the same
70allocated by our implementations and that makes pairing The AsyncTensor to the original
71tensor a lot harder. This pairing is needed so we can lookup the Work object to use.
72
73Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's
74identity is not stable across dispatch, the op caller would end up with a different Tensor
75instance that would not match any in the dictionary.
76
77With Tensor identity out of the question, we decided use the tensor data pointer, which
78should be stable across all the Tensor changes done during dispatch.
79
80We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d.
81
82We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait()
83
84Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we
85can clean up stale entries in the dictionary.
86
87To eliminate the possibility of races we have a global version counter that is used by the finalizer.
88
89As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo)
90
91"""
92
93"""
94Functional collectives can accept any of these types to describe the ranks participating in collectives.
95
96The different types will be desugared to a canonical format
97"""
98RANK_TYPES = Union[99List[int],100List[List[int]],101dist.ProcessGroup,102DeviceMesh,103Tuple["dist._tensor.DeviceMesh", int],104str,105]
106
107
108"""
109User facing APIs for functional collectives
110-------------------------------------------
111
112These apis are called by user code and expected to work both in eager execution and compilation,
113but there are significant differences to how the two modes are implemented underneath.
114
115Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op)
116just before the tensor is first used. Compiled tracing currently relies on the compiler to perform this optimization,
117and cannot yet correctly trace the AsyncTensor wrapper class. In the future, these paths may be unified
118if sufficient subclass support is added in dynamo.
119
120Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern.
121
122Here's how it works under torch.compile/dynamo:
123all_reduce(...)
124|--> _expand_group(...) - desugars processgroup into canonical/traceable format
125|--> c10d_functional.all_reduce(...) - dynamo captures this op call, doesn't trace deeper
126|--> _maybe_wrap_tensor(...) - wait_tensor() op is immediately called, no AsyncTensor subclass needed
127
128And under eager execution:
129all_reduce(...)
130|--> _expand_group(...) - same as above, but less critical for eager
131|--> c10d_functional.all_reduce(...) - dispatches to real kernel OR records op in trace
132|--> _maybe_wrap_tensor(...) - AsyncTensor wrapper applied to returned tensor,
133which issues wait_tensor() at the time of first use
134"""
135
136
137def wait_tensor(tensor):138"""139Wait on a tensor returned by the collectives ops.
140
141Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
142"""
143if native_funcol_enabled():144return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]145else:146return torch.ops.c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]147
148
149def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""):150"""151Broadcasts the tensor to all processes in the given process group.
152
153Args:
154src (int): Source rank
155group (ProcessGroup or List[int]): The process group to work on.
156tag (str, optional): A unique identifier for the collective. Default: empty string
157"""
158if native_funcol_enabled():159group_name = _resolve_group_name(group, tag)160tensor = torch.ops._c10d_functional.broadcast(self, src, group_name)161else:162tag, rankset, group_size = _expand_group(group, tag)163tensor = torch.ops.c10d_functional.broadcast(164self, src, tag, rankset, group_size165)166return _maybe_wrap_tensor(tensor)167
168
169def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):170"""171Reduces the tensor data across all machines in such a way that all get
172the final result.
173
174The input tensor is left unmodified.
175
176Group can be one of:
177List[int]: ranks participating in the collective.
178List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
179ProcessGroup: Will perform a collective using the ranks and tag of the PG.
180DeviceMesh: Do a SPMD collective over all ranks of the mesh
181(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
182
183:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
184that information and perform collective algebraic optimization. Use other forms of input for that.
185"""
186if native_funcol_enabled():187group_name = _resolve_group_name(group, tag)188tensor = torch.ops._c10d_functional.all_reduce(189self, reduceOp.lower(), group_name190)191else:192tag, rankset, group_size = _expand_group(group, tag)193tensor = torch.ops.c10d_functional.all_reduce( # type: ignore[attr-defined]194self,195reduceOp,196tag,197rankset,198group_size,199)200return _maybe_wrap_tensor(tensor)201
202
203def all_gather_tensor(204self: torch.Tensor,205gather_dim: int,206group: RANK_TYPES,207tag: str = "",208):209"""210Gather tensor data across from all machines and concatenate over ``gather_dim``.
211
212Note that it currently only supports gather_dim = 0.
213
214The input tensor is left unmodified.
215Group can be one of:
216List[int]: ranks participating in the collective.
217List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
218ProcessGroup: Will perform a collective using the ranks and tag of the PG.
219DeviceMesh: Do a SPMD collective over all ranks of the mesh
220(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
221
222:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
223that information and perform collective algebraic optimization. Use other forms of input for that.
224"""
225assert self.is_contiguous()226if native_funcol_enabled():227group_name = _resolve_group_name(group, tag)228group_size = c10d._get_group_size_by_name(group_name)229tensor = torch.ops._c10d_functional.all_gather_into_tensor(230self, group_size, group_name231)232else:233tag, rankset, group_size = _expand_group(group, tag)234tensor = torch.ops.c10d_functional.all_gather_into_tensor( # type: ignore[attr-defined]235self,236tag,237rankset,238group_size,239)240res = _maybe_wrap_tensor(tensor)241# TODO this should be done inside AsyncCollectiveTensor to delay the wait() call242if gather_dim != 0:243# torch.cat access the data so we already need to wait here, first do wait244# and then chunk + cat avoid us going through ACT dispatching logic again245if isinstance(res, AsyncCollectiveTensor):246res = res.wait() # type: ignore[attr-defined]247res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)248return res249
250
251def reduce_scatter_tensor(252self: torch.Tensor,253reduceOp: str,254scatter_dim: int,255group: RANK_TYPES,256tag: str = "",257):258"""259Reduces the tensor data across all machines in such a way that all get
260the final result, then scatter the results to corresponding ranks.
261
262
263The input tensor is left unmodified.
264Group can be one of:
265List[int]: ranks participating in the collective.
266List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
267ProcessGroup: Will perform a collective using the ranks and tag of the PG.
268DeviceMesh: Do a SPMD collective over all ranks of the mesh
269(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
270:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
271that information and perform collective algebraic optimization. Use other forms of input for that.
272"""
273if native_funcol_enabled():274group_name = _resolve_group_name(group, tag)275group_size = c10d._get_group_size_by_name(group_name)276else:277tag, rankset, group_size = _expand_group(group, tag)278
279assert (280self.size(scatter_dim) % group_size == 0281), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"282if scatter_dim != 0:283tensor_list = torch.chunk(self, group_size, dim=scatter_dim)284self = torch.cat(tensor_list)285
286if native_funcol_enabled():287tensor = torch.ops._c10d_functional.reduce_scatter_tensor(288self,289reduceOp.lower(),290group_size,291group_name, # type: ignore[possibly-undefined]292)293else:294tensor = torch.ops.c10d_functional.reduce_scatter_tensor( # type: ignore[attr-defined]295self,296reduceOp,297tag,298rankset, # type: ignore[possibly-undefined]299group_size,300)301res = _maybe_wrap_tensor(tensor)302return res303
304
305def all_reduce_coalesced(306self: List[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = ""307) -> List[torch.Tensor]:308"""309Reduces a list of tensors across all machines in such a way that all get
310the final result.
311
312The all tensors in the input list are left unmodified.
313
314Group can be one of:
315List[int]: ranks participating in the collective.
316List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
317ProcessGroup: Will perform a collective using the ranks and tag of the PG.
318DeviceMesh: Do a SPMD collective over all ranks of the mesh
319(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
320
321:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
322that information and perform collective algebraic optimization. Use other forms of input for that.
323"""
324if native_funcol_enabled():325group_name = _resolve_group_name(group, tag)326tensor_list = torch.ops._c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined]327self,328reduceOp.lower(),329group_name,330)331else:332tag, rankset, group_size = _expand_group(group, tag)333tensor_list = torch.ops.c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined]334self,335reduceOp,336tag,337rankset,338group_size,339)340return list(map(_maybe_wrap_tensor, tensor_list))341
342
343def all_gather_into_tensor_coalesced(344self: List[torch.Tensor], group: RANK_TYPES, tag: str = ""345) -> List[torch.Tensor]:346"""347Gather a list of tensors across from all machines.
348
349Note that it currently only supports gather_dim = 0.
350
351The input tensor is left unmodified.
352Group can be one of:
353List[int]: ranks participating in the collective.
354List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
355ProcessGroup: Will perform a collective using the ranks and tag of the PG.
356DeviceMesh: Do a SPMD collective over all ranks of the mesh
357(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
358
359:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
360that information and perform collective algebraic optimization. Use other forms of input for that.
361"""
362if native_funcol_enabled():363group_name = _resolve_group_name(group, tag)364group_size = c10d._get_group_size_by_name(group_name)365tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined]366self,367group_size,368group_name,369)370else:371tag, rankset, group_size = _expand_group(group, tag)372tensor_list = torch.ops.c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined]373self,374tag,375rankset,376group_size,377)378return list(map(_maybe_wrap_tensor, tensor_list))379
380
381def reduce_scatter_tensor_coalesced(382inputs: List[torch.Tensor],383reduceOp: str,384scatter_dim: List[int],385group: RANK_TYPES,386tag: str = "",387) -> List[torch.Tensor]:388"""389Reduces a list of tensors across all machines in such a way that all get
390the final result, then scatter the results to corresponding ranks.
391
392The input tensors are left unmodified.
393Group can be one of:
394List[int]: ranks participating in the collective.
395List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
396ProcessGroup: Will perform a collective using the ranks and tag of the PG.
397DeviceMesh: Do a SPMD collective over all ranks of the mesh
398(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
399
400:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
401that information and perform collective algebraic optimization. Use other forms of input for that.
402"""
403if native_funcol_enabled():404group_name = _resolve_group_name(group, tag)405group_size = c10d._get_group_size_by_name(group_name)406else:407tag, rankset, group_size = _expand_group(group, tag)408
409assert len(scatter_dim) == len(inputs)410for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):411assert (412tensor.size(dim) % group_size == 0413), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"414if dim != 0:415tensor_list = torch.chunk(tensor, group_size, dim=dim)416inputs[idx] = torch.cat(tensor_list)417
418if native_funcol_enabled():419tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( # type: ignore[attr-defined]420inputs,421reduceOp.lower(),422group_size,423group_name, # type: ignore[possibly-undefined]424)425else:426tensor_list = torch.ops.c10d_functional.reduce_scatter_tensor_coalesced( # type: ignore[attr-defined]427inputs,428reduceOp,429tag,430rankset, # type: ignore[possibly-undefined]431group_size,432)433
434return list(map(_maybe_wrap_tensor, tensor_list))435
436
437# This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias.
438# Today, this maps 1:1 with "aten ops that are views".
439def _is_view_op(tgt):440assert isinstance(tgt, torch._ops.OpOverload)441schema = tgt._schema442if len(schema.arguments) > 0:443first_arg = schema.arguments[0]444# check if op is a view445return first_arg.alias_info is not None and not first_arg.alias_info.is_write446
447
448def all_to_all_single(449self: torch.Tensor,450output_split_sizes: Optional[List[int]],451input_split_sizes: Optional[List[int]],452group: RANK_TYPES,453tag: str = "",454) -> torch.Tensor:455"""456Each process splits input tensor and then scatters the split list
457to all processes in a group. Then concatenate the received tensors from all
458the processes in the group and return single output tensor.
459
460Group can be one of:
461List[int]: ranks participating in the collective.
462List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
463ProcessGroup: Will perform a collective using the ranks and tag of the PG.
464DeviceMesh: Do a SPMD collective over all ranks of the mesh
465(DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
466
467:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
468that information and perform collective algebraic optimization. Use other forms of input for that.
469"""
470if output_split_sizes is not None:471assert all(472isinstance(size, (int, torch.SymInt)) for size in output_split_sizes473), output_split_sizes474if input_split_sizes is not None:475assert all(476isinstance(size, (int, torch.SymInt)) for size in input_split_sizes477), input_split_sizes478if native_funcol_enabled():479group_name = _resolve_group_name(group, tag)480group_size = c10d._get_group_size_by_name(group_name)481if output_split_sizes is None or input_split_sizes is None:482assert output_split_sizes is None and input_split_sizes is None, (483"output_split_sizes and input_split_sizes must either be "484"specified together or both set to None"485)486output_split_sizes = [self.shape[0] // group_size] * group_size487input_split_sizes = output_split_sizes488tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined]489self,490output_split_sizes,491input_split_sizes,492group_name,493)494else:495tag, rankset, group_size = _expand_group(group, tag)496tensor = torch.ops.c10d_functional.all_to_all_single( # type: ignore[attr-defined]497self,498output_split_sizes,499input_split_sizes,500tag,501rankset,502group_size,503)504return _maybe_wrap_tensor(tensor)505
506
507def permute_tensor(508self: torch.Tensor,509src_dst: List[int],510group: RANK_TYPES,511tag: str = "",512) -> torch.Tensor:513"""514Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should
515be defined such that src_dst[m] == n means m sends to n.
516
517Group can be one of:
518List[int]: ranks participating in the collective.
519List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
520ProcessGroup: Will perform a collective using the ranks and tag of the PG.
521DeviceMesh: Do a SPMD collective over all ranks of the mesh
522(DeviceMesh, int): Do a MPMD collective over one
523"""
524t, rankset, group_size = _expand_group(group, tag)525local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size)526
527output_split_sizes = [0] * group_size528input_split_sizes = [0] * group_size529for src, dst in enumerate(src_dst):530if src == dist.get_rank(local_pg):531input_split_sizes[dst] = self.numel()532if dst == dist.get_rank(local_pg):533output_split_sizes[src] = self.numel()534
535return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag)536
537
538class AsyncCollectiveTensor(torch.Tensor):539r"""540A Tensor wrapper subclass that is used to trigger a call to wait
541prior to first use of the underlying tensor.
542Use it inside functional collective pytorch wrappers like the following:
543def functional_collective(self, group, tag):
544tag, rankset, group_size = _expand_group(group, tag)
545tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
546return _maybe_wrap_tensor(tensor)
547"""
548elem: torch.Tensor549completed: bool550
551__slots__ = ["elem", "completed"]552
553__torch_function__ = torch._C._disabled_torch_function_impl554
555@staticmethod556def __new__(cls, elem: torch.Tensor):557r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]558cls,559elem.size(),560strides=elem.stride(),561storage_offset=elem.storage_offset(),562dtype=elem.dtype,563layout=elem.layout,564device=elem.device,565requires_grad=False,566)567r.elem = elem568r.completed = False569return r570
571def __tensor_flatten__(self):572return ["elem"], None573
574def tolist(self):575self.trigger_wait()576return self.elem.tolist()577
578@staticmethod579def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):580assert meta is None581elem = inner_tensors["elem"]582return AsyncCollectiveTensor(elem)583
584def __repr__(self):585self.trigger_wait()586return f"AsyncCollectiveTensor({self.elem})"587
588def trigger_wait(self):589if not self.completed:590wait_tensor(self.elem)591self.completed = True592return self.elem593
594def wait(self) -> torch.Tensor:595wait_tensor(self.elem)596return self.elem597
598def _get_acs_underlying_tensor(self):599"""This method enables _functional_collectives_impl to test if a tensor is an ACS"""600return self.elem601
602@classmethod603def __torch_dispatch__(cls, func, types, args=(), kwargs=None):604if func == torch.ops.aten.view.default:605# Fast handle aten.view as a lot of view related op goes to aten.view606# eventually, this avoids pytree slowdown607res = func(args[0].elem, args[1])608wrapper_res = AsyncCollectiveTensor(res)609_register_tensor_wrapper(wrapper_res)610return wrapper_res611
612is_view_op = _is_view_op(func)613
614def unwrap(e: AsyncCollectiveTensor):615# wait_tensor is idepotent and will do stream sync only once616if not is_view_op:617e.trigger_wait()618return e.elem619
620def wrap(e: torch.Tensor):621# wait_tensor is idepotent and will do stream sync only once622assert not isinstance(e, AsyncCollectiveTensor)623res = AsyncCollectiveTensor(e)624_register_tensor_wrapper(res)625return res626
627unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args)628unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs)629
630# we don't wrap the result as it doesn't need to be waited on.631out = func(*unwrapped_args, **unwrapped_kwargs)632
633# View ops dont require a sync, so we should re-wrap the outputs.634if is_view_op:635out = tree_map_only(torch.Tensor, wrap, out)636
637return out638
639def numpy(self):640return self.wait().numpy()641
642
643"""
644Utils and infrastructure for tracing support
645"""
646
647
648def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]:649"""650_expand_group desugars the different RANK_TYPES types into a canonical format that is traceable.
651
652By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside
653torchdynamo and can still interoperate with processgroup objects or other untraceable forms.
654"""
655# had to define this hack _inside_ expand_group to avoid656# graph_break [('torch.* op returned non-Tensor int657# caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc)658if TYPE_CHECKING:659
660def cast_listlistint(x):661return cast(List[List[int]], x)662
663def cast_listint(x):664return cast(List[int], x)665
666else:667# fake cast op for use at runtime since dynamo doesn't support real cast668# also, dynamo didn't like encountering 'typing' objects ()669# NotImplementedError: argument of type: <class 'typing._GenericAlias'>670def cast_listlistint(x):671return x672
673def cast_listint(x):674return x675
676rankset: List[int]677if isinstance(group, list):678if isinstance(group[0], list):679nested_list = cast_listlistint(group)680rankset = []681group_size = -1682for rs in nested_list:683rankset.extend(rs)684if group_size != -1 and group_size != len(rs):685raise ValueError(686f"group sizes must be identical found {group_size} and {len(rs)}"687)688group_size = len(rs)689else:690rankset = cast_listint(group)691group_size = len(rankset)692elif isinstance(group, dist.ProcessGroup):693rankset = dist.get_process_group_ranks(group)694group_size = len(rankset)695tag = tag or c10d._get_group_tag(group)696elif isinstance(group, DeviceMesh):697assert (698group.ndim == 1699), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"700# TODO: it should run collective in the whole mesh instead of dim 0701tag, rankset, _ = group._dim_group_infos[0]702group_size = len(rankset)703elif isinstance(group, tuple):704if (705len(group) == 2706and isinstance(group[0], DeviceMesh)707and isinstance(group[1], int)708):709dmesh = group[0]710dim = group[1]711tag, rankset, _ = dmesh._dim_group_infos[dim]712group_size = len(rankset)713else:714raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")715else:716raise ValueError(717"Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)."718)719
720return (tag, rankset, group_size)721
722
723def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:724"""725Given group in RANK_TYPES, return the group name.
726"""
727# `tag` will be deprecated. See details in:728# https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208729if isinstance(group, dist.ProcessGroup):730return group.group_name731elif isinstance(group, str):732return group733elif isinstance(group, DeviceMesh):734assert (735group.ndim == 1736), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"737return group._dim_group_infos[0][2]738elif isinstance(group, tuple):739if (740len(group) == 2741and isinstance(group[0], DeviceMesh)742and isinstance(group[1], int)743):744dmesh = group[0]745dim = group[1]746return dmesh._dim_group_infos[dim][2]747else:748raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")749elif isinstance(group, list):750if not is_torchdynamo_compiling():751warnings.warn(752"The combination of ranks + tag as process group "753"identifier has been deprecated. Please switch to "754"using ProcessGroup, DeviceMesh, or group name instead."755)756return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag)757else:758raise ValueError(f"Unsupported group type: {type(group)}, {group}")759
760
761def _are_we_tracing() -> bool:762if is_torchdynamo_compiling():763return True764# If functionalization is turned on, we are almost definitely compiling/tracing.765# (In particular, AOTAutograd traces a model once with functionalization on766# but proxy tracing turned of, so this is how we detect it).767if (768torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)769is not None770):771return True772mode = get_innermost_proxy_mode()773if mode is None:774return False775return mode.tracer is not None776
777
778def _maybe_wrap_tensor(self) -> torch.Tensor:779if _are_we_tracing():780return wait_tensor(self)781res = AsyncCollectiveTensor(self)782_register_tensor_wrapper(res)783return cast(torch.Tensor, res)784
785
786def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size):787def mk_out_tensor(shard):788out_size = list(shard.size())789out_size[0] *= group_size790out_tensor = shard.new_empty(out_size)791return out_tensor792
793return [mk_out_tensor(t) for t in self]794
795
796# We now register meta kernels to deal with tracing
797def _broadcast_meta(self, *args):798return torch.empty_like(self)799
800
801def _all_reduce_meta(self, *args):802return torch.empty_like(self)803
804
805def _wait_tensor_meta(self, *args):806return torch.empty_like(self)807
808
809def _all_gather_into_tensor_meta(shard, tag, rankset, group_size):810out_size = list(shard.size())811out_size[0] *= group_size812return shard.new_empty(out_size)813
814
815def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):816out_size = list(input.size())817out_size[0] //= group_size818return input.new_empty(out_size)819
820
821def _all_reduce_coalesced_meta(self, *args):822return [torch.empty_like(t) for t in self]823
824
825def _all_reduce__meta(inp, *args):826return inp827
828
829def _broadcast__meta(inp, *args):830return inp831
832
833def _all_reduce_coalesced__meta(inputs, *args):834return inputs835
836
837def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size):838def mk_out_tensor(input):839out_size = list(input.size())840out_size[0] //= group_size841out_tensor = input.new_empty(out_size)842return out_tensor843
844return [mk_out_tensor(t) for t in inputs]845
846
847# NB: We often say all_to_all has dynamic output size, but this is not
848# technically true: instead, what typically happens is you manually
849# communicate the output_split_sizes ahead of time (which is dynamic),
850# but then you pass those sizes explicitly, and the all to all itself
851# isn't dynamic, it just follows the specified output splits
852def _all_to_all_single_meta(853input, output_split_sizes, input_split_sizes, *args, **kwargs854):855if output_split_sizes is None:856return input.new_empty(input.size())857else:858for s in output_split_sizes:859torch._check_is_size(s)860out_size = list(input.size())861out_size[0] = sum(output_split_sizes)862return input.new_empty(out_size)863
864
865def _all_gather_into_tensor_native_meta(input, group_size, group_name):866shape = list(input.size())867shape[0] *= group_size868return input.new_empty(shape)869
870
871def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name):872return [873_all_gather_into_tensor_native_meta(input, group_size, group_name)874for input in inputs875]876
877
878def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name):879shape = list(inp.size())880shape[0] //= group_size881return inp.new_empty(shape)882
883
884def _reduce_scatter_tensor_coalesced_native_meta(885inputs, reduce_op, group_size, group_name886):887return [888_reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name)889for inp in inputs890]891
892
893def _register_ops():894ops_defs = [895"broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor",896"all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",897"all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",898"wait_tensor(Tensor self) -> Tensor",899"all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor",900"all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]",901"reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",902"reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",903"all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950904]905
906my_module = sys.modules[__name__]907for op_def in ops_defs:908op_name = op_def[0 : op_def.index("(")]909backend_impl = getattr(fun_col_impl, f"_{op_name}")910meta_impl = getattr(my_module, f"_{op_name}_meta")911c10_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag)912c10_lib_impl.impl(op_name, backend_impl, "CompositeExplicitAutograd")913impl_abstract(f"c10d_functional::{op_name}")(meta_impl)914
915
916if not torch._running_with_deploy():917# Library MUST be defined at module scope or it doesn't work918# Creating a "DEF" Library always crashes torch::deploy so we create our Library instances here919# guarded against running inside it920c10_lib = torch.library.Library("c10d_functional", "DEF")921c10_lib_impl = torch.library.Library("c10d_functional", "IMPL")922_register_ops()923
924_c10_lib_impl = torch.library.Library("_c10d_functional", "IMPL")925_c10_lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")926_c10_lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")927_c10_lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")928_c10_lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")929_c10_lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")930_c10_lib_impl.impl(931"all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta"932)933_c10_lib_impl.impl(934"all_gather_into_tensor_coalesced",935_all_gather_into_tensor_coalesced_native_meta,936"Meta",937)938_c10_lib_impl.impl(939"reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta"940)941_c10_lib_impl.impl(942"reduce_scatter_tensor_coalesced",943_reduce_scatter_tensor_coalesced_native_meta,944"Meta",945)946_c10_lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta")947_c10_lib_impl.impl("broadcast", _broadcast_meta, "Meta")948_c10_lib_impl.impl("broadcast_", _broadcast__meta, "Meta")949else:950warnings.warn(951"PyTorch Distributed functional collectives do not work with torch::deploy."952)953
954
955"""
956Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into
957functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph.
958
959We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via
960the mapping dict below.
961
962These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from
963"""
964
965
966def all_gather_tensor_inplace(967output_tensor: torch.Tensor,968input_tensor: torch.Tensor,969group, # TODO add a type,970async_op: bool = False,971tag: str = "",972gather_dim: int = 0,973):974assert (975not async_op976), "Can't remap async version of inplace op to functional collective"977return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))978
979
980def reduce_scatter_tensor_inplace(981output: torch.Tensor,982input: torch.Tensor,983op: str = "sum", # TODO type is actually c10d ReduceOp. is this ok?984group=None, # TODO add a type985async_op: bool = False,986scatter_dim: int = 0,987tag: str = "",988):989assert (990not async_op991), "Can't remap async version of inplace op to functional collective"992return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag))993
994
995def all_reduce_inplace(996tensor: torch.Tensor,997op: str = "sum",998group=None,999async_op: bool = False,1000tag: str = "",1001):1002assert (1003not async_op1004), "Can't remap async version of inplace op to functional collective"1005
1006return tensor.copy_(all_reduce(tensor, op, group, tag))1007
1008
1009def all_to_all_inplace(1010output: torch.Tensor,1011input: torch.Tensor,1012output_split_sizes=None,1013input_split_sizes=None,1014group=None,1015async_op=False,1016tag: str = "",1017):1018assert (1019not async_op1020), "Can't remap async version of inplace op to functional collective"1021return output.copy_(1022all_to_all_single(input, output_split_sizes, input_split_sizes, group, tag)1023)1024
1025
1026def all_gather_inplace(1027tensor_list: List[torch.Tensor],1028tensor: torch.Tensor,1029group=None,1030async_op=False,1031tag: str = "",1032):1033assert (1034not async_op1035), "Can't remap async version of inplace op to functional collective"1036output = all_gather_tensor(tensor, 0, group, tag)1037for dst, src in zip(1038tensor_list, output.split([t.size(0) for t in tensor_list], dim=0)1039):1040dst.copy_(src)1041return tensor_list1042
1043
1044from torch.distributed.distributed_c10d import (1045_all_gather_base as legacy_all_gather_base,1046_reduce_scatter_base as legacy_reduce_scatter_base,1047all_gather as legacy_all_gather,1048all_gather_into_tensor as legacy_allgather,1049all_reduce as legacy_allreduce,1050all_to_all_single as legacy_all_to_all_single,1051reduce_scatter_tensor as legacy_reducescatter,1052)
1053
1054# This dict should contain sets of functions that dynamo is allowed to remap.
1055# Functions in this set should accept the same args/kwargs 1:1 as their mapping.
1056traceable_collective_remaps = {1057legacy_allgather: all_gather_tensor_inplace,1058legacy_reducescatter: reduce_scatter_tensor_inplace,1059legacy_allreduce: all_reduce_inplace,1060legacy_all_to_all_single: all_to_all_inplace,1061legacy_all_gather: all_gather_inplace,1062legacy_reduce_scatter_base: reduce_scatter_tensor_inplace,1063legacy_all_gather_base: all_gather_tensor_inplace,1064}
1065