pytorch

Форк
0
/
test_binary_ufuncs.py 
4473 строки · 175.4 Кб
1
# Owner(s): ["module: tests"]
2

3
import torch
4
import numpy as np
5

6
import sys
7
import itertools
8
from itertools import chain
9
from itertools import product
10
import math
11
import random
12
from numbers import Number
13
import warnings
14
import operator
15
from functools import partial
16

17
import torch.autograd.forward_ad as fwAD
18
from torch import inf, nan
19
from torch.testing._internal.common_utils import (
20
    TestCase,
21
    slowTest,
22
    iter_indices,
23
    run_tests,
24
    gradcheck,
25
    torch_to_numpy_dtype_dict,
26
    numpy_to_torch_dtype_dict,
27
    TEST_SCIPY,
28
    set_default_dtype,
29
    skipIfTorchDynamo,
30
)
31
from torch.testing._internal.common_device_type import (
32
    expectedFailureMeta,
33
    instantiate_device_type_tests,
34
    onlyCUDA,
35
    onlyCPU,
36
    dtypes,
37
    dtypesIfCUDA,
38
    dtypesIfCPU,
39
    deviceCountAtLeast,
40
    precisionOverride,
41
    onlyNativeDeviceTypes,
42
    skipIf,
43
    ops,
44
    OpDTypes,
45
    skipMeta,
46
)
47
from torch.testing import make_tensor
48
from torch.testing._internal.common_dtype import (
49
    all_types_and_complex_and,
50
    all_types_and,
51
    integral_types,
52
    complex_types,
53
    integral_types_and,
54
    floating_types_and,
55
    floating_and_complex_types,
56
    get_all_math_dtypes,
57
    get_all_int_dtypes,
58
)
59
from torch.testing._internal.common_methods_invocations import (
60
    binary_ufuncs,
61
    binary_ufuncs_and_refs,
62
    generate_elementwise_binary_tensors,
63
    generate_elementwise_binary_small_value_tensors,
64
    generate_elementwise_binary_large_value_tensors,
65
    generate_elementwise_binary_extremal_value_tensors,
66
    generate_elementwise_binary_broadcasting_tensors,
67
    generate_elementwise_binary_with_scalar_samples,
68
    generate_elementwise_binary_with_scalar_and_type_promotion_samples,
69
)
70

71
if TEST_SCIPY:
72
    import scipy.special
73
    import scipy.integrate
74

75
# TODO: update to use opinfos consistently
76
class TestBinaryUfuncs(TestCase):
77
    # Generic tests for elementwise binary (AKA binary universal (u) functions (funcs))
78
    # TODO: below contiguous tensor results are compared with a variety of noncontiguous results.
79
    #   It would be interesting to have the lhs and rhs have different discontiguities.
80

81
    # Helper for comparing torch tensors and NumPy arrays
82
    # TODO: should this or assertEqual also validate that strides are equal?
83
    def assertEqualHelper(
84
        self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs
85
    ):
86
        assert isinstance(actual, torch.Tensor)
87

88
        # Some NumPy functions return scalars, not arrays
89
        if isinstance(expected, Number):
90
            self.assertEqual(actual.item(), expected, msg=msg, **kwargs)
91
        elif isinstance(expected, np.ndarray):
92
            # Handles exact dtype comparisons between arrays and tensors
93
            if exact_dtype:
94
                # Allows array dtype to be float32 when comparing with bfloat16 tensors
95
                #   since NumPy doesn't support the bfloat16 dtype
96
                # Also ops like scipy.special.erf, scipy.special.erfc, etc, promote float16
97
                # to float32
98
                if expected.dtype == np.float32:
99
                    assert actual.dtype in (
100
                        torch.float16,
101
                        torch.bfloat16,
102
                        torch.float32,
103
                    )
104
                else:
105
                    assert expected.dtype == torch_to_numpy_dtype_dict[actual.dtype]
106

107
            self.assertEqual(
108
                actual,
109
                torch.from_numpy(expected).to(actual.dtype),
110
                msg,
111
                exact_device=False,
112
                **kwargs,
113
            )
114
        else:
115
            self.assertEqual(actual, expected, msg, exact_device=False, **kwargs)
116

117
    # Tests that the function and its (array-accepting) reference produce the same
118
    #   values on given tensors
119
    def _test_reference_numerics(self, dtype, op, gen, equal_nan=True):
120
        def _helper_reference_numerics(
121
            expected, actual, msg, exact_dtype, equal_nan=True
122
        ):
123
            if not torch.can_cast(
124
                numpy_to_torch_dtype_dict[expected.dtype.type], dtype
125
            ):
126
                exact_dtype = False
127

128
            if dtype is torch.bfloat16 and expected.dtype == np.float32:
129
                # Ref: https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_utils.py#L1149
130
                self.assertEqualHelper(
131
                    actual,
132
                    expected,
133
                    msg,
134
                    dtype=dtype,
135
                    exact_dtype=exact_dtype,
136
                    rtol=16e-3,
137
                    atol=1e-5,
138
                )
139
            else:
140
                self.assertEqualHelper(
141
                    actual,
142
                    expected,
143
                    msg,
144
                    dtype=dtype,
145
                    equal_nan=equal_nan,
146
                    exact_dtype=exact_dtype,
147
                )
148

149
        for sample in gen:
150
            # Each sample input acquired from the generator is just one lhs tensor
151
            #   and one rhs tensor
152
            l = sample.input
153
            r = sample.args[0]
154

155
            numpy_sample = sample.numpy()
156
            l_numpy = numpy_sample.input
157
            r_numpy = numpy_sample.args[0]
158

159
            actual = op(l, r)
160
            expected = op.ref(l_numpy, r_numpy)
161

162
            # Crafts a custom error message for smaller, printable tensors
163
            def _numel(x):
164
                if isinstance(x, torch.Tensor):
165
                    return x.numel()
166
                # Assumes x is a scalar
167
                return 1
168

169
            if _numel(l) <= 100 and _numel(r) <= 100:
170
                msg = (
171
                    "Failed to produce expected results! Input lhs tensor was"
172
                    f" {l}, rhs tensor was {r}, torch result is {actual}, and reference result is"
173
                    f" {expected}."
174
                )
175
            else:
176
                msg = None
177

178
            exact_dtype = True
179
            if isinstance(actual, torch.Tensor):
180
                _helper_reference_numerics(
181
                    expected, actual, msg, exact_dtype, equal_nan
182
                )
183
            else:
184
                for x, y in zip(expected, actual):
185
                    # testing multi-outputs results
186
                    _helper_reference_numerics(x, y, msg, exact_dtype, equal_nan)
187

188
    # The following tests only apply to elementwise binary operators with references
189
    binary_ufuncs_with_references = list(
190
        filter(lambda op: op.ref is not None and op.ref is not None, binary_ufuncs)
191
    )
192

193
    @ops(binary_ufuncs_with_references)
194
    def test_reference_numerics(self, device, dtype, op):
195
        gen = generate_elementwise_binary_tensors(op, device=device, dtype=dtype)
196
        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
197

198
    @ops(binary_ufuncs_with_references)
199
    def test_reference_numerics_small_values(self, device, dtype, op):
200
        if dtype is torch.bool:
201
            self.skipTest("Doesn't support bool!")
202

203
        gen = generate_elementwise_binary_small_value_tensors(
204
            op, device=device, dtype=dtype
205
        )
206
        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
207

208
    @ops(
209
        binary_ufuncs_with_references,
210
        allowed_dtypes=(
211
            torch.int16,
212
            torch.int32,
213
            torch.int64,
214
            torch.float16,
215
            torch.bfloat16,
216
            torch.float32,
217
            torch.float64,
218
            torch.complex64,
219
            torch.complex128,
220
        ),
221
    )
222
    def test_reference_numerics_large_values(self, device, dtype, op):
223
        gen = generate_elementwise_binary_large_value_tensors(
224
            op, device=device, dtype=dtype
225
        )
226
        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
227

228
    @ops(
229
        binary_ufuncs_with_references,
230
        allowed_dtypes=(
231
            torch.float16,
232
            torch.bfloat16,
233
            torch.float32,
234
            torch.float64,
235
            torch.complex64,
236
            torch.complex128,
237
        ),
238
    )
239
    def test_reference_numerics_extremal_values(self, device, dtype, op):
240
        gen = generate_elementwise_binary_extremal_value_tensors(
241
            op, device=device, dtype=dtype
242
        )
243
        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
244

245
    # tests broadcasting and noncontiguous broadcasting behavior
246
    @ops(
247
        binary_ufuncs_with_references,
248
        allowed_dtypes=(
249
            torch.long,
250
            torch.float32,
251
        ),
252
    )
253
    def test_broadcasting(self, device, dtype, op):
254
        gen = generate_elementwise_binary_broadcasting_tensors(
255
            op, device=device, dtype=dtype
256
        )
257
        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
258

259
    @ops(
260
        binary_ufuncs_with_references,
261
        allowed_dtypes=(torch.long, torch.float32, torch.complex64),
262
    )
263
    def test_scalar_support(self, device, dtype, op):
264
        gen = generate_elementwise_binary_with_scalar_samples(
265
            op, device=device, dtype=dtype
266
        )
267
        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
268
        gen = generate_elementwise_binary_with_scalar_and_type_promotion_samples(
269
            op, device=device, dtype=dtype
270
        )
271
        self._test_reference_numerics(dtype, op, gen, equal_nan=True)
272

273

274
    @ops(binary_ufuncs)
275
    def test_contig_vs_every_other(self, device, dtype, op):
276
        lhs = make_tensor(
277
            (1026,), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs
278
        )
279
        rhs = make_tensor(
280
            (1026,), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs
281
        )
282

283
        lhs_non_contig = lhs[::2]
284
        rhs_non_contig = rhs[::2]
285

286
        self.assertTrue(lhs.is_contiguous())
287
        self.assertTrue(rhs.is_contiguous())
288

289
        self.assertFalse(lhs_non_contig.is_contiguous())
290
        self.assertFalse(rhs_non_contig.is_contiguous())
291

292
        expected = op(lhs, rhs)[::2]
293
        actual = op(lhs_non_contig, rhs_non_contig)
294
        self.assertEqual(expected, actual)
295

296
    @ops(binary_ufuncs)
297
    def test_contig_vs_transposed(self, device, dtype, op):
298
        lhs = make_tensor(
299
            (789, 357), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs
300
        )
301
        rhs = make_tensor(
302
            (789, 357), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs
303
        )
304

305
        lhs_non_contig = lhs.T
306
        rhs_non_contig = rhs.T
307

308
        self.assertTrue(lhs.is_contiguous())
309
        self.assertTrue(rhs.is_contiguous())
310

311
        self.assertFalse(lhs_non_contig.is_contiguous())
312
        self.assertFalse(rhs_non_contig.is_contiguous())
313

314
        expected = op(lhs, rhs).T
315
        actual = op(lhs_non_contig, rhs_non_contig)
316
        self.assertEqual(expected, actual)
317

318
    @ops(binary_ufuncs)
319
    def test_non_contig(self, device, dtype, op):
320
        shapes = ((5, 7), (1024,))
321
        for shape in shapes:
322
            lhs = make_tensor(
323
                shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
324
            )
325
            rhs = make_tensor(
326
                shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
327
            )
328

329
            lhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[
330
                ..., 0
331
            ]
332
            lhs_non_contig.copy_(lhs)
333

334
            rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[
335
                ..., 0
336
            ]
337
            rhs_non_contig.copy_(rhs)
338

339
            self.assertTrue(lhs.is_contiguous())
340
            self.assertTrue(rhs.is_contiguous())
341

342
            self.assertFalse(lhs_non_contig.is_contiguous())
343
            self.assertFalse(rhs_non_contig.is_contiguous())
344

345
            expected = op(lhs, rhs)
346
            actual = op(lhs_non_contig, rhs_non_contig)
347
            self.assertEqual(expected, actual)
348

349
    @ops(binary_ufuncs)
350
    def test_non_contig_index(self, device, dtype, op):
351
        shape = (2, 2, 1, 2)
352
        lhs = make_tensor(
353
            shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
354
        )
355
        rhs = make_tensor(
356
            shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
357
        )
358

359
        lhs_non_contig = lhs[:, 1, ...]
360
        lhs = lhs_non_contig.contiguous()
361

362
        rhs_non_contig = rhs[:, 1, ...]
363
        rhs = rhs_non_contig.contiguous()
364

365
        self.assertTrue(lhs.is_contiguous())
366
        self.assertTrue(rhs.is_contiguous())
367

368
        self.assertFalse(lhs_non_contig.is_contiguous())
369
        self.assertFalse(rhs_non_contig.is_contiguous())
370

371
        expected = op(lhs, rhs)
372
        actual = op(lhs_non_contig, rhs_non_contig)
373
        self.assertEqual(expected, actual)
374

375
    @ops(binary_ufuncs)
376
    def test_non_contig_expand(self, device, dtype, op):
377
        shapes = [(1, 3), (1, 7), (5, 7)]
378
        for shape in shapes:
379
            lhs = make_tensor(
380
                shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
381
            )
382
            rhs = make_tensor(
383
                shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
384
            )
385

386
            lhs_non_contig = lhs.clone().expand(3, -1, -1)
387
            rhs_non_contig = rhs.clone().expand(3, -1, -1)
388

389
            self.assertTrue(lhs.is_contiguous())
390
            self.assertTrue(rhs.is_contiguous())
391

392
            self.assertFalse(lhs_non_contig.is_contiguous())
393
            self.assertFalse(rhs_non_contig.is_contiguous())
394

395
            expected = op(lhs, rhs)
396
            actual = op(lhs_non_contig, rhs_non_contig)
397
            for i in range(3):
398
                self.assertEqual(expected, actual[i])
399

400
    @ops(binary_ufuncs)
401
    def test_contig_size1(self, device, dtype, op):
402
        shape = (5, 100)
403
        lhs = make_tensor(
404
            shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
405
        )
406
        rhs = make_tensor(
407
            shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
408
        )
409

410
        lhs = lhs[:1, :50]
411
        lhs_alt = torch.empty(lhs.size(), device=device, dtype=dtype)
412
        lhs_alt.copy_(lhs)
413

414
        rhs = rhs[:1, :50]
415
        rhs_alt = torch.empty(rhs.size(), device=device, dtype=dtype)
416
        rhs_alt.copy_(rhs)
417

418
        self.assertTrue(lhs.is_contiguous())
419
        self.assertTrue(rhs.is_contiguous())
420

421
        self.assertTrue(lhs_alt.is_contiguous())
422
        self.assertTrue(rhs_alt.is_contiguous())
423

424
        expected = op(lhs, rhs)
425
        actual = op(lhs_alt, rhs_alt)
426
        self.assertEqual(expected, actual)
427

428
    @ops(binary_ufuncs)
429
    def test_contig_size1_large_dim(self, device, dtype, op):
430
        shape = (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4)
431
        lhs = make_tensor(
432
            shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
433
        )
434
        rhs = make_tensor(
435
            shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
436
        )
437

438
        lhs = lhs[:1, :, :, :, :, :, :, :, :, :, :, :]
439
        lhs_alt = torch.empty(lhs.size(), device=device, dtype=dtype)
440
        lhs_alt.copy_(lhs)
441

442
        rhs = rhs[:1, :, :, :, :, :, :, :, :, :, :, :]
443
        rhs_alt = torch.empty(rhs.size(), device=device, dtype=dtype)
444
        rhs_alt.copy_(rhs)
445

446
        self.assertTrue(lhs.is_contiguous())
447
        self.assertTrue(rhs.is_contiguous())
448

449
        self.assertTrue(lhs_alt.is_contiguous())
450
        self.assertTrue(rhs_alt.is_contiguous())
451

452
        expected = op(lhs, rhs)
453
        actual = op(lhs_alt, rhs_alt)
454
        self.assertEqual(expected, actual)
455

456
    @ops(binary_ufuncs)
457
    def test_batch_vs_slicing(self, device, dtype, op):
458
        shape = (32, 512)
459
        lhs = make_tensor(
460
            shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
461
        )
462
        rhs = make_tensor(
463
            shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
464
        )
465

466
        expected = op(lhs, rhs)
467

468
        actual = []
469
        for idx in range(32):
470
            actual.append(op(lhs[idx], rhs[idx]))
471
        actual = torch.stack(actual)
472

473
        self.assertEqual(expected, actual)
474

475
    # Tests that elementwise binary operators participate in type promotion properly
476
    # NOTE: because the cross-product of all possible type promotion tests is huge, this
477
    #   just spot checks some handwritten cases.
478
    # NOTE: It may be possible to refactor this test into something simpler
479
    @ops(binary_ufuncs_and_refs, dtypes=OpDTypes.none)
480
    def test_type_promotion(self, device, op):
481
        supported_dtypes = op.supported_dtypes(torch.device(device).type)
482

483
        make_lhs = partial(
484
            make_tensor, (5,), device=device, **op.lhs_make_tensor_kwargs
485
        )
486
        make_rhs = partial(
487
            make_tensor, (5,), device=device, **op.rhs_make_tensor_kwargs
488
        )
489

490
        make_rhs_scalar_tensor = partial(
491
            make_tensor, (), device='cpu', **op.rhs_make_tensor_kwargs
492
        )
493

494
        def _supported(dtypes):
495
            return all(x in supported_dtypes for x in dtypes)
496

497
        # int x int type promotion
498
        if _supported((torch.int16, torch.int32, torch.int64)):
499
            lhs_i16 = make_lhs(dtype=torch.int16)
500
            lhs_i32 = make_lhs(dtype=torch.int32)
501
            lhs_i64 = make_lhs(dtype=torch.int64)
502

503
            rhs_i16 = make_rhs(dtype=torch.int16)
504
            rhs_i32 = make_rhs(dtype=torch.int32)
505
            rhs_i64 = make_rhs(dtype=torch.int64)
506

507
            if op.promotes_int_to_float:
508
                default_dtype = torch.get_default_dtype()
509
                self.assertEqual(op(lhs_i16, rhs_i32).dtype, default_dtype)
510
                self.assertEqual(
511
                    op(lhs_i16, rhs_i32),
512
                    op(lhs_i16.to(default_dtype), rhs_i32.to(default_dtype)),
513
                )
514

515
                self.assertEqual(op(lhs_i32, rhs_i64).dtype, default_dtype)
516
                self.assertEqual(
517
                    op(lhs_i32, rhs_i64),
518
                    op(lhs_i32.to(default_dtype), rhs_i64.to(default_dtype)),
519
                )
520
            elif op.always_returns_bool:
521
                self.assertEqual(op(lhs_i16, rhs_i32).dtype, torch.bool)
522
                self.assertEqual(op(lhs_i32, rhs_i64).dtype, torch.bool)
523
            else:  # standard type promotion
524
                self.assertEqual(op(lhs_i16, rhs_i32).dtype, torch.int32)
525
                self.assertEqual(
526
                    op(lhs_i16, rhs_i32), op(lhs_i16.to(torch.int32), rhs_i32)
527
                )
528

529
                self.assertEqual(op(lhs_i32, rhs_i64).dtype, torch.int64)
530
                self.assertEqual(
531
                    op(lhs_i32, rhs_i64), op(lhs_i32.to(torch.int64), rhs_i64)
532
                )
533

534
            if op.supports_out:
535
                if not op.promotes_int_to_float:
536
                    # Integers can be safely cast to other integer types
537
                    out = torch.empty_like(lhs_i64)
538
                    self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.int64)
539
                    self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False)
540

541
                    out = torch.empty_like(lhs_i16)
542
                    self.assertEqual(op(lhs_i32, rhs_i64, out=out).dtype, torch.int16)
543
                else:
544
                    # Float outs cannot be safely cast to integer types
545
                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
546
                        op(lhs_i16, rhs_i32, out=torch.empty_like(lhs_i64))
547

548
                if not op.always_returns_bool:
549
                    # Neither integer nor float outs can be cast to bool
550
                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
551
                        op(
552
                            lhs_i16,
553
                            rhs_i32,
554
                            out=torch.empty_like(lhs_i64, dtype=torch.bool),
555
                        )
556

557
                # All these output types can be cast to any float or complex type
558
                out = torch.empty_like(lhs_i64, dtype=torch.float16)
559
                self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.float16)
560

561
                out = torch.empty_like(lhs_i64, dtype=torch.bfloat16)
562
                self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.bfloat16)
563

564
                out = torch.empty_like(lhs_i64, dtype=torch.float32)
565
                self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.float32)
566
                self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False)
567

568
                out = torch.empty_like(lhs_i64, dtype=torch.complex64)
569
                self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.complex64)
570
                self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False)
571

572
        # float x float type promotion
573
        if _supported((torch.float32, torch.float64)):
574
            lhs_f32 = make_lhs(dtype=torch.float32)
575
            lhs_f64 = make_lhs(dtype=torch.float64)
576

577
            rhs_f32 = make_rhs(dtype=torch.float32)
578
            rhs_f64 = make_rhs(dtype=torch.float64)
579

580
            if op.always_returns_bool:
581
                self.assertEqual(op(lhs_f32, rhs_f64).dtype, torch.bool)
582
            else:  # normal float type promotion
583
                self.assertEqual(op(lhs_f32, rhs_f64).dtype, torch.float64)
584
                self.assertEqual(
585
                    op(lhs_f32, rhs_f64), op(lhs_f32.to(torch.float64), rhs_f64)
586
                )
587

588
            if op.supports_out:
589
                # All these output types can be cast to any float or complex type
590
                out = torch.empty_like(lhs_f64, dtype=torch.float16)
591
                self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.float16)
592

593
                out = torch.empty_like(lhs_f64, dtype=torch.bfloat16)
594
                self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.bfloat16)
595
                self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
596

597
                out = torch.empty_like(lhs_f64, dtype=torch.float32)
598
                self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.float32)
599
                self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
600

601
                out = torch.empty_like(lhs_f64, dtype=torch.complex64)
602
                self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.complex64)
603
                self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
604

605
                if not op.always_returns_bool:
606
                    # float outs can't be cast to an integer dtype
607
                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
608
                        op(
609
                            lhs_f32,
610
                            rhs_f64,
611
                            out=torch.empty_like(lhs_f64, dtype=torch.int64),
612
                        )
613
                else:
614
                    # bool outs can be cast to an integer dtype
615
                    out = torch.empty_like(lhs_f64, dtype=torch.int64)
616
                    self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.int64)
617
                    self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
618

619
        # complex x complex type promotion
620
        if _supported((torch.complex64, torch.complex128)):
621
            lhs_c64 = make_lhs(dtype=torch.complex64)
622
            lhs_c128 = make_lhs(dtype=torch.complex128)
623

624
            rhs_c64 = make_rhs(dtype=torch.complex64)
625
            rhs_c128 = make_rhs(dtype=torch.complex128)
626

627
            if op.always_returns_bool:
628
                self.assertEqual(op(lhs_c64, lhs_c128).dtype, torch.bool)
629
            else:  # normal complex type promotion
630
                self.assertEqual(op(lhs_c64, rhs_c128).dtype, torch.complex128)
631
                self.assertEqual(
632
                    op(lhs_c64, rhs_c128), op(lhs_c64.to(torch.complex128), rhs_c128)
633
                )
634

635
            if op.supports_out:
636
                # All these output types can be cast to any or complex type
637
                out = torch.empty_like(lhs_c64, dtype=torch.complex64)
638

639
                self.assertEqual(op(lhs_c64, rhs_c128, out=out).dtype, torch.complex64)
640
                result = op(lhs_c64, rhs_c128)
641
                self.assertEqual(result, out.to(result.dtype))
642

643
                if not op.always_returns_bool:
644
                    # complex outs can't be cast to float types
645
                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
646
                        op(
647
                            lhs_c64,
648
                            rhs_c128,
649
                            out=torch.empty_like(lhs_c64, dtype=torch.float64),
650
                        )
651
                    # complex outs can't be cast to an integer dtype
652
                    with self.assertRaisesRegex(RuntimeError, "can't be cast"):
653
                        op(
654
                            lhs_c64,
655
                            rhs_c128,
656
                            out=torch.empty_like(lhs_c64, dtype=torch.int64),
657
                        )
658
                else:
659
                    # bool outs can be cast to a float type
660
                    out = torch.empty_like(lhs_c64, dtype=torch.float64)
661
                    self.assertEqual(
662
                        op(lhs_c64, rhs_c128, out=out).dtype, torch.float64
663
                    )
664
                    self.assertEqual(op(lhs_c64, rhs_c128), out, exact_dtype=False)
665

666
                    # bool outs can be cast to an integer dtype
667
                    out = torch.empty_like(lhs_f64, dtype=torch.int64)
668
                    self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.int64)
669
                    self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
670

671
        # int x float type promotion
672
        # Note: float type is the result dtype
673
        if _supported((torch.long, torch.float32)):
674
            lhs_i64 = make_lhs(dtype=torch.int64)
675
            rhs_f32 = make_rhs(dtype=torch.float32)
676

677
            result = op(lhs_i64, rhs_f32)
678
            expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool
679
            self.assertEqual(result.dtype, expected_dtype)
680

681
        # float x complex type promotion
682
        # Note: complex type with highest "value type" is the result dtype
683
        if _supported((torch.float64, torch.complex64)):
684
            lhs_f64 = make_lhs(dtype=torch.float64)
685
            rhs_c64 = make_rhs(dtype=torch.complex64)
686

687
            result = op(lhs_f64, rhs_c64)
688
            expected_dtype = (
689
                torch.complex128 if not op.always_returns_bool else torch.bool
690
            )
691
            self.assertEqual(result.dtype, expected_dtype)
692

693
        # int x float scalar type promotion
694
        # Note: default float dtype is the result dtype
695
        if _supported((torch.int64, torch.float32)) and op.supports_rhs_python_scalar:
696
            lhs_i64 = make_lhs(dtype=torch.int64)
697
            rhs_f_scalar = 1.0
698

699
            result = op(lhs_i64, rhs_f_scalar)
700
            expected_dtype = (
701
                torch.get_default_dtype() if not op.always_returns_bool else torch.bool
702
            )
703
            self.assertEqual(result.dtype, expected_dtype)
704

705
            # repeats with a scalar float tensor, which should set the dtype
706
            rhs_f32_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float32)
707
            result = op(lhs_i64, rhs_f32_scalar_tensor)
708
            expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool
709
            self.assertEqual(result.dtype, expected_dtype)
710

711
            # Additional test with double
712
            if _supported((torch.float64,)):
713
                rhs_f64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float64)
714
                result = op(lhs_i64, rhs_f64_scalar_tensor)
715
                expected_dtype = (
716
                    torch.float64 if not op.always_returns_bool else torch.bool
717
                )
718
                self.assertEqual(result.dtype, expected_dtype)
719

720
        # float x complex scalar type promotion
721
        # Note: result dtype is complex with highest "value type" among all tensors
722
        if (
723
            _supported((torch.float32, torch.complex64))
724
            and op.supports_rhs_python_scalar
725
        ):
726
            lhs_f32 = make_lhs(dtype=torch.float32)
727
            rhs_c_scalar = complex(1, 1)
728

729
            result = op(lhs_f32, rhs_c_scalar)
730
            expected_dtype = (
731
                torch.complex64 if not op.always_returns_bool else torch.bool
732
            )
733
            self.assertEqual(result.dtype, expected_dtype)
734

735
            # repeats with a scalar complex tensor
736
            rhs_c64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex64)
737
            result = op(lhs_f32, rhs_c64_scalar_tensor)
738
            expected_dtype = (
739
                torch.complex64 if not op.always_returns_bool else torch.bool
740
            )
741
            self.assertEqual(result.dtype, expected_dtype)
742

743
            # Additional test with complexdouble
744
            if _supported((torch.complex128,)):
745
                rhs_c128_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex128)
746
                result = op(lhs_f32, rhs_c128_scalar_tensor)
747
                # Value type of 1D+ Tensor (lhs_f32) takes priority over scalar tensor (rhs_c128).
748
                expected_dtype = (
749
                    torch.complex64 if not op.always_returns_bool else torch.bool
750
                )
751
                self.assertEqual(result.dtype, expected_dtype)
752

753
        # float x float scalar tensor
754
        # Note: result dtype is the type of the float tensor
755
        if _supported((torch.float32, torch.float64)) and op.supports_rhs_python_scalar:
756
            lhs_f32 = make_lhs(dtype=torch.float32)
757
            rhs_f64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float64)
758

759
            result = op(lhs_f32, rhs_f64_scalar_tensor)
760
            expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool
761
            self.assertEqual(result.dtype, expected_dtype)
762

763
        # complex x complex scalar tensor
764
        # Note: result dtype is the type of the complex tensor
765
        if (
766
            _supported((torch.complex64, torch.complex128))
767
            and op.supports_rhs_python_scalar
768
        ):
769
            lhs_c64 = make_lhs(dtype=torch.complex64)
770
            rhs_c128_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex128)
771

772
            result = op(lhs_c64, rhs_c128_scalar_tensor)
773
            expected_dtype = (
774
                torch.complex64 if not op.always_returns_bool else torch.bool
775
            )
776
            self.assertEqual(result.dtype, expected_dtype)
777

778
        # scalar  x scalar
779
        # Note: result dtype is default float type
780
        if op.supports_two_python_scalars and _supported((torch.long, torch.float32)):
781
            rhs_f_scalar = 2.
782
            for lhs in (1, 1.):
783
                result = op(lhs, rhs_f_scalar)
784
                expected_dtype = torch.get_default_dtype() if not op.always_returns_bool else torch.bool
785
                self.assertEqual(result.dtype, expected_dtype)
786

787
    # TODO: move to error input test
788
    @ops(binary_ufuncs, allowed_dtypes=(torch.float32,))
789
    def test_not_broadcastable(self, device, dtype, op):
790
        for shape_lhs, shape_rhs in (
791
            ((2,), (3,)),
792
            ((3, 1), (2, 1)),
793
            ((1, 3, 2), (3,)),
794
            ((3, 1, 2), (2, 1, 2)),
795
        ):
796
            lhs = make_tensor(
797
                shape_lhs, device=device, dtype=dtype, **op.lhs_make_tensor_kwargs
798
            )
799
            rhs = make_tensor(
800
                shape_rhs, device=device, dtype=dtype, **op.rhs_make_tensor_kwargs
801
            )
802

803
            try:
804
                broadcasted_shape = op(lhs, rhs).shape
805
            except RuntimeError:
806
                continue
807

808
            msg = (
809
                f"On {device}, torch.{op.name} broadcasts inputs shapes {shape_lhs} and {shape_rhs} into "
810
                f"{broadcasted_shape}, although they are not broadcastable."
811
            )
812
            raise AssertionError(msg)
813

814
    def test_add_broadcast_empty(self, device):
815
        # empty + empty
816
        self.assertRaises(
817
            RuntimeError,
818
            lambda: torch.randn(5, 0, device=device) + torch.randn(0, 5, device=device),
819
        )
820
        self.assertEqual(
821
            torch.randn(5, 0, device=device),
822
            torch.randn(0, device=device) + torch.randn(5, 0, device=device),
823
        )
824
        self.assertEqual(
825
            torch.randn(5, 0, 0, device=device),
826
            torch.randn(0, device=device) + torch.randn(5, 0, 1, device=device),
827
        )
828

829
        # scalar + empty
830
        self.assertEqual(
831
            torch.randn(5, 0, 6, device=device),
832
            torch.randn((), device=device) + torch.randn(5, 0, 6, device=device),
833
        )
834

835
        # non-empty, empty
836
        self.assertEqual(
837
            torch.randn(0, device=device),
838
            torch.randn(0, device=device) + torch.randn(1, device=device),
839
        )
840
        self.assertEqual(
841
            torch.randn(0, 7, 0, 6, 5, 0, 7, device=device),
842
            torch.randn(0, 7, 0, 6, 5, 0, 1, device=device)
843
            + torch.randn(1, 1, 5, 1, 7, device=device),
844
        )
845
        self.assertRaises(
846
            RuntimeError,
847
            lambda: torch.randn(7, 0, device=device) + torch.randn(2, 1, device=device),
848
        )
849

850
    def test_addcmul_scalars_as_floats(self, device):
851
        # zero-dim variables that don't require grad should bind to scalar arguments
852
        x = torch.tensor(2.0)
853
        y = torch.tensor(3.0, device=device)
854
        # 3 + (3 * 3) * 2
855
        self.assertEqual(y.addcmul(y, y, value=x), 21)
856

857
        x = torch.tensor(2.0, requires_grad=True)
858
        self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x))
859

860
    # Tests that the binary operators and, or, and xor (as well as their reflected and inplace versions)
861
    # work properly (AKA &, ||, ^ and &=, |=, ^=)
862
    @dtypes(*integral_types_and(torch.bool))
863
    def test_bitwise_ops(self, device, dtype):
864
        # Tensor x Tensor and Tensor x Scalar ops
865
        ops = (
866
            operator.and_,
867
            operator.iand,
868
            operator.or_,
869
            operator.ior,
870
            operator.xor,
871
            operator.ixor,
872
        )
873
        inplace_ops = (operator.iand, operator.ior, operator.ixor)
874
        shapes = ((5,), (15, 15), (500, 500))
875

876
        for op, shape in itertools.product(ops, shapes):
877
            # Tests tensor x tensor case
878
            a = make_tensor(shape, device=device, dtype=dtype)
879
            b = make_tensor(shape, device=device, dtype=dtype)
880
            a_np = a.cpu().clone().numpy()
881
            b_np = b.cpu().clone().numpy()
882
            self.assertEqual(op(a, b), op(a_np, b_np))
883

884
            # Tests tensor x scalar case
885
            a = make_tensor(shape, device=device, dtype=dtype)
886
            b_scalar = make_tensor((), device="cpu", dtype=dtype).item()
887
            a_np = a.cpu().clone().numpy()
888
            self.assertEqual(op(a, b_scalar), op(a_np, b_scalar))
889

890
            # Tests scalar x tensor case
891
            a_scalar = make_tensor((), device="cpu", dtype=dtype).item()
892
            b = make_tensor(shape, device=device, dtype=dtype)
893
            b_np = b.cpu().clone().numpy()
894
            self.assertEqual(op(a_scalar, b), op(a_scalar, b_np))
895

896
            # Tests scalar x tensor case (for ops which aren't inplace)
897
            if op in inplace_ops:
898
                # Tests tensor x tensor case
899
                a = make_tensor(shape, device=device, dtype=dtype)
900
                b = make_tensor(shape, device=device, dtype=dtype)
901
                a_np = a.cpu().clone().numpy()
902
                b_np = b.cpu().clone().numpy()
903
                op(a, b)
904
                op(a_np, b_np)
905
                self.assertEqual(a, a_np)
906

907
                # Tests tensor x scalar case
908
                a = make_tensor(shape, device=device, dtype=dtype)
909
                b_scalar = make_tensor((), device="cpu", dtype=dtype).item()
910
                a_np = a.cpu().clone().numpy()
911
                op(a, b_scalar)
912
                op(a_np, b_scalar)
913
                self.assertEqual(a, a_np)
914

915
    def test_inplace_division(self, device):
916
        t = torch.rand(5, 5, device=device)
917
        id_before = id(t)
918
        t /= 2
919
        id_after = id(t)
920
        self.assertEqual(id_before, id_after)
921

922
    @dtypes(*all_types_and(torch.half, torch.bfloat16))
923
    def test_div_rounding_modes(self, device, dtype):
924
        if dtype.is_floating_point:
925
            low, high = -10.0, 10.0
926
        else:
927
            info = torch.iinfo(dtype)
928
            low, high = info.min, info.max
929

930
        a = make_tensor((100,), dtype=dtype, device=device, low=low, high=high)
931
        b = make_tensor((100,), dtype=dtype, device=device, low=low, high=high)
932

933
        # Avoid division by zero so we can test (a / b) * b == a
934
        if dtype.is_floating_point:
935
            eps = 0.1
936
            b[(-eps < b) & (b < eps)] = eps
937
        else:
938
            b[b == 0] = 1
939

940
        if not dtype.is_floating_point:
941
            # floor(a / b) * b can be < a, so fixup slightly to avoid underflow
942
            a = torch.where(a < 0, a + b, a)
943

944
        d_true = torch.divide(a, b, rounding_mode=None)
945
        self.assertTrue(d_true.is_floating_point())
946
        self.assertEqual(d_true * b, a.to(d_true.dtype))
947

948
        d_floor = torch.divide(a, b, rounding_mode="floor")
949
        if dtype not in (torch.bfloat16, torch.half):
950
            self.assertEqual(d_floor * b + torch.remainder(a, b), a)
951
        else:
952
            self.assertEqual(
953
                d_floor * b + torch.remainder(a.float(), b.float()),
954
                a,
955
                exact_dtype=False,
956
            )
957

958
        d_trunc = torch.divide(a, b, rounding_mode="trunc")
959
        rounding_unsupported = (
960
            dtype == torch.half
961
            and device != "cuda"
962
            or dtype == torch.bfloat16
963
            and device != "cpu"
964
        )
965
        d_ref = d_true.float() if rounding_unsupported else d_true
966
        self.assertEqual(d_trunc, d_ref.trunc().to(dtype))
967

968
    @dtypes(torch.bfloat16, torch.half, torch.float32, torch.float64)
969
    def test_div_rounding_nonfinite(self, device, dtype):
970

971
        # Compare division of special floating point values against NumPy
972
        num = torch.tensor(
973
            [1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan],
974
            dtype=dtype,
975
        )
976
        # Divide by zero is tested separately
977
        denom = num[num != 0]
978

979
        a, b = num[None, :].clone(), denom[:, None].clone()
980

981
        # Compare bfloat16 against NumPy float
982
        exact_dtype = dtype != torch.bfloat16
983
        if exact_dtype:
984
            an, bn = a.cpu().numpy(), b.cpu().numpy()
985
        else:
986
            an, bn = a.float().cpu().numpy(), b.float().cpu().numpy()
987

988
        for mode, np_ref in ((None, np.true_divide), ("floor", np.floor_divide)):
989
            expect = np_ref(an, bn)
990
            kwargs = dict(rounding_mode=mode) if mode is not None else {}
991
            with set_default_dtype(torch.double):
992
                actual = torch.divide(a, b, **kwargs)
993
            self.assertEqual(
994
                actual,
995
                torch.from_numpy(expect),
996
                exact_device=False,
997
                exact_dtype=exact_dtype,
998
            )
999

1000
        # Compare contiguous (likely vectorized) against non-contiguous (not vectorized)
1001
        a_noncontig = torch.empty([2 * i for i in a.shape], dtype=dtype, device=device)[
1002
            ::2, ::2
1003
        ]
1004
        a_noncontig[:] = a
1005
        b_noncontig = torch.empty([2 * i for i in b.shape], dtype=dtype, device=device)[
1006
            ::2, ::2
1007
        ]
1008
        b_noncontig[:] = b
1009

1010
        for rounding_mode in (None, "trunc", "floor"):
1011
            expect = torch.divide(a_noncontig, b_noncontig, rounding_mode=rounding_mode)
1012
            actual = torch.divide(a, b, rounding_mode=rounding_mode)
1013
            self.assertEqual(actual, expect)
1014

1015
    @dtypes(torch.bfloat16, torch.half, torch.float32, torch.float64)
1016
    def test_divide_by_zero_rounding(self, device, dtype):
1017
        a = torch.tensor(
1018
            [1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan],
1019
            dtype=dtype,
1020
        )
1021
        exact_dtype = dtype != torch.bfloat16
1022
        if exact_dtype:
1023
            an = a.cpu().numpy()
1024
        else:
1025
            an = a.float().cpu().numpy()
1026

1027
        zero = torch.zeros_like(a)
1028

1029
        # NOTE: NumPy's floor_divide rounding changed in 1.20.0 to be consistent with divide
1030
        expect = np.divide(an, 0)
1031
        for rounding_mode in (None, "floor"):
1032
            # CPU scalar
1033
            actual = torch.divide(a, 0, rounding_mode=rounding_mode)
1034
            self.assertEqual(actual, expect, exact_dtype=exact_dtype)
1035
            # Device tensor
1036
            actual = torch.divide(a, zero, rounding_mode=rounding_mode)
1037
            self.assertEqual(actual, expect, exact_dtype=exact_dtype)
1038

1039
    @dtypes(*all_types_and(torch.half))
1040
    def test_div_rounding_numpy(self, device, dtype):
1041
        info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)
1042
        low, high = info.min, info.max
1043

1044
        # Compare division of random values against NumPy
1045
        a = make_tensor((4096,), dtype=dtype, device=device, low=low, high=high)
1046
        b = make_tensor((4096,), dtype=dtype, device=device, low=low, high=high)
1047

1048
        # Avoid division by zero which raises for integers and, for floats,
1049
        # NumPy 1.20 changed floor_divide to follow IEEE rules for inf/nan
1050
        # after dividing by zero.
1051
        b[b == 0] = 1
1052

1053
        # Compare bfloat16 against NumPy float
1054
        exact_dtype = dtype != torch.bfloat16
1055

1056
        if exact_dtype:
1057
            an, bn = a.cpu().numpy(), b.cpu().numpy()
1058
        else:
1059
            an, bn = a.float().cpu().numpy(), b.float().cpu().numpy()
1060

1061
        for mode, np_ref in (
1062
            (None, np.true_divide),
1063
            ("floor", np.floor_divide),
1064
            ("trunc", lambda a, b: np.trunc(np.true_divide(a, b)).astype(a.dtype)),
1065
        ):
1066
            expect = torch.from_numpy(np_ref(an, bn))
1067

1068
            kwargs = dict(rounding_mode=mode) if mode is not None else {}
1069
            # Contiguous (likely vectorized)
1070
            with set_default_dtype(torch.double):
1071
                actual = torch.divide(a, b, **kwargs)
1072
            self.assertEqual(
1073
                actual, expect, exact_device=False, exact_dtype=exact_dtype
1074
            )
1075

1076
            # Non-contiguous (not vectorized)
1077
            expect = expect[::2]
1078
            with set_default_dtype(torch.double):
1079
                actual = torch.divide(a[::2], b[::2], **kwargs)
1080

1081
            self.assertEqual(
1082
                actual, expect, exact_device=False, exact_dtype=exact_dtype
1083
            )
1084

1085
    @dtypes(*complex_types())
1086
    def test_complex_div_underflow_overflow(self, device, dtype):
1087
        # test to make sure the complex division does not produce underflow or overflow
1088
        # in the intermediate of its calculations
1089
        # NOTE: the calculation still produces an error if the number is greater than
1090
        # finfo.max / 2, but hopefully people realized that it's a dangerous region to work with
1091
        finfo = torch.finfo(dtype)
1092
        nom_lst = [complex(finfo.min / 2, finfo.min / 2),
1093
                   complex(finfo.max / 2, finfo.max / 2),
1094
                   complex(finfo.tiny, finfo.tiny),
1095
                   complex(finfo.tiny, 0.0),
1096
                   complex(0.0, 0.0)]
1097
        denom_lst = [complex(finfo.min / 2, finfo.min / 2),
1098
                     complex(finfo.max / 2, finfo.max / 2),
1099
                     complex(finfo.tiny, finfo.tiny),
1100
                     complex(0.0, finfo.tiny),
1101
                     complex(finfo.tiny, finfo.tiny)]
1102
        expected_lst = [complex(1.0, 0.0),
1103
                        complex(1.0, 0.0),
1104
                        complex(1.0, 0.0),
1105
                        complex(0.0, -1.0),
1106
                        complex(0.0, 0.0)]
1107
        nom = torch.tensor(nom_lst, dtype=dtype, device=device)
1108
        denom = torch.tensor(denom_lst, dtype=dtype, device=device)
1109
        expected = torch.tensor(expected_lst, dtype=dtype, device=device)
1110
        res = nom / denom
1111
        self.assertEqual(res, expected)
1112

1113
    # Tests that trying to add, inplace, a CUDA tensor to a CPU tensor
1114
    #   throws the correct error message
1115
    @onlyCUDA
1116
    def test_cross_device_inplace_error_msg(self, device):
1117
        a = torch.tensor(2.0)
1118
        b = torch.tensor(2.0, device=device)
1119
        with self.assertRaisesRegex(
1120
            RuntimeError, "Expected all tensors to be on the same device"
1121
        ):
1122
            a += b
1123

1124
    # TODO: refactor this test into a more generic one, it's parked here currently
1125
    @onlyNativeDeviceTypes
1126
    def test_out_resize_warning(self, device):
1127
        a = torch.tensor((1, 2, 3), device=device, dtype=torch.float32)
1128
        b = torch.tensor((4, 5, 6), device=device, dtype=torch.float32)
1129

1130
        unary_inputs = (a,)
1131
        binary_inputs = (a, b)
1132
        unary_ops = (torch.ceil, torch.exp)
1133
        binary_ops = (torch.add, torch.sub)
1134
        for op in unary_ops + binary_ops:
1135
            with warnings.catch_warnings(record=True) as w:
1136
                warnings.simplefilter("always")
1137
                inputs = unary_inputs if op in unary_ops else binary_inputs
1138

1139
                # No warnings
1140
                op(*inputs, out=torch.empty(3, device=device))
1141
                op(*inputs, out=torch.empty(0, device=device))
1142
                self.assertEqual(len(w), 0)
1143

1144
                # Cases that throw warnings
1145
                op(*inputs, out=torch.empty(2, device=device))
1146
                self.assertEqual(len(w), 1)
1147
        # test that multi-d out doesn't trigger segfault
1148
        arg1 = (torch.ones(2, 1, device=device), torch.ones(1, device=device))
1149
        arg2 = (torch.ones(2, device=device), torch.ones(1, 1, device=device))
1150
        outs = (torch.ones(2, 1, 1, 1, device=device), torch.ones(2, 2, 2, 2, device=device))
1151

1152
        for a1, a2, o in zip(arg1, arg2, outs):
1153
            with warnings.catch_warnings(record=True) as w:
1154
                warnings.simplefilter("always")
1155
                torch.mul(a1, a2, out=o)
1156
                self.assertEqual(len(w), 1)
1157

1158
    # Verifies that the inplace dunders (like idiv) actually are in place
1159
    @expectedFailureMeta  # UserWarning not triggered
1160
    @onlyNativeDeviceTypes
1161
    def test_inplace_dunders(self, device):
1162
        t = torch.randn((1,), device=device)
1163
        expected = t.data_ptr()
1164
        t += 1
1165
        t -= 1
1166
        t *= 1
1167
        t /= 1
1168
        t **= 1
1169
        t //= 1
1170
        t %= 1
1171
        self.assertEqual(expected, t.data_ptr())
1172

1173
    def check_internal_mem_overlap(
1174
        self, inplace_op, num_inputs, dtype, device, expected_failure=False
1175
    ):
1176
        if isinstance(inplace_op, str):
1177
            inplace_op = getattr(torch.Tensor, inplace_op)
1178
        input = torch.randn(1, dtype=dtype, device=device).expand(3, 3)
1179
        inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)]
1180
        if not expected_failure:
1181
            with self.assertRaisesRegex(RuntimeError, "single memory location"):
1182
                inplace_op(*inputs)
1183
        else:
1184
            with self.assertRaises(AssertionError):
1185
                with self.assertRaisesRegex(RuntimeError, "single memory location"):
1186
                    inplace_op(*inputs)
1187

1188
    def unary_check_input_output_mem_overlap(
1189
        self, data, sz, op, expected_failure=False
1190
    ):
1191
        def _test(op, output, input):
1192
            output_exp = torch.empty_like(output)
1193
            op(input, out=output_exp)
1194
            self.assertEqual(op(input, out=output), output_exp, msg=op.__name__)
1195

1196
        # output is identical to input:
1197
        _test(op, output=data[0:sz], input=data[0:sz])
1198
        # output and input are independent:
1199
        _test(op, output=data[0:sz], input=data[sz : 2 * sz])
1200
        # output partially overlaps with input:
1201
        if not expected_failure:
1202
            with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
1203
                _test(op, data[0:sz], data[1 : sz + 1])
1204
        else:
1205
            with self.assertRaises(AssertionError):
1206
                with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
1207
                    _test(op, data[0:sz], data[1 : sz + 1])
1208

1209
    def binary_check_input_output_mem_overlap(self, op, device, expected_failure=False):
1210
        sz = 3
1211
        data = torch.randn(2 * sz, device=device)
1212
        other = torch.randn(sz, device=device)
1213

1214
        self.unary_check_input_output_mem_overlap(
1215
            data,
1216
            sz,
1217
            lambda input, out: op(other, input, out=out),
1218
            expected_failure=expected_failure,
1219
        )
1220

1221
        self.unary_check_input_output_mem_overlap(
1222
            data,
1223
            sz,
1224
            lambda input, out: op(input, other, out=out),
1225
            expected_failure=expected_failure,
1226
        )
1227

1228
    @dtypes(torch.double)
1229
    def test_binary_op_mem_overlap(self, device, dtype):
1230
        ops = [
1231
            ("add", True, True, "cpu"),
1232
            ("add", True, True, "cuda"),
1233
            ("mul", True, True, "cpu"),
1234
            ("mul", True, True, "cuda"),
1235
            ("sub", True, True, "cpu"),
1236
            ("sub", True, True, "cuda"),
1237
            ("div", True, True, "cpu"),
1238
            ("div", True, True, "cuda"),
1239
            ("pow", True, True, "cpu"),
1240
            ("pow", True, True, "cuda"),
1241
            ("fmod", True, True, "cpu"),
1242
            ("fmod", True, True, "cuda"),
1243
            ("atan2", True, True, "cpu"),
1244
            ("atan2", True, True, "cuda"),
1245
            ("hypot", True, True, "cpu"),
1246
            ("hypot", True, True, "cuda"),
1247
            ("igamma", True, True, "cpu"),
1248
            ("igamma", True, True, "cuda"),
1249
            ("igammac", True, True, "cpu"),
1250
            ("igammac", True, True, "cuda"),
1251
            ("nextafter", True, True, "cpu"),
1252
            ("nextafter", True, True, "cuda"),
1253
            ("le", True, True, "cpu"),
1254
            ("le", True, True, "cuda"),
1255
            ("lt", True, True, "cpu"),
1256
            ("lt", True, True, "cuda"),
1257
            ("ge", True, True, "cpu"),
1258
            ("ge", True, True, "cuda"),
1259
            ("gt", True, True, "cpu"),
1260
            ("gt", True, True, "cuda"),
1261
            ("eq", True, True, "cpu"),
1262
            ("eq", True, True, "cuda"),
1263
            ("ne", True, True, "cpu"),
1264
            ("ne", True, True, "cuda"),
1265
            ("logical_and", True, True, "cpu"),
1266
            ("logical_and", True, True, "cuda"),
1267
            ("logical_or", True, True, "cpu"),
1268
            ("logical_or", True, True, "cuda"),
1269
            ("logical_xor", True, True, "cpu"),
1270
            ("logical_xor", True, True, "cuda"),
1271
        ]
1272

1273
        for (
1274
            fn,
1275
            has_input_output_mem_overlap_check,
1276
            has_internal_mem_overlap_check,
1277
            dev,
1278
        ) in ops:
1279
            if dev != device:
1280
                continue
1281
            out_op = getattr(torch, fn)
1282
            inplace_op = getattr(torch.Tensor, fn + "_")
1283
            self.check_internal_mem_overlap(
1284
                inplace_op,
1285
                2,
1286
                dtype,
1287
                device,
1288
                expected_failure=not has_internal_mem_overlap_check,
1289
            )
1290

1291
            self.binary_check_input_output_mem_overlap(
1292
                out_op, device, expected_failure=not has_input_output_mem_overlap_check
1293
            )
1294

1295
    def _do_pow_for_exponents(self, m1, exponents, pow_fn, atol):
1296
        for num in exponents:
1297
            if (
1298
                isinstance(num, int)
1299
                and num < 0
1300
                and not m1.is_floating_point()
1301
                and not m1.is_complex()
1302
            ):
1303
                with self.assertRaisesRegex(
1304
                    RuntimeError,
1305
                    r"Integers to negative integer powers are not allowed\.",
1306
                ):
1307
                    torch.pow(m1[4], num)
1308
            else:
1309
                # base - tensor, exponent - number
1310
                # contiguous
1311
                res1 = torch.pow(m1[4], num)
1312
                res2 = res1.clone().zero_()
1313
                # `math.pow` has issues with complex exponentiation so we need to resort to normal `pow`.
1314
                for i in range(res2.size(0)):
1315
                    res2[i] = pow_fn(m1[4][i], num)
1316
                rtol = 0 if atol is not None else None
1317
                self.assertEqual(res1, res2, atol=atol, rtol=rtol)
1318

1319
                # non-contiguous
1320
                res1 = torch.pow(m1[:, 4], num)
1321
                res2 = res1.clone().zero_()
1322
                for i in range(res2.size(0)):
1323
                    res2[i] = pow_fn(m1[i, 4], num)
1324
                self.assertEqual(res1, res2, atol=atol, rtol=rtol)
1325

1326
                # scalar ** tensor to enforce correct handling of dtypes for __rpow__().
1327
                expected_dtype = torch.result_type(num, m1)
1328
                res1 = num ** m1[4]
1329
                res2 = (
1330
                    torch.tensor(num, dtype=expected_dtype, device=m1.device) ** m1[4]
1331
                )
1332
                self.assertEqual(res1, res2)
1333
                self.assertEqual(res1.dtype, expected_dtype)
1334

1335
    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
1336
    def test_pow(self, device, dtype):
1337
        m1 = torch.empty(0, dtype=dtype, device=device)
1338
        if m1.is_floating_point() or m1.is_complex():
1339
            m1 = (
1340
                make_tensor((100, 100), low=0, high=1, dtype=dtype, device=device) + 0.5
1341
            )
1342
        else:
1343
            # math.pow will overflow and throw exceptions for large integers
1344
            range_high = 4 if dtype in (torch.int8, torch.uint8) else 10
1345
            m1 = make_tensor(
1346
                (100, 100), low=1, high=range_high, dtype=dtype, device=device
1347
            )
1348

1349
        exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3, True, False]
1350
        complex_exponents = [
1351
            -2.5j,
1352
            -1.0j,
1353
            0j,
1354
            1.0j,
1355
            2.5j,
1356
            1.0 + 1.0j,
1357
            -1.0 - 1.5j,
1358
            3.3j,
1359
        ]
1360
        if m1.is_complex():
1361
            self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4)
1362
        else:
1363
            self._do_pow_for_exponents(m1, exponents, math.pow, None)
1364
            will_raise_error = dtype is torch.half and torch.device(device).type == 'cpu'
1365
            if will_raise_error:
1366
                # On CPU,
1367
                # Half Tensor with complex exponents leads to computation dtype
1368
                # of ComplexHalf for which this ops is not supported yet
1369
                with self.assertRaisesRegex(RuntimeError, "not implemented for 'ComplexHalf'"):
1370
                    self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
1371
            else:
1372
                self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
1373

1374
        # base - number, exponent - tensor
1375
        # contiguous
1376
        res1 = torch.pow(3, m1[4])
1377
        res2 = res1.clone().zero_()
1378
        for i in range(res2.size(0)):
1379
            res2[i] = pow(3, m1[4, i])
1380
        self.assertEqual(res1, res2)
1381

1382
        # non-contiguous
1383
        res1 = torch.pow(3, m1[:, 4])
1384
        res2 = res1.clone().zero_()
1385
        for i in range(res2.size(0)):
1386
            res2[i] = pow(3, m1[i][4])
1387
        self.assertEqual(res1, res2)
1388

1389
    # TODO: refactor all these tests using opinfos properly
1390
    def _test_pow(self, base, exponent, np_exponent=None):
1391
        if np_exponent is None:
1392
            np_exponent = exponent
1393

1394
        def to_np(value):
1395
            if isinstance(value, torch.Tensor):
1396
                return value.cpu().numpy()
1397
            return value
1398

1399
        try:
1400
            np_res = np.power(to_np(base), to_np(np_exponent))
1401
            expected = (
1402
                torch.from_numpy(np_res)
1403
                if isinstance(np_res, np.ndarray)
1404
                else torch.tensor(np_res, dtype=base.dtype)
1405
            )
1406
        except ValueError as e:
1407
            err_msg = "Integers to negative integer powers are not allowed."
1408
            self.assertEqual(str(e), err_msg)
1409
            out = torch.empty_like(base)
1410
            test_cases = [
1411
                lambda: base.pow(exponent),
1412
                lambda: base.pow_(exponent),
1413
                lambda: torch.pow(base, exponent),
1414
                lambda: torch.pow(base, exponent, out=out),
1415
            ]
1416
            for test_case in test_cases:
1417
                self.assertRaisesRegex(RuntimeError, err_msg, test_case)
1418
        else:
1419
            if isinstance(base, torch.Tensor):
1420
                actual = base.pow(exponent)
1421
                self.assertEqual(actual, expected.to(actual))
1422
                actual = base.clone()
1423
                # When base is a 0-dim cpu tensor and exp is a cuda tensor, we exp `pow` to work but `pow_` to fail, since
1424
                # `pow` will try to create the output tensor on a cuda device, but `pow_` needs to use the cpu tensor as the output
1425
                if (
1426
                    isinstance(exponent, torch.Tensor)
1427
                    and base.dim() == 0
1428
                    and base.device.type == "cpu"
1429
                    and exponent.device.type == "cuda"
1430
                ):
1431
                    regex = "Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!"
1432
                    self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent)
1433
                elif torch.can_cast(torch.result_type(base, exponent), base.dtype):
1434
                    actual2 = actual.pow_(exponent)
1435
                    self.assertEqual(actual, expected)
1436
                    self.assertEqual(actual2, expected)
1437
                else:
1438
                    self.assertRaisesRegex(
1439
                        RuntimeError,
1440
                        "Found dtype \\w+ but expected \\w+",
1441
                        lambda: actual.pow_(exponent),
1442
                    )
1443

1444
            actual = torch.pow(base, exponent)
1445
            self.assertEqual(actual, expected.to(actual))
1446

1447
            actual2 = torch.pow(base, exponent, out=actual)
1448
            self.assertEqual(actual, expected.to(actual))
1449
            self.assertEqual(actual2, expected.to(actual))
1450

1451
    # We can potentially merge this into OpInfo, but one blocker is that the
1452
    # first input must be a scalar. It is not as simple as just wrapping this in
1453
    # a lambada that switches the inputs, because we also want to test samples inputs
1454
    # where the second input is a scalar. The wrapper would need some more logic.
1455
    def test_pow_scalar_base(self, device):
1456
        a = (
1457
            torch.arange(1, 13, dtype=torch.double, device=device)
1458
            .view(3, 4)
1459
            .requires_grad_()
1460
        )
1461
        gradcheck(lambda a: torch.pow(2, a), (a,))
1462

1463
    # Tests pow() for integral, floating-type tensors, with integral, floating-type
1464
    # exponents (tensor or scalar), respectively. noncontiguous tensors are also tested.
1465
    def test_int_and_float_pow(self, device):
1466
        def _test_int_and_float_pow(dt, low, high, dev):
1467
            test_cases = (
1468
                ((4, 4), 0, (4, 1)),
1469
                ((3, 1), 4, (3, 1)),
1470
                ((2,), 4, (1,)),
1471
                ((1,), 2, ()),
1472
                ((513, 513), 4, (513,)),
1473
                ((5, 5, 5), 5, (5,)),
1474
                ((), 2, ()),
1475
            )
1476
            for base_shape, exp_scalar, exp_shape in test_cases:
1477
                base_tensor = make_tensor(
1478
                    base_shape, dtype=dt, device=dev, low=low, high=high
1479
                )
1480
                # int tensors don't take negative exponents
1481
                if dt in [
1482
                    torch.uint8,
1483
                    torch.int8,
1484
                    torch.int16,
1485
                    torch.int32,
1486
                    torch.int64,
1487
                ]:
1488
                    exp_tensor = make_tensor(
1489
                        exp_shape, dtype=dt, device=dev, low=0, high=high
1490
                    )
1491
                else:
1492
                    exp_tensor = make_tensor(
1493
                        exp_shape, dtype=dt, device=dev, low=low, high=high
1494
                    )
1495
                self._test_pow(base_tensor, exp_scalar)
1496
                self._test_pow(base_tensor, exp_tensor)
1497
                # test non-contiguous tensors as well
1498
                base_tensor = make_tensor(
1499
                    base_shape,
1500
                    dtype=dt,
1501
                    device=dev,
1502
                    low=low,
1503
                    high=high,
1504
                    noncontiguous=True,
1505
                )
1506
                if dt in [
1507
                    torch.uint8,
1508
                    torch.int8,
1509
                    torch.int16,
1510
                    torch.int32,
1511
                    torch.int64,
1512
                ]:
1513
                    exp_tensor = make_tensor(
1514
                        exp_shape,
1515
                        dtype=dt,
1516
                        device=dev,
1517
                        low=0,
1518
                        high=high,
1519
                        noncontiguous=True,
1520
                    )
1521
                else:
1522
                    exp_tensor = make_tensor(
1523
                        exp_shape,
1524
                        dtype=dt,
1525
                        device=dev,
1526
                        low=low,
1527
                        high=high,
1528
                        noncontiguous=True,
1529
                    )
1530
                self._test_pow(base_tensor, exp_scalar)
1531
                self._test_pow(base_tensor, exp_tensor)
1532

1533
        _test_int_and_float_pow(torch.int8, -2, 2, device)
1534
        _test_int_and_float_pow(torch.uint8, 0, 3, device)
1535
        _test_int_and_float_pow(torch.int16, -5, 5, device)
1536
        _test_int_and_float_pow(torch.int64, -10, 10, device)
1537
        _test_int_and_float_pow(torch.int32, -10, 10, device)
1538
        _test_int_and_float_pow(torch.float16, 0.0, 5.0, device)
1539
        _test_int_and_float_pow(torch.float32, 0.0, 10.0, device)
1540
        _test_int_and_float_pow(torch.float64, 0.0, 10.0, device)
1541
        # pow's output would have some NaNs as well
1542
        _test_int_and_float_pow(torch.float32, -10.0, 10.0, device)
1543
        _test_int_and_float_pow(torch.float64, -10.0, 10.0, device)
1544

1545
    # Tests that a Runtime error occurs when a base tensor cannot be resized
1546
    # by pow's inplace variant due to PyTorch's broadcasting semantics.
1547
    def test_pow_inplace_resizing_exception(self, device):
1548
        test_cases = (
1549
            ((), (3,)),
1550
            ((2,), (2, 1)),
1551
            ((2, 1), (2, 2)),
1552
            ((2, 2), (2, 1, 1)),
1553
        )
1554
        test_inputs = [
1555
            (
1556
                make_tensor(
1557
                    base_size, dtype=torch.float64, device=device, high=10.0, low=0.0
1558
                ),
1559
                make_tensor(
1560
                    exp_size, dtype=torch.float64, device=device, high=10.0, low=0.0
1561
                ),
1562
            )
1563
            for base_size, exp_size in test_cases
1564
        ]
1565
        for base, exponent in test_inputs:
1566
            regex = "doesn't match the broadcast shape"
1567
            self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent)
1568

1569
    def test_int_tensor_pow_neg_ints(self, device):
1570
        ints = [
1571
            torch.iinfo(torch.int32).min,
1572
            -3,
1573
            -2,
1574
            -1,
1575
            0,
1576
            1,
1577
            2,
1578
            3,
1579
            torch.iinfo(torch.int32).max,
1580
        ]
1581
        neg_ints = [torch.iinfo(torch.int32).min, -3, -2, -1]
1582
        tensor = torch.tensor(ints, dtype=torch.int32, device=device)
1583
        for pow in neg_ints:
1584
            self._test_pow(tensor, pow)
1585

1586
    def test_long_tensor_pow_floats(self, device):
1587
        ints = [0, 1, 23, 4567]
1588
        floats = [0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0]
1589
        tensor = torch.tensor(ints, dtype=torch.int64, device=device)
1590
        for pow in floats:
1591
            self._test_pow(tensor, pow)
1592

1593
    @dtypes(*[torch.float32, torch.float64])
1594
    def test_float_scalar_pow_float_tensor(self, device, dtype):
1595
        floats = [2.0, -3 / 2, -1.0, -1 / 2, -1 / 3, 0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0]
1596
        exponent_shapes = (
1597
            (1,),
1598
            (2, 2),
1599
            (2, 1),
1600
            (2, 2, 2),
1601
        )
1602
        tensors = [
1603
            make_tensor(shape, dtype=dtype, device=device, low=0)
1604
            for shape in exponent_shapes
1605
        ]
1606
        floats_tensor = torch.tensor(floats, dtype=dtype, device=device)
1607
        for base in floats:
1608
            self._test_pow(base, floats_tensor)
1609
            for tensor in tensors:
1610
                self._test_pow(base, tensor)
1611

1612
    @onlyCUDA
1613
    def test_cuda_tensor_pow_scalar_tensor(self, device):
1614
        cuda_tensors = [
1615
            torch.randn((3, 3), device=device),
1616
            torch.tensor(3.0, device=device),
1617
        ]
1618
        scalar_tensors = [
1619
            torch.tensor(5.0, device="cpu"),
1620
            torch.tensor(-3),
1621
            torch.tensor(1),
1622
        ]
1623
        for base, exp in product(cuda_tensors, scalar_tensors):
1624
            self._test_pow(base, exp)
1625

1626
    @onlyCUDA
1627
    def test_cpu_tensor_pow_cuda_scalar_tensor(self, device):
1628
        cuda_tensors = [
1629
            torch.tensor(5.0, device="cuda"),
1630
            torch.tensor(-3, device="cuda"),
1631
        ]
1632
        for exp in cuda_tensors:
1633
            base = torch.randn((3, 3), device="cpu")
1634
            regex = "Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!"
1635
            self.assertRaisesRegex(RuntimeError, regex, torch.pow, base, exp)
1636
        for exp in cuda_tensors:
1637
            # Binary ops with a cpu + cuda tensor are allowed if the cpu tensor has 0 dimension
1638
            base = torch.tensor(3.0, device="cpu")
1639
            self._test_pow(base, exp)
1640

1641
    @onlyCUDA
1642
    @dtypes(torch.complex64, torch.complex128)
1643
    def test_pow_cuda_complex_extremal_failing(self, device, dtype):
1644
        t = torch.tensor(complex(-1.0, float("inf")), dtype=dtype, device=device)
1645
        with self.assertRaises(AssertionError):
1646
            cuda_out = t.pow(2)
1647
            cpu_out = t.cpu().pow(2)
1648
            self.assertEqual(cpu_out, cuda_out)
1649

1650
    @skipIfTorchDynamo()
1651
    @onlyNativeDeviceTypes
1652
    @dtypes(*all_types_and_complex_and(torch.half))
1653
    def test_complex_scalar_pow_tensor(self, device, dtype):
1654
        complexes = [0.5j, 1.0 + 1.0j, -1.5j, 2.2 - 1.6j, 1 + 0j]
1655
        first_exp = make_tensor((100,), dtype=dtype, device=device, low=-2, high=2)
1656
        second_exp = make_tensor(
1657
            (100,), dtype=dtype, device=device, low=-2, high=2, noncontiguous=True
1658
        )
1659
        first_exp[0] = first_exp[10] = first_exp[20] = 0
1660
        second_exp[0] = second_exp[10] = second_exp[20] = 0
1661
        for base in complexes:
1662
            # On CPU,
1663
            # Half Tensor with complex base leads to computation dtype
1664
            # of ComplexHalf for which this ops is not supported yet
1665
            # NOTE: pow has fast-path when base is 1 which supports
1666
            # ComplexHalf
1667
            will_raise_error = torch.device(device).type == 'cpu' and \
1668
                dtype is torch.half and base != (1 + 0j)
1669
            if will_raise_error:
1670
                with self.assertRaisesRegex(RuntimeError, "not implemented for 'ComplexHalf'"):
1671
                    self._test_pow(base, first_exp)
1672
                    self._test_pow(base, second_exp)
1673
            else:
1674
                self._test_pow(base, first_exp)
1675
                self._test_pow(base, second_exp)
1676

1677
    @onlyNativeDeviceTypes
1678
    @skipMeta
1679
    def test_pow_scalar_type_promotion(self, device):
1680
        # Test against a scalar and non-scalar input
1681
        inputs = [17, [17]]
1682
        for input in inputs:
1683
            # We expect the computation to be performed in uint8 (overflowing to 0), and then cast to int64
1684
            input_tensor_uint8 = torch.tensor(input, dtype=torch.uint8, device=device)
1685
            out_uint8_computation = torch.pow(
1686
                2,
1687
                input_tensor_uint8,
1688
                out=torch.tensor(0, dtype=torch.int64, device=device),
1689
            )
1690

1691
            # Computation should run in int64, and not overflow
1692
            input_tensor_int64 = torch.tensor(input, dtype=torch.int64, device=device)
1693
            out_int64_computation = torch.pow(
1694
                2,
1695
                input_tensor_int64,
1696
                out=torch.tensor(0, dtype=torch.int64, device=device),
1697
            )
1698

1699
            self.assertNotEqual(out_uint8_computation, out_int64_computation)
1700
            self.assertEqual(
1701
                out_uint8_computation.to(dtype=torch.uint8),
1702
                out_int64_computation.to(dtype=torch.uint8),
1703
            )
1704

1705
    def test_tensor_pow_tensor(self, device):
1706
        def rotate(l, n):
1707
            return l[-n:] + l[:-n]
1708

1709
        def test_tensor_pow_tensor(values, torch_type, numpy_type):
1710
            vals_tensor = torch.tensor(values, dtype=torch_type, device=device)
1711
            for i in range(len(values)):
1712
                pows = rotate(values, i)
1713
                pows_tensor = torch.tensor(pows, dtype=torch_type, device=device)
1714
                self._test_pow(vals_tensor, pows_tensor)
1715

1716
        ints = [0, 1, 2, 3]
1717
        test_tensor_pow_tensor(ints, torch.uint8, np.uint8)
1718
        test_tensor_pow_tensor(ints, torch.int8, np.int8)
1719
        test_tensor_pow_tensor(ints, torch.int16, np.int16)
1720
        test_tensor_pow_tensor(ints, torch.int32, np.int32)
1721
        test_tensor_pow_tensor(ints, torch.int64, np.int64)
1722

1723
        floats = [-3.0, -2.0, -1.0, -1 / 2, -1 / 3, 0.0, 1 / 3, 1 / 2, 1.0, 2.0, 3.0]
1724
        test_tensor_pow_tensor(floats, torch.float16, np.float16)
1725
        test_tensor_pow_tensor(floats, torch.float32, np.float32)
1726
        test_tensor_pow_tensor(floats, torch.float64, np.float64)
1727

1728
    def test_logical_xor_with_nontrivial_alignment(self, device):
1729
        # test tensor that is not aligned to multiple of 16 bytes
1730
        size = 128
1731
        a = torch.randn(size, device=device) > 0
1732
        b = torch.randn(size, device=device) > 0
1733
        c = torch.randn(size, device=device) > 0
1734
        non_trivial_alignment = [1, 2, 4, 8, 15]
1735
        for i in non_trivial_alignment:
1736
            for j in non_trivial_alignment:
1737
                for k in non_trivial_alignment:
1738
                    a_ = a[i : 100 + i]
1739
                    b_ = b[j : 100 + j]
1740
                    c_ = c[k : 100 + k]
1741
                    torch.logical_xor(a_, b_, out=c_)
1742
                    for x, y, z in zip(a_.tolist(), b_.tolist(), c_.tolist()):
1743
                        self.assertEqual(x ^ y, z)
1744

1745
    @dtypes(torch.float)
1746
    def test_add_with_tail(self, device, dtype):
1747
        # test tensor where there is a tail which is not a multiple
1748
        # of GPU warp size
1749
        for tail_size in [1, 63, 67, 130]:
1750
            size = 4096 + tail_size
1751
            a = torch.randn(size, device=device, dtype=dtype)
1752
            b = torch.randn(size, device=device, dtype=dtype)
1753
            c = a + b
1754
            for x, y, z in zip(a.tolist(), b.tolist(), c.tolist()):
1755
                self.assertEqual(x + y, z)
1756

1757
    # Tests that CUDA tensors on different devices cannot be used in the same
1758
    # binary operation, and that CUDA "scalars" cannot be used in the same
1759
    # binary operation as non-scalar CPU tensors.
1760
    @deviceCountAtLeast(2)
1761
    @onlyCUDA
1762
    def test_cross_device_binary_ops(self, devices):
1763
        vals = (1.0, (2.0,))
1764
        cpu_tensor = torch.randn(2, 2)
1765

1766
        def do_test(op, a, b):
1767
            with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1768
                op(a, b)
1769
            with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1770
                op(b, a)
1771
            with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1772
                op(a, cpu_tensor)
1773
            with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1774
                op(cpu_tensor, a)
1775

1776
        for op in (
1777
            operator.add,
1778
            torch.add,
1779
            operator.sub,
1780
            torch.sub,
1781
            operator.mul,
1782
            torch.mul,
1783
            operator.truediv,
1784
            torch.true_divide,
1785
            operator.floordiv,
1786
            torch.floor_divide,
1787
        ):
1788
            for a, b in product(vals, vals):
1789
                a = torch.tensor(a, device=devices[0])
1790
                b = torch.tensor(b, device=devices[1])
1791

1792
            do_test(op, a, b)
1793

1794
    # This test ensures that a scalar Tensor can be safely used
1795
    # in a binary operation in conjunction with a Tensor on all
1796
    # available CUDA devices
1797
    @deviceCountAtLeast(2)
1798
    @onlyCUDA
1799
    def test_binary_op_scalar_device_unspecified(self, devices):
1800
        scalar_val = torch.tensor(1.0)
1801
        for default_device in devices:
1802
            with torch.cuda.device(default_device):
1803
                for device in devices:
1804
                    device_obj = torch.device(device)
1805
                    x = torch.rand(3, device=device)
1806
                    y0 = x * scalar_val
1807
                    self.assertEqual(y0.device, device_obj)
1808
                    y1 = scalar_val * x
1809
                    self.assertEqual(y1.device, device_obj)
1810
                    self.assertEqual(y0, y1)
1811

1812
    def test_div_and_floordiv_vs_python(self, device):
1813
        # Tests torch division ops which can handle both arguments being
1814
        #   scalars.
1815
        def _scalar_helper(python_op, torch_op):
1816
            for a, b in product(range(-10, 10), range(-10, 10)):
1817
                for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1818
                    a = op(a)
1819
                    b = op(b)
1820

1821
                    # Skips zero divisors
1822
                    if b == 0:
1823
                        continue
1824

1825
                    expected = python_op(a, b)
1826

1827
                    for op in (operator.truediv, torch.true_divide):
1828
                        actual_scalar = torch_op(a, b)
1829

1830
                        a_t = torch.tensor(a, device=device)
1831
                        b_t = torch.tensor(b, device=device)
1832

1833
                        actual_tensor = torch_op(a_t, b_t)
1834
                        actual_first_tensor = torch_op(a_t, b)
1835
                        actual_second_tensor = torch_op(a, b_t)
1836

1837
                        self.assertEqual(actual_scalar, expected)
1838
                        self.assertEqual(actual_tensor.item(), expected)
1839
                        self.assertEqual(actual_first_tensor, actual_tensor)
1840
                        self.assertEqual(actual_second_tensor, actual_tensor)
1841

1842
        _scalar_helper(operator.truediv, operator.truediv)
1843
        _scalar_helper(operator.truediv, torch.true_divide)
1844
        _scalar_helper(lambda a, b: math.floor(a / b), operator.floordiv)
1845
        _scalar_helper(lambda a, b: math.floor(a / b), torch.floor_divide)
1846

1847
    @onlyNativeDeviceTypes
1848
    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1849
    def test_div_and_floordiv_script_vs_python(self, device):
1850
        # Creates jitted functions of two tensors
1851
        def _wrapped_div(a, b):
1852
            return a / b
1853

1854
        def _wrapped_floordiv(a, b):
1855
            return a // b
1856

1857
        scripted_div = torch.jit.script(_wrapped_div)
1858
        scripted_floordiv = torch.jit.script(_wrapped_floordiv)
1859
        for a, b in product(range(-10, 10), range(-10, 10)):
1860
            for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1861
                a = op(a)
1862
                b = op(b)
1863

1864
                # Skips zero divisors
1865
                if b == 0:
1866
                    continue
1867

1868
                expected_div = a / b
1869
                expected_floordiv = math.floor(a / b)
1870
                a_t = torch.tensor(a, device=device)
1871
                b_t = torch.tensor(b, device=device)
1872

1873
                self.assertEqual(scripted_div(a_t, b_t), expected_div)
1874
                self.assertEqual(scripted_floordiv(a_t, b_t), expected_floordiv)
1875

1876
        # Creates jitted functions of one tensor
1877
        def _wrapped_div_scalar(a):
1878
            return a / 5
1879

1880
        # NOTE: the JIT implements division as torch.reciprocal(a) * 5
1881
        def _wrapped_rdiv_scalar(a):
1882
            return 5 / a
1883

1884
        def _wrapped_floordiv_scalar(a):
1885
            return a // 5
1886

1887
        # NOTE: this fails if the input is not an integer tensor
1888
        # See https://github.com/pytorch/pytorch/issues/45199
1889
        def _wrapped_rfloordiv_scalar(a):
1890
            return 5 // a
1891

1892
        scripted_div_scalar = torch.jit.script(_wrapped_div_scalar)
1893
        scripted_rdiv_scalar = torch.jit.script(_wrapped_rdiv_scalar)
1894
        scripted_floordiv_scalar = torch.jit.script(_wrapped_floordiv_scalar)
1895
        scripted_rfloordiv_scalar = torch.jit.script(_wrapped_rfloordiv_scalar)
1896

1897
        for a in range(-10, 10):
1898
            for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1899
                a = op(a)
1900

1901
                a_t = torch.tensor(a, device=device)
1902

1903
                self.assertEqual(a / 5, scripted_div_scalar(a_t))
1904

1905
                # Skips zero divisors
1906
                if a == 0:
1907
                    continue
1908

1909
                self.assertEqual(5 / a, scripted_rdiv_scalar(a_t))
1910

1911
                # Handles Issue 45199 (see comment above)
1912
                if a_t.is_floating_point():
1913
                    with self.assertRaises(RuntimeError):
1914
                        scripted_rfloordiv_scalar(a_t)
1915
                else:
1916
                    # This should emit a UserWarning, why doesn't it?
1917
                    # See issue gh-52387
1918
                    self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t))
1919

1920
    @onlyNativeDeviceTypes
1921
    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1922
    def test_idiv_and_ifloordiv_vs_python(self, device):
1923
        def _wrapped_idiv_tensor(a, b):
1924
            a /= b
1925
            return a
1926

1927
        def _wrapped_idiv_scalar(a):
1928
            a /= 5
1929
            return a
1930

1931
        def _wrapped_true_divide__tensor(a, b):
1932
            a.true_divide_(b)
1933
            return a
1934

1935
        def _wrapped_true_divide__scalar(a):
1936
            a.true_divide_(5)
1937
            return a
1938

1939
        def _wrapped_floor_divide__tensor(a, b):
1940
            a.floor_divide_(b)
1941
            return a
1942

1943
        def _wrapped_floor_divide__scalar(a):
1944
            a.floor_divide_(5)
1945
            return a
1946

1947
        # The following functions are unsupported by the JIT
1948
        def _wrapped_ifloordiv_tensor(a, b):
1949
            a //= b
1950
            return a
1951

1952
        def _wrapped_ifloordiv_scalar(a):
1953
            a //= 5
1954
            return a
1955

1956
        with self.assertRaises(torch.jit.frontend.NotSupportedError):
1957
            scripted_ifloordiv_tensor = torch.jit.script(_wrapped_ifloordiv_tensor)
1958

1959
        with self.assertRaises(torch.jit.frontend.NotSupportedError):
1960
            scripted_ifloordiv_scalar = torch.jit.script(_wrapped_ifloordiv_scalar)
1961

1962
        scripted_idiv_tensor = torch.jit.script(_wrapped_idiv_tensor)
1963
        scripted_idiv_scalar = torch.jit.script(_wrapped_idiv_scalar)
1964
        scripted_true_divide__tensor = torch.jit.script(_wrapped_true_divide__tensor)
1965
        scripted_true_divide__scalar = torch.jit.script(_wrapped_true_divide__scalar)
1966
        scripted_floor_divide__tensor = torch.jit.script(_wrapped_floor_divide__tensor)
1967
        scripted_floor_divide__scalar = torch.jit.script(_wrapped_floor_divide__scalar)
1968

1969
        for a, b in product(range(-10, 10), range(-10, 10)):
1970
            for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1971
                a = op(a)
1972
                b = op(b)
1973

1974
                # Skips zero divisors
1975
                if b == 0:
1976
                    continue
1977

1978
                expected_idiv = a / b
1979
                expected_ifloordiv = a // b
1980

1981
                a_t = torch.tensor(a, device=device)
1982
                b_t = torch.tensor(b, device=device)
1983

1984
                if a_t.is_floating_point():
1985
                    tmp0 = a_t.clone()
1986
                    tmp0 /= b
1987

1988
                    tmp1 = a_t.clone()
1989
                    tmp1 /= b_t
1990

1991
                    self.assertEqual(tmp0.item(), expected_idiv)
1992
                    self.assertEqual(tmp1.item(), expected_idiv)
1993
                    self.assertEqual(
1994
                        scripted_true_divide__tensor(a_t.clone(), b_t).item(),
1995
                        expected_idiv,
1996
                    )
1997
                    self.assertEqual(
1998
                        scripted_true_divide__scalar(a_t.clone()).item(), a / 5
1999
                    )
2000
                else:
2001
                    tmp = a_t.clone()
2002
                    with self.assertRaises(RuntimeError):
2003
                        tmp /= b
2004
                    with self.assertRaises(RuntimeError):
2005
                        tmp /= b_t
2006
                    with self.assertRaises(RuntimeError):
2007
                        scripted_true_divide__tensor(tmp, b_t)
2008
                    with self.assertRaises(RuntimeError):
2009
                        scripted_true_divide__scalar(tmp)
2010

2011
                if not a_t.is_floating_point() and b_t.is_floating_point():
2012
                    # Inplace modification fails because a float tensor is required
2013
                    #   if the divisor is a float tensor
2014
                    a_t.clone().floor_divide_(b_t)
2015
                    scripted_floor_divide__tensor(a_t.clone(), b_t)
2016
                    tmp = a_t.clone()
2017
                    tmp //= b_t
2018
                else:
2019
                    # Inplace modification is OK when both or neither tensor is
2020
                    #   a float tensor
2021
                    self.assertEqual(
2022
                        a_t.clone().floor_divide_(b_t).item(), expected_ifloordiv
2023
                    )
2024
                    self.assertEqual(
2025
                        scripted_floor_divide__tensor(a_t.clone(), b_t).item(),
2026
                        expected_ifloordiv,
2027
                    )
2028
                    tmp = a_t.clone()
2029
                    tmp //= b_t
2030
                    self.assertEqual(tmp.item(), expected_ifloordiv)
2031

2032
                self.assertEqual(
2033
                    scripted_floor_divide__scalar(a_t), math.floor(a / 5)
2034
                )
2035

2036
    # Tests binary op equivalence with Python builtin ops
2037
    # Also tests that reverse operations are equivalent to forward ops
2038
    # NOTE: division ops are tested separately above
2039
    def test_binary_ops_with_scalars(self, device):
2040
        for python_op, torch_op in (
2041
            (operator.add, torch.add),
2042
            (operator.sub, torch.sub),
2043
            (operator.mul, torch.mul),
2044
            (operator.truediv, torch.div),
2045
        ):
2046

2047
            for a, b in product(range(-10, 10), range(-10, 10)):
2048
                for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
2049
                    a = op(a)
2050
                    b = op(b)
2051

2052
                    # Skips zero divisors
2053
                    if b == 0 or a == 0:
2054
                        continue
2055

2056
                    a_tensor = torch.tensor(a, device=device)
2057
                    b_tensor = torch.tensor(b, device=device)
2058
                    a_tensor_cpu = a_tensor.cpu()
2059
                    b_tensor_cpu = b_tensor.cpu()
2060
                    vals = (a, b, a_tensor, b_tensor, a_tensor_cpu, b_tensor_cpu)
2061

2062
                    for args in product(vals, vals):
2063
                        first, second = args
2064

2065
                        first_scalar = (
2066
                            first
2067
                            if not isinstance(first, torch.Tensor)
2068
                            else first.item()
2069
                        )
2070
                        second_scalar = (
2071
                            second
2072
                            if not isinstance(second, torch.Tensor)
2073
                            else second.item()
2074
                        )
2075
                        expected = python_op(first_scalar, second_scalar)
2076

2077
                        self.assertEqual(expected, python_op(first, second))
2078
                        self.assertEqual(expected, torch_op(first, second))
2079

2080
    @dtypes(
2081
        *product(
2082
            all_types_and(torch.half, torch.bfloat16, torch.bool),
2083
            all_types_and(torch.half, torch.bfloat16, torch.bool),
2084
        )
2085
    )
2086
    def test_maximum_minimum_type_promotion(self, device, dtypes):
2087
        a = torch.tensor((0, 1), device=device, dtype=dtypes[0])
2088
        b = torch.tensor((1, 0), device=device, dtype=dtypes[1])
2089
        for op in (
2090
            torch.maximum,
2091
            torch.max,
2092
            torch.fmax,
2093
            torch.minimum,
2094
            torch.min,
2095
            torch.fmin,
2096
        ):
2097
            result = op(a, b)
2098
            self.assertEqual(result.dtype, torch.result_type(a, b))
2099

2100
    @dtypes(*integral_types_and(torch.bool))
2101
    def test_maximum_minimum_int_and_bool(self, device, dtype):
2102
        ops = (
2103
            (torch.maximum, torch.max, np.maximum),
2104
            (torch.minimum, torch.min, np.minimum),
2105
            (torch.fmax, None, np.fmax),
2106
            (torch.fmin, None, np.fmin),
2107
        )
2108
        rng = np.random.default_rng()
2109
        a_np = np.array(
2110
            rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]
2111
        )
2112
        b_np = np.array(
2113
            rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]
2114
        )
2115

2116
        for torch_op, alias, numpy_op in ops:
2117
            a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
2118
            b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
2119
            tensor_result = torch_op(a_tensor, b_tensor)
2120

2121
            out = torch.empty_like(a_tensor)
2122
            torch_op(a_tensor, b_tensor, out=out)
2123

2124
            numpy_result = numpy_op(a_np, b_np)
2125

2126
            if alias is not None:
2127
                alias_result = alias(a_tensor, b_tensor)
2128
                self.assertEqual(alias_result, tensor_result)
2129

2130
            self.assertEqual(tensor_result, numpy_result)
2131
            self.assertEqual(out, numpy_result)
2132

2133
    @precisionOverride({torch.bfloat16: 1e-2})
2134
    @dtypes(*(floating_types_and(torch.half, torch.bfloat16)))
2135
    def test_maximum_minimum_float(self, device, dtype):
2136
        ops = (
2137
            (torch.maximum, torch.max, np.maximum),
2138
            (torch.minimum, torch.min, np.minimum),
2139
            (torch.fmax, None, np.fmax),
2140
            (torch.fmin, None, np.fmin),
2141
        )
2142

2143
        if dtype == torch.bfloat16:
2144
            a_np = np.random.randn(10).astype(np.float64)
2145
            b_np = np.random.randn(10).astype(np.float64)
2146
        else:
2147
            a_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype])
2148
            b_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype])
2149

2150
        for torch_op, alias, numpy_op in ops:
2151
            numpy_result = numpy_op(a_np, b_np)
2152

2153
            a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
2154
            b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
2155
            tensor_result = torch_op(a_tensor, b_tensor)
2156
            out = torch.empty_like(a_tensor)
2157
            torch_op(a_tensor, b_tensor, out=out)
2158

2159
            if alias is not None:
2160
                alias_result = alias(a_tensor, b_tensor)
2161
                self.assertEqual(alias_result, tensor_result, exact_dtype=False)
2162

2163
            self.assertEqual(tensor_result, numpy_result, exact_dtype=False)
2164
            self.assertEqual(out, numpy_result, exact_dtype=False)
2165

2166
    @dtypes(*(floating_types_and(torch.half, torch.bfloat16)))
2167
    def test_maximum_minimum_float_nan_and_inf(self, device, dtype):
2168
        # np.maximum and np.minimum functions compare input arrays element-wisely.
2169
        # if one of the elements being compared is a NaN, then that element is returned.
2170
        ops = (
2171
            (torch.maximum, torch.max, np.maximum),
2172
            (torch.minimum, torch.min, np.minimum),
2173
            (torch.fmax, None, np.fmax),
2174
            (torch.fmin, None, np.fmin),
2175
        )
2176
        a_vals = (
2177
            float("inf"),
2178
            -float("inf"),
2179
            float("nan"),
2180
            float("inf"),
2181
            float("nan"),
2182
            float("nan"),
2183
            1,
2184
            float("nan"),
2185
        )
2186
        b_vals = (
2187
            -float("inf"),
2188
            float("inf"),
2189
            float("inf"),
2190
            float("nan"),
2191
            float("nan"),
2192
            0,
2193
            float("nan"),
2194
            -5,
2195
        )
2196
        if dtype == torch.bfloat16:
2197
            a_np = np.array(a_vals, dtype=np.float64)
2198
            b_np = np.array(b_vals, dtype=np.float64)
2199
        else:
2200
            a_np = np.array(a_vals, dtype=torch_to_numpy_dtype_dict[dtype])
2201
            b_np = np.array(b_vals, dtype=torch_to_numpy_dtype_dict[dtype])
2202

2203
        for torch_op, alias, numpy_op in ops:
2204
            numpy_result = numpy_op(a_np, b_np)
2205

2206
            a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
2207
            b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
2208
            tensor_result = torch_op(a_tensor, b_tensor)
2209

2210
            out = torch.empty_like(a_tensor)
2211
            torch_op(a_tensor, b_tensor, out=out)
2212

2213
            if alias is not None:
2214
                alias_result = alias(a_tensor, b_tensor)
2215
                self.assertEqual(alias_result, tensor_result)
2216

2217
            if dtype == torch.bfloat16:
2218
                self.assertEqual(tensor_result, numpy_result, exact_dtype=False)
2219
                self.assertEqual(out, numpy_result, exact_dtype=False)
2220
            else:
2221
                self.assertEqual(tensor_result, numpy_result)
2222
                self.assertEqual(out, numpy_result)
2223

2224
    @dtypes(
2225
        *product(
2226
            complex_types(),
2227
            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
2228
        )
2229
    )
2230
    def test_maximum_minimum_complex(self, device, dtypes):
2231
        for torch_op in (
2232
            torch.maximum,
2233
            torch.minimum,
2234
            torch.max,
2235
            torch.min,
2236
            torch.fmax,
2237
            torch.fmin,
2238
        ):
2239
            with self.assertRaisesRegex(RuntimeError, ".+not implemented for.+"):
2240
                torch_op(
2241
                    torch.ones(1, device=device, dtype=dtypes[0]),
2242
                    torch.ones(1, device=device, dtype=dtypes[1]),
2243
                )
2244

2245
            with self.assertRaisesRegex(RuntimeError, ".+not implemented for.+"):
2246
                torch_op(
2247
                    torch.ones(1, device=device, dtype=dtypes[1]),
2248
                    torch.ones(1, device=device, dtype=dtypes[0]),
2249
                )
2250

2251
    @onlyCUDA
2252
    def test_maximum_minimum_cross_device(self, device):
2253
        a = torch.tensor((1, 2, -1))
2254
        b = torch.tensor((3, 0, 4), device=device)
2255
        ops = (torch.maximum, torch.minimum)
2256

2257
        for torch_op in ops:
2258
            with self.assertRaisesRegex(
2259
                RuntimeError, "Expected all tensors to be on the same device"
2260
            ):
2261
                torch_op(a, b)
2262

2263
            with self.assertRaisesRegex(
2264
                RuntimeError, "Expected all tensors to be on the same device"
2265
            ):
2266
                torch_op(b, a)
2267

2268
        # test cuda tensor and cpu scalar
2269
        ops = ((torch.maximum, np.maximum), (torch.minimum, np.minimum))
2270
        a_np = np.array(1)
2271
        b_np = np.array([3, 0, 4])
2272

2273
        for torch_op, numpy_op in ops:
2274
            a_tensor = torch.from_numpy(a_np)
2275
            b_tensor = torch.from_numpy(b_np).to(device=device)
2276
            tensor_result_1 = torch_op(a_tensor, b_tensor)
2277
            numpy_result_1 = numpy_op(a_np, b_np)
2278
            tensor_result_2 = torch_op(b_tensor, a_tensor)
2279
            numpy_result_2 = numpy_op(b_np, a_np)
2280

2281
            self.assertEqual(tensor_result_1, numpy_result_1)
2282
            self.assertEqual(tensor_result_2, numpy_result_2)
2283

2284
    @dtypes(
2285
        *product(
2286
            floating_types_and(torch.half, torch.bfloat16),
2287
            floating_types_and(torch.half, torch.bfloat16),
2288
        )
2289
    )
2290
    def test_maximum_and_minimum_subgradient(self, device, dtypes):
2291
        def run_test(f, a, b, expected_a_grad, expected_b_grad):
2292
            a = torch.tensor(a, requires_grad=True, device=device, dtype=dtypes[0])
2293
            b = torch.tensor(b, requires_grad=True, device=device, dtype=dtypes[1])
2294
            z = f(a, b)
2295
            z.sum().backward()
2296
            self.assertEqual(a.grad, expected_a_grad)
2297
            self.assertEqual(b.grad, expected_b_grad)
2298

2299
        run_test(
2300
            torch.maximum,
2301
            [0.0, 1.0, 2.0],
2302
            [1.0, 1.0, 1.0],
2303
            [0.0, 0.5, 1.0],
2304
            [1.0, 0.5, 0.0],
2305
        )
2306
        run_test(
2307
            torch.minimum,
2308
            [0.0, 1.0, 2.0],
2309
            [1.0, 1.0, 1.0],
2310
            [1.0, 0.5, 0.0],
2311
            [0.0, 0.5, 1.0],
2312
        )
2313

2314
    def test_maximum_minimum_forward_ad_float32(self, device):
2315
        # TODO: This should really be covered by OpInfo but it isn't. The problem
2316
        # is that our gradient tests test using float64 but it should also test
2317
        # float32
2318
        x = torch.randn(3, device=device, dtype=torch.float32)
2319
        y = torch.randn(3, device=device, dtype=torch.float32)
2320
        tx = torch.randn(3, device=device, dtype=torch.float32)
2321
        ty = torch.randn(3, device=device, dtype=torch.float32)
2322

2323
        with fwAD.dual_level():
2324
            x_dual = fwAD.make_dual(x, tx)
2325
            y_dual = fwAD.make_dual(y, ty)
2326
            result = torch.maximum(x_dual, y_dual)
2327
            _, result_tangent = fwAD.unpack_dual(result)
2328

2329
        expected = torch.where(x > y, tx, ty)
2330
        self.assertEqual(result_tangent, expected)
2331

2332
        with fwAD.dual_level():
2333
            x_dual = fwAD.make_dual(x, tx)
2334
            y_dual = fwAD.make_dual(y, ty)
2335
            result = torch.minimum(x_dual, y_dual)
2336
            _, result_tangent = fwAD.unpack_dual(result)
2337

2338
        expected = torch.where(x < y, tx, ty)
2339
        self.assertEqual(result_tangent, expected)
2340

2341
    # TODO: tests like this should be generic
2342
    @dtypesIfCUDA(torch.half, torch.float, torch.double)
2343
    @dtypes(torch.float, torch.double)
2344
    def test_mul_intertype_scalar(self, device, dtype):
2345
        x = torch.tensor(1.5, dtype=dtype, device=device)
2346
        y = torch.tensor(3, dtype=torch.int32, device=device)
2347

2348
        self.assertEqual(x * y, 4.5)
2349
        self.assertEqual(y * x, 4.5)
2350

2351
        with self.assertRaisesRegex(
2352
            RuntimeError, "can't be cast to the desired output type"
2353
        ):
2354
            y *= x
2355
        x *= y
2356
        self.assertEqual(x, 4.5)
2357

2358
    @onlyCPU
2359
    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
2360
    def test_sub(self, device, dtype):
2361
        if dtype in integral_types():
2362
            # Before Python 3.10, floats were implicitly converted to ints, but with
2363
            #   DeprecationWarning: an integer is required (got type float).
2364
            #   Implicit conversion to integers using __int__ is deprecated,
2365
            #   and may be removed in a future version of Python.
2366
            # Since Python 3.10, that attempt gives an error.
2367
            m1 = torch.tensor([2, 4], dtype=dtype, device=device)
2368
            m2 = torch.tensor([1, 2], dtype=dtype, device=device)
2369
            diff = torch.tensor([1, 2], dtype=dtype)
2370
        else:
2371
            m1 = torch.tensor([2.34, 4.44], dtype=dtype, device=device)
2372
            m2 = torch.tensor([1.23, 2.33], dtype=dtype, device=device)
2373
            diff = torch.tensor([1.11, 2.11], dtype=dtype)
2374

2375
        if dtype == torch.bool:
2376
            self.assertRaises(RuntimeError, lambda: m1 - m2)
2377
        elif dtype == torch.bfloat16 or dtype == torch.half:
2378
            # bfloat16 has a lower precision so we have to have a separate check for it
2379
            self.assertEqual(m1 - m2, diff, atol=0.01, rtol=0)
2380
        else:
2381
            self.assertEqual(m1 - m2, diff)
2382

2383
    # TODO: what is this test testing?
2384
    @onlyCPU
2385
    @dtypes(torch.float)
2386
    def test_csub(self, device, dtype):
2387
        # with a tensor
2388
        a = torch.randn(100, 90, dtype=dtype, device=device)
2389
        b = a.clone().normal_()
2390

2391
        res_add = torch.add(a, b, alpha=-1)
2392
        res_csub = a.clone()
2393
        res_csub.sub_(b)
2394
        self.assertEqual(res_add, res_csub)
2395

2396
        # with a scalar
2397
        a = torch.randn(100, 100, dtype=dtype, device=device)
2398

2399
        scalar = 123.5
2400
        res_add = torch.add(a, -scalar)
2401
        res_csub = a.clone()
2402
        res_csub.sub_(scalar)
2403
        self.assertEqual(res_add, res_csub)
2404

2405
    # TODO: reconcile with minimum/maximum tests
2406
    @dtypesIfCUDA(torch.half, torch.float, torch.double)
2407
    @dtypes(torch.float, torch.double)
2408
    def test_min_max_binary_op_nan(self, device, dtype):
2409
        a = torch.rand(1000, dtype=dtype, device=device)
2410
        b = torch.rand(1000, dtype=dtype, device=device)
2411

2412
        # 0:250: a -- nan, b -- not nan
2413
        a[:250] = float("nan")
2414
        # 250:500: a -- not nan, b -- nan
2415
        b[250:500] = float("nan")
2416
        # 500:750: a and b both nan
2417
        a[500:750] = float("nan")
2418
        b[500:750] = float("nan")
2419
        # 750:1000: neither nan
2420

2421
        ma = torch.max(a, b)
2422
        mi = torch.min(a, b)
2423

2424
        for i in range(750):
2425
            self.assertTrue(
2426
                torch.isnan(ma[i]),
2427
                f"max(a, b): {ma[i]}, a: {a[i]}, b: {b[i]}",
2428
            )
2429
            self.assertTrue(
2430
                torch.isnan(mi[i]),
2431
                f"min(a, b): {mi[i]}, a: {a[i]}, b: {b[i]}",
2432
            )
2433

2434
        for i in range(750, 1000):
2435
            self.assertFalse(
2436
                torch.isnan(ma[i]),
2437
                f"max(a, b): {ma[i]}, a: {a[i]}, b: {b[i]}",
2438
            )
2439
            self.assertFalse(
2440
                torch.isnan(mi[i]),
2441
                f"min(a, b): {mi[i]}, a: {a[i]}, b: {b[i]}",
2442
            )
2443

2444
    @dtypes(
2445
        *product(
2446
            all_types_and(torch.half, torch.bfloat16, torch.bool),
2447
            all_types_and(torch.half, torch.bfloat16, torch.bool),
2448
        )
2449
    )
2450
    def test_copysign(self, device, dtypes):
2451
        def _test_copysign_numpy(a, b):
2452
            torch_result = torch.copysign(a, b)
2453

2454
            if a.dtype == torch.bfloat16:
2455
                np_a = a.to(torch.float).cpu().numpy()
2456
            else:
2457
                np_a = a.cpu().numpy()
2458

2459
            if b.dtype == torch.bfloat16:
2460
                np_b = b.to(torch.float).cpu().numpy()
2461
            else:
2462
                np_b = b.cpu().numpy()
2463
            expected = torch.from_numpy(np.copysign(np_a, np_b))
2464
            # To handle inconsistencies of type promotion between PyTorch and Numpy
2465
            # Applied for both arguments having integral precision and bfloat16
2466
            types = integral_types_and(torch.bool, torch.bfloat16)
2467
            if a.dtype in types or b.dtype in types:
2468
                promoted_type = torch.promote_types(torch_result.dtype, expected.dtype)
2469
                torch_result = torch_result.to(promoted_type)
2470
                expected = expected.to(promoted_type)
2471

2472
            # Verify Value
2473
            self.assertEqual(torch_result, expected)
2474
            # Verify Sign
2475
            # Use double copysign to verify the correctnes of 0.0 and -0.0, since
2476
            # it always True for self.assertEqual(0.0 == -0.0). So, we use 1 as the
2477
            # magnitude to verify the sign between torch and numpy results, elementwise.
2478
            # Special case: NaN conversions between FP32 and FP16 is not bitwise
2479
            # equivalent to pass this assertion.
2480
            if a.dtype != torch.float16 and b.dtype != torch.float16:
2481
                self.assertEqual(
2482
                    torch.copysign(torch.tensor(1.0), torch_result),
2483
                    torch.copysign(torch.tensor(1.0), expected),
2484
                )
2485

2486
        # Compare Result with NumPy
2487
        # Type promotion
2488
        a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2489
        b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2490
        _test_copysign_numpy(a, b)
2491

2492
        # Broadcast
2493
        a = make_tensor((10, 1, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2494
        b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2495
        _test_copysign_numpy(a, b)
2496

2497
        a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2498
        b = make_tensor((10, 1, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2499
        _test_copysign_numpy(a, b)
2500

2501
        # 0.0/-0.0/inf/-inf/nan
2502
        cases = [0.0, -0.0, float("inf"), float("-inf"), float("nan")]
2503
        # torch.bfloat16 can not hold '-nan'
2504
        # torch.half can not hold '-nan' on CUDA
2505
        types = [torch.float32, torch.float64]
2506
        if device == "cpu":
2507
            types.append(torch.float16)
2508
        if dtypes[0] in types:
2509
            b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2510
            for case in cases:
2511
                _test_copysign_numpy(
2512
                    torch.tensor([case], device=device, dtype=dtypes[0]), b
2513
                )
2514

2515
        if dtypes[1] in floating_types_and(torch.half, torch.bfloat16):
2516
            a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2517
            for case in cases:
2518
                _test_copysign_numpy(
2519
                    a, torch.tensor([case], device=device, dtype=dtypes[1])
2520
                )
2521

2522
    @dtypes(
2523
        *product(
2524
            floating_types_and(torch.half, torch.bfloat16),
2525
            floating_types_and(torch.half, torch.bfloat16),
2526
        )
2527
    )
2528
    def test_copysign_subgradient(self, device, dtypes):
2529
        # Input is 0.0
2530
        x = torch.tensor(
2531
            [0.0, 0.0, 0.0], dtype=dtypes[0], device=device, requires_grad=True
2532
        )
2533
        y = torch.tensor(
2534
            [-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True
2535
        )
2536
        out = torch.copysign(x, y)
2537
        out.sum().backward()
2538
        self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0])
2539
        self.assertEqual(y.grad.tolist(), [0.0] * 3)
2540

2541
        # Input is -0.0
2542
        x = torch.tensor(
2543
            [-0.0, -0.0, -0.0], dtype=dtypes[0], device=device, requires_grad=True
2544
        )
2545
        y = torch.tensor(
2546
            [-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True
2547
        )
2548
        out = torch.copysign(x, y)
2549
        out.sum().backward()
2550
        self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0])
2551
        self.assertEqual(y.grad.tolist(), [0.0] * 3)
2552

2553
        # Other is 0.0
2554
        x = torch.tensor(
2555
            [-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True
2556
        )
2557
        y = torch.tensor(
2558
            [0.0, 0.0, 0.0], dtype=dtypes[1], device=device, requires_grad=True
2559
        )
2560
        out = torch.copysign(x, y)
2561
        out.sum().backward()
2562
        self.assertEqual(x.grad.tolist(), [-1.0, 0.0, 1.0])
2563
        self.assertEqual(y.grad.tolist(), [0.0] * 3)
2564

2565
        # Other is -0.0
2566
        x = torch.tensor(
2567
            [-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True
2568
        )
2569
        y = torch.tensor(
2570
            [-0.0, -0.0, -0.0], dtype=dtypes[1], device=device, requires_grad=True
2571
        )
2572
        out = torch.copysign(x, y)
2573
        out.sum().backward()
2574
        self.assertEqual(x.grad.tolist(), [1.0, 0.0, -1.0])
2575
        self.assertEqual(y.grad.tolist(), [0.0] * 3)
2576

2577
    @dtypes(torch.bfloat16, torch.float)
2578
    def test_div(self, device, dtype):
2579
        for op, method, inplace in (
2580
            (torch.div, torch.Tensor.div, torch.Tensor.div_),
2581
            (torch.true_divide, torch.Tensor.true_divide, torch.Tensor.true_divide_),
2582
        ):
2583
            m1 = torch.randn(10, 10, dtype=torch.float, device=device).to(dtype=dtype)
2584
            res1 = m1.clone()
2585
            inplace(res1[:, 3], 2)
2586
            res2 = m1.clone()
2587
            for i in range(m1.size(0)):
2588
                res2[i, 3] = res2[i, 3] / 2
2589
            self.assertEqual(res1, res2)
2590

2591
            if dtype == torch.bfloat16:
2592
                a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device)
2593
                a2 = torch.tensor([2.0, 2.0], dtype=dtype, device=device)
2594
                self.assertEqual(
2595
                    op(a1, a2),
2596
                    torch.tensor([2.1, 3.1], dtype=dtype, device=device),
2597
                    atol=0.01,
2598
                    rtol=0,
2599
                )
2600
                self.assertEqual(method(a1, a2), op(a1, a2))
2601

2602
    @dtypes(torch.bfloat16, torch.float)
2603
    def test_true_divide_out(self, device, dtype):
2604
        a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device)
2605
        a2 = torch.tensor([2.0, 2.0], dtype=dtype, device=device)
2606
        res = torch.empty_like(a1)
2607
        self.assertEqual(
2608
            torch.true_divide(a1, a2, out=res),
2609
            torch.tensor([2.1, 3.1], dtype=dtype, device=device),
2610
            atol=0.01,
2611
            rtol=0,
2612
        )
2613

2614
    @dtypes(torch.half)
2615
    def test_divmul_scalar(self, device, dtype):
2616
        x = torch.tensor(100.0, device=device, dtype=dtype)
2617
        x_ref = x.float()
2618
        scale = 1e5
2619
        res = x.div(scale)
2620
        expected = x_ref.div(scale)
2621
        self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0)
2622
        x = torch.tensor(1e-5, device=device, dtype=dtype)
2623
        x_ref = x.float()
2624
        res = x.mul(scale)
2625
        expected = x_ref.mul(scale)
2626
        self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0)
2627
        res = scale * x
2628
        self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0)
2629

2630
    @dtypesIfCUDA(
2631
        *set(get_all_math_dtypes("cuda")) - {torch.complex64, torch.complex128}
2632
    )
2633
    @dtypes(*set(get_all_math_dtypes("cpu")) - {torch.complex64, torch.complex128})
2634
    def test_floor_divide_tensor(self, device, dtype):
2635
        x = torch.randn(10, device=device).mul(30).to(dtype)
2636
        y = torch.arange(1, 11, dtype=dtype, device=device)
2637

2638
        z = x // y
2639
        z_alt = torch.floor(x.double() / y.double()).to(dtype)
2640

2641
        self.assertEqual(z.dtype, x.dtype)
2642
        self.assertEqual(z, z_alt)
2643

2644
    @dtypesIfCUDA(
2645
        *set(get_all_math_dtypes("cuda")) - {torch.complex64, torch.complex128}
2646
    )
2647
    @dtypes(*set(get_all_math_dtypes("cpu")) - {torch.complex64, torch.complex128})
2648
    def test_floor_divide_scalar(self, device, dtype):
2649
        x = torch.randn(100, device=device).mul(10).to(dtype)
2650

2651
        z = x // 3
2652
        z_alt = torch.tensor(
2653
            [math.floor(v.item() / 3.0) for v in x], dtype=x.dtype, device=device
2654
        )
2655

2656
        self.assertEqual(z.dtype, x.dtype)
2657
        self.assertEqual(z, z_alt)
2658

2659
    @onlyCPU
2660
    @dtypes(*get_all_math_dtypes("cpu"))
2661
    def test_rdiv(self, device, dtype):
2662
        if dtype is torch.float16:
2663
            return
2664
        elif dtype.is_complex:
2665
            x = torch.rand(100, dtype=dtype, device=device).add(1).mul(4)
2666
        else:
2667
            x = torch.rand(100, device=device).add(1).mul(4).to(dtype)
2668
        y = 30 / x
2669
        z = torch.tensor([30 / v.item() for v in x], device=device)
2670
        self.assertEqual(y, z, exact_dtype=False)
2671

2672
    @dtypes(*floating_types_and(torch.half))
2673
    def test_fmod_remainder_by_zero_float(self, device, dtype):
2674
        fn_list = (torch.fmod, torch.remainder)
2675
        for fn in fn_list:
2676
            # check floating-point tensor fmod/remainder to zero is nan on both CPU and GPU
2677
            x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2678
            zero = torch.zeros_like(x)
2679
            self.assertTrue(torch.all(fn(x, 0.0).isnan()))
2680
            self.assertTrue(torch.all(fn(x, zero).isnan()))
2681

2682
    @onlyNativeDeviceTypes  # Check Issue https://github.com/pytorch/pytorch/issues/48130
2683
    @dtypes(*integral_types())
2684
    def test_fmod_remainder_by_zero_integral(self, device, dtype):
2685
        fn_list = (torch.fmod, torch.remainder)
2686
        for fn in fn_list:
2687
            # check integral tensor fmod/remainder to zero
2688
            x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2689
            zero = torch.zeros_like(x)
2690
            # RuntimeError on CPU
2691
            if self.device_type == "cpu":
2692
                with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"):
2693
                    fn(x, zero)
2694
            elif torch.version.hip is not None:
2695
                # ROCm behavior: x % 0 is a no-op; x is returned
2696
                self.assertEqual(fn(x, zero), x)
2697
            else:
2698
                # CUDA behavior: Different value for different dtype
2699
                # Due to it's an undefined behavior, CUDA returns a pattern of all 1s
2700
                # for integral dividend (other than int64) divided by zero. For int64,
2701
                # CUDA returns all 1s for negative dividend, half 1s for positive dividend.
2702
                # uint8: 0xff -> 255
2703
                # int32: 0xffffffff -> -1
2704
                if dtype == torch.int64:
2705
                    self.assertEqual(fn(x, zero) == 4294967295, x >= 0)
2706
                    self.assertEqual(fn(x, zero) == -1, x < 0)
2707
                else:
2708
                    value = 255 if dtype == torch.uint8 else -1
2709
                    self.assertTrue(torch.all(fn(x, zero) == value))
2710

2711
    @dtypes(*all_types_and(torch.half))
2712
    def test_fmod_remainder(self, device, dtype):
2713
        # Use numpy as reference
2714
        def _helper(x, mod, fns_list):
2715
            for fn, inplace_fn, ref_fn in fns_list:
2716
                np_x = x.cpu().numpy() if torch.is_tensor(x) else x
2717
                np_mod = mod.cpu().numpy() if torch.is_tensor(mod) else mod
2718
                exp = ref_fn(np_x, np_mod)
2719
                exp = torch.from_numpy(exp)
2720
                res = fn(x, mod)
2721

2722
                self.assertEqual(res, exp, exact_dtype=False)
2723

2724
                if torch.is_tensor(x):
2725
                    # out
2726
                    out = torch.empty(0, device=device, dtype=res.dtype)
2727
                    fn(x, mod, out=out)
2728
                    self.assertEqual(out, exp, exact_dtype=False)
2729
                    self.assertEqual(out.size(), torch.Size([10, 10]))
2730
                    # in-place (Type cast runtime error)
2731
                    try:
2732
                        inplace_fn(x, mod)
2733
                        self.assertEqual(x, exp, exact_dtype=False)
2734
                    except RuntimeError as e:
2735
                        self.assertRegex(
2736
                            str(e),
2737
                            "result type (Half|Float|Double) "
2738
                            "can't be cast to the desired output "
2739
                            "type (Byte|Char|Short|Int|Long)",
2740
                        )
2741

2742
        x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2743
        # mod with same dtype as x
2744
        mod = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2745
        # Exclude 0
2746
        mod[mod == 0] = 1
2747

2748
        # Mods: Integer, Float, Tensor, Non-contiguous Tensor
2749
        mods = [3, 2.3, mod, mod.t()]
2750
        # mod with floating-point dtype
2751
        if dtype in integral_types():
2752
            mod_float = make_tensor(
2753
                (10, 10), device=device, dtype=torch.float, low=-9, high=9
2754
            )
2755
            mod[mod == 0] = 1
2756
            mods.append(mod_float)
2757

2758
        for dividend, mod in product([x, x.t()], mods):
2759
            _helper(
2760
                dividend,
2761
                mod,
2762
                (
2763
                    (torch.fmod, torch.Tensor.fmod_, np.fmod),
2764
                    (torch.remainder, torch.Tensor.remainder_, np.remainder),
2765
                ),
2766
            )
2767

2768
        # Tests for torch.remainder(scalar, tensor)
2769
        for dividend, mod in product([5, 3.14], mods):
2770
            if torch.is_tensor(mod):
2771
                _helper(
2772
                    dividend,
2773
                    mod,
2774
                    ((torch.remainder, torch.Tensor.remainder_, np.remainder),),
2775
                )
2776

2777
    @dtypes(torch.float, torch.double)
2778
    def test_remainder_fmod_large_dividend(self, device, dtype):
2779
        alarge = 1e9
2780
        pi = 3.14159265358979
2781
        for avalue in [alarge, -alarge]:
2782
            for bvalue in [pi, -pi]:
2783
                a = torch.tensor([avalue], dtype=dtype, device=device)
2784
                b = torch.tensor([bvalue], dtype=dtype, device=device)
2785
                c = torch.remainder(a, b)
2786
                d = torch.fmod(a, b)
2787
                self.assertTrue(
2788
                    (b[0] > 0) == (c[0] > 0)
2789
                )  # remainder has same sign as divisor
2790
                self.assertTrue(
2791
                    (a[0] > 0) == (d[0] > 0)
2792
                )  # fmod has same sign as dividend
2793
                self.assertTrue(
2794
                    abs(c[0]) < abs(b[0])
2795
                )  # remainder is within range of divisor
2796
                self.assertTrue(
2797
                    abs(d[0]) < abs(b[0])
2798
                )  # fmod is within range of divisor
2799
                if (a[0] > 0) == (b[0] > 0):
2800
                    self.assertTrue(c[0] == d[0])  # remainder is same as fmod
2801
                else:
2802
                    self.assertTrue(
2803
                        abs(c[0] - d[0]) == abs(b[0])
2804
                    )  # differ by one divisor
2805

2806
    @dtypesIfCPU(torch.bfloat16, torch.half, torch.float32, torch.float64)
2807
    @dtypes(torch.float32, torch.float64)
2808
    def test_hypot(self, device, dtype):
2809
        inputs = [
2810
            (
2811
                torch.randn(10, device=device).to(dtype),
2812
                torch.randn(10, device=device).to(dtype),
2813
            ),
2814
            (
2815
                torch.randn((3, 3, 3), device=device).to(dtype),
2816
                torch.randn((3, 3, 3), device=device).to(dtype),
2817
            ),
2818
            (
2819
                torch.randn((10, 1), device=device).to(dtype),
2820
                torch.randn((10, 1), device=device).to(dtype).transpose(0, 1),
2821
            ),
2822
            (
2823
                torch.randint(100, (10,), device=device, dtype=torch.long),
2824
                torch.randn(10, device=device).to(dtype),
2825
            ),
2826
        ]
2827
        for input in inputs:
2828
            actual = torch.hypot(input[0], input[1])
2829
            if dtype in [torch.bfloat16, torch.half]:
2830
                expected = torch.sqrt(input[0] * input[0] + input[1] * input[1])
2831
            else:
2832
                expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy())
2833
            self.assertEqual(actual, expected, exact_dtype=False)
2834

2835
    @onlyNativeDeviceTypes
2836
    @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
2837
    def test_gcd(self, device, dtype):
2838
        # Tests gcd(0, 0), gcd(0, a) cases
2839
        t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device)
2840
        t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device)
2841
        actual = torch.gcd(t1, t2)
2842
        expected = np.gcd([0, 10, 0], [0, 0, 10])
2843
        self.assertEqual(actual, expected, exact_dtype=False)
2844

2845
        if dtype == torch.uint8:
2846
            # Test unsigned integers with potential sign issues (i.e., uint8 with value >= 128)
2847
            a = torch.tensor([190, 210], device=device, dtype=dtype)
2848
            b = torch.tensor([190, 220], device=device, dtype=dtype)
2849
            actual = torch.gcd(a, b)
2850
            expected = torch.tensor([190, 10], device=device, dtype=dtype)
2851
            self.assertEqual(actual, expected)
2852
        else:
2853
            # Compares with NumPy
2854
            a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2855
            b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2856
            actual = torch.gcd(a, b)
2857
            expected = np.gcd(a.cpu().numpy(), b.cpu().numpy())
2858
            self.assertEqual(actual, expected)
2859

2860
    @onlyNativeDeviceTypes
2861
    @dtypes(torch.int16, torch.int32, torch.int64)
2862
    def test_lcm(self, device, dtype):
2863
        # Tests lcm(0, 0), lcm(0, a) cases
2864
        t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device)
2865
        t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device)
2866
        actual = torch.lcm(t1, t2)
2867
        expected = np.lcm([0, 10, 0], [0, 0, 10])
2868
        self.assertEqual(actual, expected, exact_dtype=False)
2869

2870
        # Compares with NumPy
2871
        a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2872
        b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2873
        actual = torch.lcm(a, b)
2874
        expected = np.lcm(a.cpu().numpy(), b.cpu().numpy())
2875
        self.assertEqual(actual, expected, exact_dtype=False)
2876

2877
    @onlyNativeDeviceTypes
2878
    @dtypesIfCPU(torch.float32, torch.float64, torch.float16)
2879
    @dtypes(torch.float32, torch.float64)
2880
    def test_nextafter(self, device, dtype):
2881
        # Test special cases
2882
        t1 = torch.tensor([0, 0, 10], device=device, dtype=dtype)
2883
        t2 = torch.tensor([inf, -inf, 10], device=device, dtype=dtype)
2884
        actual = torch.nextafter(t1, t2)
2885
        expected = np.nextafter(t1.cpu().numpy(), t2.cpu().numpy())
2886
        self.assertEqual(actual, expected, atol=0, rtol=0)
2887

2888
        actual = torch.nextafter(t2, t1)
2889
        expected = np.nextafter(t2.cpu().numpy(), t1.cpu().numpy())
2890
        self.assertEqual(actual, expected, atol=0, rtol=0)
2891

2892
        t1 = torch.tensor([0, nan], device=device, dtype=dtype)
2893
        t2 = torch.tensor([nan, 0], device=device, dtype=dtype)
2894
        self.assertTrue(torch.nextafter(t1, t2).isnan().all())
2895

2896
        a = torch.randn(100, device=device, dtype=dtype)
2897
        b = torch.randn(100, device=device, dtype=dtype)
2898
        actual = torch.nextafter(a, b)
2899
        expected = np.nextafter(a.cpu().numpy(), b.cpu().numpy())
2900
        self.assertEqual(actual, expected, atol=0, rtol=0)
2901

2902
    @onlyNativeDeviceTypes
2903
    @dtypes(torch.bfloat16)
2904
    def test_nextafter_bfloat16(self, device, dtype):
2905
        nan = float("nan")
2906
        inf = float("inf")
2907
        cases = (
2908
            # (from, to, expected)
2909
            (0, 1, 9.183549615799121e-41),
2910
            (0, -1, -9.183549615799121e-41),
2911
            (1, -2, 0.99609375),
2912
            (1, 0, 0.99609375),
2913
            (1, 2, 1.0078125),
2914
            (-1, -2, -1.0078125),
2915
            (-1, 0, -0.99609375),
2916
            (2, -1, 1.9921875),
2917
            (2, 1, 1.9921875),
2918
            (20, 3000, 20.125),
2919
            (20, -3000, 19.875),
2920
            (3000, -20, 2992.0),
2921
            (-3000, 20, -2992.0),
2922
            (65536, 0, 65280.0),
2923
            (65536, inf, 66048.0),
2924
            (-65536, 0, -65280.0),
2925
            (-65536, -inf, -66048.0),
2926
            (nan, 0, nan),
2927
            (0, nan, nan),
2928
            (nan, nan, nan),
2929
            (nan, inf, nan),
2930
            (inf, nan, nan),
2931
            (inf, -inf, 3.3895313892515355e38),
2932
            (-inf, inf, -3.3895313892515355e38),
2933
            (inf, 0, 3.3895313892515355e38),
2934
            (0, inf, 9.183549615799121e-41),
2935
            (-inf, 0, -3.3895313892515355e38),
2936
            (0, -inf, -9.183549615799121e-41),
2937
        )
2938

2939
        for from_v, to_v, expected in cases:
2940
            from_t = torch.tensor([from_v], device=device, dtype=dtype)
2941
            to_t = torch.tensor([to_v], device=device, dtype=dtype)
2942
            actual = torch.nextafter(from_t, to_t).item()
2943
            self.assertEqual(actual, expected, atol=0, rtol=0)
2944

2945
    def _test_cop(self, torchfn, mathfn, dtype, device):
2946
        def reference_implementation(res2):
2947
            for i, j in iter_indices(sm1):
2948
                idx1d = i * sm1.size(0) + j
2949
                res2[i, j] = mathfn(sm1[i, j], sm2[idx1d])
2950
            return res2
2951

2952
        # contiguous
2953
        m1 = torch.randn(10, 10, 10, dtype=dtype, device=device)
2954
        m2 = torch.randn(10, 10 * 10, dtype=dtype, device=device)
2955
        sm1 = m1[4]
2956
        sm2 = m2[4]
2957

2958
        res1 = torchfn(sm1, sm2.view(10, 10))
2959
        res2 = reference_implementation(res1.clone())
2960
        self.assertEqual(res1, res2)
2961

2962
        # non-contiguous
2963
        m1 = torch.randn(10, 10, 10, dtype=dtype, device=device)
2964
        m2 = torch.randn(10 * 10, 10 * 10, dtype=dtype, device=device)
2965
        sm1 = m1[:, 4]
2966
        sm2 = m2[:, 4]
2967
        # view as sm1.size()
2968
        sm2.set_(
2969
            sm2.storage(),
2970
            sm2.storage_offset(),
2971
            sm1.size(),
2972
            (sm2.stride()[0] * 10, sm2.stride()[0]),
2973
        )
2974
        res1 = torchfn(sm1, sm2)
2975
        # reference_implementation assumes 1-d sm2
2976
        sm2.set_(
2977
            sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride()
2978
        )
2979
        res2 = reference_implementation(res1.clone())
2980
        self.assertEqual(res1, res2)
2981

2982
    @onlyCPU
2983
    @dtypes(torch.float)
2984
    def test_cdiv(self, device, dtype):
2985
        self._test_cop(torch.div, operator.truediv, dtype, device)
2986

2987
    @onlyCPU
2988
    @dtypes(torch.float)
2989
    def test_cremainder(self, device, dtype):
2990
        self._test_cop(torch.remainder, operator.mod, dtype, device)
2991

2992
    @onlyCPU
2993
    @dtypes(torch.float)
2994
    def test_cmul(self, device, dtype):
2995
        self._test_cop(torch.mul, operator.mul, dtype, device)
2996

2997
    @onlyCPU
2998
    @dtypes(torch.float)
2999
    def test_cpow(self, device, dtype):
3000
        self._test_cop(
3001
            torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y), dtype, device
3002
        )
3003

3004
    @onlyCPU
3005
    @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
3006
    def test_floor_divide_zero(self, device, dtype):
3007
        a = torch.tensor([0, 1], dtype=dtype, device=device)
3008
        b = torch.tensor([0, 1], dtype=dtype, device=device)
3009
        with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"):
3010
            with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
3011
                a // b
3012

3013
    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
3014
    def test_muldiv_scalar(self, device, dtype):
3015
        x = make_tensor((10, 3), dtype=dtype, device=device, low=None, high=None)
3016
        s = make_tensor((1,), dtype=dtype, device="cpu", low=None, high=None).item()
3017
        y = torch.full_like(x, s)
3018
        self.assertEqual(x * s, x * y)
3019
        self.assertEqual(s * x, y * x)
3020
        self.assertEqual(x / s, x / y)
3021
        self.assertEqual(s / x, y / x)
3022

3023
    # TODO: update make_tensor to support extremal additions and remove this in favor of make_tensor
3024
    def _generate_input(self, shape, dtype, device, with_extremal):
3025
        if shape == ():
3026
            x = torch.tensor((), dtype=dtype, device=device)
3027
        else:
3028
            if dtype.is_floating_point or dtype.is_complex:
3029
                # work around torch.randn not being implemented for bfloat16
3030
                if dtype == torch.bfloat16:
3031
                    x = torch.randn(*shape, device=device) * random.randint(30, 100)
3032
                    x = x.to(torch.bfloat16)
3033
                else:
3034
                    x = torch.randn(
3035
                        *shape, dtype=dtype, device=device
3036
                    ) * random.randint(30, 100)
3037
                x[torch.randn(*shape) > 0.5] = 0
3038
                if with_extremal and dtype.is_floating_point:
3039
                    # Use extremal values
3040
                    x[torch.randn(*shape) > 0.5] = float("nan")
3041
                    x[torch.randn(*shape) > 0.5] = float("inf")
3042
                    x[torch.randn(*shape) > 0.5] = float("-inf")
3043
                elif with_extremal and dtype.is_complex:
3044
                    x[torch.randn(*shape) > 0.5] = complex("nan")
3045
                    x[torch.randn(*shape) > 0.5] = complex("inf")
3046
                    x[torch.randn(*shape) > 0.5] = complex("-inf")
3047
            elif dtype == torch.bool:
3048
                x = torch.zeros(shape, dtype=dtype, device=device)
3049
                x[torch.randn(*shape) > 0.5] = True
3050
            else:
3051
                x = torch.randint(15, 100, shape, dtype=dtype, device=device)
3052

3053
        return x
3054

3055
    @dtypes(
3056
        *tuple(
3057
            itertools.combinations_with_replacement(
3058
                all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 2
3059
            )
3060
        )
3061
    )
3062
    def test_comparison_ops_type_promotion_and_broadcasting(self, device, dtypes):
3063
        # issue #42660
3064
        # testing all combinations of broadcasting and type promotion
3065
        # with a range of dtypes and input shapes, and with extremal values
3066
        def compare_with_numpy_bin_op(torch_fn, np_fn, x, y, out=None):
3067
            # working around the fact that numpy doesn't support bfloat16
3068
            # by letting numpy treat them as float32's
3069
            x_np = x if x.dtype != torch.bfloat16 else x.to(torch.float32)
3070
            y_np = (
3071
                y.cpu().numpy()
3072
                if y.dtype != torch.bfloat16
3073
                else y.to(torch.float32).cpu().numpy()
3074
            )
3075
            self.compare_with_numpy(
3076
                lambda inp: torch_fn(inp, y, out=out) if out else torch_fn(inp, y),
3077
                lambda inp: np_fn(inp, y_np, out=out) if out else np_fn(inp, y_np),
3078
                x_np,
3079
            )
3080

3081
        complex_op_denylist = [
3082
            torch.lt,
3083
            torch.le,
3084
            torch.gt,
3085
            torch.ge,
3086
        ]  # complex not supported
3087
        input_sizes = [(1,), (10,), (10, 1), (1, 10), (4, 10), (64, 10), (12, 3)]
3088
        op_pairs = [
3089
            (torch.lt, np.less),
3090
            (torch.le, np.less_equal),
3091
            (torch.gt, np.greater),
3092
            (torch.ge, np.greater_equal),
3093
            (torch.eq, np.equal),
3094
            (torch.ne, np.not_equal),
3095
            (torch.logical_and, np.logical_and),
3096
            (torch.logical_or, np.logical_or),
3097
            (torch.logical_xor, np.logical_xor),
3098
        ]
3099

3100
        for size1 in input_sizes:
3101
            size2 = (2,) + size1  # perform broadcasting
3102
            for with_extremal in [False, True]:
3103
                a = self._generate_input(size1, dtypes[0], device, with_extremal)
3104
                b = self._generate_input(size2, dtypes[1], device, with_extremal)
3105
                for torch_op, numpy_op in op_pairs:
3106
                    if (
3107
                        dtypes[0].is_complex or dtypes[1].is_complex
3108
                    ) and torch_op in complex_op_denylist:
3109
                        continue
3110
                    # functional version of op
3111
                    compare_with_numpy_bin_op(torch_op, numpy_op, a, b)
3112

3113
                    # functional comparison ops always return bool tensors
3114
                    self.assertEqual(torch_op(a, b).dtype, torch.bool)
3115

3116
                    # out version of op
3117
                    out = torch.zeros(
3118
                        1, dtype=torch.complex128
3119
                    )  # all casts to complex128 are safe
3120
                    compare_with_numpy_bin_op(torch_op, numpy_op, a, b, out=out)
3121

3122
    @onlyNativeDeviceTypes
3123
    @dtypes(torch.int8, torch.int16, torch.int32, torch.int64)
3124
    def test_signed_shift(self, device, dtype):
3125
        "Ensure that signed integer bit shifting works as expected."
3126
        a = torch.tensor([-10, 10], device=device, dtype=dtype)  # [11...1110110, 1010]
3127
        expected_l = torch.tensor(
3128
            [-40, 40], device=device, dtype=dtype
3129
        )  # [11...11011000, 101000]
3130
        self.assertEqual(a << 2, expected_l)
3131
        self.compare_with_numpy(lambda x: x << 2, lambda x: np.left_shift(x, 2), a)
3132
        expected_r = torch.tensor(
3133
            [-5, 5], device=device, dtype=dtype
3134
        )  # [1111...111011, 101]
3135
        self.assertEqual(a >> 1, expected_r)
3136
        self.compare_with_numpy(lambda x: x >> 1, lambda x: np.right_shift(x, 1), a)
3137

3138
    @onlyNativeDeviceTypes
3139
    @dtypes(*get_all_int_dtypes())
3140
    def test_shift_limits(self, device, dtype):
3141
        "Ensure that integer bit shifting works as expected with out-of-limits shift values."
3142
        # Issue #70904
3143
        iinfo = torch.iinfo(dtype)
3144
        bits = iinfo.bits
3145
        low = iinfo.min
3146
        high = iinfo.max
3147
        exact_dtype = dtype != torch.uint8  # numpy changes dtype from uint8 to int16 for some out-of-limits shift values
3148
        for input in (
3149
            torch.tensor([-1, 0, 1], device=device, dtype=dtype),  # small for non-vectorized operation
3150
            torch.tensor([low, high], device=device, dtype=dtype),  # small for non-vectorized operation
3151
            make_tensor((64, 64, 64), low=low, high=high, device=device, dtype=dtype),  # large for vectorized operation
3152
        ):
3153
            shift_left_expected = torch.zeros_like(input)
3154
            shift_right_expected = torch.clamp(input, -1, 0)
3155
            for shift in chain(range(-100, -1), range(bits, 100)):
3156
                shift_left = input << shift
3157
                self.assertEqual(shift_left, shift_left_expected, msg=f"<< {shift}")
3158
                self.compare_with_numpy(
3159
                    lambda x: x << shift,
3160
                    lambda x: np.left_shift(x, shift),
3161
                    input,
3162
                    exact_dtype=exact_dtype, msg=f"<< {shift}"
3163
                )
3164
                shift_right = input >> shift
3165
                self.assertEqual(shift_right, shift_right_expected, msg=f">> {shift}")
3166
                self.compare_with_numpy(
3167
                    lambda x: x >> shift,
3168
                    lambda x: np.right_shift(x, shift),
3169
                    input,
3170
                    exact_dtype=exact_dtype, msg=f">> {shift}"
3171
                )
3172

3173
    @onlyNativeDeviceTypes
3174
    @dtypes(
3175
        *list(
3176
            product(
3177
                all_types_and(torch.half, torch.bfloat16, torch.bool),
3178
                all_types_and(torch.half, torch.bfloat16, torch.bool),
3179
            )
3180
        )
3181
    )
3182
    def test_heaviside(self, device, dtypes):
3183
        input_dtype = dtypes[0]
3184
        values_dtype = dtypes[1]
3185

3186
        rng = np.random.default_rng()
3187
        input = np.array(
3188
            rng.integers(-10, 10, size=10),
3189
            dtype=torch_to_numpy_dtype_dict[
3190
                input_dtype if (input_dtype != torch.bfloat16) else torch.float64
3191
            ],
3192
        )
3193
        input[0] = input[3] = input[7] = 0
3194
        values = np.array(
3195
            rng.integers(-10, 10, size=10),
3196
            dtype=torch_to_numpy_dtype_dict[
3197
                values_dtype if (values_dtype != torch.bfloat16) else torch.float64
3198
            ],
3199
        )
3200
        np_result = torch.from_numpy(np.heaviside(input, values)).to(
3201
            device=device, dtype=input_dtype
3202
        )
3203

3204
        input = torch.from_numpy(input).to(device=device, dtype=input_dtype)
3205
        values = torch.from_numpy(values).to(device=device, dtype=values_dtype)
3206
        out = torch.empty_like(input)
3207

3208
        if input_dtype == values_dtype:
3209
            torch_result = torch.heaviside(input, values)
3210
            self.assertEqual(np_result, torch_result)
3211

3212
            torch_result = input.heaviside(values)
3213
            self.assertEqual(np_result, torch_result)
3214

3215
            torch.heaviside(input, values, out=out)
3216
            self.assertEqual(np_result, out)
3217

3218
            input.heaviside_(values)
3219
            self.assertEqual(np_result, input)
3220
        else:
3221
            with self.assertRaisesRegex(
3222
                RuntimeError,
3223
                "heaviside is not yet implemented for tensors with different dtypes.",
3224
            ):
3225
                torch.heaviside(input, values)
3226
            with self.assertRaisesRegex(
3227
                RuntimeError,
3228
                "heaviside is not yet implemented for tensors with different dtypes.",
3229
            ):
3230
                input.heaviside(values)
3231
            with self.assertRaisesRegex(
3232
                RuntimeError,
3233
                "heaviside is not yet implemented for tensors with different dtypes.",
3234
            ):
3235
                torch.heaviside(input, values, out=out)
3236
            with self.assertRaisesRegex(
3237
                RuntimeError,
3238
                "heaviside is not yet implemented for tensors with different dtypes.",
3239
            ):
3240
                input.heaviside_(values)
3241

3242
    @onlyCUDA
3243
    def test_heaviside_cross_device(self, device):
3244
        x = torch.tensor([-9, 5, 0, 6, -2, 2], device=device)
3245
        y = torch.tensor(0)
3246
        result = torch.heaviside(x, y)
3247
        expect = torch.tensor([0, 1, 0, 1, 0, 1], device=device)
3248
        self.assertEqual(result, expect)
3249

3250
        result = torch.heaviside(y, x)
3251
        expect = torch.tensor([-9, 5, 0, 6, -2, 2], device=device)
3252
        self.assertEqual(result, expect)
3253

3254
        x = torch.tensor([-9, 5, 0, 6, -2, 2])
3255
        y = torch.tensor(0, device=device)
3256
        with self.assertRaisesRegex(
3257
            RuntimeError, "Expected all tensors to be on the same device"
3258
        ):
3259
            torch.heaviside(x, y)
3260

3261
        with self.assertRaisesRegex(
3262
            RuntimeError, "Expected all tensors to be on the same device"
3263
        ):
3264
            torch.heaviside(y, x)
3265

3266
    @dtypes(*list(product(complex_types(), complex_types())))
3267
    def test_heaviside_complex(self, device, dtypes):
3268
        input_dtype = dtypes[0]
3269
        values_dtype = dtypes[1]
3270

3271
        data = (complex(0, -6), complex(-1, 3), complex(1, 1))
3272
        input = torch.tensor(data, device=device, dtype=input_dtype)
3273
        values = torch.tensor(data, device=device, dtype=values_dtype)
3274
        out = torch.empty_like(input)
3275
        real = input.real
3276

3277
        with self.assertRaisesRegex(
3278
            RuntimeError, "heaviside is not yet implemented for complex tensors."
3279
        ):
3280
            torch.heaviside(input, real)
3281
        with self.assertRaisesRegex(
3282
            RuntimeError, "heaviside is not yet implemented for complex tensors."
3283
        ):
3284
            real.heaviside(values)
3285
        with self.assertRaisesRegex(
3286
            RuntimeError, "heaviside is not yet implemented for complex tensors."
3287
        ):
3288
            input.heaviside_(values)
3289
        with self.assertRaisesRegex(
3290
            RuntimeError, "heaviside is not yet implemented for complex tensors."
3291
        ):
3292
            torch.heaviside(real, real, out=out)
3293

3294
    def _test_logical(self, device, dtypes, op, a_, b_, expected_res_):
3295
        expected_res = torch.tensor(expected_res_, dtype=dtypes[0], device=device)
3296
        a = torch.tensor(a_, dtype=dtypes[0], device=device)
3297
        b = torch.tensor(b_, dtype=dtypes[1], device=device)
3298

3299
        # new tensor
3300
        self.assertEqual(expected_res.bool(), getattr(a, op)(b))
3301
        # out
3302
        c = torch.empty(0, dtype=torch.bool, device=device)
3303
        getattr(torch, op)(a, b, out=c)
3304
        self.assertEqual(expected_res.bool(), c)
3305

3306
        getattr(a, op + "_")(b)
3307
        self.assertEqual(expected_res, a)
3308

3309
    @dtypes(
3310
        *product(
3311
            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3312
            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3313
        )
3314
    )
3315
    def test_logical_xor(self, device, dtypes):
3316
        self._test_logical(
3317
            device, dtypes, "logical_xor", [10, 0, 1, 0], [1, 0, 0, 10], [0, 0, 1, 1]
3318
        )
3319

3320
    @dtypes(
3321
        *product(
3322
            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3323
            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3324
        )
3325
    )
3326
    def test_logical_and(self, device, dtypes):
3327
        self._test_logical(
3328
            device, dtypes, "logical_and", [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 0, 0]
3329
        )
3330

3331
    @dtypes(
3332
        *product(
3333
            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3334
            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3335
        )
3336
    )
3337
    def test_logical_or(self, device, dtypes):
3338
        self._test_logical(
3339
            device, dtypes, "logical_or", [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1]
3340
        )
3341

3342
    def test_remainder_overflow(self, device):
3343
        # Check Integer Overflows
3344
        x = torch.tensor(23500, dtype=torch.int64, device=device)
3345
        q = 392486996410368
3346
        self.assertEqual(x % q, x)
3347
        self.assertEqual(-x % q, q - x)
3348
        self.assertEqual(x % -q, x - q)
3349
        self.assertEqual(-x % -q, -x)
3350

3351
    def test_rpow(self, device):
3352
        m = torch.randn(10, 10, device=device)
3353
        self.assertEqual(torch.pow(2, m), 2**m)
3354

3355
        # test with scalar
3356
        m = torch.randn(1, device=device).squeeze()
3357
        assert m.dim() == 0, "m is intentionally a scalar"
3358
        self.assertEqual(torch.pow(2, m), 2**m)
3359

3360
    @onlyCPU
3361
    def test_ldexp(self, device):
3362
        # random values
3363
        mantissas = torch.randn(64, device=device)
3364
        exponents = torch.randint(-31, 31, (64,), device=device, dtype=torch.int32)
3365

3366
        # basic test
3367
        np_outcome = np.ldexp(mantissas.numpy(), exponents.numpy())
3368
        pt_outcome_1 = torch.ldexp(mantissas, exponents)
3369
        pt_outcome_2 = mantissas.ldexp(exponents)
3370
        self.assertEqual(np_outcome, pt_outcome_1)
3371
        self.assertEqual(np_outcome, pt_outcome_2)
3372
        mantissas.ldexp_(exponents)
3373
        self.assertEqual(np_outcome, mantissas)
3374

3375
        # test bounds
3376
        mantissas = torch.tensor(
3377
            [float("inf"), float("-inf"), float("inf"), float("nan")], device=device
3378
        )
3379
        exponents = torch.randint(0, 31, (4,), device=device, dtype=torch.int32)
3380
        np_outcome = np.ldexp(mantissas.numpy(), exponents.numpy())
3381
        pt_outcome = torch.ldexp(mantissas, exponents)
3382
        self.assertEqual(np_outcome, pt_outcome)
3383

3384
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3385
    def test_lerp(self, device, dtype):
3386
        start_end_weight_shapes = [(), (5,), (5, 5)]
3387
        for shapes in product(
3388
            start_end_weight_shapes, start_end_weight_shapes, start_end_weight_shapes
3389
        ):
3390
            start = torch.randn(shapes[0], device=device, dtype=dtype)
3391
            end = torch.randn(shapes[1], device=device, dtype=dtype)
3392

3393
            # Tensor weights
3394
            weights = [
3395
                torch.randn(shapes[2], device=device, dtype=dtype),
3396
                random.random(),
3397
            ]
3398
            if dtype.is_complex:
3399
                weights += [complex(0, 1), complex(0.4, 1.2)]
3400

3401
            for weight in weights:
3402
                actual = torch.lerp(start, end, weight)
3403
                actual_method = start.lerp(end, weight)
3404
                self.assertEqual(actual, actual_method)
3405
                actual_out = torch.tensor(1.0, dtype=dtype, device=device)
3406
                torch.lerp(start, end, weight, out=actual_out)
3407
                self.assertEqual(actual, actual_out)
3408
                expected = start + weight * (end - start)
3409
                self.assertEqual(expected, actual)
3410

3411
    @onlyCUDA
3412
    @dtypes(torch.half, torch.bfloat16)
3413
    def test_lerp_lowp(self, device, dtype):
3414
        xvals = (0.0, -30000.0)
3415
        yvals = (0.1, -20000.0)
3416
        xs = [torch.full((4,), xval, device=device, dtype=dtype) for xval in xvals]
3417
        ys = [torch.full((4,), yval, device=device, dtype=dtype) for yval in yvals]
3418
        weights = [70000, torch.full((4,), 8, device=device, dtype=dtype)]
3419
        for x, y, w in zip(xs, ys, weights):
3420
            xref = x.float()
3421
            yref = y.float()
3422
            wref = w.float() if isinstance(w, torch.Tensor) else w
3423
            actual = torch.lerp(x, y, w)
3424
            expected = torch.lerp(xref, yref, wref).to(dtype)
3425
            self.assertEqual(actual, expected, atol=0.0, rtol=0.0)
3426

3427
    @onlyCPU
3428
    @dtypes(torch.half, torch.bfloat16)
3429
    def test_lerp_lowp_cpu(self, device, dtype):
3430
        xvals = (0.0, -30000.0)
3431
        yvals = (0.1, -20000.0)
3432
        for shape in [(4,), (20,), (3, 10, 10)]:
3433
            xs = [torch.full(shape, xval, device=device, dtype=dtype) for xval in xvals]
3434
            ys = [torch.full(shape, yval, device=device, dtype=dtype) for yval in yvals]
3435
            weights = [70000, torch.full(shape, 8, device=device, dtype=dtype)]
3436
            for x, y, w in zip(xs, ys, weights):
3437
                xref = x.float()
3438
                yref = y.float()
3439
                wref = w.float() if isinstance(w, torch.Tensor) else w
3440
                actual = torch.lerp(x, y, w)
3441
                expected = torch.lerp(xref, yref, wref).to(dtype)
3442
                self.assertEqual(actual, expected, atol=0.0, rtol=0.0)
3443

3444
    def _test_logaddexp(self, device, dtype, base2):
3445
        if base2:
3446
            ref_func = np.logaddexp2
3447
            our_func = torch.logaddexp2
3448
        elif dtype in (torch.complex64, torch.complex128):
3449
            # numpy has not implemented logaddexp for complex
3450
            def _ref_func(x, y):
3451
                return scipy.special.logsumexp(np.stack((x, y), axis=0), axis=0)
3452
            ref_func = _ref_func
3453
            our_func = torch.logaddexp
3454
        else:
3455
            ref_func = np.logaddexp
3456
            our_func = torch.logaddexp
3457

3458
        def _test_helper(a, b):
3459
            if dtype == torch.bfloat16:
3460
                ref = ref_func(a.cpu().float().numpy(), b.cpu().float().numpy())
3461
                v = our_func(a, b)
3462
                self.assertEqual(ref, v.float(), atol=0.01, rtol=0.01)
3463
            else:
3464
                ref = ref_func(a.cpu().numpy(), b.cpu().numpy())
3465
                v = our_func(a, b)
3466
                self.assertEqual(ref, v)
3467

3468
        # simple test
3469
        a = torch.randn(64, 2, dtype=dtype, device=device) - 0.5
3470
        b = torch.randn(64, 2, dtype=dtype, device=device) - 0.5
3471
        _test_helper(a, b)
3472
        _test_helper(a[:3], b[:3])
3473

3474
        # large value test for numerical stability
3475
        a *= 10000
3476
        b *= 10000
3477
        _test_helper(a, b)
3478
        _test_helper(a[:3], b[:3])
3479

3480
        a = torch.tensor(
3481
            [float("inf"), float("-inf"), float("inf"), float("nan")],
3482
            dtype=dtype,
3483
            device=device,
3484
        )
3485
        b = torch.tensor(
3486
            [float("inf"), float("-inf"), float("-inf"), float("nan")],
3487
            dtype=dtype,
3488
            device=device,
3489
        )
3490
        _test_helper(a, b)
3491

3492
    @skipIfTorchDynamo()    # complex infs/nans differ under Dynamo/Inductor
3493
    @dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16)
3494
    @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.complex64, torch.complex128)
3495
    def test_logaddexp(self, device, dtype):
3496
        self._test_logaddexp(device, dtype, base2=False)
3497

3498
    @dtypes(torch.float32, torch.float64, torch.bfloat16)
3499
    def test_logaddexp2(self, device, dtype):
3500
        self._test_logaddexp(device, dtype, base2=True)
3501

3502
    def test_add(self, device):
3503
        dtypes = floating_and_complex_types()
3504
        for dtype in dtypes:
3505
            # [res] torch.add([res,] tensor1, tensor2)
3506
            m1 = torch.randn(100, 100, dtype=dtype, device=device)
3507
            v1 = torch.randn(100, dtype=dtype, device=device)
3508

3509
            # contiguous
3510
            res1 = torch.add(m1[4], v1)
3511
            res2 = res1.clone().zero_()
3512
            for i in range(m1.size(1)):
3513
                res2[i] = m1[4, i] + v1[i]
3514
            self.assertEqual(res1, res2)
3515

3516
            m1 = torch.randn(100, 100, device=device)
3517
            v1 = torch.randn(100, device=device)
3518

3519
            # non-contiguous
3520
            res1 = torch.add(m1[:, 4], v1)
3521
            res2 = res1.clone().zero_()
3522
            for i in range(m1.size(0)):
3523
                res2[i] = m1[i, 4] + v1[i]
3524
            self.assertEqual(res1, res2)
3525

3526
            # [res] torch.add([res,] tensor, value)
3527
            m1 = torch.randn(10, 10, device=device)
3528

3529
            # contiguous
3530
            res1 = m1.clone()
3531
            res1[3].add_(2)
3532
            res2 = m1.clone()
3533
            for i in range(m1.size(1)):
3534
                res2[3, i] = res2[3, i] + 2
3535
            self.assertEqual(res1, res2)
3536

3537
            # non-contiguous
3538
            m1 = torch.randn(10, 10, device=device)
3539
            res1 = m1.clone()
3540
            res1[:, 3].add_(2)
3541
            res2 = m1.clone()
3542
            for i in range(m1.size(0)):
3543
                res2[i, 3] = res2[i, 3] + 2
3544
            self.assertEqual(res1, res2)
3545

3546
            # inter-type
3547
            m1 = torch.randn(10, 10, dtype=dtype, device=device)
3548
            self.assertEqual(m1 + 3, m1 + torch.tensor(3))
3549
            self.assertEqual(3 + m1, torch.tensor(3) + m1)
3550

3551
            # contiguous + non-contiguous
3552
            m1 = torch.randn(10, 10, dtype=dtype, device=device)
3553
            m2 = torch.randn(10, 10, dtype=dtype, device=device).t()
3554
            res = m1 + m2
3555
            self.assertTrue(res.is_contiguous())
3556
            self.assertEqual(res, m1 + m2.contiguous())
3557

3558
            # 1d + empty
3559
            m1 = torch.tensor([1.0], dtype=dtype, device=device)
3560
            m2 = torch.tensor([], dtype=dtype, device=device)
3561
            self.assertEqual(m1 + m2, [])
3562

3563
        # inter-type unint8
3564
        one = torch.tensor(1, dtype=torch.uint8, device=device)
3565
        self.assertEqual(torch.add(one, 1), 2)
3566
        self.assertEqual(torch.add(one, 1).dtype, torch.uint8)
3567

3568
        # bool
3569
        m1 = torch.tensor(
3570
            [True, False, False, True, False, False], dtype=torch.bool, device=device
3571
        )
3572
        m2 = torch.tensor(
3573
            [True, True, False, False, False, True], dtype=torch.bool, device=device
3574
        )
3575
        expected = torch.tensor(
3576
            [True, True, False, True, False, True], dtype=torch.bool, device=device
3577
        )
3578
        self.assertEqual(m1 + m2, expected)
3579

3580
        # fused multiply add
3581
        a = torch.zeros(2, 3, dtype=torch.bool, device=device)
3582
        res = torch.add(a, a, alpha=0)
3583
        expected = torch.zeros(2, 3, device=device).bool()
3584
        self.assertEqual(res, expected)
3585

3586
        # bfloat16
3587
        m1 = torch.tensor([1.0, 2.0], dtype=torch.bfloat16)
3588
        m2 = torch.tensor([3.0, 4.0], dtype=torch.bfloat16)
3589
        self.assertEqual(m1 + m2, torch.tensor([4.0, 6.0], dtype=torch.bfloat16))
3590

3591
        # different alpha types
3592
        m1 = torch.tensor([2 + 3j, 4 + 5j], dtype=torch.complex64, device=device)
3593
        m2 = torch.tensor([4 + 5j, 2 + 3j], dtype=torch.complex64, device=device)
3594
        # add complex numbers with float alpha
3595
        res = torch.add(m1, m2, alpha=0.1)
3596
        expected = torch.tensor(
3597
            [2.4000 + 3.5000j, 4.2000 + 5.3000j], dtype=torch.complex64, device=device
3598
        )
3599
        self.assertEqual(res, expected)
3600

3601
        # add complex numbers with complex alpha
3602
        res = torch.add(m1, m2, alpha=complex(0.1, 0.2))
3603
        expected = torch.tensor(
3604
            [1.4000 + 4.3000j, 3.6000 + 5.7000j], dtype=torch.complex64, device=device
3605
        )
3606
        self.assertEqual(res, expected)
3607

3608
        # add complex numbers with integer alpha
3609
        res = torch.add(m1, m2, alpha=2)
3610
        expected = torch.tensor(
3611
            [10.0 + 13.0j, 8.0 + 11.0j], dtype=torch.complex64, device=device
3612
        )
3613
        self.assertEqual(res, expected)
3614

3615
        # mismatched alpha
3616
        m1 = torch.tensor([1], dtype=torch.int8, device=device)
3617
        m2 = torch.tensor([2], dtype=torch.int8, device=device)
3618
        self.assertRaisesRegex(
3619
            RuntimeError,
3620
            r"Boolean alpha only supported for Boolean results\.",
3621
            lambda: torch.add(m1, m2, alpha=True),
3622
        )
3623
        self.assertRaisesRegex(
3624
            RuntimeError,
3625
            r"For integral input tensors, argument alpha must not be a floating point number\.",
3626
            lambda: torch.add(m1, m2, alpha=1.0),
3627
        )
3628

3629
        # mismatched alpha, float / double tensor and complex alpha
3630
        msg = r"For non-complex input tensors, argument alpha must not be a complex number\."
3631
        m1 = torch.tensor([3.0, 4.0], device=device)
3632
        m2 = torch.tensor([4.0, 3.0], device=device)
3633
        self.assertRaisesRegex(
3634
            RuntimeError, msg, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))
3635
        )
3636

3637
        m1 = torch.tensor([3.0, 4.0], dtype=torch.double, device=device)
3638
        m2 = torch.tensor([4.0, 3.0], dtype=torch.double, device=device)
3639
        self.assertRaisesRegex(
3640
            RuntimeError, msg, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))
3641
        )
3642

3643
        # complex
3644
        m1 = torch.tensor((4.0000 + 4.0000j), dtype=torch.complex64)
3645
        m2 = torch.tensor(4.0, dtype=torch.float64)
3646
        self.assertRaisesRegex(
3647
            RuntimeError,
3648
            r"result type ComplexFloat can't be cast to the desired output type Double",
3649
            lambda: torch.add(m1, m1, out=m2),
3650
        )
3651

3652
    @onlyCUDA
3653
    def test_addsub_half_tensor(self, device):
3654
        x = torch.tensor([60000.0], dtype=torch.half, device=device)
3655
        for op, y, alpha in (
3656
            (torch.add, torch.tensor([-60000.0], dtype=torch.half, device=device), 2),
3657
            (torch.sub, torch.tensor([60000.0], dtype=torch.half, device=device), 2),
3658
            (torch.add, -70000.0, 1),
3659
            (torch.sub, 70000.0, 1),
3660
        ):
3661
            actual = op(x, y, alpha=alpha)
3662
            self.assertTrue(not (actual.isnan() or actual.isinf()))
3663

3664
    def test_sub_typing(self, device):
3665
        m1 = torch.tensor(
3666
            [True, False, False, True, False, False], dtype=torch.bool, device=device
3667
        )
3668
        m2 = torch.tensor(
3669
            [True, True, False, False, False, True], dtype=torch.bool, device=device
3670
        )
3671
        self.assertRaisesRegex(
3672
            RuntimeError,
3673
            r"Subtraction, the `\-` operator, with two bool tensors is not supported. "
3674
            r"Use the `\^` or `logical_xor\(\)` operator instead.",
3675
            lambda: m1 - m2,
3676
        )
3677
        self.assertRaisesRegex(
3678
            RuntimeError,
3679
            r"Subtraction, the `\-` operator, with a bool tensor is not supported. "
3680
            r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.",
3681
            lambda: 1 - m1,
3682
        )
3683
        self.assertRaisesRegex(
3684
            RuntimeError,
3685
            r"Subtraction, the `\-` operator, with a bool tensor is not supported. "
3686
            r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.",
3687
            lambda: m2 - 1,
3688
        )
3689

3690
        # mismatched alpha
3691
        m1 = torch.tensor([1], dtype=torch.int8, device=device)
3692
        m2 = torch.tensor([2], dtype=torch.int8, device=device)
3693
        self.assertRaisesRegex(
3694
            RuntimeError,
3695
            r"Boolean alpha only supported for Boolean results\.",
3696
            lambda: torch.sub(m1, m2, alpha=True),
3697
        )
3698
        self.assertRaisesRegex(
3699
            RuntimeError,
3700
            r"For integral input tensors, argument alpha must not be a floating point number\.",
3701
            lambda: torch.sub(m1, m2, alpha=1.0),
3702
        )
3703

3704
    def test_mul(self, device):
3705
        m1 = torch.randn(10, 10, device=device)
3706
        res1 = m1.clone()
3707
        res1[:, 3].mul_(2)
3708
        res2 = m1.clone()
3709
        for i in range(res1.size(0)):
3710
            res2[i, 3] = res2[i, 3] * 2
3711
        self.assertEqual(res1, res2)
3712

3713
        a1 = torch.tensor([True, False, False, True], dtype=torch.bool, device=device)
3714
        a2 = torch.tensor([True, False, True, False], dtype=torch.bool, device=device)
3715
        self.assertEqual(
3716
            a1 * a2,
3717
            torch.tensor([True, False, False, False], dtype=torch.bool, device=device),
3718
        )
3719

3720
        if device == "cpu":
3721
            a1 = torch.tensor([0.1, 0.1], dtype=torch.bfloat16, device=device)
3722
            a2 = torch.tensor([1.1, 0.1], dtype=torch.bfloat16, device=device)
3723
            self.assertEqual(
3724
                a1 * a2,
3725
                torch.tensor([0.11, 0.01], dtype=torch.bfloat16, device=device),
3726
                atol=0.01,
3727
                rtol=0,
3728
            )
3729
            self.assertEqual(a1.mul(a2), a1 * a2)
3730

3731
    def test_bool_tensor_comparison_ops(self, device):
3732
        a = torch.tensor(
3733
            [True, False, True, False, True, False], dtype=torch.bool, device=device
3734
        )
3735
        b = torch.tensor(
3736
            [True, False, True, True, True, True], dtype=torch.bool, device=device
3737
        )
3738
        self.assertEqual(
3739
            a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)
3740
        )
3741
        self.assertEqual(
3742
            a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)
3743
        )
3744
        self.assertEqual(
3745
            a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)
3746
        )
3747
        self.assertEqual(
3748
            a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.bool, device=device)
3749
        )
3750
        self.assertEqual(
3751
            a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)
3752
        )
3753
        self.assertEqual(
3754
            a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.bool, device=device)
3755
        )
3756
        self.assertEqual(
3757
            a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)
3758
        )
3759
        self.assertEqual(
3760
            a == torch.tensor(True, dtype=torch.bool, device=device),
3761
            torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device),
3762
        )
3763
        self.assertEqual(
3764
            a == torch.tensor(0, dtype=torch.bool, device=device),
3765
            torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device),
3766
        )
3767
        self.assertFalse(a.equal(b))
3768

3769
    @dtypes(*all_types_and(torch.half, torch.bfloat16, torch.bool))
3770
    def test_logical(self, device, dtype):
3771
        if dtype != torch.bool:
3772
            x = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype)
3773
            b = torch.tensor([2], device=device, dtype=dtype)
3774
            self.assertEqual(x.lt(2), torch.tensor([True, False, False, False]))
3775
            self.assertEqual(x.le(2), torch.tensor([True, True, False, False]))
3776
            self.assertEqual(x.ge(2), torch.tensor([False, True, True, True]))
3777
            self.assertEqual(x.gt(2), torch.tensor([False, False, True, True]))
3778
            self.assertEqual(x.eq(2), torch.tensor([False, True, False, False]))
3779
            self.assertEqual(x.ne(2), torch.tensor([True, False, True, True]))
3780

3781
            self.assertEqual(x.lt(b), torch.tensor([True, False, False, False]))
3782
            self.assertEqual(x.le(b), torch.tensor([True, True, False, False]))
3783
            self.assertEqual(x.ge(b), torch.tensor([False, True, True, True]))
3784
            self.assertEqual(x.gt(b), torch.tensor([False, False, True, True]))
3785
            self.assertEqual(x.eq(b), torch.tensor([False, True, False, False]))
3786
            self.assertEqual(x.ne(b), torch.tensor([True, False, True, True]))
3787
        else:
3788
            x = torch.tensor([True, False, True, False], device=device)
3789
            self.assertEqual(x.lt(True), torch.tensor([False, True, False, True]))
3790
            self.assertEqual(x.le(True), torch.tensor([True, True, True, True]))
3791
            self.assertEqual(x.ge(True), torch.tensor([True, False, True, False]))
3792
            self.assertEqual(x.gt(True), torch.tensor([False, False, False, False]))
3793
            self.assertEqual(x.eq(True), torch.tensor([True, False, True, False]))
3794
            self.assertEqual(x.ne(True), torch.tensor([False, True, False, True]))
3795

3796
    def test_atan2(self, device):
3797
        def _test_atan2_with_size(size, device):
3798
            a = torch.rand(size=size, device=device, dtype=torch.double)
3799
            b = torch.rand(size=size, device=device, dtype=torch.double)
3800
            actual = a.atan2(b)
3801
            x = a.view(-1)
3802
            y = b.view(-1)
3803
            expected = torch.tensor(
3804
                [math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())],
3805
                device=device,
3806
                dtype=torch.double,
3807
            )
3808
            self.assertEqual(expected, actual.view(-1), rtol=0, atol=0.02)
3809

3810
            # bfloat16/float16
3811
            for lowp_dtype in [torch.bfloat16, torch.float16]:
3812
                if lowp_dtype == torch.bfloat16:
3813
                    rtol = 0
3814
                    atol = 0.02
3815
                else:
3816
                    rtol = 0
3817
                    atol = 0.001
3818
                a_16 = a.to(dtype=lowp_dtype)
3819
                b_16 = b.to(dtype=lowp_dtype)
3820
                actual_16 = a_16.atan2(b_16)
3821
                self.assertEqual(actual_16, actual.to(dtype=lowp_dtype))
3822
                self.assertEqual(expected, actual_16.view(-1), exact_dtype=False, rtol=rtol, atol=atol)
3823

3824
        _test_atan2_with_size((2, 2), device)
3825
        _test_atan2_with_size((3, 3), device)
3826
        _test_atan2_with_size((5, 5), device)
3827

3828
    def test_atan2_edgecases(self, device):
3829
        def _test_atan2(x, y, expected, device, dtype):
3830
            expected_tensor = torch.tensor([expected], dtype=dtype, device=device)
3831
            x_tensor = torch.tensor([x], dtype=dtype, device=device)
3832
            y_tensor = torch.tensor([y], dtype=dtype, device=device)
3833
            actual = torch.atan2(y_tensor, x_tensor)
3834
            self.assertEqual(expected_tensor, actual, rtol=0, atol=0.02)
3835

3836
        for dtype in [torch.float, torch.double]:
3837
            _test_atan2(0, 0, 0, device, dtype)
3838
            _test_atan2(0, 1, math.pi / 2, device, dtype)
3839
            _test_atan2(0, -1, math.pi / -2, device, dtype)
3840
            _test_atan2(-1, 0, math.pi, device, dtype)
3841
            _test_atan2(1, 0, 0, device, dtype)
3842
            _test_atan2(-1, -1, math.pi * -3 / 4, device, dtype)
3843
            _test_atan2(1, 1, math.pi / 4, device, dtype)
3844
            _test_atan2(1, -1, math.pi / -4, device, dtype)
3845
            _test_atan2(-1, 1, math.pi * 3 / 4, device, dtype)
3846

3847
    def test_trapezoid(self, device):
3848
        def test_dx(sizes, dim, dx, device):
3849
            t = torch.randn(sizes, device=device)
3850
            actual = torch.trapezoid(t, dx=dx, dim=dim)
3851
            expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim)
3852
            self.assertEqual(expected.shape, actual.shape)
3853
            self.assertEqual(expected, actual, exact_dtype=False)
3854

3855
        def test_x(sizes, dim, x, device):
3856
            t = torch.randn(sizes, device=device)
3857
            actual = torch.trapezoid(t, x=torch.tensor(x, device=device), dim=dim)
3858
            expected = np.trapz(t.cpu().numpy(), x=x, axis=dim)
3859
            self.assertEqual(expected.shape, actual.shape)
3860
            self.assertEqual(expected, actual.cpu(), exact_dtype=False)
3861

3862
        test_dx((2, 3, 4), 1, 1, device)
3863
        test_dx((10, 2), 0, 0.1, device)
3864
        test_dx((1, 10), 0, 2.3, device)
3865
        test_dx((0, 2), 0, 1.0, device)
3866
        test_dx((0, 2), 1, 1.0, device)
3867
        test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
3868
        test_x(
3869
            (10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device
3870
        )
3871
        test_x((1, 10), 0, [1.0], device)
3872
        test_x((0, 2), 0, [], device)
3873
        test_x((0, 2), 1, [1.0, 2.0], device)
3874
        test_x((2, 3, 4), -1, [1.0, 2.0, 3.0, 4.0], device)
3875
        test_x((2, 3, 4), 0, [1.0, 2.0], device)
3876
        test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
3877
        test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device)
3878
        test_x((2, 2, 4), -1, [[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], device)
3879
        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
3880
            test_x((2, 3), 2, [], device)
3881
            test_dx((2, 3), 2, 1.0, device)
3882
        with self.assertRaisesRegex(
3883
            RuntimeError, "There must be one `x` value for each sample point"
3884
        ):
3885
            test_x((2, 3), 1, [1.0, 2.0], device)
3886
            test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device)
3887

3888
    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
3889
    # This is failing on Python 3.12. https://github.com/pytorch/pytorch/issues/119462
3890
    @skipIf(
3891
        sys.version_info >= (3, 12), "Failing on Python 3.12"
3892
    )
3893
    def test_cumulative_trapezoid(self, device):
3894

3895
        import scipy.integrate
3896

3897
        if hasattr(scipy.integrate, "cumulative_trapezoid"):
3898
            scipy_cumulative_trapezoid = scipy.integrate.cumulative_trapezoid
3899
        else:  # Older version of SciPy uses a different name
3900
            scipy_cumulative_trapezoid = scipy.integrate.cumtrapz
3901

3902
        def test_dx(sizes, dim, dx, device):
3903
            t = torch.randn(sizes, device=device)
3904
            y = t.cpu().numpy()
3905
            actual = torch.cumulative_trapezoid(t, dx=dx, dim=dim)
3906
            expected = scipy_cumulative_trapezoid(t.cpu().numpy(), dx=dx, axis=dim)
3907
            self.assertEqual(expected.shape, actual.shape)
3908
            self.assertEqual(expected, actual, exact_dtype=False, atol=1e-4, rtol=1e-4)
3909

3910
        def test_x(sizes, dim, x, device):
3911
            t = torch.randn(sizes, device=device)
3912
            actual = torch.cumulative_trapezoid(
3913
                t, x=torch.tensor(x, device=device), dim=dim
3914
            )
3915
            expected = scipy_cumulative_trapezoid(t.cpu().numpy(), x=x, axis=dim)
3916
            self.assertEqual(expected.shape, actual.shape)
3917
            self.assertEqual(
3918
                expected, actual.cpu(), exact_dtype=False, atol=1e-4, rtol=1e-4
3919
            )
3920

3921
        def test_empty_x(sizes, dim, x, device):
3922
            t = torch.randn(sizes, device=device)
3923
            actual = torch.cumulative_trapezoid(
3924
                t, x=torch.tensor(x, device=device), dim=dim
3925
            )
3926
            self.assertEqual(torch.empty(actual.shape), actual)
3927

3928
        test_dx((2,), -1, 1, device)
3929
        test_dx((3, 3), -1, 1, device)
3930
        test_dx((4, 2), 0, 1, device)
3931
        test_dx((2, 3, 4), 1, 1, device)
3932
        test_dx((10, 2), 0, 0.1, device)
3933
        test_dx((1, 10), 0, 2.3, device)
3934
        test_dx((0, 2), 0, 1.0, device)
3935
        test_dx((0, 2), 1, 1.0, device)
3936
        test_dx((512, 512), 1, 1.0, device)
3937
        test_dx((100, 100, 100), 1, 1.0, device)
3938

3939
        test_x((2,), -1, [100, 50], device)
3940
        test_x((4, 2), 0, [2, 3, 4, 5], device)
3941
        test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
3942
        test_x(
3943
            (10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device
3944
        )
3945
        test_x((1, 10), 0, [1.0], device)
3946
        test_x((0, 2), 1, [1, 2], device)
3947
        test_x((2, 3, 4), -1, [1.0, 2.0, 3.0, 4.0], device)
3948
        test_x((2, 3, 4), 0, [1.0, 2.0], device)
3949
        test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
3950
        test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device)
3951

3952
        test_empty_x(
3953
            (0, 2), 0, [], device
3954
        )  # SciPy failing when x == [], but our version returns empty
3955

3956
        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
3957
            test_x((2, 3), 2, [], device)
3958
            test_dx((2, 3), 2, 1.0, device)
3959
        with self.assertRaisesRegex(
3960
            RuntimeError, "There must be one `x` value for each sample point"
3961
        ):
3962
            test_x((2, 3), 1, [1.0, 2.0], device)
3963
            test_x((0, 2), 0, [1.0, 2.0], device)
3964
            test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device)
3965
        with self.assertRaisesRegex(
3966
            RuntimeError, "Currently, we only support dx as a real number"
3967
        ):
3968
            test_dx((2, 2), -1, complex(1, 1), device)
3969
        with self.assertRaisesRegex(
3970
            TypeError, "received an invalid combination of arguments"
3971
        ):
3972
            actual = torch.cumulative_trapezoid(
3973
                torch.randn((3, 3)), x=torch.randn((3, 3)), dx=3
3974
            )
3975

3976
    @skipMeta
3977
    @dtypes(torch.double)
3978
    def test_pow_scalar_overloads_mem_overlap(self, device, dtype):
3979
        sz = 3
3980
        doubles = torch.randn(2 * sz, dtype=dtype, device=device)
3981
        self.check_internal_mem_overlap(lambda t: t.pow_(42), 1, dtype, device)
3982
        self.unary_check_input_output_mem_overlap(
3983
            doubles, sz, lambda input, out: torch.pow(input, 42, out=out)
3984
        )
3985
        self.unary_check_input_output_mem_overlap(
3986
            doubles, sz, lambda input, out: torch.pow(42, input, out=out)
3987
        )
3988

3989
    @dtypes(
3990
        *list(
3991
            product(
3992
                all_types_and_complex_and(torch.half, torch.bfloat16),
3993
                all_types_and_complex_and(torch.half, torch.bfloat16),
3994
            )
3995
        )
3996
    )
3997
    def test_float_power(self, device, dtypes):
3998
        def to_np(value):
3999
            if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16:
4000
                return value.to(torch.float).cpu().numpy()
4001
            return value.cpu().numpy() if isinstance(value, torch.Tensor) else value
4002

4003
        base_dtype = dtypes[0]
4004
        exp_dtype = dtypes[1]
4005
        out_dtype = (
4006
            torch.complex128
4007
            if base_dtype.is_complex or exp_dtype.is_complex
4008
            else torch.float64
4009
        )
4010

4011
        base = make_tensor((30,), dtype=base_dtype, device=device, low=1, high=100)
4012
        # Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0
4013
        # Related: https://github.com/pytorch/pytorch/issues/48000
4014
        # base[0] = base[3] = base[7] = 0
4015
        exp = make_tensor((30,), dtype=exp_dtype, device=device, low=-2, high=2)
4016
        exp[0] = exp[4] = exp[6] = 0
4017

4018
        expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp)))
4019

4020
        exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2]
4021
        complex_exponents = exponents + [
4022
            -2.5j,
4023
            -1.0j,
4024
            1.0j,
4025
            2.5j,
4026
            1.0 + 1.0j,
4027
            -1.0 - 1.5j,
4028
            3.3j,
4029
        ]
4030

4031
        for op in (
4032
            torch.float_power,
4033
            torch.Tensor.float_power,
4034
            torch.Tensor.float_power_,
4035
        ):
4036

4037
            # Case of Tensor x Tensor
4038
            if op is torch.Tensor.float_power_ and base_dtype != out_dtype:
4039
                with self.assertRaisesRegex(
4040
                    RuntimeError, "operation's result requires dtype"
4041
                ):
4042
                    op(base.clone(), exp)
4043
            else:
4044
                result = op(base.clone(), exp)
4045
                self.assertEqual(expected, result)
4046

4047
            if op is torch.float_power:
4048
                out = torch.empty_like(base).to(device=device, dtype=out_dtype)
4049
                op(base, exp, out=out)
4050
                self.assertEqual(expected, out)
4051

4052
            # Case of Tensor x Scalar
4053
            for i in complex_exponents if exp_dtype.is_complex else exponents:
4054
                out_dtype_scalar_exp = (
4055
                    torch.complex128
4056
                    if base_dtype.is_complex or type(i) == complex
4057
                    else torch.float64
4058
                )
4059
                expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i))
4060

4061
                if (
4062
                    op is torch.Tensor.float_power_
4063
                    and base_dtype != out_dtype_scalar_exp
4064
                ):
4065
                    with self.assertRaisesRegex(
4066
                        RuntimeError, "operation's result requires dtype"
4067
                    ):
4068
                        op(base.clone(), i)
4069
                else:
4070
                    result = op(base.clone(), i)
4071
                    self.assertEqual(expected_scalar_exp, result)
4072

4073
                if op is torch.float_power:
4074
                    out = torch.empty_like(base).to(
4075
                        device=device, dtype=out_dtype_scalar_exp
4076
                    )
4077
                    op(base, i, out=out)
4078
                    self.assertEqual(expected_scalar_exp, out)
4079

4080
        # Case of Scalar x Tensor
4081
        for i in complex_exponents if base_dtype.is_complex else exponents:
4082
            out_dtype_scalar_base = (
4083
                torch.complex128
4084
                if exp_dtype.is_complex or type(i) == complex
4085
                else torch.float64
4086
            )
4087
            expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp)))
4088

4089
            result = torch.float_power(i, exp)
4090
            self.assertEqual(expected_scalar_base, result)
4091

4092
            out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base)
4093
            torch.float_power(i, exp, out=out)
4094
            self.assertEqual(expected_scalar_base, out)
4095

4096
    def test_float_power_exceptions(self, device):
4097
        def _promo_helper(x, y):
4098
            for i in (x, y):
4099
                if type(i) == complex:
4100
                    return torch.complex128
4101
                elif type(i) == torch.Tensor and i.is_complex():
4102
                    return torch.complex128
4103
            return torch.double
4104

4105
        test_cases = (
4106
            (torch.tensor([-2, -1, 0, 1, 2], device=device), -0.25),
4107
            (
4108
                torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device),
4109
                2.0,
4110
            ),
4111
        )
4112
        for base, exp in test_cases:
4113
            for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble):
4114
                out = torch.empty(1, device=device, dtype=out_dtype)
4115
                required_dtype = _promo_helper(base, exp)
4116

4117
                if out.dtype == required_dtype:
4118
                    torch.float_power(base, exp, out=out)
4119
                else:
4120
                    with self.assertRaisesRegex(
4121
                        RuntimeError, "operation's result requires dtype"
4122
                    ):
4123
                        torch.float_power(base, exp, out=out)
4124

4125
                if base.dtype == required_dtype:
4126
                    torch.Tensor.float_power_(base.clone(), exp)
4127
                else:
4128
                    with self.assertRaisesRegex(
4129
                        RuntimeError, "operation's result requires dtype"
4130
                    ):
4131
                        torch.Tensor.float_power_(base.clone(), exp)
4132

4133
    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
4134
    @dtypes(
4135
        *product(
4136
            all_types_and(torch.half, torch.bool), all_types_and(torch.half, torch.bool)
4137
        )
4138
    )
4139
    def test_xlogy_xlog1py(self, device, dtypes):
4140
        x_dtype, y_dtype = dtypes
4141

4142
        def out_variant_helper(torch_fn, x, y):
4143
            expected = torch_fn(x, y)
4144
            out = torch.empty_like(expected)
4145
            torch_fn(x, y, out=out)
4146
            self.assertEqual(expected, out)
4147

4148
        def xlogy_inplace_variant_helper(x, y):
4149
            if x.dtype in integral_types_and(torch.bool):
4150
                with self.assertRaisesRegex(
4151
                    RuntimeError, "can't be cast to the desired output type"
4152
                ):
4153
                    x.clone().xlogy_(y)
4154
            else:
4155
                expected = torch.empty_like(x)
4156
                torch.xlogy(x, y, out=expected)
4157
                inplace_out = x.clone().xlogy_(y)
4158
                self.assertEqual(expected, inplace_out)
4159

4160
        def test_helper(torch_fn, reference_fn, inputs, scalar=None):
4161
            x, y, z = inputs
4162
            torch_fn_partial = partial(torch_fn, x)
4163
            reference_fn_partial = partial(reference_fn, x.cpu().numpy())
4164
            self.compare_with_numpy(
4165
                torch_fn_partial, reference_fn_partial, x, exact_dtype=False
4166
            )
4167
            self.compare_with_numpy(
4168
                torch_fn_partial, reference_fn_partial, y, exact_dtype=False
4169
            )
4170
            self.compare_with_numpy(
4171
                torch_fn_partial, reference_fn_partial, z, exact_dtype=False
4172
            )
4173

4174
            val = scalar if scalar is not None else x
4175
            out_variant_helper(torch_fn, val, x)
4176
            out_variant_helper(torch_fn, val, y)
4177
            out_variant_helper(torch_fn, val, z)
4178

4179
        # Tensor-Tensor Test (tensor of same and different shape)
4180
        x = make_tensor((3, 2, 4, 5), dtype=x_dtype, device=device, low=0.5, high=1000)
4181
        y = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4182
        z = make_tensor((4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4183

4184
        x_1p = make_tensor(
4185
            (3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.5, high=1000
4186
        )
4187
        y_1p = make_tensor(
4188
            (3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.5, high=1000
4189
        )
4190
        z_1p = make_tensor((4, 5), dtype=y_dtype, device=device, low=-0.5, high=1000)
4191

4192
        xlogy_fns = torch.xlogy, scipy.special.xlogy
4193
        xlog1py_fns = torch.special.xlog1py, scipy.special.xlog1py
4194

4195
        test_helper(*xlogy_fns, (x, y, z))
4196
        xlogy_inplace_variant_helper(x, x)
4197
        xlogy_inplace_variant_helper(x, y)
4198
        xlogy_inplace_variant_helper(x, z)
4199
        test_helper(*xlog1py_fns, (x_1p, y_1p, z_1p))
4200

4201
        # Scalar-Tensor Test
4202
        test_helper(*xlogy_fns, (x, y, z), 3.14)
4203
        test_helper(*xlog1py_fns, (x_1p, y_1p, z_1p), 3.14)
4204

4205
        # Special Values Tensor-Tensor
4206
        t = torch.tensor(
4207
            [-1.0, 0.0, 1.0, 2.0, float("inf"), -float("inf"), float("nan")],
4208
            device=device,
4209
        )
4210
        zeros = torch.zeros(7, dtype=y_dtype, device=device)
4211

4212
        def test_zeros_special_helper(torch_fn, reference_fn, scalar=False):
4213
            zeros_t = 0 if scalar else zeros
4214
            zeros_np = 0 if scalar else zeros.cpu().numpy()
4215
            torch_fn_partial = partial(torch_fn, zeros_t)
4216
            reference_fn_partial = partial(reference_fn, zeros_np)
4217
            self.compare_with_numpy(
4218
                torch_fn_partial, reference_fn_partial, t, exact_dtype=False
4219
            )
4220
            out_variant_helper(torch_fn, zeros_t, t)
4221

4222
        test_zeros_special_helper(*xlogy_fns)
4223
        xlogy_inplace_variant_helper(zeros, t)
4224
        test_zeros_special_helper(*xlog1py_fns)
4225

4226
        # Special Values Scalar-Tensor
4227
        test_zeros_special_helper(*xlogy_fns, scalar=True)
4228
        test_zeros_special_helper(*xlog1py_fns, scalar=True)
4229

4230
    @dtypes(torch.float64)
4231
    def test_xlogy_xlog1py_gradients(self, device, dtype):
4232
        make_arg = partial(torch.tensor, dtype=dtype, device=device, requires_grad=True)
4233

4234
        zeros = torch.zeros((2,), dtype=dtype, device=device)
4235

4236
        x = make_arg([0.0, 0.0])
4237
        y = make_arg([-1.5, 0.0])
4238
        torch.special.xlogy(x, y).sum().backward()
4239
        self.assertEqual(x.grad, zeros)
4240

4241
        x = make_arg([0.0, 0.0])
4242
        y = make_arg([-2.5, -1.0])
4243
        torch.special.xlog1py(x, y).sum().backward()
4244
        self.assertEqual(x.grad, zeros)
4245

4246
    def test_xlogy_xlog1py_scalar_type_promotion(self, device):
4247
        # Test that python numbers don't participate in type promotion at the same
4248
        # priority level as 0-dim tensors
4249
        t = torch.randn((), dtype=torch.float32, device=device)
4250

4251
        self.assertEqual(t.dtype, torch.xlogy(t, 5).dtype)
4252
        self.assertEqual(t.dtype, torch.xlogy(t, 5.0).dtype)
4253
        self.assertEqual(t.dtype, torch.special.xlog1py(t, 5).dtype)
4254
        self.assertEqual(t.dtype, torch.special.xlog1py(t, 5.0).dtype)
4255

4256
        self.assertEqual(t.dtype, torch.xlogy(5, t).dtype)
4257
        self.assertEqual(t.dtype, torch.xlogy(5.0, t).dtype)
4258
        self.assertEqual(t.dtype, torch.special.xlog1py(5, t).dtype)
4259
        self.assertEqual(t.dtype, torch.special.xlog1py(5.0, t).dtype)
4260

4261
    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
4262
    def test_xlogy_xlog1py_bfloat16(self, device):
4263
        def _compare_helper(x, y, torch_fn, reference_fn):
4264
            x_np = x if isinstance(x, float) else x.cpu().to(torch.float).numpy()
4265
            y_np = y if isinstance(y, float) else y.cpu().to(torch.float).numpy()
4266
            expected = torch.from_numpy(reference_fn(x_np, y_np))
4267
            actual = torch_fn(x, y)
4268
            self.assertEqual(expected, actual, exact_dtype=False)
4269

4270
        x_dtype, y_dtype = torch.bfloat16, torch.bfloat16
4271

4272
        # Tensor-Tensor Test (tensor of same and different shape)
4273
        x = make_tensor((3, 2, 4, 5), dtype=x_dtype, device=device, low=0.5, high=1000)
4274
        y = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4275
        z = make_tensor((4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4276

4277
        x_1p = make_tensor(
4278
            (3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.8, high=1000
4279
        )
4280
        y_1p = make_tensor(
4281
            (3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.8, high=1000
4282
        )
4283
        z_1p = make_tensor((4, 5), dtype=y_dtype, device=device, low=-0.8, high=1000)
4284

4285
        xlogy_fns = torch.xlogy, scipy.special.xlogy
4286
        xlog1py_fns = torch.special.xlog1py, scipy.special.xlog1py
4287

4288
        _compare_helper(x, x, *xlogy_fns)
4289
        _compare_helper(x, y, *xlogy_fns)
4290
        _compare_helper(x, z, *xlogy_fns)
4291
        _compare_helper(x, 3.14, *xlogy_fns)
4292
        _compare_helper(y, 3.14, *xlogy_fns)
4293
        _compare_helper(z, 3.14, *xlogy_fns)
4294

4295
        _compare_helper(x_1p, x_1p, *xlog1py_fns)
4296
        _compare_helper(x_1p, y_1p, *xlog1py_fns)
4297
        _compare_helper(x_1p, z_1p, *xlog1py_fns)
4298
        _compare_helper(x_1p, 3.14, *xlog1py_fns)
4299
        _compare_helper(y_1p, 3.14, *xlog1py_fns)
4300
        _compare_helper(z_1p, 3.14, *xlog1py_fns)
4301

4302
        # Special Values Tensor-Tensor
4303
        t = torch.tensor(
4304
            [-1.0, 0.0, 1.0, 2.0, float("inf"), -float("inf"), float("nan")],
4305
            device=device,
4306
        )
4307
        zeros = torch.tensor(7, dtype=y_dtype, device=device)
4308

4309
        _compare_helper(t, zeros, *xlogy_fns)
4310
        _compare_helper(t, 0.0, *xlogy_fns)
4311

4312
        _compare_helper(t, zeros, *xlog1py_fns)
4313
        _compare_helper(t, 0.0, *xlog1py_fns)
4314

4315
    @dtypes(*product(all_types_and(torch.bool), all_types_and(torch.bool)))
4316
    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
4317
    @slowTest
4318
    def test_zeta(self, device, dtypes):
4319
        x_dtype, q_dtype = dtypes
4320

4321
        def test_helper(x, q):
4322
            x_np = x if isinstance(x, float) else x.cpu().numpy()
4323
            q_np = q if isinstance(q, float) else q.cpu().numpy()
4324
            expected = torch.from_numpy(scipy.special.zeta(x_np, q_np))
4325
            actual = torch.special.zeta(x, q)
4326

4327
            rtol, atol = None, None
4328
            if self.device_type == "cpu":
4329
                rtol, atol = 1e-6, 1e-6
4330
            self.assertEqual(expected, actual, rtol=rtol, atol=atol, exact_dtype=False)
4331

4332
        # x tensor - q tensor same size
4333
        x = make_tensor((2, 3, 4), dtype=x_dtype, device=device)
4334
        q = make_tensor((2, 3, 4), dtype=q_dtype, device=device)
4335
        test_helper(x, q)
4336

4337
        # x tensor - q tensor broadcast lhs
4338
        x = make_tensor((2, 1, 4), dtype=x_dtype, device=device)
4339
        q = make_tensor((2, 3, 4), dtype=q_dtype, device=device)
4340
        test_helper(x, q)
4341

4342
        # x tensor - q tensor broadcast rhs
4343
        x = make_tensor((2, 3, 4), dtype=x_dtype, device=device)
4344
        q = make_tensor((2, 1, 4), dtype=q_dtype, device=device)
4345
        test_helper(x, q)
4346

4347
        # x tensor - q tensor broadcast all
4348
        x = make_tensor((2, 3, 1), dtype=x_dtype, device=device)
4349
        q = make_tensor((2, 1, 4), dtype=q_dtype, device=device)
4350
        test_helper(x, q)
4351

4352
        # x scalar - q tensor
4353
        for x in np.linspace(-5, 5, num=10).tolist():
4354
            if not q_dtype.is_floating_point:
4355
                q_dtype = torch.get_default_dtype()
4356
            q = make_tensor((2, 3, 4), dtype=q_dtype, device=device)
4357
            test_helper(x, q)
4358

4359
        # x tensor - q scalar
4360
        for q in np.linspace(-5, 5, num=10).tolist():
4361
            if not x_dtype.is_floating_point:
4362
                x_dtype = torch.get_default_dtype()
4363
            x = make_tensor((2, 3, 4), dtype=x_dtype, device=device)
4364
            test_helper(x, q)
4365

4366
    @onlyCUDA
4367
    @dtypes(
4368
        torch.chalf,
4369
    )
4370
    def test_mul_chalf_tensor_and_cpu_scalar(self, device, dtype):
4371
        # Tests that Tensor and CPU Scalar work for `mul` for chalf.
4372
        # Ideally, this should be covered by `test_complex_half_reference_testing`
4373
        # from test_ops.py by checking reference_samples from the OpInfo.
4374
        # But currently that doesn't work as sample generation requires support of
4375
        # `index_select` which is not implemented for `complex32` at the
4376
        # time of writing this test.
4377
        # TODO: Remove this test once above issue is fixed.
4378
        # Ref: https://github.com/pytorch/pytorch/pull/76364
4379
        x = make_tensor((2, 2), device=device, dtype=dtype)
4380
        self.assertEqual(x * 2.5, x * torch.tensor(2.5, device=device, dtype=dtype))
4381

4382

4383
tensor_binary_ops = [
4384
    "__lt__",
4385
    "__le__",
4386
    "__gt__",
4387
    "__ge__",
4388
    "__eq__",
4389
    "__ne__",
4390
    "__add__",
4391
    "__radd__",
4392
    "__iadd__",
4393
    "__sub__",
4394
    "__rsub__",
4395
    "__isub__",
4396
    "__mul__",
4397
    "__rmul__",
4398
    "__imul__",
4399
    "__matmul__",
4400
    "__rmatmul__",
4401
    "__truediv__",
4402
    "__rtruediv__",
4403
    "__itruediv__",
4404
    "__floordiv__",
4405
    "__rfloordiv__",
4406
    "__ifloordiv__",
4407
    "__mod__",
4408
    "__rmod__",
4409
    "__imod__",
4410
    "__pow__",
4411
    "__rpow__",
4412
    "__ipow__",
4413
    "__lshift__",
4414
    "__rlshift__",
4415
    "__ilshift__",
4416
    "__rshift__",
4417
    "__rrshift__",
4418
    "__irshift__",
4419
    "__and__",
4420
    "__rand__",
4421
    "__iand__",
4422
    "__xor__",
4423
    "__rxor__",
4424
    "__ixor__",
4425
    "__or__",
4426
    "__ror__",
4427
    "__ior__",
4428
    # Unsupported operators
4429
    # '__imatmul__',
4430
    # '__divmod__', '__rdivmod__', '__idivmod__',
4431
]
4432

4433
# Test that binary math operations return NotImplemented for unknown types.
4434
def generate_not_implemented_tests(cls):
4435
    class UnknownType:
4436
        pass
4437

4438
    # TODO: refactor to inline these
4439
    _types = [
4440
        torch.half,
4441
        torch.float,
4442
        torch.double,
4443
        torch.int8,
4444
        torch.short,
4445
        torch.int,
4446
        torch.long,
4447
        torch.uint8,
4448
    ]
4449

4450
    def create_test_func(op):
4451
        @dtypes(*_types)
4452
        def test(self, device, dtype):
4453
            # Generate the inputs
4454
            tensor = torch.empty((), device=device, dtype=dtype)
4455

4456
            # Runs the tensor op on the device
4457
            result = getattr(tensor, op)(UnknownType())
4458
            self.assertEqual(result, NotImplemented)
4459

4460
        return test
4461

4462
    for op in tensor_binary_ops:
4463
        test_name = f"test_{op}_not_implemented"
4464
        assert not hasattr(cls, test_name), f"{test_name} already in {cls.__name__}"
4465

4466
        setattr(cls, test_name, create_test_func(op))
4467

4468

4469
generate_not_implemented_tests(TestBinaryUfuncs)
4470
instantiate_device_type_tests(TestBinaryUfuncs, globals())
4471

4472
if __name__ == "__main__":
4473
    run_tests()
4474

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.