pytorch

Форк
0
/
test_numba_integration.py 
399 строк · 15.4 Кб
1
# Owner(s): ["module: unknown"]
2

3
import unittest
4

5
import torch
6
import torch.testing._internal.common_utils as common
7
from torch.testing._internal.common_cuda import (
8
    TEST_CUDA,
9
    TEST_MULTIGPU,
10
    TEST_NUMBA_CUDA,
11
)
12
from torch.testing._internal.common_utils import TEST_NUMPY
13

14

15
if TEST_NUMPY:
16
    import numpy
17

18
if TEST_NUMBA_CUDA:
19
    import numba.cuda
20

21

22
class TestNumbaIntegration(common.TestCase):
23
    @unittest.skipIf(not TEST_NUMPY, "No numpy")
24
    @unittest.skipIf(not TEST_CUDA, "No cuda")
25
    def test_cuda_array_interface(self):
26
        """torch.Tensor exposes __cuda_array_interface__ for cuda tensors.
27

28
        An object t is considered a cuda-tensor if:
29
            hasattr(t, '__cuda_array_interface__')
30

31
        A cuda-tensor provides a tensor description dict:
32
            shape: (integer, ...) Tensor shape.
33
            strides: (integer, ...) Tensor strides, in bytes.
34
            typestr: (str) A numpy-style typestr.
35
            data: (int, boolean) A (data_ptr, read-only) tuple.
36
            version: (int) Version 0
37

38
        See:
39
        https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
40
        """
41

42
        types = [
43
            torch.DoubleTensor,
44
            torch.FloatTensor,
45
            torch.HalfTensor,
46
            torch.LongTensor,
47
            torch.IntTensor,
48
            torch.ShortTensor,
49
            torch.CharTensor,
50
            torch.ByteTensor,
51
        ]
52
        dtypes = [
53
            numpy.float64,
54
            numpy.float32,
55
            numpy.float16,
56
            numpy.int64,
57
            numpy.int32,
58
            numpy.int16,
59
            numpy.int8,
60
            numpy.uint8,
61
        ]
62
        for tp, npt in zip(types, dtypes):
63
            # CPU tensors do not implement the interface.
64
            cput = tp(10)
65

66
            self.assertFalse(hasattr(cput, "__cuda_array_interface__"))
67
            self.assertRaises(AttributeError, lambda: cput.__cuda_array_interface__)
68

69
            # Sparse CPU/CUDA tensors do not implement the interface
70
            if tp not in (torch.HalfTensor,):
71
                indices_t = torch.empty(1, cput.size(0), dtype=torch.long).clamp_(min=0)
72
                sparse_t = torch.sparse_coo_tensor(indices_t, cput)
73

74
                self.assertFalse(hasattr(sparse_t, "__cuda_array_interface__"))
75
                self.assertRaises(
76
                    AttributeError, lambda: sparse_t.__cuda_array_interface__
77
                )
78

79
                sparse_cuda_t = torch.sparse_coo_tensor(indices_t, cput).cuda()
80

81
                self.assertFalse(hasattr(sparse_cuda_t, "__cuda_array_interface__"))
82
                self.assertRaises(
83
                    AttributeError, lambda: sparse_cuda_t.__cuda_array_interface__
84
                )
85

86
            # CUDA tensors have the attribute and v2 interface
87
            cudat = tp(10).cuda()
88

89
            self.assertTrue(hasattr(cudat, "__cuda_array_interface__"))
90

91
            ar_dict = cudat.__cuda_array_interface__
92

93
            self.assertEqual(
94
                set(ar_dict.keys()), {"shape", "strides", "typestr", "data", "version"}
95
            )
96

97
            self.assertEqual(ar_dict["shape"], (10,))
98
            self.assertIs(ar_dict["strides"], None)
99
            # typestr from numpy, cuda-native little-endian
100
            self.assertEqual(ar_dict["typestr"], numpy.dtype(npt).newbyteorder("<").str)
101
            self.assertEqual(ar_dict["data"], (cudat.data_ptr(), False))
102
            self.assertEqual(ar_dict["version"], 2)
103

104
    @unittest.skipIf(not TEST_CUDA, "No cuda")
105
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
106
    def test_array_adaptor(self):
107
        """Torch __cuda_array_adaptor__ exposes tensor data to numba.cuda."""
108

109
        torch_dtypes = [
110
            torch.complex64,
111
            torch.complex128,
112
            torch.float16,
113
            torch.float32,
114
            torch.float64,
115
            torch.uint8,
116
            torch.int8,
117
            torch.uint16,
118
            torch.int16,
119
            torch.uint32,
120
            torch.int32,
121
            torch.uint64,
122
            torch.int64,
123
            torch.bool,
124
        ]
125

126
        for dt in torch_dtypes:
127
            # CPU tensors of all types do not register as cuda arrays,
128
            # attempts to convert raise a type error.
129
            cput = torch.arange(10).to(dt)
130
            npt = cput.numpy()
131

132
            self.assertTrue(not numba.cuda.is_cuda_array(cput))
133
            with self.assertRaises(TypeError):
134
                numba.cuda.as_cuda_array(cput)
135

136
            # Any cuda tensor is a cuda array.
137
            cudat = cput.to(device="cuda")
138
            self.assertTrue(numba.cuda.is_cuda_array(cudat))
139

140
            numba_view = numba.cuda.as_cuda_array(cudat)
141
            self.assertIsInstance(numba_view, numba.cuda.devicearray.DeviceNDArray)
142

143
            # The reported type of the cuda array matches the numpy type of the cpu tensor.
144
            self.assertEqual(numba_view.dtype, npt.dtype)
145
            self.assertEqual(numba_view.strides, npt.strides)
146
            self.assertEqual(numba_view.shape, cudat.shape)
147

148
            # Pass back to cuda from host for all equality checks below, needed for
149
            # float16 comparisons, which aren't supported cpu-side.
150

151
            # The data is identical in the view.
152
            self.assertEqual(cudat, torch.tensor(numba_view.copy_to_host()).to("cuda"))
153

154
            # Writes to the torch.Tensor are reflected in the numba array.
155
            cudat[:5] = 11
156
            self.assertEqual(cudat, torch.tensor(numba_view.copy_to_host()).to("cuda"))
157

158
            # Strided tensors are supported.
159
            strided_cudat = cudat[::2]
160
            strided_npt = cput[::2].numpy()
161
            strided_numba_view = numba.cuda.as_cuda_array(strided_cudat)
162

163
            self.assertEqual(strided_numba_view.dtype, strided_npt.dtype)
164
            self.assertEqual(strided_numba_view.strides, strided_npt.strides)
165
            self.assertEqual(strided_numba_view.shape, strided_cudat.shape)
166

167
            # As of numba 0.40.0 support for strided views is ...limited...
168
            # Cannot verify correctness of strided view operations.
169

170
    @unittest.skipIf(not TEST_CUDA, "No cuda")
171
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
172
    def test_conversion_errors(self):
173
        """Numba properly detects array interface for tensor.Tensor variants."""
174

175
        # CPU tensors are not cuda arrays.
176
        cput = torch.arange(100)
177

178
        self.assertFalse(numba.cuda.is_cuda_array(cput))
179
        with self.assertRaises(TypeError):
180
            numba.cuda.as_cuda_array(cput)
181

182
        # Sparse tensors are not cuda arrays, regardless of device.
183
        sparset = torch.sparse_coo_tensor(cput[None, :], cput)
184

185
        self.assertFalse(numba.cuda.is_cuda_array(sparset))
186
        with self.assertRaises(TypeError):
187
            numba.cuda.as_cuda_array(sparset)
188

189
        sparse_cuda_t = sparset.cuda()
190

191
        self.assertFalse(numba.cuda.is_cuda_array(sparset))
192
        with self.assertRaises(TypeError):
193
            numba.cuda.as_cuda_array(sparset)
194

195
        # Device-status overrides gradient status.
196
        # CPU+gradient isn't a cuda array.
197
        cpu_gradt = torch.zeros(100).requires_grad_(True)
198

199
        self.assertFalse(numba.cuda.is_cuda_array(cpu_gradt))
200
        with self.assertRaises(TypeError):
201
            numba.cuda.as_cuda_array(cpu_gradt)
202

203
        # CUDA+gradient raises a RuntimeError on check or conversion.
204
        #
205
        # Use of hasattr for interface detection causes interface change in
206
        # python2; it swallows all exceptions not just AttributeError.
207
        cuda_gradt = torch.zeros(100).requires_grad_(True).cuda()
208

209
        # conversion raises RuntimeError
210
        with self.assertRaises(RuntimeError):
211
            numba.cuda.is_cuda_array(cuda_gradt)
212
        with self.assertRaises(RuntimeError):
213
            numba.cuda.as_cuda_array(cuda_gradt)
214

215
    @unittest.skipIf(not TEST_CUDA, "No cuda")
216
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
217
    @unittest.skipIf(not TEST_MULTIGPU, "No multigpu")
218
    def test_active_device(self):
219
        """'as_cuda_array' tensor device must match active numba context."""
220

221
        # Both torch/numba default to device 0 and can interop freely
222
        cudat = torch.arange(10, device="cuda")
223
        self.assertEqual(cudat.device.index, 0)
224
        self.assertIsInstance(
225
            numba.cuda.as_cuda_array(cudat), numba.cuda.devicearray.DeviceNDArray
226
        )
227

228
        # Tensors on non-default device raise api error if converted
229
        cudat = torch.arange(10, device=torch.device("cuda", 1))
230

231
        with self.assertRaises(numba.cuda.driver.CudaAPIError):
232
            numba.cuda.as_cuda_array(cudat)
233

234
        # but can be converted when switching to the device's context
235
        with numba.cuda.devices.gpus[cudat.device.index]:
236
            self.assertIsInstance(
237
                numba.cuda.as_cuda_array(cudat), numba.cuda.devicearray.DeviceNDArray
238
            )
239

240
    @unittest.skip(
241
        "Test is temporary disabled, see https://github.com/pytorch/pytorch/issues/54418"
242
    )
243
    @unittest.skipIf(not TEST_NUMPY, "No numpy")
244
    @unittest.skipIf(not TEST_CUDA, "No cuda")
245
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
246
    def test_from_cuda_array_interface(self):
247
        """torch.as_tensor() and torch.tensor() supports the __cuda_array_interface__ protocol.
248

249
        If an object exposes the __cuda_array_interface__, .as_tensor() and .tensor()
250
        will use the exposed device memory.
251

252
        See:
253
        https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
254
        """
255

256
        dtypes = [
257
            numpy.complex64,
258
            numpy.complex128,
259
            numpy.float64,
260
            numpy.float32,
261
            numpy.int64,
262
            numpy.int32,
263
            numpy.int16,
264
            numpy.int8,
265
            numpy.uint8,
266
        ]
267
        for dtype in dtypes:
268
            numpy_arys = [
269
                numpy.ones((), dtype=dtype),
270
                numpy.arange(6).reshape(2, 3).astype(dtype),
271
                numpy.arange(6)
272
                .reshape(2, 3)
273
                .astype(dtype)[1:],  # View offset should be ignored
274
                numpy.arange(6)
275
                .reshape(2, 3)
276
                .astype(dtype)[:, None],  # change the strides but still contiguous
277
            ]
278
            # Zero-copy when using `torch.as_tensor()`
279
            for numpy_ary in numpy_arys:
280
                numba_ary = numba.cuda.to_device(numpy_ary)
281
                torch_ary = torch.as_tensor(numba_ary, device="cuda")
282
                self.assertEqual(
283
                    numba_ary.__cuda_array_interface__,
284
                    torch_ary.__cuda_array_interface__,
285
                )
286
                self.assertEqual(
287
                    torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype)
288
                )
289

290
                # Check that `torch_ary` and `numba_ary` points to the same device memory
291
                torch_ary += 42
292
                self.assertEqual(
293
                    torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype)
294
                )
295

296
            # Implicit-copy because `torch_ary` is a CPU array
297
            for numpy_ary in numpy_arys:
298
                numba_ary = numba.cuda.to_device(numpy_ary)
299
                torch_ary = torch.as_tensor(numba_ary, device="cpu")
300
                self.assertEqual(
301
                    torch_ary.data.numpy(), numpy.asarray(numba_ary, dtype=dtype)
302
                )
303

304
                # Check that `torch_ary` and `numba_ary` points to different memory
305
                torch_ary += 42
306
                self.assertEqual(
307
                    torch_ary.data.numpy(), numpy.asarray(numba_ary, dtype=dtype) + 42
308
                )
309

310
            # Explicit-copy when using `torch.tensor()`
311
            for numpy_ary in numpy_arys:
312
                numba_ary = numba.cuda.to_device(numpy_ary)
313
                torch_ary = torch.tensor(numba_ary, device="cuda")
314
                self.assertEqual(
315
                    torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary, dtype=dtype)
316
                )
317

318
                # Check that `torch_ary` and `numba_ary` points to different memory
319
                torch_ary += 42
320
                self.assertEqual(
321
                    torch_ary.cpu().data.numpy(),
322
                    numpy.asarray(numba_ary, dtype=dtype) + 42,
323
                )
324

325
    @unittest.skipIf(not TEST_NUMPY, "No numpy")
326
    @unittest.skipIf(not TEST_CUDA, "No cuda")
327
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
328
    def test_from_cuda_array_interface_inferred_strides(self):
329
        """torch.as_tensor(numba_ary) should have correct inferred (contiguous) strides"""
330
        # This could, in theory, be combined with test_from_cuda_array_interface but that test
331
        # is overly strict: it checks that the exported protocols are exactly the same, which
332
        # cannot handle differing exported protocol versions.
333
        dtypes = [
334
            numpy.float64,
335
            numpy.float32,
336
            numpy.int64,
337
            numpy.int32,
338
            numpy.int16,
339
            numpy.int8,
340
            numpy.uint8,
341
        ]
342
        for dtype in dtypes:
343
            numpy_ary = numpy.arange(6).reshape(2, 3).astype(dtype)
344
            numba_ary = numba.cuda.to_device(numpy_ary)
345
            self.assertTrue(numba_ary.is_c_contiguous())
346
            torch_ary = torch.as_tensor(numba_ary, device="cuda")
347
            self.assertTrue(torch_ary.is_contiguous())
348

349
    @unittest.skip(
350
        "Test is temporary disabled, see https://github.com/pytorch/pytorch/issues/54418"
351
    )
352
    @unittest.skipIf(not TEST_NUMPY, "No numpy")
353
    @unittest.skipIf(not TEST_CUDA, "No cuda")
354
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
355
    def test_from_cuda_array_interface_lifetime(self):
356
        """torch.as_tensor(obj) tensor grabs a reference to obj so that the lifetime of obj exceeds the tensor"""
357
        numba_ary = numba.cuda.to_device(numpy.arange(6))
358
        torch_ary = torch.as_tensor(numba_ary, device="cuda")
359
        self.assertEqual(
360
            torch_ary.__cuda_array_interface__, numba_ary.__cuda_array_interface__
361
        )  # No copy
362
        del numba_ary
363
        self.assertEqual(
364
            torch_ary.cpu().data.numpy(), numpy.arange(6)
365
        )  # `torch_ary` is still alive
366

367
    @unittest.skip(
368
        "Test is temporary disabled, see https://github.com/pytorch/pytorch/issues/54418"
369
    )
370
    @unittest.skipIf(not TEST_NUMPY, "No numpy")
371
    @unittest.skipIf(not TEST_CUDA, "No cuda")
372
    @unittest.skipIf(not TEST_NUMBA_CUDA, "No numba.cuda")
373
    @unittest.skipIf(not TEST_MULTIGPU, "No multigpu")
374
    def test_from_cuda_array_interface_active_device(self):
375
        """torch.as_tensor() tensor device must match active numba context."""
376

377
        # Zero-copy: both torch/numba default to device 0 and can interop freely
378
        numba_ary = numba.cuda.to_device(numpy.arange(6))
379
        torch_ary = torch.as_tensor(numba_ary, device="cuda")
380
        self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary))
381
        self.assertEqual(
382
            torch_ary.__cuda_array_interface__, numba_ary.__cuda_array_interface__
383
        )
384

385
        # Implicit-copy: when the Numba and Torch device differ
386
        numba_ary = numba.cuda.to_device(numpy.arange(6))
387
        torch_ary = torch.as_tensor(numba_ary, device=torch.device("cuda", 1))
388
        self.assertEqual(torch_ary.get_device(), 1)
389
        self.assertEqual(torch_ary.cpu().data.numpy(), numpy.asarray(numba_ary))
390
        if1 = torch_ary.__cuda_array_interface__
391
        if2 = numba_ary.__cuda_array_interface__
392
        self.assertNotEqual(if1["data"], if2["data"])
393
        del if1["data"]
394
        del if2["data"]
395
        self.assertEqual(if1, if2)
396

397

398
if __name__ == "__main__":
399
    common.run_tests()
400

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.