8
from itertools import chain
9
from itertools import product
12
from numbers import Number
15
from functools import partial
17
import torch.autograd.forward_ad as fwAD
18
from torch import inf, nan
19
from torch.testing._internal.common_utils import (
25
torch_to_numpy_dtype_dict,
26
numpy_to_torch_dtype_dict,
31
from torch.testing._internal.common_device_type import (
33
instantiate_device_type_tests,
41
onlyNativeDeviceTypes,
47
from torch.testing import make_tensor
48
from torch.testing._internal.common_dtype import (
49
all_types_and_complex_and,
55
floating_and_complex_types,
59
from torch.testing._internal.common_methods_invocations import (
61
binary_ufuncs_and_refs,
62
generate_elementwise_binary_tensors,
63
generate_elementwise_binary_small_value_tensors,
64
generate_elementwise_binary_large_value_tensors,
65
generate_elementwise_binary_extremal_value_tensors,
66
generate_elementwise_binary_broadcasting_tensors,
67
generate_elementwise_binary_with_scalar_samples,
68
generate_elementwise_binary_with_scalar_and_type_promotion_samples,
73
import scipy.integrate
76
class TestBinaryUfuncs(TestCase):
83
def assertEqualHelper(
84
self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs
86
assert isinstance(actual, torch.Tensor)
89
if isinstance(expected, Number):
90
self.assertEqual(actual.item(), expected, msg=msg, **kwargs)
91
elif isinstance(expected, np.ndarray):
98
if expected.dtype == np.float32:
99
assert actual.dtype in (
105
assert expected.dtype == torch_to_numpy_dtype_dict[actual.dtype]
109
torch.from_numpy(expected).to(actual.dtype),
115
self.assertEqual(actual, expected, msg, exact_device=False, **kwargs)
119
def _test_reference_numerics(self, dtype, op, gen, equal_nan=True):
120
def _helper_reference_numerics(
121
expected, actual, msg, exact_dtype, equal_nan=True
123
if not torch.can_cast(
124
numpy_to_torch_dtype_dict[expected.dtype.type], dtype
128
if dtype is torch.bfloat16 and expected.dtype == np.float32:
130
self.assertEqualHelper(
135
exact_dtype=exact_dtype,
140
self.assertEqualHelper(
146
exact_dtype=exact_dtype,
155
numpy_sample = sample.numpy()
156
l_numpy = numpy_sample.input
157
r_numpy = numpy_sample.args[0]
160
expected = op.ref(l_numpy, r_numpy)
164
if isinstance(x, torch.Tensor):
169
if _numel(l) <= 100 and _numel(r) <= 100:
171
"Failed to produce expected results! Input lhs tensor was"
172
f" {l}, rhs tensor was {r}, torch result is {actual}, and reference result is"
179
if isinstance(actual, torch.Tensor):
180
_helper_reference_numerics(
181
expected, actual, msg, exact_dtype, equal_nan
184
for x, y in zip(expected, actual):
186
_helper_reference_numerics(x, y, msg, exact_dtype, equal_nan)
189
binary_ufuncs_with_references = list(
190
filter(lambda op: op.ref is not None and op.ref is not None, binary_ufuncs)
193
@ops(binary_ufuncs_with_references)
194
def test_reference_numerics(self, device, dtype, op):
195
gen = generate_elementwise_binary_tensors(op, device=device, dtype=dtype)
196
self._test_reference_numerics(dtype, op, gen, equal_nan=True)
198
@ops(binary_ufuncs_with_references)
199
def test_reference_numerics_small_values(self, device, dtype, op):
200
if dtype is torch.bool:
201
self.skipTest("Doesn't support bool!")
203
gen = generate_elementwise_binary_small_value_tensors(
204
op, device=device, dtype=dtype
206
self._test_reference_numerics(dtype, op, gen, equal_nan=True)
209
binary_ufuncs_with_references,
222
def test_reference_numerics_large_values(self, device, dtype, op):
223
gen = generate_elementwise_binary_large_value_tensors(
224
op, device=device, dtype=dtype
226
self._test_reference_numerics(dtype, op, gen, equal_nan=True)
229
binary_ufuncs_with_references,
239
def test_reference_numerics_extremal_values(self, device, dtype, op):
240
gen = generate_elementwise_binary_extremal_value_tensors(
241
op, device=device, dtype=dtype
243
self._test_reference_numerics(dtype, op, gen, equal_nan=True)
247
binary_ufuncs_with_references,
253
def test_broadcasting(self, device, dtype, op):
254
gen = generate_elementwise_binary_broadcasting_tensors(
255
op, device=device, dtype=dtype
257
self._test_reference_numerics(dtype, op, gen, equal_nan=True)
260
binary_ufuncs_with_references,
261
allowed_dtypes=(torch.long, torch.float32, torch.complex64),
263
def test_scalar_support(self, device, dtype, op):
264
gen = generate_elementwise_binary_with_scalar_samples(
265
op, device=device, dtype=dtype
267
self._test_reference_numerics(dtype, op, gen, equal_nan=True)
268
gen = generate_elementwise_binary_with_scalar_and_type_promotion_samples(
269
op, device=device, dtype=dtype
271
self._test_reference_numerics(dtype, op, gen, equal_nan=True)
275
def test_contig_vs_every_other(self, device, dtype, op):
277
(1026,), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs
280
(1026,), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs
283
lhs_non_contig = lhs[::2]
284
rhs_non_contig = rhs[::2]
286
self.assertTrue(lhs.is_contiguous())
287
self.assertTrue(rhs.is_contiguous())
289
self.assertFalse(lhs_non_contig.is_contiguous())
290
self.assertFalse(rhs_non_contig.is_contiguous())
292
expected = op(lhs, rhs)[::2]
293
actual = op(lhs_non_contig, rhs_non_contig)
294
self.assertEqual(expected, actual)
297
def test_contig_vs_transposed(self, device, dtype, op):
299
(789, 357), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs
302
(789, 357), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs
305
lhs_non_contig = lhs.T
306
rhs_non_contig = rhs.T
308
self.assertTrue(lhs.is_contiguous())
309
self.assertTrue(rhs.is_contiguous())
311
self.assertFalse(lhs_non_contig.is_contiguous())
312
self.assertFalse(rhs_non_contig.is_contiguous())
314
expected = op(lhs, rhs).T
315
actual = op(lhs_non_contig, rhs_non_contig)
316
self.assertEqual(expected, actual)
319
def test_non_contig(self, device, dtype, op):
320
shapes = ((5, 7), (1024,))
323
shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
326
shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
329
lhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[
332
lhs_non_contig.copy_(lhs)
334
rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[
337
rhs_non_contig.copy_(rhs)
339
self.assertTrue(lhs.is_contiguous())
340
self.assertTrue(rhs.is_contiguous())
342
self.assertFalse(lhs_non_contig.is_contiguous())
343
self.assertFalse(rhs_non_contig.is_contiguous())
345
expected = op(lhs, rhs)
346
actual = op(lhs_non_contig, rhs_non_contig)
347
self.assertEqual(expected, actual)
350
def test_non_contig_index(self, device, dtype, op):
353
shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
356
shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
359
lhs_non_contig = lhs[:, 1, ...]
360
lhs = lhs_non_contig.contiguous()
362
rhs_non_contig = rhs[:, 1, ...]
363
rhs = rhs_non_contig.contiguous()
365
self.assertTrue(lhs.is_contiguous())
366
self.assertTrue(rhs.is_contiguous())
368
self.assertFalse(lhs_non_contig.is_contiguous())
369
self.assertFalse(rhs_non_contig.is_contiguous())
371
expected = op(lhs, rhs)
372
actual = op(lhs_non_contig, rhs_non_contig)
373
self.assertEqual(expected, actual)
376
def test_non_contig_expand(self, device, dtype, op):
377
shapes = [(1, 3), (1, 7), (5, 7)]
380
shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
383
shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
386
lhs_non_contig = lhs.clone().expand(3, -1, -1)
387
rhs_non_contig = rhs.clone().expand(3, -1, -1)
389
self.assertTrue(lhs.is_contiguous())
390
self.assertTrue(rhs.is_contiguous())
392
self.assertFalse(lhs_non_contig.is_contiguous())
393
self.assertFalse(rhs_non_contig.is_contiguous())
395
expected = op(lhs, rhs)
396
actual = op(lhs_non_contig, rhs_non_contig)
398
self.assertEqual(expected, actual[i])
401
def test_contig_size1(self, device, dtype, op):
404
shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
407
shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
411
lhs_alt = torch.empty(lhs.size(), device=device, dtype=dtype)
415
rhs_alt = torch.empty(rhs.size(), device=device, dtype=dtype)
418
self.assertTrue(lhs.is_contiguous())
419
self.assertTrue(rhs.is_contiguous())
421
self.assertTrue(lhs_alt.is_contiguous())
422
self.assertTrue(rhs_alt.is_contiguous())
424
expected = op(lhs, rhs)
425
actual = op(lhs_alt, rhs_alt)
426
self.assertEqual(expected, actual)
429
def test_contig_size1_large_dim(self, device, dtype, op):
430
shape = (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4)
432
shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
435
shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
438
lhs = lhs[:1, :, :, :, :, :, :, :, :, :, :, :]
439
lhs_alt = torch.empty(lhs.size(), device=device, dtype=dtype)
442
rhs = rhs[:1, :, :, :, :, :, :, :, :, :, :, :]
443
rhs_alt = torch.empty(rhs.size(), device=device, dtype=dtype)
446
self.assertTrue(lhs.is_contiguous())
447
self.assertTrue(rhs.is_contiguous())
449
self.assertTrue(lhs_alt.is_contiguous())
450
self.assertTrue(rhs_alt.is_contiguous())
452
expected = op(lhs, rhs)
453
actual = op(lhs_alt, rhs_alt)
454
self.assertEqual(expected, actual)
457
def test_batch_vs_slicing(self, device, dtype, op):
460
shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs
463
shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs
466
expected = op(lhs, rhs)
469
for idx in range(32):
470
actual.append(op(lhs[idx], rhs[idx]))
471
actual = torch.stack(actual)
473
self.assertEqual(expected, actual)
479
@ops(binary_ufuncs_and_refs, dtypes=OpDTypes.none)
480
def test_type_promotion(self, device, op):
481
supported_dtypes = op.supported_dtypes(torch.device(device).type)
484
make_tensor, (5,), device=device, **op.lhs_make_tensor_kwargs
487
make_tensor, (5,), device=device, **op.rhs_make_tensor_kwargs
490
make_rhs_scalar_tensor = partial(
491
make_tensor, (), device='cpu', **op.rhs_make_tensor_kwargs
494
def _supported(dtypes):
495
return all(x in supported_dtypes for x in dtypes)
498
if _supported((torch.int16, torch.int32, torch.int64)):
499
lhs_i16 = make_lhs(dtype=torch.int16)
500
lhs_i32 = make_lhs(dtype=torch.int32)
501
lhs_i64 = make_lhs(dtype=torch.int64)
503
rhs_i16 = make_rhs(dtype=torch.int16)
504
rhs_i32 = make_rhs(dtype=torch.int32)
505
rhs_i64 = make_rhs(dtype=torch.int64)
507
if op.promotes_int_to_float:
508
default_dtype = torch.get_default_dtype()
509
self.assertEqual(op(lhs_i16, rhs_i32).dtype, default_dtype)
511
op(lhs_i16, rhs_i32),
512
op(lhs_i16.to(default_dtype), rhs_i32.to(default_dtype)),
515
self.assertEqual(op(lhs_i32, rhs_i64).dtype, default_dtype)
517
op(lhs_i32, rhs_i64),
518
op(lhs_i32.to(default_dtype), rhs_i64.to(default_dtype)),
520
elif op.always_returns_bool:
521
self.assertEqual(op(lhs_i16, rhs_i32).dtype, torch.bool)
522
self.assertEqual(op(lhs_i32, rhs_i64).dtype, torch.bool)
524
self.assertEqual(op(lhs_i16, rhs_i32).dtype, torch.int32)
526
op(lhs_i16, rhs_i32), op(lhs_i16.to(torch.int32), rhs_i32)
529
self.assertEqual(op(lhs_i32, rhs_i64).dtype, torch.int64)
531
op(lhs_i32, rhs_i64), op(lhs_i32.to(torch.int64), rhs_i64)
535
if not op.promotes_int_to_float:
537
out = torch.empty_like(lhs_i64)
538
self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.int64)
539
self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False)
541
out = torch.empty_like(lhs_i16)
542
self.assertEqual(op(lhs_i32, rhs_i64, out=out).dtype, torch.int16)
545
with self.assertRaisesRegex(RuntimeError, "can't be cast"):
546
op(lhs_i16, rhs_i32, out=torch.empty_like(lhs_i64))
548
if not op.always_returns_bool:
550
with self.assertRaisesRegex(RuntimeError, "can't be cast"):
554
out=torch.empty_like(lhs_i64, dtype=torch.bool),
558
out = torch.empty_like(lhs_i64, dtype=torch.float16)
559
self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.float16)
561
out = torch.empty_like(lhs_i64, dtype=torch.bfloat16)
562
self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.bfloat16)
564
out = torch.empty_like(lhs_i64, dtype=torch.float32)
565
self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.float32)
566
self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False)
568
out = torch.empty_like(lhs_i64, dtype=torch.complex64)
569
self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.complex64)
570
self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False)
573
if _supported((torch.float32, torch.float64)):
574
lhs_f32 = make_lhs(dtype=torch.float32)
575
lhs_f64 = make_lhs(dtype=torch.float64)
577
rhs_f32 = make_rhs(dtype=torch.float32)
578
rhs_f64 = make_rhs(dtype=torch.float64)
580
if op.always_returns_bool:
581
self.assertEqual(op(lhs_f32, rhs_f64).dtype, torch.bool)
583
self.assertEqual(op(lhs_f32, rhs_f64).dtype, torch.float64)
585
op(lhs_f32, rhs_f64), op(lhs_f32.to(torch.float64), rhs_f64)
590
out = torch.empty_like(lhs_f64, dtype=torch.float16)
591
self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.float16)
593
out = torch.empty_like(lhs_f64, dtype=torch.bfloat16)
594
self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.bfloat16)
595
self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
597
out = torch.empty_like(lhs_f64, dtype=torch.float32)
598
self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.float32)
599
self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
601
out = torch.empty_like(lhs_f64, dtype=torch.complex64)
602
self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.complex64)
603
self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
605
if not op.always_returns_bool:
607
with self.assertRaisesRegex(RuntimeError, "can't be cast"):
611
out=torch.empty_like(lhs_f64, dtype=torch.int64),
615
out = torch.empty_like(lhs_f64, dtype=torch.int64)
616
self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.int64)
617
self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
620
if _supported((torch.complex64, torch.complex128)):
621
lhs_c64 = make_lhs(dtype=torch.complex64)
622
lhs_c128 = make_lhs(dtype=torch.complex128)
624
rhs_c64 = make_rhs(dtype=torch.complex64)
625
rhs_c128 = make_rhs(dtype=torch.complex128)
627
if op.always_returns_bool:
628
self.assertEqual(op(lhs_c64, lhs_c128).dtype, torch.bool)
630
self.assertEqual(op(lhs_c64, rhs_c128).dtype, torch.complex128)
632
op(lhs_c64, rhs_c128), op(lhs_c64.to(torch.complex128), rhs_c128)
637
out = torch.empty_like(lhs_c64, dtype=torch.complex64)
639
self.assertEqual(op(lhs_c64, rhs_c128, out=out).dtype, torch.complex64)
640
result = op(lhs_c64, rhs_c128)
641
self.assertEqual(result, out.to(result.dtype))
643
if not op.always_returns_bool:
645
with self.assertRaisesRegex(RuntimeError, "can't be cast"):
649
out=torch.empty_like(lhs_c64, dtype=torch.float64),
652
with self.assertRaisesRegex(RuntimeError, "can't be cast"):
656
out=torch.empty_like(lhs_c64, dtype=torch.int64),
660
out = torch.empty_like(lhs_c64, dtype=torch.float64)
662
op(lhs_c64, rhs_c128, out=out).dtype, torch.float64
664
self.assertEqual(op(lhs_c64, rhs_c128), out, exact_dtype=False)
667
out = torch.empty_like(lhs_f64, dtype=torch.int64)
668
self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.int64)
669
self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False)
673
if _supported((torch.long, torch.float32)):
674
lhs_i64 = make_lhs(dtype=torch.int64)
675
rhs_f32 = make_rhs(dtype=torch.float32)
677
result = op(lhs_i64, rhs_f32)
678
expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool
679
self.assertEqual(result.dtype, expected_dtype)
683
if _supported((torch.float64, torch.complex64)):
684
lhs_f64 = make_lhs(dtype=torch.float64)
685
rhs_c64 = make_rhs(dtype=torch.complex64)
687
result = op(lhs_f64, rhs_c64)
689
torch.complex128 if not op.always_returns_bool else torch.bool
691
self.assertEqual(result.dtype, expected_dtype)
695
if _supported((torch.int64, torch.float32)) and op.supports_rhs_python_scalar:
696
lhs_i64 = make_lhs(dtype=torch.int64)
699
result = op(lhs_i64, rhs_f_scalar)
701
torch.get_default_dtype() if not op.always_returns_bool else torch.bool
703
self.assertEqual(result.dtype, expected_dtype)
706
rhs_f32_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float32)
707
result = op(lhs_i64, rhs_f32_scalar_tensor)
708
expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool
709
self.assertEqual(result.dtype, expected_dtype)
712
if _supported((torch.float64,)):
713
rhs_f64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float64)
714
result = op(lhs_i64, rhs_f64_scalar_tensor)
716
torch.float64 if not op.always_returns_bool else torch.bool
718
self.assertEqual(result.dtype, expected_dtype)
723
_supported((torch.float32, torch.complex64))
724
and op.supports_rhs_python_scalar
726
lhs_f32 = make_lhs(dtype=torch.float32)
727
rhs_c_scalar = complex(1, 1)
729
result = op(lhs_f32, rhs_c_scalar)
731
torch.complex64 if not op.always_returns_bool else torch.bool
733
self.assertEqual(result.dtype, expected_dtype)
736
rhs_c64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex64)
737
result = op(lhs_f32, rhs_c64_scalar_tensor)
739
torch.complex64 if not op.always_returns_bool else torch.bool
741
self.assertEqual(result.dtype, expected_dtype)
744
if _supported((torch.complex128,)):
745
rhs_c128_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex128)
746
result = op(lhs_f32, rhs_c128_scalar_tensor)
749
torch.complex64 if not op.always_returns_bool else torch.bool
751
self.assertEqual(result.dtype, expected_dtype)
755
if _supported((torch.float32, torch.float64)) and op.supports_rhs_python_scalar:
756
lhs_f32 = make_lhs(dtype=torch.float32)
757
rhs_f64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float64)
759
result = op(lhs_f32, rhs_f64_scalar_tensor)
760
expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool
761
self.assertEqual(result.dtype, expected_dtype)
766
_supported((torch.complex64, torch.complex128))
767
and op.supports_rhs_python_scalar
769
lhs_c64 = make_lhs(dtype=torch.complex64)
770
rhs_c128_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex128)
772
result = op(lhs_c64, rhs_c128_scalar_tensor)
774
torch.complex64 if not op.always_returns_bool else torch.bool
776
self.assertEqual(result.dtype, expected_dtype)
780
if op.supports_two_python_scalars and _supported((torch.long, torch.float32)):
783
result = op(lhs, rhs_f_scalar)
784
expected_dtype = torch.get_default_dtype() if not op.always_returns_bool else torch.bool
785
self.assertEqual(result.dtype, expected_dtype)
788
@ops(binary_ufuncs, allowed_dtypes=(torch.float32,))
789
def test_not_broadcastable(self, device, dtype, op):
790
for shape_lhs, shape_rhs in (
794
((3, 1, 2), (2, 1, 2)),
797
shape_lhs, device=device, dtype=dtype, **op.lhs_make_tensor_kwargs
800
shape_rhs, device=device, dtype=dtype, **op.rhs_make_tensor_kwargs
804
broadcasted_shape = op(lhs, rhs).shape
809
f"On {device}, torch.{op.name} broadcasts inputs shapes {shape_lhs} and {shape_rhs} into "
810
f"{broadcasted_shape}, although they are not broadcastable."
812
raise AssertionError(msg)
814
def test_add_broadcast_empty(self, device):
818
lambda: torch.randn(5, 0, device=device) + torch.randn(0, 5, device=device),
821
torch.randn(5, 0, device=device),
822
torch.randn(0, device=device) + torch.randn(5, 0, device=device),
825
torch.randn(5, 0, 0, device=device),
826
torch.randn(0, device=device) + torch.randn(5, 0, 1, device=device),
831
torch.randn(5, 0, 6, device=device),
832
torch.randn((), device=device) + torch.randn(5, 0, 6, device=device),
837
torch.randn(0, device=device),
838
torch.randn(0, device=device) + torch.randn(1, device=device),
841
torch.randn(0, 7, 0, 6, 5, 0, 7, device=device),
842
torch.randn(0, 7, 0, 6, 5, 0, 1, device=device)
843
+ torch.randn(1, 1, 5, 1, 7, device=device),
847
lambda: torch.randn(7, 0, device=device) + torch.randn(2, 1, device=device),
850
def test_addcmul_scalars_as_floats(self, device):
852
x = torch.tensor(2.0)
853
y = torch.tensor(3.0, device=device)
855
self.assertEqual(y.addcmul(y, y, value=x), 21)
857
x = torch.tensor(2.0, requires_grad=True)
858
self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x))
862
@dtypes(*integral_types_and(torch.bool))
863
def test_bitwise_ops(self, device, dtype):
873
inplace_ops = (operator.iand, operator.ior, operator.ixor)
874
shapes = ((5,), (15, 15), (500, 500))
876
for op, shape in itertools.product(ops, shapes):
878
a = make_tensor(shape, device=device, dtype=dtype)
879
b = make_tensor(shape, device=device, dtype=dtype)
880
a_np = a.cpu().clone().numpy()
881
b_np = b.cpu().clone().numpy()
882
self.assertEqual(op(a, b), op(a_np, b_np))
885
a = make_tensor(shape, device=device, dtype=dtype)
886
b_scalar = make_tensor((), device="cpu", dtype=dtype).item()
887
a_np = a.cpu().clone().numpy()
888
self.assertEqual(op(a, b_scalar), op(a_np, b_scalar))
891
a_scalar = make_tensor((), device="cpu", dtype=dtype).item()
892
b = make_tensor(shape, device=device, dtype=dtype)
893
b_np = b.cpu().clone().numpy()
894
self.assertEqual(op(a_scalar, b), op(a_scalar, b_np))
897
if op in inplace_ops:
899
a = make_tensor(shape, device=device, dtype=dtype)
900
b = make_tensor(shape, device=device, dtype=dtype)
901
a_np = a.cpu().clone().numpy()
902
b_np = b.cpu().clone().numpy()
905
self.assertEqual(a, a_np)
908
a = make_tensor(shape, device=device, dtype=dtype)
909
b_scalar = make_tensor((), device="cpu", dtype=dtype).item()
910
a_np = a.cpu().clone().numpy()
913
self.assertEqual(a, a_np)
915
def test_inplace_division(self, device):
916
t = torch.rand(5, 5, device=device)
920
self.assertEqual(id_before, id_after)
922
@dtypes(*all_types_and(torch.half, torch.bfloat16))
923
def test_div_rounding_modes(self, device, dtype):
924
if dtype.is_floating_point:
925
low, high = -10.0, 10.0
927
info = torch.iinfo(dtype)
928
low, high = info.min, info.max
930
a = make_tensor((100,), dtype=dtype, device=device, low=low, high=high)
931
b = make_tensor((100,), dtype=dtype, device=device, low=low, high=high)
934
if dtype.is_floating_point:
936
b[(-eps < b) & (b < eps)] = eps
940
if not dtype.is_floating_point:
942
a = torch.where(a < 0, a + b, a)
944
d_true = torch.divide(a, b, rounding_mode=None)
945
self.assertTrue(d_true.is_floating_point())
946
self.assertEqual(d_true * b, a.to(d_true.dtype))
948
d_floor = torch.divide(a, b, rounding_mode="floor")
949
if dtype not in (torch.bfloat16, torch.half):
950
self.assertEqual(d_floor * b + torch.remainder(a, b), a)
953
d_floor * b + torch.remainder(a.float(), b.float()),
958
d_trunc = torch.divide(a, b, rounding_mode="trunc")
959
rounding_unsupported = (
962
or dtype == torch.bfloat16
965
d_ref = d_true.float() if rounding_unsupported else d_true
966
self.assertEqual(d_trunc, d_ref.trunc().to(dtype))
968
@dtypes(torch.bfloat16, torch.half, torch.float32, torch.float64)
969
def test_div_rounding_nonfinite(self, device, dtype):
973
[1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan],
977
denom = num[num != 0]
979
a, b = num[None, :].clone(), denom[:, None].clone()
982
exact_dtype = dtype != torch.bfloat16
984
an, bn = a.cpu().numpy(), b.cpu().numpy()
986
an, bn = a.float().cpu().numpy(), b.float().cpu().numpy()
988
for mode, np_ref in ((None, np.true_divide), ("floor", np.floor_divide)):
989
expect = np_ref(an, bn)
990
kwargs = dict(rounding_mode=mode) if mode is not None else {}
991
with set_default_dtype(torch.double):
992
actual = torch.divide(a, b, **kwargs)
995
torch.from_numpy(expect),
997
exact_dtype=exact_dtype,
1001
a_noncontig = torch.empty([2 * i for i in a.shape], dtype=dtype, device=device)[
1005
b_noncontig = torch.empty([2 * i for i in b.shape], dtype=dtype, device=device)[
1010
for rounding_mode in (None, "trunc", "floor"):
1011
expect = torch.divide(a_noncontig, b_noncontig, rounding_mode=rounding_mode)
1012
actual = torch.divide(a, b, rounding_mode=rounding_mode)
1013
self.assertEqual(actual, expect)
1015
@dtypes(torch.bfloat16, torch.half, torch.float32, torch.float64)
1016
def test_divide_by_zero_rounding(self, device, dtype):
1018
[1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan],
1021
exact_dtype = dtype != torch.bfloat16
1023
an = a.cpu().numpy()
1025
an = a.float().cpu().numpy()
1027
zero = torch.zeros_like(a)
1030
expect = np.divide(an, 0)
1031
for rounding_mode in (None, "floor"):
1033
actual = torch.divide(a, 0, rounding_mode=rounding_mode)
1034
self.assertEqual(actual, expect, exact_dtype=exact_dtype)
1036
actual = torch.divide(a, zero, rounding_mode=rounding_mode)
1037
self.assertEqual(actual, expect, exact_dtype=exact_dtype)
1039
@dtypes(*all_types_and(torch.half))
1040
def test_div_rounding_numpy(self, device, dtype):
1041
info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)
1042
low, high = info.min, info.max
1045
a = make_tensor((4096,), dtype=dtype, device=device, low=low, high=high)
1046
b = make_tensor((4096,), dtype=dtype, device=device, low=low, high=high)
1054
exact_dtype = dtype != torch.bfloat16
1057
an, bn = a.cpu().numpy(), b.cpu().numpy()
1059
an, bn = a.float().cpu().numpy(), b.float().cpu().numpy()
1061
for mode, np_ref in (
1062
(None, np.true_divide),
1063
("floor", np.floor_divide),
1064
("trunc", lambda a, b: np.trunc(np.true_divide(a, b)).astype(a.dtype)),
1066
expect = torch.from_numpy(np_ref(an, bn))
1068
kwargs = dict(rounding_mode=mode) if mode is not None else {}
1070
with set_default_dtype(torch.double):
1071
actual = torch.divide(a, b, **kwargs)
1073
actual, expect, exact_device=False, exact_dtype=exact_dtype
1077
expect = expect[::2]
1078
with set_default_dtype(torch.double):
1079
actual = torch.divide(a[::2], b[::2], **kwargs)
1082
actual, expect, exact_device=False, exact_dtype=exact_dtype
1085
@dtypes(*complex_types())
1086
def test_complex_div_underflow_overflow(self, device, dtype):
1091
finfo = torch.finfo(dtype)
1092
nom_lst = [complex(finfo.min / 2, finfo.min / 2),
1093
complex(finfo.max / 2, finfo.max / 2),
1094
complex(finfo.tiny, finfo.tiny),
1095
complex(finfo.tiny, 0.0),
1097
denom_lst = [complex(finfo.min / 2, finfo.min / 2),
1098
complex(finfo.max / 2, finfo.max / 2),
1099
complex(finfo.tiny, finfo.tiny),
1100
complex(0.0, finfo.tiny),
1101
complex(finfo.tiny, finfo.tiny)]
1102
expected_lst = [complex(1.0, 0.0),
1107
nom = torch.tensor(nom_lst, dtype=dtype, device=device)
1108
denom = torch.tensor(denom_lst, dtype=dtype, device=device)
1109
expected = torch.tensor(expected_lst, dtype=dtype, device=device)
1111
self.assertEqual(res, expected)
1116
def test_cross_device_inplace_error_msg(self, device):
1117
a = torch.tensor(2.0)
1118
b = torch.tensor(2.0, device=device)
1119
with self.assertRaisesRegex(
1120
RuntimeError, "Expected all tensors to be on the same device"
1125
@onlyNativeDeviceTypes
1126
def test_out_resize_warning(self, device):
1127
a = torch.tensor((1, 2, 3), device=device, dtype=torch.float32)
1128
b = torch.tensor((4, 5, 6), device=device, dtype=torch.float32)
1131
binary_inputs = (a, b)
1132
unary_ops = (torch.ceil, torch.exp)
1133
binary_ops = (torch.add, torch.sub)
1134
for op in unary_ops + binary_ops:
1135
with warnings.catch_warnings(record=True) as w:
1136
warnings.simplefilter("always")
1137
inputs = unary_inputs if op in unary_ops else binary_inputs
1140
op(*inputs, out=torch.empty(3, device=device))
1141
op(*inputs, out=torch.empty(0, device=device))
1142
self.assertEqual(len(w), 0)
1145
op(*inputs, out=torch.empty(2, device=device))
1146
self.assertEqual(len(w), 1)
1148
arg1 = (torch.ones(2, 1, device=device), torch.ones(1, device=device))
1149
arg2 = (torch.ones(2, device=device), torch.ones(1, 1, device=device))
1150
outs = (torch.ones(2, 1, 1, 1, device=device), torch.ones(2, 2, 2, 2, device=device))
1152
for a1, a2, o in zip(arg1, arg2, outs):
1153
with warnings.catch_warnings(record=True) as w:
1154
warnings.simplefilter("always")
1155
torch.mul(a1, a2, out=o)
1156
self.assertEqual(len(w), 1)
1159
@expectedFailureMeta
1160
@onlyNativeDeviceTypes
1161
def test_inplace_dunders(self, device):
1162
t = torch.randn((1,), device=device)
1163
expected = t.data_ptr()
1171
self.assertEqual(expected, t.data_ptr())
1173
def check_internal_mem_overlap(
1174
self, inplace_op, num_inputs, dtype, device, expected_failure=False
1176
if isinstance(inplace_op, str):
1177
inplace_op = getattr(torch.Tensor, inplace_op)
1178
input = torch.randn(1, dtype=dtype, device=device).expand(3, 3)
1179
inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)]
1180
if not expected_failure:
1181
with self.assertRaisesRegex(RuntimeError, "single memory location"):
1184
with self.assertRaises(AssertionError):
1185
with self.assertRaisesRegex(RuntimeError, "single memory location"):
1188
def unary_check_input_output_mem_overlap(
1189
self, data, sz, op, expected_failure=False
1191
def _test(op, output, input):
1192
output_exp = torch.empty_like(output)
1193
op(input, out=output_exp)
1194
self.assertEqual(op(input, out=output), output_exp, msg=op.__name__)
1197
_test(op, output=data[0:sz], input=data[0:sz])
1199
_test(op, output=data[0:sz], input=data[sz : 2 * sz])
1201
if not expected_failure:
1202
with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
1203
_test(op, data[0:sz], data[1 : sz + 1])
1205
with self.assertRaises(AssertionError):
1206
with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
1207
_test(op, data[0:sz], data[1 : sz + 1])
1209
def binary_check_input_output_mem_overlap(self, op, device, expected_failure=False):
1211
data = torch.randn(2 * sz, device=device)
1212
other = torch.randn(sz, device=device)
1214
self.unary_check_input_output_mem_overlap(
1217
lambda input, out: op(other, input, out=out),
1218
expected_failure=expected_failure,
1221
self.unary_check_input_output_mem_overlap(
1224
lambda input, out: op(input, other, out=out),
1225
expected_failure=expected_failure,
1228
@dtypes(torch.double)
1229
def test_binary_op_mem_overlap(self, device, dtype):
1231
("add", True, True, "cpu"),
1232
("add", True, True, "cuda"),
1233
("mul", True, True, "cpu"),
1234
("mul", True, True, "cuda"),
1235
("sub", True, True, "cpu"),
1236
("sub", True, True, "cuda"),
1237
("div", True, True, "cpu"),
1238
("div", True, True, "cuda"),
1239
("pow", True, True, "cpu"),
1240
("pow", True, True, "cuda"),
1241
("fmod", True, True, "cpu"),
1242
("fmod", True, True, "cuda"),
1243
("atan2", True, True, "cpu"),
1244
("atan2", True, True, "cuda"),
1245
("hypot", True, True, "cpu"),
1246
("hypot", True, True, "cuda"),
1247
("igamma", True, True, "cpu"),
1248
("igamma", True, True, "cuda"),
1249
("igammac", True, True, "cpu"),
1250
("igammac", True, True, "cuda"),
1251
("nextafter", True, True, "cpu"),
1252
("nextafter", True, True, "cuda"),
1253
("le", True, True, "cpu"),
1254
("le", True, True, "cuda"),
1255
("lt", True, True, "cpu"),
1256
("lt", True, True, "cuda"),
1257
("ge", True, True, "cpu"),
1258
("ge", True, True, "cuda"),
1259
("gt", True, True, "cpu"),
1260
("gt", True, True, "cuda"),
1261
("eq", True, True, "cpu"),
1262
("eq", True, True, "cuda"),
1263
("ne", True, True, "cpu"),
1264
("ne", True, True, "cuda"),
1265
("logical_and", True, True, "cpu"),
1266
("logical_and", True, True, "cuda"),
1267
("logical_or", True, True, "cpu"),
1268
("logical_or", True, True, "cuda"),
1269
("logical_xor", True, True, "cpu"),
1270
("logical_xor", True, True, "cuda"),
1275
has_input_output_mem_overlap_check,
1276
has_internal_mem_overlap_check,
1281
out_op = getattr(torch, fn)
1282
inplace_op = getattr(torch.Tensor, fn + "_")
1283
self.check_internal_mem_overlap(
1288
expected_failure=not has_internal_mem_overlap_check,
1291
self.binary_check_input_output_mem_overlap(
1292
out_op, device, expected_failure=not has_input_output_mem_overlap_check
1295
def _do_pow_for_exponents(self, m1, exponents, pow_fn, atol):
1296
for num in exponents:
1298
isinstance(num, int)
1300
and not m1.is_floating_point()
1301
and not m1.is_complex()
1303
with self.assertRaisesRegex(
1305
r"Integers to negative integer powers are not allowed\.",
1307
torch.pow(m1[4], num)
1311
res1 = torch.pow(m1[4], num)
1312
res2 = res1.clone().zero_()
1314
for i in range(res2.size(0)):
1315
res2[i] = pow_fn(m1[4][i], num)
1316
rtol = 0 if atol is not None else None
1317
self.assertEqual(res1, res2, atol=atol, rtol=rtol)
1320
res1 = torch.pow(m1[:, 4], num)
1321
res2 = res1.clone().zero_()
1322
for i in range(res2.size(0)):
1323
res2[i] = pow_fn(m1[i, 4], num)
1324
self.assertEqual(res1, res2, atol=atol, rtol=rtol)
1327
expected_dtype = torch.result_type(num, m1)
1330
torch.tensor(num, dtype=expected_dtype, device=m1.device) ** m1[4]
1332
self.assertEqual(res1, res2)
1333
self.assertEqual(res1.dtype, expected_dtype)
1335
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
1336
def test_pow(self, device, dtype):
1337
m1 = torch.empty(0, dtype=dtype, device=device)
1338
if m1.is_floating_point() or m1.is_complex():
1340
make_tensor((100, 100), low=0, high=1, dtype=dtype, device=device) + 0.5
1344
range_high = 4 if dtype in (torch.int8, torch.uint8) else 10
1346
(100, 100), low=1, high=range_high, dtype=dtype, device=device
1349
exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3, True, False]
1350
complex_exponents = [
1361
self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4)
1363
self._do_pow_for_exponents(m1, exponents, math.pow, None)
1364
will_raise_error = dtype is torch.half and torch.device(device).type == 'cpu'
1365
if will_raise_error:
1369
with self.assertRaisesRegex(RuntimeError, "not implemented for 'ComplexHalf'"):
1370
self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
1372
self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
1376
res1 = torch.pow(3, m1[4])
1377
res2 = res1.clone().zero_()
1378
for i in range(res2.size(0)):
1379
res2[i] = pow(3, m1[4, i])
1380
self.assertEqual(res1, res2)
1383
res1 = torch.pow(3, m1[:, 4])
1384
res2 = res1.clone().zero_()
1385
for i in range(res2.size(0)):
1386
res2[i] = pow(3, m1[i][4])
1387
self.assertEqual(res1, res2)
1390
def _test_pow(self, base, exponent, np_exponent=None):
1391
if np_exponent is None:
1392
np_exponent = exponent
1395
if isinstance(value, torch.Tensor):
1396
return value.cpu().numpy()
1400
np_res = np.power(to_np(base), to_np(np_exponent))
1402
torch.from_numpy(np_res)
1403
if isinstance(np_res, np.ndarray)
1404
else torch.tensor(np_res, dtype=base.dtype)
1406
except ValueError as e:
1407
err_msg = "Integers to negative integer powers are not allowed."
1408
self.assertEqual(str(e), err_msg)
1409
out = torch.empty_like(base)
1411
lambda: base.pow(exponent),
1412
lambda: base.pow_(exponent),
1413
lambda: torch.pow(base, exponent),
1414
lambda: torch.pow(base, exponent, out=out),
1416
for test_case in test_cases:
1417
self.assertRaisesRegex(RuntimeError, err_msg, test_case)
1419
if isinstance(base, torch.Tensor):
1420
actual = base.pow(exponent)
1421
self.assertEqual(actual, expected.to(actual))
1422
actual = base.clone()
1426
isinstance(exponent, torch.Tensor)
1428
and base.device.type == "cpu"
1429
and exponent.device.type == "cuda"
1431
regex = "Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!"
1432
self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent)
1433
elif torch.can_cast(torch.result_type(base, exponent), base.dtype):
1434
actual2 = actual.pow_(exponent)
1435
self.assertEqual(actual, expected)
1436
self.assertEqual(actual2, expected)
1438
self.assertRaisesRegex(
1440
"Found dtype \\w+ but expected \\w+",
1441
lambda: actual.pow_(exponent),
1444
actual = torch.pow(base, exponent)
1445
self.assertEqual(actual, expected.to(actual))
1447
actual2 = torch.pow(base, exponent, out=actual)
1448
self.assertEqual(actual, expected.to(actual))
1449
self.assertEqual(actual2, expected.to(actual))
1455
def test_pow_scalar_base(self, device):
1457
torch.arange(1, 13, dtype=torch.double, device=device)
1461
gradcheck(lambda a: torch.pow(2, a), (a,))
1465
def test_int_and_float_pow(self, device):
1466
def _test_int_and_float_pow(dt, low, high, dev):
1468
((4, 4), 0, (4, 1)),
1469
((3, 1), 4, (3, 1)),
1472
((513, 513), 4, (513,)),
1473
((5, 5, 5), 5, (5,)),
1476
for base_shape, exp_scalar, exp_shape in test_cases:
1477
base_tensor = make_tensor(
1478
base_shape, dtype=dt, device=dev, low=low, high=high
1488
exp_tensor = make_tensor(
1489
exp_shape, dtype=dt, device=dev, low=0, high=high
1492
exp_tensor = make_tensor(
1493
exp_shape, dtype=dt, device=dev, low=low, high=high
1495
self._test_pow(base_tensor, exp_scalar)
1496
self._test_pow(base_tensor, exp_tensor)
1498
base_tensor = make_tensor(
1513
exp_tensor = make_tensor(
1522
exp_tensor = make_tensor(
1530
self._test_pow(base_tensor, exp_scalar)
1531
self._test_pow(base_tensor, exp_tensor)
1533
_test_int_and_float_pow(torch.int8, -2, 2, device)
1534
_test_int_and_float_pow(torch.uint8, 0, 3, device)
1535
_test_int_and_float_pow(torch.int16, -5, 5, device)
1536
_test_int_and_float_pow(torch.int64, -10, 10, device)
1537
_test_int_and_float_pow(torch.int32, -10, 10, device)
1538
_test_int_and_float_pow(torch.float16, 0.0, 5.0, device)
1539
_test_int_and_float_pow(torch.float32, 0.0, 10.0, device)
1540
_test_int_and_float_pow(torch.float64, 0.0, 10.0, device)
1542
_test_int_and_float_pow(torch.float32, -10.0, 10.0, device)
1543
_test_int_and_float_pow(torch.float64, -10.0, 10.0, device)
1547
def test_pow_inplace_resizing_exception(self, device):
1552
((2, 2), (2, 1, 1)),
1557
base_size, dtype=torch.float64, device=device, high=10.0, low=0.0
1560
exp_size, dtype=torch.float64, device=device, high=10.0, low=0.0
1563
for base_size, exp_size in test_cases
1565
for base, exponent in test_inputs:
1566
regex = "doesn't match the broadcast shape"
1567
self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent)
1569
def test_int_tensor_pow_neg_ints(self, device):
1571
torch.iinfo(torch.int32).min,
1579
torch.iinfo(torch.int32).max,
1581
neg_ints = [torch.iinfo(torch.int32).min, -3, -2, -1]
1582
tensor = torch.tensor(ints, dtype=torch.int32, device=device)
1583
for pow in neg_ints:
1584
self._test_pow(tensor, pow)
1586
def test_long_tensor_pow_floats(self, device):
1587
ints = [0, 1, 23, 4567]
1588
floats = [0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0]
1589
tensor = torch.tensor(ints, dtype=torch.int64, device=device)
1591
self._test_pow(tensor, pow)
1593
@dtypes(*[torch.float32, torch.float64])
1594
def test_float_scalar_pow_float_tensor(self, device, dtype):
1595
floats = [2.0, -3 / 2, -1.0, -1 / 2, -1 / 3, 0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0]
1603
make_tensor(shape, dtype=dtype, device=device, low=0)
1604
for shape in exponent_shapes
1606
floats_tensor = torch.tensor(floats, dtype=dtype, device=device)
1608
self._test_pow(base, floats_tensor)
1609
for tensor in tensors:
1610
self._test_pow(base, tensor)
1613
def test_cuda_tensor_pow_scalar_tensor(self, device):
1615
torch.randn((3, 3), device=device),
1616
torch.tensor(3.0, device=device),
1619
torch.tensor(5.0, device="cpu"),
1623
for base, exp in product(cuda_tensors, scalar_tensors):
1624
self._test_pow(base, exp)
1627
def test_cpu_tensor_pow_cuda_scalar_tensor(self, device):
1629
torch.tensor(5.0, device="cuda"),
1630
torch.tensor(-3, device="cuda"),
1632
for exp in cuda_tensors:
1633
base = torch.randn((3, 3), device="cpu")
1634
regex = "Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!"
1635
self.assertRaisesRegex(RuntimeError, regex, torch.pow, base, exp)
1636
for exp in cuda_tensors:
1638
base = torch.tensor(3.0, device="cpu")
1639
self._test_pow(base, exp)
1642
@dtypes(torch.complex64, torch.complex128)
1643
def test_pow_cuda_complex_extremal_failing(self, device, dtype):
1644
t = torch.tensor(complex(-1.0, float("inf")), dtype=dtype, device=device)
1645
with self.assertRaises(AssertionError):
1647
cpu_out = t.cpu().pow(2)
1648
self.assertEqual(cpu_out, cuda_out)
1650
@skipIfTorchDynamo()
1651
@onlyNativeDeviceTypes
1652
@dtypes(*all_types_and_complex_and(torch.half))
1653
def test_complex_scalar_pow_tensor(self, device, dtype):
1654
complexes = [0.5j, 1.0 + 1.0j, -1.5j, 2.2 - 1.6j, 1 + 0j]
1655
first_exp = make_tensor((100,), dtype=dtype, device=device, low=-2, high=2)
1656
second_exp = make_tensor(
1657
(100,), dtype=dtype, device=device, low=-2, high=2, noncontiguous=True
1659
first_exp[0] = first_exp[10] = first_exp[20] = 0
1660
second_exp[0] = second_exp[10] = second_exp[20] = 0
1661
for base in complexes:
1667
will_raise_error = torch.device(device).type == 'cpu' and \
1668
dtype is torch.half and base != (1 + 0j)
1669
if will_raise_error:
1670
with self.assertRaisesRegex(RuntimeError, "not implemented for 'ComplexHalf'"):
1671
self._test_pow(base, first_exp)
1672
self._test_pow(base, second_exp)
1674
self._test_pow(base, first_exp)
1675
self._test_pow(base, second_exp)
1677
@onlyNativeDeviceTypes
1679
def test_pow_scalar_type_promotion(self, device):
1682
for input in inputs:
1684
input_tensor_uint8 = torch.tensor(input, dtype=torch.uint8, device=device)
1685
out_uint8_computation = torch.pow(
1688
out=torch.tensor(0, dtype=torch.int64, device=device),
1692
input_tensor_int64 = torch.tensor(input, dtype=torch.int64, device=device)
1693
out_int64_computation = torch.pow(
1696
out=torch.tensor(0, dtype=torch.int64, device=device),
1699
self.assertNotEqual(out_uint8_computation, out_int64_computation)
1701
out_uint8_computation.to(dtype=torch.uint8),
1702
out_int64_computation.to(dtype=torch.uint8),
1705
def test_tensor_pow_tensor(self, device):
1707
return l[-n:] + l[:-n]
1709
def test_tensor_pow_tensor(values, torch_type, numpy_type):
1710
vals_tensor = torch.tensor(values, dtype=torch_type, device=device)
1711
for i in range(len(values)):
1712
pows = rotate(values, i)
1713
pows_tensor = torch.tensor(pows, dtype=torch_type, device=device)
1714
self._test_pow(vals_tensor, pows_tensor)
1717
test_tensor_pow_tensor(ints, torch.uint8, np.uint8)
1718
test_tensor_pow_tensor(ints, torch.int8, np.int8)
1719
test_tensor_pow_tensor(ints, torch.int16, np.int16)
1720
test_tensor_pow_tensor(ints, torch.int32, np.int32)
1721
test_tensor_pow_tensor(ints, torch.int64, np.int64)
1723
floats = [-3.0, -2.0, -1.0, -1 / 2, -1 / 3, 0.0, 1 / 3, 1 / 2, 1.0, 2.0, 3.0]
1724
test_tensor_pow_tensor(floats, torch.float16, np.float16)
1725
test_tensor_pow_tensor(floats, torch.float32, np.float32)
1726
test_tensor_pow_tensor(floats, torch.float64, np.float64)
1728
def test_logical_xor_with_nontrivial_alignment(self, device):
1731
a = torch.randn(size, device=device) > 0
1732
b = torch.randn(size, device=device) > 0
1733
c = torch.randn(size, device=device) > 0
1734
non_trivial_alignment = [1, 2, 4, 8, 15]
1735
for i in non_trivial_alignment:
1736
for j in non_trivial_alignment:
1737
for k in non_trivial_alignment:
1741
torch.logical_xor(a_, b_, out=c_)
1742
for x, y, z in zip(a_.tolist(), b_.tolist(), c_.tolist()):
1743
self.assertEqual(x ^ y, z)
1745
@dtypes(torch.float)
1746
def test_add_with_tail(self, device, dtype):
1749
for tail_size in [1, 63, 67, 130]:
1750
size = 4096 + tail_size
1751
a = torch.randn(size, device=device, dtype=dtype)
1752
b = torch.randn(size, device=device, dtype=dtype)
1754
for x, y, z in zip(a.tolist(), b.tolist(), c.tolist()):
1755
self.assertEqual(x + y, z)
1760
@deviceCountAtLeast(2)
1762
def test_cross_device_binary_ops(self, devices):
1763
vals = (1.0, (2.0,))
1764
cpu_tensor = torch.randn(2, 2)
1766
def do_test(op, a, b):
1767
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1769
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1771
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1773
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1788
for a, b in product(vals, vals):
1789
a = torch.tensor(a, device=devices[0])
1790
b = torch.tensor(b, device=devices[1])
1797
@deviceCountAtLeast(2)
1799
def test_binary_op_scalar_device_unspecified(self, devices):
1800
scalar_val = torch.tensor(1.0)
1801
for default_device in devices:
1802
with torch.cuda.device(default_device):
1803
for device in devices:
1804
device_obj = torch.device(device)
1805
x = torch.rand(3, device=device)
1807
self.assertEqual(y0.device, device_obj)
1809
self.assertEqual(y1.device, device_obj)
1810
self.assertEqual(y0, y1)
1812
def test_div_and_floordiv_vs_python(self, device):
1815
def _scalar_helper(python_op, torch_op):
1816
for a, b in product(range(-10, 10), range(-10, 10)):
1817
for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1825
expected = python_op(a, b)
1827
for op in (operator.truediv, torch.true_divide):
1828
actual_scalar = torch_op(a, b)
1830
a_t = torch.tensor(a, device=device)
1831
b_t = torch.tensor(b, device=device)
1833
actual_tensor = torch_op(a_t, b_t)
1834
actual_first_tensor = torch_op(a_t, b)
1835
actual_second_tensor = torch_op(a, b_t)
1837
self.assertEqual(actual_scalar, expected)
1838
self.assertEqual(actual_tensor.item(), expected)
1839
self.assertEqual(actual_first_tensor, actual_tensor)
1840
self.assertEqual(actual_second_tensor, actual_tensor)
1842
_scalar_helper(operator.truediv, operator.truediv)
1843
_scalar_helper(operator.truediv, torch.true_divide)
1844
_scalar_helper(lambda a, b: math.floor(a / b), operator.floordiv)
1845
_scalar_helper(lambda a, b: math.floor(a / b), torch.floor_divide)
1847
@onlyNativeDeviceTypes
1848
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1849
def test_div_and_floordiv_script_vs_python(self, device):
1851
def _wrapped_div(a, b):
1854
def _wrapped_floordiv(a, b):
1857
scripted_div = torch.jit.script(_wrapped_div)
1858
scripted_floordiv = torch.jit.script(_wrapped_floordiv)
1859
for a, b in product(range(-10, 10), range(-10, 10)):
1860
for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1868
expected_div = a / b
1869
expected_floordiv = math.floor(a / b)
1870
a_t = torch.tensor(a, device=device)
1871
b_t = torch.tensor(b, device=device)
1873
self.assertEqual(scripted_div(a_t, b_t), expected_div)
1874
self.assertEqual(scripted_floordiv(a_t, b_t), expected_floordiv)
1877
def _wrapped_div_scalar(a):
1881
def _wrapped_rdiv_scalar(a):
1884
def _wrapped_floordiv_scalar(a):
1889
def _wrapped_rfloordiv_scalar(a):
1892
scripted_div_scalar = torch.jit.script(_wrapped_div_scalar)
1893
scripted_rdiv_scalar = torch.jit.script(_wrapped_rdiv_scalar)
1894
scripted_floordiv_scalar = torch.jit.script(_wrapped_floordiv_scalar)
1895
scripted_rfloordiv_scalar = torch.jit.script(_wrapped_rfloordiv_scalar)
1897
for a in range(-10, 10):
1898
for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1901
a_t = torch.tensor(a, device=device)
1903
self.assertEqual(a / 5, scripted_div_scalar(a_t))
1909
self.assertEqual(5 / a, scripted_rdiv_scalar(a_t))
1912
if a_t.is_floating_point():
1913
with self.assertRaises(RuntimeError):
1914
scripted_rfloordiv_scalar(a_t)
1918
self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t))
1920
@onlyNativeDeviceTypes
1921
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1922
def test_idiv_and_ifloordiv_vs_python(self, device):
1923
def _wrapped_idiv_tensor(a, b):
1927
def _wrapped_idiv_scalar(a):
1931
def _wrapped_true_divide__tensor(a, b):
1935
def _wrapped_true_divide__scalar(a):
1939
def _wrapped_floor_divide__tensor(a, b):
1943
def _wrapped_floor_divide__scalar(a):
1948
def _wrapped_ifloordiv_tensor(a, b):
1952
def _wrapped_ifloordiv_scalar(a):
1956
with self.assertRaises(torch.jit.frontend.NotSupportedError):
1957
scripted_ifloordiv_tensor = torch.jit.script(_wrapped_ifloordiv_tensor)
1959
with self.assertRaises(torch.jit.frontend.NotSupportedError):
1960
scripted_ifloordiv_scalar = torch.jit.script(_wrapped_ifloordiv_scalar)
1962
scripted_idiv_tensor = torch.jit.script(_wrapped_idiv_tensor)
1963
scripted_idiv_scalar = torch.jit.script(_wrapped_idiv_scalar)
1964
scripted_true_divide__tensor = torch.jit.script(_wrapped_true_divide__tensor)
1965
scripted_true_divide__scalar = torch.jit.script(_wrapped_true_divide__scalar)
1966
scripted_floor_divide__tensor = torch.jit.script(_wrapped_floor_divide__tensor)
1967
scripted_floor_divide__scalar = torch.jit.script(_wrapped_floor_divide__scalar)
1969
for a, b in product(range(-10, 10), range(-10, 10)):
1970
for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
1978
expected_idiv = a / b
1979
expected_ifloordiv = a // b
1981
a_t = torch.tensor(a, device=device)
1982
b_t = torch.tensor(b, device=device)
1984
if a_t.is_floating_point():
1991
self.assertEqual(tmp0.item(), expected_idiv)
1992
self.assertEqual(tmp1.item(), expected_idiv)
1994
scripted_true_divide__tensor(a_t.clone(), b_t).item(),
1998
scripted_true_divide__scalar(a_t.clone()).item(), a / 5
2002
with self.assertRaises(RuntimeError):
2004
with self.assertRaises(RuntimeError):
2006
with self.assertRaises(RuntimeError):
2007
scripted_true_divide__tensor(tmp, b_t)
2008
with self.assertRaises(RuntimeError):
2009
scripted_true_divide__scalar(tmp)
2011
if not a_t.is_floating_point() and b_t.is_floating_point():
2014
a_t.clone().floor_divide_(b_t)
2015
scripted_floor_divide__tensor(a_t.clone(), b_t)
2022
a_t.clone().floor_divide_(b_t).item(), expected_ifloordiv
2025
scripted_floor_divide__tensor(a_t.clone(), b_t).item(),
2030
self.assertEqual(tmp.item(), expected_ifloordiv)
2033
scripted_floor_divide__scalar(a_t), math.floor(a / 5)
2039
def test_binary_ops_with_scalars(self, device):
2040
for python_op, torch_op in (
2041
(operator.add, torch.add),
2042
(operator.sub, torch.sub),
2043
(operator.mul, torch.mul),
2044
(operator.truediv, torch.div),
2047
for a, b in product(range(-10, 10), range(-10, 10)):
2048
for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
2053
if b == 0 or a == 0:
2056
a_tensor = torch.tensor(a, device=device)
2057
b_tensor = torch.tensor(b, device=device)
2058
a_tensor_cpu = a_tensor.cpu()
2059
b_tensor_cpu = b_tensor.cpu()
2060
vals = (a, b, a_tensor, b_tensor, a_tensor_cpu, b_tensor_cpu)
2062
for args in product(vals, vals):
2063
first, second = args
2067
if not isinstance(first, torch.Tensor)
2072
if not isinstance(second, torch.Tensor)
2075
expected = python_op(first_scalar, second_scalar)
2077
self.assertEqual(expected, python_op(first, second))
2078
self.assertEqual(expected, torch_op(first, second))
2082
all_types_and(torch.half, torch.bfloat16, torch.bool),
2083
all_types_and(torch.half, torch.bfloat16, torch.bool),
2086
def test_maximum_minimum_type_promotion(self, device, dtypes):
2087
a = torch.tensor((0, 1), device=device, dtype=dtypes[0])
2088
b = torch.tensor((1, 0), device=device, dtype=dtypes[1])
2098
self.assertEqual(result.dtype, torch.result_type(a, b))
2100
@dtypes(*integral_types_and(torch.bool))
2101
def test_maximum_minimum_int_and_bool(self, device, dtype):
2103
(torch.maximum, torch.max, np.maximum),
2104
(torch.minimum, torch.min, np.minimum),
2105
(torch.fmax, None, np.fmax),
2106
(torch.fmin, None, np.fmin),
2108
rng = np.random.default_rng()
2110
rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]
2113
rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]
2116
for torch_op, alias, numpy_op in ops:
2117
a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
2118
b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
2119
tensor_result = torch_op(a_tensor, b_tensor)
2121
out = torch.empty_like(a_tensor)
2122
torch_op(a_tensor, b_tensor, out=out)
2124
numpy_result = numpy_op(a_np, b_np)
2126
if alias is not None:
2127
alias_result = alias(a_tensor, b_tensor)
2128
self.assertEqual(alias_result, tensor_result)
2130
self.assertEqual(tensor_result, numpy_result)
2131
self.assertEqual(out, numpy_result)
2133
@precisionOverride({torch.bfloat16: 1e-2})
2134
@dtypes(*(floating_types_and(torch.half, torch.bfloat16)))
2135
def test_maximum_minimum_float(self, device, dtype):
2137
(torch.maximum, torch.max, np.maximum),
2138
(torch.minimum, torch.min, np.minimum),
2139
(torch.fmax, None, np.fmax),
2140
(torch.fmin, None, np.fmin),
2143
if dtype == torch.bfloat16:
2144
a_np = np.random.randn(10).astype(np.float64)
2145
b_np = np.random.randn(10).astype(np.float64)
2147
a_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype])
2148
b_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype])
2150
for torch_op, alias, numpy_op in ops:
2151
numpy_result = numpy_op(a_np, b_np)
2153
a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
2154
b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
2155
tensor_result = torch_op(a_tensor, b_tensor)
2156
out = torch.empty_like(a_tensor)
2157
torch_op(a_tensor, b_tensor, out=out)
2159
if alias is not None:
2160
alias_result = alias(a_tensor, b_tensor)
2161
self.assertEqual(alias_result, tensor_result, exact_dtype=False)
2163
self.assertEqual(tensor_result, numpy_result, exact_dtype=False)
2164
self.assertEqual(out, numpy_result, exact_dtype=False)
2166
@dtypes(*(floating_types_and(torch.half, torch.bfloat16)))
2167
def test_maximum_minimum_float_nan_and_inf(self, device, dtype):
2171
(torch.maximum, torch.max, np.maximum),
2172
(torch.minimum, torch.min, np.minimum),
2173
(torch.fmax, None, np.fmax),
2174
(torch.fmin, None, np.fmin),
2196
if dtype == torch.bfloat16:
2197
a_np = np.array(a_vals, dtype=np.float64)
2198
b_np = np.array(b_vals, dtype=np.float64)
2200
a_np = np.array(a_vals, dtype=torch_to_numpy_dtype_dict[dtype])
2201
b_np = np.array(b_vals, dtype=torch_to_numpy_dtype_dict[dtype])
2203
for torch_op, alias, numpy_op in ops:
2204
numpy_result = numpy_op(a_np, b_np)
2206
a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
2207
b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
2208
tensor_result = torch_op(a_tensor, b_tensor)
2210
out = torch.empty_like(a_tensor)
2211
torch_op(a_tensor, b_tensor, out=out)
2213
if alias is not None:
2214
alias_result = alias(a_tensor, b_tensor)
2215
self.assertEqual(alias_result, tensor_result)
2217
if dtype == torch.bfloat16:
2218
self.assertEqual(tensor_result, numpy_result, exact_dtype=False)
2219
self.assertEqual(out, numpy_result, exact_dtype=False)
2221
self.assertEqual(tensor_result, numpy_result)
2222
self.assertEqual(out, numpy_result)
2227
all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
2230
def test_maximum_minimum_complex(self, device, dtypes):
2239
with self.assertRaisesRegex(RuntimeError, ".+not implemented for.+"):
2241
torch.ones(1, device=device, dtype=dtypes[0]),
2242
torch.ones(1, device=device, dtype=dtypes[1]),
2245
with self.assertRaisesRegex(RuntimeError, ".+not implemented for.+"):
2247
torch.ones(1, device=device, dtype=dtypes[1]),
2248
torch.ones(1, device=device, dtype=dtypes[0]),
2252
def test_maximum_minimum_cross_device(self, device):
2253
a = torch.tensor((1, 2, -1))
2254
b = torch.tensor((3, 0, 4), device=device)
2255
ops = (torch.maximum, torch.minimum)
2257
for torch_op in ops:
2258
with self.assertRaisesRegex(
2259
RuntimeError, "Expected all tensors to be on the same device"
2263
with self.assertRaisesRegex(
2264
RuntimeError, "Expected all tensors to be on the same device"
2269
ops = ((torch.maximum, np.maximum), (torch.minimum, np.minimum))
2271
b_np = np.array([3, 0, 4])
2273
for torch_op, numpy_op in ops:
2274
a_tensor = torch.from_numpy(a_np)
2275
b_tensor = torch.from_numpy(b_np).to(device=device)
2276
tensor_result_1 = torch_op(a_tensor, b_tensor)
2277
numpy_result_1 = numpy_op(a_np, b_np)
2278
tensor_result_2 = torch_op(b_tensor, a_tensor)
2279
numpy_result_2 = numpy_op(b_np, a_np)
2281
self.assertEqual(tensor_result_1, numpy_result_1)
2282
self.assertEqual(tensor_result_2, numpy_result_2)
2286
floating_types_and(torch.half, torch.bfloat16),
2287
floating_types_and(torch.half, torch.bfloat16),
2290
def test_maximum_and_minimum_subgradient(self, device, dtypes):
2291
def run_test(f, a, b, expected_a_grad, expected_b_grad):
2292
a = torch.tensor(a, requires_grad=True, device=device, dtype=dtypes[0])
2293
b = torch.tensor(b, requires_grad=True, device=device, dtype=dtypes[1])
2296
self.assertEqual(a.grad, expected_a_grad)
2297
self.assertEqual(b.grad, expected_b_grad)
2314
def test_maximum_minimum_forward_ad_float32(self, device):
2318
x = torch.randn(3, device=device, dtype=torch.float32)
2319
y = torch.randn(3, device=device, dtype=torch.float32)
2320
tx = torch.randn(3, device=device, dtype=torch.float32)
2321
ty = torch.randn(3, device=device, dtype=torch.float32)
2323
with fwAD.dual_level():
2324
x_dual = fwAD.make_dual(x, tx)
2325
y_dual = fwAD.make_dual(y, ty)
2326
result = torch.maximum(x_dual, y_dual)
2327
_, result_tangent = fwAD.unpack_dual(result)
2329
expected = torch.where(x > y, tx, ty)
2330
self.assertEqual(result_tangent, expected)
2332
with fwAD.dual_level():
2333
x_dual = fwAD.make_dual(x, tx)
2334
y_dual = fwAD.make_dual(y, ty)
2335
result = torch.minimum(x_dual, y_dual)
2336
_, result_tangent = fwAD.unpack_dual(result)
2338
expected = torch.where(x < y, tx, ty)
2339
self.assertEqual(result_tangent, expected)
2342
@dtypesIfCUDA(torch.half, torch.float, torch.double)
2343
@dtypes(torch.float, torch.double)
2344
def test_mul_intertype_scalar(self, device, dtype):
2345
x = torch.tensor(1.5, dtype=dtype, device=device)
2346
y = torch.tensor(3, dtype=torch.int32, device=device)
2348
self.assertEqual(x * y, 4.5)
2349
self.assertEqual(y * x, 4.5)
2351
with self.assertRaisesRegex(
2352
RuntimeError, "can't be cast to the desired output type"
2356
self.assertEqual(x, 4.5)
2359
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
2360
def test_sub(self, device, dtype):
2361
if dtype in integral_types():
2367
m1 = torch.tensor([2, 4], dtype=dtype, device=device)
2368
m2 = torch.tensor([1, 2], dtype=dtype, device=device)
2369
diff = torch.tensor([1, 2], dtype=dtype)
2371
m1 = torch.tensor([2.34, 4.44], dtype=dtype, device=device)
2372
m2 = torch.tensor([1.23, 2.33], dtype=dtype, device=device)
2373
diff = torch.tensor([1.11, 2.11], dtype=dtype)
2375
if dtype == torch.bool:
2376
self.assertRaises(RuntimeError, lambda: m1 - m2)
2377
elif dtype == torch.bfloat16 or dtype == torch.half:
2379
self.assertEqual(m1 - m2, diff, atol=0.01, rtol=0)
2381
self.assertEqual(m1 - m2, diff)
2385
@dtypes(torch.float)
2386
def test_csub(self, device, dtype):
2388
a = torch.randn(100, 90, dtype=dtype, device=device)
2389
b = a.clone().normal_()
2391
res_add = torch.add(a, b, alpha=-1)
2392
res_csub = a.clone()
2394
self.assertEqual(res_add, res_csub)
2397
a = torch.randn(100, 100, dtype=dtype, device=device)
2400
res_add = torch.add(a, -scalar)
2401
res_csub = a.clone()
2402
res_csub.sub_(scalar)
2403
self.assertEqual(res_add, res_csub)
2406
@dtypesIfCUDA(torch.half, torch.float, torch.double)
2407
@dtypes(torch.float, torch.double)
2408
def test_min_max_binary_op_nan(self, device, dtype):
2409
a = torch.rand(1000, dtype=dtype, device=device)
2410
b = torch.rand(1000, dtype=dtype, device=device)
2413
a[:250] = float("nan")
2415
b[250:500] = float("nan")
2417
a[500:750] = float("nan")
2418
b[500:750] = float("nan")
2421
ma = torch.max(a, b)
2422
mi = torch.min(a, b)
2424
for i in range(750):
2427
f"max(a, b): {ma[i]}, a: {a[i]}, b: {b[i]}",
2431
f"min(a, b): {mi[i]}, a: {a[i]}, b: {b[i]}",
2434
for i in range(750, 1000):
2437
f"max(a, b): {ma[i]}, a: {a[i]}, b: {b[i]}",
2441
f"min(a, b): {mi[i]}, a: {a[i]}, b: {b[i]}",
2446
all_types_and(torch.half, torch.bfloat16, torch.bool),
2447
all_types_and(torch.half, torch.bfloat16, torch.bool),
2450
def test_copysign(self, device, dtypes):
2451
def _test_copysign_numpy(a, b):
2452
torch_result = torch.copysign(a, b)
2454
if a.dtype == torch.bfloat16:
2455
np_a = a.to(torch.float).cpu().numpy()
2457
np_a = a.cpu().numpy()
2459
if b.dtype == torch.bfloat16:
2460
np_b = b.to(torch.float).cpu().numpy()
2462
np_b = b.cpu().numpy()
2463
expected = torch.from_numpy(np.copysign(np_a, np_b))
2466
types = integral_types_and(torch.bool, torch.bfloat16)
2467
if a.dtype in types or b.dtype in types:
2468
promoted_type = torch.promote_types(torch_result.dtype, expected.dtype)
2469
torch_result = torch_result.to(promoted_type)
2470
expected = expected.to(promoted_type)
2473
self.assertEqual(torch_result, expected)
2480
if a.dtype != torch.float16 and b.dtype != torch.float16:
2482
torch.copysign(torch.tensor(1.0), torch_result),
2483
torch.copysign(torch.tensor(1.0), expected),
2488
a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2489
b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2490
_test_copysign_numpy(a, b)
2493
a = make_tensor((10, 1, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2494
b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2495
_test_copysign_numpy(a, b)
2497
a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2498
b = make_tensor((10, 1, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2499
_test_copysign_numpy(a, b)
2502
cases = [0.0, -0.0, float("inf"), float("-inf"), float("nan")]
2505
types = [torch.float32, torch.float64]
2507
types.append(torch.float16)
2508
if dtypes[0] in types:
2509
b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
2511
_test_copysign_numpy(
2512
torch.tensor([case], device=device, dtype=dtypes[0]), b
2515
if dtypes[1] in floating_types_and(torch.half, torch.bfloat16):
2516
a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
2518
_test_copysign_numpy(
2519
a, torch.tensor([case], device=device, dtype=dtypes[1])
2524
floating_types_and(torch.half, torch.bfloat16),
2525
floating_types_and(torch.half, torch.bfloat16),
2528
def test_copysign_subgradient(self, device, dtypes):
2531
[0.0, 0.0, 0.0], dtype=dtypes[0], device=device, requires_grad=True
2534
[-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True
2536
out = torch.copysign(x, y)
2537
out.sum().backward()
2538
self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0])
2539
self.assertEqual(y.grad.tolist(), [0.0] * 3)
2543
[-0.0, -0.0, -0.0], dtype=dtypes[0], device=device, requires_grad=True
2546
[-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True
2548
out = torch.copysign(x, y)
2549
out.sum().backward()
2550
self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0])
2551
self.assertEqual(y.grad.tolist(), [0.0] * 3)
2555
[-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True
2558
[0.0, 0.0, 0.0], dtype=dtypes[1], device=device, requires_grad=True
2560
out = torch.copysign(x, y)
2561
out.sum().backward()
2562
self.assertEqual(x.grad.tolist(), [-1.0, 0.0, 1.0])
2563
self.assertEqual(y.grad.tolist(), [0.0] * 3)
2567
[-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True
2570
[-0.0, -0.0, -0.0], dtype=dtypes[1], device=device, requires_grad=True
2572
out = torch.copysign(x, y)
2573
out.sum().backward()
2574
self.assertEqual(x.grad.tolist(), [1.0, 0.0, -1.0])
2575
self.assertEqual(y.grad.tolist(), [0.0] * 3)
2577
@dtypes(torch.bfloat16, torch.float)
2578
def test_div(self, device, dtype):
2579
for op, method, inplace in (
2580
(torch.div, torch.Tensor.div, torch.Tensor.div_),
2581
(torch.true_divide, torch.Tensor.true_divide, torch.Tensor.true_divide_),
2583
m1 = torch.randn(10, 10, dtype=torch.float, device=device).to(dtype=dtype)
2585
inplace(res1[:, 3], 2)
2587
for i in range(m1.size(0)):
2588
res2[i, 3] = res2[i, 3] / 2
2589
self.assertEqual(res1, res2)
2591
if dtype == torch.bfloat16:
2592
a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device)
2593
a2 = torch.tensor([2.0, 2.0], dtype=dtype, device=device)
2596
torch.tensor([2.1, 3.1], dtype=dtype, device=device),
2600
self.assertEqual(method(a1, a2), op(a1, a2))
2602
@dtypes(torch.bfloat16, torch.float)
2603
def test_true_divide_out(self, device, dtype):
2604
a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device)
2605
a2 = torch.tensor([2.0, 2.0], dtype=dtype, device=device)
2606
res = torch.empty_like(a1)
2608
torch.true_divide(a1, a2, out=res),
2609
torch.tensor([2.1, 3.1], dtype=dtype, device=device),
2615
def test_divmul_scalar(self, device, dtype):
2616
x = torch.tensor(100.0, device=device, dtype=dtype)
2620
expected = x_ref.div(scale)
2621
self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0)
2622
x = torch.tensor(1e-5, device=device, dtype=dtype)
2625
expected = x_ref.mul(scale)
2626
self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0)
2628
self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0)
2631
*set(get_all_math_dtypes("cuda")) - {torch.complex64, torch.complex128}
2633
@dtypes(*set(get_all_math_dtypes("cpu")) - {torch.complex64, torch.complex128})
2634
def test_floor_divide_tensor(self, device, dtype):
2635
x = torch.randn(10, device=device).mul(30).to(dtype)
2636
y = torch.arange(1, 11, dtype=dtype, device=device)
2639
z_alt = torch.floor(x.double() / y.double()).to(dtype)
2641
self.assertEqual(z.dtype, x.dtype)
2642
self.assertEqual(z, z_alt)
2645
*set(get_all_math_dtypes("cuda")) - {torch.complex64, torch.complex128}
2647
@dtypes(*set(get_all_math_dtypes("cpu")) - {torch.complex64, torch.complex128})
2648
def test_floor_divide_scalar(self, device, dtype):
2649
x = torch.randn(100, device=device).mul(10).to(dtype)
2652
z_alt = torch.tensor(
2653
[math.floor(v.item() / 3.0) for v in x], dtype=x.dtype, device=device
2656
self.assertEqual(z.dtype, x.dtype)
2657
self.assertEqual(z, z_alt)
2660
@dtypes(*get_all_math_dtypes("cpu"))
2661
def test_rdiv(self, device, dtype):
2662
if dtype is torch.float16:
2664
elif dtype.is_complex:
2665
x = torch.rand(100, dtype=dtype, device=device).add(1).mul(4)
2667
x = torch.rand(100, device=device).add(1).mul(4).to(dtype)
2669
z = torch.tensor([30 / v.item() for v in x], device=device)
2670
self.assertEqual(y, z, exact_dtype=False)
2672
@dtypes(*floating_types_and(torch.half))
2673
def test_fmod_remainder_by_zero_float(self, device, dtype):
2674
fn_list = (torch.fmod, torch.remainder)
2677
x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2678
zero = torch.zeros_like(x)
2679
self.assertTrue(torch.all(fn(x, 0.0).isnan()))
2680
self.assertTrue(torch.all(fn(x, zero).isnan()))
2682
@onlyNativeDeviceTypes
2683
@dtypes(*integral_types())
2684
def test_fmod_remainder_by_zero_integral(self, device, dtype):
2685
fn_list = (torch.fmod, torch.remainder)
2688
x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2689
zero = torch.zeros_like(x)
2691
if self.device_type == "cpu":
2692
with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"):
2694
elif torch.version.hip is not None:
2696
self.assertEqual(fn(x, zero), x)
2704
if dtype == torch.int64:
2705
self.assertEqual(fn(x, zero) == 4294967295, x >= 0)
2706
self.assertEqual(fn(x, zero) == -1, x < 0)
2708
value = 255 if dtype == torch.uint8 else -1
2709
self.assertTrue(torch.all(fn(x, zero) == value))
2711
@dtypes(*all_types_and(torch.half))
2712
def test_fmod_remainder(self, device, dtype):
2714
def _helper(x, mod, fns_list):
2715
for fn, inplace_fn, ref_fn in fns_list:
2716
np_x = x.cpu().numpy() if torch.is_tensor(x) else x
2717
np_mod = mod.cpu().numpy() if torch.is_tensor(mod) else mod
2718
exp = ref_fn(np_x, np_mod)
2719
exp = torch.from_numpy(exp)
2722
self.assertEqual(res, exp, exact_dtype=False)
2724
if torch.is_tensor(x):
2726
out = torch.empty(0, device=device, dtype=res.dtype)
2728
self.assertEqual(out, exp, exact_dtype=False)
2729
self.assertEqual(out.size(), torch.Size([10, 10]))
2733
self.assertEqual(x, exp, exact_dtype=False)
2734
except RuntimeError as e:
2737
"result type (Half|Float|Double) "
2738
"can't be cast to the desired output "
2739
"type (Byte|Char|Short|Int|Long)",
2742
x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2744
mod = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
2749
mods = [3, 2.3, mod, mod.t()]
2751
if dtype in integral_types():
2752
mod_float = make_tensor(
2753
(10, 10), device=device, dtype=torch.float, low=-9, high=9
2756
mods.append(mod_float)
2758
for dividend, mod in product([x, x.t()], mods):
2763
(torch.fmod, torch.Tensor.fmod_, np.fmod),
2764
(torch.remainder, torch.Tensor.remainder_, np.remainder),
2769
for dividend, mod in product([5, 3.14], mods):
2770
if torch.is_tensor(mod):
2774
((torch.remainder, torch.Tensor.remainder_, np.remainder),),
2777
@dtypes(torch.float, torch.double)
2778
def test_remainder_fmod_large_dividend(self, device, dtype):
2780
pi = 3.14159265358979
2781
for avalue in [alarge, -alarge]:
2782
for bvalue in [pi, -pi]:
2783
a = torch.tensor([avalue], dtype=dtype, device=device)
2784
b = torch.tensor([bvalue], dtype=dtype, device=device)
2785
c = torch.remainder(a, b)
2786
d = torch.fmod(a, b)
2788
(b[0] > 0) == (c[0] > 0)
2791
(a[0] > 0) == (d[0] > 0)
2794
abs(c[0]) < abs(b[0])
2797
abs(d[0]) < abs(b[0])
2799
if (a[0] > 0) == (b[0] > 0):
2800
self.assertTrue(c[0] == d[0])
2803
abs(c[0] - d[0]) == abs(b[0])
2806
@dtypesIfCPU(torch.bfloat16, torch.half, torch.float32, torch.float64)
2807
@dtypes(torch.float32, torch.float64)
2808
def test_hypot(self, device, dtype):
2811
torch.randn(10, device=device).to(dtype),
2812
torch.randn(10, device=device).to(dtype),
2815
torch.randn((3, 3, 3), device=device).to(dtype),
2816
torch.randn((3, 3, 3), device=device).to(dtype),
2819
torch.randn((10, 1), device=device).to(dtype),
2820
torch.randn((10, 1), device=device).to(dtype).transpose(0, 1),
2823
torch.randint(100, (10,), device=device, dtype=torch.long),
2824
torch.randn(10, device=device).to(dtype),
2827
for input in inputs:
2828
actual = torch.hypot(input[0], input[1])
2829
if dtype in [torch.bfloat16, torch.half]:
2830
expected = torch.sqrt(input[0] * input[0] + input[1] * input[1])
2832
expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy())
2833
self.assertEqual(actual, expected, exact_dtype=False)
2835
@onlyNativeDeviceTypes
2836
@dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
2837
def test_gcd(self, device, dtype):
2839
t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device)
2840
t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device)
2841
actual = torch.gcd(t1, t2)
2842
expected = np.gcd([0, 10, 0], [0, 0, 10])
2843
self.assertEqual(actual, expected, exact_dtype=False)
2845
if dtype == torch.uint8:
2847
a = torch.tensor([190, 210], device=device, dtype=dtype)
2848
b = torch.tensor([190, 220], device=device, dtype=dtype)
2849
actual = torch.gcd(a, b)
2850
expected = torch.tensor([190, 10], device=device, dtype=dtype)
2851
self.assertEqual(actual, expected)
2854
a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2855
b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2856
actual = torch.gcd(a, b)
2857
expected = np.gcd(a.cpu().numpy(), b.cpu().numpy())
2858
self.assertEqual(actual, expected)
2860
@onlyNativeDeviceTypes
2861
@dtypes(torch.int16, torch.int32, torch.int64)
2862
def test_lcm(self, device, dtype):
2864
t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device)
2865
t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device)
2866
actual = torch.lcm(t1, t2)
2867
expected = np.lcm([0, 10, 0], [0, 0, 10])
2868
self.assertEqual(actual, expected, exact_dtype=False)
2871
a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2872
b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
2873
actual = torch.lcm(a, b)
2874
expected = np.lcm(a.cpu().numpy(), b.cpu().numpy())
2875
self.assertEqual(actual, expected, exact_dtype=False)
2877
@onlyNativeDeviceTypes
2878
@dtypesIfCPU(torch.float32, torch.float64, torch.float16)
2879
@dtypes(torch.float32, torch.float64)
2880
def test_nextafter(self, device, dtype):
2882
t1 = torch.tensor([0, 0, 10], device=device, dtype=dtype)
2883
t2 = torch.tensor([inf, -inf, 10], device=device, dtype=dtype)
2884
actual = torch.nextafter(t1, t2)
2885
expected = np.nextafter(t1.cpu().numpy(), t2.cpu().numpy())
2886
self.assertEqual(actual, expected, atol=0, rtol=0)
2888
actual = torch.nextafter(t2, t1)
2889
expected = np.nextafter(t2.cpu().numpy(), t1.cpu().numpy())
2890
self.assertEqual(actual, expected, atol=0, rtol=0)
2892
t1 = torch.tensor([0, nan], device=device, dtype=dtype)
2893
t2 = torch.tensor([nan, 0], device=device, dtype=dtype)
2894
self.assertTrue(torch.nextafter(t1, t2).isnan().all())
2896
a = torch.randn(100, device=device, dtype=dtype)
2897
b = torch.randn(100, device=device, dtype=dtype)
2898
actual = torch.nextafter(a, b)
2899
expected = np.nextafter(a.cpu().numpy(), b.cpu().numpy())
2900
self.assertEqual(actual, expected, atol=0, rtol=0)
2902
@onlyNativeDeviceTypes
2903
@dtypes(torch.bfloat16)
2904
def test_nextafter_bfloat16(self, device, dtype):
2909
(0, 1, 9.183549615799121e-41),
2910
(0, -1, -9.183549615799121e-41),
2911
(1, -2, 0.99609375),
2914
(-1, -2, -1.0078125),
2915
(-1, 0, -0.99609375),
2919
(20, -3000, 19.875),
2920
(3000, -20, 2992.0),
2921
(-3000, 20, -2992.0),
2922
(65536, 0, 65280.0),
2923
(65536, inf, 66048.0),
2924
(-65536, 0, -65280.0),
2925
(-65536, -inf, -66048.0),
2931
(inf, -inf, 3.3895313892515355e38),
2932
(-inf, inf, -3.3895313892515355e38),
2933
(inf, 0, 3.3895313892515355e38),
2934
(0, inf, 9.183549615799121e-41),
2935
(-inf, 0, -3.3895313892515355e38),
2936
(0, -inf, -9.183549615799121e-41),
2939
for from_v, to_v, expected in cases:
2940
from_t = torch.tensor([from_v], device=device, dtype=dtype)
2941
to_t = torch.tensor([to_v], device=device, dtype=dtype)
2942
actual = torch.nextafter(from_t, to_t).item()
2943
self.assertEqual(actual, expected, atol=0, rtol=0)
2945
def _test_cop(self, torchfn, mathfn, dtype, device):
2946
def reference_implementation(res2):
2947
for i, j in iter_indices(sm1):
2948
idx1d = i * sm1.size(0) + j
2949
res2[i, j] = mathfn(sm1[i, j], sm2[idx1d])
2953
m1 = torch.randn(10, 10, 10, dtype=dtype, device=device)
2954
m2 = torch.randn(10, 10 * 10, dtype=dtype, device=device)
2958
res1 = torchfn(sm1, sm2.view(10, 10))
2959
res2 = reference_implementation(res1.clone())
2960
self.assertEqual(res1, res2)
2963
m1 = torch.randn(10, 10, 10, dtype=dtype, device=device)
2964
m2 = torch.randn(10 * 10, 10 * 10, dtype=dtype, device=device)
2970
sm2.storage_offset(),
2972
(sm2.stride()[0] * 10, sm2.stride()[0]),
2974
res1 = torchfn(sm1, sm2)
2977
sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride()
2979
res2 = reference_implementation(res1.clone())
2980
self.assertEqual(res1, res2)
2983
@dtypes(torch.float)
2984
def test_cdiv(self, device, dtype):
2985
self._test_cop(torch.div, operator.truediv, dtype, device)
2988
@dtypes(torch.float)
2989
def test_cremainder(self, device, dtype):
2990
self._test_cop(torch.remainder, operator.mod, dtype, device)
2993
@dtypes(torch.float)
2994
def test_cmul(self, device, dtype):
2995
self._test_cop(torch.mul, operator.mul, dtype, device)
2998
@dtypes(torch.float)
2999
def test_cpow(self, device, dtype):
3001
torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y), dtype, device
3005
@dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
3006
def test_floor_divide_zero(self, device, dtype):
3007
a = torch.tensor([0, 1], dtype=dtype, device=device)
3008
b = torch.tensor([0, 1], dtype=dtype, device=device)
3009
with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"):
3010
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
3013
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
3014
def test_muldiv_scalar(self, device, dtype):
3015
x = make_tensor((10, 3), dtype=dtype, device=device, low=None, high=None)
3016
s = make_tensor((1,), dtype=dtype, device="cpu", low=None, high=None).item()
3017
y = torch.full_like(x, s)
3018
self.assertEqual(x * s, x * y)
3019
self.assertEqual(s * x, y * x)
3020
self.assertEqual(x / s, x / y)
3021
self.assertEqual(s / x, y / x)
3024
def _generate_input(self, shape, dtype, device, with_extremal):
3026
x = torch.tensor((), dtype=dtype, device=device)
3028
if dtype.is_floating_point or dtype.is_complex:
3030
if dtype == torch.bfloat16:
3031
x = torch.randn(*shape, device=device) * random.randint(30, 100)
3032
x = x.to(torch.bfloat16)
3035
*shape, dtype=dtype, device=device
3036
) * random.randint(30, 100)
3037
x[torch.randn(*shape) > 0.5] = 0
3038
if with_extremal and dtype.is_floating_point:
3040
x[torch.randn(*shape) > 0.5] = float("nan")
3041
x[torch.randn(*shape) > 0.5] = float("inf")
3042
x[torch.randn(*shape) > 0.5] = float("-inf")
3043
elif with_extremal and dtype.is_complex:
3044
x[torch.randn(*shape) > 0.5] = complex("nan")
3045
x[torch.randn(*shape) > 0.5] = complex("inf")
3046
x[torch.randn(*shape) > 0.5] = complex("-inf")
3047
elif dtype == torch.bool:
3048
x = torch.zeros(shape, dtype=dtype, device=device)
3049
x[torch.randn(*shape) > 0.5] = True
3051
x = torch.randint(15, 100, shape, dtype=dtype, device=device)
3057
itertools.combinations_with_replacement(
3058
all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 2
3062
def test_comparison_ops_type_promotion_and_broadcasting(self, device, dtypes):
3066
def compare_with_numpy_bin_op(torch_fn, np_fn, x, y, out=None):
3069
x_np = x if x.dtype != torch.bfloat16 else x.to(torch.float32)
3072
if y.dtype != torch.bfloat16
3073
else y.to(torch.float32).cpu().numpy()
3075
self.compare_with_numpy(
3076
lambda inp: torch_fn(inp, y, out=out) if out else torch_fn(inp, y),
3077
lambda inp: np_fn(inp, y_np, out=out) if out else np_fn(inp, y_np),
3081
complex_op_denylist = [
3087
input_sizes = [(1,), (10,), (10, 1), (1, 10), (4, 10), (64, 10), (12, 3)]
3089
(torch.lt, np.less),
3090
(torch.le, np.less_equal),
3091
(torch.gt, np.greater),
3092
(torch.ge, np.greater_equal),
3093
(torch.eq, np.equal),
3094
(torch.ne, np.not_equal),
3095
(torch.logical_and, np.logical_and),
3096
(torch.logical_or, np.logical_or),
3097
(torch.logical_xor, np.logical_xor),
3100
for size1 in input_sizes:
3101
size2 = (2,) + size1
3102
for with_extremal in [False, True]:
3103
a = self._generate_input(size1, dtypes[0], device, with_extremal)
3104
b = self._generate_input(size2, dtypes[1], device, with_extremal)
3105
for torch_op, numpy_op in op_pairs:
3107
dtypes[0].is_complex or dtypes[1].is_complex
3108
) and torch_op in complex_op_denylist:
3111
compare_with_numpy_bin_op(torch_op, numpy_op, a, b)
3114
self.assertEqual(torch_op(a, b).dtype, torch.bool)
3118
1, dtype=torch.complex128
3120
compare_with_numpy_bin_op(torch_op, numpy_op, a, b, out=out)
3122
@onlyNativeDeviceTypes
3123
@dtypes(torch.int8, torch.int16, torch.int32, torch.int64)
3124
def test_signed_shift(self, device, dtype):
3125
"Ensure that signed integer bit shifting works as expected."
3126
a = torch.tensor([-10, 10], device=device, dtype=dtype)
3127
expected_l = torch.tensor(
3128
[-40, 40], device=device, dtype=dtype
3130
self.assertEqual(a << 2, expected_l)
3131
self.compare_with_numpy(lambda x: x << 2, lambda x: np.left_shift(x, 2), a)
3132
expected_r = torch.tensor(
3133
[-5, 5], device=device, dtype=dtype
3135
self.assertEqual(a >> 1, expected_r)
3136
self.compare_with_numpy(lambda x: x >> 1, lambda x: np.right_shift(x, 1), a)
3138
@onlyNativeDeviceTypes
3139
@dtypes(*get_all_int_dtypes())
3140
def test_shift_limits(self, device, dtype):
3141
"Ensure that integer bit shifting works as expected with out-of-limits shift values."
3143
iinfo = torch.iinfo(dtype)
3147
exact_dtype = dtype != torch.uint8
3149
torch.tensor([-1, 0, 1], device=device, dtype=dtype),
3150
torch.tensor([low, high], device=device, dtype=dtype),
3151
make_tensor((64, 64, 64), low=low, high=high, device=device, dtype=dtype),
3153
shift_left_expected = torch.zeros_like(input)
3154
shift_right_expected = torch.clamp(input, -1, 0)
3155
for shift in chain(range(-100, -1), range(bits, 100)):
3156
shift_left = input << shift
3157
self.assertEqual(shift_left, shift_left_expected, msg=f"<< {shift}")
3158
self.compare_with_numpy(
3159
lambda x: x << shift,
3160
lambda x: np.left_shift(x, shift),
3162
exact_dtype=exact_dtype, msg=f"<< {shift}"
3164
shift_right = input >> shift
3165
self.assertEqual(shift_right, shift_right_expected, msg=f">> {shift}")
3166
self.compare_with_numpy(
3167
lambda x: x >> shift,
3168
lambda x: np.right_shift(x, shift),
3170
exact_dtype=exact_dtype, msg=f">> {shift}"
3173
@onlyNativeDeviceTypes
3177
all_types_and(torch.half, torch.bfloat16, torch.bool),
3178
all_types_and(torch.half, torch.bfloat16, torch.bool),
3182
def test_heaviside(self, device, dtypes):
3183
input_dtype = dtypes[0]
3184
values_dtype = dtypes[1]
3186
rng = np.random.default_rng()
3188
rng.integers(-10, 10, size=10),
3189
dtype=torch_to_numpy_dtype_dict[
3190
input_dtype if (input_dtype != torch.bfloat16) else torch.float64
3193
input[0] = input[3] = input[7] = 0
3195
rng.integers(-10, 10, size=10),
3196
dtype=torch_to_numpy_dtype_dict[
3197
values_dtype if (values_dtype != torch.bfloat16) else torch.float64
3200
np_result = torch.from_numpy(np.heaviside(input, values)).to(
3201
device=device, dtype=input_dtype
3204
input = torch.from_numpy(input).to(device=device, dtype=input_dtype)
3205
values = torch.from_numpy(values).to(device=device, dtype=values_dtype)
3206
out = torch.empty_like(input)
3208
if input_dtype == values_dtype:
3209
torch_result = torch.heaviside(input, values)
3210
self.assertEqual(np_result, torch_result)
3212
torch_result = input.heaviside(values)
3213
self.assertEqual(np_result, torch_result)
3215
torch.heaviside(input, values, out=out)
3216
self.assertEqual(np_result, out)
3218
input.heaviside_(values)
3219
self.assertEqual(np_result, input)
3221
with self.assertRaisesRegex(
3223
"heaviside is not yet implemented for tensors with different dtypes.",
3225
torch.heaviside(input, values)
3226
with self.assertRaisesRegex(
3228
"heaviside is not yet implemented for tensors with different dtypes.",
3230
input.heaviside(values)
3231
with self.assertRaisesRegex(
3233
"heaviside is not yet implemented for tensors with different dtypes.",
3235
torch.heaviside(input, values, out=out)
3236
with self.assertRaisesRegex(
3238
"heaviside is not yet implemented for tensors with different dtypes.",
3240
input.heaviside_(values)
3243
def test_heaviside_cross_device(self, device):
3244
x = torch.tensor([-9, 5, 0, 6, -2, 2], device=device)
3246
result = torch.heaviside(x, y)
3247
expect = torch.tensor([0, 1, 0, 1, 0, 1], device=device)
3248
self.assertEqual(result, expect)
3250
result = torch.heaviside(y, x)
3251
expect = torch.tensor([-9, 5, 0, 6, -2, 2], device=device)
3252
self.assertEqual(result, expect)
3254
x = torch.tensor([-9, 5, 0, 6, -2, 2])
3255
y = torch.tensor(0, device=device)
3256
with self.assertRaisesRegex(
3257
RuntimeError, "Expected all tensors to be on the same device"
3259
torch.heaviside(x, y)
3261
with self.assertRaisesRegex(
3262
RuntimeError, "Expected all tensors to be on the same device"
3264
torch.heaviside(y, x)
3266
@dtypes(*list(product(complex_types(), complex_types())))
3267
def test_heaviside_complex(self, device, dtypes):
3268
input_dtype = dtypes[0]
3269
values_dtype = dtypes[1]
3271
data = (complex(0, -6), complex(-1, 3), complex(1, 1))
3272
input = torch.tensor(data, device=device, dtype=input_dtype)
3273
values = torch.tensor(data, device=device, dtype=values_dtype)
3274
out = torch.empty_like(input)
3277
with self.assertRaisesRegex(
3278
RuntimeError, "heaviside is not yet implemented for complex tensors."
3280
torch.heaviside(input, real)
3281
with self.assertRaisesRegex(
3282
RuntimeError, "heaviside is not yet implemented for complex tensors."
3284
real.heaviside(values)
3285
with self.assertRaisesRegex(
3286
RuntimeError, "heaviside is not yet implemented for complex tensors."
3288
input.heaviside_(values)
3289
with self.assertRaisesRegex(
3290
RuntimeError, "heaviside is not yet implemented for complex tensors."
3292
torch.heaviside(real, real, out=out)
3294
def _test_logical(self, device, dtypes, op, a_, b_, expected_res_):
3295
expected_res = torch.tensor(expected_res_, dtype=dtypes[0], device=device)
3296
a = torch.tensor(a_, dtype=dtypes[0], device=device)
3297
b = torch.tensor(b_, dtype=dtypes[1], device=device)
3300
self.assertEqual(expected_res.bool(), getattr(a, op)(b))
3302
c = torch.empty(0, dtype=torch.bool, device=device)
3303
getattr(torch, op)(a, b, out=c)
3304
self.assertEqual(expected_res.bool(), c)
3306
getattr(a, op + "_")(b)
3307
self.assertEqual(expected_res, a)
3311
all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3312
all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3315
def test_logical_xor(self, device, dtypes):
3317
device, dtypes, "logical_xor", [10, 0, 1, 0], [1, 0, 0, 10], [0, 0, 1, 1]
3322
all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3323
all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3326
def test_logical_and(self, device, dtypes):
3328
device, dtypes, "logical_and", [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 0, 0]
3333
all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3334
all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
3337
def test_logical_or(self, device, dtypes):
3339
device, dtypes, "logical_or", [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1]
3342
def test_remainder_overflow(self, device):
3344
x = torch.tensor(23500, dtype=torch.int64, device=device)
3346
self.assertEqual(x % q, x)
3347
self.assertEqual(-x % q, q - x)
3348
self.assertEqual(x % -q, x - q)
3349
self.assertEqual(-x % -q, -x)
3351
def test_rpow(self, device):
3352
m = torch.randn(10, 10, device=device)
3353
self.assertEqual(torch.pow(2, m), 2**m)
3356
m = torch.randn(1, device=device).squeeze()
3357
assert m.dim() == 0, "m is intentionally a scalar"
3358
self.assertEqual(torch.pow(2, m), 2**m)
3361
def test_ldexp(self, device):
3363
mantissas = torch.randn(64, device=device)
3364
exponents = torch.randint(-31, 31, (64,), device=device, dtype=torch.int32)
3367
np_outcome = np.ldexp(mantissas.numpy(), exponents.numpy())
3368
pt_outcome_1 = torch.ldexp(mantissas, exponents)
3369
pt_outcome_2 = mantissas.ldexp(exponents)
3370
self.assertEqual(np_outcome, pt_outcome_1)
3371
self.assertEqual(np_outcome, pt_outcome_2)
3372
mantissas.ldexp_(exponents)
3373
self.assertEqual(np_outcome, mantissas)
3376
mantissas = torch.tensor(
3377
[float("inf"), float("-inf"), float("inf"), float("nan")], device=device
3379
exponents = torch.randint(0, 31, (4,), device=device, dtype=torch.int32)
3380
np_outcome = np.ldexp(mantissas.numpy(), exponents.numpy())
3381
pt_outcome = torch.ldexp(mantissas, exponents)
3382
self.assertEqual(np_outcome, pt_outcome)
3384
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3385
def test_lerp(self, device, dtype):
3386
start_end_weight_shapes = [(), (5,), (5, 5)]
3387
for shapes in product(
3388
start_end_weight_shapes, start_end_weight_shapes, start_end_weight_shapes
3390
start = torch.randn(shapes[0], device=device, dtype=dtype)
3391
end = torch.randn(shapes[1], device=device, dtype=dtype)
3395
torch.randn(shapes[2], device=device, dtype=dtype),
3398
if dtype.is_complex:
3399
weights += [complex(0, 1), complex(0.4, 1.2)]
3401
for weight in weights:
3402
actual = torch.lerp(start, end, weight)
3403
actual_method = start.lerp(end, weight)
3404
self.assertEqual(actual, actual_method)
3405
actual_out = torch.tensor(1.0, dtype=dtype, device=device)
3406
torch.lerp(start, end, weight, out=actual_out)
3407
self.assertEqual(actual, actual_out)
3408
expected = start + weight * (end - start)
3409
self.assertEqual(expected, actual)
3412
@dtypes(torch.half, torch.bfloat16)
3413
def test_lerp_lowp(self, device, dtype):
3414
xvals = (0.0, -30000.0)
3415
yvals = (0.1, -20000.0)
3416
xs = [torch.full((4,), xval, device=device, dtype=dtype) for xval in xvals]
3417
ys = [torch.full((4,), yval, device=device, dtype=dtype) for yval in yvals]
3418
weights = [70000, torch.full((4,), 8, device=device, dtype=dtype)]
3419
for x, y, w in zip(xs, ys, weights):
3422
wref = w.float() if isinstance(w, torch.Tensor) else w
3423
actual = torch.lerp(x, y, w)
3424
expected = torch.lerp(xref, yref, wref).to(dtype)
3425
self.assertEqual(actual, expected, atol=0.0, rtol=0.0)
3428
@dtypes(torch.half, torch.bfloat16)
3429
def test_lerp_lowp_cpu(self, device, dtype):
3430
xvals = (0.0, -30000.0)
3431
yvals = (0.1, -20000.0)
3432
for shape in [(4,), (20,), (3, 10, 10)]:
3433
xs = [torch.full(shape, xval, device=device, dtype=dtype) for xval in xvals]
3434
ys = [torch.full(shape, yval, device=device, dtype=dtype) for yval in yvals]
3435
weights = [70000, torch.full(shape, 8, device=device, dtype=dtype)]
3436
for x, y, w in zip(xs, ys, weights):
3439
wref = w.float() if isinstance(w, torch.Tensor) else w
3440
actual = torch.lerp(x, y, w)
3441
expected = torch.lerp(xref, yref, wref).to(dtype)
3442
self.assertEqual(actual, expected, atol=0.0, rtol=0.0)
3444
def _test_logaddexp(self, device, dtype, base2):
3446
ref_func = np.logaddexp2
3447
our_func = torch.logaddexp2
3448
elif dtype in (torch.complex64, torch.complex128):
3450
def _ref_func(x, y):
3451
return scipy.special.logsumexp(np.stack((x, y), axis=0), axis=0)
3452
ref_func = _ref_func
3453
our_func = torch.logaddexp
3455
ref_func = np.logaddexp
3456
our_func = torch.logaddexp
3458
def _test_helper(a, b):
3459
if dtype == torch.bfloat16:
3460
ref = ref_func(a.cpu().float().numpy(), b.cpu().float().numpy())
3462
self.assertEqual(ref, v.float(), atol=0.01, rtol=0.01)
3464
ref = ref_func(a.cpu().numpy(), b.cpu().numpy())
3466
self.assertEqual(ref, v)
3469
a = torch.randn(64, 2, dtype=dtype, device=device) - 0.5
3470
b = torch.randn(64, 2, dtype=dtype, device=device) - 0.5
3472
_test_helper(a[:3], b[:3])
3478
_test_helper(a[:3], b[:3])
3481
[float("inf"), float("-inf"), float("inf"), float("nan")],
3486
[float("inf"), float("-inf"), float("-inf"), float("nan")],
3492
@skipIfTorchDynamo()
3493
@dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16)
3494
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.complex64, torch.complex128)
3495
def test_logaddexp(self, device, dtype):
3496
self._test_logaddexp(device, dtype, base2=False)
3498
@dtypes(torch.float32, torch.float64, torch.bfloat16)
3499
def test_logaddexp2(self, device, dtype):
3500
self._test_logaddexp(device, dtype, base2=True)
3502
def test_add(self, device):
3503
dtypes = floating_and_complex_types()
3504
for dtype in dtypes:
3506
m1 = torch.randn(100, 100, dtype=dtype, device=device)
3507
v1 = torch.randn(100, dtype=dtype, device=device)
3510
res1 = torch.add(m1[4], v1)
3511
res2 = res1.clone().zero_()
3512
for i in range(m1.size(1)):
3513
res2[i] = m1[4, i] + v1[i]
3514
self.assertEqual(res1, res2)
3516
m1 = torch.randn(100, 100, device=device)
3517
v1 = torch.randn(100, device=device)
3520
res1 = torch.add(m1[:, 4], v1)
3521
res2 = res1.clone().zero_()
3522
for i in range(m1.size(0)):
3523
res2[i] = m1[i, 4] + v1[i]
3524
self.assertEqual(res1, res2)
3527
m1 = torch.randn(10, 10, device=device)
3533
for i in range(m1.size(1)):
3534
res2[3, i] = res2[3, i] + 2
3535
self.assertEqual(res1, res2)
3538
m1 = torch.randn(10, 10, device=device)
3542
for i in range(m1.size(0)):
3543
res2[i, 3] = res2[i, 3] + 2
3544
self.assertEqual(res1, res2)
3547
m1 = torch.randn(10, 10, dtype=dtype, device=device)
3548
self.assertEqual(m1 + 3, m1 + torch.tensor(3))
3549
self.assertEqual(3 + m1, torch.tensor(3) + m1)
3552
m1 = torch.randn(10, 10, dtype=dtype, device=device)
3553
m2 = torch.randn(10, 10, dtype=dtype, device=device).t()
3555
self.assertTrue(res.is_contiguous())
3556
self.assertEqual(res, m1 + m2.contiguous())
3559
m1 = torch.tensor([1.0], dtype=dtype, device=device)
3560
m2 = torch.tensor([], dtype=dtype, device=device)
3561
self.assertEqual(m1 + m2, [])
3564
one = torch.tensor(1, dtype=torch.uint8, device=device)
3565
self.assertEqual(torch.add(one, 1), 2)
3566
self.assertEqual(torch.add(one, 1).dtype, torch.uint8)
3570
[True, False, False, True, False, False], dtype=torch.bool, device=device
3573
[True, True, False, False, False, True], dtype=torch.bool, device=device
3575
expected = torch.tensor(
3576
[True, True, False, True, False, True], dtype=torch.bool, device=device
3578
self.assertEqual(m1 + m2, expected)
3581
a = torch.zeros(2, 3, dtype=torch.bool, device=device)
3582
res = torch.add(a, a, alpha=0)
3583
expected = torch.zeros(2, 3, device=device).bool()
3584
self.assertEqual(res, expected)
3587
m1 = torch.tensor([1.0, 2.0], dtype=torch.bfloat16)
3588
m2 = torch.tensor([3.0, 4.0], dtype=torch.bfloat16)
3589
self.assertEqual(m1 + m2, torch.tensor([4.0, 6.0], dtype=torch.bfloat16))
3592
m1 = torch.tensor([2 + 3j, 4 + 5j], dtype=torch.complex64, device=device)
3593
m2 = torch.tensor([4 + 5j, 2 + 3j], dtype=torch.complex64, device=device)
3595
res = torch.add(m1, m2, alpha=0.1)
3596
expected = torch.tensor(
3597
[2.4000 + 3.5000j, 4.2000 + 5.3000j], dtype=torch.complex64, device=device
3599
self.assertEqual(res, expected)
3602
res = torch.add(m1, m2, alpha=complex(0.1, 0.2))
3603
expected = torch.tensor(
3604
[1.4000 + 4.3000j, 3.6000 + 5.7000j], dtype=torch.complex64, device=device
3606
self.assertEqual(res, expected)
3609
res = torch.add(m1, m2, alpha=2)
3610
expected = torch.tensor(
3611
[10.0 + 13.0j, 8.0 + 11.0j], dtype=torch.complex64, device=device
3613
self.assertEqual(res, expected)
3616
m1 = torch.tensor([1], dtype=torch.int8, device=device)
3617
m2 = torch.tensor([2], dtype=torch.int8, device=device)
3618
self.assertRaisesRegex(
3620
r"Boolean alpha only supported for Boolean results\.",
3621
lambda: torch.add(m1, m2, alpha=True),
3623
self.assertRaisesRegex(
3625
r"For integral input tensors, argument alpha must not be a floating point number\.",
3626
lambda: torch.add(m1, m2, alpha=1.0),
3630
msg = r"For non-complex input tensors, argument alpha must not be a complex number\."
3631
m1 = torch.tensor([3.0, 4.0], device=device)
3632
m2 = torch.tensor([4.0, 3.0], device=device)
3633
self.assertRaisesRegex(
3634
RuntimeError, msg, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))
3637
m1 = torch.tensor([3.0, 4.0], dtype=torch.double, device=device)
3638
m2 = torch.tensor([4.0, 3.0], dtype=torch.double, device=device)
3639
self.assertRaisesRegex(
3640
RuntimeError, msg, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))
3644
m1 = torch.tensor((4.0000 + 4.0000j), dtype=torch.complex64)
3645
m2 = torch.tensor(4.0, dtype=torch.float64)
3646
self.assertRaisesRegex(
3648
r"result type ComplexFloat can't be cast to the desired output type Double",
3649
lambda: torch.add(m1, m1, out=m2),
3653
def test_addsub_half_tensor(self, device):
3654
x = torch.tensor([60000.0], dtype=torch.half, device=device)
3655
for op, y, alpha in (
3656
(torch.add, torch.tensor([-60000.0], dtype=torch.half, device=device), 2),
3657
(torch.sub, torch.tensor([60000.0], dtype=torch.half, device=device), 2),
3658
(torch.add, -70000.0, 1),
3659
(torch.sub, 70000.0, 1),
3661
actual = op(x, y, alpha=alpha)
3662
self.assertTrue(not (actual.isnan() or actual.isinf()))
3664
def test_sub_typing(self, device):
3666
[True, False, False, True, False, False], dtype=torch.bool, device=device
3669
[True, True, False, False, False, True], dtype=torch.bool, device=device
3671
self.assertRaisesRegex(
3673
r"Subtraction, the `\-` operator, with two bool tensors is not supported. "
3674
r"Use the `\^` or `logical_xor\(\)` operator instead.",
3677
self.assertRaisesRegex(
3679
r"Subtraction, the `\-` operator, with a bool tensor is not supported. "
3680
r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.",
3683
self.assertRaisesRegex(
3685
r"Subtraction, the `\-` operator, with a bool tensor is not supported. "
3686
r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.",
3691
m1 = torch.tensor([1], dtype=torch.int8, device=device)
3692
m2 = torch.tensor([2], dtype=torch.int8, device=device)
3693
self.assertRaisesRegex(
3695
r"Boolean alpha only supported for Boolean results\.",
3696
lambda: torch.sub(m1, m2, alpha=True),
3698
self.assertRaisesRegex(
3700
r"For integral input tensors, argument alpha must not be a floating point number\.",
3701
lambda: torch.sub(m1, m2, alpha=1.0),
3704
def test_mul(self, device):
3705
m1 = torch.randn(10, 10, device=device)
3709
for i in range(res1.size(0)):
3710
res2[i, 3] = res2[i, 3] * 2
3711
self.assertEqual(res1, res2)
3713
a1 = torch.tensor([True, False, False, True], dtype=torch.bool, device=device)
3714
a2 = torch.tensor([True, False, True, False], dtype=torch.bool, device=device)
3717
torch.tensor([True, False, False, False], dtype=torch.bool, device=device),
3721
a1 = torch.tensor([0.1, 0.1], dtype=torch.bfloat16, device=device)
3722
a2 = torch.tensor([1.1, 0.1], dtype=torch.bfloat16, device=device)
3725
torch.tensor([0.11, 0.01], dtype=torch.bfloat16, device=device),
3729
self.assertEqual(a1.mul(a2), a1 * a2)
3731
def test_bool_tensor_comparison_ops(self, device):
3733
[True, False, True, False, True, False], dtype=torch.bool, device=device
3736
[True, False, True, True, True, True], dtype=torch.bool, device=device
3739
a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)
3742
a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)
3745
a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)
3748
a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.bool, device=device)
3751
a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)
3754
a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.bool, device=device)
3757
a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)
3760
a == torch.tensor(True, dtype=torch.bool, device=device),
3761
torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device),
3764
a == torch.tensor(0, dtype=torch.bool, device=device),
3765
torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device),
3767
self.assertFalse(a.equal(b))
3769
@dtypes(*all_types_and(torch.half, torch.bfloat16, torch.bool))
3770
def test_logical(self, device, dtype):
3771
if dtype != torch.bool:
3772
x = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype)
3773
b = torch.tensor([2], device=device, dtype=dtype)
3774
self.assertEqual(x.lt(2), torch.tensor([True, False, False, False]))
3775
self.assertEqual(x.le(2), torch.tensor([True, True, False, False]))
3776
self.assertEqual(x.ge(2), torch.tensor([False, True, True, True]))
3777
self.assertEqual(x.gt(2), torch.tensor([False, False, True, True]))
3778
self.assertEqual(x.eq(2), torch.tensor([False, True, False, False]))
3779
self.assertEqual(x.ne(2), torch.tensor([True, False, True, True]))
3781
self.assertEqual(x.lt(b), torch.tensor([True, False, False, False]))
3782
self.assertEqual(x.le(b), torch.tensor([True, True, False, False]))
3783
self.assertEqual(x.ge(b), torch.tensor([False, True, True, True]))
3784
self.assertEqual(x.gt(b), torch.tensor([False, False, True, True]))
3785
self.assertEqual(x.eq(b), torch.tensor([False, True, False, False]))
3786
self.assertEqual(x.ne(b), torch.tensor([True, False, True, True]))
3788
x = torch.tensor([True, False, True, False], device=device)
3789
self.assertEqual(x.lt(True), torch.tensor([False, True, False, True]))
3790
self.assertEqual(x.le(True), torch.tensor([True, True, True, True]))
3791
self.assertEqual(x.ge(True), torch.tensor([True, False, True, False]))
3792
self.assertEqual(x.gt(True), torch.tensor([False, False, False, False]))
3793
self.assertEqual(x.eq(True), torch.tensor([True, False, True, False]))
3794
self.assertEqual(x.ne(True), torch.tensor([False, True, False, True]))
3796
def test_atan2(self, device):
3797
def _test_atan2_with_size(size, device):
3798
a = torch.rand(size=size, device=device, dtype=torch.double)
3799
b = torch.rand(size=size, device=device, dtype=torch.double)
3803
expected = torch.tensor(
3804
[math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())],
3808
self.assertEqual(expected, actual.view(-1), rtol=0, atol=0.02)
3811
for lowp_dtype in [torch.bfloat16, torch.float16]:
3812
if lowp_dtype == torch.bfloat16:
3818
a_16 = a.to(dtype=lowp_dtype)
3819
b_16 = b.to(dtype=lowp_dtype)
3820
actual_16 = a_16.atan2(b_16)
3821
self.assertEqual(actual_16, actual.to(dtype=lowp_dtype))
3822
self.assertEqual(expected, actual_16.view(-1), exact_dtype=False, rtol=rtol, atol=atol)
3824
_test_atan2_with_size((2, 2), device)
3825
_test_atan2_with_size((3, 3), device)
3826
_test_atan2_with_size((5, 5), device)
3828
def test_atan2_edgecases(self, device):
3829
def _test_atan2(x, y, expected, device, dtype):
3830
expected_tensor = torch.tensor([expected], dtype=dtype, device=device)
3831
x_tensor = torch.tensor([x], dtype=dtype, device=device)
3832
y_tensor = torch.tensor([y], dtype=dtype, device=device)
3833
actual = torch.atan2(y_tensor, x_tensor)
3834
self.assertEqual(expected_tensor, actual, rtol=0, atol=0.02)
3836
for dtype in [torch.float, torch.double]:
3837
_test_atan2(0, 0, 0, device, dtype)
3838
_test_atan2(0, 1, math.pi / 2, device, dtype)
3839
_test_atan2(0, -1, math.pi / -2, device, dtype)
3840
_test_atan2(-1, 0, math.pi, device, dtype)
3841
_test_atan2(1, 0, 0, device, dtype)
3842
_test_atan2(-1, -1, math.pi * -3 / 4, device, dtype)
3843
_test_atan2(1, 1, math.pi / 4, device, dtype)
3844
_test_atan2(1, -1, math.pi / -4, device, dtype)
3845
_test_atan2(-1, 1, math.pi * 3 / 4, device, dtype)
3847
def test_trapezoid(self, device):
3848
def test_dx(sizes, dim, dx, device):
3849
t = torch.randn(sizes, device=device)
3850
actual = torch.trapezoid(t, dx=dx, dim=dim)
3851
expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim)
3852
self.assertEqual(expected.shape, actual.shape)
3853
self.assertEqual(expected, actual, exact_dtype=False)
3855
def test_x(sizes, dim, x, device):
3856
t = torch.randn(sizes, device=device)
3857
actual = torch.trapezoid(t, x=torch.tensor(x, device=device), dim=dim)
3858
expected = np.trapz(t.cpu().numpy(), x=x, axis=dim)
3859
self.assertEqual(expected.shape, actual.shape)
3860
self.assertEqual(expected, actual.cpu(), exact_dtype=False)
3862
test_dx((2, 3, 4), 1, 1, device)
3863
test_dx((10, 2), 0, 0.1, device)
3864
test_dx((1, 10), 0, 2.3, device)
3865
test_dx((0, 2), 0, 1.0, device)
3866
test_dx((0, 2), 1, 1.0, device)
3867
test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
3869
(10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device
3871
test_x((1, 10), 0, [1.0], device)
3872
test_x((0, 2), 0, [], device)
3873
test_x((0, 2), 1, [1.0, 2.0], device)
3874
test_x((2, 3, 4), -1, [1.0, 2.0, 3.0, 4.0], device)
3875
test_x((2, 3, 4), 0, [1.0, 2.0], device)
3876
test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
3877
test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device)
3878
test_x((2, 2, 4), -1, [[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], device)
3879
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
3880
test_x((2, 3), 2, [], device)
3881
test_dx((2, 3), 2, 1.0, device)
3882
with self.assertRaisesRegex(
3883
RuntimeError, "There must be one `x` value for each sample point"
3885
test_x((2, 3), 1, [1.0, 2.0], device)
3886
test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device)
3888
@skipIf(not TEST_SCIPY, "Scipy required for the test.")
3891
sys.version_info >= (3, 12), "Failing on Python 3.12"
3893
def test_cumulative_trapezoid(self, device):
3895
import scipy.integrate
3897
if hasattr(scipy.integrate, "cumulative_trapezoid"):
3898
scipy_cumulative_trapezoid = scipy.integrate.cumulative_trapezoid
3900
scipy_cumulative_trapezoid = scipy.integrate.cumtrapz
3902
def test_dx(sizes, dim, dx, device):
3903
t = torch.randn(sizes, device=device)
3905
actual = torch.cumulative_trapezoid(t, dx=dx, dim=dim)
3906
expected = scipy_cumulative_trapezoid(t.cpu().numpy(), dx=dx, axis=dim)
3907
self.assertEqual(expected.shape, actual.shape)
3908
self.assertEqual(expected, actual, exact_dtype=False, atol=1e-4, rtol=1e-4)
3910
def test_x(sizes, dim, x, device):
3911
t = torch.randn(sizes, device=device)
3912
actual = torch.cumulative_trapezoid(
3913
t, x=torch.tensor(x, device=device), dim=dim
3915
expected = scipy_cumulative_trapezoid(t.cpu().numpy(), x=x, axis=dim)
3916
self.assertEqual(expected.shape, actual.shape)
3918
expected, actual.cpu(), exact_dtype=False, atol=1e-4, rtol=1e-4
3921
def test_empty_x(sizes, dim, x, device):
3922
t = torch.randn(sizes, device=device)
3923
actual = torch.cumulative_trapezoid(
3924
t, x=torch.tensor(x, device=device), dim=dim
3926
self.assertEqual(torch.empty(actual.shape), actual)
3928
test_dx((2,), -1, 1, device)
3929
test_dx((3, 3), -1, 1, device)
3930
test_dx((4, 2), 0, 1, device)
3931
test_dx((2, 3, 4), 1, 1, device)
3932
test_dx((10, 2), 0, 0.1, device)
3933
test_dx((1, 10), 0, 2.3, device)
3934
test_dx((0, 2), 0, 1.0, device)
3935
test_dx((0, 2), 1, 1.0, device)
3936
test_dx((512, 512), 1, 1.0, device)
3937
test_dx((100, 100, 100), 1, 1.0, device)
3939
test_x((2,), -1, [100, 50], device)
3940
test_x((4, 2), 0, [2, 3, 4, 5], device)
3941
test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
3943
(10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device
3945
test_x((1, 10), 0, [1.0], device)
3946
test_x((0, 2), 1, [1, 2], device)
3947
test_x((2, 3, 4), -1, [1.0, 2.0, 3.0, 4.0], device)
3948
test_x((2, 3, 4), 0, [1.0, 2.0], device)
3949
test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
3950
test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device)
3953
(0, 2), 0, [], device
3956
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
3957
test_x((2, 3), 2, [], device)
3958
test_dx((2, 3), 2, 1.0, device)
3959
with self.assertRaisesRegex(
3960
RuntimeError, "There must be one `x` value for each sample point"
3962
test_x((2, 3), 1, [1.0, 2.0], device)
3963
test_x((0, 2), 0, [1.0, 2.0], device)
3964
test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device)
3965
with self.assertRaisesRegex(
3966
RuntimeError, "Currently, we only support dx as a real number"
3968
test_dx((2, 2), -1, complex(1, 1), device)
3969
with self.assertRaisesRegex(
3970
TypeError, "received an invalid combination of arguments"
3972
actual = torch.cumulative_trapezoid(
3973
torch.randn((3, 3)), x=torch.randn((3, 3)), dx=3
3977
@dtypes(torch.double)
3978
def test_pow_scalar_overloads_mem_overlap(self, device, dtype):
3980
doubles = torch.randn(2 * sz, dtype=dtype, device=device)
3981
self.check_internal_mem_overlap(lambda t: t.pow_(42), 1, dtype, device)
3982
self.unary_check_input_output_mem_overlap(
3983
doubles, sz, lambda input, out: torch.pow(input, 42, out=out)
3985
self.unary_check_input_output_mem_overlap(
3986
doubles, sz, lambda input, out: torch.pow(42, input, out=out)
3992
all_types_and_complex_and(torch.half, torch.bfloat16),
3993
all_types_and_complex_and(torch.half, torch.bfloat16),
3997
def test_float_power(self, device, dtypes):
3999
if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16:
4000
return value.to(torch.float).cpu().numpy()
4001
return value.cpu().numpy() if isinstance(value, torch.Tensor) else value
4003
base_dtype = dtypes[0]
4004
exp_dtype = dtypes[1]
4007
if base_dtype.is_complex or exp_dtype.is_complex
4011
base = make_tensor((30,), dtype=base_dtype, device=device, low=1, high=100)
4015
exp = make_tensor((30,), dtype=exp_dtype, device=device, low=-2, high=2)
4016
exp[0] = exp[4] = exp[6] = 0
4018
expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp)))
4020
exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2]
4021
complex_exponents = exponents + [
4033
torch.Tensor.float_power,
4034
torch.Tensor.float_power_,
4038
if op is torch.Tensor.float_power_ and base_dtype != out_dtype:
4039
with self.assertRaisesRegex(
4040
RuntimeError, "operation's result requires dtype"
4042
op(base.clone(), exp)
4044
result = op(base.clone(), exp)
4045
self.assertEqual(expected, result)
4047
if op is torch.float_power:
4048
out = torch.empty_like(base).to(device=device, dtype=out_dtype)
4049
op(base, exp, out=out)
4050
self.assertEqual(expected, out)
4053
for i in complex_exponents if exp_dtype.is_complex else exponents:
4054
out_dtype_scalar_exp = (
4056
if base_dtype.is_complex or type(i) == complex
4059
expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i))
4062
op is torch.Tensor.float_power_
4063
and base_dtype != out_dtype_scalar_exp
4065
with self.assertRaisesRegex(
4066
RuntimeError, "operation's result requires dtype"
4070
result = op(base.clone(), i)
4071
self.assertEqual(expected_scalar_exp, result)
4073
if op is torch.float_power:
4074
out = torch.empty_like(base).to(
4075
device=device, dtype=out_dtype_scalar_exp
4077
op(base, i, out=out)
4078
self.assertEqual(expected_scalar_exp, out)
4081
for i in complex_exponents if base_dtype.is_complex else exponents:
4082
out_dtype_scalar_base = (
4084
if exp_dtype.is_complex or type(i) == complex
4087
expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp)))
4089
result = torch.float_power(i, exp)
4090
self.assertEqual(expected_scalar_base, result)
4092
out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base)
4093
torch.float_power(i, exp, out=out)
4094
self.assertEqual(expected_scalar_base, out)
4096
def test_float_power_exceptions(self, device):
4097
def _promo_helper(x, y):
4099
if type(i) == complex:
4100
return torch.complex128
4101
elif type(i) == torch.Tensor and i.is_complex():
4102
return torch.complex128
4106
(torch.tensor([-2, -1, 0, 1, 2], device=device), -0.25),
4108
torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device),
4112
for base, exp in test_cases:
4113
for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble):
4114
out = torch.empty(1, device=device, dtype=out_dtype)
4115
required_dtype = _promo_helper(base, exp)
4117
if out.dtype == required_dtype:
4118
torch.float_power(base, exp, out=out)
4120
with self.assertRaisesRegex(
4121
RuntimeError, "operation's result requires dtype"
4123
torch.float_power(base, exp, out=out)
4125
if base.dtype == required_dtype:
4126
torch.Tensor.float_power_(base.clone(), exp)
4128
with self.assertRaisesRegex(
4129
RuntimeError, "operation's result requires dtype"
4131
torch.Tensor.float_power_(base.clone(), exp)
4133
@skipIf(not TEST_SCIPY, "Scipy required for the test.")
4136
all_types_and(torch.half, torch.bool), all_types_and(torch.half, torch.bool)
4139
def test_xlogy_xlog1py(self, device, dtypes):
4140
x_dtype, y_dtype = dtypes
4142
def out_variant_helper(torch_fn, x, y):
4143
expected = torch_fn(x, y)
4144
out = torch.empty_like(expected)
4145
torch_fn(x, y, out=out)
4146
self.assertEqual(expected, out)
4148
def xlogy_inplace_variant_helper(x, y):
4149
if x.dtype in integral_types_and(torch.bool):
4150
with self.assertRaisesRegex(
4151
RuntimeError, "can't be cast to the desired output type"
4155
expected = torch.empty_like(x)
4156
torch.xlogy(x, y, out=expected)
4157
inplace_out = x.clone().xlogy_(y)
4158
self.assertEqual(expected, inplace_out)
4160
def test_helper(torch_fn, reference_fn, inputs, scalar=None):
4162
torch_fn_partial = partial(torch_fn, x)
4163
reference_fn_partial = partial(reference_fn, x.cpu().numpy())
4164
self.compare_with_numpy(
4165
torch_fn_partial, reference_fn_partial, x, exact_dtype=False
4167
self.compare_with_numpy(
4168
torch_fn_partial, reference_fn_partial, y, exact_dtype=False
4170
self.compare_with_numpy(
4171
torch_fn_partial, reference_fn_partial, z, exact_dtype=False
4174
val = scalar if scalar is not None else x
4175
out_variant_helper(torch_fn, val, x)
4176
out_variant_helper(torch_fn, val, y)
4177
out_variant_helper(torch_fn, val, z)
4180
x = make_tensor((3, 2, 4, 5), dtype=x_dtype, device=device, low=0.5, high=1000)
4181
y = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4182
z = make_tensor((4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4185
(3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.5, high=1000
4188
(3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.5, high=1000
4190
z_1p = make_tensor((4, 5), dtype=y_dtype, device=device, low=-0.5, high=1000)
4192
xlogy_fns = torch.xlogy, scipy.special.xlogy
4193
xlog1py_fns = torch.special.xlog1py, scipy.special.xlog1py
4195
test_helper(*xlogy_fns, (x, y, z))
4196
xlogy_inplace_variant_helper(x, x)
4197
xlogy_inplace_variant_helper(x, y)
4198
xlogy_inplace_variant_helper(x, z)
4199
test_helper(*xlog1py_fns, (x_1p, y_1p, z_1p))
4202
test_helper(*xlogy_fns, (x, y, z), 3.14)
4203
test_helper(*xlog1py_fns, (x_1p, y_1p, z_1p), 3.14)
4207
[-1.0, 0.0, 1.0, 2.0, float("inf"), -float("inf"), float("nan")],
4210
zeros = torch.zeros(7, dtype=y_dtype, device=device)
4212
def test_zeros_special_helper(torch_fn, reference_fn, scalar=False):
4213
zeros_t = 0 if scalar else zeros
4214
zeros_np = 0 if scalar else zeros.cpu().numpy()
4215
torch_fn_partial = partial(torch_fn, zeros_t)
4216
reference_fn_partial = partial(reference_fn, zeros_np)
4217
self.compare_with_numpy(
4218
torch_fn_partial, reference_fn_partial, t, exact_dtype=False
4220
out_variant_helper(torch_fn, zeros_t, t)
4222
test_zeros_special_helper(*xlogy_fns)
4223
xlogy_inplace_variant_helper(zeros, t)
4224
test_zeros_special_helper(*xlog1py_fns)
4227
test_zeros_special_helper(*xlogy_fns, scalar=True)
4228
test_zeros_special_helper(*xlog1py_fns, scalar=True)
4230
@dtypes(torch.float64)
4231
def test_xlogy_xlog1py_gradients(self, device, dtype):
4232
make_arg = partial(torch.tensor, dtype=dtype, device=device, requires_grad=True)
4234
zeros = torch.zeros((2,), dtype=dtype, device=device)
4236
x = make_arg([0.0, 0.0])
4237
y = make_arg([-1.5, 0.0])
4238
torch.special.xlogy(x, y).sum().backward()
4239
self.assertEqual(x.grad, zeros)
4241
x = make_arg([0.0, 0.0])
4242
y = make_arg([-2.5, -1.0])
4243
torch.special.xlog1py(x, y).sum().backward()
4244
self.assertEqual(x.grad, zeros)
4246
def test_xlogy_xlog1py_scalar_type_promotion(self, device):
4249
t = torch.randn((), dtype=torch.float32, device=device)
4251
self.assertEqual(t.dtype, torch.xlogy(t, 5).dtype)
4252
self.assertEqual(t.dtype, torch.xlogy(t, 5.0).dtype)
4253
self.assertEqual(t.dtype, torch.special.xlog1py(t, 5).dtype)
4254
self.assertEqual(t.dtype, torch.special.xlog1py(t, 5.0).dtype)
4256
self.assertEqual(t.dtype, torch.xlogy(5, t).dtype)
4257
self.assertEqual(t.dtype, torch.xlogy(5.0, t).dtype)
4258
self.assertEqual(t.dtype, torch.special.xlog1py(5, t).dtype)
4259
self.assertEqual(t.dtype, torch.special.xlog1py(5.0, t).dtype)
4261
@skipIf(not TEST_SCIPY, "Scipy required for the test.")
4262
def test_xlogy_xlog1py_bfloat16(self, device):
4263
def _compare_helper(x, y, torch_fn, reference_fn):
4264
x_np = x if isinstance(x, float) else x.cpu().to(torch.float).numpy()
4265
y_np = y if isinstance(y, float) else y.cpu().to(torch.float).numpy()
4266
expected = torch.from_numpy(reference_fn(x_np, y_np))
4267
actual = torch_fn(x, y)
4268
self.assertEqual(expected, actual, exact_dtype=False)
4270
x_dtype, y_dtype = torch.bfloat16, torch.bfloat16
4273
x = make_tensor((3, 2, 4, 5), dtype=x_dtype, device=device, low=0.5, high=1000)
4274
y = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4275
z = make_tensor((4, 5), dtype=y_dtype, device=device, low=0.5, high=1000)
4278
(3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.8, high=1000
4281
(3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.8, high=1000
4283
z_1p = make_tensor((4, 5), dtype=y_dtype, device=device, low=-0.8, high=1000)
4285
xlogy_fns = torch.xlogy, scipy.special.xlogy
4286
xlog1py_fns = torch.special.xlog1py, scipy.special.xlog1py
4288
_compare_helper(x, x, *xlogy_fns)
4289
_compare_helper(x, y, *xlogy_fns)
4290
_compare_helper(x, z, *xlogy_fns)
4291
_compare_helper(x, 3.14, *xlogy_fns)
4292
_compare_helper(y, 3.14, *xlogy_fns)
4293
_compare_helper(z, 3.14, *xlogy_fns)
4295
_compare_helper(x_1p, x_1p, *xlog1py_fns)
4296
_compare_helper(x_1p, y_1p, *xlog1py_fns)
4297
_compare_helper(x_1p, z_1p, *xlog1py_fns)
4298
_compare_helper(x_1p, 3.14, *xlog1py_fns)
4299
_compare_helper(y_1p, 3.14, *xlog1py_fns)
4300
_compare_helper(z_1p, 3.14, *xlog1py_fns)
4304
[-1.0, 0.0, 1.0, 2.0, float("inf"), -float("inf"), float("nan")],
4307
zeros = torch.tensor(7, dtype=y_dtype, device=device)
4309
_compare_helper(t, zeros, *xlogy_fns)
4310
_compare_helper(t, 0.0, *xlogy_fns)
4312
_compare_helper(t, zeros, *xlog1py_fns)
4313
_compare_helper(t, 0.0, *xlog1py_fns)
4315
@dtypes(*product(all_types_and(torch.bool), all_types_and(torch.bool)))
4316
@skipIf(not TEST_SCIPY, "Scipy required for the test.")
4318
def test_zeta(self, device, dtypes):
4319
x_dtype, q_dtype = dtypes
4321
def test_helper(x, q):
4322
x_np = x if isinstance(x, float) else x.cpu().numpy()
4323
q_np = q if isinstance(q, float) else q.cpu().numpy()
4324
expected = torch.from_numpy(scipy.special.zeta(x_np, q_np))
4325
actual = torch.special.zeta(x, q)
4327
rtol, atol = None, None
4328
if self.device_type == "cpu":
4329
rtol, atol = 1e-6, 1e-6
4330
self.assertEqual(expected, actual, rtol=rtol, atol=atol, exact_dtype=False)
4333
x = make_tensor((2, 3, 4), dtype=x_dtype, device=device)
4334
q = make_tensor((2, 3, 4), dtype=q_dtype, device=device)
4338
x = make_tensor((2, 1, 4), dtype=x_dtype, device=device)
4339
q = make_tensor((2, 3, 4), dtype=q_dtype, device=device)
4343
x = make_tensor((2, 3, 4), dtype=x_dtype, device=device)
4344
q = make_tensor((2, 1, 4), dtype=q_dtype, device=device)
4348
x = make_tensor((2, 3, 1), dtype=x_dtype, device=device)
4349
q = make_tensor((2, 1, 4), dtype=q_dtype, device=device)
4353
for x in np.linspace(-5, 5, num=10).tolist():
4354
if not q_dtype.is_floating_point:
4355
q_dtype = torch.get_default_dtype()
4356
q = make_tensor((2, 3, 4), dtype=q_dtype, device=device)
4360
for q in np.linspace(-5, 5, num=10).tolist():
4361
if not x_dtype.is_floating_point:
4362
x_dtype = torch.get_default_dtype()
4363
x = make_tensor((2, 3, 4), dtype=x_dtype, device=device)
4370
def test_mul_chalf_tensor_and_cpu_scalar(self, device, dtype):
4379
x = make_tensor((2, 2), device=device, dtype=dtype)
4380
self.assertEqual(x * 2.5, x * torch.tensor(2.5, device=device, dtype=dtype))
4383
tensor_binary_ops = [
4434
def generate_not_implemented_tests(cls):
4450
def create_test_func(op):
4452
def test(self, device, dtype):
4454
tensor = torch.empty((), device=device, dtype=dtype)
4457
result = getattr(tensor, op)(UnknownType())
4458
self.assertEqual(result, NotImplemented)
4462
for op in tensor_binary_ops:
4463
test_name = f"test_{op}_not_implemented"
4464
assert not hasattr(cls, test_name), f"{test_name} already in {cls.__name__}"
4466
setattr(cls, test_name, create_test_func(op))
4469
generate_not_implemented_tests(TestBinaryUfuncs)
4470
instantiate_device_type_tests(TestBinaryUfuncs, globals())
4472
if __name__ == "__main__":