4
from torch.cuda.jiterator import _create_jit_fn as create_jit_fn
5
from torch.cuda.jiterator import _create_multi_output_jit_fn as create_multi_output_jit_fn
7
from itertools import product
8
from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA, NoTest
9
from torch.testing._internal.common_dtype import all_types_and_complex_and
10
from torch.testing._internal.common_device_type import (
11
skipCUDAIfVersionLessThan, instantiate_device_type_tests, dtypes, toleranceOverride, tol)
14
print('CUDA not available, skipping tests', file=sys.stderr)
18
code_string = "template <typename T> T my_fused_kernel(T x, T y, T alpha, T beta) { return alpha * x + beta * y; }"
19
jitted_fn = create_jit_fn(code_string, alpha=1, beta=1)
21
def ref_fn(x, y, alpha=1, beta=1):
22
return alpha * x + beta * y
24
class TestPythonJiterator(TestCase):
25
@parametrize("shape_strides", [
26
(([3, 3], [3, 1]), ([3, 3], [3, 1])),
28
@dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
29
all_types_and_complex_and(torch.half, torch.bfloat16)))
30
def test_all_dtype_contiguous(self, device, dtypes, shape_strides):
31
a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
32
b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
34
a = a_buffer.as_strided(*shape_strides[0])
35
b = b_buffer.as_strided(*shape_strides[1])
37
expected = ref_fn(a, b)
38
result = jitted_fn(a, b)
40
self.assertEqual(expected, result)
45
@skipCUDAIfVersionLessThan((11, 6))
46
@parametrize("shape_strides", [
47
(([3, 3], [1, 3]), ([3, 1], [1, 3])),
49
@dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
50
all_types_and_complex_and(torch.half, torch.bfloat16)))
51
def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides):
52
a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
53
b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
55
a = a_buffer.as_strided(*shape_strides[0])
56
b = b_buffer.as_strided(*shape_strides[1])
58
expected = ref_fn(a, b)
59
result = jitted_fn(a, b)
61
self.assertEqual(expected, result)
63
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
64
@parametrize("alpha", [-1, 2.0, None])
65
@parametrize("beta", [3, -4.2, None])
66
@toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)})
67
def test_extra_args(self, device, dtype, alpha, beta):
68
a = torch.rand(3, device=device).mul(10).type(dtype)
69
b = torch.rand(3, device=device).mul(10).type(dtype)
73
extra_args["alpha"] = alpha
75
extra_args["beta"] = beta
77
expected = ref_fn(a, b, **extra_args)
78
result = jitted_fn(a, b, **extra_args)
80
self.assertEqual(expected, result)
82
@parametrize("is_train", [True, False])
83
def test_bool_extra_args(self, device, is_train):
84
code_string = "template <typename T> T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }"
85
jitted_fn = create_jit_fn(code_string, is_train=False)
87
def ref_fn(x, mask, is_train):
88
return x * mask if is_train else x
90
a = torch.rand(3, device=device)
91
b = torch.rand(3, device=device)
93
expected = ref_fn(a, b, is_train=is_train)
94
result = jitted_fn(a, b, is_train=is_train)
95
self.assertEqual(expected, result)
97
def test_multiple_functors(self, device):
99
template <typename T> T fn(T x, T mask) { return x * mask; }
100
template <typename T> T main_fn(T x, T mask, T y) { return fn(x, mask) + y; }
102
jitted_fn = create_jit_fn(code_string)
104
def ref_fn(x, mask, y):
107
a = torch.rand(3, device=device)
108
b = torch.rand(3, device=device)
109
c = torch.rand(3, device=device)
111
expected = ref_fn(a, b, c)
112
result = jitted_fn(a, b, c)
113
self.assertEqual(expected, result)
115
@parametrize("num_inputs", [1, 5, 8])
116
def test_various_num_inputs(self, num_inputs):
118
for i in range(num_inputs):
119
inputs.append(torch.rand(3, device='cuda').mul(10))
121
input_string = ",".join([f"T i{i}" for i in range(num_inputs)])
122
function_body = "+".join([f"i{i}" for i in range(num_inputs)])
123
code_string = f"template <typename T> T my_kernel({input_string}) {{ return {function_body}; }}"
124
jitted_fn = create_jit_fn(code_string)
127
return torch.sum(torch.stack(inputs), dim=0)
129
expected = ref_fn(*inputs)
130
result = jitted_fn(*inputs)
132
self.assertEqual(expected, result)
134
@parametrize("num_outputs", [1, 4, 8])
135
def test_various_num_outputs(self, num_outputs):
136
input = torch.rand(3, device='cuda')
138
output_string = ", ".join([f"T& out{i}" for i in range(num_outputs)])
140
for i in range(num_outputs):
141
function_body += f"out{i} = input + {i};\n"
143
code_string = f"template <typename T> void my_kernel(T input, {output_string}) {{ {function_body} }}"
145
jitted_fn = create_multi_output_jit_fn(code_string, num_outputs)
149
for i in range(num_outputs):
150
outputs.append(input + i)
154
return tuple(outputs)
156
expected = ref_fn(input)
157
result = jitted_fn(input)
159
for i in range(num_outputs):
160
self.assertEqual(expected[i], result[i])
162
@parametrize("code_string", [
163
"template <typename T> T my _kernel(T x) { return x; }",
164
"template <typename T> Tmy_kernel(T x) { return x; }",
166
def test_invalid_function_name(self, code_string):
167
with self.assertRaises(Exception):
168
jitted_fn = create_jit_fn(code_string)
171
instantiate_device_type_tests(TestPythonJiterator, globals(), only_for="cuda")
173
if __name__ == '__main__':