pytorch

Форк
0
/
schema_check_mode.py 
198 строк · 7.7 Кб
1
# mypy: ignore-errors
2

3
from collections import namedtuple
4
from copy import deepcopy
5
from itertools import combinations
6

7
import torch
8
from torch.fx.operator_schemas import normalize_function
9
from torch.testing._internal.jit_utils import clone_inputs
10
from torch.utils import _pytree as pytree
11
from torch.utils._python_dispatch import TorchDispatchMode
12
from torch.utils._pytree import tree_map
13

14
# Named Tuples used within SchemaCheckMode
15
Mutation = namedtuple("Mutation", ["op_name", "arg_name"])
16
Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"])
17

18
# Simplified naming for C++ classes
19
SchemaArgument = torch._C._SchemaArgument
20
SchemaArgType = torch._C._SchemaArgType
21
SchemaInfo = torch._C._SchemaInfo
22

23
# This TorchDispatchMode Subclass is used to verify op schemas
24
# This TorchDispatchMode Scubclass currently:
25
#  - Records the called ops
26
#  - Checks for mutations on all inputs
27
#  - Checks for aliasing on all inputs
28

29

30
class SchemaCheckMode(TorchDispatchMode):
31
    def __init__(self):
32
        # Information recorded for testing purposes. For example:
33
        #  - incorrect schemas
34
        #  - overly conservative schemas
35
        self.ops = []
36
        self.mutated = []
37
        self.aliasing = []
38

39
    def reset_cache(self):
40
        self.ops.clear()
41
        self.mutated.clear()
42
        self.aliasing.clear()
43

44
    def display_ops(self):
45
        print(*self.ops, sep=",")
46

47
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
48
        def bitwise_equal(lhs, rhs):
49
            if lhs.is_quantized:
50
                # TODO: This is only OK if can't have NaN quantized; idk if
51
                # this is actually true
52
                return torch.equal(lhs, rhs)
53
            else:
54
                return torch.allclose(lhs, rhs, equal_nan=True)
55

56
        def has_mutated(before, after, md):
57
            are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor
58
            if (
59
                are_tensors
60
                and before.layout != torch.sparse_csr
61
                and after.layout != torch.sparse_csr
62
            ):
63
                return not (
64
                    before.size() == after.size()
65
                    and bitwise_equal(before, after)
66
                    and md[0] == after.stride()
67
                    and md[1] == after._typed_storage()._cdata
68
                )
69
            return False
70

71
        def has_aliased(lhs, rhs):
72
            try:
73
                return torch._C._overlaps(lhs, rhs)
74
            except Exception as exception:
75
                if str(exception).startswith("Cannot inspect value of type "):
76
                    return False
77
                else:
78
                    raise exception
79

80
        def standardize_name(name):
81
            return name if name != "self" else "input"
82

83
        def unwrap(e):
84
            if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
85
                try:
86
                    return e.elem
87
                except AttributeError as t:
88
                    return e
89
            return e
90

91
        def parse_metadata(e):
92
            if isinstance(e, torch.Tensor):
93
                if not type(e) == torch.Tensor:
94
                    try:
95
                        current = e.elem
96
                        return (
97
                            deepcopy(current.stride()),
98
                            current._typed_storage()._cdata,
99
                        )
100
                    except AttributeError as t:
101
                        return None
102
                # Sparse CSR tensors do not have strides or storage
103
                elif e.layout != torch.sparse_csr:
104
                    return (deepcopy(e.stride()), e._typed_storage()._cdata)
105
            return None
106

107
        self.ops.append(func._schema.name)
108

109
        # Clone and process arguments and outputs
110
        pre_arguments = normalize_function(
111
            func, args, kwargs, normalize_to_only_use_kwargs=True
112
        ).kwargs
113

114
        c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))
115
        cloned_arguments = {
116
            name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args
117
        }
118
        cloned_metadata = {
119
            name: [
120
                parse_metadata(a) for a in pytree.tree_leaves(pre_arguments.get(name))
121
            ]
122
            for name in pre_arguments
123
        }
124

125
        out = func(*args, **kwargs)
126
        arguments = {
127
            name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments
128
        }
129
        tuple_out = out if isinstance(out, tuple) else (out,)
130
        tuple_out = tree_map(unwrap, tuple_out)
131

132
        schema_info = SchemaInfo(func._schema)
133
        schema_info.add_argument_values(pre_arguments)
134

135
        # Process arguments with outputs
136
        for i in range(len(func._schema.arguments)):
137
            arg = func._schema.arguments[i]
138
            name = standardize_name(arg.name)
139
            if arguments.get(name) is not None:
140
                before = cloned_arguments.get(name)
141
                md = cloned_metadata.get(name)
142
                after = arguments.get(name)
143
                for j in range(len(tuple_out)):
144
                    # aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
145
                    unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split")
146
                    if (
147
                        has_aliased(tuple_out[j], after)
148
                        and func._schema.name not in unsafe_ops
149
                    ):
150
                        if not schema_info.may_contain_alias(
151
                            SchemaArgument(SchemaArgType.output, j),
152
                            SchemaArgument(SchemaArgType.input, i),
153
                        ):
154
                            raise RuntimeError(
155
                                f"Argument {name} is not defined to alias output but was aliasing"
156
                            )
157
                        else:
158
                            self.aliasing.append(
159
                                Aliasing(func._schema.name, name, f"output_{j}")
160
                            )
161
                    if after is tuple_out[j] and isinstance(after, torch.Tensor):
162
                        # Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs.
163
                        if not schema_info.is_mutable(
164
                            SchemaArgument(SchemaArgType.input, i)
165
                        ) and func not in [
166
                            torch.ops.aten.lift.default,
167
                            torch.ops.aten.lift_fresh.default,
168
                        ]:
169
                            raise RuntimeError(
170
                                f"""\
171
Dispatcher operators below autograd are not allowed to directly return inputs.
172
However, we found that `outputs[{str(j)}] is {name}"""
173
                            )
174
                if any(
175
                    has_mutated(a, b, c)
176
                    for a, b, c in zip(
177
                        pytree.tree_leaves(before), pytree.tree_leaves(after), md
178
                    )
179
                ):
180
                    if not schema_info.is_mutable(
181
                        SchemaArgument(SchemaArgType.input, i)
182
                    ):
183
                        raise RuntimeError(
184
                            f"Argument {name} is not defined as mutable but was mutated"
185
                        )
186
                    else:
187
                        self.mutated.append(Mutation(func._schema.name, name))
188

189
        # Aliasing between outputs
190
        for i, j in combinations(range(len(func._schema.returns)), 2):
191
            if has_aliased(tuple_out[i], tuple_out[j]):
192
                if not schema_info.may_contain_alias(
193
                    SchemaArgument(SchemaArgType.output, i),
194
                    SchemaArgument(SchemaArgType.output, j),
195
                ):
196
                    raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly")
197

198
        return out
199

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.