pytorch
198 строк · 7.7 Кб
1# mypy: ignore-errors
2
3from collections import namedtuple
4from copy import deepcopy
5from itertools import combinations
6
7import torch
8from torch.fx.operator_schemas import normalize_function
9from torch.testing._internal.jit_utils import clone_inputs
10from torch.utils import _pytree as pytree
11from torch.utils._python_dispatch import TorchDispatchMode
12from torch.utils._pytree import tree_map
13
14# Named Tuples used within SchemaCheckMode
15Mutation = namedtuple("Mutation", ["op_name", "arg_name"])
16Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"])
17
18# Simplified naming for C++ classes
19SchemaArgument = torch._C._SchemaArgument
20SchemaArgType = torch._C._SchemaArgType
21SchemaInfo = torch._C._SchemaInfo
22
23# This TorchDispatchMode Subclass is used to verify op schemas
24# This TorchDispatchMode Scubclass currently:
25# - Records the called ops
26# - Checks for mutations on all inputs
27# - Checks for aliasing on all inputs
28
29
30class SchemaCheckMode(TorchDispatchMode):
31def __init__(self):
32# Information recorded for testing purposes. For example:
33# - incorrect schemas
34# - overly conservative schemas
35self.ops = []
36self.mutated = []
37self.aliasing = []
38
39def reset_cache(self):
40self.ops.clear()
41self.mutated.clear()
42self.aliasing.clear()
43
44def display_ops(self):
45print(*self.ops, sep=",")
46
47def __torch_dispatch__(self, func, types, args=(), kwargs=None):
48def bitwise_equal(lhs, rhs):
49if lhs.is_quantized:
50# TODO: This is only OK if can't have NaN quantized; idk if
51# this is actually true
52return torch.equal(lhs, rhs)
53else:
54return torch.allclose(lhs, rhs, equal_nan=True)
55
56def has_mutated(before, after, md):
57are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor
58if (
59are_tensors
60and before.layout != torch.sparse_csr
61and after.layout != torch.sparse_csr
62):
63return not (
64before.size() == after.size()
65and bitwise_equal(before, after)
66and md[0] == after.stride()
67and md[1] == after._typed_storage()._cdata
68)
69return False
70
71def has_aliased(lhs, rhs):
72try:
73return torch._C._overlaps(lhs, rhs)
74except Exception as exception:
75if str(exception).startswith("Cannot inspect value of type "):
76return False
77else:
78raise exception
79
80def standardize_name(name):
81return name if name != "self" else "input"
82
83def unwrap(e):
84if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
85try:
86return e.elem
87except AttributeError as t:
88return e
89return e
90
91def parse_metadata(e):
92if isinstance(e, torch.Tensor):
93if not type(e) == torch.Tensor:
94try:
95current = e.elem
96return (
97deepcopy(current.stride()),
98current._typed_storage()._cdata,
99)
100except AttributeError as t:
101return None
102# Sparse CSR tensors do not have strides or storage
103elif e.layout != torch.sparse_csr:
104return (deepcopy(e.stride()), e._typed_storage()._cdata)
105return None
106
107self.ops.append(func._schema.name)
108
109# Clone and process arguments and outputs
110pre_arguments = normalize_function(
111func, args, kwargs, normalize_to_only_use_kwargs=True
112).kwargs
113
114c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))
115cloned_arguments = {
116name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args
117}
118cloned_metadata = {
119name: [
120parse_metadata(a) for a in pytree.tree_leaves(pre_arguments.get(name))
121]
122for name in pre_arguments
123}
124
125out = func(*args, **kwargs)
126arguments = {
127name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments
128}
129tuple_out = out if isinstance(out, tuple) else (out,)
130tuple_out = tree_map(unwrap, tuple_out)
131
132schema_info = SchemaInfo(func._schema)
133schema_info.add_argument_values(pre_arguments)
134
135# Process arguments with outputs
136for i in range(len(func._schema.arguments)):
137arg = func._schema.arguments[i]
138name = standardize_name(arg.name)
139if arguments.get(name) is not None:
140before = cloned_arguments.get(name)
141md = cloned_metadata.get(name)
142after = arguments.get(name)
143for j in range(len(tuple_out)):
144# aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
145unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split")
146if (
147has_aliased(tuple_out[j], after)
148and func._schema.name not in unsafe_ops
149):
150if not schema_info.may_contain_alias(
151SchemaArgument(SchemaArgType.output, j),
152SchemaArgument(SchemaArgType.input, i),
153):
154raise RuntimeError(
155f"Argument {name} is not defined to alias output but was aliasing"
156)
157else:
158self.aliasing.append(
159Aliasing(func._schema.name, name, f"output_{j}")
160)
161if after is tuple_out[j] and isinstance(after, torch.Tensor):
162# Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs.
163if not schema_info.is_mutable(
164SchemaArgument(SchemaArgType.input, i)
165) and func not in [
166torch.ops.aten.lift.default,
167torch.ops.aten.lift_fresh.default,
168]:
169raise RuntimeError(
170f"""\
171Dispatcher operators below autograd are not allowed to directly return inputs.
172However, we found that `outputs[{str(j)}] is {name}"""
173)
174if any(
175has_mutated(a, b, c)
176for a, b, c in zip(
177pytree.tree_leaves(before), pytree.tree_leaves(after), md
178)
179):
180if not schema_info.is_mutable(
181SchemaArgument(SchemaArgType.input, i)
182):
183raise RuntimeError(
184f"Argument {name} is not defined as mutable but was mutated"
185)
186else:
187self.mutated.append(Mutation(func._schema.name, name))
188
189# Aliasing between outputs
190for i, j in combinations(range(len(func._schema.returns)), 2):
191if has_aliased(tuple_out[i], tuple_out[j]):
192if not schema_info.may_contain_alias(
193SchemaArgument(SchemaArgType.output, i),
194SchemaArgument(SchemaArgType.output, j),
195):
196raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly")
197
198return out
199