pytorch
197 строк · 7.2 Кб
1# mypy: ignore-errors
2
3import functools
4import warnings
5from typing import Callable, Union
6
7import torch
8import torch.utils._pytree as pytree
9from torch._ops import OpOverload
10from torch._subclasses.fake_tensor import (
11FakeTensorMode,
12tree_flatten_only,
13UnsupportedFakeTensorException,
14)
15from torch.utils._python_dispatch import TorchDispatchMode
16
17
18aten = torch._ops.ops.aten
19
20
21def outputs_alias_inputs(outputs, inputs):
22input_storages = {
23inp._typed_storage()._cdata
24for inp in tree_flatten_only(torch.Tensor, inputs)
25if torch._C._has_storage(inp)
26}
27return any(
28torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
29for out in tree_flatten_only(torch.Tensor, outputs)
30)
31
32
33def outputs_are_inputs(outputs, inputs):
34input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)}
35return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs))
36
37
38def output_alias_each_other(outputs):
39storages = set()
40for out in tree_flatten_only(torch.Tensor, outputs):
41if not torch._C._has_storage(out):
42continue
43stor = out._typed_storage()._cdata
44if stor in storages:
45return True
46storages.add(stor)
47return False
48
49
50def is_sdpa_error(func, idx, e):
51if (
52(
53func is aten._scaled_dot_product_flash_attention.default
54or func is aten._flash_attention_forward.default
55)
56and idx in (6, 7)
57and "Devices" in repr(e)
58):
59return True
60if (
61(
62func is aten._scaled_dot_product_efficient_attention.default
63or func is aten._efficient_attention_forward.default
64)
65and idx in (2, 3)
66and "Devices" in repr(e)
67):
68return True
69if (
70func is aten._scaled_dot_product_cudnn_attention.default
71and idx in (6, 7)
72and "Devices" in repr(e)
73):
74return True
75return False
76
77
78class CrossRefFakeMode(TorchDispatchMode):
79def __init__(
80self,
81ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None,
82*,
83check_strides=True,
84check_aliasing=True,
85):
86super().__init__()
87self.ignore_op_fn = (
88ignore_op_fn if ignore_op_fn is not None else lambda fn: False
89)
90self.check_strides = check_strides
91self.check_aliasing = check_aliasing
92
93def __torch_dispatch__(self, func, types, args=(), kwargs=None):
94kwargs = kwargs or {}
95
96fake_r = None
97
98# empty_like excluded for now due to sparse complex
99# aten._to_dense.default this one is getting called with csc
100if (
101func
102not in (
103aten.lift_fresh.default,
104aten.lift_fresh_copy.default,
105aten.set_.source_Storage_storage_offset,
106)
107and not self.ignore_op_fn(func)
108and torch.Tag.dynamic_output_shape not in func.tags
109and torch.Tag.inplace_view not in func.tags
110and torch.Tag.data_dependent_output not in func.tags
111):
112# Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow
113from torch.fx.experimental.symbolic_shapes import ShapeEnv
114
115try:
116# TODO: enable_python_dispatcher() here
117with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode:
118fake_args, fake_kwargs = pytree.tree_map_only(
119torch.Tensor,
120functools.partial(fake_mode.from_tensor, static_shapes=True),
121(args, kwargs),
122)
123with warnings.catch_warnings():
124fake_r = func(*fake_args, **fake_kwargs)
125except UnsupportedFakeTensorException:
126pass
127
128context = (
129f"When comparing the output of {func} on FakeTensor and concrete Tensors, "
130f"found"
131)
132r = func(*args, **kwargs)
133if fake_r is not None:
134r_flat = pytree.tree_leaves(r)
135f_flat = pytree.tree_leaves(fake_r)
136assert len(f_flat) == len(
137r_flat
138), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}"
139
140if self.check_aliasing:
141r_aliasing = outputs_alias_inputs(r, (args, kwargs))
142f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs))
143assert (
144r_aliasing == f_aliasing
145), f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}"
146
147r_identity_eq = outputs_are_inputs(r, (args, kwargs))
148f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs))
149assert (
150r_identity_eq == f_identity_eq
151), f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}"
152
153r_output_alias_each_other = output_alias_each_other(r)
154f_output_alias_each_other = output_alias_each_other(fake_r)
155assert r_output_alias_each_other == f_output_alias_each_other, (
156f"{context} mismatch in outputs_alias_each_other check "
157f"{f_output_alias_each_other} != {r_output_alias_each_other}"
158)
159
160for idx, (r_out, fake_out) in enumerate(
161zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r))
162):
163r_is_ten = isinstance(r_out, torch.Tensor)
164assert r_is_ten == isinstance(
165fake_out, torch.Tensor
166), f"{context} mismatched number of tensor outputs"
167if r_is_ten:
168assert r_out.requires_grad == fake_out.requires_grad, (
169f"{context} mismatched requires_grad-ness of outputs. "
170f"This usually means that you have added autograd support "
171f"for your operator at a dispatch key other than Autograd, "
172f"which will lead to problems"
173)
174if torch._C._has_storage(r_out):
175r_offset = r_out.storage_offset()
176f_offset = fake_out.storage_offset()
177assert (
178r_offset == f_offset
179), f"{context} mismatched storage offset"
180
181try:
182torch._prims.utils.compare_tensor_meta(
183r_out,
184fake_out,
185check_strides=self.check_strides,
186allow_rhs_unbacked=True,
187)
188except Exception as e:
189if is_sdpa_error(func, idx, e):
190continue
191error_message = (
192f"{context} mismatched tensor metadata: {e}"
193if len(r_flat) == 1
194else f"{context} mismatched tensor metadata for output[{idx}]: {e}"
195)
196raise RuntimeError(error_message) from e
197return r
198