pytorch

Форк
0
142 строки · 5.3 Кб
1
# mypy: allow-untyped-defs
2
import logging
3
from contextlib import contextmanager
4

5
import torch
6
from torch._C import DispatchKey  # @manual
7
from torch._functorch._aot_autograd.utils import KNOWN_TYPES
8
from torch._higher_order_ops.utils import autograd_not_implemented
9
from torch._library.fake_class_registry import _ns_and_class_name, FakeScriptObject
10
from torch._ops import HigherOrderOperator
11
from torch._subclasses.fake_tensor import FakeTensorMode
12
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
13
from torch.fx.node import has_side_effect
14
from torch.utils import _pytree as pytree
15

16

17
log = logging.getLogger(__name__)
18

19

20
# The call_torchbind operator represents a method invocation on a torchbind
21
# object. The calling convention is:
22
#   call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs)
23
# We do not expect users to write this operator directly. Instead it will be
24
# emitted by Dynamo when tracing encounters a torchbind object.
25
class CallTorchBind(HigherOrderOperator):
26
    def __init__(self):
27
        super().__init__("call_torchbind")
28

29
    def __call__(self, obj, method, *args, **kwargs):
30
        return super().__call__(obj, method, *args, **kwargs)
31

32

33
call_torchbind = CallTorchBind()
34

35
# Register this operator as side-effectful with FX.
36
# TODO: this is not really sufficient. While passes (hopefully) check
37
# Node.is_impure() and make good decisions, we also assume we can execute the
38
# graph as many times as we want without changing behavior, which is NOT true of
39
# ops that mutate torchbind object state.
40
has_side_effect(call_torchbind)
41

42
_orig_scriptmethod_call = torch.ScriptMethod.__call__
43

44

45
def torchbind_method_redispatch(self, *args, **kwargs):
46
    if isinstance(self.raw_owner, torch.ScriptObject):
47
        return call_torchbind(self.raw_owner, self.name, *args, **kwargs)
48
    return _orig_scriptmethod_call(self, *args, **kwargs)
49

50

51
@contextmanager
52
def enable_torchbind_tracing():
53
    """Context manager that acts as a feature flag to enable torchbind tracing
54
    behavior. Once torchbind tracing has been stabilized, we can remove this and
55
    turn it always on.
56
    """
57
    try:
58
        KNOWN_TYPES.append(torch.ScriptObject)
59
        torch.ScriptMethod.__call__ = torchbind_method_redispatch  # type: ignore[method-assign]
60
        yield
61
    finally:
62
        assert (
63
            KNOWN_TYPES.pop() is torch.ScriptObject
64
        ), "Someone else messed with KNOWN_TYPES during tracing, exploding."
65
        torch.ScriptMethod.__call__ = _orig_scriptmethod_call  # type: ignore[method-assign]
66

67

68
@call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd)
69
def call_torchbind_impl(obj, method, *args, **kwargs):
70
    if isinstance(obj, torch.ScriptObject):
71
        return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs)
72
    elif isinstance(obj, FakeScriptObject):
73
        return getattr(obj.wrapped_obj, method)(*args, **kwargs)
74
    else:
75
        raise RuntimeError(f"Unsupported first arg type {type(obj)} for call_torchbind")
76

77

78
@call_torchbind.py_impl(ProxyTorchDispatchMode)
79
def inner(mode, *args, **kwargs):
80
    proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
81
    proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
82

83
    out_proxy = mode.tracer.create_proxy(
84
        "call_function",
85
        call_torchbind,
86
        proxy_args,
87
        proxy_kwargs,
88
    )
89
    out = call_torchbind(*args, **kwargs)
90

91
    obj, method, *rest_args = args
92
    if isinstance(obj, torch.ScriptObject):
93
        ns, class_name = _ns_and_class_name(
94
            obj._type().qualified_name()  # type: ignore[attr-defined]
95
        )
96
        log.warning(
97
            "Tracing torchbind method %s.%s with real ScriptObject. This may"
98
            " cause the original object being mutated. If this is not intended,"
99
            ' You can register a fake class with torch._library.register_fake_class("%s::%s").',
100
            class_name,
101
            method,
102
            ns,
103
            class_name,
104
        )
105

106
    ret = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
107
    if "val" not in out_proxy.node.meta:
108
        assert out is None or isinstance(
109
            out, (int, float, bool)
110
        ), "Currently, only these constant dtypes are supported to be returned from torchbind methods."
111
        out_proxy.node.meta["val"] = out
112
    return ret
113

114

115
# When tracing with fake script object, the call_torchbind op will return a fake tensor
116
# When tracing with real script object, the call_torchbind op may return a real tensor,
117
# we need to convert it to fake tensor mannually. Dynamic shape is surpported.
118
@call_torchbind.py_impl(FakeTensorMode)
119
def call_torchbind_fake(mode, *args, **kwargs):
120
    with mode:
121
        out = call_torchbind_impl(*args, **kwargs)
122
        return pytree.tree_map_only(
123
            torch.Tensor,
124
            lambda x: mode.from_tensor(x, static_shapes=True)
125
            if not isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
126
            else x,
127
            out,
128
        )
129

130

131
call_torchbind.py_impl(DispatchKey.Autograd)(
132
    autograd_not_implemented(call_torchbind, deferred_error=True)
133
)
134

135

136
@call_torchbind.py_functionalize_impl
137
def call_torchbind_func(ctx, *args, **kwargs):
138
    from torch._higher_order_ops.effects import handle_effects
139

140
    return handle_effects(
141
        ctx.mode._allow_token_discovery, ctx.mode._tokens, call_torchbind, args, kwargs
142
    )
143

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.