pytorch

Форк
0
/
test_foreach.py 
995 строк · 45.6 Кб
1
# Owner(s): ["module: mta"]
2

3
from contextlib import nullcontext
4
from numbers import Number
5
import random
6
import re
7
import torch
8
import unittest
9
import itertools
10
import weakref
11

12
from torch.testing import make_tensor
13
from torch.testing._comparison import default_tolerances
14
from torch.testing._internal.common_cuda import TEST_MULTIGPU
15
from torch.testing._internal.common_utils import \
16
    TestCase, run_tests, TEST_WITH_ROCM, skipIfTorchDynamo, parametrize, gradcheck, skipIfRocmVersionLessThan
17
from torch.testing._internal.common_device_type import \
18
    (instantiate_device_type_tests, dtypes, onlyCUDA, ops, OpDTypes)
19
from torch.testing._internal.common_methods_invocations import (
20
    foreach_unary_op_db, foreach_binary_op_db, foreach_pointwise_op_db,
21
    foreach_reduce_op_db, foreach_other_op_db)
22
from torch.testing._internal.common_dtype import (
23
    all_types_and_complex_and, floating_types_and, floating_types, integral_types_and,
24
)
25

26

27
_BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator"
28

29

30
class RegularFuncWrapper:
31
    def __init__(self, func):
32
        self.func = func
33

34
    def __call__(self, inputs, scalars=None, **kwargs):
35
        if scalars is not None:
36
            assert len(inputs) == 3
37
            # We need to distribute each scalar to the regular func and it needs
38
            # special consideration as it is a keyword only argument to the
39
            # regular func. (Strangely, it is not a keyword only argument to the
40
            # foreach func)
41
            return [self.func(*i, value=scalars[idx], **kwargs) for idx, i in enumerate(zip(*inputs))]
42
        if len(inputs) == 2 and isinstance(inputs[1], (Number, torch.Tensor)):
43
            # binary op with tensorlist and scalar.
44
            inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
45
        return [self.func(*i, **kwargs) for i in zip(*inputs)]
46

47

48
class ForeachFuncWrapper:
49
    def __init__(self, func):
50
        self.func = func
51
        # Some foreach functions don't have in-place implementations.
52
        self.is_inplace = False if func is None else func.__name__.endswith('_')
53

54
    def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs):
55
        actual = None
56
        zero_size = kwargs.pop("zero_size", False)
57
        if (
58
            is_cuda and
59
            torch.autograd.kineto_available() and
60
            torch.profiler.ProfilerActivity.CUDA in torch.profiler.supported_activities()
61
        ):
62
            with torch.profiler.profile() as p:
63
                actual = self.func(*inputs, **kwargs)
64
            keys = tuple([e.key for e in p.key_averages()])
65
            mta_called = any("multi_tensor_apply_kernel" in k for k in keys)
66
            assert mta_called == (expect_fastpath and (not zero_size))
67
        else:
68
            actual = self.func(*inputs, **kwargs)
69
        # note(mkozuki): inplace foreach functions are void functions.
70
        return inputs[0] if self.is_inplace else actual
71

72

73
class InplaceForeachVersionBumpCheck:
74

75
    def __init__(self, testcase: TestCase, tensorlist: "List[torch.Tensor]") -> None:  # noqa: F821
76
        self._testcase = testcase
77
        self._tensorlist = tensorlist
78
        self._orig_version_counts = [t._version for t in tensorlist]
79

80
    def __enter__(self):
81
        pass
82

83
    def __exit__(self, exc_type, exc_value, traceback):
84
        # note(crcrpar): some methods e.g. `_binary_test` could call the given inplace function multiple times
85
        self._testcase.assertGreaterEqual([t._version for t in self._tensorlist], self._orig_version_counts)
86

87

88
def get_transform_func(num_tensors, dtype, device, is_fastpath):
89
    def transform(t):
90
        if not torch.is_tensor(t):
91
            return t
92
        if torch.is_tensor(t) and t.ndim == 0:
93
            return t
94
        return make_tensor(
95
            (num_tensors, num_tensors), dtype=dtype, device=device,
96
            requires_grad=True, noncontiguous=not is_fastpath,
97
        )
98

99
    return transform
100

101

102
# note(crcrpar): `zero_size` is `False` unless (dtype, device) == (torch.float32, "cuda")
103
# as the pair would go through `multi_tensor_apply_kernel` if inputs are not zero size.
104
class TestForeach(TestCase):
105
    @property
106
    def is_cuda(self):
107
        return self.device_type == 'cuda'
108

109
    def _get_funcs(self, op):
110
        return (
111
            ForeachFuncWrapper(op.method_variant),
112
            RegularFuncWrapper(op.ref),
113
            ForeachFuncWrapper(op.inplace_variant),
114
            RegularFuncWrapper(op.ref_inplace),
115
        )
116

117
    # note(crcrpar): Make sure 0-size tensors are appropriately ignored by `multi_tensor_apply`
118
    # which is originally reported in https://github.com/pytorch/pytorch/issues/94865.
119
    # rel:
120
    #   - https://github.com/pytorch/pytorch/pull/94655
121
    #   - https://github.com/pytorch/pytorch/issues/100701
122
    #   - https://github.com/pytorch/pytorch/pull/100811
123
    @onlyCUDA
124
    @ops(
125
        foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_reduce_op_db + foreach_other_op_db,
126
        dtypes=(torch.float32,)
127
    )
128
    def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op):
129
        wrapped_op, _, inplace_op, _ = self._get_funcs(op)
130

131
        for sample in op.sample_zero_size_inputs(device, dtype):
132
            if op.supports_out:
133
                wrapped_op((sample.input, *sample.args), is_cuda=self.is_cuda, expect_fastpath=True, zero_size=True)
134
            with InplaceForeachVersionBumpCheck(self, sample.input):
135
                inplace_op((sample.input, *sample.args), is_cuda=self.is_cuda, expect_fastpath=True, zero_size=True)
136

137
    @skipIfRocmVersionLessThan((6, 0))
138
    @ops(
139
        foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_reduce_op_db + foreach_other_op_db,
140
    )
141
    @parametrize(
142
        "noncontiguous,inplace",
143
        [(False, False), (False, True), (True, False), (True, True)],
144
        name_fn=lambda x, y: '{}_{}'.format(
145
            'fastpath' if not x else 'slowpath', 'inplace' if y else 'outplace'
146
        )
147
    )
148
    def test_parity(self, device, dtype, op, noncontiguous, inplace):
149
        if inplace:
150
            _, _, func, ref = self._get_funcs(op)
151
        else:
152
            func, ref, _, _ = self._get_funcs(op)
153
        for sample in op.sample_inputs(device, dtype, noncontiguous=noncontiguous):
154
            ref_kwargs = sample.kwargs
155
            # div promotes ints to floats, so we cannot go on the fastpath there
156
            div_slowpath = dtype in integral_types_and(torch.bool) and op.name == '_foreach_div'
157
            expect_fastpath = not (noncontiguous or sample.disable_fastpath or div_slowpath)
158
            ref_input, ctxmgr = sample.input, nullcontext()
159
            if inplace:
160
                with torch.no_grad():
161
                    ref_input = [t.clone().detach() for t in sample.input]
162
                ctxmgr = InplaceForeachVersionBumpCheck(self, sample.input)
163
            try:
164
                with ctxmgr:
165
                    actual = func([sample.input, *sample.args], self.is_cuda, expect_fastpath, **sample.kwargs)
166
            except Exception as e:
167
                with (
168
                    self.assertRaisesRegex(type(e), re.escape(str(e)))
169
                    if not (op.has_no_in_place or not op.supports_out)
170
                    else self.assertRaises(type(e))
171
                ):
172
                    ref([ref_input, *sample.ref_args], **ref_kwargs)
173
            else:
174
                expected = ref([ref_input, *sample.ref_args], **ref_kwargs)
175
                self.assertEqual(expected, actual)
176

177
    def _binary_test(
178
        self,
179
        dtype, op, ref, inputs, is_fastpath, is_inplace,
180
        *,
181
        alpha, scalar_self_arg: bool,
182
    ):
183
        ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1]] if is_inplace else inputs
184
        try:
185
            with InplaceForeachVersionBumpCheck(self, inputs[0]) if op.is_inplace else nullcontext():
186
                actual = op(inputs, self.is_cuda, is_fastpath)
187
        except RuntimeError as e:
188
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
189
                if not scalar_self_arg:
190
                    ref(ref_inputs)
191
                else:
192
                    [ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
193
        else:
194
            expected = ref(ref_inputs) if not scalar_self_arg else [ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
195
            self.assertEqual(actual, expected)
196
        if alpha is not None and not scalar_self_arg:
197
            kwargs = {'alpha': alpha}
198
            ref_inputs = inputs
199
            try:
200
                op_kwargs = {}
201
                op_kwargs.update(kwargs)
202
                with InplaceForeachVersionBumpCheck(self, inputs[0]) if op.is_inplace else nullcontext():
203
                    actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs)
204
            except RuntimeError as e:
205
                with self.assertRaisesRegex(type(e), re.escape(str(e))):
206
                    ref(ref_inputs, **kwargs)
207
            else:
208
                expected = ref(ref_inputs, **kwargs)
209
                if dtype in (torch.float16, torch.bfloat16) and TEST_WITH_ROCM:
210
                    self.assertEqual(expected, actual, atol=1.e-3, rtol=default_tolerances(dtype)[0])
211
                else:
212
                    self.assertEqual(expected, actual)
213

214
    @ops(filter(lambda op: op.supports_scalar_self_arg, foreach_binary_op_db))
215
    @parametrize("is_fastpath", (True, False))
216
    def test_binary_op_with_scalar_self_support(self, device, dtype, op, is_fastpath):
217

218
        def clone(arg):
219
            if isinstance(arg, (list, tuple)):
220
                return [clone(a) for a in arg]
221
            if torch.is_tensor(arg):
222
                return arg.clone().detach().requires_grad_()
223
            else:
224
                return arg
225

226
        scalar_self_arg_test_complete = False
227
        for i, sample in enumerate(op.sample_inputs(device, dtype, noncontiguous=not is_fastpath)):
228
            (rhs_arg,) = sample.args
229
            kwargs = {} or sample.kwargs
230
            alpha = kwargs.pop("alpha", None)
231
            wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
232
            if isinstance(rhs_arg, Number) and not scalar_self_arg_test_complete:
233
                scalar_self_arg_test_complete = True
234
                self._binary_test(
235
                    dtype, wrapped_op, ref, [rhs_arg, sample.input], is_fastpath, False,
236
                    alpha=alpha, scalar_self_arg=True,
237
                )
238
                if op.supports_autograd and dtype == torch.float32:
239
                    transformed_sample = sample.transform(
240
                        get_transform_func(len(sample.input), dtype, device, is_fastpath))
241
                    tensors = transformed_sample.input
242
                    (rhs_arg,) = transformed_sample.args
243
                    ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
244
                    sum(wrapped_op(
245
                        [rhs_arg, tensors], is_cuda=False, expect_fastpath=False
246
                    )).mean().backward()
247
                    sum([ref.func(ref_rhs_arg, t) for t in ref_tensors]).mean().backward()
248
                    self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
249

250
    @ops(foreach_pointwise_op_db)
251
    @parametrize("is_fastpath", (True, False))
252
    def test_pointwise_op_with_tensor_of_scalarlist_overload(self, device, dtype, op, is_fastpath):
253
        for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath):
254
            assert isinstance(sample.args, tuple)
255
            assert len(sample.args) == 2
256
            inputs = [sample.input, *sample.args]
257
            kwargs = sample.kwargs.copy()
258
            disable_fastpath = sample.disable_fastpath and is_fastpath
259
            wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
260
            scalars = kwargs.pop("scalars", None)
261

262
            if is_fastpath and scalars:
263
                sample = sample.transform(lambda t: t.clone().detach() if torch.is_tensor(t) else t)
264
                inputs = [sample.input, *sample.args]
265
                tensor_values = torch.tensor(scalars)
266
                # 1D Tensor of scalars
267
                for is_inplace, op_, ref_ in ((False, wrapped_op, ref), (True, inplace_op, inplace_ref)):
268
                    self._pointwise_test(
269
                        op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
270
                        scalars=tensor_values, **kwargs)
271
                    self._pointwise_test(
272
                        op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
273
                        scalars=tensor_values[0],
274
                        custom_values_err="Expected packed scalar Tensor to be of dimension 1. Got 0 instead.",
275
                        **kwargs,
276
                    )
277
                    if self.is_cuda:
278
                        self._pointwise_test(
279
                            op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
280
                            scalars=tensor_values.cuda(),
281
                            custom_values_err="Expected scalars to be on CPU, got cuda:0 instead.",
282
                            **kwargs,
283
                        )
284
                    self._pointwise_test(
285
                        op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
286
                        scalars=tensor_values[:2],
287
                        custom_values_err=f"Expected length of scalars to match input of length {len(scalars)} but got 2 instead.",
288
                        **kwargs,
289
                    )
290
                    self._pointwise_test(
291
                        op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
292
                        scalars=torch.tensor([[0, 1], [2, 3]])[:, 1],
293
                        custom_values_err="Expected scalars to be contiguous.",
294
                        **kwargs,
295
                    )
296

297
            # Tests of implicit broadcasting
298
            N = len(sample.input)
299
            inputs = [
300
                [make_tensor((N, N), device=device, dtype=dtype, noncontiguous=not is_fastpath) for _ in range(N)],
301
                [
302
                    make_tensor((N - i, 1), device=device, dtype=dtype, noncontiguous=not is_fastpath)
303
                    for i in range(N)
304
                ],
305
                [
306
                    make_tensor((1, N - i), device=device, dtype=dtype, noncontiguous=not is_fastpath)
307
                    for i in range(N)
308
                ],
309
            ]
310
            self._pointwise_test(
311
                wrapped_op, ref, inputs, is_fastpath and disable_fastpath, is_inplace=False,
312
                scalars=scalars, **kwargs)
313
            self._pointwise_test(
314
                inplace_op, inplace_ref, inputs, is_fastpath and disable_fastpath,
315
                is_inplace=True, scalars=scalars, **kwargs)
316

317
    def _pointwise_test(
318
        self,
319
        op, ref, inputs, is_fastpath, is_inplace,
320
        *,
321
        scalars=None, custom_values_err=None, **kwargs
322
    ):
323
        ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1], inputs[2]] if is_inplace else inputs
324
        try:
325
            with (InplaceForeachVersionBumpCheck(self, inputs[0]) if is_inplace else nullcontext()):
326
                actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
327
        except RuntimeError as e:
328
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
329
                ref(ref_inputs, **kwargs)
330
        else:
331
            expected = ref(ref_inputs, **kwargs)
332
            self.assertEqual(expected, actual)
333
        if scalars is not None:
334
            kwargs = kwargs.copy()
335
            kwargs["scalars"] = scalars
336
            try:
337
                actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
338
            except RuntimeError as e:
339
                # Match with error messages from regular non-foreach reference if no
340
                # custom error message was provided.
341
                if custom_values_err is None:
342
                    with self.assertRaisesRegex(type(e), re.escape(str(e))):
343
                        ref(ref_inputs, **kwargs)
344
                else:
345
                    self.assertEqual(re.escape(str(e)), re.escape(custom_values_err))
346
            else:
347
                expected = ref(ref_inputs, **kwargs)
348
                self.assertEqual(expected, actual)
349

350
    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
351
    def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
352
        # TODO: enable empty list case
353
        for tensors in [[torch.randn([0], device=device, dtype=dtype)],
354
                        [torch.empty_strided((0, 1), (0, 0), dtype=dtype, device=device)]]:
355
            res = torch._foreach_add(tensors, 1)
356
            self.assertEqual(res, tensors)
357

358
            torch._foreach_add_(tensors, 1)
359
            self.assertEqual(res, tensors)
360

361
            # Regression test for https://github.com/pytorch/pytorch/issues/113156
362
            torch._foreach_mul_(tensors, 1)
363

364
    @ops(
365
        filter(lambda op: op.supports_out, foreach_binary_op_db),
366
        dtypes=OpDTypes.supported,
367
    )
368
    def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op):
369
        foreach_op, ref = op.method_variant, op.ref
370
        tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
371

372
        if ref == torch.sub and dtype == torch.bool:
373
            with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
374
                [ref(t, 1) for t in tensors]
375
            with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
376
                foreach_op(tensors, 1)
377
            return
378

379
        expected = [ref(t, 1) for t in tensors]
380
        res = foreach_op(tensors, 1)
381
        self.assertEqual(res, expected)
382

383
    @ops(
384
        filter(lambda op: op.supports_out, foreach_binary_op_db),
385
        allowed_dtypes=[torch.float],
386
    )
387
    def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
388
        foreach_op = op.method_variant
389
        tensors = [
390
            torch.tensor([1.1], dtype=torch.float, device=device),
391
            torch.tensor([1], dtype=torch.long, device=device),
392
        ]
393
        runtime_error = None
394
        try:
395
            foreach_op(tensors, 1)
396
        except RuntimeError as e:
397
            runtime_error = e
398
        self.assertIsNone(runtime_error)
399

400
    @skipIfTorchDynamo("Different error msgs, TODO")
401
    @ops(
402
        filter(lambda op: op.supports_out, foreach_binary_op_db),
403
        dtypes=OpDTypes.supported,
404
    )
405
    def test_binary_op_list_error_cases(self, device, dtype, op):
406
        foreach_op, foreach_op_, ref, ref_ = op.method_variant, op.inplace_variant, op.ref, op.ref_inplace
407
        tensors1 = []
408
        tensors2 = []
409
        ops_to_test = [foreach_op, foreach_op_]
410

411
        # Empty lists
412
        for fop in ops_to_test:
413
            with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"):
414
                fop(tensors1, tensors2)
415

416
        # One empty list
417
        tensors1.append(torch.tensor([1], device=device, dtype=dtype))
418
        for fop in ops_to_test:
419
            with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."):
420
                fop(tensors1, tensors2)
421

422
        # Lists have different amount of tensors
423
        tensors2.append(torch.tensor([1], device=device))
424
        tensors2.append(torch.tensor([1], device=device))
425
        for fop in ops_to_test:
426
            with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
427
                fop(tensors1, tensors2)
428
            with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 2 and 1"):
429
                fop(tensors2, tensors1)
430

431
        # Corresponding tensors with different sizes that aren't compatible with broadcast
432
        # If sizes are different then foreach chooses slow path, thus error messages are expected
433
        # to be the same as torch regular function.
434
        tensors1 = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
435
        tensors2 = [torch.ones(11, 11, device=device, dtype=dtype) for _ in range(10)]
436
        try:
437
            foreach_op(tensors1, tensors2)
438
        except RuntimeError as e:
439
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
440
                [ref(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
441
        try:
442
            foreach_op_(tensors1, tensors2)
443
        except RuntimeError as e:
444
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
445
                [ref_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
446

447
        # different devices
448
        if self.device_type == "cuda" and torch.cuda.device_count() > 1:
449
            tensor1 = torch.zeros(10, 10, device="cuda:0", dtype=dtype)
450
            tensor2 = torch.ones(10, 10, device="cuda:1", dtype=dtype)
451
            if dtype == torch.bool and foreach_op == torch._foreach_sub:
452
                for fop in ops_to_test:
453
                    with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
454
                        fop([tensor1], [tensor2])
455
                return
456
            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
457
                foreach_op([tensor1], [tensor2])
458
            if dtype in integral_types_and(torch.bool) and foreach_op == torch._foreach_div:
459
                with self.assertRaisesRegex(RuntimeError, "result type"):
460
                    foreach_op_([tensor1], [tensor2])
461
            else:
462
                with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
463
                    foreach_op_([tensor1], [tensor2])
464

465
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
466
    @ops(
467
        filter(lambda op: op.supports_out, foreach_binary_op_db),
468
        dtypes=OpDTypes.supported,
469
    )
470
    def test_binary_op_list_slow_path(self, device, dtype, op):
471
        foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op)
472
        # 0-strides
473
        tensor1 = make_tensor((10, 10), dtype=dtype, device=device)
474
        tensor2 = make_tensor((1,), device=device, dtype=dtype).expand_as(tensor1)
475
        inputs = ([tensor1], [tensor2])
476
        self._binary_test(
477
            dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
478
            alpha=None, scalar_self_arg=False)
479
        self._binary_test(
480
            dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
481
            alpha=None, scalar_self_arg=False)
482

483
        # different strides
484
        tensor1 = torch.zeros(10, 10, device=device, dtype=dtype)
485
        tensor2 = torch.ones(10, 10, device=device, dtype=dtype)
486
        inputs = ([tensor1], [tensor2.t()])
487
        self._binary_test(
488
            dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
489
            alpha=None, scalar_self_arg=False)
490
        self._binary_test(
491
            dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
492
            alpha=None, scalar_self_arg=False)
493

494
        # non contiguous
495
        tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True)
496
        tensor2 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True)
497
        self.assertFalse(tensor1.is_contiguous())
498
        self.assertFalse(tensor2.is_contiguous())
499
        inputs = ([tensor1], [tensor2])
500
        self._binary_test(
501
            dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
502
            alpha=None, scalar_self_arg=False)
503
        self._binary_test(
504
            dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
505
            alpha=None, scalar_self_arg=False)
506

507
        # sliced tensor
508
        tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype)
509
        tensor2 = make_tensor((5, 2, 1, 3 * 7), device=device, dtype=dtype)[:, :, :, ::7]
510
        inputs = ([tensor1], [tensor2])
511
        self._binary_test(
512
            dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
513
            alpha=None, scalar_self_arg=False)
514
        self._binary_test(
515
            dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
516
            alpha=None, scalar_self_arg=False)
517

518
    @ops(
519
        filter(lambda op: op.supports_out, foreach_binary_op_db),
520
        dtypes=floating_types_and(torch.half, torch.bfloat16),
521
    )
522
    def test_binary_op_float_inf_nan(self, device, dtype, op):
523
        inputs = (
524
            [
525
                torch.tensor([float("inf")], device=device, dtype=dtype),
526
                torch.tensor([-float("inf")], device=device, dtype=dtype),
527
                torch.tensor([float("nan")], device=device, dtype=dtype),
528
                torch.tensor([float("nan")], device=device, dtype=dtype),
529
            ],
530
            [
531
                torch.tensor([-float("inf")], device=device, dtype=dtype),
532
                torch.tensor([float("inf")], device=device, dtype=dtype),
533
                torch.tensor([float("inf")], device=device, dtype=dtype),
534
                torch.tensor([float("nan")], device=device, dtype=dtype),
535
            ],
536
        )
537
        op, ref, inplace_op, inplace_ref = self._get_funcs(op)
538
        self._binary_test(dtype, op, ref, inputs, True, False, alpha=None, scalar_self_arg=False)
539
        self._binary_test(
540
            dtype, inplace_op, inplace_ref, inputs, True, True, alpha=None, scalar_self_arg=False
541
        )
542

543
    # note: Below three tests (postfixed with `_tensors_on_different_devices`)
544
    # checks whether foreach works with lists of tensors on different devices
545
    # but tensors of the same index are on the same device, e.g., ['cuda', 'cpu].
546
    @onlyCUDA
547
    @ops(foreach_unary_op_db)
548
    def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
549
        method, ref, inplace_method, ref_inplace = self._get_funcs(op)
550
        # tensors: ['cuda', 'cpu]
551
        tensors = next(iter(op.sample_inputs(device, dtype, num_input_tensors=[2]))).input
552
        tensors[1] = tensors[1].to("cpu")
553
        if not op.supports_out:
554
            try:
555
                actual = method((tensors,), False, False, zero_size=False)
556
            except RuntimeError as e:
557
                with self.assertRaisesRegex(type(e), str(e)):
558
                    ref((tensors,))
559
            else:
560
                expected = ref((tensors,))
561
                self.assertEqual(expected, actual)
562

563
        try:
564
            inplace_method((tensors,), False, False, zero_size=False)
565
        except RuntimeError as e:
566
            with self.assertRaisesRegex(type(e), str(e)):
567
                ref_inplace((tensors,))
568
        else:
569
            if not op.supports_out:
570
                self.assertEqual(expected, tensors)
571
            else:
572
                self.assertEqual([torch.zeros_like(t) for t in tensors], tensors)
573

574
    @onlyCUDA
575
    @ops(filter(lambda op: op.supports_out, foreach_binary_op_db))
576
    def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
577
        # `tensors1`: ['cuda', 'cpu']
578
        # `tensors2`: ['cuda', 'cpu']
579
        _cuda_tensors = next(iter(op.sample_inputs(device, dtype, num_input_tensors=[2], same_size=True))).input
580
        _cpu_tensors = next(iter(op.sample_inputs("cpu", dtype, num_input_tensors=[2], same_size=True))).input
581
        tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors))
582

583
        foreach_op, foreach_op_ = op.method_variant, op.inplace_variant
584
        native_op, native_op_ = op.ref, op.ref_inplace
585
        try:
586
            actual = foreach_op(tensors1, tensors2)
587
        except RuntimeError as e:
588
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
589
                [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
590
        else:
591
            expected = [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
592
            self.assertEqual(expected, actual)
593
        try:
594
            foreach_op_(tensors1, tensors2)
595
        except RuntimeError as e:
596
            with self.assertRaisesRegex(type(e), re.escape(str(e))):
597
                [native_op_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
598
        else:
599
            self.assertEqual(actual, tensors1)
600

601
    @onlyCUDA
602
    @ops(foreach_pointwise_op_db, allowed_dtypes=floating_types())
603
    def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op):
604
        # tensors1: ['cuda', 'cpu]
605
        # tensors2: ['cuda', 'cpu]
606
        # tensors3: ['cuda', 'cpu]
607
        # first tensorlist is zero-size when float32
608
        _cuda_tensors = list(
609
            op.sample_inputs(device, dtype, num_input_tensors=[3], same_size=True)
610
        )[int(dtype == torch.float32)].input
611
        _cpu_tensors = next(iter(op.sample_inputs("cpu", dtype, num_input_tensors=[3], same_size=True))).input
612
        tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors))
613

614
        foreach_op, foreach_op_, native_op = op.method_variant, op.inplace_variant, op.ref
615
        actual = foreach_op(tensors1, tensors2, tensors3)
616
        expected = [native_op(*_cuda_tensors), native_op(*_cpu_tensors)]
617
        self.assertEqual(expected, actual)
618

619
        # note(mkozuki): Limiting dtypes to FP32&FP64, we can safely run inplace ops.
620
        foreach_op_(tensors1, tensors2, tensors3)
621
        self.assertEqual(expected, tensors1)
622

623
    # note: BFloat16 has the same number of exponent bits as FP32
624
    # so if squared L2 norm overflows in BF16, then it also overflows in FP32.
625
    @onlyCUDA
626
    @ops(foreach_reduce_op_db, allowed_dtypes=(torch.half, torch.bfloat16))
627
    def test_foreach_l2_large_value_input(self, device, dtype, op):
628
        ord, N = 2, 10
629
        max_value = torch.finfo(dtype).max
630
        scaler = torch.tensor([max_value]).sqrt().to(device=device, dtype=dtype)
631
        inputs = ([
632
            t * scaler for t in next(iter(op.sample_inputs(device, dtype, requries_grad=True, num_input_tensors=[N], low=1))).input
633
        ],)
634
        # make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`.
635
        self.assertTrue(scaler * scaler * N > max_value)
636
        fn, ref_fn, *_ = self._get_funcs(op)
637
        actual = fn(inputs, is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False)
638
        expect = ref_fn(inputs, ord=ord)
639

640
        if dtype == torch.float16:
641
            # making sure the reference L2 norm values are in the range of FP16.
642
            self.assertFalse(any(torch.isinf(e) for e in expect))
643
        else:
644
            self.assertTrue(all(
645
                inputs[0][i].numel() == 0 or torch.isinf(e)
646
                for i, e in enumerate(expect)))
647
        self.assertEqual(expect, actual, equal_nan=False)
648

649
    @onlyCUDA
650
    @ops(foreach_reduce_op_db, allowed_dtypes=floating_types())
651
    def test_big_num_tensors(self, device, dtype, op):
652
        N = 600
653
        tensorlist = [make_tensor((2, 3), dtype=dtype, device=device, noncontiguous=False) for _ in range(N)]
654
        fn, ref_fn, *_ = self._get_funcs(op)
655

656
        import math
657
        for ord in (1, 2, math.inf):
658
            actual = fn(inputs=[tensorlist], is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False)
659
            expect = ref_fn(inputs=[tensorlist], ord=ord)
660

661
            self.assertEqual(expect, actual, equal_nan=True)
662

663
    @onlyCUDA
664
    @ops(foreach_reduce_op_db)
665
    def test_foreach_reduce_large_input(self, device, dtype, op):
666
        # test inputs larger than kChunkSize = 65536
667
        ord, N = 2, 65536 * 2
668
        disable_fastpath = True
669
        if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
670
            disable_fastpath = False
671
        inputs = ([make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)],)
672
        wrapped_op, ref, _, _ = self._get_funcs(op)
673
        self.assertEqual(
674
            ref(inputs, ord=ord),
675
            wrapped_op(inputs, self.is_cuda, not disable_fastpath, ord=ord, zero_size=False),
676
        )
677

678
    @onlyCUDA
679
    @ops(
680
        foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_other_op_db,
681
        dtypes=(torch.float,),
682
    )
683
    def test_inplace_foreach_leaf_check_and_grad_fn(self, device, dtype, op):
684
        inplace_op = op.inplace_variant
685
        if inplace_op is None:
686
            self.skipTest("no in-place op available")
687

688
        sample = next(iter(op.sample_inputs(dtype=dtype, device=device, num_input_tensors=[2], same_size=True)))
689
        sample.input[0].requires_grad_(True)
690
        with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
691
            inplace_op(sample.input, *sample.args)
692
        sample.input[1].requires_grad_(True)
693
        with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
694
            inplace_op(sample.input, *sample.args)
695

696
        _tensors = [t.clone().detach().requires_grad_(i == 0) for i, t in enumerate(sample.input)]
697
        tensors = [t.clone() for t in _tensors]
698
        inplace_op(tensors, *sample.args)
699
        self.assertIsNotNone(tensors[0].grad_fn)
700
        self.assertIsNone(tensors[1].grad_fn)
701

702
    @onlyCUDA
703
    @ops(
704
        filter(
705
            lambda op: op.supports_out,
706
            foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_other_op_db,
707
        ),
708
        dtypes=(torch.float,),
709
    )
710
    def test_outplace_with_invalid_grads(self, device, dtype, op):
711
        func, *_ = self._get_funcs(op)
712
        sample = next(iter(op.sample_inputs(dtype=dtype, device=device, requires_grad=True, num_input_tensors=[2], same_size=True)))
713
        self.assertTrue(all(t.requires_grad for t in sample.input))
714
        (out1, out2) = func([sample.input, *sample.args], is_cuda=False, expect_fastpath=False, **sample.kwargs)
715
        out1.backward(torch.ones_like(out1))
716
        self.assertIsNotNone(sample.input[0].grad)
717
        self.assertIsNone(sample.input[1].grad)
718

719
    @ops(
720
        filter(
721
            lambda op: op.backward_requires_result,
722
            foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_other_op_db,
723
        ),
724
        dtypes=(torch.float32,),
725
    )
726
    def test_lifetime_of_grad_fn_when_result_is_saved(self, device, dtype, op):
727

728
        def get_ref(func, sample):
729
            class Foo:
730
                pass
731

732
            out = func((sample.input, *sample.args), is_cuda=False, expect_fastpath=False, **sample.kwargs)
733
            foo = Foo()
734
            meta_dict = out[0].grad_fn.metadata
735
            meta_dict[0] = foo
736
            ref = weakref.ref(foo)
737
            return out, ref
738

739
        def _test(func, sample):
740
            out, ref = get_ref(func, sample)
741
            self.assertIsNotNone(ref())
742
            del out
743
            self.assertIsNone(ref())
744

745
        func = self._get_funcs(op)[0]
746
        for sample in op.sample_inputs(device, dtype, requires_grad=True, num_input_tensors=[1]):
747
            for key in ("is_fastpath", "disable_fastpath"):
748
                if key in sample.kwargs:
749
                    del sample.kwargs[key]
750
            # note: `_foreach_pow.Scalar` and `_foreach_pow.ScalarList` don't depend on `result`
751
            # see: https://github.com/pytorch/pytorch/blob/5403c777/tools/autograd/derivatives.yaml#L3048-L3049
752
            if op.name == "_foreach_pow":
753
                if (
754
                    (isinstance(sample.args[0], list) and isinstance(sample.args[0][0], Number))
755
                    or (isinstance(sample.args[0], Number) and not isinstance(sample.args[0], float))
756
                ):
757
                    continue
758
                if isinstance(sample.args[0], float):
759
                    new_args = (sample.input,)
760
                    sample.input = sample.args[0]
761
                    sample.args = new_args
762
            _test(func, sample)
763

764
    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
765
    def test_tensors_grouping(self):
766
        num_tensors_per_list = 10
767
        num_devices = torch.cuda.device_count()
768
        dtypes = (torch.float16, torch.float32, torch.float64)
769
        list1 = [
770
            torch.tensor(
771
                i,
772
                device=torch.device("cuda", random.randint(0, num_devices - 1)),
773
                dtype=dtypes[random.randint(0, 2)],
774
            ) for i in range(num_tensors_per_list)
775
        ]
776
        list2 = [None for _ in list1]
777
        list3 = [torch.rand_like(t) for t in list1]
778
        nested_tensorlists = [list1, list2, list3]
779
        grouped_tensors = torch.utils._foreach_utils._group_tensors_by_device_and_dtype(nested_tensorlists, with_indices=True)
780
        num_tensors_seen = 0
781
        for (device, dtype), ([l1, l2, l3], indices) in grouped_tensors.items():
782
            for t in itertools.chain(l1, l3):
783
                self.assertEqual(t.device, device)
784
                self.assertEqual(t.dtype, dtype)
785
                num_tensors_seen += 1
786
            self.assertEqual(len(l1), len(l2))
787
            self.assertTrue(all(p is None for p in l2))
788
            for i, index in enumerate(indices):
789
                self.assertEqual(l1[i], list1[index])
790
                self.assertEqual(l2[i], list2[index])
791
                self.assertEqual(l3[i], list3[index])
792
        self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list)
793

794
    @onlyCUDA
795
    def test_0dim_tensor_overload_cpu_ok(self):
796
        tensors = [torch.ones((), device="cuda", dtype=torch.float32) for _ in range(2)]
797
        scalar_cpu_tensor = torch.tensor(4.0, device="cpu")
798

799
        # For mul and div, the scalar is allowed to be on CPU too
800
        actual = torch._foreach_mul(tensors, scalar_cpu_tensor)
801
        self.assertEqual(actual, [t.mul(scalar_cpu_tensor) for t in tensors])
802
        actual = torch._foreach_div(tensors, scalar_cpu_tensor)
803
        self.assertEqual(actual, [t.div(scalar_cpu_tensor) for t in tensors])
804

805

806
    @onlyCUDA
807
    def test_0dim_tensor_overload_exception(self):
808
        # check exceptions of fast path
809
        tensors = [make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2)]
810
        with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be on"):
811
            torch._foreach_add(tensors, torch.tensor(1.0, device="cpu"), alpha=1.0)
812

813
        tensors = [make_tensor((2, 2), dtype=torch.float, device=d) for d in ("cpu", "cuda")]
814
        with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"):
815
            torch._foreach_mul(tensors, torch.tensor([1.0, 1.0], device="cuda"))
816
        with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"):
817
            torch._foreach_add(tensors, torch.tensor([1.0, 1.0], device="cuda"))
818

819
    @onlyCUDA
820
    @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
821
    def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op):
822
        foreach_copy_ = op.inplace_variant
823
        copy_ = op.ref_inplace
824
        for non_blocking in (False, True):
825
            for sample in op.sample_inputs(device, dtype, noncontiguous=False):
826
                with torch.no_grad():
827
                    ref_input = [t.clone().detach() for t in sample.input]
828
                foreach_copy_(sample.input, sample.args[0], non_blocking)
829
                for t, s in zip(ref_input, sample.args[0]):
830
                    copy_(t, s, non_blocking)
831
                self.assertEqual(sample.input, ref_input)
832
                if torch.cuda.device_count() > 1:
833
                    device = torch.device("cuda", 1)
834
                    rhs_tensors = [t.to(device) for t in sample.args[0]]
835
                    foreach_copy_(sample.input, rhs_tensors, non_blocking)
836
                    for t, s in zip(ref_input, rhs_tensors):
837
                        copy_(t, s, non_blocking)
838
                    self.assertEqual(ref_input, sample.input)
839

840
    # Test reverse-mode & forward-mode AD if supported.
841
    @onlyCUDA
842
    @ops(
843
        foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_reduce_op_db + foreach_other_op_db,
844
        dtypes=OpDTypes.supported,
845
        allowed_dtypes=(torch.float64, torch.complex128),
846
    )
847
    @parametrize("inplace", (False, True), name_fn=lambda x: "inplace" if x else "outplace")
848
    def test_autodiff(self, device, dtype, op, inplace):
849
        if not (op.supports_autograd or op.supports_forward_ad):
850
            self.skipTest("neither reverse mode nor forward mode supported")
851
        if (not inplace) and not op.supports_out:
852
            self.skipTest("out-of-place not implemented")
853
        if inplace and op.has_no_in_place:
854
            self.skipTest("in-place not implemented")
855

856
        # note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex.
857
        if (not inplace) and dtype == torch.float64 and op.name in (
858
            "_foreach_acos", "_foreach_asin", "_foreach_log10", "_foreach_log1p", "_foreach_log2",
859
            "_foreach_log", "_foreach_pow", "_foreach_sqrt",
860
        ):
861
            value_range = {"low": 0.5, "high": 1.0}
862
        else:
863
            value_range = {}
864
        for sample in op.sample_inputs(
865
            device, dtype, requires_grad=True, num_input_tensors=[5], **value_range,
866
        ):
867
            # Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])`
868
            if op.name == "_foreach_pow" and isinstance(sample.input, Number):
869
                continue
870

871
            func = None
872
            if inplace:
873
                # Call `clone` to avoid inplace modifications likewise
874
                # `torch.testing._internal.common_utils.TestGradients._get_safe_inplace`
875
                def inplace_func(*tensorlist):
876
                    kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {}
877
                    op.inplace_variant(tuple(t.clone() for t in tensorlist), *sample.args, **kwargs)
878
                    return tensorlist
879
                func = inplace_func
880
            else:
881
                def outplace_func(*tensorlist):
882
                    kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {}
883
                    return op.method_variant(tensorlist, *sample.args, **kwargs)
884
                func = outplace_func
885

886
            working_sample, err_msg_pattern = check_autodiff_sample(op, sample, dtype, inplace)
887

888
            def call_gradcheck():
889
                gradcheck(
890
                    func,
891
                    sample.input,
892
                    raise_exception=True,
893
                    check_forward_ad=op.supports_forward_ad,
894
                    check_batched_forward_grad=False,
895
                    check_backward_ad=op.supports_autograd,
896
                    check_batched_grad=False,
897
                )
898

899
            if not working_sample:
900
                if not err_msg_pattern:
901
                    # lhs of float64 and rhs of complex.
902
                    continue
903
                with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
904
                    call_gradcheck()
905
                continue
906
            call_gradcheck()
907

908
            # Test per-tensor `grad_fn` behavior.
909
            if inplace and op.supports_inplace_autograd:
910
                # per-tensor `grad_fn` check.
911
                hook_buffer = []
912

913
                def get_grad_fn_hook(i):
914

915
                    def hook(grad_inputs, grad_outputs) -> None:
916
                        hook_buffer.append(i)
917

918
                    return hook
919

920
                _inputs = [t.clone().detach().requires_grad_() for t in sample.input]
921
                inputs = [t.clone() for t in _inputs]
922
                kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {}
923
                op.inplace_variant(inputs, *sample.args, **kwargs)
924

925
                self.assertEqual(len({t.grad_fn for t in inputs}), len(inputs))
926

927
                for i, t in enumerate(inputs):
928
                    t.grad_fn.register_hook(get_grad_fn_hook(i))
929

930
                torch.autograd.grad(
931
                    inputs[0],
932
                    inputs=(_inputs[0],),
933
                    grad_outputs=(torch.rand_like(inputs[0]),),
934
                    retain_graph=True,
935
                )
936
                self.assertEqual(hook_buffer, [0])
937
                hook_buffer.clear()
938

939
                # tensors have different shapes.
940
                sum_of_cloned_tensors = torch.cat([t.view(-1) for t in inputs]).sum()
941
                grad_output = torch.rand_like(sum_of_cloned_tensors)
942
                torch.autograd.grad(
943
                    sum_of_cloned_tensors,
944
                    inputs=tuple(_inputs),
945
                    grad_outputs=(grad_output,),
946
                    retain_graph=False,
947
                )
948
                self.assertEqual(hook_buffer, list(reversed(range(len(inputs)))))
949

950

951
# TODO(crcrpar): Hide this inside torch/testing/_internal.
952
# would end up adding another layer to `foreach_inputs_sample_func.__call__`
953
# so that we can use this function as something like the first argument of `filter` function.
954
# Even after moving this function to testing, I personally think it'd be better to check the error message.
955
def check_autodiff_sample(op, sample, dtype, is_inplace):
956
    if op.name == "_foreach_abs" and is_inplace and dtype == torch.complex128:
957
        return False, "In-place abs is not supported for complex tensors."
958
    if (
959
        op.name == "_foreach_sub"
960
        and (
961
            (isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
962
            or isinstance(sample.args[0], bool)
963
        )
964
    ):
965
        return False, _BOOL_SUB_ERR_MSG
966
    if op.name == "_foreach_norm" and (not is_inplace):
967
        return (
968
            False,
969
            "Trying to set a forward gradient that has a different size than that of the original Tensor, "
970
            "this is not supported. Tensor is of size [] while the given forward gradient is of size [1, 1]."
971
        )
972
    rhs_arg_has_complex_number = sample.args and ((
973
        isinstance(sample.args[0], list)
974
        and any(isinstance(a, complex) for a in sample.args[0])
975
    ) or (
976
        isinstance(sample.args[0], complex)
977
    ))
978
    if rhs_arg_has_complex_number and dtype == torch.float64:
979
        if op.name in ("_foreach_clamp_max", "_foreach_clamp_min", "_foreach_maximum", "_foreach_minimum"):
980
            return False, "clamp is not supported for complex types"
981
        if not is_inplace:
982
            return False, ""
983
        else:
984
            if op.name == "_foreach_pow":
985
                return False, "Found dtype Double but expected ComplexDouble"
986
            if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
987
                return False, "result type ComplexDouble can't be cast to the desired output type Double"
988
    return True, ""
989

990

991
instantiate_device_type_tests(TestForeach, globals())
992

993

994
if __name__ == "__main__":
995
    run_tests()
996

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.