pytorch

Форк
0
/
test_jiterator.py 
174 строки · 6.5 Кб
1
# Owner(s): ["module: cuda"]
2

3
import torch
4
from torch.cuda.jiterator import _create_jit_fn as create_jit_fn
5
from torch.cuda.jiterator import _create_multi_output_jit_fn as create_multi_output_jit_fn
6
import sys
7
from itertools import product
8
from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA, NoTest
9
from torch.testing._internal.common_dtype import all_types_and_complex_and
10
from torch.testing._internal.common_device_type import (
11
    skipCUDAIfVersionLessThan, instantiate_device_type_tests, dtypes, toleranceOverride, tol)
12

13
if not TEST_CUDA:
14
    print('CUDA not available, skipping tests', file=sys.stderr)
15
    TestCase = NoTest  # noqa: F811
16

17

18
code_string = "template <typename T> T my_fused_kernel(T x, T y, T alpha, T beta) { return alpha * x + beta * y; }"
19
jitted_fn = create_jit_fn(code_string, alpha=1, beta=1)
20

21
def ref_fn(x, y, alpha=1, beta=1):
22
    return alpha * x + beta * y
23

24
class TestPythonJiterator(TestCase):
25
    @parametrize("shape_strides", [
26
        (([3, 3], [3, 1]), ([3, 3], [3, 1])),  # contiguous
27
    ])
28
    @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
29
                     all_types_and_complex_and(torch.half, torch.bfloat16)))
30
    def test_all_dtype_contiguous(self, device, dtypes, shape_strides):
31
        a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
32
        b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
33

34
        a = a_buffer.as_strided(*shape_strides[0])
35
        b = b_buffer.as_strided(*shape_strides[1])
36

37
        expected = ref_fn(a, b)
38
        result = jitted_fn(a, b)
39

40
        self.assertEqual(expected, result)
41

42
    # See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details
43
    # On cuda 11.3, nvrtcCompileProgram is taking too long to
44
    # compile jiterator generated kernels for non-contiguous input that requires dynamic-casting.
45
    @skipCUDAIfVersionLessThan((11, 6))
46
    @parametrize("shape_strides", [
47
        (([3, 3], [1, 3]), ([3, 1], [1, 3])),  # non-contiguous
48
    ])
49
    @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
50
                     all_types_and_complex_and(torch.half, torch.bfloat16)))
51
    def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides):
52
        a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
53
        b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
54

55
        a = a_buffer.as_strided(*shape_strides[0])
56
        b = b_buffer.as_strided(*shape_strides[1])
57

58
        expected = ref_fn(a, b)
59
        result = jitted_fn(a, b)
60

61
        self.assertEqual(expected, result)
62

63
    @dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
64
    @parametrize("alpha", [-1, 2.0, None])
65
    @parametrize("beta", [3, -4.2, None])
66
    @toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)})
67
    def test_extra_args(self, device, dtype, alpha, beta):
68
        a = torch.rand(3, device=device).mul(10).type(dtype)
69
        b = torch.rand(3, device=device).mul(10).type(dtype)
70

71
        extra_args = {}
72
        if alpha is not None:
73
            extra_args["alpha"] = alpha
74
        if beta is not None:
75
            extra_args["beta"] = beta
76

77
        expected = ref_fn(a, b, **extra_args)
78
        result = jitted_fn(a, b, **extra_args)
79

80
        self.assertEqual(expected, result)
81

82
    @parametrize("is_train", [True, False])
83
    def test_bool_extra_args(self, device, is_train):
84
        code_string = "template <typename T> T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }"
85
        jitted_fn = create_jit_fn(code_string, is_train=False)
86

87
        def ref_fn(x, mask, is_train):
88
            return x * mask if is_train else x
89

90
        a = torch.rand(3, device=device)
91
        b = torch.rand(3, device=device)
92

93
        expected = ref_fn(a, b, is_train=is_train)
94
        result = jitted_fn(a, b, is_train=is_train)
95
        self.assertEqual(expected, result)
96

97
    def test_multiple_functors(self, device):
98
        code_string = '''
99
        template <typename T> T fn(T x, T mask) { return x * mask; }
100
        template <typename T> T main_fn(T x, T mask, T y) { return fn(x, mask) + y; }
101
        '''
102
        jitted_fn = create_jit_fn(code_string)
103

104
        def ref_fn(x, mask, y):
105
            return x * mask + y
106

107
        a = torch.rand(3, device=device)
108
        b = torch.rand(3, device=device)
109
        c = torch.rand(3, device=device)
110

111
        expected = ref_fn(a, b, c)
112
        result = jitted_fn(a, b, c)
113
        self.assertEqual(expected, result)
114

115
    @parametrize("num_inputs", [1, 5, 8])
116
    def test_various_num_inputs(self, num_inputs):
117
        inputs = []
118
        for i in range(num_inputs):
119
            inputs.append(torch.rand(3, device='cuda').mul(10))
120

121
        input_string = ",".join([f"T i{i}" for i in range(num_inputs)])
122
        function_body = "+".join([f"i{i}" for i in range(num_inputs)])
123
        code_string = f"template <typename T> T my_kernel({input_string}) {{ return {function_body}; }}"
124
        jitted_fn = create_jit_fn(code_string)
125

126
        def ref_fn(*inputs):
127
            return torch.sum(torch.stack(inputs), dim=0)
128

129
        expected = ref_fn(*inputs)
130
        result = jitted_fn(*inputs)
131

132
        self.assertEqual(expected, result)
133

134
    @parametrize("num_outputs", [1, 4, 8])
135
    def test_various_num_outputs(self, num_outputs):
136
        input = torch.rand(3, device='cuda')
137

138
        output_string = ", ".join([f"T& out{i}" for i in range(num_outputs)])
139
        function_body = ""
140
        for i in range(num_outputs):
141
            function_body += f"out{i} = input + {i};\n"
142
        # NB: return type must be void, otherwise ROCm silently fails
143
        code_string = f"template <typename T> void my_kernel(T input, {output_string}) {{ {function_body} }}"
144

145
        jitted_fn = create_multi_output_jit_fn(code_string, num_outputs)
146

147
        def ref_fn(input):
148
            outputs = []
149
            for i in range(num_outputs):
150
                outputs.append(input + i)
151

152
            if num_outputs == 1:
153
                return outputs[0]
154
            return tuple(outputs)
155

156
        expected = ref_fn(input)
157
        result = jitted_fn(input)
158

159
        for i in range(num_outputs):
160
            self.assertEqual(expected[i], result[i])
161

162
    @parametrize("code_string", [
163
        "template <typename T> T my _kernel(T x) { return x; }",
164
        "template <typename T> Tmy_kernel(T x) { return x; }",
165
    ])
166
    def test_invalid_function_name(self, code_string):
167
        with self.assertRaises(Exception):
168
            jitted_fn = create_jit_fn(code_string)
169

170

171
instantiate_device_type_tests(TestPythonJiterator, globals(), only_for="cuda")
172

173
if __name__ == '__main__':
174
    run_tests()
175

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.