pytorch

Форк
0
/
test_numpy_interop.py 
631 строка · 23.6 Кб
1
# mypy: ignore-errors
2

3
# Owner(s): ["module: numpy"]
4

5
import sys
6
from itertools import product
7

8
import numpy as np
9

10
import torch
11
from torch.testing import make_tensor
12
from torch.testing._internal.common_device_type import (
13
    dtypes,
14
    instantiate_device_type_tests,
15
    onlyCPU,
16
    skipMeta,
17
)
18
from torch.testing._internal.common_dtype import all_types_and_complex_and
19
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
20

21

22
# For testing handling NumPy objects and sending tensors to / accepting
23
#   arrays from NumPy.
24
class TestNumPyInterop(TestCase):
25
    # Note: the warning this tests for only appears once per program, so
26
    # other instances of this warning should be addressed to avoid
27
    # the tests depending on the order in which they're run.
28
    @onlyCPU
29
    def test_numpy_non_writeable(self, device):
30
        arr = np.zeros(5)
31
        arr.flags["WRITEABLE"] = False
32
        self.assertWarns(UserWarning, lambda: torch.from_numpy(arr))
33

34
    @onlyCPU
35
    def test_numpy_unresizable(self, device) -> None:
36
        x = np.zeros((2, 2))
37
        y = torch.from_numpy(x)
38
        with self.assertRaises(ValueError):
39
            x.resize((5, 5))
40

41
        z = torch.randn(5, 5)
42
        w = z.numpy()
43
        with self.assertRaises(RuntimeError):
44
            z.resize_(10, 10)
45
        with self.assertRaises(ValueError):
46
            w.resize((10, 10))
47

48
    @onlyCPU
49
    def test_to_numpy(self, device) -> None:
50
        def get_castable_tensor(shape, dtype):
51
            if dtype.is_floating_point:
52
                dtype_info = torch.finfo(dtype)
53
                # can't directly use min and max, because for double, max - min
54
                # is greater than double range and sampling always gives inf.
55
                low = max(dtype_info.min, -1e10)
56
                high = min(dtype_info.max, 1e10)
57
                t = torch.empty(shape, dtype=torch.float64).uniform_(low, high)
58
            else:
59
                # can't directly use min and max, because for int64_t, max - min
60
                # is greater than int64_t range and triggers UB.
61
                low = max(torch.iinfo(dtype).min, int(-1e10))
62
                high = min(torch.iinfo(dtype).max, int(1e10))
63
                t = torch.empty(shape, dtype=torch.int64).random_(low, high)
64
            return t.to(dtype)
65

66
        dtypes = [
67
            torch.uint8,
68
            torch.int8,
69
            torch.short,
70
            torch.int,
71
            torch.half,
72
            torch.float,
73
            torch.double,
74
            torch.long,
75
        ]
76

77
        for dtp in dtypes:
78
            # 1D
79
            sz = 10
80
            x = get_castable_tensor(sz, dtp)
81
            y = x.numpy()
82
            for i in range(sz):
83
                self.assertEqual(x[i], y[i])
84

85
            # 1D > 0 storage offset
86
            xm = get_castable_tensor(sz * 2, dtp)
87
            x = xm.narrow(0, sz - 1, sz)
88
            self.assertTrue(x.storage_offset() > 0)
89
            y = x.numpy()
90
            for i in range(sz):
91
                self.assertEqual(x[i], y[i])
92

93
            def check2d(x, y):
94
                for i in range(sz1):
95
                    for j in range(sz2):
96
                        self.assertEqual(x[i][j], y[i][j])
97

98
            # empty
99
            x = torch.tensor([]).to(dtp)
100
            y = x.numpy()
101
            self.assertEqual(y.size, 0)
102

103
            # contiguous 2D
104
            sz1 = 3
105
            sz2 = 5
106
            x = get_castable_tensor((sz1, sz2), dtp)
107
            y = x.numpy()
108
            check2d(x, y)
109
            self.assertTrue(y.flags["C_CONTIGUOUS"])
110

111
            # with storage offset
112
            xm = get_castable_tensor((sz1 * 2, sz2), dtp)
113
            x = xm.narrow(0, sz1 - 1, sz1)
114
            y = x.numpy()
115
            self.assertTrue(x.storage_offset() > 0)
116
            check2d(x, y)
117
            self.assertTrue(y.flags["C_CONTIGUOUS"])
118

119
            # non-contiguous 2D
120
            x = get_castable_tensor((sz2, sz1), dtp).t()
121
            y = x.numpy()
122
            check2d(x, y)
123
            self.assertFalse(y.flags["C_CONTIGUOUS"])
124

125
            # with storage offset
126
            xm = get_castable_tensor((sz2 * 2, sz1), dtp)
127
            x = xm.narrow(0, sz2 - 1, sz2).t()
128
            y = x.numpy()
129
            self.assertTrue(x.storage_offset() > 0)
130
            check2d(x, y)
131

132
            # non-contiguous 2D with holes
133
            xm = get_castable_tensor((sz2 * 2, sz1 * 2), dtp)
134
            x = xm.narrow(0, sz2 - 1, sz2).narrow(1, sz1 - 1, sz1).t()
135
            y = x.numpy()
136
            self.assertTrue(x.storage_offset() > 0)
137
            check2d(x, y)
138

139
            if dtp != torch.half:
140
                # check writeable
141
                x = get_castable_tensor((3, 4), dtp)
142
                y = x.numpy()
143
                self.assertTrue(y.flags.writeable)
144
                y[0][1] = 3
145
                self.assertTrue(x[0][1] == 3)
146
                y = x.t().numpy()
147
                self.assertTrue(y.flags.writeable)
148
                y[0][1] = 3
149
                self.assertTrue(x[0][1] == 3)
150

151
    def test_to_numpy_bool(self, device) -> None:
152
        x = torch.tensor([True, False], dtype=torch.bool)
153
        self.assertEqual(x.dtype, torch.bool)
154

155
        y = x.numpy()
156
        self.assertEqual(y.dtype, np.bool_)
157
        for i in range(len(x)):
158
            self.assertEqual(x[i], y[i])
159

160
        x = torch.tensor([True], dtype=torch.bool)
161
        self.assertEqual(x.dtype, torch.bool)
162

163
        y = x.numpy()
164
        self.assertEqual(y.dtype, np.bool_)
165
        self.assertEqual(x[0], y[0])
166

167
    @skipIfTorchDynamo("conj bit not implemented in TensorVariable yet")
168
    def test_to_numpy_force_argument(self, device) -> None:
169
        for force in [False, True]:
170
            for requires_grad in [False, True]:
171
                for sparse in [False, True]:
172
                    for conj in [False, True]:
173
                        data = [[1 + 2j, -2 + 3j], [-1 - 2j, 3 - 2j]]
174
                        x = torch.tensor(
175
                            data, requires_grad=requires_grad, device=device
176
                        )
177
                        y = x
178
                        if sparse:
179
                            if requires_grad:
180
                                continue
181
                            x = x.to_sparse()
182
                        if conj:
183
                            x = x.conj()
184
                            y = x.resolve_conj()
185
                        expect_error = (
186
                            requires_grad or sparse or conj or not device == "cpu"
187
                        )
188
                        error_msg = r"Use (t|T)ensor\..*(\.numpy\(\))?"
189
                        if not force and expect_error:
190
                            self.assertRaisesRegex(
191
                                (RuntimeError, TypeError), error_msg, lambda: x.numpy()
192
                            )
193
                            self.assertRaisesRegex(
194
                                (RuntimeError, TypeError),
195
                                error_msg,
196
                                lambda: x.numpy(force=False),
197
                            )
198
                        elif force and sparse:
199
                            self.assertRaisesRegex(
200
                                TypeError, error_msg, lambda: x.numpy(force=True)
201
                            )
202
                        else:
203
                            self.assertEqual(x.numpy(force=force), y)
204

205
    def test_from_numpy(self, device) -> None:
206
        dtypes = [
207
            np.double,
208
            np.float64,
209
            np.float16,
210
            np.complex64,
211
            np.complex128,
212
            np.int64,
213
            np.int32,
214
            np.int16,
215
            np.int8,
216
            np.uint8,
217
            np.longlong,
218
            np.bool_,
219
        ]
220
        complex_dtypes = [
221
            np.complex64,
222
            np.complex128,
223
        ]
224

225
        for dtype in dtypes:
226
            array = np.array([1, 2, 3, 4], dtype=dtype)
227
            tensor_from_array = torch.from_numpy(array)
228
            # TODO: change to tensor equality check once HalfTensor
229
            # implements `==`
230
            for i in range(len(array)):
231
                self.assertEqual(tensor_from_array[i], array[i])
232
            # ufunc 'remainder' not supported for complex dtypes
233
            if dtype not in complex_dtypes:
234
                # This is a special test case for Windows
235
                # https://github.com/pytorch/pytorch/issues/22615
236
                array2 = array % 2
237
                tensor_from_array2 = torch.from_numpy(array2)
238
                for i in range(len(array2)):
239
                    self.assertEqual(tensor_from_array2[i], array2[i])
240

241
        # Test unsupported type
242
        array = np.array(["foo", "bar"], dtype=np.dtype(np.str_))
243
        with self.assertRaises(TypeError):
244
            tensor_from_array = torch.from_numpy(array)
245

246
        # check storage offset
247
        x = np.linspace(1, 125, 125)
248
        x.shape = (5, 5, 5)
249
        x = x[1]
250
        expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[1]
251
        self.assertEqual(torch.from_numpy(x), expected)
252

253
        # check noncontiguous
254
        x = np.linspace(1, 25, 25)
255
        x.shape = (5, 5)
256
        expected = torch.arange(1, 26, dtype=torch.float64).view(5, 5).t()
257
        self.assertEqual(torch.from_numpy(x.T), expected)
258

259
        # check noncontiguous with holes
260
        x = np.linspace(1, 125, 125)
261
        x.shape = (5, 5, 5)
262
        x = x[:, 1]
263
        expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[:, 1]
264
        self.assertEqual(torch.from_numpy(x), expected)
265

266
        # check zero dimensional
267
        x = np.zeros((0, 2))
268
        self.assertEqual(torch.from_numpy(x).shape, (0, 2))
269
        x = np.zeros((2, 0))
270
        self.assertEqual(torch.from_numpy(x).shape, (2, 0))
271

272
        # check ill-sized strides raise exception
273
        x = np.array([3.0, 5.0, 8.0])
274
        x.strides = (3,)
275
        self.assertRaises(ValueError, lambda: torch.from_numpy(x))
276

277
    @skipIfTorchDynamo("No need to test invalid dtypes that should fail by design.")
278
    def test_from_numpy_no_leak_on_invalid_dtype(self):
279
        # This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary
280
        # object. See https://github.com/pytorch/pytorch/issues/121138
281
        x = np.array("value".encode("ascii"))
282
        for _ in range(1000):
283
            try:
284
                torch.from_numpy(x)
285
            except TypeError:
286
                pass
287
        self.assertTrue(sys.getrefcount(x) == 2)
288

289
    @skipMeta
290
    def test_from_list_of_ndarray_warning(self, device):
291
        warning_msg = (
292
            r"Creating a tensor from a list of numpy.ndarrays is extremely slow"
293
        )
294
        with self.assertWarnsOnceRegex(UserWarning, warning_msg):
295
            torch.tensor([np.array([0]), np.array([1])], device=device)
296

297
    def test_ctor_with_invalid_numpy_array_sequence(self, device):
298
        # Invalid list of numpy array
299
        with self.assertRaisesRegex(ValueError, "expected sequence of length"):
300
            torch.tensor(
301
                [np.random.random(size=(3, 3)), np.random.random(size=(3, 0))],
302
                device=device,
303
            )
304

305
        # Invalid list of list of numpy array
306
        with self.assertRaisesRegex(ValueError, "expected sequence of length"):
307
            torch.tensor(
308
                [[np.random.random(size=(3, 3)), np.random.random(size=(3, 2))]],
309
                device=device,
310
            )
311

312
        with self.assertRaisesRegex(ValueError, "expected sequence of length"):
313
            torch.tensor(
314
                [
315
                    [np.random.random(size=(3, 3)), np.random.random(size=(3, 3))],
316
                    [np.random.random(size=(3, 3)), np.random.random(size=(3, 2))],
317
                ],
318
                device=device,
319
            )
320

321
        # expected shape is `[1, 2, 3]`, hence we try to iterate over 0-D array
322
        # leading to type error : not a sequence.
323
        with self.assertRaisesRegex(TypeError, "not a sequence"):
324
            torch.tensor(
325
                [[np.random.random(size=(3)), np.random.random()]], device=device
326
            )
327

328
        # list of list or numpy array.
329
        with self.assertRaisesRegex(ValueError, "expected sequence of length"):
330
            torch.tensor([[1, 2, 3], np.random.random(size=(2,))], device=device)
331

332
    @onlyCPU
333
    def test_ctor_with_numpy_scalar_ctor(self, device) -> None:
334
        dtypes = [
335
            np.double,
336
            np.float64,
337
            np.float16,
338
            np.int64,
339
            np.int32,
340
            np.int16,
341
            np.uint8,
342
            np.bool_,
343
        ]
344
        for dtype in dtypes:
345
            self.assertEqual(dtype(42), torch.tensor(dtype(42)).item())
346

347
    @onlyCPU
348
    def test_numpy_index(self, device):
349
        i = np.array([0, 1, 2], dtype=np.int32)
350
        x = torch.randn(5, 5)
351
        for idx in i:
352
            self.assertFalse(isinstance(idx, int))
353
            self.assertEqual(x[idx], x[int(idx)])
354

355
    @onlyCPU
356
    def test_numpy_index_multi(self, device):
357
        for dim_sz in [2, 8, 16, 32]:
358
            i = np.zeros((dim_sz, dim_sz, dim_sz), dtype=np.int32)
359
            i[: dim_sz // 2, :, :] = 1
360
            x = torch.randn(dim_sz, dim_sz, dim_sz)
361
            self.assertTrue(x[i == 1].numel() == np.sum(i))
362

363
    @onlyCPU
364
    def test_numpy_array_interface(self, device):
365
        types = [
366
            torch.DoubleTensor,
367
            torch.FloatTensor,
368
            torch.HalfTensor,
369
            torch.LongTensor,
370
            torch.IntTensor,
371
            torch.ShortTensor,
372
            torch.ByteTensor,
373
        ]
374
        dtypes = [
375
            np.float64,
376
            np.float32,
377
            np.float16,
378
            np.int64,
379
            np.int32,
380
            np.int16,
381
            np.uint8,
382
        ]
383
        for tp, dtype in zip(types, dtypes):
384
            # Only concrete class can be given where "Type[number[_64Bit]]" is expected
385
            if np.dtype(dtype).kind == "u":  # type: ignore[misc]
386
                # .type expects a XxxTensor, which have no type hints on
387
                # purpose, so ignore during mypy type checking
388
                x = torch.tensor([1, 2, 3, 4]).type(tp)  # type: ignore[call-overload]
389
                array = np.array([1, 2, 3, 4], dtype=dtype)
390
            else:
391
                x = torch.tensor([1, -2, 3, -4]).type(tp)  # type: ignore[call-overload]
392
                array = np.array([1, -2, 3, -4], dtype=dtype)
393

394
            # Test __array__ w/o dtype argument
395
            asarray = np.asarray(x)
396
            self.assertIsInstance(asarray, np.ndarray)
397
            self.assertEqual(asarray.dtype, dtype)
398
            for i in range(len(x)):
399
                self.assertEqual(asarray[i], x[i])
400

401
            # Test __array_wrap__, same dtype
402
            abs_x = np.abs(x)
403
            abs_array = np.abs(array)
404
            self.assertIsInstance(abs_x, tp)
405
            for i in range(len(x)):
406
                self.assertEqual(abs_x[i], abs_array[i])
407

408
        # Test __array__ with dtype argument
409
        for dtype in dtypes:
410
            x = torch.IntTensor([1, -2, 3, -4])
411
            asarray = np.asarray(x, dtype=dtype)
412
            self.assertEqual(asarray.dtype, dtype)
413
            # Only concrete class can be given where "Type[number[_64Bit]]" is expected
414
            if np.dtype(dtype).kind == "u":  # type: ignore[misc]
415
                wrapped_x = np.array([1, -2, 3, -4], dtype=dtype)
416
                for i in range(len(x)):
417
                    self.assertEqual(asarray[i], wrapped_x[i])
418
            else:
419
                for i in range(len(x)):
420
                    self.assertEqual(asarray[i], x[i])
421

422
        # Test some math functions with float types
423
        float_types = [torch.DoubleTensor, torch.FloatTensor]
424
        float_dtypes = [np.float64, np.float32]
425
        for tp, dtype in zip(float_types, float_dtypes):
426
            x = torch.tensor([1, 2, 3, 4]).type(tp)  # type: ignore[call-overload]
427
            array = np.array([1, 2, 3, 4], dtype=dtype)
428
            for func in ["sin", "sqrt", "ceil"]:
429
                ufunc = getattr(np, func)
430
                res_x = ufunc(x)
431
                res_array = ufunc(array)
432
                self.assertIsInstance(res_x, tp)
433
                for i in range(len(x)):
434
                    self.assertEqual(res_x[i], res_array[i])
435

436
        # Test functions with boolean return value
437
        for tp, dtype in zip(types, dtypes):
438
            x = torch.tensor([1, 2, 3, 4]).type(tp)  # type: ignore[call-overload]
439
            array = np.array([1, 2, 3, 4], dtype=dtype)
440
            geq2_x = np.greater_equal(x, 2)
441
            geq2_array = np.greater_equal(array, 2).astype("uint8")
442
            self.assertIsInstance(geq2_x, torch.ByteTensor)
443
            for i in range(len(x)):
444
                self.assertEqual(geq2_x[i], geq2_array[i])
445

446
    @onlyCPU
447
    def test_multiplication_numpy_scalar(self, device) -> None:
448
        for np_dtype in [
449
            np.float32,
450
            np.float64,
451
            np.int32,
452
            np.int64,
453
            np.int16,
454
            np.uint8,
455
        ]:
456
            for t_dtype in [torch.float, torch.double]:
457
                # mypy raises an error when np.floatXY(2.0) is called
458
                # even though this is valid code
459
                np_sc = np_dtype(2.0)  # type: ignore[abstract, arg-type]
460
                t = torch.ones(2, requires_grad=True, dtype=t_dtype)
461
                r1 = t * np_sc
462
                self.assertIsInstance(r1, torch.Tensor)
463
                self.assertTrue(r1.dtype == t_dtype)
464
                self.assertTrue(r1.requires_grad)
465
                r2 = np_sc * t
466
                self.assertIsInstance(r2, torch.Tensor)
467
                self.assertTrue(r2.dtype == t_dtype)
468
                self.assertTrue(r2.requires_grad)
469

470
    @onlyCPU
471
    @skipIfTorchDynamo()
472
    def test_parse_numpy_int_overflow(self, device):
473
        # assertRaises uses a try-except which dynamo has issues with
474
        # Only concrete class can be given where "Type[number[_64Bit]]" is expected
475
        self.assertRaisesRegex(
476
            RuntimeError,
477
            "(Overflow|an integer is required)",
478
            lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)),
479
        )  # type: ignore[call-overload]
480

481
    @onlyCPU
482
    def test_parse_numpy_int(self, device):
483
        # https://github.com/pytorch/pytorch/issues/29252
484
        for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]:
485
            scalar = 3
486
            np_arr = np.array([scalar], dtype=nptype)
487
            np_val = np_arr[0]
488

489
            # np integral type can be treated as a python int in native functions with
490
            # int parameters:
491
            self.assertEqual(torch.ones(5).diag(scalar), torch.ones(5).diag(np_val))
492
            self.assertEqual(
493
                torch.ones([2, 2, 2, 2]).mean(scalar),
494
                torch.ones([2, 2, 2, 2]).mean(np_val),
495
            )
496

497
            # numpy integral type parses like a python int in custom python bindings:
498
            self.assertEqual(torch.Storage(np_val).size(), scalar)  # type: ignore[attr-defined]
499

500
            tensor = torch.tensor([2], dtype=torch.int)
501
            tensor[0] = np_val
502
            self.assertEqual(tensor[0], np_val)
503

504
            # Original reported issue, np integral type parses to the correct
505
            # PyTorch integral type when passed for a `Scalar` parameter in
506
            # arithmetic operations:
507
            t = torch.from_numpy(np_arr)
508
            self.assertEqual((t + np_val).dtype, t.dtype)
509
            self.assertEqual((np_val + t).dtype, t.dtype)
510

511
    def test_has_storage_numpy(self, device):
512
        for dtype in [np.float32, np.float64, np.int64, np.int32, np.int16, np.uint8]:
513
            arr = np.array([1], dtype=dtype)
514
            self.assertIsNotNone(
515
                torch.tensor(arr, device=device, dtype=torch.float32).storage()
516
            )
517
            self.assertIsNotNone(
518
                torch.tensor(arr, device=device, dtype=torch.double).storage()
519
            )
520
            self.assertIsNotNone(
521
                torch.tensor(arr, device=device, dtype=torch.int).storage()
522
            )
523
            self.assertIsNotNone(
524
                torch.tensor(arr, device=device, dtype=torch.long).storage()
525
            )
526
            self.assertIsNotNone(
527
                torch.tensor(arr, device=device, dtype=torch.uint8).storage()
528
            )
529

530
    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
531
    def test_numpy_scalar_cmp(self, device, dtype):
532
        if dtype.is_complex:
533
            tensors = (
534
                torch.tensor(complex(1, 3), dtype=dtype, device=device),
535
                torch.tensor([complex(1, 3), 0, 2j], dtype=dtype, device=device),
536
                torch.tensor(
537
                    [[complex(3, 1), 0], [-1j, 5]], dtype=dtype, device=device
538
                ),
539
            )
540
        else:
541
            tensors = (
542
                torch.tensor(3, dtype=dtype, device=device),
543
                torch.tensor([1, 0, -3], dtype=dtype, device=device),
544
                torch.tensor([[3, 0, -1], [3, 5, 4]], dtype=dtype, device=device),
545
            )
546

547
        for tensor in tensors:
548
            if dtype == torch.bfloat16:
549
                with self.assertRaises(TypeError):
550
                    np_array = tensor.cpu().numpy()
551
                continue
552

553
            np_array = tensor.cpu().numpy()
554
            for t, a in product(
555
                (tensor.flatten()[0], tensor.flatten()[0].item()),
556
                (np_array.flatten()[0], np_array.flatten()[0].item()),
557
            ):
558
                self.assertEqual(t, a)
559
                if (
560
                    dtype == torch.complex64
561
                    and torch.is_tensor(t)
562
                    and type(a) == np.complex64
563
                ):
564
                    # TODO: Imaginary part is dropped in this case. Need fix.
565
                    # https://github.com/pytorch/pytorch/issues/43579
566
                    self.assertFalse(t == a)
567
                else:
568
                    self.assertTrue(t == a)
569

570
    @onlyCPU
571
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
572
    def test___eq__(self, device, dtype):
573
        a = make_tensor((5, 7), dtype=dtype, device=device, low=-9, high=9)
574
        b = a.clone().detach()
575
        b_np = b.numpy()
576

577
        # Check all elements equal
578
        res_check = torch.ones_like(a, dtype=torch.bool)
579
        self.assertEqual(a == b_np, res_check)
580
        self.assertEqual(b_np == a, res_check)
581

582
        # Check one element unequal
583
        if dtype == torch.bool:
584
            b[1][3] = not b[1][3]
585
        else:
586
            b[1][3] += 1
587
        res_check[1][3] = False
588
        self.assertEqual(a == b_np, res_check)
589
        self.assertEqual(b_np == a, res_check)
590

591
        # Check random elements unequal
592
        rand = torch.randint(0, 2, a.shape, dtype=torch.bool)
593
        res_check = rand.logical_not()
594
        b.copy_(a)
595

596
        if dtype == torch.bool:
597
            b[rand] = b[rand].logical_not()
598
        else:
599
            b[rand] += 1
600

601
        self.assertEqual(a == b_np, res_check)
602
        self.assertEqual(b_np == a, res_check)
603

604
        # Check all elements unequal
605
        if dtype == torch.bool:
606
            b.copy_(a.logical_not())
607
        else:
608
            b.copy_(a + 1)
609
        res_check.fill_(False)
610
        self.assertEqual(a == b_np, res_check)
611
        self.assertEqual(b_np == a, res_check)
612

613
    @onlyCPU
614
    def test_empty_tensors_interop(self, device):
615
        x = torch.rand((), dtype=torch.float16)
616
        y = torch.tensor(np.random.rand(0), dtype=torch.float16)
617
        # Same can be achieved by running
618
        # y = torch.empty_strided((0,), (0,), dtype=torch.float16)
619

620
        # Regression test for https://github.com/pytorch/pytorch/issues/115068
621
        self.assertEqual(torch.true_divide(x, y).shape, y.shape)
622
        # Regression test for https://github.com/pytorch/pytorch/issues/115066
623
        self.assertEqual(torch.mul(x, y).shape, y.shape)
624
        # Regression test for https://github.com/pytorch/pytorch/issues/113037
625
        self.assertEqual(torch.div(x, y, rounding_mode="floor").shape, y.shape)
626

627

628
instantiate_device_type_tests(TestNumPyInterop, globals())
629

630
if __name__ == "__main__":
631
    run_tests()
632

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.