pytorch
1557 строк · 64.2 Кб
1# mypy: ignore-errors
2
3import itertools4from contextlib import contextmanager, nullcontext5from functools import partial, wraps6from typing import Any, Callable, Dict, List, NewType, Optional, Tuple7from unittest.mock import patch8
9import torch10import torch._dynamo.logging11import torch.nn as nn12import torch.utils._pytree as pytree13import torch.utils.dlpack14from torch import Tensor15from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions16from torch._dispatch.python import enable_python_dispatcher17from torch._dynamo import compiled_autograd18from torch._dynamo.utils import dynamo_timed, preserve_rng_state19from torch._guards import detect_fake_mode20from torch._inductor.utils import BoxedBool21from torch._subclasses import FakeTensor, FakeTensorMode22from torch.fx.experimental.proxy_tensor import make_fx23from torch.fx.experimental.symbolic_shapes import ShapeEnv24from torch.utils._python_dispatch import is_traceable_wrapper_subclass25
26
27static_inputs_log = torch._logging.getArtifactLogger(28__name__, "cudagraph_static_inputs"29)
30
31from . import config32from ._aot_autograd.autograd_cache import ( # noqa: F40133AOTAutogradCache,34autograd_cache_key,35)
36from ._aot_autograd.collect_metadata_analysis import ( # noqa: F40137run_functionalized_fw_and_collect_metadata,38)
39from ._aot_autograd.functional_utils import ( # noqa: F40140_check_if_mutation_can_be_in_graph,41are_all_mutations_hidden_from_autograd,42are_all_mutations_under_no_grad_or_inference_mode,43assert_functional_graph,44from_fun,45gen_alias_from_base,46has_data_mutation,47has_metadata_mutation,48is_fun,49sync_functional_tensor,50to_fun,51)
52from ._aot_autograd.input_output_analysis import ( # noqa: F40153_tensors_definitely_do_not_overlap,54compute_overlapping_inputs,55create_graph_signature,56create_synthetic_base_metadata,57remove_dupe_metadata,58)
59from ._aot_autograd.jit_compile_runtime_wrappers import ( # noqa: F40160aot_dispatch_autograd,61aot_dispatch_base,62aot_dispatch_export,63)
64from ._aot_autograd.logging_utils import ( # noqa: F40165callback_set,66describe_input,67format_guard_bug_msg,68get_aot_compilation_context,69get_aot_graph_name,70get_graph_being_compiled,71graph_being_compiled,72model_name,73nth_graph,74set_model_name,75setup_stacktrace_preservation_hooks,76track_graph_compiling,77)
78from ._aot_autograd.runtime_wrappers import ( # noqa: F40179AOTDedupeWrapper,80AOTSyntheticBaseWrapper,81)
82from ._aot_autograd.schemas import ( # noqa: F40183AOTConfig,84BackwardSignature,85FQN,86GraphInputName,87GraphOutputName,88GraphSignature,89InputAliasInfo,90MutationType,91OutputAliasInfo,92OutputType,93SubclassCreationMeta,94SubclassMeta,95TensorAlias,96ViewAndMutationMeta,97)
98from ._aot_autograd.subclass_utils import ( # noqa: F40199create_metadata_for_subclass,100requires_subclass_dispatch,101unwrap_tensor_subclasses,102wrap_tensor_subclasses,103wrap_tensor_subclasses_maybe_joint,104)
105from ._aot_autograd.traced_function_transforms import ( # noqa: F401106aot_dispatch_subclass,107create_functional_call,108create_functionalized_fn,109create_functionalized_rng_ops_wrapper,110create_joint,111fn_input_mutations_to_outputs,112fn_prepped_for_autograd,113)
114from ._aot_autograd.utils import ( # noqa: F401115_get_autocast_states,116_get_symint_hints,117call_func_at_runtime_with_args,118create_tree_flattened_fn,119KNOWN_TYPES,120make_boxed_compiler,121make_boxed_func,122maybe_to_fresh_input,123normalize_as_list,124partial_flatten_asdict,125root_module_when_exporting_non_strict,126strict_zip,127)
128from .partitioners import default_partition129
130
131zip = strict_zip132
133# This global counter increments every time we compile a graph with
134# AOTAutograd. You can use this to correlate runtime error messages
135# with compile time (e.g., if you get an error at runtime saying
136# compiled graph 3 failed, you can set a breakpoint at compile time
137# for this graph number to investigate further at compile time.)
138#
139# NB: this is different from get_aot_compilation_context, which tracks
140# each underlying graph that is compiled. In contrast, AOT_COUNTER
141# corresponds to top-level invocations of aot_module/aot_function;
142# one counter is allocated per entire compiled block (but this block
143# may involve compiling multiple subgraphs; e.g., for forwards/backwards)
144AOT_COUNTER = itertools.count()145
146# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
147# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
148#
149# AOT Autograd contains a pretty non-trivial amount of logic to handle edge cases around aliasing and mutation
150# that are external to the graph (they show up as side effects in some way when you run the graph).
151#
152# Take a look at `test_aotdispatch.py TestAOTAutograd.test_input_mutation*` tests for some examples functions
153# and what they're compiled graphs looks like.
154# Below is a very long comment detailing several edge cases, and showing how AOT Autograd handles them.
155#
156# Note [AOT Autograd: input data mutations]
157#
158# If we compile a function that mutates inputs, then those input mutations are real side effects
159# that a user expects to see after running the compiled graph.
160# However, the graph that we want to send to a backend needs to be *entirely* functional.
161# The way we reconcile this difference is that we remove the mutations completely from the graph that we compile
162# but we update the graph to return (updated_inputs, user_outputs).
163# In the epilogue that runs after the compiled graph is executed, we copy the updated inputs back to the originals.
164#
165# Example: original user code:
166# def f(x):
167# x.mul_(2)
168# out = x.mul(3)
169# return out
170#
171# After AOT Autograd compiles, we end up with a:
172# (a) compiled graph
173# (b) autograd.Function.forward() method, that executes the compiled graph
174# (c) wrapper function, that calls the autograd.Function.forward() and performs the epilogue
175#
176# The output of (a, b, c) are all written below.
177#
178# def compiled_forward_graph(x):
179# x_updated = x.mul(2)
180# out = x_updated.mul(3)
181# return x_updated, out
182#
183# # x_updated gets a gradient in the compiled backward
184# def compiled_backward_graph(grad_x_updated, grad_out):
185# grad_x = ...
186# return grad_x
187#
188# def autograd.Function.forward(x):
189# x_updated, out = compiled_forward_graph(x)
190# return x_updated, out
191#
192# def compiled_wrapper(x):
193# x_updated, out = autograd.Function.apply(x)
194# x.copy_(x_updated)
195# return out
196#
197# Another important thing to note is that updated inputs (due to data mutations) *do* participate
198# in the compiled backward graph! Since the compiled forward graph gets N extra outputs
199# (due to updated inputs showing up as graph outputs),
200# The compiled backward gets an additional N inputs.
201# That way, during the x.copy_(x_updated) bit in the epilogue, gradients will flow from the updated input
202# back to the original input.
203
204
205# Note [AOT Autograd: input metadata mutations]
206#
207# For the same reason as input mutations, we also don't put input metadata mutations in the graph.
208# Instead, we return the updated version of the input (a view), and mutate the input's metadata outside of the graph
209#
210# Example: original user code:
211# def f(x):
212# x.t_()
213# out = x.mul(3)
214# return out
215#
216# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
217# def compiled_forward_graph(x):
218# x_updated = x.t()
219# out = x_updated.mul(3)
220# return x_updated, out
221#
222# # x_updated does *not* get a gradient in the compiled backward
223# def compiled_backward_graph(grad_out):
224# grad_x = ...
225# return grad_x
226#
227# def autograd.Function.forward(x):
228# x_updated, out = compiled_forward_graph(x)
229# return x_updated, out
230#
231# def compiled_wrapper(x):
232# x_updated, out = autograd.Function.apply(x)
233# x.as_strided_(x_updated)
234# return out
235
236
237# Note [AOT Autograd: outputs aliasing inputs or intermediates!]
238#
239# AOT Autograd needs special handling for outputs that alias graph inputs or intermediates!
240# Why?
241# (1) autograd.Function.forward() has a limitation, where views that returned in the forward cannot later be mutated.
242# (2) views don't need to be compiled in the graph anyway - it's cheap to generate them outside of the compiled graph,
243# in an epilogue.
244# For outputs that alias inputs, we do the following:
245# (a) *still* return the aliased output as a graph output
246# (b) In the AOT Autograd wrapper/epilogue, we don't return that aliased output. Instead, we use it to regenerate the output.
247#
248# For outputs that alias *intermediates*, we do the following:
249# (a) Return the output in the compiled forward, **and** return it's ._base (a graph intermediates) as an output in the forward
250# (b) Use (output, graph_intermediate) to regenerate the alias, and return that to the user (instead of the compiled fw output).
251# You might wonder why we return the aliased output directly in the graph (and making the graph compute it),
252# only to not return it and instead generate a fresh alias off of the intermediate,
253# instead of (say) just storing metadata about the size/stride of the output somewhere to generate the alias. There are two reasons:
254# (1) Getting the actual alias tensor allows us to use view-replay to generate the alias, instead of an as_strided() call
255# (2) Inductor (and other backends) are free to change the memory format of graph outputs, if it results in better performance.
256# This can result in problems if a user later tries to .view() that output expecting it to have one set of strides,
257# when it has a different set of strides.
258# By including the view op directly in the graph, inductor takes that into account when deciding what memory format
259# the graph intermediate should be.
260#
261# Another important thing to note is how our traced backward() graph handles aliases.
262# (this applies to outputs aliasing inputs, outputs aliasing intermediates,
263# *and* updated inputs returned in the compiled forward due to metadata-only mutations).
264# Any outputs that alias (either inputs or intermediates) do NOT participate in the compiled backward graph
265# It would be wasteful to include them in the compiled backward(), because we regenerate them eagerly
266# at the end of the forward.
267#
268# Example: original user code:
269# def f(x):
270# out1 = x.t()
271# intermediate = x.mul(2)
272# out2 = intermediate.view(-1)
273# return out1, out2
274#
275# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
276# def compiled_forward_graph(x):
277# out1 = x.t()
278# intermediate = x.mul(2)
279# out2 = intermediate.view(-1)
280# # the compiled graph also returns the intermediate
281# return out1, out2, intermediate
282#
283# # intermediate gets a gradient in the compiled backward.
284# # both output aliases (out1 and out2) do not.
285# def compiled_backward_graph(grad_intermediate):
286# grad_x = ...
287# return grad_x
288#
289# def autograd.Function.forward(x):
290# out1, out2, intermediate = compiled_forward_graph(x)
291# return out1, out2, intermediate
292#
293# def compiled_wrapper(x):
294# out1, out2, intermediate = autograd.Function.apply(x)
295# # regenerate out1 from the input
296# out1_regenerated = out1._view_func(x)
297# # regenerate out1 from the intermediate
298# out2_regenerated = out2._view_func(intermediate)
299# return out1_regenerated, out2_regenerated
300
301
302# Note [AOT Autograd: mutations to inputs that alias other inputs]
303#
304# Another edge case that is (only partially) handled today is when an input is mutated, but itself aliases another input.
305# AOT Autograd needs to **ensure** that functionalization knows that the two inputs are aliased to each other.
306# That way, when the aliased input is accessed later in the graph, functionalization knows to "update" the alias
307# given the mutation that occurred.
308#
309# This is handled by updating the calling convention: we create a "synthetic base" that becomes a new input
310# in the compiled function, and we regenerate the original (aliased) inputs directly off of the base
311# inside of the compiled function.
312#
313# This logic is fully encapsulated in aot_wrapper_synthetic_base()
314#
315# Example: original user code:
316# def f(x, x_view):
317# x.mul_(2)
318# out = x * x_view
319# return out
320# f(x, x.view(-1))
321#
322# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
323# def compiled_forward_graph(base)
324# x = generate_x(base)
325# x_view = generate_x_view(base)
326# x_updated = x.mul(2)
327# x_view_updated = x_updated.view(-1)
328# out = x_updated * x_view_updated
329# return x_updated, out
330#
331# # The calling convention change from (aliases) -> (base) happens
332# # *outside* of the autograd.Function.forward().
333# # That means the forward() only has 1 input (base),
334# # and the backward() only has 1 output (grad_base)
335# def compiled_backward_graph(grad_out):
336# grad_base = ...
337# return grad_base
338#
339# def autograd.Function.forward(base):
340# x_updated, out = compiled_forward_graph(base)
341# return x_updated, out
342#
343# # The compiled wrapper is where we create synthetic bases.
344# # The info on which inputs are mutated is also tracked *before* synthetic base creation.
345# def compiled_wrapper(x, x_view):
346# base = merge_view_inputs(x, x_view)
347# x_updated, out = autograd.Function.apply(base)
348# # x and x_view are aliased in eager mode, so this mutation to x will automatically affect x_view.
349# x.copy_(x_updated)
350# return out
351
352
353# Note [AOT Autograd: Views to avoid tangents aliasing inputs]
354#
355# We view every forward output when creating out tangent tensors to handle the problematic
356# case in which a subclass does extra aliasing between graph outputs/inputs in a way that
357# is not visible above the sublass.
358#
359# Ordinarily, when constructing the joint function that we want to trace in AOTAutograd,
360# we're guaranteed that the tangent tensors that we pass
361# into the joint are distinct tensors from the primals. This is because when
362# decide which forward outputs to create tangents for, we only create tangents
363# for forward outputs that are not aliases of inputs (See Note
364# [AOT Autograd: outputs aliasing inputs or intermediates!]).
365#
366# However, when wrapper tensor subclasses enter the picture, it is possible
367# to have an output of the forward that is a subclass that is not an
368# input / alias of an input, but one of its inner tensors is an alias!
369# NestedTensor is an example: Performing an out-of-place pointwise op on a
370# NestedTensor constructs a fresh NestedTensor that holds onto the input's
371# offsets tensor directly.
372#
373# Having tangent tensors that are the same as the (primal) forward inputs,
374# can cause problems during tracing as make_fx() will specialize on our
375# duplicate inputs: If we passed in the same tensor for primals_1 and
376# tangents_1 during tracing, make_fx() will happily sub out all usages of
377# tangents_1 with primals_1 in the graph, which is not what we want.
378#
379# To work around this, we view every forward output when creating out tangent
380# tensors so that tangents can never be the same as forward inputs even if
381# forward inputs alias forward outputs.
382
383# Note [Side-Effectful Tokens in AOTAutograd]
384#
385# We allow some some side-effectful operators in
386# the post-AOTAutograd (functional) graph, such as prints and torchbind operations.
387# To ensure that these side-effects are compatible to future graph passes that
388# assume that the graph is functional, we will thread "effect tokens" to show
389# data dependence between these side-effectful operators. Practically speaking,
390# effect tokens are just dummy values (torch.tensor([])). The graph would look
391# like the following:
392#
393# def gm(self, token0, reader):
394# token1, frame = with_token(ordered_effect_op, (reader,), token0)
395# frame = frame * 2
396# token2, frame2 = with_token(ordered_effect_op, (reader,), token1)
397# frame2 = frame2 * 2
398# return token2, frame, frame2
399#
400# We will pass the token as an input to the graph, thread it through
401# side-effectful operators using the `with_effects` high order operator, and then
402# return the updated token as an output.
403# So the signature of the graph input would look something like
404# (*tokens, *params_buffers, *user_inputs), and the signature of the graph
405# output would look something like (*tokens, *outputs).
406#
407# However, Inductor does not want the concept of tokens in the final generated
408# code's input and output. Since changing the graph signature inside of inductor
409# is difficult, after generating the forward graph, we will run a pass to
410# remove the tokens from the inputgenerate the following graph for Inductor, where
411# the tokens are created and sunk within the graph, rather than as inputs and
412# outputs:
413#
414# def gm(self, reader):
415# token0 = torch.ops.prims._make_token()
416# token1, frame = with_token(ordered_effect_op, (reader,), token0)
417# frame = frame * 2
418# token2, frame2 = with_token(ordered_effect_op, (reader,), token1)
419# frame2 = frame2 * 2
420# sink_token = torch.ops.prims._sink_tokens([token2])
421# return frame, frame2
422
423#
424#
425# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
426# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
427
428
429aot_autograd_decompositions = {}430
431FakifiedFlatArgs = NewType("FakifiedFlatArgs", List[Any])432
433
434def process_inputs(435flat_args: List[Any],436aot_config: AOTConfig,437fake_mode: FakeTensorMode,438shape_env: Optional[ShapeEnv],439) -> FakifiedFlatArgs:440with fake_mode:441
442def convert(idx, x):443if shape_env is not None:444from torch._dynamo.source import ConstantSource445
446if isinstance(x, int):447# We always specialize on scalar values in export.448if aot_config.is_export:449return x450source = ConstantSource(f"sym_{idx}")451return shape_env.create_symintnode(452shape_env.create_symbol(x, source), hint=x, source=source453)454if isinstance(x, torch.ScriptObject):455return torch._library.fake_class_registry.maybe_to_fake_obj(456fake_mode, x457)458if not isinstance(x, torch.Tensor):459return x460if isinstance(x, FakeTensor):461assert x.fake_mode is fake_mode462return x463if is_traceable_wrapper_subclass(x):464attrs, _ = x.__tensor_flatten__()465if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs):466assert all(467getattr(x, attr).fake_mode is fake_mode for attr in attrs468)469return x470
471# see note [Tensor Fakification and Symbol Caching]472symbolic_context = None473source = None474trace = True475if tracing_context := torch._guards.TracingContext.try_get():476if x in tracing_context.tensor_to_context:477symbolic_context = tracing_context.tensor_to_context[x]478source = symbolic_context.tensor_source479# We already fakeified this tensor in Dynamo, don't480# dump the trace for it again481trace = False482if (483idx < aot_config.num_params_buffers484and config.static_weight_shapes485and not symbolic_context486):487# TODO: Ensure that this codepath is never exercised from488# Dynamo489return fake_mode.from_tensor(x, static_shapes=True)490
491return fake_mode.from_tensor(492x,493static_shapes=False,494symbolic_context=symbolic_context,495source=source,496trace=trace,497)498
499return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)])500
501
502def construct_fake_mode(503flat_args: List[Any], aot_config: AOTConfig504) -> Tuple[FakeTensorMode, Optional[ShapeEnv]]:505fake_mode = detect_fake_mode(flat_args)506if fake_mode is None:507shape_env = ShapeEnv() if aot_config.dynamic_shapes else None508fake_mode = FakeTensorMode(shape_env=shape_env)509else:510shape_env = fake_mode.shape_env511return (fake_mode, shape_env)512
513
514def create_aot_dispatcher_function(515flat_fn,516fake_flat_args: FakifiedFlatArgs,517aot_config: AOTConfig,518fake_mode: FakeTensorMode,519shape_env: Optional[ShapeEnv],520) -> Tuple[Callable, ViewAndMutationMeta]:521with dynamo_timed("create_aot_dispatcher_function"):522return _create_aot_dispatcher_function(523flat_fn, fake_flat_args, aot_config, fake_mode, shape_env524)525
526
527def _create_aot_dispatcher_function(528flat_fn,529fake_flat_args: FakifiedFlatArgs,530aot_config: AOTConfig,531fake_mode: FakeTensorMode,532shape_env: Optional[ShapeEnv],533) -> Tuple[Callable, ViewAndMutationMeta]:534"""535Traces the forward and backward graphs of the attr:`flat_fn` to generate a
536joint graph. The joint graph is an Fx graph with Aten ops. Please refer to
537the tracing mechanism to understand the graph capturing details.
538
539The joint graph is then passed through attr:`partition_fn` to isolate the
540forward and backward portions, which are then respectively compiled via the
541provided attr:`fw_compiler` and attr:`bw_compiler`.
542
543The resulting compiled forward and backward graphs are then wrapped up in a
544``torch.autograd.Function`` object.
545
546The calling convention here is that the first aot_config.num_params_buffers
547inputs in flat_args are parameters and buffers, and the rest are inputs.
548
549We use this to assume that parameters/buffer's shapes don't change.
550
551Note: this function is used both by aot_function and aot_export (controlled by aot_config.is_export)
552When aot_config.is_export is True, we return an FX graph + metadata
553When aot_config.is_export is False, we return an ordinary runtime function
554"""
555
556# This is the main entry point.557# TODO: Chillee argues that dynamo itself should pass in fake tensors to558# the list of arguments when compiling; at the moment we do not do this559
560if aot_config.decompositions is None:561aot_config.decompositions = {}562
563aot_config.decompositions = {564**aot_autograd_decompositions,565**aot_config.decompositions,566}567
568if config.functionalize_rng_ops:569# Update the decompositions with functionalized random decompositions570aot_config.decompositions = {571**rng_decompositions,572**aot_config.decompositions,573}574
575# Check flat_args to see if they're already fake. If so, use that fake576# mode instead.577
578python_dispatcher_mode = (579enable_python_dispatcher() if shape_env is not None else nullcontext()580)581
582# See NOTE: [Deferring tensor pack/unpack hooks until runtime]583# If any saved tensor hooks are active, we **don't** want to trace them.584# Instead, we'll let them run at runtime, around the custom autograd.Function585# that we generate in torch.compile.586with torch.autograd.set_multithreading_enabled(587False588), preserve_rng_state(), (589fake_mode
590), (591python_dispatcher_mode
592), PhiloxStateTracker(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():593from torch._library.fake_class_registry import (594FakeScriptObject,595maybe_to_fake_obj,596)597
598# Tracing may mutate the states the fake script object,599# so we need to duplicate the fake script objects so that subsequent tracing600# won't be affected.601def _dup_fake_script_obj(fake_flat_args):602return [603maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj)604if isinstance(arg, FakeScriptObject)605else arg606for arg in fake_flat_args607]608
609needs_autograd = any(610x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)611)612
613with enable_python_dispatcher():614# Patch set_rng_state as set_rng_state with fake tensors is615# nonsensical. This does not affect the collection of metadata.616with patch("torch.cuda.set_rng_state", lambda *args: None):617mod = root_module_when_exporting_non_strict(flat_fn)618if mod is not None:619ctx = _detect_attribute_assignment(mod)620else:621ctx = nullcontext()622with ctx:623fw_metadata = run_functionalized_fw_and_collect_metadata(624flat_fn,625static_input_indices=aot_config.static_input_indices,626keep_input_mutations=aot_config.keep_inference_input_mutations,627is_train=needs_autograd,628pre_dispatch=aot_config.pre_dispatch,629)(*_dup_fake_script_obj(fake_flat_args))630
631req_subclass_dispatch = requires_subclass_dispatch(632fake_flat_args, fw_metadata633)634
635output_and_mutation_safe = not any(636x.requires_grad637# view-type operations preserve requires_grad even in no_grad.638# Do not count aliases of inputs with requires_grad as reason to make a training graph,639# as AOTAutograd will perform view-replay to regenerate the view outputs at runtime,640# setting their grad_fn properly.641and not (642x.output_type643in (OutputType.alias_of_input, OutputType.is_input)644and fw_metadata.input_info[x.base_idx].requires_grad645)646for x in fw_metadata.output_info647) and not any(648x.requires_grad649and x.mutates_data650and not x.mutations_under_no_grad_or_inference_mode651and not x.mutations_hidden_from_autograd652for x in fw_metadata.input_info653)654
655if needs_autograd and output_and_mutation_safe:656# We realized that none of the outputs require grad,657# and none of the inputs that require grad are mutated.658# so we actually have an inference graph.659needs_autograd = False660# A bit silly: right now in the subclass codepath, our ViewAndMutationMeta661# changes depending on whether we pass in is_train / keep_input_mutations,662# so we're forced to recompute the metadata.663# TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata664# so that this is unnecessary.665if req_subclass_dispatch:666fw_metadata = run_functionalized_fw_and_collect_metadata(667flat_fn,668keep_input_mutations=aot_config.keep_inference_input_mutations,669is_train=False,670pre_dispatch=aot_config.pre_dispatch,671static_input_indices=aot_config.static_input_indices,672)(*fake_flat_args)673else:674fw_metadata = ViewAndMutationMeta(675input_info=fw_metadata.input_info,676output_info=fw_metadata.output_info,677num_intermediate_bases=fw_metadata.num_intermediate_bases,678keep_input_mutations=aot_config.keep_inference_input_mutations,679traced_tangents=fw_metadata.traced_tangents,680subclass_inp_meta=fw_metadata.subclass_inp_meta,681subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta,682subclass_tangent_meta=fw_metadata.subclass_tangent_meta,683is_train=False,684tokens=fw_metadata.tokens,685static_input_indices=fw_metadata.static_input_indices,686)687
688if fw_metadata.num_intermediate_bases > 0:689assert not req_subclass_dispatch, f"""\690torch.compile is currently being used with tensor subclass inputs:
691{','.join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs692that alias one another, which is currently unsupported in the subclass use case. If you run into this,
693please file a github issue"""
694
695if aot_config.is_export:696# aot_export: ban input metadata mutations for now to keep shared code paths simpler.697# Keeping .resize_() in the graph will require some work698# Allowing it but keeping the graph functional will require some calling convention changes.699if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0:700raise RuntimeError(701f"""\702Found an input that received a metadata mutation, through e.g. a call to `.resize_()` or `.transpose_()`.
703This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue.
704
705fw_metadata={str(fw_metadata)}"""706)707# In export, banning data mutations on inputs that require grad for now.708# This should be rare, and is tricky to get right. When we trace the backward,709# we currently trace with autograd.grad instead of .backward(), which makes it difficult710# to ensure that we run autograd all the way through the input **before** it saw the mutation.711if (712len(713[714x
715for x in fw_metadata.input_info716if x.requires_grad and x.mutates_data717]718)719!= 0720):721raise RuntimeError(722f"""\723Found a graph input that requires gradients, and received a mutation.
724This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue.
725
726fw_metadata={str(fw_metadata)}"""727)728if req_subclass_dispatch:729raise RuntimeError(730"""\731aot_export is not currently supported with traceable tensor subclass.
732If you need this feature, please comment on <CREATE_ISSUE_LINK>"""
733)734
735# Need to decide on a strategy for functionalized RNG: toggling via global config seems bad,736# and turning it on will require a non-trivial calling convention change for any export runtime.737if config.functionalize_rng_ops:738raise RuntimeError(739"""\740Functionalized RNG is not currently supported in the aot_export workflow. Please file a github issue,
741or otherwise set torch._functorch.config.functionalize_rng_ops = False."""
742)743
744def choose_dispatcher(needs_autograd, aot_config):745"""746Pick a dispatcher based on the config rules.
747"""
748if aot_config.is_export:749# export uses just the "graph bits", whereas the other750# two dispatchers include some extra work around handling a runtime epilogue751return partial(aot_dispatch_export, needs_autograd=needs_autograd)752elif needs_autograd and not aot_config.pre_dispatch:753return aot_dispatch_autograd754else:755return aot_dispatch_base756
757compiler_fn = choose_dispatcher(needs_autograd, aot_config)758
759compiled_fn, fw_metadata = compiler_fn(760flat_fn,761_dup_fake_script_obj(fake_flat_args),762aot_config,763fw_metadata=fw_metadata,764)765return compiled_fn, fw_metadata766
767
768def aot_function(769fn: Callable,770fw_compiler: Callable,771bw_compiler: Optional[Callable] = None,772partition_fn: Callable = default_partition,773decompositions: Optional[Dict] = None,774num_params_buffers: int = 0,775keep_inference_input_mutations: bool = False,776inference_compiler: Optional[Callable] = None,777*,778# Whether or not to trace with dynamic shapes779dynamic=False,780enable_log=True,781) -> Callable:782"""783Traces the forward and backward graph of :attr:`fn` using torch dispatch
784mechanism, and then compiles the generated forward and backward graphs
785through :attr:`fw_compiler` and :attr:`bw_compiler`.
786
787:func:`aot_function` traces the forward and backward graph ahead of time,
788and generates a joint forward and backward graph. :attr:`partition_fn` is
789then used to separate out forward and backward graphs. The partitioner
790function can be used to perform optimizations such as recomputation. One can
791set `decompositions` dictionary to decompose the operators into a sequence
792of core or simpler operators supported by the backend compilers.
793
794.. warning::
795This API is experimental and likely to change.
796
797Args:
798fn (Callable): A Python function that takes one ore more arguments. Must
799return one or more Tensors.
800fw_compiler (Callable): A Python function that accepts an Fx graph with
801Aten ops and input args, and returns a Callable that semantically is
802equivalent to the input Fx graph.
803bw_compiler (Optional[Callable]): A Python function that accepts an
804Fx graph with Aten ops and input args, and returns a Callable that
805semantically is equivalent to the input Fx graph. Default: None
806(when None, it defaults to the :attr:`fw_compiler`)
807partition_fn (Callable): A Python function that takes a joint forward
808and backward graph, and partitions it into separate forward and
809backward graphs.
810decompositions (Dict): A dictionary to define the decomposition of
811larger Aten ops into simpler or core Aten ops.
812inference_compiler (Optional[Callable]): A Python function that accepts an
813Fx graph with Aten ops and input args, and returns a Callable that
814semantically is equivalent to the input Fx graph. inference_compiler is invoked
815if no autograd is needed. Default: None
816(when None, it defaults to the :attr:`fw_compiler`)
817Returns:
818Returns a ``Callable`` that retains the eager behavior of the original
819:attr:`fn`, but with forward and backward graph compiled via
820:attr:`fw_compile` and :attr:`bw_compile`.
821
822A simple example usage of :func:`aot_function` is as follows. This example
823will print the forward and backward graphs of the function ``fn``
824
825>>> fn = lambda x : x.sin().cos()
826>>> def print_compile_fn(fx_module, args):
827>>> print(fx_module)
828>>> return fx_module
829>>> aot_fn = aot_function(fn, print_compile_fn)
830>>> x = torch.randn(4, 5, requires_grad=True)
831>>> aot_fn(x)
832"""
833
834if bw_compiler is None:835bw_compiler = fw_compiler836if inference_compiler is None:837inference_compiler = fw_compiler838aot_config = AOTConfig(839fw_compiler=fw_compiler,840bw_compiler=bw_compiler,841inference_compiler=inference_compiler,842partition_fn=partition_fn,843decompositions=decompositions,844num_params_buffers=num_params_buffers,845aot_id=next(AOT_COUNTER),846keep_inference_input_mutations=keep_inference_input_mutations,847dynamic_shapes=dynamic,848aot_autograd_arg_pos_to_source=None,849is_export=False,850no_tangents=False,851enable_log=enable_log,852)853cached_res = None854
855@wraps(fn)856def returned_function(*args, **kwargs):857nonlocal cached_res858# Now flatten the tensor args859flat_args = pytree.arg_tree_leaves(*args, **kwargs)860
861# Compile the function and save it in the cache862if cached_res is None:863flat_fn, out_spec = create_tree_flattened_fn(fn, args, kwargs)864(fake_mode, shape_env) = construct_fake_mode(flat_args, aot_config)865fake_flat_args: FakifiedFlatArgs = process_inputs(866flat_args, aot_config, fake_mode, shape_env867)868compiled_fn, _ = create_aot_dispatcher_function(869flat_fn,870fake_flat_args,871aot_config,872fake_mode,873shape_env,874)875cached_res = (compiled_fn, out_spec)876
877cached_fn, out_spec = cached_res878out = cached_fn(flat_args)879return out_spec.unflatten(out)880
881return returned_function882
883
884def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:885"""886Traces the forward and backward graph of :attr:`mod` using torch dispatch
887tracing mechanism. It is wrapper function, that underneath uses
888:func:`aot_function` to perform tracing and compilation.
889
890:func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs
891to a new callable which is then compiled through :func:`aot_function`.
892
893.. warning::
894This API is experimental and likely to change.
895
896Args:
897mod (Callable): A ``nn.Module`` module.
898args : args to be passed to :func:`aot_function`
899kwargs : kwargs to be passed to :func:`aot_function`
900
901Returns:
902Returns a ``nn.Module`` that retains the eager behavior of the original
903:attr:`mod`, but with forward and backward graph compiled.
904
905"""
906# See Note: [Fake Modules and AOTAutograd]907torch._dynamo.utils.assert_no_fake_params_or_buffers(mod)908
909def functional_call(named_params, named_buffers, *args, **kwargs):910params_and_buffers = {**named_params, **named_buffers}911return torch.func.functional_call(mod, params_and_buffers, args, kwargs)912
913named_params = dict(mod.named_parameters(remove_duplicate=False))914named_buffers = dict(mod.named_buffers(remove_duplicate=False))915num_params_buffers = len(named_params) + len(named_buffers)916compiled_f = aot_function(917functional_call, *args, num_params_buffers=num_params_buffers, **kwargs918)919
920class AOTModule(nn.Module):921def __init__(self) -> None:922super().__init__()923self.orig_module = mod924
925def forward(self, *args, **kwargs):926return compiled_f(927named_params,928named_buffers,929*args,930**kwargs,931)932
933return AOTModule()934
935
936def aot_module_simplified(937mod: nn.Module,938args,939fw_compiler: Callable,940bw_compiler: Optional[Callable] = None,941partition_fn: Callable = default_partition,942decompositions: Optional[Dict] = None,943keep_inference_input_mutations=False,944inference_compiler: Optional[Callable] = None,945cudagraphs: Optional[BoxedBool] = None,946) -> nn.Module:947"""948This is the simplified or low overhead version of aot_module. For frontends
949like TorchDynamo, the input functions/modules to AOT are static and have
950unpacked inputs/outputs. This gives us an opportunity to remove the
951(1) pytree overhead to parse inputs/outputs,
952(2) AOT Autograd cache,
953(3) Reading of params/buffers in every forward call
954
955:func:`aot_module_simplified` removes these overheads.
956"""
957params = {958**dict(mod.named_parameters(remove_duplicate=False)),959**dict(mod.named_buffers(remove_duplicate=False)),960}961params_flat, params_spec = pytree.tree_flatten(params)962params_flat = list(params_flat)963params_len = len(params_flat)964
965if cudagraphs is None:966cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs)967
968if bw_compiler is None:969bw_compiler = fw_compiler970if inference_compiler is None:971inference_compiler = fw_compiler972
973seen_sources = set()974
975full_args = []976# First, the params977full_args.extend(params_flat)978
979if tracing_context := torch._guards.TracingContext.try_get():980tracing_context.params_flat = params_flat981
982aot_autograd_arg_pos_to_source = None983# Then, the params 1:1 mapped sources, if relevant.984if hasattr(mod, "_param_name_to_source"):985aot_autograd_arg_pos_to_source = []986# We now know this came from dynamo, and (1) we care about guards,987# so setting up aot_autograd_arg_pos_to_source for downstream dedup guards988# can now be done safely. (2) Dynamo logic protects the 1:1 sizing below.989for name in params.keys():990assert name in mod._param_name_to_source, f"{name} not found."991source = mod._param_name_to_source[name]992assert source not in seen_sources, source993seen_sources.add(source)994aot_autograd_arg_pos_to_source.append(source)995
996# Next, the input args997full_args.extend(args)998
999static_input_indices = []1000if hasattr(mod, "graph"):1001# Non dynamo entrypoints can get to here...1002for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):1003if hasattr(node, "_dynamo_source"):1004# ... but not here!1005if aot_autograd_arg_pos_to_source is None:1006aot_autograd_arg_pos_to_source = []1007source = node._dynamo_source1008assert source not in seen_sources, source1009seen_sources.add(source)1010aot_autograd_arg_pos_to_source.append(source)1011source_name = source.name() if source else str(source)1012
1013if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(1014"_dynamo_static_input_type", None1015):1016static_inputs_log.debug(1017"Adding static input pos %s for source %s", pos, source_name1018)1019static_input_indices.append(pos)1020else:1021static_inputs_log.debug(1022"Non-static input pos %s for source %s", pos, source_name1023)1024
1025if aot_autograd_arg_pos_to_source is not None:1026assert len(full_args) == len(aot_autograd_arg_pos_to_source)1027
1028dynamic_shapes = False1029for x in full_args:1030if isinstance(x, FakeTensor):1031dynamic_shapes = x.fake_mode.shape_env is not None1032break1033
1034aot_config = AOTConfig(1035fw_compiler=fw_compiler,1036bw_compiler=bw_compiler,1037inference_compiler=inference_compiler,1038partition_fn=partition_fn,1039decompositions=decompositions,1040num_params_buffers=params_len,1041aot_id=next(AOT_COUNTER),1042keep_inference_input_mutations=keep_inference_input_mutations,1043dynamic_shapes=dynamic_shapes,1044aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source,1045static_input_indices=static_input_indices,1046is_export=False,1047no_tangents=False,1048cache_key=None,1049)1050fake_mode, shape_env = construct_fake_mode(full_args, aot_config)1051fake_flat_args = process_inputs(full_args, aot_config, fake_mode, shape_env)1052
1053def dispatch_and_compile():1054functional_call = create_functional_call(mod, params_spec, params_len)1055with compiled_autograd.disable():1056compiled_fn, _ = create_aot_dispatcher_function(1057functional_call,1058fake_flat_args,1059aot_config,1060fake_mode,1061shape_env,1062)1063return compiled_fn1064
1065# Autograd cache stuff1066if config.enable_autograd_cache:1067compiled_fn = AOTAutogradCache.load(1068dispatch_and_compile, mod, fake_flat_args, aot_config, cudagraphs1069)1070else:1071compiled_fn = dispatch_and_compile()1072
1073if isinstance(mod, torch._dynamo.utils.GmWrapper):1074# This function is called by the flatten_graph_inputs wrapper, which boxes1075# the inputs so that they can be freed before the end of this scope.1076# For overhead reasons, this is not the default wrapper, see comment:1077# https://github.com/pytorch/pytorch/pull/122535/files#r15600964811078def boxed_forward(runtime_args: List[Any]):1079flat_args = []1080flat_args.extend(params_flat)1081flat_args.extend(runtime_args)1082runtime_args.clear()1083return compiled_fn(flat_args)1084
1085# Just for convenience1086boxed_forward.zero_grad = mod.zero_grad1087boxed_forward.named_parameters = mod.named_parameters1088boxed_forward.named_buffers = mod.named_buffers1089return boxed_forward1090
1091# TODO: There is something deeply wrong here; compiled_fn running with1092# the boxed calling convention, but aot_module_simplified somehow1093# historically returned a function that was not the boxed calling1094# convention. This should get fixed...1095# NB: GraphModule/nn.Module rely on the non-boxed calling convention here1096def forward(*runtime_args: Tuple[Any]):1097full_args = []1098full_args.extend(params_flat)1099full_args.extend(runtime_args)1100return compiled_fn(full_args)1101
1102# Just for convenience1103forward.zero_grad = mod.zero_grad1104forward.named_parameters = mod.named_parameters1105forward.named_buffers = mod.named_buffers1106
1107return forward1108
1109
1110def aot_export_module(1111mod: nn.Module,1112args,1113*,1114decompositions: Optional[Dict] = None,1115# If true, we'll return a joint forward-backward graph,1116# As well as metadata on the loss + gradients in the backward.1117trace_joint: bool,1118# If trace_joint is True, we expect your module to return a scalar loss.1119# Your module can return multiple outputs, so you must specify which output the loss is.1120output_loss_index: Optional[int] = None,1121pre_dispatch: bool = False,1122# If None, will be infered from inputs and mod.graph.nodes if mod is a graph module, but the inferred result might be wrong.1123dynamic_shapes: Optional[bool] = None,1124kwargs=None,1125) -> Tuple[torch.fx.GraphModule, GraphSignature]:1126"""1127This function takes in a module, and returns:
1128(1) an FX graph that can be exported
1129(2) some metadata about the graph
1130
1131If `trace_joint=True` we will return a joint graph of the forward + backward.
1132
1133The traced FX graph will have the following properties compared to the original module:
1134(1) Inputs and outputs to the module will be pytree-flattened
1135(2) Parameters and buffers on the module will be lifted into graph inputs,
1136graph_inputs = (*parameters, *buffers, *user_inputs)
1137(3) The graph will be fully functionalized
1138(4) Any input mutations will be converted into additional outputs in the graph,
1139meaning whoever calls this graph is responsible for applying the mutations
1140back to the original inputs.
1141(5) If is_joint is provided the graph will return parameter gradients in addition to user outputs.
1142The graph output will look like:
1143graph_outputs = (*updated_inputs, *user_outputs, *param_gradients)
1144
1145There are also several restrictions on what modules can use this API. In particular:
1146(1) If trace_joint is specified, we expect the loss function to be **fused**
1147into the module forward. One of the outputs to the forward must be a scalar loss,
1148which is specified with `output_loss_index`.
1149All other outputs to the forward are presumed to not require gradients.
1150(2) This API cannot capture optimizers (although in theory we could build an API for this).
1151(3) Metadata mutations on params/buffers/inputs are banned.
1152(4) Data mutations on anything that requires gradients are banned (parameters)
1153(5) If an input is mutated, it is not allowed to alias any other inputs.
1154(6) Parameters must not be duplicated.
1155"""
1156if pre_dispatch and trace_joint:1157raise RuntimeError("pre_dispatch is not supported when trace_joint is True.")1158named_parameters = dict(mod.named_parameters(remove_duplicate=False))1159named_buffers = dict(mod.named_buffers(remove_duplicate=False))1160
1161params_and_buffers = {1162**dict(named_parameters),1163**dict(named_buffers),1164}1165params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers)1166params_and_buffers_flat = tuple(params_and_buffers_flat)1167params_len = len(params_and_buffers_flat)1168
1169kwargs = kwargs or {}1170
1171functional_call = create_functional_call(1172mod, params_spec, params_len, store_orig_mod=True1173)1174
1175num_fw_outs = None1176
1177if trace_joint:1178# This helper effectively just adds some extra asserts about what the backward will look like:1179# Outputs must include a scalar loss, that we compute gradients w.r.t.1180# We don't compute gradients w.r.t. anything else: so just in case we detach()1181# and other output tensors.1182def fn_to_trace(*args):1183nonlocal num_fw_outs1184out = functional_call(*args)1185if output_loss_index is None:1186raise RuntimeError(1187"""\1188If trace_joint=Trueit is required that one of your forward outputs must be a scalar loss.
1189You must specify the which (index) output is the loss with output_loss_index."""
1190)1191if isinstance(out, (torch.Tensor)):1192out = (out,)1193if not isinstance(out, (tuple, list)):1194raise RuntimeError(1195f"Expected forward output to be either a tensor or a list/tuple of tensors. found {type(out)}"1196)1197
1198for i, o in enumerate(out):1199# We only want to create a backward graph w.r.t. the loss that the user passed in.1200# This implies that every other output should not require gradients.1201# Instead of making this an error (and forcing the user to detach all other outputs1202# of their forward),1203# we'll automatically detach them here.1204if o.requires_grad and i != output_loss_index:1205raise RuntimeError(1206f"""\1207Found an output of the forward that requires gradients, that was not the scalar loss.
1208We require all outputs to the forward that are not the scalar loss to not require gradient,
1209because we will only compute a backward graph against the scalar loss.
1210You can fix this by calling .detach() on each of your forward outputs that is not the loss.
1211You specified that output index {output_loss_index} is the loss, but we found that1212the output at index {i} requires gradients."""1213)1214out_loss = out[output_loss_index]1215num_fw_outs = len(out)1216if not out_loss.requires_grad:1217raise RuntimeError(1218f"""\1219The output at index {output_loss_index} was marked as the loss, but it does not require gradients"""1220)1221if out_loss.numel() != 1:1222raise RuntimeError(1223f"""\1224We require the output marked as the loss (at index {output_loss_index}) to be a scalar, but it has shape {out_loss.shape}"""1225)1226return out1227
1228ctx = nullcontext1229else:1230# Run under no_grad, so our tracing machinery only traces an inference graph.1231# However if pre_dispatch=True, we want to correctly trace set_grad_enabled calls for training.1232ctx = nullcontext if pre_dispatch else torch.no_grad1233fn_to_trace = functional_call1234
1235full_args = []1236# First, the params1237# NB: It is REQUIRED that parameters come first, Inductor infers "fixed"1238# parameters by looking at the difference in parameter count outside1239# and inside AOTAutograd, and assumes the prefix of arguments are fixed1240# arguments1241full_args.extend(params_and_buffers_flat)1242# Next, the input args1243full_args.extend(args)1244
1245with ctx():1246fx_g, metadata, in_spec, out_spec = _aot_export_function(1247fn_to_trace,1248full_args,1249decompositions=decompositions,1250num_params_buffers=params_len,1251no_tangents=True,1252pre_dispatch=pre_dispatch,1253dynamic_shapes=dynamic_shapes,1254kwargs=kwargs,1255)1256if trace_joint:1257
1258def flattened_joint(*args):1259# The idea here is that the joint graph that AOTAutograd creates has some strict properties:1260# (1) It accepts two arguments (primals, tangents), and pytree_flattens them1261# (2) It returns a tuple of (fw_outs, gradients)1262# This is a very useful convention for anyone who wants to partition the joint graph1263# into a separate forward and backward graph.1264# However,1265# (1) for people exporting a single joint graph, it would be preferable not to have1266# any pytrees in the graph.1267# (2) We are guaranteed in the aot_export_module case that the forward outputs a loss,1268# and there are therefore no tangents that are needed to run the joint graph.1269# (3) AOTAutograd creates a grad_input for every input in the forward,1270# including None's for inputs that are not grad-requiring tensors.1271# we don't want these in our export graph.1272# and there are therefore no tangents that are needed to run the joint graph.1273# This function "fixes" both of the above by removing any tangent inputs,1274# and removing pytrees from the original FX graph.1275fake_tangents = [1276None1277for _ in range(1278metadata.num_outputs + metadata.num_mutated_inp_runtime_indices1279)1280]1281fw_outs, gradients = fx_g(args, fake_tangents)1282assert len(gradients) == len(args)1283output_gradients = []1284for i, (a, grad) in enumerate(zip(args, gradients)):1285if isinstance(a, torch.Tensor) and a.requires_grad:1286assert (1287grad is not None1288), """\1289Found a parameter that did not receive a gradient.
1290"This is most likely a bug, but if this needs to be supported please comment on this Github issue:
1291https://github.com/pytorch/pytorch/issues/101192
1292"""
1293output_gradients.append(grad)1294else:1295assert grad is None1296return *fw_outs, *output_gradients1297
1298fx_g = make_fx(flattened_joint)(*full_args)1299
1300user_args_flat = pytree.arg_tree_leaves(*args, **kwargs)1301return fx_g, create_graph_signature(1302fx_g,1303metadata,1304in_spec,1305out_spec,1306user_args_flat=user_args_flat,1307params_and_buffers_flat=params_and_buffers_flat,1308param_names=list(named_parameters.keys()),1309buffer_names=list(named_buffers.keys()),1310trace_joint=trace_joint,1311num_user_fw_outs=num_fw_outs,1312loss_index=output_loss_index,1313)1314
1315
1316def aot_export_joint_simple(1317func: Callable,1318args,1319*,1320trace_joint: bool,1321# It looks like the main consequence of this API is that for dynamic shapes,1322# it will assume that parms/buffers are static.1323# With the new inferred dynamic shapes API, maybe this doesn't matter?1324num_params_buffers: int = 0,1325decompositions: Optional[Dict] = None,1326) -> torch.fx.GraphModule:1327"""1328A simplified version of export. Used by higher order operators.
1329
1330This function makes a high-level "no calling convention changes" guarantee:
1331- If no inputs require grad (so we export an inference graph),
1332there are *no* calling convention change between the exported graph, and "func".
1333- If at least one input requires grad (so we trace out and export a joint fw-bw graph),
1334Then if you were partition the graph into a separate forward and backward graph,
1335The forward graph will have no calling convention changes compared to "func".
1336
1337The above also relies on some strong restrictions around which functions this API accepts:
1338(1) `args` cannot contain any pytrees (they must have been pytree_flattened already)
1339(2) `func` cannot mutate any inputs
1340(3) The outputs of `func` cannot alias any inputs.
1341
1342Note: this function is only lightly tested today. It will probably be tested more heavily by higher order ops.
1343"""
1344if trace_joint:1345ctx = nullcontext1346else:1347# Run under no_grad, so our tracing machinery only traces an inference graph.1348ctx = torch.no_grad1349
1350with ctx():1351fx_g, metadata, in_spec, out_spec = _aot_export_function(1352func,1353args,1354decompositions=decompositions,1355)1356in_spec, _kw_in_spec = in_spec.children_specs1357# At this point, we can just directly return the (joint or inference graph) that we traced.1358# First though: a bunch of assertions to make sure that our graph doesn't require1359# any calling convention changes compared to the original function.1360# These restrictions are *in addition to* the general restrictions on export.1361
1362# No input mutations1363if (1364len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata])1365!= 01366):1367raise RuntimeError(1368f"aot_export_joint_simple does not support input mutations. {str(metadata)}"1369)1370# No output aliasing1371if (1372len([x for x in metadata.output_info if x.output_type != OutputType.non_alias])1373!= 01374):1375raise RuntimeError(1376f"aot_export_joint_simple does not support outputs that alias inputs. {str(metadata)}"1377)1378# No pytrees1379if in_spec.is_leaf():1380raise RuntimeError(1381f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}"1382)1383if not all(child.is_leaf() for child in in_spec.children_specs):1384raise RuntimeError(1385f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}"1386)1387if out_spec.is_leaf():1388raise RuntimeError(1389f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}"1390)1391if not all(child.is_leaf() for child in out_spec.children_specs):1392raise RuntimeError(1393f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}"1394)1395# TODO: we might have to temporarily patch config.functionalize_rng1396# so that it doesn't run when we're exporting a higher order op.1397
1398if config.debug_assert:1399# Smoke test that after partitioning, we can run the forward without any calling convention changes.1400fw_module, bw_module = aot_config.default_partition( # noqa: F8211401fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos) # noqa: F8211402)1403# Attempt to run the fw_module with the original user inputs1404fake_mode = detect_fake_mode(args)1405if fake_mode is None:1406fake_mode = FakeTensorMode()1407with fake_mode:1408fw_module(*args)1409return fx_g1410
1411
1412# Private for now because we aren't providing a contract on what to return
1413# for joint graphs (we could when there's a clearer use case)
1414# In the future, we may need to add more export API's that provide their own strong guarantees.
1415# This is meant as a general helper function for handling various export-y use cases.
1416def _aot_export_function(1417func: Callable,1418args,1419*,1420num_params_buffers: int = 0,1421decompositions: Optional[Dict] = None,1422# If we're exporting a joint graph and we don't want any tangent inputs in the graph1423# (because we are backpropping through a scalar 1 loss),1424# we need to explicitly specify not to include tangents in the graph.1425# It's not enough just to check that our tangent is a scalar, since we also1426# need to know if it is a 1 (no need to make it a graph input), or something else1427# (requiring it to be a graph input).1428# We don't know this info at trace time though, so we need to make it an explicit config.1429no_tangents: bool = False,1430pre_dispatch: bool = False,1431# If None, `dynamic_shapes` will be infered from inputs, but the inferred result might be wrong.1432dynamic_shapes: Optional[bool] = None,1433kwargs=None,1434) -> Tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:1435kwargs = kwargs or {}1436
1437flat_fn, out_spec = create_tree_flattened_fn(func, args, kwargs)1438flat_args, in_spec = pytree.tree_flatten((args, kwargs))1439
1440if dynamic_shapes is None:1441# Try to infer `dynamic_shapes from inputs and graph nodes1442fake_mode = detect_fake_mode(flat_args)1443if (1444fake_mode is None1445and hasattr(func, "_orig_mod")1446and isinstance(func._orig_mod, torch.fx.GraphModule)1447):1448vals = [1449node.meta["val"]1450for node in func._orig_mod.graph.nodes1451if "val" in node.meta1452]1453fake_mode = detect_fake_mode(vals)1454dynamic_shapes = fake_mode is not None and fake_mode.shape_env is not None1455
1456# The export use case doesn't care about several bits of AOTConfig1457# (1) compilers (we just export the graph)1458# (2) partitioners (export is only full graph, user can partition themselves)1459aot_config = AOTConfig(1460fw_compiler=None,1461bw_compiler=None,1462inference_compiler=None,1463partition_fn=None,1464decompositions=decompositions,1465num_params_buffers=num_params_buffers,1466aot_id=next(AOT_COUNTER),1467# For now there's no use case involving keeping input mutations in the graph1468# (which we can only do in the inference case anyway).1469# We can add this later if we need to.1470keep_inference_input_mutations=False,1471dynamic_shapes=dynamic_shapes,1472aot_autograd_arg_pos_to_source=None,1473is_export=True,1474no_tangents=no_tangents,1475pre_dispatch=pre_dispatch,1476)1477fake_mode, shape_env = construct_fake_mode(flat_args, aot_config)1478fake_flat_args = process_inputs(flat_args, aot_config, fake_mode, shape_env)1479
1480fx_g, meta = create_aot_dispatcher_function(1481flat_fn,1482fake_flat_args,1483aot_config,1484fake_mode,1485shape_env,1486)1487return fx_g, meta, in_spec, out_spec.spec1488
1489
1490@contextmanager
1491def _detect_attribute_assignment(mod: torch.nn.Module):1492# Do not allow assignment of tensor attributes during export unless1493# the attribute is registered as a buffer.1494
1495STD_ATTRS = {1496"_backward_hooks",1497"_backward_pre_hooks",1498"_buffers",1499"_forward_hooks",1500"_forward_hooks_always_called",1501"_forward_hooks_with_kwargs",1502"_forward_pre_hooks",1503"_forward_pre_hooks_with_kwargs",1504"_is_full_backward_hook",1505"_load_state_dict_post_hooks",1506"_load_state_dict_pre_hooks",1507"_modules",1508"_non_persistent_buffers_set",1509"_parameters",1510"_state_dict_hooks",1511"_state_dict_pre_hooks",1512"training",1513}1514
1515def _get_attributes(mod):1516# return any attributes of a module that are not standard attributes1517return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS}1518
1519# save state of attributes before enter1520snapshot = pytree.tree_map(lambda x: x, _get_attributes(mod))1521try:1522yield1523finally:1524# after exit, compare state of attributes with snapshot1525# to detect which tensor attributes were assigned1526assigned_tensor_attributes = []1527
1528def _collect_assigned_tensor_attributes(kp, v, _v):1529if _v is not v:1530attr, *rest = kp1531if isinstance(v, torch.Tensor):1532assigned_tensor_attributes.append(1533f"self.{attr.key}{pytree.keystr(rest)}"1534)1535# TODO(avik): Assigning all other types are allowed right now.1536# Maybe in the future we want to limit this to primitive types?1537
1538pytree.tree_map_with_path(1539_collect_assigned_tensor_attributes, snapshot, _get_attributes(mod)1540)1541# restore state of all attributes (including, e.g., of primitive types)1542mod.__dict__.update(snapshot)1543
1544if assigned_tensor_attributes:1545if len(assigned_tensor_attributes) > 1:1546noun, verb = "attributes", "were"1547else:1548noun, verb = "attribute", "was"1549raise ValueError(1550f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. "1551"Such attributes must be registered as buffers using the `register_buffer` API "1552"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."1553)1554
1555
1556compiled_function = aot_function1557compiled_module = aot_module1558