4
from contextlib import contextmanager
5
from typing import Iterator
10
import torch.utils._python_dispatch
11
import torch.utils._pytree as pytree
14
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
16
no_python_dispatcher = torch._C._DisablePythonDispatcher
17
enable_python_dispatcher = torch._C._EnablePythonDispatcher
18
enable_pre_dispatch = torch._C._EnablePreDispatch
20
CROSSREF_FUNCTIONALIZE = False
23
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
25
Warning: the set of overloads this will report is very subtle. It is precisely
26
the set of torch.ops functions that have actually been accessed from Python
27
(e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
28
from the set of registered operators, which will in general be a larger set,
29
as this would include all operators which we ran C++ static initializers or
30
Python operator registration on. This does not eagerly populate the list on
31
torch.ops.aten; this list is lazy!
33
In other words, this is good for traversing over everything that has an
34
OpOverload object allocated in Python. We use it for cache invalidation, but
35
don't rely on this list being complete.
37
Note that even if we did report all C++ registered overloads, this isn't guaranteed
38
to be complete either, as a subsequent lazy load of a library which triggers more
39
registrations could add more things to the set.
42
packets = getattr(torch.ops, ns)
43
for op_name in packets:
44
packet = getattr(packets, op_name)
45
for overload in packet:
46
yield getattr(packet, overload)
50
def suspend_functionalization():
51
f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
52
torch._C.DispatchKey.Functionalize
54
f_rv = torch._C._functionalization_reapply_views_tls()
56
torch._disable_functionalization()
61
torch._enable_functionalization(reapply_views=f_rv)
64
def check_tensor_metadata_matches(nv, rv, desc):
66
assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
67
assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
68
same_strides, idx = torch._prims_common.check_significant_strides(
69
nv, rv, only_cuda=False
73
), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
76
def check_metadata_matches(n, r, desc):
78
n_vals, n_spec = pytree.tree_flatten(n)
79
r_vals, r_spec = pytree.tree_flatten(r)
82
assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
83
for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
84
if not isinstance(rv, torch.Tensor):
86
check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
90
def __init__(self, s):
97
def _fmt(a: object) -> object:
98
if isinstance(a, torch.Tensor):
100
f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
106
def make_crossref_functionalize(op, final_key):
107
from torch._subclasses.fake_tensor import FakeTensorMode
110
if op == torch.ops.aten.lift_fresh.default:
113
def handler(*args, **kwargs):
114
fake_mode = FakeTensorMode()
116
def fakeify_defun(t):
117
if isinstance(t, torch.Tensor):
118
if torch._is_functional_tensor(t):
119
r = torch._from_functional_tensor(t)
124
assert t.size() == r.size()
125
assert t.stride() == r.stride()
129
return fake_mode.from_tensor(r)
133
if isinstance(t, torch.Tensor):
140
with torch.utils._python_dispatch._disable_current_modes(), suspend_functionalization():
141
f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
142
orig_f_args, orig_f_kwargs = pytree.tree_map(
143
maybe_detach, (f_args, f_kwargs)
146
f_r = op(*f_args, **f_kwargs)
147
r = op._op_dk(final_key, *args, **kwargs)
150
fmt_args = ", ".join(
152
(repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
154
f"{k}={pytree.tree_map(_fmt, v)}"
155
for k, v in orig_f_kwargs.items()
159
return f"{op}({fmt_args})"
161
check_metadata_matches(f_r, r, desc)
170
def enable_crossref_functionalize():
171
for op in all_py_loaded_overloads():
172
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
174
with enable_python_dispatcher(), unittest.mock.patch(
175
"torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True
179
for op in all_py_loaded_overloads():
180
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)