pytorch

Форк
0
/
schema_check_mode.py 
230 строк · 8.5 Кб
1
# mypy: ignore-errors
2

3
from collections import namedtuple
4
from copy import deepcopy
5
from itertools import combinations
6

7
import torch
8
from torch.fx.operator_schemas import normalize_function
9
from torch.utils import _pytree as pytree
10
from torch.utils._python_dispatch import TorchDispatchMode
11
from torch.utils._pytree import tree_map
12

13

14
# Named Tuples used within SchemaCheckMode
15
Mutation = namedtuple("Mutation", ["op_name", "arg_name"])
16
Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"])
17

18
# Simplified naming for C++ classes
19
SchemaArgument = torch._C._SchemaArgument
20
SchemaArgType = torch._C._SchemaArgType
21
SchemaInfo = torch._C._SchemaInfo
22

23
# This TorchDispatchMode Subclass is used to verify op schemas
24
# This TorchDispatchMode Scubclass currently:
25
#  - Records the called ops
26
#  - Checks for mutations on all inputs
27
#  - Checks for aliasing on all inputs
28

29

30
# move these 2 functions here to avoid numpy dependency in testing/_internal/common_utils.py
31

32

33
def is_iterable_of_tensors(iterable):
34
    # Tensor itself is iterable so we check this first
35
    if isinstance(iterable, torch.Tensor):
36
        return False
37
    try:
38
        if len(iterable) == 0:
39
            return False
40
        for t in iter(iterable):
41
            if not isinstance(t, torch.Tensor):
42
                return False
43
    except TypeError as te:
44
        return False
45
    return True
46

47

48
def clone_inputs(args):
49
    inputs = []
50

51
    for arg in args:
52
        if isinstance(arg, torch.Tensor):
53
            inputs.append(arg.detach().clone())
54
        elif is_iterable_of_tensors(arg):
55
            inputs.append([t.detach().clone() for t in arg])
56
        else:
57
            inputs.append(arg)
58

59
    return inputs
60

61

62
class SchemaCheckMode(TorchDispatchMode):
63
    def __init__(self) -> None:
64
        # Information recorded for testing purposes. For example:
65
        #  - incorrect schemas
66
        #  - overly conservative schemas
67
        self.ops = []
68
        self.mutated = []
69
        self.aliasing = []
70

71
    def reset_cache(self):
72
        self.ops.clear()
73
        self.mutated.clear()
74
        self.aliasing.clear()
75

76
    def display_ops(self):
77
        print(*self.ops, sep=",")
78

79
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
80
        def bitwise_equal(lhs, rhs):
81
            if lhs.is_quantized:
82
                # TODO: This is only OK if can't have NaN quantized; idk if
83
                # this is actually true
84
                return torch.equal(lhs, rhs)
85
            else:
86
                return torch.allclose(lhs, rhs, equal_nan=True)
87

88
        def has_mutated(before, after, md):
89
            are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor
90
            if (
91
                are_tensors
92
                and before.layout != torch.sparse_csr
93
                and after.layout != torch.sparse_csr
94
            ):
95
                return not (
96
                    before.size() == after.size()
97
                    and bitwise_equal(before, after)
98
                    and md[0] == after.stride()
99
                    and md[1] == after._typed_storage()._cdata
100
                )
101
            return False
102

103
        def has_aliased(lhs, rhs):
104
            try:
105
                return torch._C._overlaps(lhs, rhs)
106
            except Exception as exception:
107
                if str(exception).startswith("Cannot inspect value of type "):
108
                    return False
109
                else:
110
                    raise exception
111

112
        def standardize_name(name):
113
            return name if name != "self" else "input"
114

115
        def unwrap(e):
116
            if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
117
                try:
118
                    return e.elem
119
                except AttributeError as t:
120
                    return e
121
            return e
122

123
        def parse_metadata(e):
124
            if isinstance(e, torch.Tensor):
125
                if not type(e) == torch.Tensor:
126
                    try:
127
                        current = e.elem
128
                        return (
129
                            deepcopy(current.stride()),
130
                            current._typed_storage()._cdata,
131
                        )
132
                    except AttributeError as t:
133
                        return None
134
                # Sparse CSR tensors do not have strides or storage
135
                elif e.layout != torch.sparse_csr:
136
                    return (deepcopy(e.stride()), e._typed_storage()._cdata)
137
            return None
138

139
        self.ops.append(func._schema.name)
140

141
        # Clone and process arguments and outputs
142
        pre_arguments = normalize_function(
143
            func, args, kwargs, normalize_to_only_use_kwargs=True
144
        ).kwargs
145

146
        c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))
147
        cloned_arguments = {
148
            name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args
149
        }
150
        cloned_metadata = {
151
            name: [
152
                parse_metadata(a) for a in pytree.tree_leaves(pre_arguments.get(name))
153
            ]
154
            for name in pre_arguments
155
        }
156

157
        out = func(*args, **kwargs)
158
        arguments = {
159
            name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments
160
        }
161
        tuple_out = out if isinstance(out, tuple) else (out,)
162
        tuple_out = tree_map(unwrap, tuple_out)
163

164
        schema_info = SchemaInfo(func._schema)
165
        schema_info.add_argument_values(pre_arguments)
166

167
        # Process arguments with outputs
168
        for i in range(len(func._schema.arguments)):
169
            arg = func._schema.arguments[i]
170
            name = standardize_name(arg.name)
171
            if arguments.get(name) is not None:
172
                before = cloned_arguments.get(name)
173
                md = cloned_metadata.get(name)
174
                after = arguments.get(name)
175
                for j in range(len(tuple_out)):
176
                    # aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
177
                    unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split")
178
                    if (
179
                        has_aliased(tuple_out[j], after)
180
                        and func._schema.name not in unsafe_ops
181
                    ):
182
                        if not schema_info.may_contain_alias(
183
                            SchemaArgument(SchemaArgType.output, j),
184
                            SchemaArgument(SchemaArgType.input, i),
185
                        ):
186
                            raise RuntimeError(
187
                                f"Argument {name} is not defined to alias output but was aliasing"
188
                            )
189
                        else:
190
                            self.aliasing.append(
191
                                Aliasing(func._schema.name, name, f"output_{j}")
192
                            )
193
                    if after is tuple_out[j] and isinstance(after, torch.Tensor):
194
                        # Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs.
195
                        if not schema_info.is_mutable(
196
                            SchemaArgument(SchemaArgType.input, i)
197
                        ) and func not in [
198
                            torch.ops.aten.lift.default,
199
                            torch.ops.aten.lift_fresh.default,
200
                        ]:
201
                            raise RuntimeError(
202
                                f"""\
203
Dispatcher operators below autograd are not allowed to directly return inputs.
204
However, we found that `outputs[{str(j)}] is {name}"""
205
                            )
206
                if any(
207
                    has_mutated(a, b, c)
208
                    for a, b, c in zip(
209
                        pytree.tree_leaves(before), pytree.tree_leaves(after), md
210
                    )
211
                ):
212
                    if not schema_info.is_mutable(
213
                        SchemaArgument(SchemaArgType.input, i)
214
                    ):
215
                        raise RuntimeError(
216
                            f"Argument {name} is not defined as mutable but was mutated"
217
                        )
218
                    else:
219
                        self.mutated.append(Mutation(func._schema.name, name))
220

221
        # Aliasing between outputs
222
        for i, j in combinations(range(len(func._schema.returns)), 2):
223
            if has_aliased(tuple_out[i], tuple_out[j]):
224
                if not schema_info.may_contain_alias(
225
                    SchemaArgument(SchemaArgType.output, i),
226
                    SchemaArgument(SchemaArgType.output, j),
227
                ):
228
                    raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly")
229

230
        return out
231

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.