7
from numbers import Number
11
from torch import inf, nan
12
from torch.testing._internal.common_utils import (
15
torch_to_numpy_dtype_dict,
16
numpy_to_torch_dtype_dict,
23
is_iterable_of_tensors,
25
from torch.testing._internal.common_methods_invocations import (
27
generate_elementwise_unary_tensors,
28
generate_elementwise_unary_small_value_tensors,
29
generate_elementwise_unary_large_value_tensors,
30
generate_elementwise_unary_extremal_value_tensors,
32
from torch.testing._internal.common_device_type import (
33
instantiate_device_type_tests,
37
onlyNativeDeviceTypes,
43
from torch.utils import _pytree as pytree
45
from torch.testing import make_tensor
46
from torch.testing._internal.common_dtype import (
48
all_types_and_complex_and,
52
floating_and_complex_types_and,
61
reference_filtered_ops = list(filter(lambda op: op.ref is not None, unary_ufuncs))
78
class TestUnaryUfuncs(TestCase):
82
[_fn for _fn in unary_ufuncs if _fn.domain != (None, None)],
83
allowed_dtypes=floating_types_and(torch.bfloat16, torch.half),
85
def test_float_domains(self, device, dtype, op):
86
eps = (1e-5, 1e-3, 1e-1, 1, 2, 10, 20, 50, 100)
91
low_tensor = torch.tensor(low, device=device, dtype=dtype)
93
lower_tensor = low_tensor - epsilon
98
if lower_tensor.item() == low_tensor.item():
101
result = op(lower_tensor)
106
f"input of {lower_tensor.item()} outside lower domain boundary"
107
f" {low} produced {result.item()}, not nan!"
112
high_tensor = torch.tensor(high, device=device, dtype=dtype)
114
higher_tensor = high_tensor + epsilon
117
if higher_tensor.item() == high_tensor.item():
120
result = op(higher_tensor)
125
f"input of {higher_tensor.item()} outside upper domain boundary"
126
f" {high} produced {result.item()}, not nan!"
132
def assertEqualHelper(
133
self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs
135
assert isinstance(actual, torch.Tensor)
138
if isinstance(expected, Number):
139
self.assertEqual(actual.item(), expected, msg, **kwargs)
140
elif isinstance(expected, np.ndarray):
144
actual.dtype is torch.bfloat16
145
or expected.dtype != torch_to_numpy_dtype_dict[actual.dtype]
151
if expected.dtype == np.float32:
152
assert actual.dtype in (
157
elif expected.dtype == np.float64:
158
assert actual.dtype in (
166
f"Expected dtype {expected.dtype} but got {actual.dtype}!"
171
torch.from_numpy(expected).to(actual.dtype),
177
self.assertEqual(actual, expected, msg, exact_device=False, **kwargs)
181
def _test_reference_numerics(self, dtype, op, tensors, equal_nan=True):
182
def _helper_reference_numerics(
183
expected, actual, msg, exact_dtype, equal_nan=True
185
if not torch.can_cast(
186
numpy_to_torch_dtype_dict[expected.dtype.type], dtype
190
if dtype in [torch.uint8, torch.int8, torch.bool]:
193
self.assertEqualHelper(
198
exact_dtype=exact_dtype,
202
elif dtype is torch.bfloat16:
204
self.assertEqualHelper(
209
exact_dtype=exact_dtype,
213
elif dtype is torch.half:
214
self.assertEqualHelper(
219
exact_dtype=exact_dtype,
224
self.assertEqualHelper(
230
exact_dtype=exact_dtype,
235
torch_kwargs, numpy_kwargs = op.sample_kwargs(t.device, dtype, t)
236
if dtype is torch.bfloat16:
237
a = t.cpu().to(torch.float32).numpy()
238
elif dtype is torch.complex32:
239
a = t.cpu().to(torch.complex64).numpy()
243
actual = op(t, **torch_kwargs)
244
expected = op.ref(a, **numpy_kwargs)
249
"Failed to produce expected results! Input tensor was"
250
f" {t}, torch result is {actual}, and reference result is"
257
if isinstance(actual, torch.Tensor):
258
_helper_reference_numerics(
259
expected, actual, msg, exact_dtype, equal_nan
262
for x, y in zip(expected, actual):
264
_helper_reference_numerics(x, y, msg, exact_dtype, equal_nan)
271
@ops(reference_filtered_ops)
272
def test_reference_numerics_normal(self, device, dtype, op):
273
tensors = generate_elementwise_unary_tensors(
274
op, device=device, dtype=dtype, requires_grad=False
276
self._test_reference_numerics(dtype, op, tensors)
279
@ops(reference_filtered_ops)
280
def test_reference_numerics_small(self, device, dtype, op):
281
if dtype in (torch.bool,):
282
raise self.skipTest("bool has no small values")
284
tensors = generate_elementwise_unary_small_value_tensors(
285
op, device=device, dtype=dtype, requires_grad=False
287
self._test_reference_numerics(dtype, op, tensors)
290
@ops(reference_filtered_ops)
291
def test_reference_numerics_large(self, device, dtype, op):
292
if dtype in (torch.bool, torch.uint8, torch.int8):
293
raise self.skipTest("bool, uint8, and int8 dtypes have no large values")
295
tensors = generate_elementwise_unary_large_value_tensors(
296
op, device=device, dtype=dtype, requires_grad=False
298
self._test_reference_numerics(dtype, op, tensors)
302
reference_filtered_ops,
303
allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
305
def test_reference_numerics_extremal(self, device, dtype, op):
306
tensors = generate_elementwise_unary_extremal_value_tensors(
307
op, device=device, dtype=dtype, requires_grad=False
309
self._test_reference_numerics(dtype, op, tensors)
313
def test_contig_vs_every_other(self, device, dtype, op):
314
contig = make_tensor(
315
(1026,), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1]
317
non_contig = contig[::2]
319
self.assertTrue(contig.is_contiguous())
320
self.assertFalse(non_contig.is_contiguous())
322
torch_kwargs, _ = op.sample_kwargs(device, dtype, non_contig)
323
expected = op(non_contig, **torch_kwargs)
324
result = op(contig, **torch_kwargs)
325
result = pytree.tree_map(lambda x: x[::2], result)
326
self.assertEqual(result, expected)
329
def test_contig_vs_transposed(self, device, dtype, op):
330
contig = make_tensor(
331
(789, 357), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1]
333
non_contig = contig.T
335
self.assertTrue(contig.is_contiguous())
336
self.assertFalse(non_contig.is_contiguous())
338
torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
339
expected = op(non_contig, **torch_kwargs)
340
result = op(contig, **torch_kwargs)
341
result = pytree.tree_map(lambda x: x.T, result)
342
self.assertEqual(result, expected)
345
def test_non_contig(self, device, dtype, op):
346
shapes = [(5, 7), (1024,)]
348
contig = make_tensor(
349
shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
351
non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0]
352
non_contig.copy_(contig)
354
self.assertTrue(contig.is_contiguous())
355
self.assertFalse(non_contig.is_contiguous())
357
torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
358
self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs))
361
def test_non_contig_index(self, device, dtype, op):
362
contig = make_tensor(
369
non_contig = contig[:, 1, ...]
370
contig = non_contig.contiguous()
372
self.assertTrue(contig.is_contiguous())
373
self.assertFalse(non_contig.is_contiguous())
375
torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
376
self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs))
379
def test_non_contig_expand(self, device, dtype, op):
380
shapes = [(1, 3), (1, 7), (5, 7)]
382
contig = make_tensor(
383
shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
385
non_contig = contig.clone().expand(3, -1, -1)
387
self.assertTrue(contig.is_contiguous())
388
self.assertFalse(non_contig.is_contiguous())
390
torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
391
contig = op(contig, **torch_kwargs)
392
non_contig = op(non_contig, **torch_kwargs)
394
non_contig_i = pytree.tree_map(lambda x: x[i], non_contig)
396
contig, non_contig_i, msg="non-contiguous expand[" + str(i) + "]"
400
def test_contig_size1(self, device, dtype, op):
401
contig = make_tensor(
402
(5, 100), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
404
contig = contig[:1, :50]
405
contig2 = torch.empty(contig.size(), device=device, dtype=dtype)
406
contig2.copy_(contig)
408
self.assertTrue(contig.is_contiguous())
409
self.assertTrue(contig2.is_contiguous())
411
torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
412
self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs))
415
def test_contig_size1_large_dim(self, device, dtype, op):
416
contig = make_tensor(
417
(5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4),
423
contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :]
424
contig2 = torch.empty(contig.size(), device=device, dtype=dtype)
425
contig2.copy_(contig)
427
self.assertTrue(contig.is_contiguous())
428
self.assertTrue(contig2.is_contiguous())
430
torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
431
self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs))
436
def test_batch_vs_slicing(self, device, dtype, op):
438
(1024, 512), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
441
torch_kwargs, _ = op.sample_kwargs(device, dtype, input)
442
actual = op(input, **torch_kwargs)
444
all_outs = [op(slice, **torch_kwargs) for slice in input]
445
if is_iterable_of_tensors(actual):
446
expected = [torch.stack([out[i] for out in all_outs]) for i in range(len(actual))]
448
expected = torch.stack(all_outs)
450
self.assertEqual(actual, expected)
452
@dtypes(*all_types_and_complex_and(torch.bool, torch.half))
453
def test_nan_to_num(self, device, dtype):
454
for contiguous in [False, True]:
455
x = make_tensor((64, 64), low=0.0, high=100.0, dtype=dtype, device=device)
457
if dtype.is_floating_point:
459
extremals = [float("nan"), float("inf"), -float("inf")]
460
for idx, extremal in zip(torch.randint(0, 63, (3,)), extremals):
467
nan = random.random()
468
posinf = random.random() * 5
469
neginf = random.random() * 10
471
self.compare_with_numpy(
472
lambda x: x.nan_to_num(nan=nan, posinf=posinf),
473
lambda x: np.nan_to_num(x, nan=nan, posinf=posinf),
476
self.compare_with_numpy(
477
lambda x: x.nan_to_num(posinf=posinf, neginf=neginf),
478
lambda x: np.nan_to_num(x, posinf=posinf, neginf=neginf),
483
out = torch.empty_like(x)
484
result = torch.nan_to_num(x)
485
torch.nan_to_num(x, out=out)
486
self.assertEqual(result, out)
488
result = torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
489
torch.nan_to_num(x, out=out, nan=nan, posinf=posinf, neginf=neginf)
490
self.assertEqual(result, out)
493
def test_nan_to_num_bfloat16(self, device):
494
def test_dtype(fn, input, dtype):
495
input = input.detach().clone().to(dtype=dtype).requires_grad_(True)
496
input2 = input.detach().clone().float().requires_grad_(True)
500
out2.sum().backward()
501
self.assertEqual(out.dtype, dtype)
502
self.assertEqual(input.grad.dtype, dtype)
503
self.assertEqual(out, out2, exact_dtype=False)
504
self.assertEqual(input.grad, input2.grad, exact_dtype=False)
507
return torch.nan_to_num
509
shapes = [[1, 3, 6, 6], [1, 3, 6, 128], [1, 3, 256, 256]]
511
x = torch.randn(shape, device=device)
512
extremals = [float('nan'), float('inf'), -float('inf')]
513
for id1, id2, extremal in zip(torch.randint(0, 2, (3,)), torch.randint(0, 5, (3,)), extremals):
514
x[0, id1, id2, :] = extremal
515
test_dtype(func(), x, torch.bfloat16)
517
@dtypes(torch.complex64, torch.complex128)
518
def test_nan_to_num_complex(self, device, dtype):
519
value_dtype = torch.tensor([], dtype=dtype).real.dtype
522
return torch.view_as_complex(torch.tensor(a, dtype=value_dtype, device=device))
524
for extremal, kwarg_name in zip(['nan', 'inf', '-inf'], ['nan', 'posinf', 'neginf']):
525
a = gen_tensor([123, float(extremal)])
526
res = torch.nan_to_num(a, **{kwarg_name: 12})
527
res_check = gen_tensor([123, 12])
528
self.assertEqual(res, res_check)
530
a = gen_tensor([float(extremal), 456])
531
res = torch.nan_to_num(a, **{kwarg_name: 21})
532
res_check = gen_tensor([21, 456])
533
self.assertEqual(res, res_check)
535
@dtypes(torch.cdouble)
536
def test_complex_edge_values(self, device, dtype):
538
x = torch.tensor(0.0 - 1.0e20j, dtype=dtype, device=device)
539
self.compare_with_numpy(torch.sqrt, np.sqrt, x)
543
if not (IS_WINDOWS and dtype == torch.cdouble and "cuda" in device):
544
self.compare_with_numpy(torch.acos, np.arccos, x)
547
(-1.0e60 if dtype == torch.cdouble else -1.0e20) - 4988429.2j,
551
self.compare_with_numpy(torch.sqrt, np.sqrt, x)
553
@unittest.skipIf(not TEST_SCIPY, "Requires SciPy")
554
@dtypes(torch.float, torch.double)
555
def test_digamma_special(self, device, dtype):
559
euler = 0.57721566490153286
563
(0.5, -2 * math.log(2) - euler),
564
(1 / 3, -math.pi / (2 * math.sqrt(3)) - 3 * math.log(3) / 2 - euler),
565
(1 / 4, -math.pi / 2 - 3 * math.log(2) - euler),
568
-math.pi * math.sqrt(3) / 2
570
- 3 * math.log(3) / 2
577
- (math.pi + math.log(2 + math.sqrt(2)) - math.log(2 - math.sqrt(2)))
582
x = torch.tensor(dataset, device=device, dtype=dtype)
583
self.compare_with_numpy(torch.digamma, scipy.special.digamma, x)
585
@unittest.skipIf(not TEST_SCIPY, "Requires SciPy")
586
@dtypes(torch.float, torch.double)
587
def test_digamma(self, device, dtype):
589
tensor = torch.tensor(
607
self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor)
609
@dtypes(*floating_types_and(torch.half))
610
def test_frexp(self, device, dtype):
611
input = make_tensor((50, 50), dtype=dtype, device=device)
612
mantissa, exponent = torch.frexp(input)
613
np_mantissa, np_exponent = np.frexp(input.cpu().numpy())
615
self.assertEqual(mantissa, np_mantissa)
616
self.assertEqual(exponent, np_exponent)
619
self.assertTrue(exponent.dtype == torch.int32)
620
self.assertTrue(torch_to_numpy_dtype_dict[exponent.dtype] == np_exponent.dtype)
622
def test_frexp_assert_raises(self, device):
623
invalid_input_dtypes = integral_types_and(torch.bool) + complex_types()
624
for dtype in invalid_input_dtypes:
625
input = make_tensor((50, 50), dtype=dtype, device=device)
626
with self.assertRaisesRegex(
627
RuntimeError, r"torch\.frexp\(\) only supports floating-point dtypes"
631
for dtype in floating_types_and(torch.half):
632
input = make_tensor((50, 50), dtype=dtype, device=device)
635
all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)
638
for mantissa_dtype in dtypes:
639
mantissa = torch.empty_like(input, dtype=mantissa_dtype)
640
exponent = torch.empty_like(input, dtype=torch.int)
641
with self.assertRaisesRegex(
643
r"torch\.frexp\(\) expects mantissa to have dtype .+ but got .+",
645
torch.frexp(input, out=(mantissa, exponent))
648
dtypes.remove(torch.int)
649
for exponent_dtype in dtypes:
650
mantissa = torch.empty_like(input)
651
exponent = torch.empty_like(input, dtype=exponent_dtype)
652
with self.assertRaisesRegex(
654
r"torch\.frexp\(\) expects exponent to have int dtype but got .+",
656
torch.frexp(input, out=(mantissa, exponent))
658
def test_polygamma_neg(self, device):
659
with self.assertRaisesRegex(
660
RuntimeError, r"polygamma\(n, x\) does not support negative n\."
662
torch.polygamma(-1, torch.tensor([1.0, 2.0], device=device))
666
def test_op_invert(self, device):
667
res = 0xFFFF - torch.arange(127, dtype=torch.int8)
668
for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
669
a = torch.arange(127, dtype=dtype)
670
self.assertEqual(res.to(dtype), ~a)
672
self.assertEqual(torch.tensor([True, False]), ~torch.tensor([False, True]))
675
for dtype in (torch.half, torch.float, torch.double):
676
a = torch.zeros(10, dtype=dtype)
677
with self.assertRaises(TypeError):
680
@dtypes(torch.complex64, torch.complex128)
681
def test_abs_angle_complex_to_float(self, device, dtype):
683
from random import random
686
for multiplier in (-1, 1, -10, 10, -100, 100):
689
complex(random() * multiplier, random() * multiplier)
692
for vals in (random_vals, []):
693
a = np.array(vals, dtype=torch_to_numpy_dtype_dict[dtype])
694
t = torch.tensor(vals, device=device, dtype=dtype)
696
for fn_name in ("abs", "angle"):
697
torch_fn = getattr(torch, fn_name)
698
np_fn = getattr(np, fn_name)
701
np_result = torch.from_numpy(np_fn(a))
702
torch_result = torch_fn(t).cpu()
703
self.assertEqual(np_result, torch_result, exact_dtype=True)
707
torch.float32 if dtype is torch.complex64 else torch.float64
709
np_float_out = np_fn(a).astype(torch_to_numpy_dtype_dict[float_dtype])
710
float_out = torch.empty_like(t, dtype=float_dtype)
711
torch_fn(t, out=float_out)
712
self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu())
715
float_out = torch.empty(1, device=device, dtype=float_dtype)
716
torch_fn(t, out=float_out)
717
self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu())
720
np_complex_out = np_fn(a).astype(torch_to_numpy_dtype_dict[dtype])
721
complex_out = torch.empty_like(t)
722
torch_fn(t, out=complex_out)
723
self.assertEqual(torch.from_numpy(np_complex_out), complex_out.cpu())
726
complex_out = torch.empty(0, device=device, dtype=dtype)
727
torch_fn(t, out=complex_out)
728
self.assertEqual(torch.from_numpy(np_complex_out), complex_out.cpu())
731
long_out = torch.empty(0, device=device, dtype=torch.long)
732
with self.assertRaises(RuntimeError):
733
torch_fn(t, out=long_out)
737
torch_inplace_method = getattr(torch.Tensor, fn_name + "_")
740
with self.assertRaisesRegex(
742
"In-place abs is not supported for complex tensors.",
744
torch_inplace_method(t)
746
torch_inplace_method(t)
747
self.assertEqual(torch.from_numpy(a), t.cpu())
750
if fn_name == "angle":
751
with self.assertRaises(AttributeError):
752
torch_inplace_method = getattr(torch.Tensor, fn_name + "_")
754
def check_internal_mem_overlap(
755
self, inplace_op, num_inputs, dtype, device, expected_failure=False
757
if isinstance(inplace_op, str):
758
inplace_op = getattr(torch.Tensor, inplace_op)
759
input = torch.randn(1, dtype=dtype, device=device).expand(3, 3)
760
inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)]
761
if not expected_failure:
762
with self.assertRaisesRegex(RuntimeError, "single memory location"):
765
with self.assertRaises(AssertionError):
766
with self.assertRaisesRegex(RuntimeError, "single memory location"):
769
def unary_check_input_output_mem_overlap(
770
self, data, sz, op, expected_failure=False
772
def _test(op, output, input):
773
output_exp = torch.empty_like(output)
774
op(input, out=output_exp)
775
self.assertEqual(op(input, out=output), output_exp, msg=op.__name__)
778
_test(op, output=data[0:sz], input=data[0:sz])
780
_test(op, output=data[0:sz], input=data[sz : 2 * sz])
782
if not expected_failure:
783
with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
784
_test(op, data[0:sz], data[1 : sz + 1])
786
with self.assertRaises(AssertionError):
787
with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
788
_test(op, data[0:sz], data[1 : sz + 1])
791
@dtypes(torch.double)
792
def test_unary_out_op_mem_overlap(self, device, dtype):
794
doubles = torch.randn(2 * sz, dtype=dtype, device=device)
795
positives = torch.randint(1, 100, (2 * sz,), device=device).double()
796
ints = torch.randint(-100, 100, (2 * sz,), device=device)
797
unary_mem_overlap_cases = [
798
("abs", doubles, True, True, "cpu"),
799
("abs", doubles, True, True, "cuda"),
800
("acos", doubles, True, True, "cpu"),
801
("acos", doubles, True, True, "cuda"),
802
("asin", doubles, True, True, "cpu"),
803
("asin", doubles, True, True, "cuda"),
804
("atan", doubles, True, True, "cpu"),
805
("atan", doubles, True, True, "cuda"),
806
("acosh", doubles, True, True, "cpu"),
807
("acosh", doubles, True, True, "cuda"),
808
("asinh", doubles, True, True, "cpu"),
809
("asinh", doubles, True, True, "cuda"),
810
("atanh", doubles, True, True, "cpu"),
811
("atanh", doubles, True, True, "cuda"),
812
("bitwise_not", ints, True, True, "cpu"),
813
("bitwise_not", ints, True, True, "cuda"),
814
("ceil", doubles, True, True, "cpu"),
815
("ceil", doubles, True, True, "cuda"),
816
("cos", doubles, True, True, "cpu"),
817
("cos", doubles, True, True, "cuda"),
818
("cosh", doubles, True, True, "cpu"),
819
("cosh", doubles, True, True, "cuda"),
820
("digamma", doubles, True, True, "cpu"),
821
("erf", doubles, True, True, "cpu"),
822
("erf", doubles, True, True, "cuda"),
823
("erfc", doubles, True, True, "cpu"),
824
("erfc", doubles, True, True, "cuda"),
825
("erfinv", doubles, True, True, "cpu"),
826
("erfinv", doubles, True, True, "cuda"),
827
("exp", doubles, True, True, "cpu"),
828
("exp", doubles, True, True, "cuda"),
829
("exp2", doubles, True, True, "cpu"),
830
("exp2", doubles, True, True, "cuda"),
831
("expm1", doubles, True, True, "cpu"),
832
("expm1", doubles, True, True, "cuda"),
833
("floor", doubles, True, True, "cpu"),
834
("floor", doubles, True, True, "cuda"),
835
("frac", doubles, True, True, "cpu"),
836
("frac", doubles, True, True, "cuda"),
837
("i0", doubles, True, True, "cpu"),
838
("i0", doubles, True, True, "cuda"),
839
("log", positives, True, True, "cpu"),
840
("log", positives, True, True, "cuda"),
841
("log10", positives, True, True, "cpu"),
842
("log10", positives, True, True, "cuda"),
843
("log1p", positives, True, True, "cpu"),
844
("log1p", positives, True, True, "cuda"),
845
("log2", positives, True, True, "cpu"),
846
("log2", positives, True, True, "cuda"),
847
("neg", doubles, True, True, "cpu"),
848
("neg", doubles, True, True, "cuda"),
849
("reciprocal", doubles, True, True, "cpu"),
850
("reciprocal", doubles, True, True, "cuda"),
851
("round", doubles, True, True, "cpu"),
852
("round", doubles, True, True, "cuda"),
853
("rsqrt", positives, True, True, "cpu"),
854
("rsqrt", positives, True, True, "cuda"),
855
("sin", doubles, True, True, "cpu"),
856
("sin", doubles, True, True, "cuda"),
857
("sinh", doubles, True, True, "cpu"),
858
("sinh", doubles, False, True, "cuda"),
859
("sigmoid", doubles, True, True, "cpu"),
860
("sigmoid", doubles, True, True, "cuda"),
861
("logit", doubles, True, True, "cpu"),
862
("logit", doubles, True, True, "cuda"),
863
("sqrt", doubles, True, True, "cpu"),
864
("sqrt", doubles, False, True, "cuda"),
865
("tan", doubles, True, True, "cpu"),
866
("tan", doubles, True, True, "cuda"),
867
("tanh", doubles, True, True, "cpu"),
868
("tanh", doubles, True, True, "cuda"),
869
("trunc", doubles, True, True, "cpu"),
870
("trunc", doubles, True, True, "cuda"),
876
has_input_output_mem_overlap_check,
877
has_internal_mem_overlap_check,
879
) in unary_mem_overlap_cases:
882
out_fn = getattr(torch, fn)
883
in_fn = getattr(torch.Tensor, fn + "_")
885
self.unary_check_input_output_mem_overlap(
889
expected_failure=not has_input_output_mem_overlap_check,
892
self.check_internal_mem_overlap(
897
expected_failure=not has_internal_mem_overlap_check,
902
@dtypes(torch.float, torch.double, torch.bfloat16)
903
def test_hardshrink(self, device, dtype):
904
data = torch.tensor([1, 0.5, 0.3, 0.6], dtype=dtype, device=device).view(2, 2)
906
torch.tensor([1, 0.5, 0, 0.6], dtype=dtype, device=device).view(2, 2),
907
data.hardshrink(0.3),
910
torch.tensor([1, 0, 0, 0.6], dtype=dtype, device=device).view(2, 2),
911
data.hardshrink(0.5),
915
self.assertEqual(data.hardshrink(), data.hardshrink(0.5))
919
torch.tensor([1, 0, 0.5, 0.6], dtype=dtype, device=device).view(2, 2),
920
data.t().hardshrink(0.3),
924
@dtypes(torch.float, torch.double, torch.bfloat16)
925
def test_hardshrink_edge_cases(self, device, dtype) -> None:
926
def h(values, l_expected):
927
for l, expected in l_expected.items():
928
values_tensor = torch.tensor(
929
[float(v) for v in values], dtype=dtype, device=device
931
expected_tensor = torch.tensor(
932
[float(v) for v in expected], dtype=dtype, device=device
935
expected_tensor == values_tensor.hardshrink(l),
936
torch.ones_like(values_tensor, dtype=torch.bool),
939
def test_helper(min, max):
941
[0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
943
0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
944
min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
945
0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf],
946
1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf],
947
max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf],
948
inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
952
test_helper(torch.finfo(dtype).tiny, torch.finfo(dtype).max)
957
@unittest.skipIf(True, "Insufficient memory on linux.(2|4)xlarge")
958
def test_exp_slow(self, device, dtype):
962
a = torch.exp(torch.ones(2**31, dtype=dtype, device=device))
963
b = torch.exp(torch.ones(1, dtype=dtype, device=device))
964
self.assertEqual(a, b.expand(2**31))
967
{torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}
969
@dtypes(torch.float, torch.double, torch.bfloat16)
970
def test_hardswish(self, device, dtype):
971
inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000]
972
expectedOutput = np.multiply(
973
inputValues, np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0
976
inputTensor = torch.tensor(inputValues, dtype=dtype, device=device)
977
expectedOutputTensor = torch.tensor(expectedOutput, dtype=dtype, device=device)
981
torch.nn.functional.hardswish(inputTensor), expectedOutputTensor
985
inputTensorCpy = inputTensor.clone().detach()
986
torch.nn.functional.hardswish(inputTensorCpy, inplace=True)
987
self.assertEqual(inputTensorCpy, expectedOutputTensor)
990
{torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}
992
@dtypes(torch.float, torch.double, torch.bfloat16)
993
def test_hardsigmoid(self, device, dtype):
994
inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000]
995
expectedOutput = np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0
997
inputTensor = torch.tensor(inputValues, dtype=dtype, device=device)
1001
torch.nn.functional.hardsigmoid(inputTensor),
1002
torch.tensor(expectedOutput, dtype=dtype, device=device),
1006
inputTensorCpy = inputTensor.clone().detach()
1008
torch.nn.functional.hardsigmoid(inputTensorCpy, inplace=True),
1009
torch.tensor(expectedOutput, dtype=dtype, device=device),
1013
{torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}
1015
@dtypes(torch.float, torch.double, torch.bfloat16)
1016
def test_hardsigmoid_backward(self, device, dtype):
1017
inputValues = [-3.0, 3.0, -2.0, 2.0, -6.0, 6.0]
1018
expectedValues = [0.0, 0.0, 1.0 / 6.0, 1.0 / 6.0, 0.0, 0.0]
1019
inputTensor = torch.tensor(
1020
inputValues, dtype=dtype, device=device
1022
expetedTensor = torch.tensor(expectedValues, dtype=dtype, device=device)
1023
out = torch.nn.functional.hardsigmoid(inputTensor)
1024
out.backward(torch.ones_like(inputTensor))
1025
self.assertEqual(inputTensor.grad, expetedTensor)
1028
@dtypes(torch.float, torch.double)
1029
def test_silu(self, device, dtype):
1030
input_np = np.random.randn(5, 8)
1031
special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]]
1032
input_np = np.concatenate((input_np, special_input), axis=0).astype(
1033
torch_to_numpy_dtype_dict[dtype]
1035
expected_output_np = input_np * scipy.special.expit(input_np)
1037
expected_output = torch.from_numpy(expected_output_np).to(device)
1038
expected_output_noncontig = expected_output.transpose(0, 1)
1043
input = torch.from_numpy(input_np).clone().contiguous().to(device)
1045
torch.nn.functional.silu(input), expected_output, atol=atol, rtol=rtol
1048
torch.nn.functional.silu(input, inplace=True),
1054
input = torch.from_numpy(input_np).clone().to(device)
1055
input_noncontig = input.transpose(0, 1)
1057
torch.nn.functional.silu(input_noncontig),
1058
expected_output_noncontig,
1063
torch.nn.functional.silu(input_noncontig, inplace=True),
1064
expected_output_noncontig,
1069
@dtypes(torch.complex64, torch.complex128)
1070
def test_silu_complex(self, device, dtype):
1074
(0.2 + 0.3j, 0.08775215595960617065 + 0.18024823069572448730j),
1075
(1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j),
1076
(-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j),
1077
(0.0 + 0.5j, -0.06383547931909561157 + 0.25000000000000000000j),
1078
(2.0j, -1.55740761756896972656 + 0.99999988079071044922j)
1081
for inp, out in inouts:
1082
res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device))
1083
self.assertFalse(torch.any(torch.isnan(res)))
1084
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
1085
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
1087
for inp, out in inouts:
1088
res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device), inplace=True)
1089
self.assertFalse(torch.any(torch.isnan(res)))
1090
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
1091
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
1095
@dtypes(torch.double)
1096
def test_sinc(self, device, dtype):
1103
[0.0, torch.finfo(torch.double).tiny, 1.0],
1108
gradcheck(torch.sinc, a)
1111
@dtypes(torch.float, torch.double)
1112
def test_mish(self, device, dtype):
1113
input_np = np.random.randn(5, 8)
1114
special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]]
1115
input_np = np.concatenate((input_np, special_input), axis=0).astype(
1116
torch_to_numpy_dtype_dict[dtype]
1118
expected_output_np = input_np * np.tanh(np.log1p(np.exp(input_np)))
1120
expected_output = torch.from_numpy(expected_output_np).to(device)
1121
expected_output_noncontig = expected_output.transpose(0, 1)
1126
input = torch.from_numpy(input_np).clone().contiguous().to(device)
1128
torch.nn.functional.mish(input), expected_output, atol=atol, rtol=rtol
1131
torch.nn.functional.mish(input, inplace=True),
1137
input = torch.from_numpy(input_np).clone().to(device)
1138
input_noncontig = input.transpose(0, 1)
1140
torch.nn.functional.mish(input_noncontig),
1141
expected_output_noncontig,
1146
torch.nn.functional.mish(input_noncontig, inplace=True),
1147
expected_output_noncontig,
1152
@dtypes(torch.complex64, torch.complex128)
1153
def test_log1p_complex(self, device, dtype):
1160
(0.2 + 0.3j, 0.21263386770217202 + 0.24497866312686414j),
1161
(1e-19 + 1e-18j, 1e-19 + 1e-18j),
1162
(1e-18 + 0.1j, 0.00497517 + 0.0996687j),
1163
(0.1 + 1e-18j, 0.0953102 + 9.090909090909090909e-19j),
1164
(0.5 + 0j, 0.40546510810816 + 0j),
1165
(0.0 + 0.5j, 0.111571776 + 0.463647609j),
1166
(2.0 + 1.0j, 1.151292546497023 + 0.3217505543966422j),
1167
(-1.0 + 2.0j, 0.6931471805599453 + 1.570796326794897j),
1168
(2.0j, 0.80471895621705014 + 1.1071487177940904j),
1169
(-2.0j, 0.80471895621705014 - 1.1071487177940904j),
1172
if dtype == torch.complex128:
1174
(-1 + 1e250j, 575.6462732485114 + 1.5707963267948966j),
1175
(1e250 + 1j, 575.6462732485114 + 1e-250j),
1176
(1e250 + 1e250j, 575.9928468387914 + 0.7853981633974483j),
1177
(1e-250 + 1e250j, 575.6462732485114 + 1.5707963267948966j),
1178
(1e-250 + 2e-250j, 1e-250 + 2e-250j),
1179
(1e250 + 1e-250j, 575.6462732485114 + 0.0j),
1181
elif dtype == torch.complex64:
1183
(-1 + 1e30j, 69.07755278982137 + 1.5707963267948966j),
1184
(1e30 + 1j, 69.07755278982137 + 1e-30j),
1185
(1e30 + 1e30j, 69.42412638010134 + 0.7853981633974483j),
1186
(1e-30 + 1e30j, 69.07755278982137 + 1.5707963267948966j),
1187
(1e-30 + 2e-30j, 1e-30 + 2e-30j),
1188
(1e30 + 1e-30j, 69.07755278982137 + 0.0j),
1192
for inp, out in inouts:
1193
res = torch.log1p(torch.tensor(inp, dtype=dtype, device=device))
1194
self.assertFalse(torch.any(torch.isnan(res)))
1196
self.assertEqual(res.real, out.real, atol=0.0, rtol=1e-6)
1197
self.assertEqual(res.imag, out.imag, atol=0.0, rtol=1e-6)
1200
inp_lst, out_lst = (list(elmt) for elmt in zip(*inouts))
1201
inp_tens = torch.tensor(inp_lst, dtype=dtype, device=device)
1202
out_tens = torch.tensor(out_lst, dtype=dtype, device=device)
1203
res_tens = torch.log1p(inp_tens)
1204
self.assertEqual(res_tens.real, out_tens.real, atol=0.0, rtol=1e-6)
1205
self.assertEqual(res_tens.imag, out_tens.imag, atol=0.0, rtol=1e-6)
1209
@dtypes(*get_all_math_dtypes("cpu"))
1210
def test_threshold(self, device, dtype):
1211
if dtype != torch.uint8 and dtype != torch.float16 and not dtype.is_complex:
1214
torch.randn(100, dtype=torch.float, device=device)
1218
y = torch.threshold(x, 0, 0)
1219
self.assertTrue(y.le(0).any())
1221
def _helper_test_igamma(self, loglo, loghi, device, dtype, torch_fcn, scipy_fcn):
1222
exp1 = 2.71828182846
1223
vec1 = torch.logspace(
1224
loglo, loghi, steps=500, base=exp1, dtype=torch.float64, device=device
1226
vec1 = vec1.to(dtype)
1228
(vec1, vec1.transpose(0, 1)),
1232
(vec1[::2, :], vec1[::2, :]),
1233
(vec1[::2, :], vec1[: vec1.shape[0] // 2, :]),
1234
(vec1[: vec1.shape[0] // 2, :], vec1[::2, :]),
1236
half_prec = dtype in [torch.bfloat16, torch.float16]
1237
for input0, input1 in inputs:
1238
actual = torch_fcn(input0, input1)
1240
input0 = input0.to(torch.float)
1241
input1 = input1.to(torch.float)
1242
expected = scipy_fcn(input0.cpu().numpy(), input1.cpu().numpy())
1243
expected = torch.from_numpy(expected).to(dtype)
1244
self.assertEqual(actual, expected)
1246
@dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1247
@dtypes(torch.float32, torch.float64)
1248
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1249
@onlyNativeDeviceTypes
1250
def test_igamma_common(self, device, dtype):
1254
self._helper_test_igamma(
1255
loglo, loghi, device, dtype, torch.igamma, scipy.special.gammainc
1258
@dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1259
@dtypes(torch.float32, torch.float64)
1260
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1261
@onlyNativeDeviceTypes
1262
def test_igammac_common(self, device, dtype):
1266
self._helper_test_igamma(
1267
loglo, loghi, device, dtype, torch.igammac, scipy.special.gammaincc
1270
@dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1271
@dtypes(torch.float32, torch.float64)
1272
@onlyNativeDeviceTypes
1273
def test_igamma_edge_cases(self, device, dtype):
1274
tkwargs = {"dtype": dtype, "device": device}
1275
infs = torch.zeros((3,), **tkwargs) + float("inf")
1276
zeros = torch.zeros((3,), **tkwargs)
1277
ones = torch.ones((3,), **tkwargs)
1278
zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs)
1279
small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs)
1280
nans = torch.zeros((3,), **tkwargs) + float("nan")
1283
((zeros, small_to_inf), ones),
1284
((small_to_inf, zeros), zeros),
1285
((infs, zero_to_large), zeros),
1286
((zero_to_large, infs), ones),
1287
((zeros, zeros), nans),
1288
((infs, infs), nans),
1289
((-small_to_inf, small_to_inf), nans),
1291
for inputs, output in inpouts:
1292
input0, input1 = inputs
1293
calc = torch.igamma(input0, input1)
1294
if torch.all(torch.isnan(output)):
1295
self.assertTrue(torch.all(torch.isnan(calc)))
1297
self.assertEqual(calc, output)
1299
@dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1300
@dtypes(torch.float32, torch.float64)
1301
@onlyNativeDeviceTypes
1302
def test_igammac_edge_cases(self, device, dtype):
1303
tkwargs = {"dtype": dtype, "device": device}
1304
infs = torch.zeros((3,), **tkwargs) + float("inf")
1305
zeros = torch.zeros((3,), **tkwargs)
1306
ones = torch.ones((3,), **tkwargs)
1307
zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs)
1308
small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs)
1309
nans = torch.zeros((3,), **tkwargs) + float("nan")
1312
((zeros, small_to_inf), zeros),
1313
((small_to_inf, zeros), ones),
1314
((infs, zero_to_large), ones),
1315
((zero_to_large, infs), zeros),
1316
((zeros, zeros), nans),
1317
((infs, infs), nans),
1318
((-small_to_inf, small_to_inf), nans),
1320
for inputs, output in inpouts:
1321
input0, input1 = inputs
1322
calc = torch.igammac(input0, input1)
1323
if torch.all(torch.isnan(output)):
1324
self.assertTrue(torch.all(torch.isnan(calc)))
1326
self.assertEqual(calc, output)
1328
def _i0_helper(self, t):
1331
actual = torch.i0(t)
1332
if dtype is torch.bfloat16:
1333
t = t.to(torch.float32)
1334
expected = scipy.special.i0(t.cpu().numpy())
1336
if dtype is torch.bfloat16 or dtype is torch.float16:
1337
expected = torch.from_numpy(expected).to(dtype)
1338
self.assertEqual(actual, expected)
1340
def _i0_range_helper(self, range, device, dtype):
1344
for r in (range, -range):
1345
t = torch.rand(1000, device=device).to(dtype) * r
1348
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1349
@dtypes(torch.bfloat16, torch.float32, torch.float64)
1350
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1351
def test_i0_range1(self, device, dtype):
1354
self._i0_range_helper(13.25, device, dtype)
1356
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1357
@dtypes(torch.bfloat16, torch.float32, torch.float64)
1358
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1359
def test_i0_range2(self, device, dtype):
1362
self._i0_range_helper(88.5, device, dtype)
1364
@dtypes(torch.float64)
1365
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1366
def test_i0_range3(self, device, dtype):
1369
self._i0_range_helper(709.75, device, dtype)
1371
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1372
@dtypes(torch.bfloat16, torch.float32, torch.float64)
1373
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1374
def test_i0_special(self, device, dtype):
1375
t = torch.tensor([], device=device, dtype=dtype)
1378
t = torch.tensor([inf, -inf, nan], device=device, dtype=dtype)
1379
self.assertTrue(torch.i0(t).isnan().all())
1381
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1382
@dtypes(torch.bfloat16, torch.float32, torch.float64)
1383
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1384
def test_special_i0_i1_vs_scipy(self, device, dtype):
1385
def check_equal(t, torch_fn, scipy_fn):
1387
actual = torch_fn(t)
1388
if dtype is torch.bfloat16:
1389
t = t.to(torch.float32)
1390
expected = scipy_fn(t.cpu().numpy())
1393
if dtype is torch.bfloat16 or dtype is torch.float16:
1394
expected = torch.from_numpy(expected).to(dtype)
1395
self.assertEqual(actual, expected)
1397
t = torch.tensor([], device=device, dtype=dtype)
1398
check_equal(t, torch.i0, scipy.special.i0)
1399
check_equal(t, torch.special.i0e, scipy.special.i0e)
1400
if dtype not in [torch.half, torch.bfloat16]:
1401
check_equal(t, torch.special.i1, scipy.special.i1)
1402
check_equal(t, torch.special.i1e, scipy.special.i1e)
1405
if dtype == torch.half:
1406
range = (-65000, 65000)
1408
t = torch.linspace(*range, int(1e4), device=device, dtype=dtype)
1409
check_equal(t, torch.i0, scipy.special.i0)
1410
check_equal(t, torch.special.i0e, scipy.special.i0e)
1411
if dtype not in [torch.half, torch.bfloat16]:
1412
check_equal(t, torch.special.i1, scipy.special.i1)
1413
check_equal(t, torch.special.i1e, scipy.special.i1e)
1416
info = torch.finfo(dtype)
1417
min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
1418
t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
1419
check_equal(t, torch.i0, scipy.special.i0)
1420
check_equal(t, torch.special.i0e, scipy.special.i0e)
1421
if dtype not in [torch.half, torch.bfloat16]:
1422
check_equal(t, torch.special.i1, scipy.special.i1)
1423
check_equal(t, torch.special.i1e, scipy.special.i1e)
1425
@dtypes(torch.float32, torch.float64)
1426
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1427
def test_special_ndtr_vs_scipy(self, device, dtype):
1430
actual = torch.special.ndtr(t)
1431
expected = scipy.special.ndtr(t.cpu().numpy())
1432
self.assertEqual(actual, expected)
1435
t = torch.linspace(*range, 1, device=device, dtype=dtype)
1439
info = torch.finfo(dtype)
1440
min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
1441
t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
1444
@dtypes(torch.float32, torch.float64)
1445
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1446
def test_special_log_ndtr_vs_scipy(self, device, dtype):
1449
actual = torch.special.log_ndtr(t)
1450
expected = scipy.special.log_ndtr(t.cpu().numpy())
1451
self.assertEqual(actual, expected)
1454
info = torch.finfo(dtype)
1455
min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
1456
t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
1461
def test_abs_big_number(self, device, dtype):
1462
bignumber = 2**31 + 1
1463
res = torch.tensor([bignumber], device=device, dtype=dtype)
1464
self.assertGreater(res.abs()[0], 0)
1467
@dtypes(torch.float, torch.double)
1468
def test_abs_signed_zero(self, device, dtype):
1472
inp = torch.zeros(size, device=device, dtype=dtype)
1476
self.assertGreater(math.copysign(1.0, v), 0.0)
1480
@dtypes(torch.float, torch.double)
1481
def test_abs_zero(self, device, dtype):
1483
abs_zeros = torch.tensor([0.0, -0.0], device=device, dtype=dtype).abs().tolist()
1484
for num in abs_zeros:
1485
self.assertGreater(math.copysign(1.0, num), 0.0)
1487
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
1488
def test_isposinf_isneginf_non_boolean_output(self, device, dtype):
1491
vals = (float("inf"), -float("inf"), 1.2)
1492
t = torch.tensor(vals, device=device)
1493
for torch_op in (torch.isposinf, torch.isneginf):
1494
out = torch.empty_like(t, dtype=dtype)
1495
with self.assertRaisesRegex(
1496
RuntimeError, "does not support non-boolean outputs"
1498
torch_op(t, out=out)
1500
def test_nonzero_empty(self, device):
1501
def assert_tuple_empty(tup, dim):
1502
self.assertEqual(dim, len(tup))
1504
self.assertEqual(torch.Size([0]), t.shape)
1506
x = torch.randn(0, 2, 0, 5, 0, device=device)
1507
y = torch.nonzero(x)
1508
z = torch.nonzero(x, as_tuple=True)
1510
self.assertEqual(0, y.numel())
1511
self.assertEqual(torch.Size([0, 5]), y.shape)
1512
assert_tuple_empty(z, 5)
1514
x = torch.tensor(0.5, device=device)
1515
y = torch.nonzero(x)
1519
z = torch.nonzero(x, as_tuple=True)
1520
self.assertEqual(1, len(z))
1521
self.assertEqual(torch.zeros(1, dtype=torch.long), z[0])
1523
x = torch.zeros((), device=device)
1524
y = torch.nonzero(x)
1525
z = torch.nonzero(x, as_tuple=True)
1526
self.assertEqual(torch.Size([0, 0]), y.shape)
1527
self.assertEqual(1, len(z))
1528
self.assertEqual(torch.empty(0, dtype=torch.long), z[0])
1531
@dtypes(*floating_and_complex_types_and(torch.bfloat16))
1532
@dtypesIfCUDA(*floating_and_complex_types_and(torch.half, torch.bfloat16))
1533
def test_exp(self, device, dtype):
1534
for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
1536
torch.tensor(v, dtype=dtype, device=device)
1537
* torch.arange(18, device=device)
1543
if dtype == torch.bfloat16:
1545
self.compare_with_numpy(torch.exp, np.exp, a)
1547
if dtype.is_complex:
1548
inf_real_zero_imag_in = torch.tensor(
1549
complex(float("inf"), 0), device=device, dtype=dtype
1551
inf_real_zero_imag_out = torch.exp(inf_real_zero_imag_in).item()
1552
self.assertTrue(math.isinf(inf_real_zero_imag_out.real))
1553
if self.device_type == "cpu":
1564
self.assertEqual(inf_real_zero_imag_out.imag, 0, atol=0, rtol=0)
1565
self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in)
1567
zero_real_inf_imag_in = torch.tensor(
1568
complex(0, float("inf")), device=device, dtype=dtype
1570
zero_real_inf_imag_out = torch.exp(zero_real_inf_imag_in).item()
1571
self.assertTrue(math.isnan(zero_real_inf_imag_out.real))
1572
self.assertTrue(math.isnan(zero_real_inf_imag_out.imag))
1574
self.compare_with_numpy(torch.exp, np.exp, zero_real_inf_imag_in)
1576
inf_real_imag_in = torch.tensor(
1577
complex(float("inf"), float("inf")), device=device, dtype=dtype
1579
inf_real_imag_out = torch.exp(inf_real_imag_in).item()
1580
if self.device_type == "cpu":
1587
self.assertTrue(math.isinf(inf_real_imag_out.real))
1588
self.assertTrue(math.isnan(inf_real_imag_out.imag))
1589
self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in)
1591
inf_real_nan_imag_in = torch.tensor(
1592
complex(float("inf"), float("nan")), device=device, dtype=dtype
1594
inf_real_nan_imag_out = torch.exp(inf_real_nan_imag_in).item()
1595
if self.device_type == "cpu":
1602
self.assertTrue(math.isinf(inf_real_nan_imag_out.real))
1603
self.assertTrue(math.isnan(inf_real_nan_imag_out.imag))
1604
self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in)
1606
nan_real_inf_imag_in = torch.tensor(
1607
complex(float("nan"), float("inf")), device=device, dtype=dtype
1609
nan_real_inf_imag_out = torch.exp(nan_real_inf_imag_in).item()
1610
self.assertTrue(math.isnan(nan_real_inf_imag_out.real))
1611
self.assertTrue(math.isnan(nan_real_inf_imag_out.imag))
1613
self.compare_with_numpy(torch.exp, np.exp, nan_real_inf_imag_in)
1616
instantiate_device_type_tests(TestUnaryUfuncs, globals())
1618
if __name__ == "__main__":