pytorch

Форк
0
/
fake_utils.py 
197 строк · 7.2 Кб
1
# mypy: ignore-errors
2

3
import functools
4
import warnings
5
from typing import Callable, Union
6

7
import torch
8
import torch.utils._pytree as pytree
9
from torch._ops import OpOverload
10
from torch._subclasses.fake_tensor import (
11
    FakeTensorMode,
12
    tree_flatten_only,
13
    UnsupportedFakeTensorException,
14
)
15
from torch.utils._python_dispatch import TorchDispatchMode
16

17

18
aten = torch._ops.ops.aten
19

20

21
def outputs_alias_inputs(outputs, inputs):
22
    input_storages = {
23
        inp._typed_storage()._cdata
24
        for inp in tree_flatten_only(torch.Tensor, inputs)
25
        if torch._C._has_storage(inp)
26
    }
27
    return any(
28
        torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
29
        for out in tree_flatten_only(torch.Tensor, outputs)
30
    )
31

32

33
def outputs_are_inputs(outputs, inputs):
34
    input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)}
35
    return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs))
36

37

38
def output_alias_each_other(outputs):
39
    storages = set()
40
    for out in tree_flatten_only(torch.Tensor, outputs):
41
        if not torch._C._has_storage(out):
42
            continue
43
        stor = out._typed_storage()._cdata
44
        if stor in storages:
45
            return True
46
        storages.add(stor)
47
    return False
48

49

50
def is_sdpa_error(func, idx, e):
51
    if (
52
        (
53
            func is aten._scaled_dot_product_flash_attention.default
54
            or func is aten._flash_attention_forward.default
55
        )
56
        and idx in (6, 7)
57
        and "Devices" in repr(e)
58
    ):
59
        return True
60
    if (
61
        (
62
            func is aten._scaled_dot_product_efficient_attention.default
63
            or func is aten._efficient_attention_forward.default
64
        )
65
        and idx in (2, 3)
66
        and "Devices" in repr(e)
67
    ):
68
        return True
69
    if (
70
        func is aten._scaled_dot_product_cudnn_attention.default
71
        and idx in (6, 7)
72
        and "Devices" in repr(e)
73
    ):
74
        return True
75
    return False
76

77

78
class CrossRefFakeMode(TorchDispatchMode):
79
    def __init__(
80
        self,
81
        ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None,
82
        *,
83
        check_strides=True,
84
        check_aliasing=True,
85
    ):
86
        super().__init__()
87
        self.ignore_op_fn = (
88
            ignore_op_fn if ignore_op_fn is not None else lambda fn: False
89
        )
90
        self.check_strides = check_strides
91
        self.check_aliasing = check_aliasing
92

93
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
94
        kwargs = kwargs or {}
95

96
        fake_r = None
97

98
        # empty_like excluded for now due to sparse complex
99
        # aten._to_dense.default this one is getting called with csc
100
        if (
101
            func
102
            not in (
103
                aten.lift_fresh.default,
104
                aten.lift_fresh_copy.default,
105
                aten.set_.source_Storage_storage_offset,
106
            )
107
            and not self.ignore_op_fn(func)
108
            and torch.Tag.dynamic_output_shape not in func.tags
109
            and torch.Tag.inplace_view not in func.tags
110
            and torch.Tag.data_dependent_output not in func.tags
111
        ):
112
            # Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow
113
            from torch.fx.experimental.symbolic_shapes import ShapeEnv
114

115
            try:
116
                # TODO: enable_python_dispatcher() here
117
                with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode:
118
                    fake_args, fake_kwargs = pytree.tree_map_only(
119
                        torch.Tensor,
120
                        functools.partial(fake_mode.from_tensor, static_shapes=True),
121
                        (args, kwargs),
122
                    )
123
                    with warnings.catch_warnings():
124
                        fake_r = func(*fake_args, **fake_kwargs)
125
            except UnsupportedFakeTensorException:
126
                pass
127

128
        context = (
129
            f"When comparing the output of {func} on FakeTensor and concrete Tensors, "
130
            f"found"
131
        )
132
        r = func(*args, **kwargs)
133
        if fake_r is not None:
134
            r_flat = pytree.tree_leaves(r)
135
            f_flat = pytree.tree_leaves(fake_r)
136
            assert len(f_flat) == len(
137
                r_flat
138
            ), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}"
139

140
            if self.check_aliasing:
141
                r_aliasing = outputs_alias_inputs(r, (args, kwargs))
142
                f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs))
143
                assert (
144
                    r_aliasing == f_aliasing
145
                ), f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}"
146

147
                r_identity_eq = outputs_are_inputs(r, (args, kwargs))
148
                f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs))
149
                assert (
150
                    r_identity_eq == f_identity_eq
151
                ), f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}"
152

153
                r_output_alias_each_other = output_alias_each_other(r)
154
                f_output_alias_each_other = output_alias_each_other(fake_r)
155
                assert r_output_alias_each_other == f_output_alias_each_other, (
156
                    f"{context} mismatch in outputs_alias_each_other check "
157
                    f"{f_output_alias_each_other} != {r_output_alias_each_other}"
158
                )
159

160
            for idx, (r_out, fake_out) in enumerate(
161
                zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r))
162
            ):
163
                r_is_ten = isinstance(r_out, torch.Tensor)
164
                assert r_is_ten == isinstance(
165
                    fake_out, torch.Tensor
166
                ), f"{context} mismatched number of tensor outputs"
167
                if r_is_ten:
168
                    assert r_out.requires_grad == fake_out.requires_grad, (
169
                        f"{context} mismatched requires_grad-ness of outputs. "
170
                        f"This usually means that you have added autograd support "
171
                        f"for your operator at a dispatch key other than Autograd, "
172
                        f"which will lead to problems"
173
                    )
174
                    if torch._C._has_storage(r_out):
175
                        r_offset = r_out.storage_offset()
176
                        f_offset = fake_out.storage_offset()
177
                        assert (
178
                            r_offset == f_offset
179
                        ), f"{context} mismatched storage offset"
180

181
                    try:
182
                        torch._prims.utils.compare_tensor_meta(
183
                            r_out,
184
                            fake_out,
185
                            check_strides=self.check_strides,
186
                            allow_rhs_unbacked=True,
187
                        )
188
                    except Exception as e:
189
                        if is_sdpa_error(func, idx, e):
190
                            continue
191
                        error_message = (
192
                            f"{context} mismatched tensor metadata: {e}"
193
                            if len(r_flat) == 1
194
                            else f"{context} mismatched tensor metadata for output[{idx}]: {e}"
195
                        )
196
                        raise RuntimeError(error_message) from e
197
        return r
198

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.