4
from torch.testing import make_tensor
5
from torch.testing._internal.common_device_type import (
7
instantiate_device_type_tests,
13
from torch.testing._internal.common_dtype import all_types_and_complex_and
14
from torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase
15
from torch.utils.dlpack import from_dlpack, to_dlpack
18
class TestTorchDlPack(TestCase):
22
@onlyNativeDeviceTypes
24
*all_types_and_complex_and(
33
def test_dlpack_capsule_conversion(self, device, dtype):
34
x = make_tensor((5,), dtype=dtype, device=device)
35
z = from_dlpack(to_dlpack(x))
36
self.assertEqual(z, x)
39
@onlyNativeDeviceTypes
41
*all_types_and_complex_and(
50
def test_dlpack_protocol_conversion(self, device, dtype):
51
x = make_tensor((5,), dtype=dtype, device=device)
53
self.assertEqual(z, x)
56
@onlyNativeDeviceTypes
57
def test_dlpack_shared_storage(self, device):
58
x = make_tensor((5,), dtype=torch.float64, device=device)
59
z = from_dlpack(to_dlpack(x))
61
self.assertEqual(z, x)
65
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
66
def test_dlpack_conversion_with_streams(self, device, dtype):
68
stream = torch.cuda.Stream()
69
with torch.cuda.stream(stream):
71
x = make_tensor((5,), dtype=dtype, device=device) + 1
80
stream = torch.cuda.Stream()
81
with torch.cuda.stream(stream):
84
self.assertEqual(z, x)
87
@onlyNativeDeviceTypes
89
*all_types_and_complex_and(
98
def test_from_dlpack(self, device, dtype):
99
x = make_tensor((5,), dtype=dtype, device=device)
100
y = torch.from_dlpack(x)
101
self.assertEqual(x, y)
104
@onlyNativeDeviceTypes
106
*all_types_and_complex_and(
115
def test_from_dlpack_noncontinguous(self, device, dtype):
116
x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5)
119
y1_dl = torch.from_dlpack(y1)
120
self.assertEqual(y1, y1_dl)
123
y2_dl = torch.from_dlpack(y2)
124
self.assertEqual(y2, y2_dl)
127
y3_dl = torch.from_dlpack(y3)
128
self.assertEqual(y3, y3_dl)
131
y4_dl = torch.from_dlpack(y4)
132
self.assertEqual(y4, y4_dl)
135
y5_dl = torch.from_dlpack(y5)
136
self.assertEqual(y5, y5_dl)
140
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
141
def test_dlpack_conversion_with_diff_streams(self, device, dtype):
142
stream_a = torch.cuda.Stream()
143
stream_b = torch.cuda.Stream()
148
with torch.cuda.stream(stream_a):
149
x = make_tensor((5,), dtype=dtype, device=device) + 1
150
z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream))
151
stream_a.synchronize()
152
stream_b.synchronize()
153
self.assertEqual(z, x)
156
@onlyNativeDeviceTypes
158
*all_types_and_complex_and(
167
def test_from_dlpack_dtype(self, device, dtype):
168
x = make_tensor((5,), dtype=dtype, device=device)
169
y = torch.from_dlpack(x)
170
assert x.dtype == y.dtype
174
def test_dlpack_default_stream(self, device):
176
def __init__(self, tensor):
179
def __dlpack_device__(self):
180
return self.tensor.__dlpack_device__()
182
def __dlpack__(self, stream=None):
183
if torch.version.hip is None:
187
capsule = self.tensor.__dlpack__(stream)
191
with torch.cuda.stream(torch.cuda.default_stream()):
192
x = DLPackTensor(make_tensor((5,), dtype=torch.float32, device=device))
198
def test_dlpack_convert_default_stream(self, device):
202
torch.cuda.default_stream().synchronize()
205
side_stream = torch.cuda.Stream()
206
with torch.cuda.stream(side_stream):
207
x = torch.zeros(1, device=device)
208
torch.cuda._sleep(2**20)
209
self.assertTrue(torch.cuda.default_stream().query())
212
self.assertFalse(torch.cuda.default_stream().query())
215
@onlyNativeDeviceTypes
216
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
217
def test_dlpack_tensor_invalid_stream(self, device, dtype):
218
with self.assertRaises(TypeError):
219
x = make_tensor((5,), dtype=dtype, device=device)
220
x.__dlpack__(stream=object())
224
def test_dlpack_export_requires_grad(self):
225
x = torch.zeros(10, dtype=torch.float32, requires_grad=True)
226
with self.assertRaisesRegex(RuntimeError, r"require gradient"):
230
def test_dlpack_export_is_conj(self):
231
x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
233
with self.assertRaisesRegex(RuntimeError, r"conjugate bit"):
237
def test_dlpack_export_non_strided(self):
238
x = torch.sparse_coo_tensor([[0]], [1], size=(1,))
240
with self.assertRaisesRegex(RuntimeError, r"strided"):
244
def test_dlpack_normalize_strides(self):
247
self.assertEqual(y.shape, (1,))
248
self.assertEqual(y.stride(), (3,))
250
self.assertEqual(z.shape, (1,))
252
self.assertEqual(z.stride(), (1,))
255
instantiate_device_type_tests(TestTorchDlPack, globals())
257
if __name__ == "__main__":