pytorch

Форк
0
220 строк · 7.5 Кб
1
# Represents all kernels used by an Executorch model.
2
# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
3

4
import itertools
5
from collections import defaultdict, namedtuple
6
from dataclasses import dataclass
7
from enum import IntEnum
8
from typing import Dict, List, Tuple, Union
9

10
from torchgen.model import (
11
    BackendIndex,
12
    BackendMetadata,
13
    DispatchKey,
14
    NativeFunction,
15
    NativeFunctionsGroup,
16
    OperatorName,
17
)
18
from torchgen.utils import assert_never
19

20
KERNEL_KEY_VERSION = 1
21

22

23
# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen
24
class ScalarType(IntEnum):
25
    Byte = 0
26
    Char = 1
27
    Short = 2
28
    Int = 3
29
    Long = 4
30
    Float = 6
31
    Double = 7
32
    Bool = 11
33

34

35
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"])
36

37

38
@dataclass(frozen=True)
39
class ETKernelKeyOpArgMeta:
40
    arg_name: str
41
    dtype: str
42
    # The order of the dimensions if entry is a Tensor
43
    dim_order: Tuple[int, ...]
44

45
    def to_native_string(self) -> str:
46
        dtype_str = ScalarType[self.dtype].value
47
        dim_str = str(self.dim_order)[1:-1].replace(" ", "")
48
        return f"{dtype_str};{dim_str}"
49

50

51
@dataclass(frozen=True)
52
class ETKernelKey:
53
    # Field undefined is default = True
54
    arg_meta: Tuple[ETKernelKeyOpArgMeta, ...] = ()
55

56
    # Indicator for this kernel being used as a catch all
57
    default: bool = False
58

59
    version: int = KERNEL_KEY_VERSION
60

61
    @staticmethod
62
    def gen_from_yaml(
63
        args: Dict[str, Tuple[str, str]],
64
        type_alias_map: Dict[str, List[str]],  # TODO: Support unwrapped str val
65
        dim_order_alias_map: Dict[str, List[int]],
66
    ) -> List["ETKernelKey"]:
67
        """Generate ETKernelKeys from arg kernel specs
68
        Multiple ETKernelKeys are returned due to dtype permutations from utilizing
69
        type_alias_map (actualizing each potential type permutation as a KernelKey)
70

71
        Args:
72
            args: Mapping from argument name to kernel specs
73
                Kernel specs are a tuple of (dtype, dim_order).
74
                Currently tuple entries must be aliased via the alias map arguments
75
            type_alias_map: Mapping from type alias to potential type enums
76
                i.e { T0 : [Double, Int] } means T0 can be either Double or Int
77
                Used for lookup by args
78
            dim_order_alias_map: Mapping from alias to a list of dimension orders
79
                Used for lookup by args
80
        """
81
        # Cast to dim order to int
82
        dim_order_alias_map = {
83
            k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items()
84
        }
85
        kernel_keys = []
86

87
        # Get all used Dtype Alias
88
        dtype_alias_used = set()
89
        for type_alias, dim_order in args.values():
90
            # Enforce usage of alias initially
91
            # TODO: Support inlined arguments
92
            assert type_alias in type_alias_map, "Undefined type alias: " + str(
93
                type_alias
94
            )
95
            assert (
96
                dim_order in dim_order_alias_map
97
            ), "Undefined dim_order alias: " + str(dim_order)
98
            dtype_alias_used.add(type_alias)
99

100
        # Generate all permutations of dtype alias values
101
        alias_dtypes = [
102
            [(alias, dtype) for dtype in type_alias_map[alias]]
103
            for alias in dtype_alias_used
104
        ]
105
        alias_permutations = [
106
            dict(permutation) for permutation in list(itertools.product(*alias_dtypes))
107
        ]
108

109
        # Using each alias value permutation, generate kernel keys
110
        op_arg_cache = {}
111
        for permutation in alias_permutations:
112
            arg_list = []
113
            for arg_name, arg_spec in args.items():
114
                dtype = permutation[arg_spec[0]]
115
                dim_order = dim_order_alias_map[arg_spec[1]]  # type: ignore[assignment]
116
                if (
117
                    cache_key := (arg_name, dtype, tuple(dim_order))
118
                ) not in op_arg_cache:
119
                    op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key)  # type: ignore[arg-type]
120

121
                arg_list.append(op_arg_cache[cache_key])
122
            kernel_keys.append(ETKernelKey(tuple(arg_list)))
123

124
        return kernel_keys
125

126
    def to_native_string(self) -> str:
127
        if self.default:
128
            return "default"
129
        return (
130
            "v"
131
            + str(KERNEL_KEY_VERSION)
132
            + "/"
133
            + "|".join([arg.to_native_string() for arg in self.arg_meta])
134
        )
135

136

137
@dataclass(frozen=True)
138
class ETKernelIndex:
139
    index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]]
140

141
    def has_kernels(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
142
        m = self.get_kernels(g)
143
        return m is not None
144

145
    def get_kernels(
146
        self, g: Union[NativeFunction, NativeFunctionsGroup]
147
    ) -> Dict[ETKernelKey, BackendMetadata]:
148
        if isinstance(g, NativeFunction):
149
            f = g
150
        elif isinstance(g, NativeFunctionsGroup):
151
            f = g.functional
152
        else:
153
            assert_never(g)
154
        if f.func.name not in self.index:
155
            return {}
156
        return self.index[f.func.name]
157

158
    @staticmethod
159
    def grow_from_backend_indices(
160
        kernel_index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]],
161
        backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
162
    ) -> None:
163
        for dk in backend_indices:
164
            index = backend_indices[dk]
165
            for op, backend_metadata in index.items():
166
                if op in kernel_index:
167
                    kernel_index[op][ETKernelKey(default=True)] = backend_metadata
168
                else:
169
                    kernel_index[op] = {ETKernelKey(default=True): backend_metadata}
170

171
    @staticmethod
172
    def from_backend_indices(
173
        backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
174
    ) -> "ETKernelIndex":
175
        kernel_index: Dict[
176
            OperatorName, Dict[ETKernelKey, BackendMetadata]
177
        ] = defaultdict(dict)
178
        ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
179
        return ETKernelIndex(kernel_index)
180

181
    def grow(
182
        self, backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
183
    ) -> "ETKernelIndex":
184
        ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
185
        return self
186

187
    def _to_backend_index(self) -> BackendIndex:
188
        """
189
        WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
190
        """
191
        index: Dict[OperatorName, BackendMetadata] = {}
192
        for op in self.index:
193
            kernel_dict = self.index[op]
194
            assert (
195
                len(kernel_dict.values()) == 1
196
            ), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
197
            index[op] = kernel_dict.get(
198
                ETKernelKey(default=True),
199
                BackendMetadata(kernel="", structured=False, cpp_namespace=""),
200
            )
201
        return BackendIndex(
202
            dispatch_key=DispatchKey.CPU,
203
            use_out_as_primary=False,
204
            device_guard=False,
205
            external=False,
206
            index=index,
207
        )
208

209
    # Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
210
    @staticmethod
211
    def merge_indices(
212
        index_a: "ETKernelIndex", index_b: "ETKernelIndex"
213
    ) -> "ETKernelIndex":
214
        combined = defaultdict(dict, index_a.index.copy())
215

216
        for op, entry in index_b.index.items():
217
            for key, metadata in entry.items():
218
                combined[op][key] = metadata
219

220
        return ETKernelIndex(combined)
221

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.