pytorch
230 строк · 8.5 Кб
1# mypy: ignore-errors
2
3from collections import namedtuple4from copy import deepcopy5from itertools import combinations6
7import torch8from torch.fx.operator_schemas import normalize_function9from torch.utils import _pytree as pytree10from torch.utils._python_dispatch import TorchDispatchMode11from torch.utils._pytree import tree_map12
13
14# Named Tuples used within SchemaCheckMode
15Mutation = namedtuple("Mutation", ["op_name", "arg_name"])16Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"])17
18# Simplified naming for C++ classes
19SchemaArgument = torch._C._SchemaArgument20SchemaArgType = torch._C._SchemaArgType21SchemaInfo = torch._C._SchemaInfo22
23# This TorchDispatchMode Subclass is used to verify op schemas
24# This TorchDispatchMode Scubclass currently:
25# - Records the called ops
26# - Checks for mutations on all inputs
27# - Checks for aliasing on all inputs
28
29
30# move these 2 functions here to avoid numpy dependency in testing/_internal/common_utils.py
31
32
33def is_iterable_of_tensors(iterable):34# Tensor itself is iterable so we check this first35if isinstance(iterable, torch.Tensor):36return False37try:38if len(iterable) == 0:39return False40for t in iter(iterable):41if not isinstance(t, torch.Tensor):42return False43except TypeError as te:44return False45return True46
47
48def clone_inputs(args):49inputs = []50
51for arg in args:52if isinstance(arg, torch.Tensor):53inputs.append(arg.detach().clone())54elif is_iterable_of_tensors(arg):55inputs.append([t.detach().clone() for t in arg])56else:57inputs.append(arg)58
59return inputs60
61
62class SchemaCheckMode(TorchDispatchMode):63def __init__(self) -> None:64# Information recorded for testing purposes. For example:65# - incorrect schemas66# - overly conservative schemas67self.ops = []68self.mutated = []69self.aliasing = []70
71def reset_cache(self):72self.ops.clear()73self.mutated.clear()74self.aliasing.clear()75
76def display_ops(self):77print(*self.ops, sep=",")78
79def __torch_dispatch__(self, func, types, args=(), kwargs=None):80def bitwise_equal(lhs, rhs):81if lhs.is_quantized:82# TODO: This is only OK if can't have NaN quantized; idk if83# this is actually true84return torch.equal(lhs, rhs)85else:86return torch.allclose(lhs, rhs, equal_nan=True)87
88def has_mutated(before, after, md):89are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor90if (91are_tensors
92and before.layout != torch.sparse_csr93and after.layout != torch.sparse_csr94):95return not (96before.size() == after.size()97and bitwise_equal(before, after)98and md[0] == after.stride()99and md[1] == after._typed_storage()._cdata100)101return False102
103def has_aliased(lhs, rhs):104try:105return torch._C._overlaps(lhs, rhs)106except Exception as exception:107if str(exception).startswith("Cannot inspect value of type "):108return False109else:110raise exception111
112def standardize_name(name):113return name if name != "self" else "input"114
115def unwrap(e):116if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:117try:118return e.elem119except AttributeError as t:120return e121return e122
123def parse_metadata(e):124if isinstance(e, torch.Tensor):125if not type(e) == torch.Tensor:126try:127current = e.elem128return (129deepcopy(current.stride()),130current._typed_storage()._cdata,131)132except AttributeError as t:133return None134# Sparse CSR tensors do not have strides or storage135elif e.layout != torch.sparse_csr:136return (deepcopy(e.stride()), e._typed_storage()._cdata)137return None138
139self.ops.append(func._schema.name)140
141# Clone and process arguments and outputs142pre_arguments = normalize_function(143func, args, kwargs, normalize_to_only_use_kwargs=True144).kwargs145
146c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))147cloned_arguments = {148name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args149}150cloned_metadata = {151name: [152parse_metadata(a) for a in pytree.tree_leaves(pre_arguments.get(name))153]154for name in pre_arguments155}156
157out = func(*args, **kwargs)158arguments = {159name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments160}161tuple_out = out if isinstance(out, tuple) else (out,)162tuple_out = tree_map(unwrap, tuple_out)163
164schema_info = SchemaInfo(func._schema)165schema_info.add_argument_values(pre_arguments)166
167# Process arguments with outputs168for i in range(len(func._schema.arguments)):169arg = func._schema.arguments[i]170name = standardize_name(arg.name)171if arguments.get(name) is not None:172before = cloned_arguments.get(name)173md = cloned_metadata.get(name)174after = arguments.get(name)175for j in range(len(tuple_out)):176# aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)177unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split")178if (179has_aliased(tuple_out[j], after)180and func._schema.name not in unsafe_ops181):182if not schema_info.may_contain_alias(183SchemaArgument(SchemaArgType.output, j),184SchemaArgument(SchemaArgType.input, i),185):186raise RuntimeError(187f"Argument {name} is not defined to alias output but was aliasing"188)189else:190self.aliasing.append(191Aliasing(func._schema.name, name, f"output_{j}")192)193if after is tuple_out[j] and isinstance(after, torch.Tensor):194# Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs.195if not schema_info.is_mutable(196SchemaArgument(SchemaArgType.input, i)197) and func not in [198torch.ops.aten.lift.default,199torch.ops.aten.lift_fresh.default,200]:201raise RuntimeError(202f"""\203Dispatcher operators below autograd are not allowed to directly return inputs.
204However, we found that `outputs[{str(j)}] is {name}"""205)206if any(207has_mutated(a, b, c)208for a, b, c in zip(209pytree.tree_leaves(before), pytree.tree_leaves(after), md210)211):212if not schema_info.is_mutable(213SchemaArgument(SchemaArgType.input, i)214):215raise RuntimeError(216f"Argument {name} is not defined as mutable but was mutated"217)218else:219self.mutated.append(Mutation(func._schema.name, name))220
221# Aliasing between outputs222for i, j in combinations(range(len(func._schema.returns)), 2):223if has_aliased(tuple_out[i], tuple_out[j]):224if not schema_info.may_contain_alias(225SchemaArgument(SchemaArgType.output, i),226SchemaArgument(SchemaArgType.output, j),227):228raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly")229
230return out231