pytorch

Форк
0
/
compile_fx.py 
1501 строка · 51.5 Кб
1
import contextlib
2
import dataclasses
3
import functools
4
import logging
5
import os
6
import sys
7
import time
8
import warnings
9
from itertools import count
10

11
from typing import (
12
    Any,
13
    Callable,
14
    Dict,
15
    FrozenSet,
16
    List,
17
    Optional,
18
    Sequence,
19
    Tuple,
20
    Union,
21
)
22
from unittest import mock
23

24
from functorch.compile import min_cut_rematerialization_partition
25

26
import torch._functorch.config as functorch_config
27

28
import torch.fx
29
import torch.utils._pytree as pytree
30
from torch._dynamo import (
31
    compiled_autograd,
32
    config as dynamo_config,
33
    logging as dynamo_logging,
34
    utils as dynamo_utils,
35
)
36
from torch._dynamo.utils import (
37
    counters,
38
    detect_fake_mode,
39
    lazy_format_graph_code,
40
    optimus_scuba_log,
41
)
42
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
43
from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache
44

45
from torch._inductor.debug import save_args_for_compile_fx_inner
46
from torch._logging import trace_structured
47
from torch._ops import OpOverload
48
from torch._subclasses.fake_tensor import FakeTensor
49
from torch._utils_internal import signpost_event
50
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
51

52
from .._dynamo.backends.common import aot_autograd
53
from ..fx._lazy_graph_module import _use_lazy_graph_module  # type: ignore[attr-defined]
54
from ..fx.graph import _PyTreeCodeGen
55
from . import config, metrics
56
from .debug import DebugContext
57
from .decomposition import select_decomp_table
58
from .fx_passes.joint_graph import joint_graph_passes
59
from .fx_passes.post_grad import post_grad_passes, view_to_reshape
60
from .fx_passes.pre_grad import pre_grad_passes
61
from .graph import GraphLowering
62
from .ir import ExternKernelNode
63
from .utils import get_dtype_size, has_incompatible_cudagraph_ops
64
from .virtualized import V
65

66
if config.is_fbcode():
67
    from torch._inductor.fb.utils import time_and_log
68
else:
69
    # no-op decorator
70
    def time_and_log(attr: str, extra_loggings: Optional[Dict[str, str]] = None):
71
        return dynamo_utils.identity
72

73

74
log = logging.getLogger(__name__)
75
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
76
post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs")
77
ALIGNMENT = 16
78

79

80
@dataclasses.dataclass
81
class BoxedBool:
82
    value: bool
83

84
    def __bool__(self):
85
        return self.value
86

87
    @staticmethod
88
    def disable(obj):
89
        if isinstance(obj, BoxedBool):
90
            obj.value = False
91
            return obj
92
        return False
93

94

95
@dataclasses.dataclass
96
class BoxedDeviceIndex:
97
    value: Optional[int]
98

99
    def set(self, device_idx):
100
        assert device_idx is None or isinstance(device_idx, int)
101
        self.value = device_idx
102

103

104
# copy_ fails when trying to write to tensors with memory overlap,
105
# for expanded dimensions (a dimension which used to have size 1 -> ?)
106
# we can select one element from that dimension and write to it
107
# to achieve writing to all values of that dimension of the input tensor
108
def get_expanded_dims(t):
109
    if not isinstance(t, torch.Tensor):
110
        return None
111
    return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
112

113

114
def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor:
115
    for expanded_dim in expanded_dims:
116
        t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
117
    return t
118

119

120
def complex_memory_overlap(t: torch.Tensor) -> bool:
121
    # if torch._debug_has_internal_overlap thinks this tensor potentially has
122
    # memory overlap internally, let's dig deeper to find out whether it's true.
123
    t = index_expanded_dims(t, get_expanded_dims(t))
124
    if torch._debug_has_internal_overlap(t) != 0:
125
        strides = t.stride()
126
        sizes = t.shape
127
        indices = list(range(len(strides)))
128
        indices = [x for _, x in sorted(zip(strides, indices))]
129
        for i in range(len(strides)):
130
            prev_stride = 1 if i == 0 else strides[indices[i - 1]]
131
            prev_size = 1 if i == 0 else sizes[indices[i - 1]]
132
            if strides[indices[i]] < prev_stride * prev_size:
133
                return True
134
    return False
135

136

137
@functools.lru_cache(None)
138
def _step_logger():
139
    return dynamo_logging.get_step_logger(log)
140

141

142
@functools.lru_cache(None)
143
def _warn_tf32_disabled():
144
    if (
145
        torch.cuda.is_available()
146
        and not torch.backends.cuda.matmul.allow_tf32
147
        and torch.cuda.get_device_capability() >= (8, 0)
148
    ):
149
        warnings.warn(
150
            "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. "
151
            "Consider setting `torch.set_float32_matmul_precision('high')` for better performance."
152
        )
153

154

155
def _unlift_graph(mod, gm, graph_signature):
156
    from torch.export.unflatten import _assign_attr, _AttrKind
157

158
    state_dict = {}
159
    for name, param in mod.named_parameters(remove_duplicate=False):
160
        state_dict[name] = param
161
        _assign_attr(
162
            param,
163
            gm,
164
            name,
165
            attr_kind=_AttrKind.PARAMETER,
166
        )
167
    for name, buffer in mod.named_buffers(remove_duplicate=False):
168
        state_dict[name] = buffer
169
        _assign_attr(
170
            buffer,
171
            gm,
172
            name,
173
            attr_kind=_AttrKind.BUFFER,
174
        )
175

176
    placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
177
    lifted_inputs = []
178
    for node in placeholder_nodes:
179
        node_name = node.name
180
        if node_name in graph_signature.inputs_to_parameters:
181
            lifted_inputs.append(graph_signature.inputs_to_parameters[node_name])
182
        elif node_name in graph_signature.inputs_to_buffers:
183
            lifted_inputs.append(graph_signature.inputs_to_buffers[node_name])
184
        else:
185
            assert node_name in graph_signature.user_inputs
186
            lifted_inputs.append(None)
187

188
    from torch.export._unlift import _unlift
189

190
    outputs = list(gm.graph.nodes)[-1].args[0]
191
    mutated_outputs = []
192
    for out in outputs:
193
        if out in graph_signature.buffers_to_mutate:
194
            mutated_outputs.append(graph_signature.buffers_to_mutate[out.name])
195
        else:
196
            mutated_outputs.append(None)
197

198
    unlifted_gm = _unlift(
199
        gm,
200
        lifted_inputs,
201
        mutated_outputs,
202
        pytree.LeafSpec(),
203
        None,
204
        state_dict,
205
        {},
206
    )
207
    return unlifted_gm
208

209

210
def _get_subgraph_names(gm):
211
    for node in gm.graph.nodes:
212
        if node.target == torch.ops.higher_order.cond:
213
            true_subgraph_name = node.args[1].name
214
            false_subgraph_name = node.args[2].name
215
            yield true_subgraph_name
216
            yield false_subgraph_name
217

218

219
def _recursive_pre_grad_passes(gm, example_inputs):
220
    for subgraph_name in _get_subgraph_names(gm):
221
        subgraph = getattr(gm, subgraph_name)
222
        # as we don't have recursive example inputs, passing None here
223
        new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None)
224
        setattr(gm, subgraph_name, new_subgraph)
225
    return pre_grad_passes(gm, example_inputs)
226

227

228
def _recursive_joint_graph_passes(gm):
229
    for subgraph_name in _get_subgraph_names(gm):
230
        subgraph = getattr(gm, subgraph_name)
231
        _recursive_joint_graph_passes(subgraph)
232
    joint_graph_passes(gm)
233

234

235
def _recursive_post_grad_passes(gm, is_inference: bool = False):
236
    for subgraph_name in _get_subgraph_names(gm):
237
        subgraph = getattr(gm, subgraph_name)
238
        _recursive_post_grad_passes(subgraph, is_inference)
239
    post_grad_passes(gm, is_inference)
240

241

242
def split_const_gm(
243
    gm: torch.fx.GraphModule,
244
) -> Tuple[torch.fx.GraphModule, Dict[str, int]]:
245
    """
246
    This function takes an GraphModule input "gm".
247
    The gm will be split into 2 components,
248
      1) const_gm, which consists the subgraph of gm that can be constant folded.
249
      2) gm (being inplace modified,) which returns the graph after constant folding.
250

251
    const_output_index is a mapping of corresponding node name from gm to the
252
    output index of const_gm.
253
    Returns (const_gm, const_output_index)
254
    """
255
    from torch._inductor.constant_folding import (
256
        CONST_MODULE_TAG,
257
        META_TAG,
258
        MODULE_TAG,
259
        replace_node_with_constant,
260
        run_and_get_constant_graph,
261
    )
262

263
    const_gm = run_and_get_constant_graph(gm)
264
    const_result = const_gm()
265

266
    const_outputs = {
267
        x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0])
268
    }
269

270
    to_erase_node = []
271
    to_replace_node = []
272
    const_output_index = {}
273
    for node in gm.graph.nodes:
274
        if node.name in const_outputs:
275
            to_replace_node.append(node)
276
        elif node.meta[META_TAG] == CONST_MODULE_TAG:
277
            to_erase_node.append(node)
278

279
    for node in to_replace_node:
280
        new_const_name = "_FOLDED_CONST_" + node.name
281
        replace_node_with_constant(
282
            gm,
283
            node,
284
            const_result[const_outputs[node.name]],
285
            new_const_name,
286
        )
287
        const_output_index[new_const_name] = const_outputs[node.name]
288
    for node in to_erase_node[::-1]:
289
        if node.users:
290
            for n in node.users:
291
                assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty."
292
        else:
293
            gm.graph.erase_node(node)
294
    gm.recompile()
295

296
    return const_gm, const_output_index
297

298

299
def is_tf32_warning_applicable(gm: torch.fx.GraphModule):
300
    aten = torch.ops.aten
301
    tf32_ops = {
302
        aten.mm.default,
303
        aten.addmm.default,
304
        aten.bmm.default,
305
        aten.baddbmm.default,
306
    }
307
    for node in gm.graph.nodes:
308
        if (
309
            node.op == "call_function"
310
            and node.target in tf32_ops
311
            and isinstance(node.meta.get("val", None), torch.Tensor)
312
            and node.meta["val"].dtype == torch.float32
313
            and node.meta["val"].device.type == "cuda"
314
        ):
315
            return True
316
    return False
317

318

319
@DebugContext.wrap
320
def count_bytes_inner(
321
    gm: torch.fx.GraphModule,
322
    example_inputs: List[torch.Tensor],
323
    num_fixed: int = 0,
324
    **kwargs,
325
):
326
    shape_env = _shape_env_from_inputs(example_inputs)
327
    fake_mode = fake_tensor_prop(gm, example_inputs)
328

329
    with V.set_fake_mode(fake_mode):
330
        _recursive_post_grad_passes(gm, False)
331

332
    graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
333
    with V.set_graph_handler(graph), V.set_real_inputs(example_inputs):
334
        graph.run(*example_inputs)
335
        num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
336
        metrics.num_bytes_accessed += num_bytes
337
        metrics.nodes_num_elem += nodes_num_elem
338
        metrics.node_runtimes += node_runtimes
339
    return make_boxed_func(gm.forward)
340

341

342
def fake_tensor_prop(
343
    gm: torch.fx.GraphModule,
344
    example_inputs: List[torch.Tensor],
345
    force_allow_non_fake_inputs: bool = False,
346
):
347
    """
348
    If we can not detect fake mode from the context of inputs, create one.
349

350
    The created fake mode will be returned.
351
    """
352
    fake_mode = detect_fake_mode(example_inputs)
353
    if not fake_mode:
354
        fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
355
        FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
356
    else:
357
        ctx = (
358
            contextlib.nullcontext()
359
            if not force_allow_non_fake_inputs
360
            else mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
361
        )
362
        with ctx:  # type: ignore[attr-defined]
363
            FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
364
                *example_inputs
365
            )
366

367
    return fake_mode
368

369

370
# pass config dict back to user
371
def get_patched_config_dict(config_patches=None) -> Dict[str, Any]:
372
    with config.patch(config_patches):
373
        return config.get_config_copy()
374

375

376
@DebugContext.wrap
377
@torch.utils._python_dispatch._disable_current_modes()
378
@time_and_log(
379
    attr="compilation time (in seconds)",
380
    extra_loggings={"config_dict": str(get_patched_config_dict())},
381
)
382
# Need this decorator for compile_fx_inner even if we already have one for
383
# compile_fx. The reason is the compilation for backward graph may happen after
384
# compile_fx return and we may want to use the _LazyGraphModule for compiling
385
# the backward graph as well.
386
@_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)
387
@dynamo_utils.dynamo_timed(phase_name="inductor_compile")
388
def compile_fx_inner(
389
    gm: torch.fx.GraphModule,
390
    example_inputs: List[torch.Tensor],
391
    cudagraphs: Optional[BoxedBool] = None,
392
    num_fixed: int = 0,
393
    is_backward: bool = False,
394
    graph_id: Optional[int] = None,
395
    cpp_wrapper: bool = False,
396
    aot_mode: bool = False,
397
    is_inference: bool = False,
398
    boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
399
    user_visible_outputs: FrozenSet[str] = frozenset(),
400
    layout_opt: Optional[bool] = None,
401
    extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
402
) -> Union[CompiledFxGraph, str]:
403
    """
404
    Inductor API that compiles a single graph.
405

406
    If you change the argument list for this function, make sure you
407
    also update the call to save_args_for_compile_fx_inner below accordingly.
408
    """
409
    if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode:
410
        # trigger the real recompilation for _LazyGraphModule before returning
411
        # the forward method.
412
        from torch.fx._lazy_graph_module import _LazyGraphModule
413

414
        _LazyGraphModule.force_recompile(gm)
415
        return make_boxed_func(gm.forward)
416

417
    assert isinstance(
418
        next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)
419
    ), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
420

421
    if config.save_args:
422
        save_args_for_compile_fx_inner(
423
            gm,
424
            example_inputs,
425
            cudagraphs=cudagraphs,
426
            num_fixed=num_fixed,
427
            is_backward=is_backward,
428
            graph_id=graph_id,
429
            cpp_wrapper=cpp_wrapper,
430
            aot_mode=aot_mode,
431
            is_inference=is_inference,
432
            boxed_forward_device_index=boxed_forward_device_index,
433
            user_visible_outputs=user_visible_outputs,
434
            layout_opt=layout_opt,
435
        )
436

437
    if cudagraphs is None:
438
        cudagraphs = BoxedBool(config.triton.cudagraphs)
439

440
    # Inputs to fx_codegen_and_compile
441
    # Anything that affects codegen should go here, so if the signature
442
    # of fx_codegen_and_compile changes, the dict should be updated accordingly
443
    graph_kwargs = {
444
        "cudagraphs": cudagraphs,
445
        "num_fixed": num_fixed,
446
        "is_backward": is_backward,
447
        "graph_id": graph_id,
448
        "cpp_wrapper": cpp_wrapper,
449
        "aot_mode": aot_mode,
450
        "is_inference": is_inference,
451
        "user_visible_outputs": user_visible_outputs,
452
        "layout_opt": layout_opt,
453
        "extern_node_serializer": extern_node_serializer,
454
    }
455

456
    start = time.time()
457

458
    if config.fx_graph_cache and not aot_mode:
459
        compiled_graph = FxGraphCache.load(
460
            fx_codegen_and_compile, gm, example_inputs, graph_kwargs
461
        )
462
    else:
463
        compiled_graph = fx_codegen_and_compile(
464
            gm, example_inputs, **graph_kwargs  # type: ignore[arg-type]
465
        )
466

467
    log.debug("FX codegen and compilation took %.3fs", time.time() - start)
468

469
    # check cudagraph disabling reasons from inductor lowering
470
    if cudagraphs and compiled_graph.disabled_cudagraphs_reason:
471
        perf_hint_log.warning(
472
            "skipping cudagraphs due to %s", compiled_graph.disabled_cudagraphs_reason
473
        )
474
        BoxedBool.disable(cudagraphs)
475

476
    # Return the output strides to the caller via TracingContext
477
    context = torch._guards.TracingContext.try_get()
478
    if context is not None and context.output_strides is not None:
479
        assert len(context.output_strides) == 0
480
        context.output_strides.extend(compiled_graph.output_strides)
481

482
    if aot_mode:
483
        return compiled_graph
484

485
    if cudagraphs:
486
        # output args are tuple of first argument
487
        output = list(gm.graph.nodes)[-1]
488
        assert len(output.args) == 1
489
        stack_traces = [
490
            (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
491
            for arg in output.args[0]
492
        ]
493

494
        complex_memory_overlap_inputs = any(
495
            complex_memory_overlap(t)
496
            for t in example_inputs
497
            if isinstance(t, torch.Tensor)
498
        )
499

500
        from torch._inductor.cudagraph_utils import check_for_mutation
501

502
        has_mutation_str = check_for_mutation(gm, compiled_graph, num_fixed)
503
        has_mutation = has_mutation_str is not None
504

505
        if has_mutation:
506
            compiled_graph.disabled_cudagraphs_reason = has_mutation_str
507

508
        cudagraph_tests = [
509
            (not has_mutation, "mutated inputs"),
510
            (not has_incompatible_cudagraph_ops(gm), "incompatible ops"),
511
            (not complex_memory_overlap_inputs, "complex memory overlap"),
512
            (
513
                all(
514
                    isinstance(t, (torch.Tensor, torch.SymInt)) for t in example_inputs
515
                ),
516
                "non-Tensor inputs",
517
            ),
518
        ]
519
        cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
520

521
        if not cudagraph_fail_reasons:
522
            if not config.triton.cudagraph_trees:
523
                # Force specialize all inputs so that CUDA graphs will work
524
                for t in example_inputs:
525
                    if isinstance(t, torch.SymInt):
526
                        int(t)  # guard
527

528
            if (
529
                boxed_forward_device_index is not None
530
                and not is_inference
531
                and not is_backward
532
            ):
533
                boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
534

535
            compiled_graph.current_callable = cudagraphify(
536
                compiled_graph.get_current_callable(),
537
                example_inputs,
538
                static_input_idxs=range(num_fixed),
539
                device_index=next(iter(compiled_graph.device_idxs)),
540
                stack_traces=stack_traces,
541
                is_backward=is_backward,
542
                is_inference=is_inference,
543
                constants=tuple(compiled_graph.constants.values()),
544
            )
545
        else:
546
            BoxedBool.disable(cudagraphs)
547

548
            # See [Backward Generation Handling]
549
            # if cudagraph'd the forward and set the device, we need to let the cudagraph manager
550
            # know we are we running the backward even if we will not run it in cudagraphs
551
            if is_backward and config.triton.cudagraph_trees:
552
                assert boxed_forward_device_index is not None
553
                assert boxed_forward_device_index.value is not None
554
                compiled_graph_callable = compiled_graph.get_current_callable()
555

556
                manager = torch._inductor.cudagraph_trees.get_manager(
557
                    boxed_forward_device_index.value, create_if_none_exists=False
558
                )
559
                # should already exist from forward
560
                assert manager is not None
561

562
                def compiled_artifact(new_inputs):
563
                    manager.set_to_running_backward()
564
                    return compiled_graph_callable(new_inputs)
565

566
                compiled_graph.current_callable = compiled_artifact
567

568
            if "cuda" in compiled_graph.device_types:
569
                # prefer better disable_cudagraphs_reason bc stack trace
570
                # TODO: migrate all disable reasons to stack trace, refactor
571
                if compiled_graph.disabled_cudagraphs_reason:
572
                    perf_hint_log.warning(compiled_graph.disabled_cudagraphs_reason)
573
                else:
574
                    perf_hint_log.warning(
575
                        "skipping cudagraphs due to %s", cudagraph_fail_reasons
576
                    )
577

578
    # cudagraphs does its own aligning of inputs
579
    if not cudagraphs:
580
        new_callable = align_inputs(
581
            compiled_graph.get_current_callable(), example_inputs, range(num_fixed)
582
        )
583
        if new_callable is not compiled_graph.get_current_callable():
584
            compiled_graph.current_callable = new_callable
585

586
    _step_logger()(
587
        logging.INFO,
588
        "torchinductor done compiling "
589
        f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
590
        f"graph {graph_id}",
591
    )
592

593
    # aot autograd needs to know to pass in inputs as a list
594
    compiled_graph._boxed_call = True
595
    return compiled_graph
596

597

598
def fx_codegen_and_compile(
599
    gm: torch.fx.GraphModule,
600
    example_inputs: List[torch.Tensor],
601
    cudagraphs: Optional[BoxedBool] = None,
602
    num_fixed: int = 0,
603
    is_backward: bool = False,
604
    graph_id: Optional[int] = None,
605
    cpp_wrapper: bool = False,
606
    aot_mode: bool = False,
607
    is_inference: bool = False,
608
    user_visible_outputs: FrozenSet[str] = frozenset(),
609
    layout_opt: Optional[bool] = None,
610
    extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
611
) -> Union[CompiledFxGraph, str]:
612
    if is_tf32_warning_applicable(gm):
613
        _warn_tf32_disabled()
614

615
    # lift the maximum depth of the Python interpreter stack
616
    # to adapt large/deep models
617
    sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
618

619
    _step_logger()(
620
        logging.INFO,
621
        "torchinductor compiling "
622
        f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
623
        f"graph {graph_id}",
624
    )
625
    V.debug.fx_graph(gm, example_inputs)
626
    # TODO: Should we actually dump this?  It should be redundant with the aot
627
    # structured logs...
628
    # trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False))
629

630
    shape_env = _shape_env_from_inputs(example_inputs)
631

632
    # Convert view to reshape in the graph. This is necessary primarily for
633
    # layout optimization. Do it unconditionally for uniformity.
634
    #
635
    # It's needed because when we do layout optimization, an contiguous tensor
636
    # in eager mode may becomes a channels last tensor. A view op previously
637
    # can be applied to the contiguous tensor may not be able to be applied
638
    # on the channels tensor any more. An error like
639
    #   RuntimeError: view size is not compatible with input tensor's size and stride
640
    #   (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
641
    # will be printed.
642
    #
643
    # Replace view op to reshape op in this case.
644
    # As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this.
645
    #
646
    # Also this has to be done before FakeTensorProp below to avoid the failed
647
    # .view() call.
648
    view_to_reshape(gm)
649

650
    # It is safe to run FakeTensorProp under no_grad because by the time
651
    # we're in inductor, we assume that AOTAutograd has already "taken care"
652
    # of autograd, so there should be no more autograd-related API's in the
653
    # graph.
654
    with torch.no_grad():
655
        fake_mode = fake_tensor_prop(gm, example_inputs)
656

657
    # pattern matcher passes might not preserve striding information
658
    # on node.meta["val"]. if in the future we rely on these being
659
    # correct we will need to fix.
660

661
    with V.set_fake_mode(fake_mode):
662
        # has some issues with memory in training
663
        _recursive_post_grad_passes(gm, is_inference=is_inference)
664
        V.debug.fx_graph_transformed(gm, example_inputs)
665
        post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm))
666
        trace_structured(
667
            "inductor_post_grad_graph",
668
            payload_fn=lambda: gm.print_readable(print_output=False),
669
        )
670
        optimus_scuba_log["inductor_post_grad"] = counters["inductor"]
671
        signpost_event(
672
            "optimus",
673
            "compile_fx.post_grad_passes",
674
            optimus_scuba_log,
675
        )
676

677
    with V.set_fake_mode(fake_mode):
678
        const_output_index = None
679
        const_graph = None
680
        const_code = None
681

682
        if aot_mode and config.aot_inductor.use_runtime_constant_folding:
683
            const_gm, const_output_index = split_const_gm(gm)
684

685
            const_graph = GraphLowering(
686
                const_gm,
687
                example_inputs=[],
688
                shape_env=shape_env,
689
                num_static_inputs=num_fixed,
690
                graph_id=graph_id,
691
                cpp_wrapper=cpp_wrapper,
692
                aot_mode=aot_mode,
693
                user_visible_outputs=user_visible_outputs,
694
                extern_node_serializer=extern_node_serializer,
695
                is_inference=is_inference,
696
                is_const_graph=True,
697
            )
698
            with V.set_graph_handler(const_graph):
699
                assert cpp_wrapper, "AOT mode only supports C++ wrapper"
700
                const_graph.run()
701

702
                const_code, _ = const_graph.codegen_with_cpp_wrapper()
703

704
        graph = GraphLowering(
705
            gm,
706
            # example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
707
            # For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
708
            # we currently use fake tensors and defake them later.
709
            example_inputs=example_inputs,
710
            shape_env=shape_env,
711
            num_static_inputs=num_fixed,
712
            graph_id=graph_id,
713
            cpp_wrapper=cpp_wrapper,
714
            aot_mode=aot_mode,
715
            user_visible_outputs=user_visible_outputs,
716
            extern_node_serializer=extern_node_serializer,
717
            is_inference=is_inference,
718
            const_output_index=const_output_index,
719
            const_code=const_code,
720
            const_module=const_graph,
721
        )
722
        with V.set_graph_handler(graph):
723
            graph.run(*example_inputs)
724
            output_strides: List[Optional[Tuple[int, ...]]] = []
725
            if graph.graph_outputs is not None:
726
                # We'll put the output strides in the compiled graph so we
727
                # can later return them to the caller via TracingContext
728
                for out in graph.graph_outputs:
729
                    if hasattr(out, "layout"):
730
                        output_strides.append(
731
                            tuple(
732
                                V.graph.sizevars.size_hint(s) for s in out.layout.stride
733
                            )
734
                        )
735
                    else:
736
                        output_strides.append(None)
737

738
            compiled_fn = graph.compile_to_fn()
739

740
            if V.aot_compilation is True:
741
                return compiled_fn
742

743
            if cudagraphs and not V.graph.disable_cudagraphs_reason:
744
                from torch._inductor.cudagraph_utils import (
745
                    check_lowering_disable_cudagraph,
746
                )
747

748
                V.graph.disable_cudagraphs_reason = check_lowering_disable_cudagraph(
749
                    V.graph.device_node_mapping
750
                )
751

752
            compiled_graph = CompiledFxGraph(
753
                compiled_fn, graph, output_strides, V.graph.disable_cudagraphs_reason
754
            )
755

756
    return compiled_graph
757

758

759
def clone_preserve_strides(x: torch.Tensor):
760
    needed_size = (
761
        sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
762
    )
763
    buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
764
    return torch.as_strided(buffer, x.size(), x.stride())
765

766

767
def copy_misaligned_inputs(
768
    new_inputs: List[torch.Tensor], check_inputs_idxs: Sequence[int]
769
) -> None:
770
    for i in check_inputs_idxs:
771
        if new_inputs[i].data_ptr() % ALIGNMENT:
772
            new_inputs[i] = clone_preserve_strides(new_inputs[i])
773

774

775
def get_input_idxs_to_check(
776
    inputs: Union[List[torch.Tensor], Sequence[int]],
777
    static_input_idxs: Sequence[int],
778
) -> Sequence[int]:
779
    def is_aligned(storage_offset, dtype):
780
        return (storage_offset * get_dtype_size(dtype)) % ALIGNMENT == 0
781

782
    ids_to_check = []
783
    for i, input in enumerate(inputs):
784
        if (
785
            isinstance(input, torch.Tensor)
786
            and (
787
                i not in static_input_idxs
788
                or not is_aligned(input.storage_offset(), input.dtype)
789
            )
790
            and input.device.type == "cuda"
791
        ):
792
            ids_to_check.append(i)
793
    return ids_to_check
794

795

796
def align_inputs_from_check_idxs(
797
    model: Callable[[List[torch.Tensor]], Any], inputs_to_check: Sequence[int]
798
):
799
    if len(inputs_to_check) == 0:
800
        return model
801

802
    def run(new_inputs):
803
        copy_misaligned_inputs(new_inputs, inputs_to_check)
804
        return model(new_inputs)
805

806
    return run
807

808

809
def align_inputs(
810
    model: Callable[[List[torch.Tensor]], Any],
811
    inputs: List[torch.Tensor],
812
    static_input_idxs: Sequence[int] = (),
813
):
814
    inputs_to_check = get_input_idxs_to_check(inputs, static_input_idxs)
815
    return align_inputs_from_check_idxs(model, inputs_to_check)
816

817

818
@dynamo_utils.dynamo_timed
819
def cudagraphify(
820
    model: torch.fx.GraphModule,
821
    inputs: List[torch.Tensor],
822
    static_input_idxs: Sequence[int] = (),
823
    *,
824
    device_index: int,
825
    stack_traces: List[Optional[str]],
826
    is_backward: bool,
827
    is_inference: bool,
828
    constants: Tuple[torch.Tensor, ...] = (),
829
):
830
    from torch._inductor.cudagraph_trees import (
831
        cudagraphify_impl as new_cudagraphify_impl,
832
    )
833

834
    cudagraphify_fn: Callable[..., Any]
835
    if config.triton.cudagraph_trees:
836
        cudagraphify_fn = functools.partial(
837
            new_cudagraphify_impl,
838
            device_index=device_index,
839
            stack_traces=stack_traces,
840
            is_backward=is_backward,
841
            is_inference=is_inference,
842
            constants=constants,
843
        )
844
    else:
845
        cudagraphify_fn = cudagraphify_impl
846

847
    # if using fake tensors, defer cudagraphs until we get real inputs at runtime
848
    if not any(isinstance(inp, FakeTensor) for inp in inputs):
849
        return cudagraphify_fn(model, inputs, static_input_idxs)
850

851
    compiled_fn = None
852

853
    def run(new_inputs):
854
        nonlocal compiled_fn
855
        if compiled_fn is None:
856
            with dynamo_utils.preserve_rng_state():
857
                compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
858
        return compiled_fn(new_inputs)
859

860
    return run
861

862

863
def remove_unaligned_input_idxs(
864
    inputs: Union[List[torch.Tensor], Sequence[int]],
865
    static_input_idxs: Sequence[int],
866
):
867
    """
868
    We require all inputs to be aligned, so introduce a copy for any
869
    that aren't.
870
    """
871
    aligned_static_input_idxs = []
872
    for idx, input in zip(static_input_idxs, inputs):
873
        if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0:
874
            aligned_static_input_idxs.append(idx)
875
    if len(aligned_static_input_idxs) != len(static_input_idxs):
876
        return aligned_static_input_idxs
877
    return static_input_idxs
878

879

880
def static_input(x: torch.Tensor):
881
    """
882
    Copy and input while preserving strides
883
    """
884
    # TODO(jansel): figure out why this version doesn't work:
885
    # return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
886
    needed_size = (
887
        sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
888
    )
889
    buffer = torch.empty(needed_size, dtype=x.dtype, device=x.device)
890
    return torch.as_strided(buffer, x.size(), x.stride())
891

892

893
def index_expanded_dims_and_copy_(
894
    dst: torch.Tensor,
895
    src: torch.Tensor,
896
    expanded_dims: List[int],
897
):
898
    "Index into expanded dimensions of both dst and src then copy_"
899
    dst = index_expanded_dims(dst, expanded_dims)
900
    src = index_expanded_dims(src, expanded_dims)
901
    dst.copy_(src)
902

903

904
def cudagraphify_impl(
905
    model: torch.fx.GraphModule,
906
    inputs: List[torch.Tensor],
907
    static_input_idxs: Sequence[int] = (),
908
):
909
    """
910
    Assumes inputs[static_input_idxs[i]] are always the same memory address
911
    """
912
    check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
913
    static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
914
    copy_misaligned_inputs(inputs, check_input_idxs)
915

916
    assert isinstance(inputs, list)
917

918
    inps_expanded_dims = [
919
        get_expanded_dims(x) if idx not in static_input_idxs else []
920
        for idx, x in enumerate(inputs)
921
    ]
922

923
    # allocate static tensor inputs
924
    static_inputs = [
925
        x
926
        if not isinstance(x, torch.Tensor)
927
        else static_input(x)
928
        if idx not in static_input_idxs
929
        else x.detach()
930
        for idx, x in enumerate(inputs)
931
    ]
932

933
    # copy over input values for fresh allocations
934
    for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)):
935
        if isinstance(x, torch.Tensor) and idx not in static_input_idxs:
936
            index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims)
937

938
    # warmup
939
    torch.cuda.synchronize()
940
    stream = torch.cuda.Stream()
941
    stream.wait_stream(torch.cuda.current_stream())
942
    # copy static_inputs because it will be cleared in model
943
    with torch.cuda.stream(stream):
944
        model(list(static_inputs))
945
    stream.synchronize()
946
    torch.cuda.current_stream().wait_stream(stream)
947
    torch.cuda.synchronize()
948

949
    # record
950
    graph = torch.cuda.CUDAGraph()
951
    with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):
952
        static_outputs = model(list(static_inputs))
953
    if not isinstance(static_outputs, (list, tuple)):
954
        static_outputs = (static_outputs,)
955

956
    if config.size_asserts:
957

958
        def run(new_inputs):
959
            assert len(static_inputs) == len(new_inputs)
960
            for idx, (dst, src, expanded_dims) in enumerate(
961
                zip(static_inputs, new_inputs, inps_expanded_dims)
962
            ):
963
                if not isinstance(dst, torch.Tensor):
964
                    pass
965
                elif idx in static_input_idxs:
966
                    assert dst.data_ptr() == src.data_ptr()
967
                else:
968
                    # TODO - could make one single op of multiple slices
969
                    # and avoid dispatch.
970
                    # Could also pre-index the `dst` tensors
971
                    index_expanded_dims_and_copy_(dst, src, expanded_dims)
972
            new_inputs.clear()
973
            graph.replay()
974
            return static_outputs
975

976
    else:
977
        copy_indices = [
978
            idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
979
        ]
980

981
        def run(new_inputs):
982
            for idx in copy_indices:
983
                expanded_dims = inps_expanded_dims[idx]
984
                index_expanded_dims_and_copy_(
985
                    static_inputs[idx], new_inputs[idx], expanded_dims
986
                )
987
            new_inputs.clear()
988
            graph.replay()
989
            return static_outputs
990

991
    return align_inputs_from_check_idxs(run, check_input_idxs)
992

993

994
def count_tangents(fx_g: torch.fx.GraphModule):
995
    """
996
    Infers which inputs are static for a backwards graph
997
    """
998

999
    def is_saved_tensor(x):
1000
        return (
1001
            "tangents" not in x.name
1002
            and "bwd_seed" not in x.name
1003
            and "bwd_base_offset" not in x.name
1004
        )
1005

1006
    arg_count = 0
1007
    static_arg_idxs = []
1008
    for n in fx_g.graph.nodes:
1009
        if n.op == "placeholder":
1010
            if is_saved_tensor(n):
1011
                static_arg_idxs.append(arg_count)
1012
            arg_count += 1
1013

1014
    assert static_arg_idxs == list(range(len(static_arg_idxs)))
1015
    return len(static_arg_idxs)
1016

1017

1018
def compile_fx_aot(
1019
    model_: torch.fx.GraphModule,
1020
    example_inputs_: List[torch.Tensor],
1021
    inner_compile: Callable[..., Any] = compile_fx_inner,
1022
    config_patches: Optional[Dict[str, Any]] = None,
1023
):
1024
    config_patches: Dict[str, Any] = (
1025
        {"cpp_wrapper": True}
1026
        if config_patches is None
1027
        else {**config_patches, "cpp_wrapper": True}
1028
    )
1029
    if (
1030
        "aot_inductor.output_path" not in config_patches
1031
        and not config.aot_inductor.output_path
1032
    ):
1033
        config_patches = {
1034
            **config_patches,
1035
            "aot_inductor.output_path": code_hash(model_.code),
1036
        }
1037

1038
    extern_node_serializer = config_patches.pop("extern_node_serializer", None)
1039
    with V.set_aot_compilation(True):
1040
        compiled_lib_path = compile_fx(
1041
            model_,
1042
            example_inputs_,
1043
            inner_compile=functools.partial(
1044
                inner_compile,
1045
                aot_mode=True,
1046
                extern_node_serializer=extern_node_serializer,
1047
            ),
1048
            config_patches=config_patches,
1049
        )
1050
        assert os.path.exists(
1051
            compiled_lib_path
1052
        ), f"AOTInductor compiled library does not exist at {compiled_lib_path}"
1053
        return compiled_lib_path
1054

1055

1056
_graph_counter = count(0)
1057

1058

1059
def fw_compiler_freezing(
1060
    aot_autograd_model: torch.fx.GraphModule,
1061
    aot_example_inputs: List[torch.Tensor],
1062
    dynamo_model: torch.fx.GraphModule,
1063
    num_example_inputs: int,
1064
    inner_compile: Callable[..., Any],
1065
    cudagraphs: BoxedBool,
1066
    graph_id: int,
1067
    forward_device: BoxedDeviceIndex,
1068
):
1069
    from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
1070

1071
    # partition_fn won't be called
1072
    _recursive_joint_graph_passes(aot_autograd_model)
1073

1074
    layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True)
1075
    if layout_opt:
1076
        # make sure meta['val'] is properly setup
1077
        fake_tensor_prop(aot_autograd_model, aot_example_inputs, True)
1078
        convert_conv_weights_to_channels_last(aot_autograd_model)
1079

1080
    opt_model, preserved_arg_indices = freeze(
1081
        dynamo_model,
1082
        aot_autograd_model,
1083
        aot_example_inputs,  # type: ignore[arg-type]
1084
    )
1085

1086
    aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
1087
    num_fixed = len(preserved_arg_indices) - num_example_inputs
1088

1089
    fake_mode = detect_fake_mode(aot_example_inputs)
1090

1091
    # for freezing, all graph outputs should be user visible
1092
    *_, model_outputs_node = opt_model.graph.nodes
1093
    model_outputs = model_outputs_node.args[0]
1094
    user_visible_outputs = [
1095
        n.name for n in model_outputs if isinstance(n, torch.fx.Node)
1096
    ]
1097

1098
    # constant params will be real tensors, not fake
1099
    tracing_context = torch._guards.TracingContext.try_get()
1100
    if tracing_context is not None:
1101
        params_flat = tracing_context.params_flat
1102
        assert params_flat is not None
1103
        for i in range(len(params_flat)):
1104
            if i not in preserved_arg_indices:
1105
                params_flat[i] = None
1106

1107
    with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
1108
        optimized_function = inner_compile(
1109
            opt_model,
1110
            aot_example_inputs,
1111
            num_fixed=num_fixed,
1112
            cudagraphs=cudagraphs,
1113
            graph_id=graph_id,
1114
            is_inference=True,
1115
            boxed_forward_device_index=forward_device,
1116
            layout_opt=layout_opt,
1117
            user_visible_outputs=user_visible_outputs,
1118
        )
1119

1120
    # aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper
1121
    # that drops constant-ified params
1122
    if V.aot_compilation is True:
1123
        return optimized_function
1124

1125
    def wrapper(args):
1126
        args_new = [args[i] for i in preserved_arg_indices]
1127
        args.clear()
1128
        return optimized_function(args_new)
1129

1130
    wrapper._boxed_call = True  # type: ignore[attr-defined]
1131

1132
    return wrapper
1133

1134

1135
@_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)
1136
def compile_fx(
1137
    model_: torch.fx.GraphModule,
1138
    example_inputs_: List[torch.Tensor],
1139
    inner_compile: Callable[..., Any] = compile_fx_inner,
1140
    config_patches: Optional[Dict[str, Any]] = None,
1141
    decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
1142
):
1143
    """Main entrypoint to a compile given FX graph"""
1144
    if config_patches:
1145
        with config.patch(config_patches):
1146
            return compile_fx(
1147
                model_,
1148
                example_inputs_,
1149
                # need extra layer of patching as backwards is compiled out of scope
1150
                inner_compile=config.patch(config_patches)(inner_compile),
1151
                decompositions=decompositions,
1152
            )
1153

1154
    if config.cpp_wrapper:
1155
        with config.patch(
1156
            {
1157
                "cpp_wrapper": False,
1158
                "triton.autotune_cublasLt": False,
1159
                "triton.cudagraphs": False,
1160
                "triton.store_cubin": True,
1161
            }
1162
        ), V.set_real_inputs(example_inputs_):
1163
            inputs_ = example_inputs_
1164
            if isinstance(model_, torch.fx.GraphModule):
1165
                fake_inputs = [
1166
                    node.meta.get("val")
1167
                    for node in model_.graph.nodes
1168
                    if node.op == "placeholder"
1169
                ]
1170
                if all(v is not None for v in fake_inputs):
1171
                    # Validate devices before switching to fake tensors.
1172
                    for idx, fi, i in zip(count(), fake_inputs, inputs_):
1173
                        if fi.device != i.device:
1174
                            raise ValueError(
1175
                                f"Device mismatch between fake input and example input at position #{idx}: "
1176
                                f"{fi.device} vs {i.device}. If the model was exported via torch.export(), "
1177
                                "make sure torch.export() and torch.aot_compile() run on the same device."
1178
                            )
1179
                    inputs_ = fake_inputs
1180
            return compile_fx(
1181
                model_,
1182
                inputs_,
1183
                inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
1184
                decompositions=decompositions,
1185
            )
1186

1187
    recursive_compile_fx = functools.partial(
1188
        compile_fx,
1189
        inner_compile=inner_compile,
1190
        decompositions=decompositions,
1191
    )
1192

1193
    if not graph_returns_tuple(model_):
1194
        return make_graph_return_tuple(
1195
            model_,
1196
            example_inputs_,
1197
            recursive_compile_fx,
1198
        )
1199

1200
    if isinstance(model_, torch.fx.GraphModule):
1201
        if isinstance(model_.graph._codegen, _PyTreeCodeGen):
1202
            # this graph is the result of dynamo.export()
1203
            return handle_dynamo_export_graph(
1204
                model_,
1205
                example_inputs_,
1206
                recursive_compile_fx,
1207
            )
1208

1209
        model_ = _recursive_pre_grad_passes(model_, example_inputs_)
1210
        optimus_scuba_log["inductor_pre_grad"] = counters["inductor"]
1211
        signpost_event(
1212
            "optimus",
1213
            "compile_fx.pre_grad_passes",
1214
            optimus_scuba_log,
1215
        )
1216

1217
    if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
1218
        return flatten_graph_inputs(
1219
            model_,
1220
            example_inputs_,
1221
            recursive_compile_fx,
1222
        )
1223

1224
    assert not config._raise_error_for_testing
1225
    num_example_inputs = len(example_inputs_)
1226
    cudagraphs = BoxedBool(config.triton.cudagraphs)
1227
    forward_device = BoxedDeviceIndex(None)
1228

1229
    graph_id = next(_graph_counter)
1230

1231
    decompositions = (
1232
        decompositions if decompositions is not None else select_decomp_table()
1233
    )
1234

1235
    @dynamo_utils.dynamo_timed
1236
    def fw_compiler_base(
1237
        model: torch.fx.GraphModule,
1238
        example_inputs: List[torch.Tensor],
1239
        is_inference: bool,
1240
    ):
1241
        if is_inference:
1242
            # partition_fn won't be called
1243
            _recursive_joint_graph_passes(model)
1244

1245
        num_rng_seed_offset_inputs = 2 if functorch_config.functionalize_rng_ops else 0
1246
        fixed = len(example_inputs) - num_example_inputs - num_rng_seed_offset_inputs
1247
        user_visible_outputs = set()
1248

1249
        if config.keep_output_stride:
1250
            *_, model_outputs_node = model.graph.nodes
1251
            assert model_outputs_node.op == "output"
1252
            model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
1253
            num_model_outputs = len(model_outputs)
1254

1255
            context = torch._guards.TracingContext.try_get()
1256
            # See Note [User Outputs in the inductor graph]
1257
            if context is not None and context.fw_metadata and not is_inference:
1258
                original_output_start_index = (
1259
                    context.fw_metadata.num_mutated_inp_runtime_indices
1260
                )
1261
            else:
1262
                original_output_start_index = 0
1263

1264
            if isinstance(model_, torch.fx.GraphModule):
1265
                *_, orig_model_outputs_node = model_.graph.nodes
1266
                assert orig_model_outputs_node.op == "output"
1267
                orig_model_outputs, _ = pytree.tree_flatten(
1268
                    orig_model_outputs_node.args
1269
                )
1270
                num_orig_model_outputs = len(orig_model_outputs)
1271
            else:
1272
                num_orig_model_outputs = num_model_outputs
1273

1274
            assert num_orig_model_outputs <= num_model_outputs
1275

1276
            # Note [User Outputs in the inductor graph]
1277
            # We makes the following assumption
1278
            # For inference
1279
            #   len(orig_model_outputs) == len(model_outputs)
1280
            # For training
1281
            #   len(orig_model_outputs) <= len(model_outputs)
1282
            # During training, most of the time the model_outputs starts with
1283
            # original module's outputs followed by saved activations.
1284
            # But this can be not true if the model have inplace updated tensors.
1285
            # AOTAutograd will make those tensors being returned before the original
1286
            # module's output.
1287
            # To make things safe, we'll use original_output_start_index field
1288
            # set by AOTAutograd to decide where the original module outputs start.
1289
            orig_output_end_idx = original_output_start_index + num_orig_model_outputs
1290
            # Sanity chec: we are about to splice out the "user" outputs from the full set
1291
            # of "graph" outputs. Make sure we're within bounds.
1292
            assert orig_output_end_idx <= num_model_outputs
1293

1294
            user_visible_outputs = {
1295
                n.name
1296
                for n in model_outputs[original_output_start_index:orig_output_end_idx]
1297
                if isinstance(n, torch.fx.Node)
1298
            }
1299

1300
        return inner_compile(
1301
            model,
1302
            example_inputs,
1303
            num_fixed=fixed,
1304
            cudagraphs=cudagraphs,
1305
            graph_id=graph_id,
1306
            is_inference=is_inference,
1307
            boxed_forward_device_index=forward_device,
1308
            user_visible_outputs=user_visible_outputs,
1309
        )
1310

1311
    fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
1312

1313
    if config.freezing and not torch.is_grad_enabled():
1314
        inference_compiler = functools.partial(
1315
            fw_compiler_freezing,
1316
            dynamo_model=model_,
1317
            num_example_inputs=num_example_inputs,
1318
            inner_compile=inner_compile,
1319
            cudagraphs=cudagraphs,
1320
            graph_id=graph_id,
1321
            forward_device=forward_device,
1322
        )
1323
    else:
1324
        inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
1325

1326
    def partition_fn(graph, joint_inputs, **kwargs):
1327
        _recursive_joint_graph_passes(graph)
1328
        return min_cut_rematerialization_partition(
1329
            graph, joint_inputs, **kwargs, compiler="inductor"
1330
        )
1331

1332
    @dynamo_utils.dynamo_timed
1333
    @dynamo_utils.maybe_cprofile
1334
    def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
1335
        fixed = count_tangents(model)
1336
        return inner_compile(
1337
            model,
1338
            example_inputs,
1339
            num_fixed=fixed,
1340
            cudagraphs=cudagraphs,
1341
            is_backward=True,
1342
            graph_id=graph_id,
1343
            boxed_forward_device_index=forward_device,
1344
        )
1345

1346
    # TODO: can add logging before/after the call to create_aot_dispatcher_function
1347
    # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
1348
    # once torchdynamo is merged into pytorch
1349

1350
    fake_mode = detect_fake_mode(example_inputs_) or torch._subclasses.FakeTensorMode(
1351
        allow_non_fake_inputs=True
1352
    )
1353
    tracing_context = (
1354
        torch._guards.TracingContext.try_get()
1355
        or torch._guards.TracingContext(fake_mode)
1356
    )
1357

1358
    if V.aot_compilation is True:
1359
        gm, graph_signature = aot_export_module(
1360
            model_, example_inputs_, trace_joint=False, decompositions=decompositions
1361
        )
1362
        unlifted_gm = _unlift_graph(model_, gm, graph_signature)
1363
        if "dynamo_flat_name_to_original_fqn" in model_.meta:
1364
            unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[
1365
                "dynamo_flat_name_to_original_fqn"
1366
            ]
1367
        with V.set_fake_mode(fake_mode), compiled_autograd.disable():
1368
            return inference_compiler(unlifted_gm, example_inputs_)
1369

1370
    with V.set_fake_mode(fake_mode), torch._guards.tracing(
1371
        tracing_context
1372
    ), compiled_autograd.disable():
1373
        return aot_autograd(
1374
            fw_compiler=fw_compiler,
1375
            bw_compiler=bw_compiler,
1376
            inference_compiler=inference_compiler,
1377
            decompositions=decompositions,
1378
            partition_fn=partition_fn,
1379
            keep_inference_input_mutations=True,
1380
        )(model_, example_inputs_)
1381

1382

1383
def _shape_env_from_inputs(inputs: List[torch.Tensor]):
1384
    shape_env = None
1385
    fake_mode = detect_fake_mode(inputs)
1386

1387
    # TODO(voz): It would be nice to enable this assert, but there are lots of tests that
1388
    # pass in real inputs for now.
1389
    # if len(inputs) > 0:
1390
    # assert fake_mode is not None, breakpoint()
1391

1392
    if fake_mode is not None:
1393
        return fake_mode.shape_env
1394

1395
    # When there are no tensor inputs, get shape_env from the first SymInt.
1396
    for input in inputs:
1397
        if isinstance(input, torch.SymInt):
1398
            return input.node.shape_env
1399

1400
    # TODO(voz): Should we always have one anyway?
1401
    return None
1402

1403

1404
def output_node(gm: torch.fx.GraphModule):
1405
    """Get the output node from an FX graph"""
1406
    last_node = next(iter(reversed(gm.graph.nodes)))
1407
    assert last_node.op == "output"
1408
    return last_node
1409

1410

1411
def graph_returns_tuple(gm: torch.fx.GraphModule):
1412
    """True if a FX graph returns a tuple"""
1413
    if not isinstance(gm, torch.fx.GraphModule):
1414
        return True  # can't check this, assume true
1415
    (rv,) = output_node(gm).args
1416
    if isinstance(rv, (list, tuple)):
1417
        return True
1418
    if (
1419
        isinstance(rv, torch.fx.node.Node)
1420
        and hasattr(rv.target, "_schema")
1421
        and len(rv.target._schema.returns) > 1
1422
        and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns)
1423
    ):
1424
        # for graphs whose result is one node with multiple outputs
1425
        return True
1426
    return False
1427

1428

1429
def make_graph_return_tuple(
1430
    gm: torch.fx.GraphModule,
1431
    inputs: List[torch.Tensor],
1432
    compile_gm: Callable[..., Any],
1433
):
1434
    """
1435
    Mutate gm so it returns a tuple.  This is only needed for graphs
1436
    not created by torchdynamo that return non-tuples.
1437
    """
1438
    node = output_node(gm)
1439
    (rv,) = node.args
1440
    rv, spec = pytree.tree_flatten(rv)
1441
    with gm.graph.inserting_before(node):
1442
        gm.graph.output(rv)
1443
    gm.graph.erase_node(node)
1444
    assert graph_returns_tuple(gm)
1445

1446
    compiled_fn = compile_gm(gm, inputs)
1447

1448
    @functools.wraps(compiled_fn)
1449
    def wrapper(*args, **kwargs):
1450
        return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec)
1451

1452
    return wrapper
1453

1454

1455
def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
1456
    """
1457
    Mutate inputs so that they are flat and wrap gm such that it
1458
    accepts those inputs.  This is only needed for graphs not created
1459
    by torchdynamo that take bumpy inputs.
1460
    """
1461
    inputs, spec = pytree.tree_flatten(inputs)
1462

1463
    class GmWrapper(torch.nn.Module):
1464
        def __init__(self):
1465
            super().__init__()
1466
            self.gm = gm
1467

1468
        def forward(self, *args):
1469
            args: List[Any] = list(args)
1470
            return self.gm(*pytree.tree_unflatten(args, spec))
1471

1472
    compiled_fn = compile_gm(GmWrapper(), inputs)
1473

1474
    @functools.wraps(compiled_fn)
1475
    def wrapper(*args):
1476
        # note this doesn't check the spec, assuming it is the same
1477
        return compiled_fn(*pytree.arg_tree_leaves(*args))
1478

1479
    return wrapper
1480

1481

1482
def handle_dynamo_export_graph(
1483
    gm: torch.fx.GraphModule,
1484
    inputs: List[torch.Tensor],
1485
    compile_gm: Callable[..., Any],
1486
):
1487
    """
1488
    `torch._dynamo.export` embeds pytrees in the FX graph codegen object,
1489
    convert that to a normal FX graph so inductor can compile it.
1490
    """
1491
    codegen = gm.graph._codegen
1492
    gm.graph._codegen = torch.fx.graph.CodeGen()
1493
    gm.recompile()
1494

1495
    compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs))
1496

1497
    @functools.wraps(compiled_fn)
1498
    def wrapper(*args):
1499
        return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args)))
1500

1501
    return wrapper
1502

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.