pytorch

Форк
0
/
_functional_collectives.py 
1064 строки · 39.0 Кб
1
import sys
2
import warnings
3
from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union
4

5
import torch
6
import torch.distributed as dist
7
import torch.distributed.distributed_c10d as c10d
8
from torch._custom_ops import impl_abstract
9
from torch.distributed.device_mesh import DeviceMesh
10
from torch.fx.experimental.proxy_tensor import get_innermost_proxy_mode
11

12
from . import _functional_collectives_impl as fun_col_impl
13
from ._functional_collectives_impl import (  # noqa: F401
14
    _register_tensor_wrapper,
15
    native_funcol_enabled,
16
)
17

18
try:
19
    from torch.utils._cxx_pytree import tree_map_only
20
except ImportError:
21
    from torch.utils._pytree import tree_map_only  # type: ignore[no-redef]
22

23

24
if torch._running_with_deploy():
25

26
    def is_torchdynamo_compiling():
27
        """Can't import torchdynamo in torchdeploy builds currently."""
28
        return False
29

30
else:
31
    try:
32
        from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
33
    except Exception:
34
        warnings.warn(
35
            "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"
36
        )
37

38
        def is_torchdynamo_compiling():
39
            return False
40

41

42
"""
43
New traceable, functional collectives.
44
RFC: https://github.com/pytorch/pytorch/issues/93173
45

46
  compiler: trace these ops with plain-old-data schemas, then choose how to lower them.
47
  eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses,
48
         automatically calling .wait() on underlying/hidden async 'work' obj only when fed to
49
         a downstream op.
50

51
Issues:
52
* Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files
53
* Proper support for eager requires inplace ops. We should explore having it as an option for the API.
54
"""
55

56
"""
57
Functional collectives are asynchronous only and we perform implicit stream synchronization
58
on behalf of the user.
59

60
We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness
61
first usage of the tensor and insert cross stream sync at the right place.
62

63
The above are the easy bits, the hard one is how we match the Work object returned by
64
c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective
65
op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the
66
dispatcher which might call other implementations that are allowed to change the returned
67
tensor - even return a tensor with a different shape (see ``torch.vmap``).
68

69
This means the caller of our ops receives a Tensor that is not guaranteed to be the same
70
allocated by our implementations and that makes pairing The AsyncTensor to the original
71
tensor a lot harder. This pairing is needed so we can lookup the Work object to use.
72

73
Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's
74
identity is not stable across dispatch, the op caller would end up with a different Tensor
75
instance that would not match any in the dictionary.
76

77
With Tensor identity out of the question, we decided use the tensor data pointer, which
78
should be stable across all the Tensor changes done during dispatch.
79

80
We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d.
81

82
We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait()
83

84
Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we
85
can clean up stale entries in the dictionary.
86

87
To eliminate the possibility of races we have a global version counter that is used by the finalizer.
88

89
As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo)
90

91
"""
92

93
"""
94
Functional collectives can accept any of these types to describe the ranks participating in collectives.
95

96
The different types will be desugared to a canonical format
97
"""
98
RANK_TYPES = Union[
99
    List[int],
100
    List[List[int]],
101
    dist.ProcessGroup,
102
    DeviceMesh,
103
    Tuple["dist._tensor.DeviceMesh", int],
104
    str,
105
]
106

107

108
"""
109
User facing APIs for functional collectives
110
-------------------------------------------
111

112
These apis are called by user code and expected to work both in eager execution and compilation,
113
but there are significant differences to how the two modes are implemented underneath.
114

115
Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op)
116
just before the tensor is first used.  Compiled tracing currently relies on the compiler to perform this optimization,
117
and cannot yet correctly trace the AsyncTensor wrapper class.  In the future, these paths may be unified
118
if sufficient subclass support is added in dynamo.
119

120
Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern.
121

122
Here's how it works under torch.compile/dynamo:
123
all_reduce(...)
124
  |--> _expand_group(...)               - desugars processgroup into canonical/traceable format
125
  |--> c10d_functional.all_reduce(...)  - dynamo captures this op call, doesn't trace deeper
126
  |--> _maybe_wrap_tensor(...)          - wait_tensor() op is immediately called, no AsyncTensor subclass needed
127

128
And under eager execution:
129
all_reduce(...)
130
  |--> _expand_group(...)               - same as above, but less critical for eager
131
  |--> c10d_functional.all_reduce(...)  - dispatches to real kernel OR records op in trace
132
  |--> _maybe_wrap_tensor(...)          - AsyncTensor wrapper applied to returned tensor,
133
                                          which issues wait_tensor() at the time of first use
134
"""
135

136

137
def wait_tensor(tensor):
138
    """
139
    Wait on a tensor returned by the collectives ops.
140

141
    Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
142
    """
143
    if native_funcol_enabled():
144
        return torch.ops._c10d_functional.wait_tensor(tensor)  # type: ignore[attr-defined]
145
    else:
146
        return torch.ops.c10d_functional.wait_tensor(tensor)  # type: ignore[attr-defined]
147

148

149
def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""):
150
    """
151
    Broadcasts the tensor to all processes in the given process group.
152

153
    Args:
154
        src (int): Source rank
155
        group (ProcessGroup or List[int]): The process group to work on.
156
        tag (str, optional): A unique identifier for the collective. Default: empty string
157
    """
158
    if native_funcol_enabled():
159
        group_name = _resolve_group_name(group, tag)
160
        tensor = torch.ops._c10d_functional.broadcast(self, src, group_name)
161
    else:
162
        tag, rankset, group_size = _expand_group(group, tag)
163
        tensor = torch.ops.c10d_functional.broadcast(
164
            self, src, tag, rankset, group_size
165
        )
166
    return _maybe_wrap_tensor(tensor)
167

168

169
def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
170
    """
171
    Reduces the tensor data across all machines in such a way that all get
172
    the final result.
173

174
    The input tensor is left unmodified.
175

176
    Group can be one of:
177
        List[int]: ranks participating in the collective.
178
        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
179
        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
180
        DeviceMesh: Do a SPMD collective over all ranks of the mesh
181
        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
182

183
    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
184
    that information and perform collective algebraic optimization. Use other forms of input for that.
185
    """
186
    if native_funcol_enabled():
187
        group_name = _resolve_group_name(group, tag)
188
        tensor = torch.ops._c10d_functional.all_reduce(
189
            self, reduceOp.lower(), group_name
190
        )
191
    else:
192
        tag, rankset, group_size = _expand_group(group, tag)
193
        tensor = torch.ops.c10d_functional.all_reduce(  # type: ignore[attr-defined]
194
            self,
195
            reduceOp,
196
            tag,
197
            rankset,
198
            group_size,
199
        )
200
    return _maybe_wrap_tensor(tensor)
201

202

203
def all_gather_tensor(
204
    self: torch.Tensor,
205
    gather_dim: int,
206
    group: RANK_TYPES,
207
    tag: str = "",
208
):
209
    """
210
    Gather tensor data across from all machines and concatenate over ``gather_dim``.
211

212
    Note that it currently only supports gather_dim = 0.
213

214
    The input tensor is left unmodified.
215
    Group can be one of:
216
        List[int]: ranks participating in the collective.
217
        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
218
        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
219
        DeviceMesh: Do a SPMD collective over all ranks of the mesh
220
        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
221

222
    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
223
    that information and perform collective algebraic optimization. Use other forms of input for that.
224
    """
225
    assert self.is_contiguous()
226
    if native_funcol_enabled():
227
        group_name = _resolve_group_name(group, tag)
228
        group_size = c10d._get_group_size_by_name(group_name)
229
        tensor = torch.ops._c10d_functional.all_gather_into_tensor(
230
            self, group_size, group_name
231
        )
232
    else:
233
        tag, rankset, group_size = _expand_group(group, tag)
234
        tensor = torch.ops.c10d_functional.all_gather_into_tensor(  # type: ignore[attr-defined]
235
            self,
236
            tag,
237
            rankset,
238
            group_size,
239
        )
240
    res = _maybe_wrap_tensor(tensor)
241
    # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
242
    if gather_dim != 0:
243
        # torch.cat access the data so we already need to wait here, first do wait
244
        # and then chunk + cat avoid us going through ACT dispatching logic again
245
        if isinstance(res, AsyncCollectiveTensor):
246
            res = res.wait()  # type: ignore[attr-defined]
247
        res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
248
    return res
249

250

251
def reduce_scatter_tensor(
252
    self: torch.Tensor,
253
    reduceOp: str,
254
    scatter_dim: int,
255
    group: RANK_TYPES,
256
    tag: str = "",
257
):
258
    """
259
    Reduces the tensor data across all machines in such a way that all get
260
    the final result, then scatter the results to corresponding ranks.
261

262

263
    The input tensor is left unmodified.
264
    Group can be one of:
265
        List[int]: ranks participating in the collective.
266
        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
267
        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
268
        DeviceMesh: Do a SPMD collective over all ranks of the mesh
269
        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
270
    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
271
    that information and perform collective algebraic optimization. Use other forms of input for that.
272
    """
273
    if native_funcol_enabled():
274
        group_name = _resolve_group_name(group, tag)
275
        group_size = c10d._get_group_size_by_name(group_name)
276
    else:
277
        tag, rankset, group_size = _expand_group(group, tag)
278

279
    assert (
280
        self.size(scatter_dim) % group_size == 0
281
    ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
282
    if scatter_dim != 0:
283
        tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
284
        self = torch.cat(tensor_list)
285

286
    if native_funcol_enabled():
287
        tensor = torch.ops._c10d_functional.reduce_scatter_tensor(
288
            self,
289
            reduceOp.lower(),
290
            group_size,
291
            group_name,  # type: ignore[possibly-undefined]
292
        )
293
    else:
294
        tensor = torch.ops.c10d_functional.reduce_scatter_tensor(  # type: ignore[attr-defined]
295
            self,
296
            reduceOp,
297
            tag,
298
            rankset,  # type: ignore[possibly-undefined]
299
            group_size,
300
        )
301
    res = _maybe_wrap_tensor(tensor)
302
    return res
303

304

305
def all_reduce_coalesced(
306
    self: List[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = ""
307
) -> List[torch.Tensor]:
308
    """
309
    Reduces a list of tensors across all machines in such a way that all get
310
    the final result.
311

312
    The all tensors in the input list are left unmodified.
313

314
    Group can be one of:
315
        List[int]: ranks participating in the collective.
316
        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
317
        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
318
        DeviceMesh: Do a SPMD collective over all ranks of the mesh
319
        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
320

321
    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
322
    that information and perform collective algebraic optimization. Use other forms of input for that.
323
    """
324
    if native_funcol_enabled():
325
        group_name = _resolve_group_name(group, tag)
326
        tensor_list = torch.ops._c10d_functional.all_reduce_coalesced(  # type: ignore[attr-defined]
327
            self,
328
            reduceOp.lower(),
329
            group_name,
330
        )
331
    else:
332
        tag, rankset, group_size = _expand_group(group, tag)
333
        tensor_list = torch.ops.c10d_functional.all_reduce_coalesced(  # type: ignore[attr-defined]
334
            self,
335
            reduceOp,
336
            tag,
337
            rankset,
338
            group_size,
339
        )
340
    return list(map(_maybe_wrap_tensor, tensor_list))
341

342

343
def all_gather_into_tensor_coalesced(
344
    self: List[torch.Tensor], group: RANK_TYPES, tag: str = ""
345
) -> List[torch.Tensor]:
346
    """
347
    Gather a list of tensors across from all machines.
348

349
    Note that it currently only supports gather_dim = 0.
350

351
    The input tensor is left unmodified.
352
    Group can be one of:
353
        List[int]: ranks participating in the collective.
354
        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
355
        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
356
        DeviceMesh: Do a SPMD collective over all ranks of the mesh
357
        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
358

359
    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
360
    that information and perform collective algebraic optimization. Use other forms of input for that.
361
    """
362
    if native_funcol_enabled():
363
        group_name = _resolve_group_name(group, tag)
364
        group_size = c10d._get_group_size_by_name(group_name)
365
        tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(  # type: ignore[attr-defined]
366
            self,
367
            group_size,
368
            group_name,
369
        )
370
    else:
371
        tag, rankset, group_size = _expand_group(group, tag)
372
        tensor_list = torch.ops.c10d_functional.all_gather_into_tensor_coalesced(  # type: ignore[attr-defined]
373
            self,
374
            tag,
375
            rankset,
376
            group_size,
377
        )
378
    return list(map(_maybe_wrap_tensor, tensor_list))
379

380

381
def reduce_scatter_tensor_coalesced(
382
    inputs: List[torch.Tensor],
383
    reduceOp: str,
384
    scatter_dim: List[int],
385
    group: RANK_TYPES,
386
    tag: str = "",
387
) -> List[torch.Tensor]:
388
    """
389
    Reduces a list of tensors across all machines in such a way that all get
390
    the final result, then scatter the results to corresponding ranks.
391

392
    The input tensors are left unmodified.
393
    Group can be one of:
394
        List[int]: ranks participating in the collective.
395
        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
396
        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
397
        DeviceMesh: Do a SPMD collective over all ranks of the mesh
398
        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
399

400
    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
401
    that information and perform collective algebraic optimization. Use other forms of input for that.
402
    """
403
    if native_funcol_enabled():
404
        group_name = _resolve_group_name(group, tag)
405
        group_size = c10d._get_group_size_by_name(group_name)
406
    else:
407
        tag, rankset, group_size = _expand_group(group, tag)
408

409
    assert len(scatter_dim) == len(inputs)
410
    for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
411
        assert (
412
            tensor.size(dim) % group_size == 0
413
        ), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
414
        if dim != 0:
415
            tensor_list = torch.chunk(tensor, group_size, dim=dim)
416
            inputs[idx] = torch.cat(tensor_list)
417

418
    if native_funcol_enabled():
419
        tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(  # type: ignore[attr-defined]
420
            inputs,
421
            reduceOp.lower(),
422
            group_size,
423
            group_name,  # type: ignore[possibly-undefined]
424
        )
425
    else:
426
        tensor_list = torch.ops.c10d_functional.reduce_scatter_tensor_coalesced(  # type: ignore[attr-defined]
427
            inputs,
428
            reduceOp,
429
            tag,
430
            rankset,  # type: ignore[possibly-undefined]
431
            group_size,
432
        )
433

434
    return list(map(_maybe_wrap_tensor, tensor_list))
435

436

437
# This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias.
438
# Today, this maps 1:1 with "aten ops that are views".
439
def _is_view_op(tgt):
440
    assert isinstance(tgt, torch._ops.OpOverload)
441
    schema = tgt._schema
442
    if len(schema.arguments) > 0:
443
        first_arg = schema.arguments[0]
444
        # check if op is a view
445
        return first_arg.alias_info is not None and not first_arg.alias_info.is_write
446

447

448
def all_to_all_single(
449
    self: torch.Tensor,
450
    output_split_sizes: Optional[List[int]],
451
    input_split_sizes: Optional[List[int]],
452
    group: RANK_TYPES,
453
    tag: str = "",
454
) -> torch.Tensor:
455
    """
456
    Each process splits input tensor and then scatters the split list
457
    to all processes in a group. Then concatenate the received tensors from all
458
    the processes in the group and return single output tensor.
459

460
    Group can be one of:
461
        List[int]: ranks participating in the collective.
462
        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
463
        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
464
        DeviceMesh: Do a SPMD collective over all ranks of the mesh
465
        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
466

467
    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
468
    that information and perform collective algebraic optimization. Use other forms of input for that.
469
    """
470
    if output_split_sizes is not None:
471
        assert all(
472
            isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
473
        ), output_split_sizes
474
    if input_split_sizes is not None:
475
        assert all(
476
            isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
477
        ), input_split_sizes
478
    if native_funcol_enabled():
479
        group_name = _resolve_group_name(group, tag)
480
        group_size = c10d._get_group_size_by_name(group_name)
481
        if output_split_sizes is None or input_split_sizes is None:
482
            assert output_split_sizes is None and input_split_sizes is None, (
483
                "output_split_sizes and input_split_sizes must either be "
484
                "specified together or both set to None"
485
            )
486
            output_split_sizes = [self.shape[0] // group_size] * group_size
487
            input_split_sizes = output_split_sizes
488
        tensor = torch.ops._c10d_functional.all_to_all_single(  # type: ignore[attr-defined]
489
            self,
490
            output_split_sizes,
491
            input_split_sizes,
492
            group_name,
493
        )
494
    else:
495
        tag, rankset, group_size = _expand_group(group, tag)
496
        tensor = torch.ops.c10d_functional.all_to_all_single(  # type: ignore[attr-defined]
497
            self,
498
            output_split_sizes,
499
            input_split_sizes,
500
            tag,
501
            rankset,
502
            group_size,
503
        )
504
    return _maybe_wrap_tensor(tensor)
505

506

507
def permute_tensor(
508
    self: torch.Tensor,
509
    src_dst: List[int],
510
    group: RANK_TYPES,
511
    tag: str = "",
512
) -> torch.Tensor:
513
    """
514
    Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should
515
    be defined such that src_dst[m] == n means m sends to n.
516

517
    Group can be one of:
518
        List[int]: ranks participating in the collective.
519
        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
520
        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
521
        DeviceMesh: Do a SPMD collective over all ranks of the mesh
522
        (DeviceMesh, int): Do a MPMD collective over one
523
    """
524
    t, rankset, group_size = _expand_group(group, tag)
525
    local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size)
526

527
    output_split_sizes = [0] * group_size
528
    input_split_sizes = [0] * group_size
529
    for src, dst in enumerate(src_dst):
530
        if src == dist.get_rank(local_pg):
531
            input_split_sizes[dst] = self.numel()
532
        if dst == dist.get_rank(local_pg):
533
            output_split_sizes[src] = self.numel()
534

535
    return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag)
536

537

538
class AsyncCollectiveTensor(torch.Tensor):
539
    r"""
540
    A Tensor wrapper subclass that is used to trigger a call to wait
541
    prior to first use of the underlying tensor.
542
    Use it inside functional collective pytorch wrappers like the following:
543
    def functional_collective(self, group, tag):
544
        tag, rankset, group_size = _expand_group(group, tag)
545
        tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
546
        return _maybe_wrap_tensor(tensor)
547
    """
548
    elem: torch.Tensor
549
    completed: bool
550

551
    __slots__ = ["elem", "completed"]
552

553
    __torch_function__ = torch._C._disabled_torch_function_impl
554

555
    @staticmethod
556
    def __new__(cls, elem: torch.Tensor):
557
        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
558
            cls,
559
            elem.size(),
560
            strides=elem.stride(),
561
            storage_offset=elem.storage_offset(),
562
            dtype=elem.dtype,
563
            layout=elem.layout,
564
            device=elem.device,
565
            requires_grad=False,
566
        )
567
        r.elem = elem
568
        r.completed = False
569
        return r
570

571
    def __tensor_flatten__(self):
572
        return ["elem"], None
573

574
    def tolist(self):
575
        self.trigger_wait()
576
        return self.elem.tolist()
577

578
    @staticmethod
579
    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
580
        assert meta is None
581
        elem = inner_tensors["elem"]
582
        return AsyncCollectiveTensor(elem)
583

584
    def __repr__(self):
585
        self.trigger_wait()
586
        return f"AsyncCollectiveTensor({self.elem})"
587

588
    def trigger_wait(self):
589
        if not self.completed:
590
            wait_tensor(self.elem)
591
            self.completed = True
592
        return self.elem
593

594
    def wait(self) -> torch.Tensor:
595
        wait_tensor(self.elem)
596
        return self.elem
597

598
    def _get_acs_underlying_tensor(self):
599
        """This method enables  _functional_collectives_impl to test if a tensor is an ACS"""
600
        return self.elem
601

602
    @classmethod
603
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
604
        if func == torch.ops.aten.view.default:
605
            # Fast handle aten.view as a lot of view related op goes to aten.view
606
            # eventually, this avoids pytree slowdown
607
            res = func(args[0].elem, args[1])
608
            wrapper_res = AsyncCollectiveTensor(res)
609
            _register_tensor_wrapper(wrapper_res)
610
            return wrapper_res
611

612
        is_view_op = _is_view_op(func)
613

614
        def unwrap(e: AsyncCollectiveTensor):
615
            # wait_tensor is idepotent and will do stream sync only once
616
            if not is_view_op:
617
                e.trigger_wait()
618
            return e.elem
619

620
        def wrap(e: torch.Tensor):
621
            # wait_tensor is idepotent and will do stream sync only once
622
            assert not isinstance(e, AsyncCollectiveTensor)
623
            res = AsyncCollectiveTensor(e)
624
            _register_tensor_wrapper(res)
625
            return res
626

627
        unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args)
628
        unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs)
629

630
        # we don't wrap the result as it doesn't need to be waited on.
631
        out = func(*unwrapped_args, **unwrapped_kwargs)
632

633
        # View ops dont require a sync, so we should re-wrap the outputs.
634
        if is_view_op:
635
            out = tree_map_only(torch.Tensor, wrap, out)
636

637
        return out
638

639
    def numpy(self):
640
        return self.wait().numpy()
641

642

643
"""
644
Utils and infrastructure for tracing support
645
"""
646

647

648
def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]:
649
    """
650
    _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable.
651

652
    By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside
653
    torchdynamo and can still interoperate with processgroup objects or other untraceable forms.
654
    """
655
    # had to define this hack _inside_ expand_group to avoid
656
    # graph_break [('torch.* op returned non-Tensor int
657
    # caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc)
658
    if TYPE_CHECKING:
659

660
        def cast_listlistint(x):
661
            return cast(List[List[int]], x)
662

663
        def cast_listint(x):
664
            return cast(List[int], x)
665

666
    else:
667
        # fake cast op for use at runtime since dynamo doesn't support real cast
668
        # also, dynamo didn't like encountering 'typing' objects ()
669
        # NotImplementedError: argument of type: <class 'typing._GenericAlias'>
670
        def cast_listlistint(x):
671
            return x
672

673
        def cast_listint(x):
674
            return x
675

676
    rankset: List[int]
677
    if isinstance(group, list):
678
        if isinstance(group[0], list):
679
            nested_list = cast_listlistint(group)
680
            rankset = []
681
            group_size = -1
682
            for rs in nested_list:
683
                rankset.extend(rs)
684
                if group_size != -1 and group_size != len(rs):
685
                    raise ValueError(
686
                        f"group sizes must be identical found {group_size} and {len(rs)}"
687
                    )
688
                group_size = len(rs)
689
        else:
690
            rankset = cast_listint(group)
691
            group_size = len(rankset)
692
    elif isinstance(group, dist.ProcessGroup):
693
        rankset = dist.get_process_group_ranks(group)
694
        group_size = len(rankset)
695
        tag = tag or c10d._get_group_tag(group)
696
    elif isinstance(group, DeviceMesh):
697
        assert (
698
            group.ndim == 1
699
        ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
700
        # TODO: it should run collective in the whole mesh instead of dim 0
701
        tag, rankset, _ = group._dim_group_infos[0]
702
        group_size = len(rankset)
703
    elif isinstance(group, tuple):
704
        if (
705
            len(group) == 2
706
            and isinstance(group[0], DeviceMesh)
707
            and isinstance(group[1], int)
708
        ):
709
            dmesh = group[0]
710
            dim = group[1]
711
            tag, rankset, _ = dmesh._dim_group_infos[dim]
712
            group_size = len(rankset)
713
        else:
714
            raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
715
    else:
716
        raise ValueError(
717
            "Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)."
718
        )
719

720
    return (tag, rankset, group_size)
721

722

723
def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
724
    """
725
    Given group in RANK_TYPES, return the group name.
726
    """
727
    # `tag` will be deprecated. See details in:
728
    # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
729
    if isinstance(group, dist.ProcessGroup):
730
        return group.group_name
731
    elif isinstance(group, str):
732
        return group
733
    elif isinstance(group, DeviceMesh):
734
        assert (
735
            group.ndim == 1
736
        ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
737
        return group._dim_group_infos[0][2]
738
    elif isinstance(group, tuple):
739
        if (
740
            len(group) == 2
741
            and isinstance(group[0], DeviceMesh)
742
            and isinstance(group[1], int)
743
        ):
744
            dmesh = group[0]
745
            dim = group[1]
746
            return dmesh._dim_group_infos[dim][2]
747
        else:
748
            raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
749
    elif isinstance(group, list):
750
        if not is_torchdynamo_compiling():
751
            warnings.warn(
752
                "The combination of ranks + tag as process group "
753
                "identifier has been deprecated. Please switch to "
754
                "using ProcessGroup, DeviceMesh, or group name instead."
755
            )
756
        return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag)
757
    else:
758
        raise ValueError(f"Unsupported group type: {type(group)}, {group}")
759

760

761
def _are_we_tracing() -> bool:
762
    if is_torchdynamo_compiling():
763
        return True
764
    # If functionalization is turned on, we are almost definitely compiling/tracing.
765
    # (In particular, AOTAutograd traces a model once with functionalization on
766
    #  but proxy tracing turned of, so this is how we detect it).
767
    if (
768
        torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
769
        is not None
770
    ):
771
        return True
772
    mode = get_innermost_proxy_mode()
773
    if mode is None:
774
        return False
775
    return mode.tracer is not None
776

777

778
def _maybe_wrap_tensor(self) -> torch.Tensor:
779
    if _are_we_tracing():
780
        return wait_tensor(self)
781
    res = AsyncCollectiveTensor(self)
782
    _register_tensor_wrapper(res)
783
    return cast(torch.Tensor, res)
784

785

786
def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size):
787
    def mk_out_tensor(shard):
788
        out_size = list(shard.size())
789
        out_size[0] *= group_size
790
        out_tensor = shard.new_empty(out_size)
791
        return out_tensor
792

793
    return [mk_out_tensor(t) for t in self]
794

795

796
# We now register meta kernels to deal with tracing
797
def _broadcast_meta(self, *args):
798
    return torch.empty_like(self)
799

800

801
def _all_reduce_meta(self, *args):
802
    return torch.empty_like(self)
803

804

805
def _wait_tensor_meta(self, *args):
806
    return torch.empty_like(self)
807

808

809
def _all_gather_into_tensor_meta(shard, tag, rankset, group_size):
810
    out_size = list(shard.size())
811
    out_size[0] *= group_size
812
    return shard.new_empty(out_size)
813

814

815
def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
816
    out_size = list(input.size())
817
    out_size[0] //= group_size
818
    return input.new_empty(out_size)
819

820

821
def _all_reduce_coalesced_meta(self, *args):
822
    return [torch.empty_like(t) for t in self]
823

824

825
def _all_reduce__meta(inp, *args):
826
    return inp
827

828

829
def _broadcast__meta(inp, *args):
830
    return inp
831

832

833
def _all_reduce_coalesced__meta(inputs, *args):
834
    return inputs
835

836

837
def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size):
838
    def mk_out_tensor(input):
839
        out_size = list(input.size())
840
        out_size[0] //= group_size
841
        out_tensor = input.new_empty(out_size)
842
        return out_tensor
843

844
    return [mk_out_tensor(t) for t in inputs]
845

846

847
# NB: We often say all_to_all has dynamic output size, but this is not
848
# technically true: instead, what typically happens is you manually
849
# communicate the output_split_sizes ahead of time (which is dynamic),
850
# but then you pass those sizes explicitly, and the all to all itself
851
# isn't dynamic, it just follows the specified output splits
852
def _all_to_all_single_meta(
853
    input, output_split_sizes, input_split_sizes, *args, **kwargs
854
):
855
    if output_split_sizes is None:
856
        return input.new_empty(input.size())
857
    else:
858
        for s in output_split_sizes:
859
            torch._check_is_size(s)
860
        out_size = list(input.size())
861
        out_size[0] = sum(output_split_sizes)
862
        return input.new_empty(out_size)
863

864

865
def _all_gather_into_tensor_native_meta(input, group_size, group_name):
866
    shape = list(input.size())
867
    shape[0] *= group_size
868
    return input.new_empty(shape)
869

870

871
def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name):
872
    return [
873
        _all_gather_into_tensor_native_meta(input, group_size, group_name)
874
        for input in inputs
875
    ]
876

877

878
def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name):
879
    shape = list(inp.size())
880
    shape[0] //= group_size
881
    return inp.new_empty(shape)
882

883

884
def _reduce_scatter_tensor_coalesced_native_meta(
885
    inputs, reduce_op, group_size, group_name
886
):
887
    return [
888
        _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name)
889
        for inp in inputs
890
    ]
891

892

893
def _register_ops():
894
    ops_defs = [
895
        "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor",
896
        "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
897
        "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
898
        "wait_tensor(Tensor self) -> Tensor",
899
        "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor",
900
        "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]",
901
        "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
902
        "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
903
        "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor",  # noqa: B950
904
    ]
905

906
    my_module = sys.modules[__name__]
907
    for op_def in ops_defs:
908
        op_name = op_def[0 : op_def.index("(")]
909
        backend_impl = getattr(fun_col_impl, f"_{op_name}")
910
        meta_impl = getattr(my_module, f"_{op_name}_meta")
911
        c10_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag)
912
        c10_lib_impl.impl(op_name, backend_impl, "CompositeExplicitAutograd")
913
        impl_abstract(f"c10d_functional::{op_name}")(meta_impl)
914

915

916
if not torch._running_with_deploy():
917
    # Library MUST be defined at module scope or it doesn't work
918
    # Creating a "DEF" Library always crashes torch::deploy so we create our Library instances here
919
    #   guarded against running inside it
920
    c10_lib = torch.library.Library("c10d_functional", "DEF")
921
    c10_lib_impl = torch.library.Library("c10d_functional", "IMPL")
922
    _register_ops()
923

924
    _c10_lib_impl = torch.library.Library("_c10d_functional", "IMPL")
925
    _c10_lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
926
    _c10_lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
927
    _c10_lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
928
    _c10_lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")
929
    _c10_lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
930
    _c10_lib_impl.impl(
931
        "all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta"
932
    )
933
    _c10_lib_impl.impl(
934
        "all_gather_into_tensor_coalesced",
935
        _all_gather_into_tensor_coalesced_native_meta,
936
        "Meta",
937
    )
938
    _c10_lib_impl.impl(
939
        "reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta"
940
    )
941
    _c10_lib_impl.impl(
942
        "reduce_scatter_tensor_coalesced",
943
        _reduce_scatter_tensor_coalesced_native_meta,
944
        "Meta",
945
    )
946
    _c10_lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta")
947
    _c10_lib_impl.impl("broadcast", _broadcast_meta, "Meta")
948
    _c10_lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
949
else:
950
    warnings.warn(
951
        "PyTorch Distributed functional collectives do not work with torch::deploy."
952
    )
953

954

955
"""
956
Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into
957
functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph.
958

959
We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via
960
the mapping dict below.
961

962
These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from
963
"""
964

965

966
def all_gather_tensor_inplace(
967
    output_tensor: torch.Tensor,
968
    input_tensor: torch.Tensor,
969
    group,  # TODO add a type,
970
    async_op: bool = False,
971
    tag: str = "",
972
    gather_dim: int = 0,
973
):
974
    assert (
975
        not async_op
976
    ), "Can't remap async version of inplace op to functional collective"
977
    return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))
978

979

980
def reduce_scatter_tensor_inplace(
981
    output: torch.Tensor,
982
    input: torch.Tensor,
983
    op: str = "sum",  # TODO type is actually c10d ReduceOp. is this ok?
984
    group=None,  # TODO add a type
985
    async_op: bool = False,
986
    scatter_dim: int = 0,
987
    tag: str = "",
988
):
989
    assert (
990
        not async_op
991
    ), "Can't remap async version of inplace op to functional collective"
992
    return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag))
993

994

995
def all_reduce_inplace(
996
    tensor: torch.Tensor,
997
    op: str = "sum",
998
    group=None,
999
    async_op: bool = False,
1000
    tag: str = "",
1001
):
1002
    assert (
1003
        not async_op
1004
    ), "Can't remap async version of inplace op to functional collective"
1005

1006
    return tensor.copy_(all_reduce(tensor, op, group, tag))
1007

1008

1009
def all_to_all_inplace(
1010
    output: torch.Tensor,
1011
    input: torch.Tensor,
1012
    output_split_sizes=None,
1013
    input_split_sizes=None,
1014
    group=None,
1015
    async_op=False,
1016
    tag: str = "",
1017
):
1018
    assert (
1019
        not async_op
1020
    ), "Can't remap async version of inplace op to functional collective"
1021
    return output.copy_(
1022
        all_to_all_single(input, output_split_sizes, input_split_sizes, group, tag)
1023
    )
1024

1025

1026
def all_gather_inplace(
1027
    tensor_list: List[torch.Tensor],
1028
    tensor: torch.Tensor,
1029
    group=None,
1030
    async_op=False,
1031
    tag: str = "",
1032
):
1033
    assert (
1034
        not async_op
1035
    ), "Can't remap async version of inplace op to functional collective"
1036
    output = all_gather_tensor(tensor, 0, group, tag)
1037
    for dst, src in zip(
1038
        tensor_list, output.split([t.size(0) for t in tensor_list], dim=0)
1039
    ):
1040
        dst.copy_(src)
1041
    return tensor_list
1042

1043

1044
from torch.distributed.distributed_c10d import (
1045
    _all_gather_base as legacy_all_gather_base,
1046
    _reduce_scatter_base as legacy_reduce_scatter_base,
1047
    all_gather as legacy_all_gather,
1048
    all_gather_into_tensor as legacy_allgather,
1049
    all_reduce as legacy_allreduce,
1050
    all_to_all_single as legacy_all_to_all_single,
1051
    reduce_scatter_tensor as legacy_reducescatter,
1052
)
1053

1054
# This dict should contain sets of functions that dynamo is allowed to remap.
1055
# Functions in this set should accept the same args/kwargs 1:1 as their mapping.
1056
traceable_collective_remaps = {
1057
    legacy_allgather: all_gather_tensor_inplace,
1058
    legacy_reducescatter: reduce_scatter_tensor_inplace,
1059
    legacy_allreduce: all_reduce_inplace,
1060
    legacy_all_to_all_single: all_to_all_inplace,
1061
    legacy_all_gather: all_gather_inplace,
1062
    legacy_reduce_scatter_base: reduce_scatter_tensor_inplace,
1063
    legacy_all_gather_base: all_gather_tensor_inplace,
1064
}
1065

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.