pytorch

Форк
0
/
python.py 
180 строк · 6.2 Кб
1
# mypy: allow-untyped-defs
2
import itertools
3
import unittest.mock
4
from contextlib import contextmanager
5
from typing import Iterator
6

7
import torch
8
import torch._C
9
import torch._ops
10
import torch.utils._python_dispatch
11
import torch.utils._pytree as pytree
12

13

14
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
15

16
no_python_dispatcher = torch._C._DisablePythonDispatcher
17
enable_python_dispatcher = torch._C._EnablePythonDispatcher
18
enable_pre_dispatch = torch._C._EnablePreDispatch
19

20
CROSSREF_FUNCTIONALIZE = False
21

22

23
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
24
    """
25
    Warning: the set of overloads this will report is very subtle.  It is precisely
26
    the set of torch.ops functions that have actually been accessed from Python
27
    (e.g., we actually called torch.ops.aten.blah at some point.  This is DIFFERENT
28
    from the set of registered operators, which will in general be a larger set,
29
    as this would include all operators which we ran C++ static initializers or
30
    Python operator registration on.  This does not eagerly populate the list on
31
    torch.ops.aten; this list is lazy!
32

33
    In other words, this is good for traversing over everything that has an
34
    OpOverload object allocated in Python.  We use it for cache invalidation, but
35
    don't rely on this list being complete.
36

37
    Note that even if we did report all C++ registered overloads, this isn't guaranteed
38
    to be complete either, as a subsequent lazy load of a library which triggers more
39
    registrations could add more things to the set.
40
    """
41
    for ns in torch.ops:
42
        packets = getattr(torch.ops, ns)
43
        for op_name in packets:
44
            packet = getattr(packets, op_name)
45
            for overload in packet:
46
                yield getattr(packet, overload)
47

48

49
@contextmanager
50
def suspend_functionalization():
51
    f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
52
        torch._C.DispatchKey.Functionalize
53
    )
54
    f_rv = torch._C._functionalization_reapply_views_tls()
55
    if f_tls:
56
        torch._disable_functionalization()
57
    try:
58
        yield
59
    finally:
60
        if f_tls:
61
            torch._enable_functionalization(reapply_views=f_rv)
62

63

64
def check_tensor_metadata_matches(nv, rv, desc):
65
    assert callable(desc)
66
    assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
67
    assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
68
    same_strides, idx = torch._prims_common.check_significant_strides(
69
        nv, rv, only_cuda=False
70
    )
71
    assert (
72
        same_strides
73
    ), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
74

75

76
def check_metadata_matches(n, r, desc):
77
    assert callable(desc)
78
    n_vals, n_spec = pytree.tree_flatten(n)
79
    r_vals, r_spec = pytree.tree_flatten(r)
80
    # TODO: test the specs match; empirically  sometimes we have a tuple
81
    # on one side and a list on the other
82
    assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
83
    for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
84
        if not isinstance(rv, torch.Tensor):
85
            continue
86
        check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
87

88

89
class Lit:
90
    def __init__(self, s):
91
        self.s = s
92

93
    def __repr__(self):
94
        return self.s
95

96

97
def _fmt(a: object) -> object:
98
    if isinstance(a, torch.Tensor):
99
        return Lit(
100
            f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
101
        )
102
    else:
103
        return a
104

105

106
def make_crossref_functionalize(op, final_key):
107
    from torch._subclasses.fake_tensor import FakeTensorMode
108

109
    # This case is pretty weird, suppress it for now
110
    if op == torch.ops.aten.lift_fresh.default:
111
        return final_key
112

113
    def handler(*args, **kwargs):
114
        fake_mode = FakeTensorMode()
115

116
        def fakeify_defun(t):
117
            if isinstance(t, torch.Tensor):
118
                if torch._is_functional_tensor(t):
119
                    r = torch._from_functional_tensor(t)
120
                    # NB: This assumes that the inner tensor sizes/strides match
121
                    # the outer tensor sizes/strides.  This doesn't necessarily have to
122
                    # be the case, see discussion at
123
                    # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
124
                    assert t.size() == r.size()
125
                    assert t.stride() == r.stride()
126
                else:
127
                    r = t
128
                # TODO: suppress guards
129
                return fake_mode.from_tensor(r)
130
            return t
131

132
        def maybe_detach(t):
133
            if isinstance(t, torch.Tensor):
134
                return t.detach()
135
            else:
136
                return t
137

138
        # TODO: This probably does the wrong thing if you're running other
139
        # substantive modes with the normal op outside here
140
        with torch.utils._python_dispatch._disable_current_modes(), suspend_functionalization():
141
            f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
142
            orig_f_args, orig_f_kwargs = pytree.tree_map(
143
                maybe_detach, (f_args, f_kwargs)
144
            )
145
            with fake_mode:
146
                f_r = op(*f_args, **f_kwargs)
147
        r = op._op_dk(final_key, *args, **kwargs)
148

149
        def desc():
150
            fmt_args = ", ".join(
151
                itertools.chain(
152
                    (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
153
                    (
154
                        f"{k}={pytree.tree_map(_fmt, v)}"
155
                        for k, v in orig_f_kwargs.items()
156
                    ),
157
                )
158
            )
159
            return f"{op}({fmt_args})"
160

161
        check_metadata_matches(f_r, r, desc)
162
        return r
163

164
    return handler
165

166

167
# NB: enabling this is slow, don't do it in a hot loop.  This is purely
168
# for debugging purposes.
169
@contextmanager
170
def enable_crossref_functionalize():
171
    for op in all_py_loaded_overloads():
172
        op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
173
    try:
174
        with enable_python_dispatcher(), unittest.mock.patch(
175
            "torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True
176
        ):
177
            yield
178
    finally:
179
        for op in all_py_loaded_overloads():
180
            op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
181

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.