9
from itertools import count
22
from unittest import mock
24
from functorch.compile import min_cut_rematerialization_partition
26
import torch._functorch.config as functorch_config
29
import torch.utils._pytree as pytree
30
from torch._dynamo import (
32
config as dynamo_config,
33
logging as dynamo_logging,
34
utils as dynamo_utils,
36
from torch._dynamo.utils import (
39
lazy_format_graph_code,
42
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
43
from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache
45
from torch._inductor.debug import save_args_for_compile_fx_inner
46
from torch._logging import trace_structured
47
from torch._ops import OpOverload
48
from torch._subclasses.fake_tensor import FakeTensor
49
from torch._utils_internal import signpost_event
50
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
52
from .._dynamo.backends.common import aot_autograd
53
from ..fx._lazy_graph_module import _use_lazy_graph_module
54
from ..fx.graph import _PyTreeCodeGen
55
from . import config, metrics
56
from .debug import DebugContext
57
from .decomposition import select_decomp_table
58
from .fx_passes.joint_graph import joint_graph_passes
59
from .fx_passes.post_grad import post_grad_passes, view_to_reshape
60
from .fx_passes.pre_grad import pre_grad_passes
61
from .graph import GraphLowering
62
from .ir import ExternKernelNode
63
from .utils import get_dtype_size, has_incompatible_cudagraph_ops
64
from .virtualized import V
67
from torch._inductor.fb.utils import time_and_log
70
def time_and_log(attr: str, extra_loggings: Optional[Dict[str, str]] = None):
71
return dynamo_utils.identity
74
log = logging.getLogger(__name__)
75
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
76
post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs")
89
if isinstance(obj, BoxedBool):
96
class BoxedDeviceIndex:
99
def set(self, device_idx):
100
assert device_idx is None or isinstance(device_idx, int)
101
self.value = device_idx
108
def get_expanded_dims(t):
109
if not isinstance(t, torch.Tensor):
111
return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
114
def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor:
115
for expanded_dim in expanded_dims:
116
t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
120
def complex_memory_overlap(t: torch.Tensor) -> bool:
123
t = index_expanded_dims(t, get_expanded_dims(t))
124
if torch._debug_has_internal_overlap(t) != 0:
127
indices = list(range(len(strides)))
128
indices = [x for _, x in sorted(zip(strides, indices))]
129
for i in range(len(strides)):
130
prev_stride = 1 if i == 0 else strides[indices[i - 1]]
131
prev_size = 1 if i == 0 else sizes[indices[i - 1]]
132
if strides[indices[i]] < prev_stride * prev_size:
137
@functools.lru_cache(None)
139
return dynamo_logging.get_step_logger(log)
142
@functools.lru_cache(None)
143
def _warn_tf32_disabled():
145
torch.cuda.is_available()
146
and not torch.backends.cuda.matmul.allow_tf32
147
and torch.cuda.get_device_capability() >= (8, 0)
150
"TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. "
151
"Consider setting `torch.set_float32_matmul_precision('high')` for better performance."
155
def _unlift_graph(mod, gm, graph_signature):
156
from torch.export.unflatten import _assign_attr, _AttrKind
159
for name, param in mod.named_parameters(remove_duplicate=False):
160
state_dict[name] = param
165
attr_kind=_AttrKind.PARAMETER,
167
for name, buffer in mod.named_buffers(remove_duplicate=False):
168
state_dict[name] = buffer
173
attr_kind=_AttrKind.BUFFER,
176
placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
178
for node in placeholder_nodes:
179
node_name = node.name
180
if node_name in graph_signature.inputs_to_parameters:
181
lifted_inputs.append(graph_signature.inputs_to_parameters[node_name])
182
elif node_name in graph_signature.inputs_to_buffers:
183
lifted_inputs.append(graph_signature.inputs_to_buffers[node_name])
185
assert node_name in graph_signature.user_inputs
186
lifted_inputs.append(None)
188
from torch.export._unlift import _unlift
190
outputs = list(gm.graph.nodes)[-1].args[0]
193
if out in graph_signature.buffers_to_mutate:
194
mutated_outputs.append(graph_signature.buffers_to_mutate[out.name])
196
mutated_outputs.append(None)
198
unlifted_gm = _unlift(
210
def _get_subgraph_names(gm):
211
for node in gm.graph.nodes:
212
if node.target == torch.ops.higher_order.cond:
213
true_subgraph_name = node.args[1].name
214
false_subgraph_name = node.args[2].name
215
yield true_subgraph_name
216
yield false_subgraph_name
219
def _recursive_pre_grad_passes(gm, example_inputs):
220
for subgraph_name in _get_subgraph_names(gm):
221
subgraph = getattr(gm, subgraph_name)
223
new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None)
224
setattr(gm, subgraph_name, new_subgraph)
225
return pre_grad_passes(gm, example_inputs)
228
def _recursive_joint_graph_passes(gm):
229
for subgraph_name in _get_subgraph_names(gm):
230
subgraph = getattr(gm, subgraph_name)
231
_recursive_joint_graph_passes(subgraph)
232
joint_graph_passes(gm)
235
def _recursive_post_grad_passes(gm, is_inference: bool = False):
236
for subgraph_name in _get_subgraph_names(gm):
237
subgraph = getattr(gm, subgraph_name)
238
_recursive_post_grad_passes(subgraph, is_inference)
239
post_grad_passes(gm, is_inference)
243
gm: torch.fx.GraphModule,
244
) -> Tuple[torch.fx.GraphModule, Dict[str, int]]:
246
This function takes an GraphModule input "gm".
247
The gm will be split into 2 components,
248
1) const_gm, which consists the subgraph of gm that can be constant folded.
249
2) gm (being inplace modified,) which returns the graph after constant folding.
251
const_output_index is a mapping of corresponding node name from gm to the
252
output index of const_gm.
253
Returns (const_gm, const_output_index)
255
from torch._inductor.constant_folding import (
259
replace_node_with_constant,
260
run_and_get_constant_graph,
263
const_gm = run_and_get_constant_graph(gm)
264
const_result = const_gm()
267
x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0])
272
const_output_index = {}
273
for node in gm.graph.nodes:
274
if node.name in const_outputs:
275
to_replace_node.append(node)
276
elif node.meta[META_TAG] == CONST_MODULE_TAG:
277
to_erase_node.append(node)
279
for node in to_replace_node:
280
new_const_name = "_FOLDED_CONST_" + node.name
281
replace_node_with_constant(
284
const_result[const_outputs[node.name]],
287
const_output_index[new_const_name] = const_outputs[node.name]
288
for node in to_erase_node[::-1]:
291
assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty."
293
gm.graph.erase_node(node)
296
return const_gm, const_output_index
299
def is_tf32_warning_applicable(gm: torch.fx.GraphModule):
300
aten = torch.ops.aten
305
aten.baddbmm.default,
307
for node in gm.graph.nodes:
309
node.op == "call_function"
310
and node.target in tf32_ops
311
and isinstance(node.meta.get("val", None), torch.Tensor)
312
and node.meta["val"].dtype == torch.float32
313
and node.meta["val"].device.type == "cuda"
320
def count_bytes_inner(
321
gm: torch.fx.GraphModule,
322
example_inputs: List[torch.Tensor],
326
shape_env = _shape_env_from_inputs(example_inputs)
327
fake_mode = fake_tensor_prop(gm, example_inputs)
329
with V.set_fake_mode(fake_mode):
330
_recursive_post_grad_passes(gm, False)
332
graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
333
with V.set_graph_handler(graph), V.set_real_inputs(example_inputs):
334
graph.run(*example_inputs)
335
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
336
metrics.num_bytes_accessed += num_bytes
337
metrics.nodes_num_elem += nodes_num_elem
338
metrics.node_runtimes += node_runtimes
339
return make_boxed_func(gm.forward)
343
gm: torch.fx.GraphModule,
344
example_inputs: List[torch.Tensor],
345
force_allow_non_fake_inputs: bool = False,
348
If we can not detect fake mode from the context of inputs, create one.
350
The created fake mode will be returned.
352
fake_mode = detect_fake_mode(example_inputs)
354
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
355
FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
358
contextlib.nullcontext()
359
if not force_allow_non_fake_inputs
360
else mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
363
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
371
def get_patched_config_dict(config_patches=None) -> Dict[str, Any]:
372
with config.patch(config_patches):
373
return config.get_config_copy()
377
@torch.utils._python_dispatch._disable_current_modes()
379
attr="compilation time (in seconds)",
380
extra_loggings={"config_dict": str(get_patched_config_dict())},
386
@_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)
387
@dynamo_utils.dynamo_timed(phase_name="inductor_compile")
389
gm: torch.fx.GraphModule,
390
example_inputs: List[torch.Tensor],
391
cudagraphs: Optional[BoxedBool] = None,
393
is_backward: bool = False,
394
graph_id: Optional[int] = None,
395
cpp_wrapper: bool = False,
396
aot_mode: bool = False,
397
is_inference: bool = False,
398
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
399
user_visible_outputs: FrozenSet[str] = frozenset(),
400
layout_opt: Optional[bool] = None,
401
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
402
) -> Union[CompiledFxGraph, str]:
404
Inductor API that compiles a single graph.
406
If you change the argument list for this function, make sure you
407
also update the call to save_args_for_compile_fx_inner below accordingly.
409
if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode:
412
from torch.fx._lazy_graph_module import _LazyGraphModule
414
_LazyGraphModule.force_recompile(gm)
415
return make_boxed_func(gm.forward)
418
next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)
419
), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
422
save_args_for_compile_fx_inner(
425
cudagraphs=cudagraphs,
427
is_backward=is_backward,
429
cpp_wrapper=cpp_wrapper,
431
is_inference=is_inference,
432
boxed_forward_device_index=boxed_forward_device_index,
433
user_visible_outputs=user_visible_outputs,
434
layout_opt=layout_opt,
437
if cudagraphs is None:
438
cudagraphs = BoxedBool(config.triton.cudagraphs)
444
"cudagraphs": cudagraphs,
445
"num_fixed": num_fixed,
446
"is_backward": is_backward,
447
"graph_id": graph_id,
448
"cpp_wrapper": cpp_wrapper,
449
"aot_mode": aot_mode,
450
"is_inference": is_inference,
451
"user_visible_outputs": user_visible_outputs,
452
"layout_opt": layout_opt,
453
"extern_node_serializer": extern_node_serializer,
458
if config.fx_graph_cache and not aot_mode:
459
compiled_graph = FxGraphCache.load(
460
fx_codegen_and_compile, gm, example_inputs, graph_kwargs
463
compiled_graph = fx_codegen_and_compile(
464
gm, example_inputs, **graph_kwargs
467
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
470
if cudagraphs and compiled_graph.disabled_cudagraphs_reason:
471
perf_hint_log.warning(
472
"skipping cudagraphs due to %s", compiled_graph.disabled_cudagraphs_reason
474
BoxedBool.disable(cudagraphs)
477
context = torch._guards.TracingContext.try_get()
478
if context is not None and context.output_strides is not None:
479
assert len(context.output_strides) == 0
480
context.output_strides.extend(compiled_graph.output_strides)
483
return compiled_graph
487
output = list(gm.graph.nodes)[-1]
488
assert len(output.args) == 1
490
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
491
for arg in output.args[0]
494
complex_memory_overlap_inputs = any(
495
complex_memory_overlap(t)
496
for t in example_inputs
497
if isinstance(t, torch.Tensor)
500
from torch._inductor.cudagraph_utils import check_for_mutation
502
has_mutation_str = check_for_mutation(gm, compiled_graph, num_fixed)
503
has_mutation = has_mutation_str is not None
506
compiled_graph.disabled_cudagraphs_reason = has_mutation_str
509
(not has_mutation, "mutated inputs"),
510
(not has_incompatible_cudagraph_ops(gm), "incompatible ops"),
511
(not complex_memory_overlap_inputs, "complex memory overlap"),
514
isinstance(t, (torch.Tensor, torch.SymInt)) for t in example_inputs
519
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
521
if not cudagraph_fail_reasons:
522
if not config.triton.cudagraph_trees:
524
for t in example_inputs:
525
if isinstance(t, torch.SymInt):
529
boxed_forward_device_index is not None
533
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
535
compiled_graph.current_callable = cudagraphify(
536
compiled_graph.get_current_callable(),
538
static_input_idxs=range(num_fixed),
539
device_index=next(iter(compiled_graph.device_idxs)),
540
stack_traces=stack_traces,
541
is_backward=is_backward,
542
is_inference=is_inference,
543
constants=tuple(compiled_graph.constants.values()),
546
BoxedBool.disable(cudagraphs)
551
if is_backward and config.triton.cudagraph_trees:
552
assert boxed_forward_device_index is not None
553
assert boxed_forward_device_index.value is not None
554
compiled_graph_callable = compiled_graph.get_current_callable()
556
manager = torch._inductor.cudagraph_trees.get_manager(
557
boxed_forward_device_index.value, create_if_none_exists=False
560
assert manager is not None
562
def compiled_artifact(new_inputs):
563
manager.set_to_running_backward()
564
return compiled_graph_callable(new_inputs)
566
compiled_graph.current_callable = compiled_artifact
568
if "cuda" in compiled_graph.device_types:
571
if compiled_graph.disabled_cudagraphs_reason:
572
perf_hint_log.warning(compiled_graph.disabled_cudagraphs_reason)
574
perf_hint_log.warning(
575
"skipping cudagraphs due to %s", cudagraph_fail_reasons
580
new_callable = align_inputs(
581
compiled_graph.get_current_callable(), example_inputs, range(num_fixed)
583
if new_callable is not compiled_graph.get_current_callable():
584
compiled_graph.current_callable = new_callable
588
"torchinductor done compiling "
589
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
594
compiled_graph._boxed_call = True
595
return compiled_graph
598
def fx_codegen_and_compile(
599
gm: torch.fx.GraphModule,
600
example_inputs: List[torch.Tensor],
601
cudagraphs: Optional[BoxedBool] = None,
603
is_backward: bool = False,
604
graph_id: Optional[int] = None,
605
cpp_wrapper: bool = False,
606
aot_mode: bool = False,
607
is_inference: bool = False,
608
user_visible_outputs: FrozenSet[str] = frozenset(),
609
layout_opt: Optional[bool] = None,
610
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
611
) -> Union[CompiledFxGraph, str]:
612
if is_tf32_warning_applicable(gm):
613
_warn_tf32_disabled()
617
sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
621
"torchinductor compiling "
622
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
625
V.debug.fx_graph(gm, example_inputs)
630
shape_env = _shape_env_from_inputs(example_inputs)
654
with torch.no_grad():
655
fake_mode = fake_tensor_prop(gm, example_inputs)
661
with V.set_fake_mode(fake_mode):
663
_recursive_post_grad_passes(gm, is_inference=is_inference)
664
V.debug.fx_graph_transformed(gm, example_inputs)
665
post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm))
667
"inductor_post_grad_graph",
668
payload_fn=lambda: gm.print_readable(print_output=False),
670
optimus_scuba_log["inductor_post_grad"] = counters["inductor"]
673
"compile_fx.post_grad_passes",
677
with V.set_fake_mode(fake_mode):
678
const_output_index = None
682
if aot_mode and config.aot_inductor.use_runtime_constant_folding:
683
const_gm, const_output_index = split_const_gm(gm)
685
const_graph = GraphLowering(
689
num_static_inputs=num_fixed,
691
cpp_wrapper=cpp_wrapper,
693
user_visible_outputs=user_visible_outputs,
694
extern_node_serializer=extern_node_serializer,
695
is_inference=is_inference,
698
with V.set_graph_handler(const_graph):
699
assert cpp_wrapper, "AOT mode only supports C++ wrapper"
702
const_code, _ = const_graph.codegen_with_cpp_wrapper()
704
graph = GraphLowering(
709
example_inputs=example_inputs,
711
num_static_inputs=num_fixed,
713
cpp_wrapper=cpp_wrapper,
715
user_visible_outputs=user_visible_outputs,
716
extern_node_serializer=extern_node_serializer,
717
is_inference=is_inference,
718
const_output_index=const_output_index,
719
const_code=const_code,
720
const_module=const_graph,
722
with V.set_graph_handler(graph):
723
graph.run(*example_inputs)
724
output_strides: List[Optional[Tuple[int, ...]]] = []
725
if graph.graph_outputs is not None:
728
for out in graph.graph_outputs:
729
if hasattr(out, "layout"):
730
output_strides.append(
732
V.graph.sizevars.size_hint(s) for s in out.layout.stride
736
output_strides.append(None)
738
compiled_fn = graph.compile_to_fn()
740
if V.aot_compilation is True:
743
if cudagraphs and not V.graph.disable_cudagraphs_reason:
744
from torch._inductor.cudagraph_utils import (
745
check_lowering_disable_cudagraph,
748
V.graph.disable_cudagraphs_reason = check_lowering_disable_cudagraph(
749
V.graph.device_node_mapping
752
compiled_graph = CompiledFxGraph(
753
compiled_fn, graph, output_strides, V.graph.disable_cudagraphs_reason
756
return compiled_graph
759
def clone_preserve_strides(x: torch.Tensor):
761
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
763
buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
764
return torch.as_strided(buffer, x.size(), x.stride())
767
def copy_misaligned_inputs(
768
new_inputs: List[torch.Tensor], check_inputs_idxs: Sequence[int]
770
for i in check_inputs_idxs:
771
if new_inputs[i].data_ptr() % ALIGNMENT:
772
new_inputs[i] = clone_preserve_strides(new_inputs[i])
775
def get_input_idxs_to_check(
776
inputs: Union[List[torch.Tensor], Sequence[int]],
777
static_input_idxs: Sequence[int],
779
def is_aligned(storage_offset, dtype):
780
return (storage_offset * get_dtype_size(dtype)) % ALIGNMENT == 0
783
for i, input in enumerate(inputs):
785
isinstance(input, torch.Tensor)
787
i not in static_input_idxs
788
or not is_aligned(input.storage_offset(), input.dtype)
790
and input.device.type == "cuda"
792
ids_to_check.append(i)
796
def align_inputs_from_check_idxs(
797
model: Callable[[List[torch.Tensor]], Any], inputs_to_check: Sequence[int]
799
if len(inputs_to_check) == 0:
803
copy_misaligned_inputs(new_inputs, inputs_to_check)
804
return model(new_inputs)
810
model: Callable[[List[torch.Tensor]], Any],
811
inputs: List[torch.Tensor],
812
static_input_idxs: Sequence[int] = (),
814
inputs_to_check = get_input_idxs_to_check(inputs, static_input_idxs)
815
return align_inputs_from_check_idxs(model, inputs_to_check)
818
@dynamo_utils.dynamo_timed
820
model: torch.fx.GraphModule,
821
inputs: List[torch.Tensor],
822
static_input_idxs: Sequence[int] = (),
825
stack_traces: List[Optional[str]],
828
constants: Tuple[torch.Tensor, ...] = (),
830
from torch._inductor.cudagraph_trees import (
831
cudagraphify_impl as new_cudagraphify_impl,
834
cudagraphify_fn: Callable[..., Any]
835
if config.triton.cudagraph_trees:
836
cudagraphify_fn = functools.partial(
837
new_cudagraphify_impl,
838
device_index=device_index,
839
stack_traces=stack_traces,
840
is_backward=is_backward,
841
is_inference=is_inference,
845
cudagraphify_fn = cudagraphify_impl
848
if not any(isinstance(inp, FakeTensor) for inp in inputs):
849
return cudagraphify_fn(model, inputs, static_input_idxs)
855
if compiled_fn is None:
856
with dynamo_utils.preserve_rng_state():
857
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
858
return compiled_fn(new_inputs)
863
def remove_unaligned_input_idxs(
864
inputs: Union[List[torch.Tensor], Sequence[int]],
865
static_input_idxs: Sequence[int],
868
We require all inputs to be aligned, so introduce a copy for any
871
aligned_static_input_idxs = []
872
for idx, input in zip(static_input_idxs, inputs):
873
if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0:
874
aligned_static_input_idxs.append(idx)
875
if len(aligned_static_input_idxs) != len(static_input_idxs):
876
return aligned_static_input_idxs
877
return static_input_idxs
880
def static_input(x: torch.Tensor):
882
Copy and input while preserving strides
887
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
889
buffer = torch.empty(needed_size, dtype=x.dtype, device=x.device)
890
return torch.as_strided(buffer, x.size(), x.stride())
893
def index_expanded_dims_and_copy_(
896
expanded_dims: List[int],
898
"Index into expanded dimensions of both dst and src then copy_"
899
dst = index_expanded_dims(dst, expanded_dims)
900
src = index_expanded_dims(src, expanded_dims)
904
def cudagraphify_impl(
905
model: torch.fx.GraphModule,
906
inputs: List[torch.Tensor],
907
static_input_idxs: Sequence[int] = (),
910
Assumes inputs[static_input_idxs[i]] are always the same memory address
912
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
913
static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
914
copy_misaligned_inputs(inputs, check_input_idxs)
916
assert isinstance(inputs, list)
918
inps_expanded_dims = [
919
get_expanded_dims(x) if idx not in static_input_idxs else []
920
for idx, x in enumerate(inputs)
926
if not isinstance(x, torch.Tensor)
928
if idx not in static_input_idxs
930
for idx, x in enumerate(inputs)
934
for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)):
935
if isinstance(x, torch.Tensor) and idx not in static_input_idxs:
936
index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims)
939
torch.cuda.synchronize()
940
stream = torch.cuda.Stream()
941
stream.wait_stream(torch.cuda.current_stream())
943
with torch.cuda.stream(stream):
944
model(list(static_inputs))
946
torch.cuda.current_stream().wait_stream(stream)
947
torch.cuda.synchronize()
950
graph = torch.cuda.CUDAGraph()
951
with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):
952
static_outputs = model(list(static_inputs))
953
if not isinstance(static_outputs, (list, tuple)):
954
static_outputs = (static_outputs,)
956
if config.size_asserts:
959
assert len(static_inputs) == len(new_inputs)
960
for idx, (dst, src, expanded_dims) in enumerate(
961
zip(static_inputs, new_inputs, inps_expanded_dims)
963
if not isinstance(dst, torch.Tensor):
965
elif idx in static_input_idxs:
966
assert dst.data_ptr() == src.data_ptr()
971
index_expanded_dims_and_copy_(dst, src, expanded_dims)
974
return static_outputs
978
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
982
for idx in copy_indices:
983
expanded_dims = inps_expanded_dims[idx]
984
index_expanded_dims_and_copy_(
985
static_inputs[idx], new_inputs[idx], expanded_dims
989
return static_outputs
991
return align_inputs_from_check_idxs(run, check_input_idxs)
994
def count_tangents(fx_g: torch.fx.GraphModule):
996
Infers which inputs are static for a backwards graph
999
def is_saved_tensor(x):
1001
"tangents" not in x.name
1002
and "bwd_seed" not in x.name
1003
and "bwd_base_offset" not in x.name
1007
static_arg_idxs = []
1008
for n in fx_g.graph.nodes:
1009
if n.op == "placeholder":
1010
if is_saved_tensor(n):
1011
static_arg_idxs.append(arg_count)
1014
assert static_arg_idxs == list(range(len(static_arg_idxs)))
1015
return len(static_arg_idxs)
1019
model_: torch.fx.GraphModule,
1020
example_inputs_: List[torch.Tensor],
1021
inner_compile: Callable[..., Any] = compile_fx_inner,
1022
config_patches: Optional[Dict[str, Any]] = None,
1024
config_patches: Dict[str, Any] = (
1025
{"cpp_wrapper": True}
1026
if config_patches is None
1027
else {**config_patches, "cpp_wrapper": True}
1030
"aot_inductor.output_path" not in config_patches
1031
and not config.aot_inductor.output_path
1035
"aot_inductor.output_path": code_hash(model_.code),
1038
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
1039
with V.set_aot_compilation(True):
1040
compiled_lib_path = compile_fx(
1043
inner_compile=functools.partial(
1046
extern_node_serializer=extern_node_serializer,
1048
config_patches=config_patches,
1050
assert os.path.exists(
1052
), f"AOTInductor compiled library does not exist at {compiled_lib_path}"
1053
return compiled_lib_path
1056
_graph_counter = count(0)
1059
def fw_compiler_freezing(
1060
aot_autograd_model: torch.fx.GraphModule,
1061
aot_example_inputs: List[torch.Tensor],
1062
dynamo_model: torch.fx.GraphModule,
1063
num_example_inputs: int,
1064
inner_compile: Callable[..., Any],
1065
cudagraphs: BoxedBool,
1067
forward_device: BoxedDeviceIndex,
1069
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
1072
_recursive_joint_graph_passes(aot_autograd_model)
1074
layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True)
1077
fake_tensor_prop(aot_autograd_model, aot_example_inputs, True)
1078
convert_conv_weights_to_channels_last(aot_autograd_model)
1080
opt_model, preserved_arg_indices = freeze(
1086
aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
1087
num_fixed = len(preserved_arg_indices) - num_example_inputs
1089
fake_mode = detect_fake_mode(aot_example_inputs)
1092
*_, model_outputs_node = opt_model.graph.nodes
1093
model_outputs = model_outputs_node.args[0]
1094
user_visible_outputs = [
1095
n.name for n in model_outputs if isinstance(n, torch.fx.Node)
1099
tracing_context = torch._guards.TracingContext.try_get()
1100
if tracing_context is not None:
1101
params_flat = tracing_context.params_flat
1102
assert params_flat is not None
1103
for i in range(len(params_flat)):
1104
if i not in preserved_arg_indices:
1105
params_flat[i] = None
1107
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
1108
optimized_function = inner_compile(
1111
num_fixed=num_fixed,
1112
cudagraphs=cudagraphs,
1115
boxed_forward_device_index=forward_device,
1116
layout_opt=layout_opt,
1117
user_visible_outputs=user_visible_outputs,
1122
if V.aot_compilation is True:
1123
return optimized_function
1126
args_new = [args[i] for i in preserved_arg_indices]
1128
return optimized_function(args_new)
1130
wrapper._boxed_call = True
1135
@_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)
1137
model_: torch.fx.GraphModule,
1138
example_inputs_: List[torch.Tensor],
1139
inner_compile: Callable[..., Any] = compile_fx_inner,
1140
config_patches: Optional[Dict[str, Any]] = None,
1141
decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
1143
"""Main entrypoint to a compile given FX graph"""
1145
with config.patch(config_patches):
1150
inner_compile=config.patch(config_patches)(inner_compile),
1151
decompositions=decompositions,
1154
if config.cpp_wrapper:
1157
"cpp_wrapper": False,
1158
"triton.autotune_cublasLt": False,
1159
"triton.cudagraphs": False,
1160
"triton.store_cubin": True,
1162
), V.set_real_inputs(example_inputs_):
1163
inputs_ = example_inputs_
1164
if isinstance(model_, torch.fx.GraphModule):
1166
node.meta.get("val")
1167
for node in model_.graph.nodes
1168
if node.op == "placeholder"
1170
if all(v is not None for v in fake_inputs):
1172
for idx, fi, i in zip(count(), fake_inputs, inputs_):
1173
if fi.device != i.device:
1175
f"Device mismatch between fake input and example input at position #{idx}: "
1176
f"{fi.device} vs {i.device}. If the model was exported via torch.export(), "
1177
"make sure torch.export() and torch.aot_compile() run on the same device."
1179
inputs_ = fake_inputs
1183
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
1184
decompositions=decompositions,
1187
recursive_compile_fx = functools.partial(
1189
inner_compile=inner_compile,
1190
decompositions=decompositions,
1193
if not graph_returns_tuple(model_):
1194
return make_graph_return_tuple(
1197
recursive_compile_fx,
1200
if isinstance(model_, torch.fx.GraphModule):
1201
if isinstance(model_.graph._codegen, _PyTreeCodeGen):
1203
return handle_dynamo_export_graph(
1206
recursive_compile_fx,
1209
model_ = _recursive_pre_grad_passes(model_, example_inputs_)
1210
optimus_scuba_log["inductor_pre_grad"] = counters["inductor"]
1213
"compile_fx.pre_grad_passes",
1217
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
1218
return flatten_graph_inputs(
1221
recursive_compile_fx,
1224
assert not config._raise_error_for_testing
1225
num_example_inputs = len(example_inputs_)
1226
cudagraphs = BoxedBool(config.triton.cudagraphs)
1227
forward_device = BoxedDeviceIndex(None)
1229
graph_id = next(_graph_counter)
1232
decompositions if decompositions is not None else select_decomp_table()
1235
@dynamo_utils.dynamo_timed
1236
def fw_compiler_base(
1237
model: torch.fx.GraphModule,
1238
example_inputs: List[torch.Tensor],
1243
_recursive_joint_graph_passes(model)
1245
num_rng_seed_offset_inputs = 2 if functorch_config.functionalize_rng_ops else 0
1246
fixed = len(example_inputs) - num_example_inputs - num_rng_seed_offset_inputs
1247
user_visible_outputs = set()
1249
if config.keep_output_stride:
1250
*_, model_outputs_node = model.graph.nodes
1251
assert model_outputs_node.op == "output"
1252
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
1253
num_model_outputs = len(model_outputs)
1255
context = torch._guards.TracingContext.try_get()
1257
if context is not None and context.fw_metadata and not is_inference:
1258
original_output_start_index = (
1259
context.fw_metadata.num_mutated_inp_runtime_indices
1262
original_output_start_index = 0
1264
if isinstance(model_, torch.fx.GraphModule):
1265
*_, orig_model_outputs_node = model_.graph.nodes
1266
assert orig_model_outputs_node.op == "output"
1267
orig_model_outputs, _ = pytree.tree_flatten(
1268
orig_model_outputs_node.args
1270
num_orig_model_outputs = len(orig_model_outputs)
1272
num_orig_model_outputs = num_model_outputs
1274
assert num_orig_model_outputs <= num_model_outputs
1289
orig_output_end_idx = original_output_start_index + num_orig_model_outputs
1292
assert orig_output_end_idx <= num_model_outputs
1294
user_visible_outputs = {
1296
for n in model_outputs[original_output_start_index:orig_output_end_idx]
1297
if isinstance(n, torch.fx.Node)
1300
return inner_compile(
1304
cudagraphs=cudagraphs,
1306
is_inference=is_inference,
1307
boxed_forward_device_index=forward_device,
1308
user_visible_outputs=user_visible_outputs,
1311
fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
1313
if config.freezing and not torch.is_grad_enabled():
1314
inference_compiler = functools.partial(
1315
fw_compiler_freezing,
1316
dynamo_model=model_,
1317
num_example_inputs=num_example_inputs,
1318
inner_compile=inner_compile,
1319
cudagraphs=cudagraphs,
1321
forward_device=forward_device,
1324
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
1326
def partition_fn(graph, joint_inputs, **kwargs):
1327
_recursive_joint_graph_passes(graph)
1328
return min_cut_rematerialization_partition(
1329
graph, joint_inputs, **kwargs, compiler="inductor"
1332
@dynamo_utils.dynamo_timed
1333
@dynamo_utils.maybe_cprofile
1334
def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
1335
fixed = count_tangents(model)
1336
return inner_compile(
1340
cudagraphs=cudagraphs,
1343
boxed_forward_device_index=forward_device,
1350
fake_mode = detect_fake_mode(example_inputs_) or torch._subclasses.FakeTensorMode(
1351
allow_non_fake_inputs=True
1354
torch._guards.TracingContext.try_get()
1355
or torch._guards.TracingContext(fake_mode)
1358
if V.aot_compilation is True:
1359
gm, graph_signature = aot_export_module(
1360
model_, example_inputs_, trace_joint=False, decompositions=decompositions
1362
unlifted_gm = _unlift_graph(model_, gm, graph_signature)
1363
if "dynamo_flat_name_to_original_fqn" in model_.meta:
1364
unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[
1365
"dynamo_flat_name_to_original_fqn"
1367
with V.set_fake_mode(fake_mode), compiled_autograd.disable():
1368
return inference_compiler(unlifted_gm, example_inputs_)
1370
with V.set_fake_mode(fake_mode), torch._guards.tracing(
1372
), compiled_autograd.disable():
1373
return aot_autograd(
1374
fw_compiler=fw_compiler,
1375
bw_compiler=bw_compiler,
1376
inference_compiler=inference_compiler,
1377
decompositions=decompositions,
1378
partition_fn=partition_fn,
1379
keep_inference_input_mutations=True,
1380
)(model_, example_inputs_)
1383
def _shape_env_from_inputs(inputs: List[torch.Tensor]):
1385
fake_mode = detect_fake_mode(inputs)
1392
if fake_mode is not None:
1393
return fake_mode.shape_env
1396
for input in inputs:
1397
if isinstance(input, torch.SymInt):
1398
return input.node.shape_env
1404
def output_node(gm: torch.fx.GraphModule):
1405
"""Get the output node from an FX graph"""
1406
last_node = next(iter(reversed(gm.graph.nodes)))
1407
assert last_node.op == "output"
1411
def graph_returns_tuple(gm: torch.fx.GraphModule):
1412
"""True if a FX graph returns a tuple"""
1413
if not isinstance(gm, torch.fx.GraphModule):
1415
(rv,) = output_node(gm).args
1416
if isinstance(rv, (list, tuple)):
1419
isinstance(rv, torch.fx.node.Node)
1420
and hasattr(rv.target, "_schema")
1421
and len(rv.target._schema.returns) > 1
1422
and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns)
1429
def make_graph_return_tuple(
1430
gm: torch.fx.GraphModule,
1431
inputs: List[torch.Tensor],
1432
compile_gm: Callable[..., Any],
1435
Mutate gm so it returns a tuple. This is only needed for graphs
1436
not created by torchdynamo that return non-tuples.
1438
node = output_node(gm)
1440
rv, spec = pytree.tree_flatten(rv)
1441
with gm.graph.inserting_before(node):
1443
gm.graph.erase_node(node)
1444
assert graph_returns_tuple(gm)
1446
compiled_fn = compile_gm(gm, inputs)
1448
@functools.wraps(compiled_fn)
1449
def wrapper(*args, **kwargs):
1450
return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec)
1455
def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
1457
Mutate inputs so that they are flat and wrap gm such that it
1458
accepts those inputs. This is only needed for graphs not created
1459
by torchdynamo that take bumpy inputs.
1461
inputs, spec = pytree.tree_flatten(inputs)
1463
class GmWrapper(torch.nn.Module):
1468
def forward(self, *args):
1469
args: List[Any] = list(args)
1470
return self.gm(*pytree.tree_unflatten(args, spec))
1472
compiled_fn = compile_gm(GmWrapper(), inputs)
1474
@functools.wraps(compiled_fn)
1477
return compiled_fn(*pytree.arg_tree_leaves(*args))
1482
def handle_dynamo_export_graph(
1483
gm: torch.fx.GraphModule,
1484
inputs: List[torch.Tensor],
1485
compile_gm: Callable[..., Any],
1488
`torch._dynamo.export` embeds pytrees in the FX graph codegen object,
1489
convert that to a normal FX graph so inductor can compile it.
1491
codegen = gm.graph._codegen
1492
gm.graph._codegen = torch.fx.graph.CodeGen()
1495
compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs))
1497
@functools.wraps(compiled_fn)
1499
return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args)))