pytorch
783 строки · 29.1 Кб
1import logging
2import operator
3from dataclasses import dataclass
4from enum import auto, Enum
5from functools import partial
6from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
7
8import torch
9import torch.distributed._spmd.experimental_ops
10import torch.fx as fx
11
12from torch.distributed._spmd.comm_tensor import _get_tracer
13from torch.distributed._spmd.graph_utils import OP
14from torch.distributed._spmd.log_utils import get_logger
15
16from torch.distributed._tensor import DeviceMesh, DTensor
17from torch.distributed._tensor.op_schema import OpSchema
18from torch.distributed._tensor.placement_types import (
19_Partial,
20DTensorSpec,
21Placement,
22Replicate,
23Shard,
24TensorMeta,
25)
26from torch.distributed._tensor.redistribute import redistribute_local_tensor
27from torch.fx.experimental.proxy_tensor import make_fx, proxy_slot
28from torch.utils import _pytree as pytree
29from torch.utils._pytree import tree_flatten, tree_map, tree_map_only, tree_unflatten
30
31
32logger: Optional[logging.Logger] = None
33
34aten = torch.ops.aten
35
36
37class TrainingPhase(Enum):
38FORWARD = auto()
39BACKWARD = auto()
40
41
42@dataclass
43class Schema:
44mesh: DeviceMesh
45placements: List[Placement]
46
47
48@dataclass
49class DSymInt:
50"""DSymInt represents a value retrieved by a SymInt op from a DTensor.
51
52DSymInt helps View and Factory ops to determine the placement and shape of the
53output tensor, as those operators either do not have an input DTensor or
54the input DTensor is insufficient to determine the output tensor's placement.
55"""
56
57global_value: int # value that the SymInt evaluates to
58local_value: int # vaue that this SymInt evaluates to on the local shard
59mesh: DeviceMesh # device mesh of the DTensor where this SymInt is retrieved from
60
61def is_shard(self) -> bool:
62return self.local_value != self.global_value
63
64@classmethod
65def from_node(cls, node: fx.Node, dtensor: DTensor) -> "DSymInt":
66dim: int = 0
67if node.target == aten.sym_size:
68dim = cast(int, node.args[1])
69return cls(
70global_value=dtensor.size(dim),
71local_value=dtensor.to_local().size(dim),
72mesh=dtensor.device_mesh,
73)
74elif node.target == aten.sym_numel:
75return cls(
76global_value=dtensor.numel(),
77local_value=dtensor.to_local().numel(),
78mesh=dtensor.device_mesh,
79)
80elif node.target == aten.sym_stride:
81dim = cast(int, node.args[1])
82return cls(
83global_value=dtensor.stride(dim),
84local_value=dtensor.to_local().stride(dim),
85mesh=dtensor.device_mesh,
86)
87else:
88raise NotImplementedError(f"DSymInt does not support {node.target}")
89
90
91def _is_partial_dtensor(obj: Any) -> bool:
92"""Check if object is 1) DTensor and 2) with any placement of _Partial."""
93if not isinstance(obj, DTensor):
94return False
95
96is_partial = False
97for placement in obj.placements:
98if isinstance(placement, _Partial):
99is_partial = True
100break
101
102return is_partial
103
104
105def _dispatch_with_local_tensors(
106op: torch._ops.OpOverload,
107local_args: Tuple[Any, ...],
108kwargs: Optional[Dict[str, Any]] = None,
109specs: Optional[
110Dict[
111torch.Tensor,
112Tuple[torch.Size, DeviceMesh, Sequence[Placement], Sequence[Placement]],
113]
114] = None,
115) -> Any:
116if kwargs is None:
117kwargs = {}
118if specs is None:
119specs = {}
120
121def redistribute(arg: Any) -> Any:
122tensor_shape, mesh, current_placement, target_placement = specs[arg]
123tensor_meta = TensorMeta(
124tensor_shape,
125stride=arg.stride(),
126dtype=arg.dtype,
127)
128current_spec = DTensorSpec(
129mesh, tuple(current_placement), tensor_meta=tensor_meta
130)
131target_spec = DTensorSpec(
132mesh, tuple(target_placement), tensor_meta=tensor_meta
133)
134
135return (
136redistribute_local_tensor(arg, current_spec, target_spec) # type: ignore[index]
137if isinstance(arg, torch.Tensor) and arg in specs # type: ignore[operator]
138else arg
139)
140
141# TODO: this is broken because it won't redistributed potential tensors on the kwargs
142return op(*tree_map(redistribute, local_args), **kwargs)
143
144
145# Figure out how to specify a type spec for the return specs value
146# without the entire structure.
147# pyre-fixme
148def _update_specs_for_redistribute(args, target_schema, redistribute):
149# Code adapted from pack_args_kwargs_with_local_tensor
150flatten_args, args_tree_spec = tree_flatten(args)
151flatten_args_schema = pytree.tree_leaves(target_schema.args_schema)
152
153specs: Dict[
154torch.Tensor,
155Tuple[
156torch.Size,
157DeviceMesh,
158Sequence[Placement],
159Sequence[Placement],
160],
161] = {}
162for i, arg in enumerate(flatten_args):
163if isinstance(arg, DTensor):
164if redistribute:
165specs[arg._local_tensor] = (
166arg.size(),
167flatten_args_schema[i].mesh,
168arg.placements,
169flatten_args_schema[i].placements,
170)
171flatten_args_schema[i] = arg._local_tensor
172
173unflattened_args = tree_unflatten(flatten_args_schema, args_tree_spec)
174return specs, unflattened_args
175
176
177# When no tensor redistribution is required, we only need to update non-tensor args
178# of the node according to op_schema and avoid building a GraphModule just for the
179# node.
180def _update_node_from_op_schema(node: torch.fx.Node, op_schema: OpSchema) -> None:
181flat_args, args_tree_spec = tree_flatten(node.args)
182flat_args_schema = pytree.tree_leaves(op_schema.args_schema)
183
184def is_sym_int_or_int(arg: Union[int, torch.fx.Node]) -> bool:
185if isinstance(arg, torch.fx.Node):
186return arg.target in [
187aten.sym_size,
188aten.sym_numel,
189aten.sym_stride,
190]
191return isinstance(arg, int)
192
193assert len(flat_args) == len(flat_args_schema)
194for i, (arg, arg_schema) in enumerate(zip(flat_args, flat_args_schema)):
195if is_sym_int_or_int(arg) and isinstance(arg_schema, int):
196flat_args[i] = arg_schema
197
198args = tree_unflatten(flat_args, args_tree_spec)
199for idx, arg in enumerate(args):
200node.update_arg(idx, arg)
201return None
202
203
204def _remap_arg(node_to_obj: Dict[fx.Node, Any], arg: Any) -> Any:
205if isinstance(arg, torch.fx.Node):
206obj = node_to_obj[arg]
207if _get_tracer():
208# This is a shared arg, already has a tracer from previous
209# tracing. Delete the tracer.
210del cast(Dict[Any, Any], obj.__dict__)[proxy_slot]
211return obj
212else:
213return arg
214
215
216def unpack_sizes_and_dims(
217sizes: List[Union[DSymInt, int]], mesh: DeviceMesh
218) -> Tuple[List[int], List[Placement]]:
219local_sizes: List[int] = [
220s.local_value if isinstance(s, DSymInt) else s for s in sizes
221]
222placements: List[Placement] = [
223Shard(i)
224for i, a in enumerate(sizes)
225if (isinstance(a, DSymInt) and a.is_shard())
226] or [Replicate()]
227
228assert len(placements) == mesh.ndim, (
229f"The number of sharded dimensions ({len(placements)}) must "
230f"match number of dimensions in device mesh ({mesh.ndim})."
231)
232
233return local_sizes, placements
234
235
236def binop_sym_int_consumer_rule(node: fx.Node, args: Tuple[Any, ...]) -> DTensor:
237assert len(args) == 2, f"Expect two args but got op {node.target} with args {args}"
238assert isinstance(
239args[0], DTensor
240), f"Expect 1st argument to be DTensor but got {args[0]}"
241assert isinstance(args[1], list), f"Expect 2nd argument as list but got {args[1]}"
242
243# extract sharded dimensions in the size list, the output DTensor should
244# follow these placements.
245local_sizes, placements = unpack_sizes_and_dims(args[1], args[0].device_mesh)
246
247# set node args to real int sizes.
248node.args = (node.args[0], local_sizes)
249op = cast(torch._ops.OpOverload, node.target)
250return DTensor.from_local(
251local_tensor=op(args[0]._local_tensor, local_sizes),
252device_mesh=args[0].device_mesh,
253placements=placements,
254run_check=False,
255)
256
257
258def slice_backwad_sym_int_consumer_rule(
259node: fx.Node, args: Tuple[Any, ...]
260) -> DTensor:
261grad_output, input_sizes, dim, start, end, step = args
262
263local_sizes: List[int] = [
264s.local_value if isinstance(s, DSymInt) else s for s in input_sizes
265]
266
267input_tensor = torch.zeros(
268local_sizes, device=grad_output.device, dtype=grad_output.dtype
269)
270return DTensor.from_local(
271local_tensor=torch.slice_scatter(
272input_tensor, grad_output.to_local(), dim, start, end, step
273),
274device_mesh=grad_output.device_mesh,
275placements=grad_output.placements,
276run_check=False,
277)
278
279
280def factory_with_sizes_rule(
281node: fx.Node,
282args: Tuple[Any, ...],
283kwargs: Dict[str, Any],
284default_mesh: DeviceMesh,
285) -> DTensor:
286flat_args = pytree.arg_tree_leaves(*args)
287assert not any(isinstance(a, DTensor) for a in flat_args), (
288f"Not expect DTensor argument for factory op, but got {node.target} "
289f"with arguments {args}."
290)
291assert isinstance(args[0], list), f"Expect 2nd argument as list but got {args[1]}"
292
293local_sizes, placements = unpack_sizes_and_dims(args[0], default_mesh)
294node.args = (local_sizes, *args[1:])
295op = cast(torch._ops.OpOverload, node.target)
296return DTensor.from_local(
297local_tensor=op(*node.args, **kwargs),
298device_mesh=default_mesh,
299placements=placements,
300run_check=False,
301)
302
303
304def factory_arange_rule(
305node: fx.Node,
306args: Tuple[Any, ...],
307kwargs: Dict[str, Any],
308default_mesh: DeviceMesh,
309) -> DTensor:
310node.args = tree_map(lambda a: a.local_value if isinstance(a, DSymInt) else a, args)
311op = cast(torch._ops.OpOverload, node.target)
312return DTensor.from_local(
313local_tensor=op(*node.args, **kwargs),
314device_mesh=default_mesh,
315placements=[Replicate()],
316run_check=False,
317)
318
319
320def default_factory_op_rule(
321node: fx.Node,
322args: Tuple[Any, ...],
323kwargs: Dict[str, Any],
324default_mesh: DeviceMesh,
325) -> DTensor:
326node.args, node.kwargs = args, kwargs
327op = cast(torch._ops.OpOverload, node.target)
328return DTensor.from_local(
329local_tensor=op(*node.args, **node.kwargs),
330device_mesh=default_mesh,
331placements=[Replicate()],
332run_check=False,
333)
334
335
336# Dispatch override for view and factory ops that consume SymInt arguments,
337# where the output spec should follow dimension placement where the SymInt comes
338# from.
339VIEW_SYM_INT_CONSUMERS: Dict[torch._ops.OpOverload, Callable] = {
340aten._unsafe_view.default: binop_sym_int_consumer_rule,
341aten.expand.default: binop_sym_int_consumer_rule,
342aten.slice_backward.default: slice_backwad_sym_int_consumer_rule,
343aten.view.default: binop_sym_int_consumer_rule,
344}
345
346FACTORY_SYM_INT_CONSUMERS: Dict[torch._ops.OpOverload, Callable] = {
347aten.full.default: factory_with_sizes_rule,
348aten.arange.default: factory_arange_rule,
349aten.arange.start: factory_arange_rule,
350}
351
352
353# Dispatch override for factory ops, as DTensor cannot propogate sharding spec
354# without DTensor inputs.
355FACTORY_OPS: Dict[torch._ops.OpOverload, Callable] = {
356aten.scalar_tensor.default: default_factory_op_rule,
357aten.arange.start: default_factory_op_rule,
358aten.zeros.default: default_factory_op_rule,
359}
360
361
362def _get_dtensor_dispatch_graph(
363node: fx.Node,
364node_to_obj: Dict[fx.Node, Any],
365*,
366force_make_fx: bool = False,
367default_mesh: Optional[DeviceMesh] = None,
368) -> Optional[fx.GraphModule]:
369with torch.no_grad():
370# Args should be a list of objects post remapping.
371args = tree_map(partial(_remap_arg, node_to_obj), node.args)
372kwargs = tree_map(partial(_remap_arg, node_to_obj), node.kwargs)
373
374op_overload = cast(torch._ops.OpOverload, node.target)
375
376if any(
377a.is_shard()
378for a in pytree.arg_tree_leaves(*args)
379if isinstance(a, DSymInt)
380):
381if op_overload in VIEW_SYM_INT_CONSUMERS:
382assert len(kwargs) == 0, f"Expect empty kwargs, but got {kwargs}"
383node_to_obj[node] = VIEW_SYM_INT_CONSUMERS[op_overload](node, args)
384return None
385elif op_overload in FACTORY_SYM_INT_CONSUMERS:
386assert default_mesh is not None, "Requires default mesh for factory ops"
387node_to_obj[node] = FACTORY_SYM_INT_CONSUMERS[op_overload](
388node, args, kwargs, default_mesh
389)
390return None
391else:
392assert isinstance(logger, logging.Logger)
393logger.warning(
394"Assuming using local_value from SymInt for %s"
395"is mathematically correct. Full args are %s.",
396op_overload,
397args,
398)
399
400if node.target == aten.view.default:
401# HACK: this is a hack to get around with the fact that some
402# view operations on a "global" tensor is invalid usage
403# but somehow the view operation on the batch input might hit it
404# so we convert the view op to reshape before calling DTensor
405op_overload = aten.reshape.default
406
407# DSymInt args are not sharded on any dimension, local value and global
408# value should be the same
409args = tree_map(lambda a: a.local_value if isinstance(a, DSymInt) else a, args)
410kwargs = tree_map(
411lambda a: a.local_value if isinstance(a, DSymInt) else a, kwargs
412)
413
414if op_overload in FACTORY_OPS:
415# Don't pass factory ops to DTensor dispatch, as DTensor cannot
416# propagate sharding spec without DTensor inputs.
417node_to_obj[node] = FACTORY_OPS[op_overload](
418node, args, kwargs, default_mesh
419)
420return None
421
422dispatch = partial(
423_dispatch_with_local_tensors,
424op_overload,
425kwargs=kwargs,
426specs=args,
427)
428
429gm = make_fx(dispatch, _allow_non_fake_inputs=False)(args)
430# FIXME(@wanchaol, @mrshenli): the above seems to accidentally captured
431# DeviceMesh tensor ops when handling inplace operators? The ``_to_copy`` is
432# not connected to graph output. So, using DCE to get rid of it, but this
433# doesn't look correct.
434#
435# The following operators appear in the captured graph, where the dtype is
436# torch.int64.
437#
438# get_attr _tensor_constant0 _tensor_constant0 ()
439# call_function transpose aten.transpose.int (_tensor_constant0, -1, 0)
440# call_function view aten.view.default (transpose, [-1, 2])
441# call_function view_1 aten.view.default (view, [2])
442# call_function _to_copy aten._to_copy.default (view_1,)
443gm.graph.eliminate_dead_code()
444
445return gm
446
447
448def _build_dummy_add_graph(
449dt: DTensor, node_to_obj: Dict[fx.Node, Any]
450) -> Tuple[fx.GraphModule, Any]:
451"""Create a graph for a dummy add function from a partial DTensor.
452
453This dummy add is used for triggering all_reduce on a Partial DTensor
454during the DTensor expansion of the traced graph.
455Also returns the actual DTensor after resharding.
456"""
457
458def dummy_add(grad: torch.Tensor, zero: torch.Tensor) -> torch.Tensor:
459return grad + zero
460
461grad: torch.Tensor = dt._local_tensor
462zero: torch.Tensor = torch.zeros_like(dt._local_tensor)
463
464traced_add = make_fx(dummy_add)(grad, zero)
465
466placeholders = [n for n in traced_add.graph.nodes if n.op == OP.PLACEHOLDER]
467call_functions = [n for n in traced_add.graph.nodes if n.op == OP.CALL_FUNCTION]
468assert len(placeholders) == 2
469assert len(call_functions) == 1
470node_to_obj[placeholders[0]] = dt
471node_to_obj[placeholders[1]] = DTensor.from_local(
472zero, dt.device_mesh, [Replicate()], run_check=False
473)
474
475traced_dispatch = _get_dtensor_dispatch_graph(
476call_functions[0], node_to_obj, force_make_fx=True
477)
478assert traced_dispatch is not None
479
480# TODO(anj): This depends on the call function node -> actual DTensor output
481# mapping that we want to avoid for SPMD expansion
482return traced_dispatch, node_to_obj[call_functions[0]]
483
484
485def _convert_output(
486gm: fx.GraphModule,
487node: fx.Node,
488node_to_obj: Dict[fx.Node, Any],
489) -> fx.Node:
490new_args = []
491has_partial = False
492for argument in node.args[0]: # type: ignore[union-attr]
493if not isinstance(argument, fx.Node):
494new_args.append(argument)
495continue
496
497obj = node_to_obj[argument]
498
499if not _is_partial_dtensor(obj):
500new_args.append(argument)
501continue
502
503has_partial = True
504
505# we know it's a dtensor from is partial DT check...
506dt = cast(DTensor, obj)
507
508traced_dispatch, result_obj = _build_dummy_add_graph(dt, node_to_obj)
509
510wait = [
511n
512for n in traced_dispatch.graph.nodes
513if n.name == "wait_comm" or n.name == "wait_tensor"
514]
515add = [n for n in traced_dispatch.graph.nodes if n.name == "add"]
516assert len(wait) == 1 and len(add) == 1
517
518# remove add node and replace it with wait node
519add[0].replace_all_uses_with(wait[0])
520traced_dispatch.graph.eliminate_dead_code()
521# also update the actual DTensor corresponding to the node
522# TODO(anj): We require mapping of the final DTensor output to the wait
523# comm node.
524node_to_obj[wait[0]] = result_obj
525
526value_remap: Dict[fx.Node, fx.Node] = {}
527for dtn in traced_dispatch.graph.nodes:
528if dtn.op == OP.PLACEHOLDER:
529# do nothing, ignore placeholders, as it has
530# already been prepared in value_remap
531value_remap[dtn] = argument
532elif dtn.op == OP.OUTPUT:
533assert (
534len(dtn.args) == 1 and len(dtn.args[0]) == 1
535), f"Expecting single output, but got {dtn.args} {len(dtn.args)}"
536new_args.append(value_remap[dtn.args[0][0]])
537# the concrete DTensor value of output was added when creating the
538# inner graph (in _build_dummy_add_graph). Just add it to the final
539# output node so that we can report the final output specs correctly.
540# TODO(anj): We are depending on the concrete DTensor output of the dummy add.
541node_to_obj[value_remap[dtn.args[0][0]]] = node_to_obj[dtn.args[0][0]]
542
543else:
544if dtn.op == OP.GET_ATTR:
545setattr(
546gm,
547dtn.target,
548getattr(traced_dispatch, dtn.target),
549)
550with gm.graph.inserting_before(node):
551value_remap[dtn] = gm.graph.node_copy(dtn, lambda n: value_remap[n])
552if has_partial:
553gm.graph.erase_node(node)
554return gm.graph.output(new_args)
555else:
556return node
557
558
559def _rebuild_graph(
560gm: fx.GraphModule,
561node_replacements: Dict[torch.fx.Node, torch.fx.GraphModule],
562) -> None:
563# replace nodes in local traced graph with DTensor's dispatch graph
564for node in gm.graph.nodes:
565if node not in node_replacements:
566continue
567
568traced_dispatch = node_replacements[node]
569# Map DT's dispatch graph input placeholder nodes to the ones in
570# local traced graph. It uses index-based accessing, which is
571# brittle, just for testing purpose.
572flatten_args = pytree.arg_tree_leaves(*node.args)
573i, value_remap = 0, {}
574for dtn in traced_dispatch.graph.nodes:
575if dtn.op == OP.PLACEHOLDER:
576value_remap[dtn] = flatten_args[i]
577i += 1
578
579# insert DT's dispatch graph to traced local graph.
580with gm.graph.inserting_before(node):
581for dtn in traced_dispatch.graph.nodes:
582if dtn.op == OP.PLACEHOLDER:
583# do nothing, ignore placeholders, as it has already
584# been prepared in value_remap
585pass
586elif dtn.op == OP.OUTPUT:
587assert (
588len(dtn.args) == 1
589), f"Expecting single output, but got {dtn.args} {len(dtn.args[0])}"
590outputs = dtn.args[0]
591# we currently support two very specific types of output
592# 1. single output
593# 2. multiple outputs resulting from getitem of all elements of tuple
594if len(outputs) == 1:
595# for single output, we replace the node with the single node
596output = outputs[0]
597else:
598# for multiple outputs, we check that these outputs correspond
599# to all elements of a tuple. In that case, we replace
600# uses of the output directly with the original tuple
601source = None
602for i, out in enumerate(outputs):
603# we allow None outputs for certain items in the tuple
604if out is None:
605continue
606assert out.op == "call_function"
607assert out.target.__module__ == "_operator"
608assert out.target.__name__ == "getitem"
609assert source is None or source == out.args[0]
610source = out.args[0]
611assert out.args[1] == i
612assert source is not None
613output = source
614
615new_node = value_remap[output]
616node.replace_all_uses_with(new_node)
617else:
618value_remap[dtn] = gm.graph.node_copy(dtn, lambda n: value_remap[n])
619if all(
620isinstance(n.target, torch._ops.OpOverload)
621and n.target._schema.name.startswith(
622("aten::_foreach", "aten::_fused_adam")
623)
624for n in [dtn, node]
625):
626# FIXME(@mrshenli): This is a temporary solution enable
627# foreach ops. The problem is that foreach ops returns
628# List[Tensor], but make_fx will flatten that before
629# passing those tensors to output node, which will
630# introduce additional getitem nodes. These redundant
631# getitem nodes breaks graph correctness as we cannot do
632# getitem(getitem(foreach_out, 0), 0). This temporary
633# solution skips getitem nodes in DTensor expanded
634# subgraphs.
635node.replace_all_uses_with(value_remap[dtn])
636break
637# explicitly erase node instead of relying on DCE, as DCE does not
638# remove inplace copy_ correctly.
639gm.graph.erase_node(node)
640
641gm.graph.eliminate_dead_code()
642gm.recompile()
643
644
645def _get_last_consumer_to_nodes(
646graph: fx.Graph,
647) -> Dict[fx.Node, List[fx.Node]]:
648# Run through reverse nodes and record the first instance of a use
649# of a given node. This represents the *last* use of the node in the
650# execution order of the program, which we will use to free unused
651# values
652node_to_last_consumer: Dict[fx.Node, fx.Node] = {}
653last_consumer_to_nodes: Dict[fx.Node, List[fx.Node]] = {}
654
655def _register_final_consumer(arg_node: fx.Node, consumer: fx.Node) -> None:
656if arg_node not in node_to_last_consumer:
657node_to_last_consumer[arg_node] = consumer
658last_consumer_to_nodes.setdefault(consumer, []).append(arg_node)
659
660for node in reversed(graph.nodes):
661fx.node.map_arg(
662node.args, lambda arg_node: _register_final_consumer(arg_node, node)
663)
664fx.node.map_arg(
665node.kwargs,
666lambda kwarg_node: _register_final_consumer(kwarg_node, node),
667)
668
669return last_consumer_to_nodes
670
671
672def _convert_to_distributed(
673gm: fx.GraphModule,
674inps: List[torch.Tensor],
675schemas: List[Schema],
676default_mesh: Optional[DeviceMesh] = None,
677_allow_partial: bool = False,
678) -> Tuple[fx.GraphModule, Dict[str, Schema]]:
679"""Transform a graph module to a distributed graph module.
680
681Returns:
682- transformed graph module
683- map from output name to DTensorSpec
684
685"""
686global logger
687logger = get_logger("spmd_exp")
688operators = {getattr(operator, name) for name in operator.__all__}
689node_to_obj: Dict[fx.Node, Any] = {}
690# map local op node in traced_f to its corresponding subgraph of
691# DTensor ops.
692node_replacements: Dict[torch.fx.Node, torch.fx.GraphModule] = {}
693
694last_consumer_to_nodes = _get_last_consumer_to_nodes(gm.graph)
695
696output_schemas: Dict[str, Schema] = {}
697for i, node in enumerate(gm.graph.nodes):
698assert logger is not None
699logger.info("node%s: op=%s target=%s", i, node.op, node.target)
700if node.op == OP.PLACEHOLDER:
701assert i < len(
702inps
703), f"got more placeholder nodes ({i + 1}) than inputs ({len(inps)})"
704
705# our example inputs are local shards. Create DTensors from them.
706node_to_obj[node] = DTensor.from_local(
707inps[i].clone(), # use clone to avoid modifications from inplace ops
708schemas[i].mesh,
709schemas[i].placements,
710# prevent running this collective in backwards pass
711run_check=False,
712)
713elif isinstance(node.target, torch._ops.OpOverloadPacket):
714dtensor = cast(DTensor, node_to_obj[node.args[0]])
715node_to_obj[node] = DSymInt.from_node(node, dtensor)
716elif isinstance(node.target, torch._ops.OpOverload):
717replacement = _get_dtensor_dispatch_graph(
718node, node_to_obj, default_mesh=default_mesh
719)
720if replacement is not None:
721node_replacements[node] = replacement
722elif node.op == OP.OUTPUT:
723if not _allow_partial:
724# Returns an expanded dummy add node that ensures
725# that the partial output tensor has been converted
726# to a replicated tensor.
727node = _convert_output(gm, node, node_to_obj)
728
729# Save output sharding for the inputs to backward pass.
730# TODO(anj): Pipe the output schema for the BW pass
731# instead of requiring the full output DTensor to be
732# materialized.
733for inp_arg in node.args[0]:
734if isinstance(inp_arg, fx.Node):
735obj = node_to_obj[inp_arg]
736if isinstance(obj, DTensor):
737output_schemas[inp_arg.name] = Schema(
738obj.device_mesh, obj.placements # type: ignore[arg-type]
739)
740elif node.op == OP.CALL_FUNCTION:
741args = tree_map(partial(_remap_arg, node_to_obj), node.args)
742kwargs = tree_map(partial(_remap_arg, node_to_obj), node.kwargs)
743
744dsymints = list(
745filter(lambda a: isinstance(a, DSymInt), args + tuple(kwargs.values()))
746)
747
748if node.target in operators and len(dsymints) > 0:
749assert all(
750dsymints[0].mesh == d.mesh for d in dsymints
751), "all DSymInts must have the same mesh. "
752
753local_args = tree_map_only(DSymInt, lambda a: a.local_value, args)
754local_kwargs = tree_map_only(DSymInt, lambda a: a.local_value, kwargs)
755
756global_args = tree_map_only(DSymInt, lambda a: a.global_value, args)
757global_kwargs = tree_map_only(DSymInt, lambda a: a.global_value, kwargs)
758
759node.args = local_args
760node.kwargs = local_kwargs
761
762node_to_obj[node] = DSymInt(
763local_value=node.target(*local_args, **local_kwargs),
764global_value=node.target(*global_args, **global_kwargs),
765mesh=dsymints[0].mesh,
766)
767else:
768assert len(dsymints) == 0, (
769"SPMD expansion does not support SymInt in non-operator "
770f"nodes, got {node.target}."
771)
772node_to_obj[node] = node.target(*args, **kwargs)
773else:
774raise ValueError(f"Unrecognized node.op type {node.op}")
775
776if node in last_consumer_to_nodes:
777# Save memory by deleting objs that wont be used anymore.
778for arg_node in last_consumer_to_nodes[node]:
779del node_to_obj[arg_node]
780
781_rebuild_graph(gm, node_replacements)
782
783return gm, output_schemas
784