pytorch

Форк
0
/
test_numba_integration.py 
360 строк · 14.8 Кб
1
# Owner(s): ["module: unknown"]
2

3
import unittest
4

5
import torch.testing._internal.common_utils as common
6
from torch.testing._internal.common_utils import TEST_NUMPY
7
from torch.testing._internal.common_cuda import TEST_NUMBA_CUDA, TEST_CUDA, TEST_MULTIGPU
8

9
import torch
10

11
if TEST_NUMPY:
12
    import numpy
13

14
if TEST_NUMBA_CUDA:
15
    import numba.cuda
16

17

18
class TestNumbaIntegration(common.TestCase):
19
    @unittest.skipIf(not TEST_NUMPY, "No numpy")
20
    @unittest.skipIf(not TEST_CUDA, "No cuda")
21
    def test_cuda_array_interface(self):
22
        """torch.Tensor exposes __cuda_array_interface__ for cuda tensors.
23

24
        An object t is considered a cuda-tensor if:
25
            hasattr(t, '__cuda_array_interface__')
26

27
        A cuda-tensor provides a tensor description dict:
28
            shape: (integer, ...) Tensor shape.
29
            strides: (integer, ...) Tensor strides, in bytes.
30
            typestr: (str) A numpy-style typestr.
31
            data: (int, boolean) A (data_ptr, read-only) tuple.
32
            version: (int) Version 0
33

34
        See:
35
        https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
36
        """
37

38
        types = [
39
            torch.DoubleTensor,
40
            torch.FloatTensor,
41
            torch.HalfTensor,
42
            torch.LongTensor,
43
            torch.IntTensor,
44
            torch.ShortTensor,
45
            torch.CharTensor,
46
            torch.ByteTensor,
47
        ]
48
        dtypes = [
49
            numpy.float64,
50
            numpy.float32,
51
            numpy.float16,
52
            numpy.int64,
53
            numpy.int32,
54
            numpy.int16,
55
            numpy.int8,
56
            numpy.uint8,
57
        ]
58
        for tp, npt in zip(types, dtypes):
59

60
            # CPU tensors do not implement the interface.
61
            cput = tp(10)
62

63
            self.assertFalse(hasattr(cput, "__cuda_array_interface__"))
64
            self.assertRaises(AttributeError, lambda: cput.__cuda_array_interface__)
65

66
            # Sparse CPU/CUDA tensors do not implement the interface
67
            if tp not in (torch.HalfTensor,):
68
                indices_t = torch.empty(1, cput.size(0), dtype=torch.long).clamp_(min=0)
69
                sparse_t = torch.sparse_coo_tensor(indices_t, cput)
70

71
                self.assertFalse(hasattr(sparse_t, "__cuda_array_interface__"))
72
                self.assertRaises(
73
                    AttributeError, lambda: sparse_t.__cuda_array_interface__
74
                )
75

76
                sparse_cuda_t = torch.sparse_coo_tensor(indices_t, cput).cuda()
77

78
                self.assertFalse(hasattr(sparse_cuda_t, "__cuda_array_interface__"))
79
                self.assertRaises(
80
                    AttributeError, lambda: sparse_cuda_t.__cuda_array_interface__
81
                )
82

83
            # CUDA tensors have the attribute and v2 interface
84
            cudat = tp(10).cuda()
85

86
            self.assertTrue(hasattr(cudat, "__cuda_array_interface__"))
87

88
            ar_dict = cudat.__cuda_array_interface__
89

90
            self.assertEqual(
91
                set(ar_dict.keys()), {"shape", "strides", "typestr", "data", "version"}
92
            )
93

94
            self.assertEqual(ar_dict["shape"], (10,))
95
            self.assertIs(ar_dict["strides"], None)
96
            # typestr from numpy, cuda-native little-endian
97
            self.assertEqual(ar_dict["typestr"], numpy.dtype(npt).newbyteorder("<").str)
98
            self.assertEqual(ar_dict["data"], (cudat.data_ptr(), False))
99
            self.assertEqual(ar_dict["version"], 2)
100

101
    @unittest.skipIf(not TEST_CUDA, "No cuda")
102
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
103
    def test_array_adaptor(self):
104
        """Torch __cuda_array_adaptor__ exposes tensor data to numba.cuda."""
105

106
        torch_dtypes = [
107
            torch.complex64,
108
            torch.complex128,
109
            torch.float16,
110
            torch.float32,
111
            torch.float64,
112
            torch.uint8,
113
            torch.int8,
114
            torch.int16,
115
            torch.int32,
116
            torch.int64,
117
        ]
118

119
        for dt in torch_dtypes:
120

121
            # CPU tensors of all types do not register as cuda arrays,
122
            # attempts to convert raise a type error.
123
            cput = torch.arange(10).to(dt)
124
            npt = cput.numpy()
125

126
            self.assertTrue(not numba.cuda.is_cuda_array(cput))
127
            with self.assertRaises(TypeError):
128
                numba.cuda.as_cuda_array(cput)
129

130
            # Any cuda tensor is a cuda array.
131
            cudat = cput.to(device="cuda")
132
            self.assertTrue(numba.cuda.is_cuda_array(cudat))
133

134
            numba_view = numba.cuda.as_cuda_array(cudat)
135
            self.assertIsInstance(numba_view, numba.cuda.devicearray.DeviceNDArray)
136

137
            # The reported type of the cuda array matches the numpy type of the cpu tensor.
138
            self.assertEqual(numba_view.dtype, npt.dtype)
139
            self.assertEqual(numba_view.strides, npt.strides)
140
            self.assertEqual(numba_view.shape, cudat.shape)
141

142
            # Pass back to cuda from host for all equality checks below, needed for
143
            # float16 comparisons, which aren't supported cpu-side.
144

145
            # The data is identical in the view.
146
            self.assertEqual(cudat, torch.tensor(numba_view.copy_to_host()).to("cuda"))
147

148
            # Writes to the torch.Tensor are reflected in the numba array.
149
            cudat[:5] = 11
150
            self.assertEqual(cudat, torch.tensor(numba_view.copy_to_host()).to("cuda"))
151

152
            # Strided tensors are supported.
153
            strided_cudat = cudat[::2]
154
            strided_npt = cput[::2].numpy()
155
            strided_numba_view = numba.cuda.as_cuda_array(strided_cudat)
156

157
            self.assertEqual(strided_numba_view.dtype, strided_npt.dtype)
158
            self.assertEqual(strided_numba_view.strides, strided_npt.strides)
159
            self.assertEqual(strided_numba_view.shape, strided_cudat.shape)
160

161
            # As of numba 0.40.0 support for strided views is ...limited...
162
            # Cannot verify correctness of strided view operations.
163

164
    @unittest.skipIf(not TEST_CUDA, "No cuda")
165
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
166
    def test_conversion_errors(self):
167
        """Numba properly detects array interface for tensor.Tensor variants."""
168

169
        # CPU tensors are not cuda arrays.
170
        cput = torch.arange(100)
171

172
        self.assertFalse(numba.cuda.is_cuda_array(cput))
173
        with self.assertRaises(TypeError):
174
            numba.cuda.as_cuda_array(cput)
175

176
        # Sparse tensors are not cuda arrays, regardless of device.
177
        sparset = torch.sparse_coo_tensor(cput[None, :], cput)
178

179
        self.assertFalse(numba.cuda.is_cuda_array(sparset))
180
        with self.assertRaises(TypeError):
181
            numba.cuda.as_cuda_array(sparset)
182

183
        sparse_cuda_t = sparset.cuda()
184

185
        self.assertFalse(numba.cuda.is_cuda_array(sparset))
186
        with self.assertRaises(TypeError):
187
            numba.cuda.as_cuda_array(sparset)
188

189
        # Device-status overrides gradient status.
190
        # CPU+gradient isn't a cuda array.
191
        cpu_gradt = torch.zeros(100).requires_grad_(True)
192

193
        self.assertFalse(numba.cuda.is_cuda_array(cpu_gradt))
194
        with self.assertRaises(TypeError):
195
            numba.cuda.as_cuda_array(cpu_gradt)
196

197
        # CUDA+gradient raises a RuntimeError on check or conversion.
198
        #
199
        # Use of hasattr for interface detection causes interface change in
200
        # python2; it swallows all exceptions not just AttributeError.
201
        cuda_gradt = torch.zeros(100).requires_grad_(True).cuda()
202

203
        # conversion raises RuntimeError
204
        with self.assertRaises(RuntimeError):
205
            numba.cuda.is_cuda_array(cuda_gradt)
206
        with self.assertRaises(RuntimeError):
207
            numba.cuda.as_cuda_array(cuda_gradt)
208

209
    @unittest.skipIf(not TEST_CUDA, "No cuda")
210
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
211
    @unittest.skipIf(not TEST_MULTIGPU, "No multigpu")
212
    def test_active_device(self):
213
        """'as_cuda_array' tensor device must match active numba context."""
214

215
        # Both torch/numba default to device 0 and can interop freely
216
        cudat = torch.arange(10, device="cuda")
217
        self.assertEqual(cudat.device.index, 0)
218
        self.assertIsInstance(
219
            numba.cuda.as_cuda_array(cudat), numba.cuda.devicearray.DeviceNDArray
220
        )
221

222
        # Tensors on non-default device raise api error if converted
223
        cudat = torch.arange(10, device=torch.device("cuda", 1))
224

225
        with self.assertRaises(numba.cuda.driver.CudaAPIError):
226
            numba.cuda.as_cuda_array(cudat)
227

228
        # but can be converted when switching to the device's context
229
        with numba.cuda.devices.gpus[cudat.device.index]:
230
            self.assertIsInstance(
231
                numba.cuda.as_cuda_array(cudat), numba.cuda.devicearray.DeviceNDArray
232
            )
233

234
    @unittest.skip("Test is temporary disabled, see https://github.com/pytorch/pytorch/issues/54418")
235
    @unittest.skipIf(not TEST_NUMPY, "No numpy")
236
    @unittest.skipIf(not TEST_CUDA, "No cuda")
237
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
238
    def test_from_cuda_array_interface(self):
239
        """torch.as_tensor() and torch.tensor() supports the __cuda_array_interface__ protocol.
240

241
        If an object exposes the __cuda_array_interface__, .as_tensor() and .tensor()
242
        will use the exposed device memory.
243

244
        See:
245
        https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
246
        """
247

248
        dtypes = [
249
            numpy.complex64,
250
            numpy.complex128,
251
            numpy.float64,
252
            numpy.float32,
253
            numpy.int64,
254
            numpy.int32,
255
            numpy.int16,
256
            numpy.int8,
257
            numpy.uint8,
258
        ]
259
        for dtype in dtypes:
260
            numpy_arys = [
261
                numpy.arange(6).reshape(2, 3).astype(dtype),
262
                numpy.arange(6).reshape(2, 3).astype(dtype)[1:],  # View offset should be ignored
263
                numpy.arange(6).reshape(2, 3).astype(dtype)[:, None],  # change the strides but still contiguous
264
            ]
265
            # Zero-copy when using `torch.as_tensor()`
266
            for numpy_ary in numpy_arys:
267
                numba_ary = numba.cuda.to_device(numpy_ary)
268
                torch_ary = torch.as_tensor(numba_ary, device="cuda")
269
                self.assertEqual(numba_ary.__cuda_array_interface__, torch_ary.__cuda_array_interface__)
270
                self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype))
271

272
                # Check that `torch_ary` and `numba_ary` points to the same device memory
273
                torch_ary += 42
274
                self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype))
275

276
            # Implicit-copy because `torch_ary` is a CPU array
277
            for numpy_ary in numpy_arys:
278
                numba_ary = numba.cuda.to_device(numpy_ary)
279
                torch_ary = torch.as_tensor(numba_ary, device="cpu")
280
                self.assertEqual(torch_ary.data.numpy(), numpy.asarray(numba_ary, dtype=dtype))
281

282
                # Check that `torch_ary` and `numba_ary` points to different memory
283
                torch_ary += 42
284
                self.assertEqual(torch_ary.data.numpy(), numpy.asarray(numba_ary, dtype=dtype) + 42)
285

286
            # Explicit-copy when using `torch.tensor()`
287
            for numpy_ary in numpy_arys:
288
                numba_ary = numba.cuda.to_device(numpy_ary)
289
                torch_ary = torch.tensor(numba_ary, device="cuda")
290
                self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype))
291

292
                # Check that `torch_ary` and `numba_ary` points to different memory
293
                torch_ary += 42
294
                self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype) + 42)
295

296
    @unittest.skipIf(not TEST_NUMPY, "No numpy")
297
    @unittest.skipIf(not TEST_CUDA, "No cuda")
298
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
299
    def test_from_cuda_array_interface_inferred_strides(self):
300
        """torch.as_tensor(numba_ary) should have correct inferred (contiguous) strides"""
301
        # This could, in theory, be combined with test_from_cuda_array_interface but that test
302
        # is overly strict: it checks that the exported protocols are exactly the same, which
303
        # cannot handle differing exported protocol versions.
304
        dtypes = [
305
            numpy.float64,
306
            numpy.float32,
307
            numpy.int64,
308
            numpy.int32,
309
            numpy.int16,
310
            numpy.int8,
311
            numpy.uint8,
312
        ]
313
        for dtype in dtypes:
314
            numpy_ary = numpy.arange(6).reshape(2, 3).astype(dtype)
315
            numba_ary = numba.cuda.to_device(numpy_ary)
316
            self.assertTrue(numba_ary.is_c_contiguous())
317
            torch_ary = torch.as_tensor(numba_ary, device="cuda")
318
            self.assertTrue(torch_ary.is_contiguous())
319

320
    @unittest.skip("Test is temporary disabled, see https://github.com/pytorch/pytorch/issues/54418")
321
    @unittest.skipIf(not TEST_NUMPY, "No numpy")
322
    @unittest.skipIf(not TEST_CUDA, "No cuda")
323
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
324
    def test_from_cuda_array_interface_lifetime(self):
325
        """torch.as_tensor(obj) tensor grabs a reference to obj so that the lifetime of obj exceeds the tensor"""
326
        numba_ary = numba.cuda.to_device(numpy.arange(6))
327
        torch_ary = torch.as_tensor(numba_ary, device="cuda")
328
        self.assertEqual(torch_ary.__cuda_array_interface__, numba_ary.__cuda_array_interface__)  # No copy
329
        del numba_ary
330
        self.assertEqual(torch_ary.cpu().data.numpy(), numpy.arange(6))  # `torch_ary` is still alive
331

332
    @unittest.skip("Test is temporary disabled, see https://github.com/pytorch/pytorch/issues/54418")
333
    @unittest.skipIf(not TEST_NUMPY, "No numpy")
334
    @unittest.skipIf(not TEST_CUDA, "No cuda")
335
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
336
    @unittest.skipIf(not TEST_MULTIGPU, "No multigpu")
337
    def test_from_cuda_array_interface_active_device(self):
338
        """torch.as_tensor() tensor device must match active numba context."""
339

340
        # Zero-copy: both torch/numba default to device 0 and can interop freely
341
        numba_ary = numba.cuda.to_device(numpy.arange(6))
342
        torch_ary = torch.as_tensor(numba_ary, device="cuda")
343
        self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary))
344
        self.assertEqual(torch_ary.__cuda_array_interface__, numba_ary.__cuda_array_interface__)
345

346
        # Implicit-copy: when the Numba and Torch device differ
347
        numba_ary = numba.cuda.to_device(numpy.arange(6))
348
        torch_ary = torch.as_tensor(numba_ary, device=torch.device("cuda", 1))
349
        self.assertEqual(torch_ary.get_device(), 1)
350
        self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary))
351
        if1 = torch_ary.__cuda_array_interface__
352
        if2 = numba_ary.__cuda_array_interface__
353
        self.assertNotEqual(if1["data"], if2["data"])
354
        del if1["data"]
355
        del if2["data"]
356
        self.assertEqual(if1, if2)
357

358

359
if __name__ == "__main__":
360
    common.run_tests()
361

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.