pytorch
379 строк · 13.6 Кб
1# mypy: allow-untyped-defs
2import functools3from contextlib import contextmanager4from dataclasses import dataclass5from typing import Any, Callable6
7import torch8import torch.fx.traceback as fx_traceback9import torch.utils._pytree as pytree10from torch._ops import OperatorBase11from torch.fx.experimental.proxy_tensor import make_fx12from torch.multiprocessing.reductions import StorageWeakRef13
14
15@dataclass
16class UnsupportedAliasMutationException(RuntimeError):17reason: str18
19
20def autograd_not_implemented_inner(21operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any22) -> Any:23"""If autograd is enabled and any of the arguments require grad this will either24raise an error or return a DelayedError depending on the value of delayed.
25
26Args:
27operator: The Operator to call with the *args and **kwargs with
28op_name: The name of the Operator
29delayed_error: If True, return a DelayedError instead of raising an error
30args: The flattened operands to the Operator
31kwargs: The keyword arguments to the Operator
32
33Raises:
34RuntimeError: If autograd is enabled and any of the arguments to the Operator
35"""
36with torch._C._AutoDispatchBelowAutograd():37result = operator(*args, **kwargs)38flat_operands = pytree.arg_tree_leaves(*args)39if torch.is_grad_enabled() and any(40f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)41):42if delayed_error:43err_fn = torch._C._functions.DelayedError(44f"Autograd not implemented for {str(operator)}",451,46)47
48def fake_requires_grad(tensor):49if torch.is_floating_point(tensor) or torch.is_complex(tensor):50tensor = tensor.detach()51tensor.requires_grad = True52return tensor53
54return pytree.tree_map_only(55torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result56)57else:58raise RuntimeError(f"Autograd not implemented for {str(operator)}")59return result60
61
62def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable:63def inner(*args, **kwargs):64return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)65
66return inner67
68
69def _maybe_run_with_interpreter(fn):70maybe_interpreted_fn = fn71if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():72# Running graph with interpreter is needed for propagating the stack_trace73def graph_with_interpreter(*args):74with fx_traceback.preserve_node_meta():75return torch.fx.Interpreter(fn).run(*args)76
77maybe_interpreted_fn = graph_with_interpreter78return maybe_interpreted_fn79
80
81def reenter_make_fx(fn):82from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER83
84@functools.wraps(fn)85def wrapped(*args):86assert (87_CURRENT_MAKE_FX_TRACER is not None88), "Cannot reenter make_fx when we're not under a make_fx tracing session"89return _CURRENT_MAKE_FX_TRACER.trace_subgraph(90_maybe_run_with_interpreter(fn), *args91)92
93return wrapped94
95
96def _maybe_reenter_make_fx(fn):97from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER98
99if _CURRENT_MAKE_FX_TRACER is not None:100return reenter_make_fx(fn)101else:102return make_fx(fn)103
104
105@contextmanager
106def _set_compilation_env():107_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag108try:109# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo110# once we are confident fx tracing works with dynamo.111torch.fx._symbolic_trace._is_fx_tracing_flag = False112yield113finally:114torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing115
116
117def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False):118"""119Dispatch-trace the branch with inputs and check if
120producing graph has mutable op on the input. This is
121bit restrictive as the branch must be traceable.
122"""
123try:124gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)125except UnsupportedAliasMutationException:126# this can happen when nested cond_op is127# functionalized128return True129except Exception as e:130raise e131
132def _detect_input_mutation(gm):133input_nodes = set()134for node in gm.graph.nodes:135if node.op == "placeholder":136input_nodes.add(node)137if node.op == "call_function":138target = node.target139if (140isinstance(target, torch._ops.OpOverload)141and target._schema.is_mutable142):143for arg in node.args:144if arg in input_nodes:145return True146
147for _, module in gm.named_children():148if isinstance(module, torch.fx.GraphModule):149if _detect_input_mutation(module):150return True151
152return False153
154return _detect_input_mutation(gm)155
156
157def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False):158"""159Dispatch-trace the branch with inputs and check if
160producing graph has output aliasing the branch input. This is
161bit restrictive as the branch must be traceable.
162"""
163try:164gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)165except UnsupportedAliasMutationException:166# this can happen when nested cond_op is167# functionalized168return True169except Exception as e:170raise e171
172def _detect_input_alias(gm):173input_storages = set()174for node in gm.graph.nodes:175# We need to check existence of "val" because we reuse the logic here176# for map operator, where num_mapped_args is a scalar177# and doesn't have a "val" meta.178if node.op == "placeholder" and "val" in node.meta:179input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))180if node.op == "output":181
182def check_alias(out):183if out is not None and "val" in out.meta:184out_storage = StorageWeakRef(out.meta["val"]._typed_storage())185return out_storage in input_storages186return False187
188if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))):189return True190
191for _, module in gm.named_children():192if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):193return True194
195return False196
197return _detect_input_alias(gm)198
199
200def unique_graph_id(proxy_mode, prefix):201"""Returns a unique name and id for a graph to be added to a proxy_mode tracer"""202# There are probably better ways - I know that create_arg has some self incrementing name203# magic to it, but since we explicitly have to get the name for register_module,204# I was not sure how to do that. This kinda simulates it.205next_name = None206i = 0207while not next_name:208candidate = f"{prefix}_{i}"209if hasattr(proxy_mode.tracer.root, candidate):210i += 1211else:212next_name = candidate213return i, next_name214
215
216def _from_fun(t):217from torch._functorch.aot_autograd import from_fun218from torch._subclasses.functional_tensor import FunctionalTensor219
220if isinstance(t, torch.Tensor):221if t.dtype != torch.bool:222return torch.empty_strided(223t.size(),224t.stride(),225dtype=t.dtype,226requires_grad=t.requires_grad,227)228else:229# clone of a functional tensor produces a functional tensor230# but we want to avoid it so we clone a non-functional version231maybe_unfunc_t = t232if isinstance(t, FunctionalTensor):233torch._sync(t)234maybe_unfunc_t = from_fun(t)235elif torch._is_functional_tensor(t):236# need to handle both types of functionalization here:237# these are the tensors that came from the user,238# which could be either FunctionalTensorWrapper or FunctionalTensor239torch._sync(t)240maybe_unfunc_t = torch._from_functional_tensor(t)241return maybe_unfunc_t.clone()242return t243
244
245def clone_outputs_aliasing_inputs(args):246input_storage = {247StorageWeakRef(arg._typed_storage())248for arg in args249if isinstance(arg, torch.Tensor)250}251
252def maybe_clone(t):253if (254isinstance(t, torch.Tensor)255and StorageWeakRef(t._typed_storage()) in input_storage256):257return t.clone()258return t259
260return maybe_clone261
262
263def prepare_fw_with_masks(fn):264def fw_with_masks(*args):265fw_out = fn(*args)266return fw_out, [267True if isinstance(ret, torch.Tensor) and ret.requires_grad else False268for ret in fw_out269]270
271return fw_with_masks272
273
274# TODO: The parameter use_output_and_grad_bw is required because some operations
275# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
276def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):277from torch._functorch.aot_autograd import AOTConfig, create_joint278
279# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys280# between Autograd and Python key. Currently, we only suspend functionalization but more can be281# added when required. Will encounter two problems if we don't suspend functionalization:282#283# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,284# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.285# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to286# fetch the proxy for the inputs and fail to capture any operations on them.287#288# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further289# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer290# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,291# when creating the output node, it fails to associate the wrapped tensor with its proxy.292# Instead, it will create _tensor_constant as output.293
294dummy_aot_config = AOTConfig(295fw_compiler=None, # type: ignore[arg-type]296bw_compiler=None, # type: ignore[arg-type]297partition_fn=None, # type: ignore[arg-type]298decompositions={},299num_params_buffers=0,300aot_id=0,301keep_inference_input_mutations=False,302)303
304example_grad = [_from_fun(out) for out in fw_outputs]305num_grads = len(example_grad)306fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)307
308def joint_fn(*joint_operands_grads):309if use_output_and_grad_bw:310grads = joint_operands_grads[0]311inputs = joint_operands_grads[1][-1:]312else:313grads = joint_operands_grads[:num_grads]314inputs = joint_operands_grads[num_grads:]315
316joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)317_, grads = joint(318list(inputs),319[grad for grad in grads if grad is not None and grad.requires_grad],320)321
322# In order to keep map functional for backward graph,323# we clone outputs that are aliasing inputs324maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)325
326return pytree.tree_map(maybe_clone, grads)327
328if use_output_and_grad_bw:329example_xs_out = list(fw_inputs) + list(fw_outputs)330joint_graph = _maybe_reenter_make_fx(joint_fn)(331(list(example_grad), list(example_xs_out))332)333else:334example_xs_out = list(fw_inputs)335joint_graph = _maybe_reenter_make_fx(joint_fn)(336*(list(example_grad) + list(example_xs_out))337)338
339return fw_graph, joint_graph340
341
342def _unstack_pytree(xs):343flat_xs, inspec = pytree.tree_flatten(xs)344if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):345raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")346
347if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):348raise RuntimeError(349f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"350)351
352a = zip(*flat_xs)353
354pytrees = []355for tuple in a:356pytrees.append(pytree.tree_unflatten(tuple, inspec))357return pytrees358
359
360def _stack_pytree(pytrees):361flat_out = []362out_spec = None363for pt in pytrees:364flat_pt, out_spec = pytree.tree_flatten(pt)365flat_out.append(flat_pt)366assert out_spec is not None367b = zip(*flat_out)368stacked_out = []369for leaves in b:370if all(isinstance(leaf, torch.Tensor) for leaf in leaves):371stacked_out.append(torch.stack(leaves))372elif all(leaf is None for leaf in leaves):373# Backward graph can return None output when forward inputs doesn't require grad.374# When we eagerly execute backward graph, we need to call _stack_pytree on its output,375# therefore we need to deal with None output.376stacked_out.append(None) # type: ignore[arg-type]377else:378raise RuntimeError(f"Cannot stack {leaves}.")379return pytree.tree_unflatten(stacked_out, out_spec)380