pytorch

Форк
0
/
distribute.py 
783 строки · 29.1 Кб
1
import logging
2
import operator
3
from dataclasses import dataclass
4
from enum import auto, Enum
5
from functools import partial
6
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
7

8
import torch
9
import torch.distributed._spmd.experimental_ops
10
import torch.fx as fx
11

12
from torch.distributed._spmd.comm_tensor import _get_tracer
13
from torch.distributed._spmd.graph_utils import OP
14
from torch.distributed._spmd.log_utils import get_logger
15

16
from torch.distributed._tensor import DeviceMesh, DTensor
17
from torch.distributed._tensor.op_schema import OpSchema
18
from torch.distributed._tensor.placement_types import (
19
    _Partial,
20
    DTensorSpec,
21
    Placement,
22
    Replicate,
23
    Shard,
24
    TensorMeta,
25
)
26
from torch.distributed._tensor.redistribute import redistribute_local_tensor
27
from torch.fx.experimental.proxy_tensor import make_fx, proxy_slot
28
from torch.utils import _pytree as pytree
29
from torch.utils._pytree import tree_flatten, tree_map, tree_map_only, tree_unflatten
30

31

32
logger: Optional[logging.Logger] = None
33

34
aten = torch.ops.aten
35

36

37
class TrainingPhase(Enum):
38
    FORWARD = auto()
39
    BACKWARD = auto()
40

41

42
@dataclass
43
class Schema:
44
    mesh: DeviceMesh
45
    placements: List[Placement]
46

47

48
@dataclass
49
class DSymInt:
50
    """DSymInt represents a value retrieved by a SymInt op from a DTensor.
51

52
    DSymInt helps View and Factory ops to determine the placement and shape of the
53
    output tensor, as those operators either do not have an input DTensor or
54
    the input DTensor is insufficient to determine the output tensor's placement.
55
    """
56

57
    global_value: int  # value that the SymInt evaluates to
58
    local_value: int  # vaue that this SymInt evaluates to on the local shard
59
    mesh: DeviceMesh  # device mesh of the DTensor where this SymInt is retrieved from
60

61
    def is_shard(self) -> bool:
62
        return self.local_value != self.global_value
63

64
    @classmethod
65
    def from_node(cls, node: fx.Node, dtensor: DTensor) -> "DSymInt":
66
        dim: int = 0
67
        if node.target == aten.sym_size:
68
            dim = cast(int, node.args[1])
69
            return cls(
70
                global_value=dtensor.size(dim),
71
                local_value=dtensor.to_local().size(dim),
72
                mesh=dtensor.device_mesh,
73
            )
74
        elif node.target == aten.sym_numel:
75
            return cls(
76
                global_value=dtensor.numel(),
77
                local_value=dtensor.to_local().numel(),
78
                mesh=dtensor.device_mesh,
79
            )
80
        elif node.target == aten.sym_stride:
81
            dim = cast(int, node.args[1])
82
            return cls(
83
                global_value=dtensor.stride(dim),
84
                local_value=dtensor.to_local().stride(dim),
85
                mesh=dtensor.device_mesh,
86
            )
87
        else:
88
            raise NotImplementedError(f"DSymInt does not support {node.target}")
89

90

91
def _is_partial_dtensor(obj: Any) -> bool:
92
    """Check if object is 1) DTensor and  2) with any placement of _Partial."""
93
    if not isinstance(obj, DTensor):
94
        return False
95

96
    is_partial = False
97
    for placement in obj.placements:
98
        if isinstance(placement, _Partial):
99
            is_partial = True
100
            break
101

102
    return is_partial
103

104

105
def _dispatch_with_local_tensors(
106
    op: torch._ops.OpOverload,
107
    local_args: Tuple[Any, ...],
108
    kwargs: Optional[Dict[str, Any]] = None,
109
    specs: Optional[
110
        Dict[
111
            torch.Tensor,
112
            Tuple[torch.Size, DeviceMesh, Sequence[Placement], Sequence[Placement]],
113
        ]
114
    ] = None,
115
) -> Any:
116
    if kwargs is None:
117
        kwargs = {}
118
    if specs is None:
119
        specs = {}
120

121
    def redistribute(arg: Any) -> Any:
122
        tensor_shape, mesh, current_placement, target_placement = specs[arg]
123
        tensor_meta = TensorMeta(
124
            tensor_shape,
125
            stride=arg.stride(),
126
            dtype=arg.dtype,
127
        )
128
        current_spec = DTensorSpec(
129
            mesh, tuple(current_placement), tensor_meta=tensor_meta
130
        )
131
        target_spec = DTensorSpec(
132
            mesh, tuple(target_placement), tensor_meta=tensor_meta
133
        )
134

135
        return (
136
            redistribute_local_tensor(arg, current_spec, target_spec)  # type: ignore[index]
137
            if isinstance(arg, torch.Tensor) and arg in specs  # type: ignore[operator]
138
            else arg
139
        )
140

141
    # TODO: this is broken because it won't redistributed potential tensors on the kwargs
142
    return op(*tree_map(redistribute, local_args), **kwargs)
143

144

145
# Figure out how to specify a type spec for the return specs value
146
# without the entire structure.
147
# pyre-fixme
148
def _update_specs_for_redistribute(args, target_schema, redistribute):
149
    # Code adapted from pack_args_kwargs_with_local_tensor
150
    flatten_args, args_tree_spec = tree_flatten(args)
151
    flatten_args_schema = pytree.tree_leaves(target_schema.args_schema)
152

153
    specs: Dict[
154
        torch.Tensor,
155
        Tuple[
156
            torch.Size,
157
            DeviceMesh,
158
            Sequence[Placement],
159
            Sequence[Placement],
160
        ],
161
    ] = {}
162
    for i, arg in enumerate(flatten_args):
163
        if isinstance(arg, DTensor):
164
            if redistribute:
165
                specs[arg._local_tensor] = (
166
                    arg.size(),
167
                    flatten_args_schema[i].mesh,
168
                    arg.placements,
169
                    flatten_args_schema[i].placements,
170
                )
171
            flatten_args_schema[i] = arg._local_tensor
172

173
    unflattened_args = tree_unflatten(flatten_args_schema, args_tree_spec)
174
    return specs, unflattened_args
175

176

177
# When no tensor redistribution is required, we only need to update non-tensor args
178
# of the node according to op_schema and avoid building a GraphModule just for the
179
# node.
180
def _update_node_from_op_schema(node: torch.fx.Node, op_schema: OpSchema) -> None:
181
    flat_args, args_tree_spec = tree_flatten(node.args)
182
    flat_args_schema = pytree.tree_leaves(op_schema.args_schema)
183

184
    def is_sym_int_or_int(arg: Union[int, torch.fx.Node]) -> bool:
185
        if isinstance(arg, torch.fx.Node):
186
            return arg.target in [
187
                aten.sym_size,
188
                aten.sym_numel,
189
                aten.sym_stride,
190
            ]
191
        return isinstance(arg, int)
192

193
    assert len(flat_args) == len(flat_args_schema)
194
    for i, (arg, arg_schema) in enumerate(zip(flat_args, flat_args_schema)):
195
        if is_sym_int_or_int(arg) and isinstance(arg_schema, int):
196
            flat_args[i] = arg_schema
197

198
    args = tree_unflatten(flat_args, args_tree_spec)
199
    for idx, arg in enumerate(args):
200
        node.update_arg(idx, arg)
201
    return None
202

203

204
def _remap_arg(node_to_obj: Dict[fx.Node, Any], arg: Any) -> Any:
205
    if isinstance(arg, torch.fx.Node):
206
        obj = node_to_obj[arg]
207
        if _get_tracer():
208
            # This is a shared arg, already has a tracer from previous
209
            # tracing. Delete the tracer.
210
            del cast(Dict[Any, Any], obj.__dict__)[proxy_slot]
211
        return obj
212
    else:
213
        return arg
214

215

216
def unpack_sizes_and_dims(
217
    sizes: List[Union[DSymInt, int]], mesh: DeviceMesh
218
) -> Tuple[List[int], List[Placement]]:
219
    local_sizes: List[int] = [
220
        s.local_value if isinstance(s, DSymInt) else s for s in sizes
221
    ]
222
    placements: List[Placement] = [
223
        Shard(i)
224
        for i, a in enumerate(sizes)
225
        if (isinstance(a, DSymInt) and a.is_shard())
226
    ] or [Replicate()]
227

228
    assert len(placements) == mesh.ndim, (
229
        f"The number of sharded dimensions ({len(placements)}) must "
230
        f"match number of dimensions in device mesh ({mesh.ndim})."
231
    )
232

233
    return local_sizes, placements
234

235

236
def binop_sym_int_consumer_rule(node: fx.Node, args: Tuple[Any, ...]) -> DTensor:
237
    assert len(args) == 2, f"Expect two args but got op {node.target} with args {args}"
238
    assert isinstance(
239
        args[0], DTensor
240
    ), f"Expect 1st argument to be DTensor but got {args[0]}"
241
    assert isinstance(args[1], list), f"Expect 2nd argument as list but got {args[1]}"
242

243
    # extract sharded dimensions in the size list, the output DTensor should
244
    # follow these placements.
245
    local_sizes, placements = unpack_sizes_and_dims(args[1], args[0].device_mesh)
246

247
    # set node args to real int sizes.
248
    node.args = (node.args[0], local_sizes)
249
    op = cast(torch._ops.OpOverload, node.target)
250
    return DTensor.from_local(
251
        local_tensor=op(args[0]._local_tensor, local_sizes),
252
        device_mesh=args[0].device_mesh,
253
        placements=placements,
254
        run_check=False,
255
    )
256

257

258
def slice_backwad_sym_int_consumer_rule(
259
    node: fx.Node, args: Tuple[Any, ...]
260
) -> DTensor:
261
    grad_output, input_sizes, dim, start, end, step = args
262

263
    local_sizes: List[int] = [
264
        s.local_value if isinstance(s, DSymInt) else s for s in input_sizes
265
    ]
266

267
    input_tensor = torch.zeros(
268
        local_sizes, device=grad_output.device, dtype=grad_output.dtype
269
    )
270
    return DTensor.from_local(
271
        local_tensor=torch.slice_scatter(
272
            input_tensor, grad_output.to_local(), dim, start, end, step
273
        ),
274
        device_mesh=grad_output.device_mesh,
275
        placements=grad_output.placements,
276
        run_check=False,
277
    )
278

279

280
def factory_with_sizes_rule(
281
    node: fx.Node,
282
    args: Tuple[Any, ...],
283
    kwargs: Dict[str, Any],
284
    default_mesh: DeviceMesh,
285
) -> DTensor:
286
    flat_args = pytree.arg_tree_leaves(*args)
287
    assert not any(isinstance(a, DTensor) for a in flat_args), (
288
        f"Not expect DTensor argument for factory op, but got {node.target} "
289
        f"with arguments {args}."
290
    )
291
    assert isinstance(args[0], list), f"Expect 2nd argument as list but got {args[1]}"
292

293
    local_sizes, placements = unpack_sizes_and_dims(args[0], default_mesh)
294
    node.args = (local_sizes, *args[1:])
295
    op = cast(torch._ops.OpOverload, node.target)
296
    return DTensor.from_local(
297
        local_tensor=op(*node.args, **kwargs),
298
        device_mesh=default_mesh,
299
        placements=placements,
300
        run_check=False,
301
    )
302

303

304
def factory_arange_rule(
305
    node: fx.Node,
306
    args: Tuple[Any, ...],
307
    kwargs: Dict[str, Any],
308
    default_mesh: DeviceMesh,
309
) -> DTensor:
310
    node.args = tree_map(lambda a: a.local_value if isinstance(a, DSymInt) else a, args)
311
    op = cast(torch._ops.OpOverload, node.target)
312
    return DTensor.from_local(
313
        local_tensor=op(*node.args, **kwargs),
314
        device_mesh=default_mesh,
315
        placements=[Replicate()],
316
        run_check=False,
317
    )
318

319

320
def default_factory_op_rule(
321
    node: fx.Node,
322
    args: Tuple[Any, ...],
323
    kwargs: Dict[str, Any],
324
    default_mesh: DeviceMesh,
325
) -> DTensor:
326
    node.args, node.kwargs = args, kwargs
327
    op = cast(torch._ops.OpOverload, node.target)
328
    return DTensor.from_local(
329
        local_tensor=op(*node.args, **node.kwargs),
330
        device_mesh=default_mesh,
331
        placements=[Replicate()],
332
        run_check=False,
333
    )
334

335

336
# Dispatch override for view and factory ops that consume SymInt arguments,
337
# where the output spec should follow dimension placement where the SymInt comes
338
# from.
339
VIEW_SYM_INT_CONSUMERS: Dict[torch._ops.OpOverload, Callable] = {
340
    aten._unsafe_view.default: binop_sym_int_consumer_rule,
341
    aten.expand.default: binop_sym_int_consumer_rule,
342
    aten.slice_backward.default: slice_backwad_sym_int_consumer_rule,
343
    aten.view.default: binop_sym_int_consumer_rule,
344
}
345

346
FACTORY_SYM_INT_CONSUMERS: Dict[torch._ops.OpOverload, Callable] = {
347
    aten.full.default: factory_with_sizes_rule,
348
    aten.arange.default: factory_arange_rule,
349
    aten.arange.start: factory_arange_rule,
350
}
351

352

353
# Dispatch override for factory ops, as DTensor cannot propogate sharding spec
354
# without DTensor inputs.
355
FACTORY_OPS: Dict[torch._ops.OpOverload, Callable] = {
356
    aten.scalar_tensor.default: default_factory_op_rule,
357
    aten.arange.start: default_factory_op_rule,
358
    aten.zeros.default: default_factory_op_rule,
359
}
360

361

362
def _get_dtensor_dispatch_graph(
363
    node: fx.Node,
364
    node_to_obj: Dict[fx.Node, Any],
365
    *,
366
    force_make_fx: bool = False,
367
    default_mesh: Optional[DeviceMesh] = None,
368
) -> Optional[fx.GraphModule]:
369
    with torch.no_grad():
370
        # Args should be a list of objects post remapping.
371
        args = tree_map(partial(_remap_arg, node_to_obj), node.args)
372
        kwargs = tree_map(partial(_remap_arg, node_to_obj), node.kwargs)
373

374
        op_overload = cast(torch._ops.OpOverload, node.target)
375

376
        if any(
377
            a.is_shard()
378
            for a in pytree.arg_tree_leaves(*args)
379
            if isinstance(a, DSymInt)
380
        ):
381
            if op_overload in VIEW_SYM_INT_CONSUMERS:
382
                assert len(kwargs) == 0, f"Expect empty kwargs, but got {kwargs}"
383
                node_to_obj[node] = VIEW_SYM_INT_CONSUMERS[op_overload](node, args)
384
                return None
385
            elif op_overload in FACTORY_SYM_INT_CONSUMERS:
386
                assert default_mesh is not None, "Requires default mesh for factory ops"
387
                node_to_obj[node] = FACTORY_SYM_INT_CONSUMERS[op_overload](
388
                    node, args, kwargs, default_mesh
389
                )
390
                return None
391
            else:
392
                assert isinstance(logger, logging.Logger)
393
                logger.warning(
394
                    "Assuming using local_value from SymInt for %s"
395
                    "is mathematically correct. Full args are %s.",
396
                    op_overload,
397
                    args,
398
                )
399

400
        if node.target == aten.view.default:
401
            # HACK: this is a hack to get around with the fact that some
402
            # view operations on a "global" tensor is invalid usage
403
            # but somehow the view operation on the batch input might hit it
404
            # so we convert the view op to reshape before calling DTensor
405
            op_overload = aten.reshape.default
406

407
        # DSymInt args are not sharded on any dimension, local value and global
408
        # value should be the same
409
        args = tree_map(lambda a: a.local_value if isinstance(a, DSymInt) else a, args)
410
        kwargs = tree_map(
411
            lambda a: a.local_value if isinstance(a, DSymInt) else a, kwargs
412
        )
413

414
        if op_overload in FACTORY_OPS:
415
            # Don't pass factory ops to DTensor dispatch, as DTensor cannot
416
            # propagate sharding spec without DTensor inputs.
417
            node_to_obj[node] = FACTORY_OPS[op_overload](
418
                node, args, kwargs, default_mesh
419
            )
420
            return None
421

422
        dispatch = partial(
423
            _dispatch_with_local_tensors,
424
            op_overload,
425
            kwargs=kwargs,
426
            specs=args,
427
        )
428

429
        gm = make_fx(dispatch, _allow_non_fake_inputs=False)(args)
430
        # FIXME(@wanchaol, @mrshenli): the above seems to accidentally captured
431
        # DeviceMesh tensor ops when handling inplace operators? The ``_to_copy`` is
432
        # not connected to graph output. So, using DCE to get rid of it, but this
433
        # doesn't look correct.
434
        #
435
        # The following operators appear in the captured graph, where the dtype is
436
        # torch.int64.
437
        #
438
        # get_attr       _tensor_constant0  _tensor_constant0         ()
439
        # call_function  transpose          aten.transpose.int        (_tensor_constant0, -1, 0)
440
        # call_function  view               aten.view.default         (transpose, [-1, 2])
441
        # call_function  view_1             aten.view.default         (view, [2])
442
        # call_function  _to_copy           aten._to_copy.default     (view_1,)
443
        gm.graph.eliminate_dead_code()
444

445
        return gm
446

447

448
def _build_dummy_add_graph(
449
    dt: DTensor, node_to_obj: Dict[fx.Node, Any]
450
) -> Tuple[fx.GraphModule, Any]:
451
    """Create a graph for a dummy add function from a partial DTensor.
452

453
    This dummy add is used for triggering all_reduce on a Partial DTensor
454
    during the DTensor expansion of the traced graph.
455
    Also returns the actual DTensor after resharding.
456
    """
457

458
    def dummy_add(grad: torch.Tensor, zero: torch.Tensor) -> torch.Tensor:
459
        return grad + zero
460

461
    grad: torch.Tensor = dt._local_tensor
462
    zero: torch.Tensor = torch.zeros_like(dt._local_tensor)
463

464
    traced_add = make_fx(dummy_add)(grad, zero)
465

466
    placeholders = [n for n in traced_add.graph.nodes if n.op == OP.PLACEHOLDER]
467
    call_functions = [n for n in traced_add.graph.nodes if n.op == OP.CALL_FUNCTION]
468
    assert len(placeholders) == 2
469
    assert len(call_functions) == 1
470
    node_to_obj[placeholders[0]] = dt
471
    node_to_obj[placeholders[1]] = DTensor.from_local(
472
        zero, dt.device_mesh, [Replicate()], run_check=False
473
    )
474

475
    traced_dispatch = _get_dtensor_dispatch_graph(
476
        call_functions[0], node_to_obj, force_make_fx=True
477
    )
478
    assert traced_dispatch is not None
479

480
    # TODO(anj): This depends on the call function node -> actual DTensor output
481
    # mapping that we want to avoid for SPMD expansion
482
    return traced_dispatch, node_to_obj[call_functions[0]]
483

484

485
def _convert_output(
486
    gm: fx.GraphModule,
487
    node: fx.Node,
488
    node_to_obj: Dict[fx.Node, Any],
489
) -> fx.Node:
490
    new_args = []
491
    has_partial = False
492
    for argument in node.args[0]:  # type: ignore[union-attr]
493
        if not isinstance(argument, fx.Node):
494
            new_args.append(argument)
495
            continue
496

497
        obj = node_to_obj[argument]
498

499
        if not _is_partial_dtensor(obj):
500
            new_args.append(argument)
501
            continue
502

503
        has_partial = True
504

505
        # we know it's a dtensor from is partial DT check...
506
        dt = cast(DTensor, obj)
507

508
        traced_dispatch, result_obj = _build_dummy_add_graph(dt, node_to_obj)
509

510
        wait = [
511
            n
512
            for n in traced_dispatch.graph.nodes
513
            if n.name == "wait_comm" or n.name == "wait_tensor"
514
        ]
515
        add = [n for n in traced_dispatch.graph.nodes if n.name == "add"]
516
        assert len(wait) == 1 and len(add) == 1
517

518
        # remove add node and replace it with wait node
519
        add[0].replace_all_uses_with(wait[0])
520
        traced_dispatch.graph.eliminate_dead_code()
521
        # also update the actual DTensor corresponding to the node
522
        # TODO(anj): We require mapping of the final DTensor output to the wait
523
        # comm node.
524
        node_to_obj[wait[0]] = result_obj
525

526
        value_remap: Dict[fx.Node, fx.Node] = {}
527
        for dtn in traced_dispatch.graph.nodes:
528
            if dtn.op == OP.PLACEHOLDER:
529
                # do nothing, ignore placeholders, as it has
530
                # already been prepared in value_remap
531
                value_remap[dtn] = argument
532
            elif dtn.op == OP.OUTPUT:
533
                assert (
534
                    len(dtn.args) == 1 and len(dtn.args[0]) == 1
535
                ), f"Expecting single output, but got {dtn.args} {len(dtn.args)}"
536
                new_args.append(value_remap[dtn.args[0][0]])
537
                # the concrete DTensor value of output was added when creating the
538
                # inner graph (in _build_dummy_add_graph). Just add it to the final
539
                # output node so that we can report the final output specs correctly.
540
                # TODO(anj): We are depending on the concrete DTensor output of the dummy add.
541
                node_to_obj[value_remap[dtn.args[0][0]]] = node_to_obj[dtn.args[0][0]]
542

543
            else:
544
                if dtn.op == OP.GET_ATTR:
545
                    setattr(
546
                        gm,
547
                        dtn.target,
548
                        getattr(traced_dispatch, dtn.target),
549
                    )
550
                with gm.graph.inserting_before(node):
551
                    value_remap[dtn] = gm.graph.node_copy(dtn, lambda n: value_remap[n])
552
    if has_partial:
553
        gm.graph.erase_node(node)
554
        return gm.graph.output(new_args)
555
    else:
556
        return node
557

558

559
def _rebuild_graph(
560
    gm: fx.GraphModule,
561
    node_replacements: Dict[torch.fx.Node, torch.fx.GraphModule],
562
) -> None:
563
    # replace nodes in local traced graph with DTensor's dispatch graph
564
    for node in gm.graph.nodes:
565
        if node not in node_replacements:
566
            continue
567

568
        traced_dispatch = node_replacements[node]
569
        # Map DT's dispatch graph input placeholder nodes to the ones in
570
        # local traced graph. It uses index-based accessing, which is
571
        # brittle, just for testing purpose.
572
        flatten_args = pytree.arg_tree_leaves(*node.args)
573
        i, value_remap = 0, {}
574
        for dtn in traced_dispatch.graph.nodes:
575
            if dtn.op == OP.PLACEHOLDER:
576
                value_remap[dtn] = flatten_args[i]
577
                i += 1
578

579
        # insert DT's dispatch graph to traced local graph.
580
        with gm.graph.inserting_before(node):
581
            for dtn in traced_dispatch.graph.nodes:
582
                if dtn.op == OP.PLACEHOLDER:
583
                    # do nothing, ignore placeholders, as it has already
584
                    # been prepared in value_remap
585
                    pass
586
                elif dtn.op == OP.OUTPUT:
587
                    assert (
588
                        len(dtn.args) == 1
589
                    ), f"Expecting single output, but got {dtn.args} {len(dtn.args[0])}"
590
                    outputs = dtn.args[0]
591
                    # we currently support two very specific types of output
592
                    # 1. single output
593
                    # 2. multiple outputs resulting from getitem of all elements of tuple
594
                    if len(outputs) == 1:
595
                        # for single output, we replace the node with the single node
596
                        output = outputs[0]
597
                    else:
598
                        # for multiple outputs, we check that these outputs correspond
599
                        # to all elements of a tuple. In that case, we replace
600
                        # uses of the output directly with the original tuple
601
                        source = None
602
                        for i, out in enumerate(outputs):
603
                            # we allow None outputs for certain items in the tuple
604
                            if out is None:
605
                                continue
606
                            assert out.op == "call_function"
607
                            assert out.target.__module__ == "_operator"
608
                            assert out.target.__name__ == "getitem"
609
                            assert source is None or source == out.args[0]
610
                            source = out.args[0]
611
                            assert out.args[1] == i
612
                        assert source is not None
613
                        output = source
614

615
                    new_node = value_remap[output]
616
                    node.replace_all_uses_with(new_node)
617
                else:
618
                    value_remap[dtn] = gm.graph.node_copy(dtn, lambda n: value_remap[n])
619
                    if all(
620
                        isinstance(n.target, torch._ops.OpOverload)
621
                        and n.target._schema.name.startswith(
622
                            ("aten::_foreach", "aten::_fused_adam")
623
                        )
624
                        for n in [dtn, node]
625
                    ):
626
                        # FIXME(@mrshenli): This is a temporary solution enable
627
                        # foreach ops. The problem is that foreach ops returns
628
                        # List[Tensor], but make_fx will flatten that before
629
                        # passing those tensors to output node, which will
630
                        # introduce additional getitem nodes. These redundant
631
                        # getitem nodes breaks graph correctness as we cannot do
632
                        # getitem(getitem(foreach_out, 0), 0). This temporary
633
                        # solution skips getitem nodes in DTensor expanded
634
                        # subgraphs.
635
                        node.replace_all_uses_with(value_remap[dtn])
636
                        break
637
            # explicitly erase node instead of relying on DCE, as DCE does not
638
            # remove inplace copy_ correctly.
639
            gm.graph.erase_node(node)
640

641
    gm.graph.eliminate_dead_code()
642
    gm.recompile()
643

644

645
def _get_last_consumer_to_nodes(
646
    graph: fx.Graph,
647
) -> Dict[fx.Node, List[fx.Node]]:
648
    # Run through reverse nodes and record the first instance of a use
649
    # of a given node. This represents the *last* use of the node in the
650
    # execution order of the program, which we will use to free unused
651
    # values
652
    node_to_last_consumer: Dict[fx.Node, fx.Node] = {}
653
    last_consumer_to_nodes: Dict[fx.Node, List[fx.Node]] = {}
654

655
    def _register_final_consumer(arg_node: fx.Node, consumer: fx.Node) -> None:
656
        if arg_node not in node_to_last_consumer:
657
            node_to_last_consumer[arg_node] = consumer
658
            last_consumer_to_nodes.setdefault(consumer, []).append(arg_node)
659

660
    for node in reversed(graph.nodes):
661
        fx.node.map_arg(
662
            node.args, lambda arg_node: _register_final_consumer(arg_node, node)
663
        )
664
        fx.node.map_arg(
665
            node.kwargs,
666
            lambda kwarg_node: _register_final_consumer(kwarg_node, node),
667
        )
668

669
    return last_consumer_to_nodes
670

671

672
def _convert_to_distributed(
673
    gm: fx.GraphModule,
674
    inps: List[torch.Tensor],
675
    schemas: List[Schema],
676
    default_mesh: Optional[DeviceMesh] = None,
677
    _allow_partial: bool = False,
678
) -> Tuple[fx.GraphModule, Dict[str, Schema]]:
679
    """Transform a graph module to a distributed graph module.
680

681
    Returns:
682
        - transformed graph module
683
        - map from output name to DTensorSpec
684

685
    """
686
    global logger
687
    logger = get_logger("spmd_exp")
688
    operators = {getattr(operator, name) for name in operator.__all__}
689
    node_to_obj: Dict[fx.Node, Any] = {}
690
    # map local op node in traced_f to its corresponding subgraph of
691
    # DTensor ops.
692
    node_replacements: Dict[torch.fx.Node, torch.fx.GraphModule] = {}
693

694
    last_consumer_to_nodes = _get_last_consumer_to_nodes(gm.graph)
695

696
    output_schemas: Dict[str, Schema] = {}
697
    for i, node in enumerate(gm.graph.nodes):
698
        assert logger is not None
699
        logger.info("node%s: op=%s target=%s", i, node.op, node.target)
700
        if node.op == OP.PLACEHOLDER:
701
            assert i < len(
702
                inps
703
            ), f"got more placeholder nodes ({i + 1}) than inputs ({len(inps)})"
704

705
            # our example inputs are local shards. Create DTensors from them.
706
            node_to_obj[node] = DTensor.from_local(
707
                inps[i].clone(),  # use clone to avoid modifications from inplace ops
708
                schemas[i].mesh,
709
                schemas[i].placements,
710
                # prevent running this collective in backwards pass
711
                run_check=False,
712
            )
713
        elif isinstance(node.target, torch._ops.OpOverloadPacket):
714
            dtensor = cast(DTensor, node_to_obj[node.args[0]])
715
            node_to_obj[node] = DSymInt.from_node(node, dtensor)
716
        elif isinstance(node.target, torch._ops.OpOverload):
717
            replacement = _get_dtensor_dispatch_graph(
718
                node, node_to_obj, default_mesh=default_mesh
719
            )
720
            if replacement is not None:
721
                node_replacements[node] = replacement
722
        elif node.op == OP.OUTPUT:
723
            if not _allow_partial:
724
                # Returns an expanded dummy add node that ensures
725
                # that the partial output tensor has been converted
726
                # to a replicated tensor.
727
                node = _convert_output(gm, node, node_to_obj)
728

729
            # Save output sharding for the inputs to backward pass.
730
            # TODO(anj): Pipe the output schema for the BW pass
731
            # instead of requiring the full output DTensor to be
732
            # materialized.
733
            for inp_arg in node.args[0]:
734
                if isinstance(inp_arg, fx.Node):
735
                    obj = node_to_obj[inp_arg]
736
                    if isinstance(obj, DTensor):
737
                        output_schemas[inp_arg.name] = Schema(
738
                            obj.device_mesh, obj.placements  # type: ignore[arg-type]
739
                        )
740
        elif node.op == OP.CALL_FUNCTION:
741
            args = tree_map(partial(_remap_arg, node_to_obj), node.args)
742
            kwargs = tree_map(partial(_remap_arg, node_to_obj), node.kwargs)
743

744
            dsymints = list(
745
                filter(lambda a: isinstance(a, DSymInt), args + tuple(kwargs.values()))
746
            )
747

748
            if node.target in operators and len(dsymints) > 0:
749
                assert all(
750
                    dsymints[0].mesh == d.mesh for d in dsymints
751
                ), "all DSymInts must have the same mesh. "
752

753
                local_args = tree_map_only(DSymInt, lambda a: a.local_value, args)
754
                local_kwargs = tree_map_only(DSymInt, lambda a: a.local_value, kwargs)
755

756
                global_args = tree_map_only(DSymInt, lambda a: a.global_value, args)
757
                global_kwargs = tree_map_only(DSymInt, lambda a: a.global_value, kwargs)
758

759
                node.args = local_args
760
                node.kwargs = local_kwargs
761

762
                node_to_obj[node] = DSymInt(
763
                    local_value=node.target(*local_args, **local_kwargs),
764
                    global_value=node.target(*global_args, **global_kwargs),
765
                    mesh=dsymints[0].mesh,
766
                )
767
            else:
768
                assert len(dsymints) == 0, (
769
                    "SPMD expansion does not support SymInt in non-operator "
770
                    f"nodes, got {node.target}."
771
                )
772
                node_to_obj[node] = node.target(*args, **kwargs)
773
        else:
774
            raise ValueError(f"Unrecognized node.op type {node.op}")
775

776
        if node in last_consumer_to_nodes:
777
            # Save memory by deleting objs that wont be used anymore.
778
            for arg_node in last_consumer_to_nodes[node]:
779
                del node_to_obj[arg_node]
780

781
    _rebuild_graph(gm, node_replacements)
782

783
    return gm, output_schemas
784

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.