pytorch

Форк
0
/
jiterator.py 
185 строк · 6.6 Кб
1
import re
2
from typing import Callable, List
3

4
import torch
5
from torch import Tensor
6

7
__all__: List[str] = []
8

9

10
class _CodeParser:
11
    def __init__(self, code_string: str):
12
        optional_ws = r"\s*"
13
        required_ws = r"\s+"
14
        template_params = r"(?P<template_params>\<.+\>)"
15
        return_type = r"(?P<return_type>\w+)"
16
        function_name = r"(?P<function_name>\w+)"
17
        function_params = r"(?P<function_params>\(.+\))"
18
        function_body = r"(?P<function_body>\{.+\})"
19

20
        pattern = (
21
            optional_ws
22
            + "template"
23
            + optional_ws
24
            + template_params
25
            + optional_ws
26
            + return_type
27
            + required_ws
28
            + function_name
29
            + optional_ws
30
            + function_params
31
            + optional_ws
32
            + function_body
33
            + optional_ws
34
        )
35

36
        result = re.match(
37
            pattern, code_string, re.DOTALL
38
        )  # DOTALL for matching multiline
39

40
        if result is None:
41
            raise Exception(
42
                f"Couldn't parse code, please check correctness:\n {code_string}"
43
            )
44

45
        self.template_params = result["template_params"]
46
        self.return_type = result["return_type"]
47
        self.function_name = result["function_name"]
48
        self.function_params = result["function_params"]
49
        self.function_body = result["function_body"]
50

51

52
class _JittedFunction:
53
    def __init__(
54
        self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
55
    ):
56
        self.code_string = code_string
57

58
        assert (
59
            return_by_ref or num_outputs == 1
60
        ), "Return by value only works for single output. "
61
        self.return_by_ref = return_by_ref
62
        self.num_outputs = num_outputs
63

64
        parsed_code = _CodeParser(code_string)
65
        self.kernel_name = parsed_code.function_name
66

67
        self.kwargs_dict = kwargs
68
        self.is_cuda_available = torch.cuda.is_available()
69

70
    def __call__(self, *tensors: Tensor, **kwargs):
71
        # Jiterator follow torch.cuda's lazy initialization behavior
72
        # Defer checking cuda's availability at the function invocation time
73
        assert (
74
            self.is_cuda_available
75
        ), "Jiterator is only supported on CUDA and ROCm GPUs, none are available."
76

77
        assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
78

79
        expanded_kwargs = self.kwargs_dict.copy()
80
        for key, value in kwargs.items():
81
            if key in self.kwargs_dict:
82
                expanded_kwargs[key] = value
83
            else:
84
                raise KeyError(f"{key} is not declared in function definition")
85

86
        return torch._C._cuda_jiterator_compile_and_launch_kernel(
87
            self.code_string,
88
            self.kernel_name,
89
            self.return_by_ref,
90
            self.num_outputs,
91
            tensors,
92
            expanded_kwargs,
93
        )
94

95

96
def _create_jit_fn(code_string: str, **kwargs) -> Callable:
97
    """
98
    Create a jiterator-generated cuda kernel for an elementwise op.
99

100
    The code string has to be a valid CUDA function that describes the computation for a single element. The code
101
    string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
102
    into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
103
    local temp dir.
104

105
    Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion.
106

107
    Args:
108
        code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.
109
        kwargs (Dict, optional): Keyword arguments for generated function
110

111
    Example::
112

113
        code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
114
        jitted_fn = create_jit_fn(code_string, alpha=1.0)
115
        a = torch.rand(3, device='cuda')
116
        b = torch.rand(3, device='cuda')
117
        # invoke jitted function like a regular python function
118
        result = jitted_fn(a, b, alpha=3.14)
119

120
    code_string also allows multiple function definitions, and the last function will be treated as the entry function.
121

122
    Example::
123

124
        code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
125
        code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
126
        jitted_fn = create_jit_fn(code_string, val=0.0)
127
        a = torch.rand(3, device='cuda')
128
        b = torch.rand(3, device='cuda')
129
        # invoke jitted function like a regular python function
130
        result = jitted_fn(a, b)  # using default val=0.0
131

132
    Jiterator can be used together with python registration to override an operator's cuda kernel.
133
    Following example is overriding gelu's cuda kernel with relu.
134

135
    Example::
136

137
        code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
138
        my_gelu = create_jit_fn(code_string)
139
        my_lib = torch.library.Library("aten", "IMPL")
140
        my_lib.impl('aten::gelu', my_gelu, "CUDA")
141
        # torch.nn.GELU and torch.nn.function.gelu are now overridden
142
        a = torch.rand(3, device='cuda')
143
        torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
144

145
    .. warning::
146
        This API is in beta and may change in future releases.
147

148
    .. warning::
149
        This API only supports up to 8 inputs and 1 output
150

151
    .. warning::
152
        All input tensors must live in CUDA device
153
    """
154
    return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
155

156

157
def _create_multi_output_jit_fn(
158
    code_string: str, num_outputs: int, **kwargs
159
) -> Callable:
160
    """
161
    Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.
162

163
    Args:
164
        code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.
165
        num_outputs(int): number of outputs return by the kernel
166
        kwargs (Dict, optional): Keyword arguments for generated function
167

168
    Example::
169

170
        code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
171
        jitted_fn = create_jit_fn(code_string, alpha=1.0)
172
        a = torch.rand(3, device='cuda')
173
        b = torch.rand(3, device='cuda')
174
        # invoke jitted function like a regular python function
175
        result = jitted_fn(a, b, alpha=3.14)
176

177
    .. warning::
178
        This API is in beta and may change in future releases.
179

180
    .. warning::
181
        This API only supports up to 8 inputs and 8 outputs
182
    """
183
    return _JittedFunction(
184
        code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs
185
    )
186

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.