pytorch

Форк
0
/
test_dlpack.py 
258 строк · 8.2 Кб
1
# Owner(s): ["module: tests"]
2

3
import torch
4
from torch.testing import make_tensor
5
from torch.testing._internal.common_device_type import (
6
    dtypes,
7
    instantiate_device_type_tests,
8
    onlyCUDA,
9
    onlyNativeDeviceTypes,
10
    skipCUDAIfRocm,
11
    skipMeta,
12
)
13
from torch.testing._internal.common_dtype import all_types_and_complex_and
14
from torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase
15
from torch.utils.dlpack import from_dlpack, to_dlpack
16

17

18
class TestTorchDlPack(TestCase):
19
    exact_dtype = True
20

21
    @skipMeta
22
    @onlyNativeDeviceTypes
23
    @dtypes(
24
        *all_types_and_complex_and(
25
            torch.half,
26
            torch.bfloat16,
27
            torch.bool,
28
            torch.uint16,
29
            torch.uint32,
30
            torch.uint64,
31
        )
32
    )
33
    def test_dlpack_capsule_conversion(self, device, dtype):
34
        x = make_tensor((5,), dtype=dtype, device=device)
35
        z = from_dlpack(to_dlpack(x))
36
        self.assertEqual(z, x)
37

38
    @skipMeta
39
    @onlyNativeDeviceTypes
40
    @dtypes(
41
        *all_types_and_complex_and(
42
            torch.half,
43
            torch.bfloat16,
44
            torch.bool,
45
            torch.uint16,
46
            torch.uint32,
47
            torch.uint64,
48
        )
49
    )
50
    def test_dlpack_protocol_conversion(self, device, dtype):
51
        x = make_tensor((5,), dtype=dtype, device=device)
52
        z = from_dlpack(x)
53
        self.assertEqual(z, x)
54

55
    @skipMeta
56
    @onlyNativeDeviceTypes
57
    def test_dlpack_shared_storage(self, device):
58
        x = make_tensor((5,), dtype=torch.float64, device=device)
59
        z = from_dlpack(to_dlpack(x))
60
        z[0] = z[0] + 20.0
61
        self.assertEqual(z, x)
62

63
    @skipMeta
64
    @onlyCUDA
65
    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
66
    def test_dlpack_conversion_with_streams(self, device, dtype):
67
        # Create a stream where the tensor will reside
68
        stream = torch.cuda.Stream()
69
        with torch.cuda.stream(stream):
70
            # Do an operation in the actual stream
71
            x = make_tensor((5,), dtype=dtype, device=device) + 1
72
        # DLPack protocol helps establish a correct stream order
73
        # (hence data dependency) at the exchange boundary.
74
        # DLPack manages this synchronization for us, so we don't need to
75
        # explicitly wait until x is populated
76
        if IS_JETSON:
77
            # DLPack protocol that establishes correct stream order
78
            # does not behave as expected on Jetson
79
            stream.synchronize()
80
        stream = torch.cuda.Stream()
81
        with torch.cuda.stream(stream):
82
            z = from_dlpack(x)
83
        stream.synchronize()
84
        self.assertEqual(z, x)
85

86
    @skipMeta
87
    @onlyNativeDeviceTypes
88
    @dtypes(
89
        *all_types_and_complex_and(
90
            torch.half,
91
            torch.bfloat16,
92
            torch.bool,
93
            torch.uint16,
94
            torch.uint32,
95
            torch.uint64,
96
        )
97
    )
98
    def test_from_dlpack(self, device, dtype):
99
        x = make_tensor((5,), dtype=dtype, device=device)
100
        y = torch.from_dlpack(x)
101
        self.assertEqual(x, y)
102

103
    @skipMeta
104
    @onlyNativeDeviceTypes
105
    @dtypes(
106
        *all_types_and_complex_and(
107
            torch.half,
108
            torch.bfloat16,
109
            torch.bool,
110
            torch.uint16,
111
            torch.uint32,
112
            torch.uint64,
113
        )
114
    )
115
    def test_from_dlpack_noncontinguous(self, device, dtype):
116
        x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5)
117

118
        y1 = x[0]
119
        y1_dl = torch.from_dlpack(y1)
120
        self.assertEqual(y1, y1_dl)
121

122
        y2 = x[:, 0]
123
        y2_dl = torch.from_dlpack(y2)
124
        self.assertEqual(y2, y2_dl)
125

126
        y3 = x[1, :]
127
        y3_dl = torch.from_dlpack(y3)
128
        self.assertEqual(y3, y3_dl)
129

130
        y4 = x[1]
131
        y4_dl = torch.from_dlpack(y4)
132
        self.assertEqual(y4, y4_dl)
133

134
        y5 = x.t()
135
        y5_dl = torch.from_dlpack(y5)
136
        self.assertEqual(y5, y5_dl)
137

138
    @skipMeta
139
    @onlyCUDA
140
    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
141
    def test_dlpack_conversion_with_diff_streams(self, device, dtype):
142
        stream_a = torch.cuda.Stream()
143
        stream_b = torch.cuda.Stream()
144
        # DLPack protocol helps establish a correct stream order
145
        # (hence data dependency) at the exchange boundary.
146
        # the `tensor.__dlpack__` method will insert a synchronization event
147
        # in the current stream to make sure that it was correctly populated.
148
        with torch.cuda.stream(stream_a):
149
            x = make_tensor((5,), dtype=dtype, device=device) + 1
150
            z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream))
151
            stream_a.synchronize()
152
        stream_b.synchronize()
153
        self.assertEqual(z, x)
154

155
    @skipMeta
156
    @onlyNativeDeviceTypes
157
    @dtypes(
158
        *all_types_and_complex_and(
159
            torch.half,
160
            torch.bfloat16,
161
            torch.bool,
162
            torch.uint16,
163
            torch.uint32,
164
            torch.uint64,
165
        )
166
    )
167
    def test_from_dlpack_dtype(self, device, dtype):
168
        x = make_tensor((5,), dtype=dtype, device=device)
169
        y = torch.from_dlpack(x)
170
        assert x.dtype == y.dtype
171

172
    @skipMeta
173
    @onlyCUDA
174
    def test_dlpack_default_stream(self, device):
175
        class DLPackTensor:
176
            def __init__(self, tensor):
177
                self.tensor = tensor
178

179
            def __dlpack_device__(self):
180
                return self.tensor.__dlpack_device__()
181

182
            def __dlpack__(self, stream=None):
183
                if torch.version.hip is None:
184
                    assert stream == 1
185
                else:
186
                    assert stream == 0
187
                capsule = self.tensor.__dlpack__(stream)
188
                return capsule
189

190
        # CUDA-based tests runs on non-default streams
191
        with torch.cuda.stream(torch.cuda.default_stream()):
192
            x = DLPackTensor(make_tensor((5,), dtype=torch.float32, device=device))
193
            from_dlpack(x)
194

195
    @skipMeta
196
    @onlyCUDA
197
    @skipCUDAIfRocm
198
    def test_dlpack_convert_default_stream(self, device):
199
        # tests run on non-default stream, so _sleep call
200
        # below will run on a non-default stream, causing
201
        # default stream to wait due to inserted syncs
202
        torch.cuda.default_stream().synchronize()
203
        # run _sleep call on a non-default stream, causing
204
        # default stream to wait due to inserted syncs
205
        side_stream = torch.cuda.Stream()
206
        with torch.cuda.stream(side_stream):
207
            x = torch.zeros(1, device=device)
208
            torch.cuda._sleep(2**20)
209
            self.assertTrue(torch.cuda.default_stream().query())
210
            d = x.__dlpack__(1)
211
        # check that the default stream has work (a pending cudaStreamWaitEvent)
212
        self.assertFalse(torch.cuda.default_stream().query())
213

214
    @skipMeta
215
    @onlyNativeDeviceTypes
216
    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
217
    def test_dlpack_tensor_invalid_stream(self, device, dtype):
218
        with self.assertRaises(TypeError):
219
            x = make_tensor((5,), dtype=dtype, device=device)
220
            x.__dlpack__(stream=object())
221

222
    # TODO: add interchange tests once NumPy 1.22 (dlpack support) is required
223
    @skipMeta
224
    def test_dlpack_export_requires_grad(self):
225
        x = torch.zeros(10, dtype=torch.float32, requires_grad=True)
226
        with self.assertRaisesRegex(RuntimeError, r"require gradient"):
227
            x.__dlpack__()
228

229
    @skipMeta
230
    def test_dlpack_export_is_conj(self):
231
        x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
232
        y = torch.conj(x)
233
        with self.assertRaisesRegex(RuntimeError, r"conjugate bit"):
234
            y.__dlpack__()
235

236
    @skipMeta
237
    def test_dlpack_export_non_strided(self):
238
        x = torch.sparse_coo_tensor([[0]], [1], size=(1,))
239
        y = torch.conj(x)
240
        with self.assertRaisesRegex(RuntimeError, r"strided"):
241
            y.__dlpack__()
242

243
    @skipMeta
244
    def test_dlpack_normalize_strides(self):
245
        x = torch.rand(16)
246
        y = x[::3][:1]
247
        self.assertEqual(y.shape, (1,))
248
        self.assertEqual(y.stride(), (3,))
249
        z = from_dlpack(y)
250
        self.assertEqual(z.shape, (1,))
251
        # gh-83069, make sure __dlpack__ normalizes strides
252
        self.assertEqual(z.stride(), (1,))
253

254

255
instantiate_device_type_tests(TestTorchDlPack, globals())
256

257
if __name__ == "__main__":
258
    run_tests()
259

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.