pytorch
142 строки · 5.3 Кб
1# mypy: allow-untyped-defs
2import logging3from contextlib import contextmanager4
5import torch6from torch._C import DispatchKey # @manual7from torch._functorch._aot_autograd.utils import KNOWN_TYPES8from torch._higher_order_ops.utils import autograd_not_implemented9from torch._library.fake_class_registry import _ns_and_class_name, FakeScriptObject10from torch._ops import HigherOrderOperator11from torch._subclasses.fake_tensor import FakeTensorMode12from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree13from torch.fx.node import has_side_effect14from torch.utils import _pytree as pytree15
16
17log = logging.getLogger(__name__)18
19
20# The call_torchbind operator represents a method invocation on a torchbind
21# object. The calling convention is:
22# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs)
23# We do not expect users to write this operator directly. Instead it will be
24# emitted by Dynamo when tracing encounters a torchbind object.
25class CallTorchBind(HigherOrderOperator):26def __init__(self):27super().__init__("call_torchbind")28
29def __call__(self, obj, method, *args, **kwargs):30return super().__call__(obj, method, *args, **kwargs)31
32
33call_torchbind = CallTorchBind()34
35# Register this operator as side-effectful with FX.
36# TODO: this is not really sufficient. While passes (hopefully) check
37# Node.is_impure() and make good decisions, we also assume we can execute the
38# graph as many times as we want without changing behavior, which is NOT true of
39# ops that mutate torchbind object state.
40has_side_effect(call_torchbind)41
42_orig_scriptmethod_call = torch.ScriptMethod.__call__43
44
45def torchbind_method_redispatch(self, *args, **kwargs):46if isinstance(self.raw_owner, torch.ScriptObject):47return call_torchbind(self.raw_owner, self.name, *args, **kwargs)48return _orig_scriptmethod_call(self, *args, **kwargs)49
50
51@contextmanager
52def enable_torchbind_tracing():53"""Context manager that acts as a feature flag to enable torchbind tracing54behavior. Once torchbind tracing has been stabilized, we can remove this and
55turn it always on.
56"""
57try:58KNOWN_TYPES.append(torch.ScriptObject)59torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign]60yield61finally:62assert (63KNOWN_TYPES.pop() is torch.ScriptObject64), "Someone else messed with KNOWN_TYPES during tracing, exploding."65torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign]66
67
68@call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd)69def call_torchbind_impl(obj, method, *args, **kwargs):70if isinstance(obj, torch.ScriptObject):71return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs)72elif isinstance(obj, FakeScriptObject):73return getattr(obj.wrapped_obj, method)(*args, **kwargs)74else:75raise RuntimeError(f"Unsupported first arg type {type(obj)} for call_torchbind")76
77
78@call_torchbind.py_impl(ProxyTorchDispatchMode)79def inner(mode, *args, **kwargs):80proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)81proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)82
83out_proxy = mode.tracer.create_proxy(84"call_function",85call_torchbind,86proxy_args,87proxy_kwargs,88)89out = call_torchbind(*args, **kwargs)90
91obj, method, *rest_args = args92if isinstance(obj, torch.ScriptObject):93ns, class_name = _ns_and_class_name(94obj._type().qualified_name() # type: ignore[attr-defined]95)96log.warning(97"Tracing torchbind method %s.%s with real ScriptObject. This may"98" cause the original object being mutated. If this is not intended,"99' You can register a fake class with torch._library.register_fake_class("%s::%s").',100class_name,101method,102ns,103class_name,104)105
106ret = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)107if "val" not in out_proxy.node.meta:108assert out is None or isinstance(109out, (int, float, bool)110), "Currently, only these constant dtypes are supported to be returned from torchbind methods."111out_proxy.node.meta["val"] = out112return ret113
114
115# When tracing with fake script object, the call_torchbind op will return a fake tensor
116# When tracing with real script object, the call_torchbind op may return a real tensor,
117# we need to convert it to fake tensor mannually. Dynamic shape is surpported.
118@call_torchbind.py_impl(FakeTensorMode)119def call_torchbind_fake(mode, *args, **kwargs):120with mode:121out = call_torchbind_impl(*args, **kwargs)122return pytree.tree_map_only(123torch.Tensor,124lambda x: mode.from_tensor(x, static_shapes=True)125if not isinstance(x, torch._subclasses.fake_tensor.FakeTensor)126else x,127out,128)129
130
131call_torchbind.py_impl(DispatchKey.Autograd)(132autograd_not_implemented(call_torchbind, deferred_error=True)133)
134
135
136@call_torchbind.py_functionalize_impl137def call_torchbind_func(ctx, *args, **kwargs):138from torch._higher_order_ops.effects import handle_effects139
140return handle_effects(141ctx.mode._allow_token_discovery, ctx.mode._tokens, call_torchbind, args, kwargs142)143