1
# Owner(s): ["module: unknown"]
3
from torch.testing._internal.common_utils import TestCase, run_tests
4
from torch.testing._internal.check_kernel_launches import (
5
check_cuda_kernel_launches, check_code_for_cuda_kernel_launches
9
class AlwaysCheckCudaLaunchTest(TestCase):
10
def test_check_code(self):
11
"""Verifies that the regex works for a few different situations"""
13
# Try some different spacings
14
self.assertEqual(2, check_code_for_cuda_kernel_launches("""
15
some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
16
C10_CUDA_KERNEL_LAUNCH_CHECK();
17
some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
19
some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
20
C10_CUDA_KERNEL_LAUNCH_CHECK();
21
some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
23
some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
24
C10_CUDA_KERNEL_LAUNCH_CHECK();
25
some_function_call<TemplateArg><<<1,2,0,stream>>> (arg1,arg2,arg3);
26
C10_CUDA_KERNEL_LAUNCH_CHECK();
27
some_function_call<TemplateArg><<<1,2,0,stream>>> ( arg1 , arg2 , arg3 ) ;
29
C10_CUDA_KERNEL_LAUNCH_CHECK();
32
# Does it work for macros?
33
self.assertEqual(0, check_code_for_cuda_kernel_launches(r"""
34
#define SOME_MACRO(x) some_function_call<<<1,2>>> ( x ) ; \
35
C10_CUDA_KERNEL_LAUNCH_CHECK();
37
#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \
38
indexAddSmallIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM> \
39
<<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
40
selfInfo, sourceInfo, indexInfo, \
41
selfAddDim, sourceAddDim, sliceSize, selfAddDimSize); \
42
C10_CUDA_KERNEL_LAUNCH_CHECK();
45
# Does it work for lambdas?
46
self.assertEqual(1, check_code_for_cuda_kernel_launches(r"""
47
rrelu_with_noise_cuda_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
55
[] __device__ (curandStatePhilox4_32_10_t* state) {
56
return curand_uniform2_double(state);
58
C10_CUDA_KERNEL_LAUNCH_CHECK();
60
rrelu_with_noise_cuda_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
68
[] __device__ (curandStatePhilox4_32_10_t* state) {
69
return curand_uniform2_double(state);
72
C10_CUDA_KERNEL_LAUNCH_CHECK();
75
def test_check_cuda_launches(self):
76
unsafeLaunchesCount = check_cuda_kernel_launches()
77
self.assertTrue(unsafeLaunchesCount == 0)
80
if __name__ == '__main__':