2
from typing import Callable, List
5
from torch import Tensor
11
def __init__(self, code_string: str):
14
template_params = r"(?P<template_params>\<.+\>)"
15
return_type = r"(?P<return_type>\w+)"
16
function_name = r"(?P<function_name>\w+)"
17
function_params = r"(?P<function_params>\(.+\))"
18
function_body = r"(?P<function_body>\{.+\})"
37
pattern, code_string, re.DOTALL
38
) # DOTALL for matching multiline
42
f"Couldn't parse code, please check correctness:\n {code_string}"
45
self.template_params = result["template_params"]
46
self.return_type = result["return_type"]
47
self.function_name = result["function_name"]
48
self.function_params = result["function_params"]
49
self.function_body = result["function_body"]
54
self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
56
self.code_string = code_string
59
return_by_ref or num_outputs == 1
60
), "Return by value only works for single output. "
61
self.return_by_ref = return_by_ref
62
self.num_outputs = num_outputs
64
parsed_code = _CodeParser(code_string)
65
self.kernel_name = parsed_code.function_name
67
self.kwargs_dict = kwargs
68
self.is_cuda_available = torch.cuda.is_available()
70
def __call__(self, *tensors: Tensor, **kwargs):
71
# Jiterator follow torch.cuda's lazy initialization behavior
72
# Defer checking cuda's availability at the function invocation time
74
self.is_cuda_available
75
), "Jiterator is only supported on CUDA and ROCm GPUs, none are available."
77
assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
79
expanded_kwargs = self.kwargs_dict.copy()
80
for key, value in kwargs.items():
81
if key in self.kwargs_dict:
82
expanded_kwargs[key] = value
84
raise KeyError(f"{key} is not declared in function definition")
86
return torch._C._cuda_jiterator_compile_and_launch_kernel(
96
def _create_jit_fn(code_string: str, **kwargs) -> Callable:
98
Create a jiterator-generated cuda kernel for an elementwise op.
100
The code string has to be a valid CUDA function that describes the computation for a single element. The code
101
string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
102
into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
105
Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion.
108
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.
109
kwargs (Dict, optional): Keyword arguments for generated function
113
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
114
jitted_fn = create_jit_fn(code_string, alpha=1.0)
115
a = torch.rand(3, device='cuda')
116
b = torch.rand(3, device='cuda')
117
# invoke jitted function like a regular python function
118
result = jitted_fn(a, b, alpha=3.14)
120
code_string also allows multiple function definitions, and the last function will be treated as the entry function.
124
code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
125
code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
126
jitted_fn = create_jit_fn(code_string, val=0.0)
127
a = torch.rand(3, device='cuda')
128
b = torch.rand(3, device='cuda')
129
# invoke jitted function like a regular python function
130
result = jitted_fn(a, b) # using default val=0.0
132
Jiterator can be used together with python registration to override an operator's cuda kernel.
133
Following example is overriding gelu's cuda kernel with relu.
137
code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
138
my_gelu = create_jit_fn(code_string)
139
my_lib = torch.library.Library("aten", "IMPL")
140
my_lib.impl('aten::gelu', my_gelu, "CUDA")
141
# torch.nn.GELU and torch.nn.function.gelu are now overridden
142
a = torch.rand(3, device='cuda')
143
torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
146
This API is in beta and may change in future releases.
149
This API only supports up to 8 inputs and 1 output
152
All input tensors must live in CUDA device
154
return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
157
def _create_multi_output_jit_fn(
158
code_string: str, num_outputs: int, **kwargs
161
Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.
164
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.
165
num_outputs(int): number of outputs return by the kernel
166
kwargs (Dict, optional): Keyword arguments for generated function
170
code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
171
jitted_fn = create_jit_fn(code_string, alpha=1.0)
172
a = torch.rand(3, device='cuda')
173
b = torch.rand(3, device='cuda')
174
# invoke jitted function like a regular python function
175
result = jitted_fn(a, b, alpha=3.14)
178
This API is in beta and may change in future releases.
181
This API only supports up to 8 inputs and 8 outputs
183
return _JittedFunction(
184
code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs