pytorch
220 строк · 7.5 Кб
1# Represents all kernels used by an Executorch model.
2# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
3
4import itertools
5from collections import defaultdict, namedtuple
6from dataclasses import dataclass
7from enum import IntEnum
8from typing import Dict, List, Tuple, Union
9
10from torchgen.model import (
11BackendIndex,
12BackendMetadata,
13DispatchKey,
14NativeFunction,
15NativeFunctionsGroup,
16OperatorName,
17)
18from torchgen.utils import assert_never
19
20KERNEL_KEY_VERSION = 1
21
22
23# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen
24class ScalarType(IntEnum):
25Byte = 0
26Char = 1
27Short = 2
28Int = 3
29Long = 4
30Float = 6
31Double = 7
32Bool = 11
33
34
35ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"])
36
37
38@dataclass(frozen=True)
39class ETKernelKeyOpArgMeta:
40arg_name: str
41dtype: str
42# The order of the dimensions if entry is a Tensor
43dim_order: Tuple[int, ...]
44
45def to_native_string(self) -> str:
46dtype_str = ScalarType[self.dtype].value
47dim_str = str(self.dim_order)[1:-1].replace(" ", "")
48return f"{dtype_str};{dim_str}"
49
50
51@dataclass(frozen=True)
52class ETKernelKey:
53# Field undefined is default = True
54arg_meta: Tuple[ETKernelKeyOpArgMeta, ...] = ()
55
56# Indicator for this kernel being used as a catch all
57default: bool = False
58
59version: int = KERNEL_KEY_VERSION
60
61@staticmethod
62def gen_from_yaml(
63args: Dict[str, Tuple[str, str]],
64type_alias_map: Dict[str, List[str]], # TODO: Support unwrapped str val
65dim_order_alias_map: Dict[str, List[int]],
66) -> List["ETKernelKey"]:
67"""Generate ETKernelKeys from arg kernel specs
68Multiple ETKernelKeys are returned due to dtype permutations from utilizing
69type_alias_map (actualizing each potential type permutation as a KernelKey)
70
71Args:
72args: Mapping from argument name to kernel specs
73Kernel specs are a tuple of (dtype, dim_order).
74Currently tuple entries must be aliased via the alias map arguments
75type_alias_map: Mapping from type alias to potential type enums
76i.e { T0 : [Double, Int] } means T0 can be either Double or Int
77Used for lookup by args
78dim_order_alias_map: Mapping from alias to a list of dimension orders
79Used for lookup by args
80"""
81# Cast to dim order to int
82dim_order_alias_map = {
83k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items()
84}
85kernel_keys = []
86
87# Get all used Dtype Alias
88dtype_alias_used = set()
89for type_alias, dim_order in args.values():
90# Enforce usage of alias initially
91# TODO: Support inlined arguments
92assert type_alias in type_alias_map, "Undefined type alias: " + str(
93type_alias
94)
95assert (
96dim_order in dim_order_alias_map
97), "Undefined dim_order alias: " + str(dim_order)
98dtype_alias_used.add(type_alias)
99
100# Generate all permutations of dtype alias values
101alias_dtypes = [
102[(alias, dtype) for dtype in type_alias_map[alias]]
103for alias in dtype_alias_used
104]
105alias_permutations = [
106dict(permutation) for permutation in list(itertools.product(*alias_dtypes))
107]
108
109# Using each alias value permutation, generate kernel keys
110op_arg_cache = {}
111for permutation in alias_permutations:
112arg_list = []
113for arg_name, arg_spec in args.items():
114dtype = permutation[arg_spec[0]]
115dim_order = dim_order_alias_map[arg_spec[1]] # type: ignore[assignment]
116if (
117cache_key := (arg_name, dtype, tuple(dim_order))
118) not in op_arg_cache:
119op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key) # type: ignore[arg-type]
120
121arg_list.append(op_arg_cache[cache_key])
122kernel_keys.append(ETKernelKey(tuple(arg_list)))
123
124return kernel_keys
125
126def to_native_string(self) -> str:
127if self.default:
128return "default"
129return (
130"v"
131+ str(KERNEL_KEY_VERSION)
132+ "/"
133+ "|".join([arg.to_native_string() for arg in self.arg_meta])
134)
135
136
137@dataclass(frozen=True)
138class ETKernelIndex:
139index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]]
140
141def has_kernels(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
142m = self.get_kernels(g)
143return m is not None
144
145def get_kernels(
146self, g: Union[NativeFunction, NativeFunctionsGroup]
147) -> Dict[ETKernelKey, BackendMetadata]:
148if isinstance(g, NativeFunction):
149f = g
150elif isinstance(g, NativeFunctionsGroup):
151f = g.functional
152else:
153assert_never(g)
154if f.func.name not in self.index:
155return {}
156return self.index[f.func.name]
157
158@staticmethod
159def grow_from_backend_indices(
160kernel_index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]],
161backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
162) -> None:
163for dk in backend_indices:
164index = backend_indices[dk]
165for op, backend_metadata in index.items():
166if op in kernel_index:
167kernel_index[op][ETKernelKey(default=True)] = backend_metadata
168else:
169kernel_index[op] = {ETKernelKey(default=True): backend_metadata}
170
171@staticmethod
172def from_backend_indices(
173backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
174) -> "ETKernelIndex":
175kernel_index: Dict[
176OperatorName, Dict[ETKernelKey, BackendMetadata]
177] = defaultdict(dict)
178ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
179return ETKernelIndex(kernel_index)
180
181def grow(
182self, backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
183) -> "ETKernelIndex":
184ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
185return self
186
187def _to_backend_index(self) -> BackendIndex:
188"""
189WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
190"""
191index: Dict[OperatorName, BackendMetadata] = {}
192for op in self.index:
193kernel_dict = self.index[op]
194assert (
195len(kernel_dict.values()) == 1
196), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
197index[op] = kernel_dict.get(
198ETKernelKey(default=True),
199BackendMetadata(kernel="", structured=False, cpp_namespace=""),
200)
201return BackendIndex(
202dispatch_key=DispatchKey.CPU,
203use_out_as_primary=False,
204device_guard=False,
205external=False,
206index=index,
207)
208
209# Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
210@staticmethod
211def merge_indices(
212index_a: "ETKernelIndex", index_b: "ETKernelIndex"
213) -> "ETKernelIndex":
214combined = defaultdict(dict, index_a.index.copy())
215
216for op, entry in index_b.index.items():
217for key, metadata in entry.items():
218combined[op][key] = metadata
219
220return ETKernelIndex(combined)
221