pytorch

Форк
0
/
_python_dispatcher.py 
182 строки · 7.0 Кб
1
# mypy: allow-untyped-defs
2
import re
3

4
import torch._C as C
5

6

7
"""
8
PythonDispatcher class is a thin python-binding to C++ dispatcher and it
9
is designed to show how dispatcher precompute works. In particular,
10
it shows for a certain op `foo`, what the computed dispatch table looks
11
like after user register their kernels to certains dispatch keys.
12

13
In the real C++ dispatcher we support many dispatch keys for different
14
functionalities. For simplicity PythonDispatcher only supports dispatch
15
keys for a single example of each use case. These use cases are listed below:
16

17
- CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference &
18
    autograd kernel in pytorch core library.
19
    E.g. CPU, CUDA
20
- FPGA/AutogradOther: represents in-tree backends which we usually have backend specific
21
    inference kernels, but they share the same autograd kernel specified in AutogradOther.
22
    E.g. FPGA, SparseCsrCPU
23
- XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd
24
    kernel defined in pytorch core library. Backend owner is responsible for registering both
25
    inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support.
26
    E.g. XLA, XPU, MPS
27
- CompositeExplicitAutograd: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc.
28
    Kernels registered to this key MUST work for inference for all backends.
29
- Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther.
30
    Kernels registered to this key MUST work for autograd for all backends.
31
- CompositeImplicitAutograd: alias key CompositeImplicitAutograd = CompositeExplicitAutograd + Autograd
32
    Kernels registered to this key MUST work for both inference + autograd for all backends.
33

34
Note we only allow registrations to alias keys inside pytorch core library. E.g
35
you shouldn't register a CompositeImplicitAutograd or CompositeExplicitAutograd
36
kernel from torch-xla extension, instead you should upstream the kernel into
37
pytorch/pytorch repo so that it's available for all backends and continuously
38
tested even without the extension.
39

40
Usage:
41
  dispatcher = PythonDispatcher()
42
  dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"])
43
  print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend.
44
  # For more debugging information
45
  # print(dispatcher.keys())
46
  # print(dispatcher.registrations())
47
  # print(dispatcher.rawRegistrations())
48
  # print(dispatcher.rawDispatchTable())
49
PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table.
50
This file only provides the simplified API for developers, relevant test code is located in
51
test/test_dispatch.py
52
"""
53

54

55
class PythonDispatcher:
56
    namespace = "__test__"
57
    name = "foo"
58
    # fmt: off
59
    runtime_keys = [
60
        "CPU", "AutogradCPU",
61
        "FPGA", "AutogradOther",
62
        "XLA", "AutogradXLA",
63
        "Lazy", "AutogradLazy",
64
    ]
65
    # fmt: on
66
    alias_keys = [
67
        "CompositeExplicitAutograd",
68
        "Autograd",
69
        "CompositeImplicitAutograd",
70
    ]
71
    supported_keys = runtime_keys + alias_keys
72

73
    def __init__(self) -> None:
74
        C._dispatch_check_invariants(self.name)  # type: ignore[attr-defined]
75
        self.ref = C._dispatch_library("FRAGMENT", self.namespace, "")
76
        self.ref.def_("foo(Tensor x) -> Tensor")
77

78
    """
79
    Returns a list of dispatch keys supported by PythonDispatcher.
80
    You can register kernels to these keys.
81
    """
82

83
    def keys(self):
84
        return self.supported_keys
85

86
    """
87
    Register kernels to the target dispatchKeys.
88
    dispatchKeys(list[str]): a list of dispatch keys that you want to register
89
      your own kernel. Note that you don't need to write the kernel yourself in
90
      this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is
91
      automatically generated and registered.
92
    """
93

94
    def register(self, dispatchKeys):
95
        # Overriden is not supported and triggers a warning in C++ dispatcher.
96
        if len(set(dispatchKeys)) != len(dispatchKeys):
97
            raise RuntimeError(
98
                f"Overriden is not allowed but found duplicates in {dispatchKeys}."
99
            )
100
        # We currently forbid this in codegen instead of C++ dispatcher.
101
        if (
102
            "CompositeImplicitAutograd" in dispatchKeys
103
            and "CompositeExplicitAutograd" in dispatchKeys
104
        ):
105
            raise RuntimeError(
106
                "Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed."
107
            )
108
        for key in dispatchKeys:
109
            if key not in self.supported_keys:
110
                raise RuntimeError(
111
                    f"{key} is not supported, please select a dispatch key in {self.supported_keys}."
112
                )
113
            self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key)
114

115
    """
116
    Helper function to format (key, kernel).
117
    """
118

119
    def _format_line(self, key, kernel):
120
        return f"{key:<15} {kernel}\n"
121

122
    """
123
    Helper function to print a table header.
124
    """
125

126
    def _format_header(self, header):
127
        s = f"""
128
{header}
129
"""
130
        s += self._format_line("key", "kernel")
131
        s += "---------------------------\n"
132
        return s
133

134
    """
135
    Returns raw output of all registration info for debugging only.
136
    Use registrations() for a simplified version.
137
    """
138

139
    def rawRegistrations(self):
140
        return C._dispatch_dump(f"{self.namespace}::{self.name}")  # type: ignore[attr-defined]
141

142
    """
143
    Returns raw output of computed dispatch table for debugging only.
144
    Use dispatchTable() for a simplified version.
145
    """
146

147
    def rawDispatchTable(self):
148
        return C._dispatch_dump_table(f"{self.namespace}::{self.name}")  # type: ignore[attr-defined]
149

150
    """
151
    Returns a table(str) including all the registrations from users.
152
    Note this includes registrations to both runtime keys and alias keys.
153
    """
154

155
    def registrations(self):
156
        output = self._format_header("Registered Kernels")
157
        state = self.rawRegistrations()
158
        state_entries = state.split("\n")
159
        for line in state_entries:
160
            first = line.split(":")[0]
161
            if any(first.startswith(k) for k in self.supported_keys):
162
                kernel = line.split("::")[0].split(" ")[1]
163
                output += self._format_line(first, kernel)
164
        return output
165

166
    """
167
    Returns the computed dispatch table(str). Note this only include
168
    runtime keys, registrations to alias keys have been decoded to their
169
    mapped runtime keys.
170
    """
171

172
    def dispatchTable(self):
173
        output = self._format_header("Computed Dispatch Table")
174
        table = self.rawDispatchTable()
175
        table_entries = table.split("\n")
176
        regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)")
177
        for line in table_entries:
178
            k = line.split(":")[0]
179
            if k in self.runtime_keys:
180
                entry = regex.sub("[", line)
181
                output += self._format_line(k, entry.split(": ")[1])
182
        return output
183

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.