pytorch

Форк
0
/
aot_autograd.py 
1557 строк · 64.2 Кб
1
# mypy: ignore-errors
2

3
import itertools
4
from contextlib import contextmanager, nullcontext
5
from functools import partial, wraps
6
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple
7
from unittest.mock import patch
8

9
import torch
10
import torch._dynamo.logging
11
import torch.nn as nn
12
import torch.utils._pytree as pytree
13
import torch.utils.dlpack
14
from torch import Tensor
15
from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions
16
from torch._dispatch.python import enable_python_dispatcher
17
from torch._dynamo import compiled_autograd
18
from torch._dynamo.utils import dynamo_timed, preserve_rng_state
19
from torch._guards import detect_fake_mode
20
from torch._inductor.utils import BoxedBool
21
from torch._subclasses import FakeTensor, FakeTensorMode
22
from torch.fx.experimental.proxy_tensor import make_fx
23
from torch.fx.experimental.symbolic_shapes import ShapeEnv
24
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
25

26

27
static_inputs_log = torch._logging.getArtifactLogger(
28
    __name__, "cudagraph_static_inputs"
29
)
30

31
from . import config
32
from ._aot_autograd.autograd_cache import (  # noqa: F401
33
    AOTAutogradCache,
34
    autograd_cache_key,
35
)
36
from ._aot_autograd.collect_metadata_analysis import (  # noqa: F401
37
    run_functionalized_fw_and_collect_metadata,
38
)
39
from ._aot_autograd.functional_utils import (  # noqa: F401
40
    _check_if_mutation_can_be_in_graph,
41
    are_all_mutations_hidden_from_autograd,
42
    are_all_mutations_under_no_grad_or_inference_mode,
43
    assert_functional_graph,
44
    from_fun,
45
    gen_alias_from_base,
46
    has_data_mutation,
47
    has_metadata_mutation,
48
    is_fun,
49
    sync_functional_tensor,
50
    to_fun,
51
)
52
from ._aot_autograd.input_output_analysis import (  # noqa: F401
53
    _tensors_definitely_do_not_overlap,
54
    compute_overlapping_inputs,
55
    create_graph_signature,
56
    create_synthetic_base_metadata,
57
    remove_dupe_metadata,
58
)
59
from ._aot_autograd.jit_compile_runtime_wrappers import (  # noqa: F401
60
    aot_dispatch_autograd,
61
    aot_dispatch_base,
62
    aot_dispatch_export,
63
)
64
from ._aot_autograd.logging_utils import (  # noqa: F401
65
    callback_set,
66
    describe_input,
67
    format_guard_bug_msg,
68
    get_aot_compilation_context,
69
    get_aot_graph_name,
70
    get_graph_being_compiled,
71
    graph_being_compiled,
72
    model_name,
73
    nth_graph,
74
    set_model_name,
75
    setup_stacktrace_preservation_hooks,
76
    track_graph_compiling,
77
)
78
from ._aot_autograd.runtime_wrappers import (  # noqa: F401
79
    AOTDedupeWrapper,
80
    AOTSyntheticBaseWrapper,
81
)
82
from ._aot_autograd.schemas import (  # noqa: F401
83
    AOTConfig,
84
    BackwardSignature,
85
    FQN,
86
    GraphInputName,
87
    GraphOutputName,
88
    GraphSignature,
89
    InputAliasInfo,
90
    MutationType,
91
    OutputAliasInfo,
92
    OutputType,
93
    SubclassCreationMeta,
94
    SubclassMeta,
95
    TensorAlias,
96
    ViewAndMutationMeta,
97
)
98
from ._aot_autograd.subclass_utils import (  # noqa: F401
99
    create_metadata_for_subclass,
100
    requires_subclass_dispatch,
101
    unwrap_tensor_subclasses,
102
    wrap_tensor_subclasses,
103
    wrap_tensor_subclasses_maybe_joint,
104
)
105
from ._aot_autograd.traced_function_transforms import (  # noqa: F401
106
    aot_dispatch_subclass,
107
    create_functional_call,
108
    create_functionalized_fn,
109
    create_functionalized_rng_ops_wrapper,
110
    create_joint,
111
    fn_input_mutations_to_outputs,
112
    fn_prepped_for_autograd,
113
)
114
from ._aot_autograd.utils import (  # noqa: F401
115
    _get_autocast_states,
116
    _get_symint_hints,
117
    call_func_at_runtime_with_args,
118
    create_tree_flattened_fn,
119
    KNOWN_TYPES,
120
    make_boxed_compiler,
121
    make_boxed_func,
122
    maybe_to_fresh_input,
123
    normalize_as_list,
124
    partial_flatten_asdict,
125
    root_module_when_exporting_non_strict,
126
    strict_zip,
127
)
128
from .partitioners import default_partition
129

130

131
zip = strict_zip
132

133
# This global counter increments every time we compile a graph with
134
# AOTAutograd.  You can use this to correlate runtime error messages
135
# with compile time (e.g., if you get an error at runtime saying
136
# compiled graph 3 failed, you can set a breakpoint at compile time
137
# for this graph number to investigate further at compile time.)
138
#
139
# NB: this is different from get_aot_compilation_context, which tracks
140
# each underlying graph that is compiled.  In contrast, AOT_COUNTER
141
# corresponds to top-level invocations of aot_module/aot_function;
142
# one counter is allocated per entire compiled block (but this block
143
# may involve compiling multiple subgraphs; e.g., for forwards/backwards)
144
AOT_COUNTER = itertools.count()
145

146
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
147
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
148
#
149
# AOT Autograd contains a pretty non-trivial amount of logic to handle edge cases around aliasing and mutation
150
# that are external to the graph (they show up as side effects in some way when you run the graph).
151
#
152
# Take a look at `test_aotdispatch.py TestAOTAutograd.test_input_mutation*` tests for some examples functions
153
# and what they're compiled graphs looks like.
154
# Below is a very long comment detailing several edge cases, and showing how AOT Autograd handles them.
155
#
156
# Note [AOT Autograd: input data mutations]
157
#
158
# If we compile a function that mutates inputs, then those input mutations are real side effects
159
# that a user expects to see after running the compiled graph.
160
# However, the graph that we want to send to a backend needs to be *entirely* functional.
161
# The way we reconcile this difference is that we remove the mutations completely from the graph that we compile
162
# but we update the graph to return (updated_inputs, user_outputs).
163
# In the epilogue that runs after the compiled graph is executed, we copy the updated inputs back to the originals.
164
#
165
# Example: original user code:
166
# def f(x):
167
#     x.mul_(2)
168
#     out = x.mul(3)
169
#     return out
170
#
171
# After AOT Autograd compiles, we end up with a:
172
# (a) compiled graph
173
# (b) autograd.Function.forward() method, that executes the compiled graph
174
# (c) wrapper function, that calls the autograd.Function.forward() and performs the epilogue
175
#
176
# The output of (a, b, c) are all written below.
177
#
178
# def compiled_forward_graph(x):
179
#     x_updated = x.mul(2)
180
#     out = x_updated.mul(3)
181
#     return x_updated, out
182
#
183
# # x_updated gets a gradient in the compiled backward
184
# def compiled_backward_graph(grad_x_updated, grad_out):
185
#     grad_x = ...
186
#     return grad_x
187
#
188
# def autograd.Function.forward(x):
189
#     x_updated, out = compiled_forward_graph(x)
190
#     return x_updated, out
191
#
192
# def compiled_wrapper(x):
193
#     x_updated, out = autograd.Function.apply(x)
194
#     x.copy_(x_updated)
195
#     return out
196
#
197
# Another important thing to note is that updated inputs (due to data mutations) *do* participate
198
# in the compiled backward graph! Since the compiled forward graph gets N extra outputs
199
# (due to updated inputs showing up as graph outputs),
200
# The compiled backward gets an additional N inputs.
201
# That way, during the x.copy_(x_updated) bit in the epilogue, gradients will flow from the updated input
202
# back to the original input.
203

204

205
# Note [AOT Autograd: input metadata mutations]
206
#
207
# For the same reason as input mutations, we also don't put input metadata mutations in the graph.
208
# Instead, we return the updated version of the input (a view), and mutate the input's metadata outside of the graph
209
#
210
# Example: original user code:
211
# def f(x):
212
#     x.t_()
213
#     out = x.mul(3)
214
#     return out
215
#
216
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
217
# def compiled_forward_graph(x):
218
#     x_updated = x.t()
219
#     out = x_updated.mul(3)
220
#     return x_updated, out
221
#
222
# # x_updated does *not* get a gradient in the compiled backward
223
# def compiled_backward_graph(grad_out):
224
#     grad_x = ...
225
#     return grad_x
226
#
227
# def autograd.Function.forward(x):
228
#     x_updated, out = compiled_forward_graph(x)
229
#     return x_updated, out
230
#
231
# def compiled_wrapper(x):
232
#     x_updated, out = autograd.Function.apply(x)
233
#     x.as_strided_(x_updated)
234
#     return out
235

236

237
# Note [AOT Autograd: outputs aliasing inputs or intermediates!]
238
#
239
# AOT Autograd needs special handling for outputs that alias graph inputs or intermediates!
240
# Why?
241
# (1) autograd.Function.forward() has a limitation, where views that returned in the forward cannot later be mutated.
242
# (2) views don't need to be compiled in the graph anyway - it's cheap to generate them outside of the compiled graph,
243
#     in an epilogue.
244
# For outputs that alias inputs, we do the following:
245
# (a) *still* return the aliased output as a graph output
246
# (b) In the AOT Autograd wrapper/epilogue, we don't return that aliased output. Instead, we use it to regenerate the output.
247
#
248
# For outputs that alias *intermediates*, we do the following:
249
# (a) Return the output in the compiled forward, **and** return it's ._base (a graph intermediates) as an output in the forward
250
# (b) Use (output, graph_intermediate) to regenerate the alias, and return that to the user (instead of the compiled fw output).
251
# You might wonder why we return the aliased output directly in the graph (and making the graph compute it),
252
# only to not return it and instead generate a fresh alias off of the intermediate,
253
# instead of (say) just storing metadata about the size/stride of the output somewhere to generate the alias. There are two reasons:
254
# (1) Getting the actual alias tensor allows us to use view-replay to generate the alias, instead of an as_strided() call
255
# (2) Inductor (and other backends) are free to change the memory format of graph outputs, if it results in better performance.
256
#     This can result in problems if a user later tries to .view() that output expecting it to have one set of strides,
257
#     when it has a different set of strides.
258
#     By including the view op directly in the graph, inductor takes that into account when deciding what memory format
259
#     the graph intermediate should be.
260
#
261
# Another important thing to note is how our traced backward() graph handles aliases.
262
# (this applies to outputs aliasing inputs, outputs aliasing intermediates,
263
#  *and* updated inputs returned in the compiled forward due to metadata-only mutations).
264
# Any outputs that alias (either inputs or intermediates) do NOT participate in the compiled backward graph
265
# It would be wasteful to include them in the compiled backward(), because we regenerate them eagerly
266
# at the end of the forward.
267
#
268
# Example: original user code:
269
# def f(x):
270
#     out1 = x.t()
271
#     intermediate = x.mul(2)
272
#     out2 = intermediate.view(-1)
273
#     return out1, out2
274
#
275
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
276
# def compiled_forward_graph(x):
277
#     out1 = x.t()
278
#     intermediate = x.mul(2)
279
#     out2 = intermediate.view(-1)
280
#     # the compiled graph also returns the intermediate
281
#     return out1, out2, intermediate
282
#
283
# # intermediate gets a gradient in the compiled backward.
284
# # both output aliases (out1 and out2) do not.
285
# def compiled_backward_graph(grad_intermediate):
286
#     grad_x = ...
287
#     return grad_x
288
#
289
# def autograd.Function.forward(x):
290
#     out1, out2, intermediate = compiled_forward_graph(x)
291
#     return out1, out2, intermediate
292
#
293
# def compiled_wrapper(x):
294
#     out1, out2, intermediate = autograd.Function.apply(x)
295
#     # regenerate out1 from the input
296
#     out1_regenerated = out1._view_func(x)
297
#     # regenerate out1 from the intermediate
298
#     out2_regenerated = out2._view_func(intermediate)
299
#     return out1_regenerated, out2_regenerated
300

301

302
# Note [AOT Autograd: mutations to inputs that alias other inputs]
303
#
304
# Another edge case that is (only partially) handled today is when an input is mutated, but itself aliases another input.
305
# AOT Autograd needs to **ensure** that functionalization knows that the two inputs are aliased to each other.
306
# That way, when the aliased input is accessed later in the graph, functionalization knows to "update" the alias
307
# given the mutation that occurred.
308
#
309
# This is handled by updating the calling convention: we create a "synthetic base" that becomes a new input
310
# in the compiled function, and we regenerate the original (aliased) inputs directly off of the base
311
# inside of the compiled function.
312
#
313
# This logic is fully encapsulated in aot_wrapper_synthetic_base()
314
#
315
# Example: original user code:
316
# def f(x, x_view):
317
#     x.mul_(2)
318
#     out = x * x_view
319
#     return out
320
# f(x, x.view(-1))
321
#
322
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
323
# def compiled_forward_graph(base)
324
#     x = generate_x(base)
325
#     x_view = generate_x_view(base)
326
#     x_updated = x.mul(2)
327
#     x_view_updated = x_updated.view(-1)
328
#     out = x_updated * x_view_updated
329
#     return x_updated, out
330
#
331
# # The calling convention change from (aliases) -> (base) happens
332
# # *outside* of the autograd.Function.forward().
333
# # That means the forward() only has 1 input (base),
334
# # and the backward() only has 1 output (grad_base)
335
# def compiled_backward_graph(grad_out):
336
#     grad_base = ...
337
#     return grad_base
338
#
339
# def autograd.Function.forward(base):
340
#     x_updated, out = compiled_forward_graph(base)
341
#     return x_updated, out
342
#
343
# # The compiled wrapper is where we create synthetic bases.
344
# # The info on which inputs are mutated is also tracked *before* synthetic base creation.
345
# def compiled_wrapper(x, x_view):
346
#     base = merge_view_inputs(x, x_view)
347
#     x_updated, out = autograd.Function.apply(base)
348
#     # x and x_view are aliased in eager mode, so this mutation to x will automatically affect x_view.
349
#     x.copy_(x_updated)
350
#     return out
351

352

353
# Note [AOT Autograd: Views to avoid tangents aliasing inputs]
354
#
355
# We view every forward output when creating out tangent tensors to handle the problematic
356
# case in which a subclass does extra aliasing between graph outputs/inputs in a way that
357
# is not visible above the sublass.
358
#
359
# Ordinarily, when constructing the joint function that we want to trace in AOTAutograd,
360
# we're guaranteed that the tangent tensors that we pass
361
# into the joint are distinct tensors from the primals. This is because when
362
# decide which forward outputs to create tangents for, we only create tangents
363
# for forward outputs that are not aliases of inputs (See Note
364
# [AOT Autograd: outputs aliasing inputs or intermediates!]).
365
#
366
# However, when wrapper tensor subclasses enter the picture, it is possible
367
# to have an output of the forward that is a subclass that is not an
368
# input / alias of an input, but one of its inner tensors is an alias!
369
# NestedTensor is an example: Performing an out-of-place pointwise op on a
370
# NestedTensor constructs a fresh NestedTensor that holds onto the input's
371
# offsets tensor directly.
372
#
373
# Having tangent tensors that are the same as the (primal) forward inputs,
374
# can cause problems during tracing as make_fx() will specialize on our
375
# duplicate inputs: If we passed in the same tensor for primals_1 and
376
# tangents_1 during tracing, make_fx() will happily sub out all usages of
377
# tangents_1 with primals_1 in the graph, which is not what we want.
378
#
379
# To work around this, we view every forward output when creating out tangent
380
# tensors so that tangents can never be the same as forward inputs even if
381
# forward inputs alias forward outputs.
382

383
# Note [Side-Effectful Tokens in AOTAutograd]
384
#
385
# We allow some some side-effectful operators in
386
# the post-AOTAutograd (functional) graph, such as prints and torchbind operations.
387
# To ensure that these side-effects are compatible to future graph passes that
388
# assume that the graph is functional, we will thread "effect tokens" to show
389
# data dependence between these side-effectful operators. Practically speaking,
390
# effect tokens are just dummy values (torch.tensor([])). The graph would look
391
# like the following:
392
#
393
# def gm(self, token0, reader):
394
#    token1, frame = with_token(ordered_effect_op, (reader,), token0)
395
#    frame = frame * 2
396
#    token2, frame2 = with_token(ordered_effect_op, (reader,), token1)
397
#    frame2 = frame2 * 2
398
#    return token2, frame, frame2
399
#
400
# We will pass the token as an input to the graph, thread it through
401
# side-effectful operators using the `with_effects` high order operator, and then
402
# return the updated token as an output.
403
# So the signature of the graph input would look something like
404
# (*tokens, *params_buffers, *user_inputs), and the signature of the graph
405
# output would look something like (*tokens, *outputs).
406
#
407
# However, Inductor does not want the concept of tokens in the final generated
408
# code's input and output. Since changing the graph signature inside of inductor
409
# is difficult, after generating the forward graph, we will run a pass to
410
# remove the tokens from the inputgenerate the following graph for Inductor, where
411
# the tokens are created and sunk within the graph, rather than as inputs and
412
# outputs:
413
#
414
# def gm(self, reader):
415
#    token0 = torch.ops.prims._make_token()
416
#    token1, frame = with_token(ordered_effect_op, (reader,), token0)
417
#    frame = frame * 2
418
#    token2, frame2 = with_token(ordered_effect_op, (reader,), token1)
419
#    frame2 = frame2 * 2
420
#    sink_token = torch.ops.prims._sink_tokens([token2])
421
#    return frame, frame2
422

423
#
424
#
425
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
426
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
427

428

429
aot_autograd_decompositions = {}
430

431
FakifiedFlatArgs = NewType("FakifiedFlatArgs", List[Any])
432

433

434
def process_inputs(
435
    flat_args: List[Any],
436
    aot_config: AOTConfig,
437
    fake_mode: FakeTensorMode,
438
    shape_env: Optional[ShapeEnv],
439
) -> FakifiedFlatArgs:
440
    with fake_mode:
441

442
        def convert(idx, x):
443
            if shape_env is not None:
444
                from torch._dynamo.source import ConstantSource
445

446
                if isinstance(x, int):
447
                    # We always specialize on scalar values in export.
448
                    if aot_config.is_export:
449
                        return x
450
                    source = ConstantSource(f"sym_{idx}")
451
                    return shape_env.create_symintnode(
452
                        shape_env.create_symbol(x, source), hint=x, source=source
453
                    )
454
            if isinstance(x, torch.ScriptObject):
455
                return torch._library.fake_class_registry.maybe_to_fake_obj(
456
                    fake_mode, x
457
                )
458
            if not isinstance(x, torch.Tensor):
459
                return x
460
            if isinstance(x, FakeTensor):
461
                assert x.fake_mode is fake_mode
462
                return x
463
            if is_traceable_wrapper_subclass(x):
464
                attrs, _ = x.__tensor_flatten__()
465
                if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs):
466
                    assert all(
467
                        getattr(x, attr).fake_mode is fake_mode for attr in attrs
468
                    )
469
                    return x
470

471
            # see note [Tensor Fakification and Symbol Caching]
472
            symbolic_context = None
473
            source = None
474
            trace = True
475
            if tracing_context := torch._guards.TracingContext.try_get():
476
                if x in tracing_context.tensor_to_context:
477
                    symbolic_context = tracing_context.tensor_to_context[x]
478
                    source = symbolic_context.tensor_source
479
                    # We already fakeified this tensor in Dynamo, don't
480
                    # dump the trace for it again
481
                    trace = False
482
            if (
483
                idx < aot_config.num_params_buffers
484
                and config.static_weight_shapes
485
                and not symbolic_context
486
            ):
487
                # TODO: Ensure that this codepath is never exercised from
488
                # Dynamo
489
                return fake_mode.from_tensor(x, static_shapes=True)
490

491
            return fake_mode.from_tensor(
492
                x,
493
                static_shapes=False,
494
                symbolic_context=symbolic_context,
495
                source=source,
496
                trace=trace,
497
            )
498

499
        return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)])
500

501

502
def construct_fake_mode(
503
    flat_args: List[Any], aot_config: AOTConfig
504
) -> Tuple[FakeTensorMode, Optional[ShapeEnv]]:
505
    fake_mode = detect_fake_mode(flat_args)
506
    if fake_mode is None:
507
        shape_env = ShapeEnv() if aot_config.dynamic_shapes else None
508
        fake_mode = FakeTensorMode(shape_env=shape_env)
509
    else:
510
        shape_env = fake_mode.shape_env
511
    return (fake_mode, shape_env)
512

513

514
def create_aot_dispatcher_function(
515
    flat_fn,
516
    fake_flat_args: FakifiedFlatArgs,
517
    aot_config: AOTConfig,
518
    fake_mode: FakeTensorMode,
519
    shape_env: Optional[ShapeEnv],
520
) -> Tuple[Callable, ViewAndMutationMeta]:
521
    with dynamo_timed("create_aot_dispatcher_function"):
522
        return _create_aot_dispatcher_function(
523
            flat_fn, fake_flat_args, aot_config, fake_mode, shape_env
524
        )
525

526

527
def _create_aot_dispatcher_function(
528
    flat_fn,
529
    fake_flat_args: FakifiedFlatArgs,
530
    aot_config: AOTConfig,
531
    fake_mode: FakeTensorMode,
532
    shape_env: Optional[ShapeEnv],
533
) -> Tuple[Callable, ViewAndMutationMeta]:
534
    """
535
    Traces the forward and backward graphs of the attr:`flat_fn` to generate a
536
    joint graph. The joint graph is an Fx graph with Aten ops. Please refer to
537
    the tracing mechanism to understand the graph capturing details.
538

539
    The joint graph is then passed through attr:`partition_fn` to isolate the
540
    forward and backward portions, which are then respectively compiled via the
541
    provided attr:`fw_compiler` and attr:`bw_compiler`.
542

543
    The resulting compiled forward and backward graphs are then wrapped up in a
544
    ``torch.autograd.Function`` object.
545

546
    The calling convention here is that the first aot_config.num_params_buffers
547
    inputs in flat_args are parameters and buffers, and the rest are inputs.
548

549
    We use this to assume that parameters/buffer's shapes don't change.
550

551
    Note: this function is used both by aot_function and aot_export (controlled by aot_config.is_export)
552
        When aot_config.is_export is True, we return an FX graph + metadata
553
        When aot_config.is_export is False, we return an ordinary runtime function
554
    """
555

556
    # This is the main entry point.
557
    # TODO: Chillee argues that dynamo itself should pass in fake tensors to
558
    # the list of arguments when compiling; at the moment we do not do this
559

560
    if aot_config.decompositions is None:
561
        aot_config.decompositions = {}
562

563
    aot_config.decompositions = {
564
        **aot_autograd_decompositions,
565
        **aot_config.decompositions,
566
    }
567

568
    if config.functionalize_rng_ops:
569
        # Update the decompositions with functionalized random decompositions
570
        aot_config.decompositions = {
571
            **rng_decompositions,
572
            **aot_config.decompositions,
573
        }
574

575
    # Check flat_args to see if they're already fake.  If so, use that fake
576
    # mode instead.
577

578
    python_dispatcher_mode = (
579
        enable_python_dispatcher() if shape_env is not None else nullcontext()
580
    )
581

582
    # See NOTE: [Deferring tensor pack/unpack hooks until runtime]
583
    # If any saved tensor hooks are active, we **don't** want to trace them.
584
    # Instead, we'll let them run at runtime, around the custom autograd.Function
585
    # that we generate in torch.compile.
586
    with torch.autograd.set_multithreading_enabled(
587
        False
588
    ), preserve_rng_state(), (
589
        fake_mode
590
    ), (
591
        python_dispatcher_mode
592
    ), PhiloxStateTracker(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
593
        from torch._library.fake_class_registry import (
594
            FakeScriptObject,
595
            maybe_to_fake_obj,
596
        )
597

598
        # Tracing may mutate the states the fake script object,
599
        # so we need to duplicate the fake script objects so that subsequent tracing
600
        # won't be affected.
601
        def _dup_fake_script_obj(fake_flat_args):
602
            return [
603
                maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj)
604
                if isinstance(arg, FakeScriptObject)
605
                else arg
606
                for arg in fake_flat_args
607
            ]
608

609
        needs_autograd = any(
610
            x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)
611
        )
612

613
        with enable_python_dispatcher():
614
            # Patch set_rng_state as set_rng_state with fake tensors is
615
            # nonsensical. This does not affect the collection of metadata.
616
            with patch("torch.cuda.set_rng_state", lambda *args: None):
617
                mod = root_module_when_exporting_non_strict(flat_fn)
618
                if mod is not None:
619
                    ctx = _detect_attribute_assignment(mod)
620
                else:
621
                    ctx = nullcontext()
622
                with ctx:
623
                    fw_metadata = run_functionalized_fw_and_collect_metadata(
624
                        flat_fn,
625
                        static_input_indices=aot_config.static_input_indices,
626
                        keep_input_mutations=aot_config.keep_inference_input_mutations,
627
                        is_train=needs_autograd,
628
                        pre_dispatch=aot_config.pre_dispatch,
629
                    )(*_dup_fake_script_obj(fake_flat_args))
630

631
                req_subclass_dispatch = requires_subclass_dispatch(
632
                    fake_flat_args, fw_metadata
633
                )
634

635
                output_and_mutation_safe = not any(
636
                    x.requires_grad
637
                    # view-type operations preserve requires_grad even in no_grad.
638
                    # Do not count aliases of inputs with requires_grad as reason to make a training graph,
639
                    # as AOTAutograd will perform view-replay to regenerate the view outputs at runtime,
640
                    # setting their grad_fn properly.
641
                    and not (
642
                        x.output_type
643
                        in (OutputType.alias_of_input, OutputType.is_input)
644
                        and fw_metadata.input_info[x.base_idx].requires_grad
645
                    )
646
                    for x in fw_metadata.output_info
647
                ) and not any(
648
                    x.requires_grad
649
                    and x.mutates_data
650
                    and not x.mutations_under_no_grad_or_inference_mode
651
                    and not x.mutations_hidden_from_autograd
652
                    for x in fw_metadata.input_info
653
                )
654

655
                if needs_autograd and output_and_mutation_safe:
656
                    # We realized that none of the outputs require grad,
657
                    # and none of the inputs that require grad are mutated.
658
                    # so we actually have an inference graph.
659
                    needs_autograd = False
660
                    # A bit silly: right now in the subclass codepath, our ViewAndMutationMeta
661
                    # changes depending on whether we pass in is_train / keep_input_mutations,
662
                    # so we're forced to recompute the metadata.
663
                    # TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata
664
                    # so that this is unnecessary.
665
                    if req_subclass_dispatch:
666
                        fw_metadata = run_functionalized_fw_and_collect_metadata(
667
                            flat_fn,
668
                            keep_input_mutations=aot_config.keep_inference_input_mutations,
669
                            is_train=False,
670
                            pre_dispatch=aot_config.pre_dispatch,
671
                            static_input_indices=aot_config.static_input_indices,
672
                        )(*fake_flat_args)
673
                    else:
674
                        fw_metadata = ViewAndMutationMeta(
675
                            input_info=fw_metadata.input_info,
676
                            output_info=fw_metadata.output_info,
677
                            num_intermediate_bases=fw_metadata.num_intermediate_bases,
678
                            keep_input_mutations=aot_config.keep_inference_input_mutations,
679
                            traced_tangents=fw_metadata.traced_tangents,
680
                            subclass_inp_meta=fw_metadata.subclass_inp_meta,
681
                            subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta,
682
                            subclass_tangent_meta=fw_metadata.subclass_tangent_meta,
683
                            is_train=False,
684
                            tokens=fw_metadata.tokens,
685
                            static_input_indices=fw_metadata.static_input_indices,
686
                        )
687

688
        if fw_metadata.num_intermediate_bases > 0:
689
            assert not req_subclass_dispatch, f"""\
690
torch.compile is currently being used with tensor subclass inputs:
691
{','.join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs
692
that alias one another, which is currently unsupported in the subclass use case. If you run into this,
693
please file a github issue"""
694

695
        if aot_config.is_export:
696
            # aot_export: ban input metadata mutations for now to keep shared code paths simpler.
697
            # Keeping .resize_() in the graph will require some work
698
            # Allowing it but keeping the graph functional will require some calling convention changes.
699
            if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0:
700
                raise RuntimeError(
701
                    f"""\
702
Found an input that received a metadata mutation, through e.g. a call to `.resize_()` or `.transpose_()`.
703
This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue.
704

705
fw_metadata={str(fw_metadata)}"""
706
                )
707
            # In export, banning data mutations on inputs that require grad for now.
708
            # This should be rare, and is tricky to get right. When we trace the backward,
709
            # we currently trace with autograd.grad instead of .backward(), which makes it difficult
710
            # to ensure that we run autograd all the way through the input **before** it saw the mutation.
711
            if (
712
                len(
713
                    [
714
                        x
715
                        for x in fw_metadata.input_info
716
                        if x.requires_grad and x.mutates_data
717
                    ]
718
                )
719
                != 0
720
            ):
721
                raise RuntimeError(
722
                    f"""\
723
Found a graph input that requires gradients, and received a mutation.
724
This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue.
725

726
fw_metadata={str(fw_metadata)}"""
727
                )
728
            if req_subclass_dispatch:
729
                raise RuntimeError(
730
                    """\
731
aot_export is not currently supported with traceable tensor subclass.
732
If you need this feature, please comment on <CREATE_ISSUE_LINK>"""
733
                )
734

735
            # Need to decide on a strategy for functionalized RNG: toggling via global config seems bad,
736
            # and turning it on will require a non-trivial calling convention change for any export runtime.
737
            if config.functionalize_rng_ops:
738
                raise RuntimeError(
739
                    """\
740
Functionalized RNG is not currently supported in the aot_export workflow. Please file a github issue,
741
or otherwise set torch._functorch.config.functionalize_rng_ops = False."""
742
                )
743

744
        def choose_dispatcher(needs_autograd, aot_config):
745
            """
746
            Pick a dispatcher based on the config rules.
747
            """
748
            if aot_config.is_export:
749
                # export uses just the "graph bits", whereas the other
750
                # two dispatchers include some extra work around handling a runtime epilogue
751
                return partial(aot_dispatch_export, needs_autograd=needs_autograd)
752
            elif needs_autograd and not aot_config.pre_dispatch:
753
                return aot_dispatch_autograd
754
            else:
755
                return aot_dispatch_base
756

757
        compiler_fn = choose_dispatcher(needs_autograd, aot_config)
758

759
        compiled_fn, fw_metadata = compiler_fn(
760
            flat_fn,
761
            _dup_fake_script_obj(fake_flat_args),
762
            aot_config,
763
            fw_metadata=fw_metadata,
764
        )
765
        return compiled_fn, fw_metadata
766

767

768
def aot_function(
769
    fn: Callable,
770
    fw_compiler: Callable,
771
    bw_compiler: Optional[Callable] = None,
772
    partition_fn: Callable = default_partition,
773
    decompositions: Optional[Dict] = None,
774
    num_params_buffers: int = 0,
775
    keep_inference_input_mutations: bool = False,
776
    inference_compiler: Optional[Callable] = None,
777
    *,
778
    # Whether or not to trace with dynamic shapes
779
    dynamic=False,
780
    enable_log=True,
781
) -> Callable:
782
    """
783
    Traces the forward and backward graph of :attr:`fn` using torch dispatch
784
    mechanism, and then compiles the generated forward and backward graphs
785
    through :attr:`fw_compiler` and :attr:`bw_compiler`.
786

787
    :func:`aot_function` traces the forward and backward graph ahead of time,
788
    and generates a joint forward and backward graph.  :attr:`partition_fn` is
789
    then used to separate out forward and backward graphs. The partitioner
790
    function can be used to perform optimizations such as recomputation. One can
791
    set `decompositions` dictionary to decompose the operators into a sequence
792
    of core or simpler operators supported by the backend compilers.
793

794
    .. warning::
795
        This API is experimental and likely to change.
796

797
    Args:
798
        fn (Callable): A Python function that takes one ore more arguments. Must
799
            return one or more Tensors.
800
        fw_compiler (Callable): A Python function that accepts an Fx graph with
801
            Aten ops and input args, and returns a Callable that semantically is
802
            equivalent to the input Fx graph.
803
        bw_compiler (Optional[Callable]): A Python function that accepts an
804
            Fx graph with Aten ops and input args, and returns a Callable that
805
            semantically is equivalent to the input Fx graph.  Default: None
806
            (when None, it defaults to the :attr:`fw_compiler`)
807
        partition_fn (Callable): A Python function that takes a joint forward
808
            and backward graph, and partitions it into separate forward and
809
            backward graphs.
810
        decompositions (Dict): A dictionary to define the decomposition of
811
            larger Aten ops into simpler or core Aten ops.
812
        inference_compiler (Optional[Callable]): A Python function that accepts an
813
            Fx graph with Aten ops and input args, and returns a Callable that
814
            semantically is equivalent to the input Fx graph. inference_compiler is invoked
815
            if no autograd is needed. Default: None
816
            (when None, it defaults to the :attr:`fw_compiler`)
817
    Returns:
818
        Returns a ``Callable`` that retains the eager behavior of the original
819
        :attr:`fn`, but with forward and backward graph compiled via
820
        :attr:`fw_compile` and :attr:`bw_compile`.
821

822
    A simple example usage of :func:`aot_function` is as follows. This example
823
    will print the forward and backward graphs of the function ``fn``
824

825
        >>> fn = lambda x : x.sin().cos()
826
        >>> def print_compile_fn(fx_module, args):
827
        >>>     print(fx_module)
828
        >>>     return fx_module
829
        >>> aot_fn = aot_function(fn, print_compile_fn)
830
        >>> x = torch.randn(4, 5, requires_grad=True)
831
        >>> aot_fn(x)
832
    """
833

834
    if bw_compiler is None:
835
        bw_compiler = fw_compiler
836
    if inference_compiler is None:
837
        inference_compiler = fw_compiler
838
    aot_config = AOTConfig(
839
        fw_compiler=fw_compiler,
840
        bw_compiler=bw_compiler,
841
        inference_compiler=inference_compiler,
842
        partition_fn=partition_fn,
843
        decompositions=decompositions,
844
        num_params_buffers=num_params_buffers,
845
        aot_id=next(AOT_COUNTER),
846
        keep_inference_input_mutations=keep_inference_input_mutations,
847
        dynamic_shapes=dynamic,
848
        aot_autograd_arg_pos_to_source=None,
849
        is_export=False,
850
        no_tangents=False,
851
        enable_log=enable_log,
852
    )
853
    cached_res = None
854

855
    @wraps(fn)
856
    def returned_function(*args, **kwargs):
857
        nonlocal cached_res
858
        # Now flatten the tensor args
859
        flat_args = pytree.arg_tree_leaves(*args, **kwargs)
860

861
        # Compile the function and save it in the cache
862
        if cached_res is None:
863
            flat_fn, out_spec = create_tree_flattened_fn(fn, args, kwargs)
864
            (fake_mode, shape_env) = construct_fake_mode(flat_args, aot_config)
865
            fake_flat_args: FakifiedFlatArgs = process_inputs(
866
                flat_args, aot_config, fake_mode, shape_env
867
            )
868
            compiled_fn, _ = create_aot_dispatcher_function(
869
                flat_fn,
870
                fake_flat_args,
871
                aot_config,
872
                fake_mode,
873
                shape_env,
874
            )
875
            cached_res = (compiled_fn, out_spec)
876

877
        cached_fn, out_spec = cached_res
878
        out = cached_fn(flat_args)
879
        return out_spec.unflatten(out)
880

881
    return returned_function
882

883

884
def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
885
    """
886
    Traces the forward and backward graph of :attr:`mod` using torch dispatch
887
    tracing mechanism. It is wrapper function, that underneath uses
888
    :func:`aot_function` to perform tracing and compilation.
889

890
    :func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs
891
    to a new callable which is then compiled through :func:`aot_function`.
892

893
    .. warning::
894
        This API is experimental and likely to change.
895

896
    Args:
897
        mod (Callable): A ``nn.Module`` module.
898
        args : args to be passed to :func:`aot_function`
899
        kwargs : kwargs to be passed to :func:`aot_function`
900

901
    Returns:
902
        Returns a ``nn.Module`` that retains the eager behavior of the original
903
        :attr:`mod`, but with forward and backward graph compiled.
904

905
    """
906
    # See Note: [Fake Modules and AOTAutograd]
907
    torch._dynamo.utils.assert_no_fake_params_or_buffers(mod)
908

909
    def functional_call(named_params, named_buffers, *args, **kwargs):
910
        params_and_buffers = {**named_params, **named_buffers}
911
        return torch.func.functional_call(mod, params_and_buffers, args, kwargs)
912

913
    named_params = dict(mod.named_parameters(remove_duplicate=False))
914
    named_buffers = dict(mod.named_buffers(remove_duplicate=False))
915
    num_params_buffers = len(named_params) + len(named_buffers)
916
    compiled_f = aot_function(
917
        functional_call, *args, num_params_buffers=num_params_buffers, **kwargs
918
    )
919

920
    class AOTModule(nn.Module):
921
        def __init__(self) -> None:
922
            super().__init__()
923
            self.orig_module = mod
924

925
        def forward(self, *args, **kwargs):
926
            return compiled_f(
927
                named_params,
928
                named_buffers,
929
                *args,
930
                **kwargs,
931
            )
932

933
    return AOTModule()
934

935

936
def aot_module_simplified(
937
    mod: nn.Module,
938
    args,
939
    fw_compiler: Callable,
940
    bw_compiler: Optional[Callable] = None,
941
    partition_fn: Callable = default_partition,
942
    decompositions: Optional[Dict] = None,
943
    keep_inference_input_mutations=False,
944
    inference_compiler: Optional[Callable] = None,
945
    cudagraphs: Optional[BoxedBool] = None,
946
) -> nn.Module:
947
    """
948
    This is the simplified or low overhead version of aot_module. For frontends
949
    like TorchDynamo, the input functions/modules to AOT are static and have
950
    unpacked inputs/outputs. This gives us an opportunity to remove the
951
        (1) pytree overhead to parse inputs/outputs,
952
        (2) AOT Autograd cache,
953
        (3) Reading of params/buffers in every forward call
954

955
    :func:`aot_module_simplified` removes these overheads.
956
    """
957
    params = {
958
        **dict(mod.named_parameters(remove_duplicate=False)),
959
        **dict(mod.named_buffers(remove_duplicate=False)),
960
    }
961
    params_flat, params_spec = pytree.tree_flatten(params)
962
    params_flat = list(params_flat)
963
    params_len = len(params_flat)
964

965
    if cudagraphs is None:
966
        cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs)
967

968
    if bw_compiler is None:
969
        bw_compiler = fw_compiler
970
    if inference_compiler is None:
971
        inference_compiler = fw_compiler
972

973
    seen_sources = set()
974

975
    full_args = []
976
    # First, the params
977
    full_args.extend(params_flat)
978

979
    if tracing_context := torch._guards.TracingContext.try_get():
980
        tracing_context.params_flat = params_flat
981

982
    aot_autograd_arg_pos_to_source = None
983
    # Then, the params 1:1 mapped sources, if relevant.
984
    if hasattr(mod, "_param_name_to_source"):
985
        aot_autograd_arg_pos_to_source = []
986
        # We now know this came from dynamo, and (1) we care about guards,
987
        # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards
988
        # can now be done safely. (2) Dynamo logic protects the 1:1 sizing below.
989
        for name in params.keys():
990
            assert name in mod._param_name_to_source, f"{name} not found."
991
            source = mod._param_name_to_source[name]
992
            assert source not in seen_sources, source
993
            seen_sources.add(source)
994
            aot_autograd_arg_pos_to_source.append(source)
995

996
    # Next, the input args
997
    full_args.extend(args)
998

999
    static_input_indices = []
1000
    if hasattr(mod, "graph"):
1001
        # Non dynamo entrypoints can get to here...
1002
        for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
1003
            if hasattr(node, "_dynamo_source"):
1004
                # ... but not here!
1005
                if aot_autograd_arg_pos_to_source is None:
1006
                    aot_autograd_arg_pos_to_source = []
1007
                source = node._dynamo_source
1008
                assert source not in seen_sources, source
1009
                seen_sources.add(source)
1010
                aot_autograd_arg_pos_to_source.append(source)
1011
                source_name = source.name() if source else str(source)
1012

1013
                if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
1014
                    "_dynamo_static_input_type", None
1015
                ):
1016
                    static_inputs_log.debug(
1017
                        "Adding static input pos %s for source %s", pos, source_name
1018
                    )
1019
                    static_input_indices.append(pos)
1020
                else:
1021
                    static_inputs_log.debug(
1022
                        "Non-static input pos %s for source %s", pos, source_name
1023
                    )
1024

1025
    if aot_autograd_arg_pos_to_source is not None:
1026
        assert len(full_args) == len(aot_autograd_arg_pos_to_source)
1027

1028
    dynamic_shapes = False
1029
    for x in full_args:
1030
        if isinstance(x, FakeTensor):
1031
            dynamic_shapes = x.fake_mode.shape_env is not None
1032
            break
1033

1034
    aot_config = AOTConfig(
1035
        fw_compiler=fw_compiler,
1036
        bw_compiler=bw_compiler,
1037
        inference_compiler=inference_compiler,
1038
        partition_fn=partition_fn,
1039
        decompositions=decompositions,
1040
        num_params_buffers=params_len,
1041
        aot_id=next(AOT_COUNTER),
1042
        keep_inference_input_mutations=keep_inference_input_mutations,
1043
        dynamic_shapes=dynamic_shapes,
1044
        aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source,
1045
        static_input_indices=static_input_indices,
1046
        is_export=False,
1047
        no_tangents=False,
1048
        cache_key=None,
1049
    )
1050
    fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
1051
    fake_flat_args = process_inputs(full_args, aot_config, fake_mode, shape_env)
1052

1053
    def dispatch_and_compile():
1054
        functional_call = create_functional_call(mod, params_spec, params_len)
1055
        with compiled_autograd.disable():
1056
            compiled_fn, _ = create_aot_dispatcher_function(
1057
                functional_call,
1058
                fake_flat_args,
1059
                aot_config,
1060
                fake_mode,
1061
                shape_env,
1062
            )
1063
        return compiled_fn
1064

1065
    # Autograd cache stuff
1066
    if config.enable_autograd_cache:
1067
        compiled_fn = AOTAutogradCache.load(
1068
            dispatch_and_compile, mod, fake_flat_args, aot_config, cudagraphs
1069
        )
1070
    else:
1071
        compiled_fn = dispatch_and_compile()
1072

1073
    if isinstance(mod, torch._dynamo.utils.GmWrapper):
1074
        # This function is called by the flatten_graph_inputs wrapper, which boxes
1075
        # the inputs so that they can be freed before the end of this scope.
1076
        # For overhead reasons, this is not the default wrapper, see comment:
1077
        # https://github.com/pytorch/pytorch/pull/122535/files#r1560096481
1078
        def boxed_forward(runtime_args: List[Any]):
1079
            flat_args = []
1080
            flat_args.extend(params_flat)
1081
            flat_args.extend(runtime_args)
1082
            runtime_args.clear()
1083
            return compiled_fn(flat_args)
1084

1085
        # Just for convenience
1086
        boxed_forward.zero_grad = mod.zero_grad
1087
        boxed_forward.named_parameters = mod.named_parameters
1088
        boxed_forward.named_buffers = mod.named_buffers
1089
        return boxed_forward
1090

1091
    # TODO: There is something deeply wrong here; compiled_fn running with
1092
    # the boxed calling convention, but aot_module_simplified somehow
1093
    # historically returned a function that was not the boxed calling
1094
    # convention.  This should get fixed...
1095
    # NB: GraphModule/nn.Module rely on the non-boxed calling convention here
1096
    def forward(*runtime_args: Tuple[Any]):
1097
        full_args = []
1098
        full_args.extend(params_flat)
1099
        full_args.extend(runtime_args)
1100
        return compiled_fn(full_args)
1101

1102
    # Just for convenience
1103
    forward.zero_grad = mod.zero_grad
1104
    forward.named_parameters = mod.named_parameters
1105
    forward.named_buffers = mod.named_buffers
1106

1107
    return forward
1108

1109

1110
def aot_export_module(
1111
    mod: nn.Module,
1112
    args,
1113
    *,
1114
    decompositions: Optional[Dict] = None,
1115
    # If true, we'll return a joint forward-backward graph,
1116
    # As well as metadata on the loss + gradients in the backward.
1117
    trace_joint: bool,
1118
    # If trace_joint is True, we expect your module to return a scalar loss.
1119
    # Your module can return multiple outputs, so you must specify which output the loss is.
1120
    output_loss_index: Optional[int] = None,
1121
    pre_dispatch: bool = False,
1122
    # If None, will be infered from inputs and mod.graph.nodes if mod is a graph module, but the inferred result might be wrong.
1123
    dynamic_shapes: Optional[bool] = None,
1124
    kwargs=None,
1125
) -> Tuple[torch.fx.GraphModule, GraphSignature]:
1126
    """
1127
    This function takes in a module, and returns:
1128
    (1) an FX graph that can be exported
1129
    (2) some metadata about the graph
1130

1131
    If `trace_joint=True` we will return a joint graph of the forward + backward.
1132

1133
    The traced FX graph will have the following properties compared to the original module:
1134
    (1) Inputs and outputs to the module will be pytree-flattened
1135
    (2) Parameters and buffers on the module will be lifted into graph inputs,
1136
        graph_inputs = (*parameters, *buffers, *user_inputs)
1137
    (3) The graph will be fully functionalized
1138
    (4) Any input mutations will be converted into additional outputs in the graph,
1139
        meaning whoever calls this graph is responsible for applying the mutations
1140
        back to the original inputs.
1141
    (5) If is_joint is provided the graph will return parameter gradients in addition to user outputs.
1142
        The graph output will look like:
1143
        graph_outputs = (*updated_inputs, *user_outputs, *param_gradients)
1144

1145
    There are also several restrictions on what modules can use this API. In particular:
1146
    (1) If trace_joint is specified, we expect the loss function to be **fused**
1147
        into the module forward. One of the outputs to the forward must be a scalar loss,
1148
        which is specified with `output_loss_index`.
1149
        All other outputs to the forward are presumed to not require gradients.
1150
    (2) This API cannot capture optimizers (although in theory we could build an API for this).
1151
    (3) Metadata mutations on params/buffers/inputs are banned.
1152
    (4) Data mutations on anything that requires gradients are banned (parameters)
1153
    (5) If an input is mutated, it is not allowed to alias any other inputs.
1154
    (6) Parameters must not be duplicated.
1155
    """
1156
    if pre_dispatch and trace_joint:
1157
        raise RuntimeError("pre_dispatch is not supported when trace_joint is True.")
1158
    named_parameters = dict(mod.named_parameters(remove_duplicate=False))
1159
    named_buffers = dict(mod.named_buffers(remove_duplicate=False))
1160

1161
    params_and_buffers = {
1162
        **dict(named_parameters),
1163
        **dict(named_buffers),
1164
    }
1165
    params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers)
1166
    params_and_buffers_flat = tuple(params_and_buffers_flat)
1167
    params_len = len(params_and_buffers_flat)
1168

1169
    kwargs = kwargs or {}
1170

1171
    functional_call = create_functional_call(
1172
        mod, params_spec, params_len, store_orig_mod=True
1173
    )
1174

1175
    num_fw_outs = None
1176

1177
    if trace_joint:
1178
        # This helper effectively just adds some extra asserts about what the backward will look like:
1179
        # Outputs must include a scalar loss, that we compute gradients w.r.t.
1180
        # We don't compute gradients w.r.t. anything else: so just in case we detach()
1181
        # and other output tensors.
1182
        def fn_to_trace(*args):
1183
            nonlocal num_fw_outs
1184
            out = functional_call(*args)
1185
            if output_loss_index is None:
1186
                raise RuntimeError(
1187
                    """\
1188
If trace_joint=Trueit is required that one of your forward outputs must be a scalar loss.
1189
You must specify the which (index) output is the loss with output_loss_index."""
1190
                )
1191
            if isinstance(out, (torch.Tensor)):
1192
                out = (out,)
1193
            if not isinstance(out, (tuple, list)):
1194
                raise RuntimeError(
1195
                    f"Expected forward output to be either a tensor or a list/tuple of tensors. found {type(out)}"
1196
                )
1197

1198
            for i, o in enumerate(out):
1199
                # We only want to create a backward graph w.r.t. the loss that the user passed in.
1200
                # This implies that every other output should not require gradients.
1201
                # Instead of making this an error (and forcing the user to detach all other outputs
1202
                # of their forward),
1203
                # we'll automatically detach them here.
1204
                if o.requires_grad and i != output_loss_index:
1205
                    raise RuntimeError(
1206
                        f"""\
1207
Found an output of the forward that requires gradients, that was not the scalar loss.
1208
We require all outputs to the forward that are not the scalar loss to not require gradient,
1209
because we will only compute a backward graph against the scalar loss.
1210
You can fix this by calling .detach() on each of your forward outputs that is not the loss.
1211
You specified that output index {output_loss_index} is the loss, but we found that
1212
the output at index {i} requires gradients."""
1213
                    )
1214
            out_loss = out[output_loss_index]
1215
            num_fw_outs = len(out)
1216
            if not out_loss.requires_grad:
1217
                raise RuntimeError(
1218
                    f"""\
1219
The output at index {output_loss_index} was marked as the loss, but it does not require gradients"""
1220
                )
1221
            if out_loss.numel() != 1:
1222
                raise RuntimeError(
1223
                    f"""\
1224
We require the output marked as the loss (at index {output_loss_index}) to be a scalar, but it has shape {out_loss.shape}"""
1225
                )
1226
            return out
1227

1228
        ctx = nullcontext
1229
    else:
1230
        # Run under no_grad, so our tracing machinery only traces an inference graph.
1231
        # However if pre_dispatch=True, we want to correctly trace set_grad_enabled calls for training.
1232
        ctx = nullcontext if pre_dispatch else torch.no_grad
1233
        fn_to_trace = functional_call
1234

1235
    full_args = []
1236
    # First, the params
1237
    # NB: It is REQUIRED that parameters come first, Inductor infers "fixed"
1238
    # parameters by looking at the difference in parameter count outside
1239
    # and inside AOTAutograd, and assumes the prefix of arguments are fixed
1240
    # arguments
1241
    full_args.extend(params_and_buffers_flat)
1242
    # Next, the input args
1243
    full_args.extend(args)
1244

1245
    with ctx():
1246
        fx_g, metadata, in_spec, out_spec = _aot_export_function(
1247
            fn_to_trace,
1248
            full_args,
1249
            decompositions=decompositions,
1250
            num_params_buffers=params_len,
1251
            no_tangents=True,
1252
            pre_dispatch=pre_dispatch,
1253
            dynamic_shapes=dynamic_shapes,
1254
            kwargs=kwargs,
1255
        )
1256
    if trace_joint:
1257

1258
        def flattened_joint(*args):
1259
            # The idea here is that the joint graph that AOTAutograd creates has some strict properties:
1260
            # (1) It accepts two arguments (primals, tangents), and pytree_flattens them
1261
            # (2) It returns a tuple of (fw_outs, gradients)
1262
            # This is a very useful convention for anyone who wants to partition the joint graph
1263
            # into a separate forward and backward graph.
1264
            # However,
1265
            # (1) for people exporting a single joint graph, it would be preferable not to have
1266
            #     any pytrees in the graph.
1267
            # (2) We are guaranteed in the aot_export_module case that the forward outputs a loss,
1268
            #     and there are therefore no tangents that are needed to run the joint graph.
1269
            # (3) AOTAutograd creates a grad_input for every input in the forward,
1270
            #     including None's for inputs that are not grad-requiring tensors.
1271
            #     we don't want these in our export graph.
1272
            #     and there are therefore no tangents that are needed to run the joint graph.
1273
            # This function "fixes" both of the above by removing any tangent inputs,
1274
            # and removing pytrees from the original FX graph.
1275
            fake_tangents = [
1276
                None
1277
                for _ in range(
1278
                    metadata.num_outputs + metadata.num_mutated_inp_runtime_indices
1279
                )
1280
            ]
1281
            fw_outs, gradients = fx_g(args, fake_tangents)
1282
            assert len(gradients) == len(args)
1283
            output_gradients = []
1284
            for i, (a, grad) in enumerate(zip(args, gradients)):
1285
                if isinstance(a, torch.Tensor) and a.requires_grad:
1286
                    assert (
1287
                        grad is not None
1288
                    ), """\
1289
Found a parameter that did not receive a gradient.
1290
"This is most likely a bug, but if this needs to be supported please comment on this Github issue:
1291
https://github.com/pytorch/pytorch/issues/101192
1292
"""
1293
                    output_gradients.append(grad)
1294
                else:
1295
                    assert grad is None
1296
            return *fw_outs, *output_gradients
1297

1298
        fx_g = make_fx(flattened_joint)(*full_args)
1299

1300
    user_args_flat = pytree.arg_tree_leaves(*args, **kwargs)
1301
    return fx_g, create_graph_signature(
1302
        fx_g,
1303
        metadata,
1304
        in_spec,
1305
        out_spec,
1306
        user_args_flat=user_args_flat,
1307
        params_and_buffers_flat=params_and_buffers_flat,
1308
        param_names=list(named_parameters.keys()),
1309
        buffer_names=list(named_buffers.keys()),
1310
        trace_joint=trace_joint,
1311
        num_user_fw_outs=num_fw_outs,
1312
        loss_index=output_loss_index,
1313
    )
1314

1315

1316
def aot_export_joint_simple(
1317
    func: Callable,
1318
    args,
1319
    *,
1320
    trace_joint: bool,
1321
    # It looks like the main consequence of this API is that for dynamic shapes,
1322
    # it will assume that parms/buffers are static.
1323
    # With the new inferred dynamic shapes API, maybe this doesn't matter?
1324
    num_params_buffers: int = 0,
1325
    decompositions: Optional[Dict] = None,
1326
) -> torch.fx.GraphModule:
1327
    """
1328
    A simplified version of export. Used by higher order operators.
1329

1330
    This function makes a high-level "no calling convention changes" guarantee:
1331
    - If no inputs require grad (so we export an inference graph),
1332
      there are *no* calling convention change between the exported graph, and "func".
1333
    - If at least one input requires grad (so we trace out and export a joint fw-bw graph),
1334
      Then if you were partition the graph into a separate forward and backward graph,
1335
      The forward graph will have no calling convention changes compared to "func".
1336

1337
    The above also relies on some strong restrictions around which functions this API accepts:
1338
    (1) `args` cannot contain any pytrees (they must have been pytree_flattened already)
1339
    (2) `func` cannot mutate any inputs
1340
    (3) The outputs of `func` cannot alias any inputs.
1341

1342
    Note: this function is only lightly tested today. It will probably be tested more heavily by higher order ops.
1343
    """
1344
    if trace_joint:
1345
        ctx = nullcontext
1346
    else:
1347
        # Run under no_grad, so our tracing machinery only traces an inference graph.
1348
        ctx = torch.no_grad
1349

1350
    with ctx():
1351
        fx_g, metadata, in_spec, out_spec = _aot_export_function(
1352
            func,
1353
            args,
1354
            decompositions=decompositions,
1355
        )
1356
        in_spec, _kw_in_spec = in_spec.children_specs
1357
    # At this point, we can just directly return the (joint or inference graph) that we traced.
1358
    # First though: a bunch of assertions to make sure that our graph doesn't require
1359
    # any calling convention changes compared to the original function.
1360
    # These restrictions are *in addition to* the general restrictions on export.
1361

1362
    # No input mutations
1363
    if (
1364
        len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata])
1365
        != 0
1366
    ):
1367
        raise RuntimeError(
1368
            f"aot_export_joint_simple does not support input mutations. {str(metadata)}"
1369
        )
1370
    # No output aliasing
1371
    if (
1372
        len([x for x in metadata.output_info if x.output_type != OutputType.non_alias])
1373
        != 0
1374
    ):
1375
        raise RuntimeError(
1376
            f"aot_export_joint_simple does not support outputs that alias inputs. {str(metadata)}"
1377
        )
1378
    # No pytrees
1379
    if in_spec.is_leaf():
1380
        raise RuntimeError(
1381
            f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
1382
        )
1383
    if not all(child.is_leaf() for child in in_spec.children_specs):
1384
        raise RuntimeError(
1385
            f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
1386
        )
1387
    if out_spec.is_leaf():
1388
        raise RuntimeError(
1389
            f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
1390
        )
1391
    if not all(child.is_leaf() for child in out_spec.children_specs):
1392
        raise RuntimeError(
1393
            f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
1394
        )
1395
    # TODO: we might have to temporarily patch config.functionalize_rng
1396
    # so that it doesn't run when we're exporting a higher order op.
1397

1398
    if config.debug_assert:
1399
        # Smoke test that after partitioning, we can run the forward without any calling convention changes.
1400
        fw_module, bw_module = aot_config.default_partition(  # noqa: F821
1401
            fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos)  # noqa: F821
1402
        )
1403
        # Attempt to run the fw_module with the original user inputs
1404
        fake_mode = detect_fake_mode(args)
1405
        if fake_mode is None:
1406
            fake_mode = FakeTensorMode()
1407
        with fake_mode:
1408
            fw_module(*args)
1409
    return fx_g
1410

1411

1412
# Private for now because we aren't providing a contract on what to return
1413
# for joint graphs (we could when there's a clearer use case)
1414
# In the future, we may need to add more export API's that provide their own strong guarantees.
1415
# This is meant as a general helper function for handling various export-y use cases.
1416
def _aot_export_function(
1417
    func: Callable,
1418
    args,
1419
    *,
1420
    num_params_buffers: int = 0,
1421
    decompositions: Optional[Dict] = None,
1422
    # If we're exporting a joint graph and we don't want any tangent inputs in the graph
1423
    # (because we are backpropping through a scalar 1 loss),
1424
    # we need to explicitly specify not to include tangents in the graph.
1425
    # It's not enough just to check that our tangent is a scalar, since we also
1426
    # need to know if it is a 1 (no need to make it a graph input), or something else
1427
    # (requiring it to be a graph input).
1428
    # We don't know this info at trace time though, so we need to make it an explicit config.
1429
    no_tangents: bool = False,
1430
    pre_dispatch: bool = False,
1431
    # If None, `dynamic_shapes` will be infered from inputs, but the inferred result might be wrong.
1432
    dynamic_shapes: Optional[bool] = None,
1433
    kwargs=None,
1434
) -> Tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:
1435
    kwargs = kwargs or {}
1436

1437
    flat_fn, out_spec = create_tree_flattened_fn(func, args, kwargs)
1438
    flat_args, in_spec = pytree.tree_flatten((args, kwargs))
1439

1440
    if dynamic_shapes is None:
1441
        # Try to infer `dynamic_shapes from inputs and graph nodes
1442
        fake_mode = detect_fake_mode(flat_args)
1443
        if (
1444
            fake_mode is None
1445
            and hasattr(func, "_orig_mod")
1446
            and isinstance(func._orig_mod, torch.fx.GraphModule)
1447
        ):
1448
            vals = [
1449
                node.meta["val"]
1450
                for node in func._orig_mod.graph.nodes
1451
                if "val" in node.meta
1452
            ]
1453
            fake_mode = detect_fake_mode(vals)
1454
        dynamic_shapes = fake_mode is not None and fake_mode.shape_env is not None
1455

1456
    # The export use case doesn't care about several bits of AOTConfig
1457
    # (1) compilers (we just export the graph)
1458
    # (2) partitioners (export is only full graph, user can partition themselves)
1459
    aot_config = AOTConfig(
1460
        fw_compiler=None,
1461
        bw_compiler=None,
1462
        inference_compiler=None,
1463
        partition_fn=None,
1464
        decompositions=decompositions,
1465
        num_params_buffers=num_params_buffers,
1466
        aot_id=next(AOT_COUNTER),
1467
        # For now there's no use case involving keeping input mutations in the graph
1468
        # (which we can only do in the inference case anyway).
1469
        # We can add this later if we need to.
1470
        keep_inference_input_mutations=False,
1471
        dynamic_shapes=dynamic_shapes,
1472
        aot_autograd_arg_pos_to_source=None,
1473
        is_export=True,
1474
        no_tangents=no_tangents,
1475
        pre_dispatch=pre_dispatch,
1476
    )
1477
    fake_mode, shape_env = construct_fake_mode(flat_args, aot_config)
1478
    fake_flat_args = process_inputs(flat_args, aot_config, fake_mode, shape_env)
1479

1480
    fx_g, meta = create_aot_dispatcher_function(
1481
        flat_fn,
1482
        fake_flat_args,
1483
        aot_config,
1484
        fake_mode,
1485
        shape_env,
1486
    )
1487
    return fx_g, meta, in_spec, out_spec.spec
1488

1489

1490
@contextmanager
1491
def _detect_attribute_assignment(mod: torch.nn.Module):
1492
    # Do not allow assignment of tensor attributes during export unless
1493
    # the attribute is registered as a buffer.
1494

1495
    STD_ATTRS = {
1496
        "_backward_hooks",
1497
        "_backward_pre_hooks",
1498
        "_buffers",
1499
        "_forward_hooks",
1500
        "_forward_hooks_always_called",
1501
        "_forward_hooks_with_kwargs",
1502
        "_forward_pre_hooks",
1503
        "_forward_pre_hooks_with_kwargs",
1504
        "_is_full_backward_hook",
1505
        "_load_state_dict_post_hooks",
1506
        "_load_state_dict_pre_hooks",
1507
        "_modules",
1508
        "_non_persistent_buffers_set",
1509
        "_parameters",
1510
        "_state_dict_hooks",
1511
        "_state_dict_pre_hooks",
1512
        "training",
1513
    }
1514

1515
    def _get_attributes(mod):
1516
        # return any attributes of a module that are not standard attributes
1517
        return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS}
1518

1519
    # save state of attributes before enter
1520
    snapshot = pytree.tree_map(lambda x: x, _get_attributes(mod))
1521
    try:
1522
        yield
1523
    finally:
1524
        # after exit, compare state of attributes with snapshot
1525
        # to detect which tensor attributes were assigned
1526
        assigned_tensor_attributes = []
1527

1528
        def _collect_assigned_tensor_attributes(kp, v, _v):
1529
            if _v is not v:
1530
                attr, *rest = kp
1531
                if isinstance(v, torch.Tensor):
1532
                    assigned_tensor_attributes.append(
1533
                        f"self.{attr.key}{pytree.keystr(rest)}"
1534
                    )
1535
                # TODO(avik): Assigning all other types are allowed right now.
1536
                # Maybe in the future we want to limit this to primitive types?
1537

1538
        pytree.tree_map_with_path(
1539
            _collect_assigned_tensor_attributes, snapshot, _get_attributes(mod)
1540
        )
1541
        # restore state of all attributes (including, e.g., of primitive types)
1542
        mod.__dict__.update(snapshot)
1543

1544
        if assigned_tensor_attributes:
1545
            if len(assigned_tensor_attributes) > 1:
1546
                noun, verb = "attributes", "were"
1547
            else:
1548
                noun, verb = "attribute", "was"
1549
            raise ValueError(
1550
                f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. "
1551
                "Such attributes must be registered as buffers using the `register_buffer` API "
1552
                "(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
1553
            )
1554

1555

1556
compiled_function = aot_function
1557
compiled_module = aot_module
1558

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.