pytorch

Форк
0
/
test_unary_ufuncs.py 
1619 строк · 66.3 Кб
1
# Owner(s): ["module: tests"]
2

3
import torch
4
import numpy as np
5

6
import math
7
from numbers import Number
8
import random
9
import unittest
10

11
from torch import inf, nan
12
from torch.testing._internal.common_utils import (
13
    TestCase,
14
    run_tests,
15
    torch_to_numpy_dtype_dict,
16
    numpy_to_torch_dtype_dict,
17
    suppress_warnings,
18
    TEST_SCIPY,
19
    slowTest,
20
    skipIfNoSciPy,
21
    IS_WINDOWS,
22
    gradcheck,
23
    is_iterable_of_tensors,
24
)
25
from torch.testing._internal.common_methods_invocations import (
26
    unary_ufuncs,
27
    generate_elementwise_unary_tensors,
28
    generate_elementwise_unary_small_value_tensors,
29
    generate_elementwise_unary_large_value_tensors,
30
    generate_elementwise_unary_extremal_value_tensors,
31
)
32
from torch.testing._internal.common_device_type import (
33
    instantiate_device_type_tests,
34
    ops,
35
    dtypes,
36
    onlyCPU,
37
    onlyNativeDeviceTypes,
38
    onlyCUDA,
39
    dtypesIfCUDA,
40
    precisionOverride,
41
    dtypesIfCPU,
42
)
43
from torch.utils import _pytree as pytree
44

45
from torch.testing import make_tensor
46
from torch.testing._internal.common_dtype import (
47
    floating_types_and,
48
    all_types_and_complex_and,
49
    integral_types_and,
50
    get_all_math_dtypes,
51
    complex_types,
52
    floating_and_complex_types_and,
53
)
54

55
if TEST_SCIPY:
56
    import scipy
57

58
# Refer [scipy reference filter]
59
# Filter operators for which the reference function
60
# is available in the current environment (for reference_numerics tests).
61
reference_filtered_ops = list(filter(lambda op: op.ref is not None, unary_ufuncs))
62

63
# Tests for unary "universal functions (ufuncs)" that accept a single
64
# tensor and have common properties like:
65
#   - they are elementwise functions
66
#   - the input shape is the output shape
67
#   - they typically have method and inplace variants
68
#   - they typically support the out kwarg
69
#   - they typically have NumPy or SciPy references
70

71
# See NumPy's universal function documentation
72
# (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
73
# about the concept of ufuncs.
74

75

76
# TODO: port test_unary_out_op_mem_overlap
77
# TODO: add test for inplace variants erroring on broadcasted inputs
78
class TestUnaryUfuncs(TestCase):
79
    exact_dtype = True
80

81
    @ops(
82
        [_fn for _fn in unary_ufuncs if _fn.domain != (None, None)],
83
        allowed_dtypes=floating_types_and(torch.bfloat16, torch.half),
84
    )
85
    def test_float_domains(self, device, dtype, op):
86
        eps = (1e-5, 1e-3, 1e-1, 1, 2, 10, 20, 50, 100)
87

88
        low, high = op.domain
89
        # NOTE: the following two loops are separated for readability
90
        if low is not None:
91
            low_tensor = torch.tensor(low, device=device, dtype=dtype)
92
            for epsilon in eps:
93
                lower_tensor = low_tensor - epsilon
94

95
                # Skips the test if the difference is not representable,
96
                #   which can occur if, for example, the difference is small
97
                #   and the dtype is imprecise (like bfloat16 is)
98
                if lower_tensor.item() == low_tensor.item():
99
                    continue
100

101
                result = op(lower_tensor)
102
                self.assertEqual(
103
                    result.item(),
104
                    float("nan"),
105
                    msg=(
106
                        f"input of {lower_tensor.item()} outside lower domain boundary"
107
                        f" {low} produced {result.item()}, not nan!"
108
                    ),
109
                )
110

111
        if high is not None:
112
            high_tensor = torch.tensor(high, device=device, dtype=dtype)
113
            for epsilon in eps:
114
                higher_tensor = high_tensor + epsilon
115

116
                # See above comment
117
                if higher_tensor.item() == high_tensor.item():
118
                    continue
119

120
                result = op(higher_tensor)
121
                self.assertEqual(
122
                    result.item(),
123
                    float("nan"),
124
                    msg=(
125
                        f"input of {higher_tensor.item()} outside upper domain boundary"
126
                        f" {high} produced {result.item()}, not nan!"
127
                    ),
128
                )
129

130
    # Helper for comparing torch tensors and numpy arrays
131
    # TODO: should this or assertEqual also validate that strides are equal?
132
    def assertEqualHelper(
133
        self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs
134
    ):
135
        assert isinstance(actual, torch.Tensor)
136

137
        # Some NumPy functions return scalars, not arrays
138
        if isinstance(expected, Number):
139
            self.assertEqual(actual.item(), expected, msg, **kwargs)
140
        elif isinstance(expected, np.ndarray):
141
            # Handles exact dtype comparisons between arrays and tensors
142
            if exact_dtype:
143
                if (
144
                    actual.dtype is torch.bfloat16
145
                    or expected.dtype != torch_to_numpy_dtype_dict[actual.dtype]
146
                ):
147
                    # Allows array dtype to be float32 when comparing with bfloat16 tensors
148
                    #   since NumPy doesn't support the bfloat16 dtype
149
                    # Also ops like scipy.special.erf, scipy.special.erfc, etc, promote float16
150
                    # to float32
151
                    if expected.dtype == np.float32:
152
                        assert actual.dtype in (
153
                            torch.float16,
154
                            torch.bfloat16,
155
                            torch.float32,
156
                        )
157
                    elif expected.dtype == np.float64:
158
                        assert actual.dtype in (
159
                            torch.float16,
160
                            torch.bfloat16,
161
                            torch.float32,
162
                            torch.float64,
163
                        )
164
                    else:
165
                        self.fail(
166
                            f"Expected dtype {expected.dtype} but got {actual.dtype}!"
167
                        )
168

169
            self.assertEqual(
170
                actual,
171
                torch.from_numpy(expected).to(actual.dtype),
172
                msg,
173
                exact_device=False,
174
                **kwargs
175
            )
176
        else:
177
            self.assertEqual(actual, expected, msg, exact_device=False, **kwargs)
178

179
    # Tests that the function and its (array-accepting) reference produce the same
180
    #   values on given tensors
181
    def _test_reference_numerics(self, dtype, op, tensors, equal_nan=True):
182
        def _helper_reference_numerics(
183
            expected, actual, msg, exact_dtype, equal_nan=True
184
        ):
185
            if not torch.can_cast(
186
                numpy_to_torch_dtype_dict[expected.dtype.type], dtype
187
            ):
188
                exact_dtype = False
189

190
            if dtype in [torch.uint8, torch.int8, torch.bool]:
191
                # NOTE: For these dtypes, PyTorch computes in the default scalar type (float)
192
                # while NumPy computes in float16
193
                self.assertEqualHelper(
194
                    actual,
195
                    expected,
196
                    msg,
197
                    dtype=dtype,
198
                    exact_dtype=exact_dtype,
199
                    rtol=1e-3,
200
                    atol=1e-2,
201
                )
202
            elif dtype is torch.bfloat16:
203
                # Ref: https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_utils.py#L1149
204
                self.assertEqualHelper(
205
                    actual,
206
                    expected,
207
                    msg,
208
                    dtype=dtype,
209
                    exact_dtype=exact_dtype,
210
                    rtol=16e-3,
211
                    atol=1e-5,
212
                )
213
            elif dtype is torch.half:
214
                self.assertEqualHelper(
215
                    actual,
216
                    expected,
217
                    msg,
218
                    dtype=dtype,
219
                    exact_dtype=exact_dtype,
220
                    rtol=1.2e-03,
221
                    atol=1e-03,
222
                )
223
            else:
224
                self.assertEqualHelper(
225
                    actual,
226
                    expected,
227
                    msg,
228
                    dtype=dtype,
229
                    equal_nan=equal_nan,
230
                    exact_dtype=exact_dtype,
231
                )
232

233
        for t in tensors:
234
            t = t.input
235
            torch_kwargs, numpy_kwargs = op.sample_kwargs(t.device, dtype, t)
236
            if dtype is torch.bfloat16:
237
                a = t.cpu().to(torch.float32).numpy()
238
            elif dtype is torch.complex32:
239
                a = t.cpu().to(torch.complex64).numpy()
240
            else:
241
                a = t.cpu().numpy()
242

243
            actual = op(t, **torch_kwargs)
244
            expected = op.ref(a, **numpy_kwargs)
245

246
            # Crafts a custom error message for smaller, printable tensors
247
            if t.numel() < 10:
248
                msg = (
249
                    "Failed to produce expected results! Input tensor was"
250
                    f" {t}, torch result is {actual}, and reference result is"
251
                    f" {expected}."
252
                )
253
            else:
254
                msg = None
255

256
            exact_dtype = True
257
            if isinstance(actual, torch.Tensor):
258
                _helper_reference_numerics(
259
                    expected, actual, msg, exact_dtype, equal_nan
260
                )
261
            else:
262
                for x, y in zip(expected, actual):
263
                    # testing multi-outputs results
264
                    _helper_reference_numerics(x, y, msg, exact_dtype, equal_nan)
265

266
    # Tests that the function and its (array-accepting) reference produce the same
267
    #   values on a range of tensors, including empty tensors, scalar tensors,
268
    #   1D tensors and a large 2D tensor with interesting and extremal values
269
    #   and noncontiguities.
270
    @suppress_warnings
271
    @ops(reference_filtered_ops)
272
    def test_reference_numerics_normal(self, device, dtype, op):
273
        tensors = generate_elementwise_unary_tensors(
274
            op, device=device, dtype=dtype, requires_grad=False
275
        )
276
        self._test_reference_numerics(dtype, op, tensors)
277

278
    @suppress_warnings
279
    @ops(reference_filtered_ops)
280
    def test_reference_numerics_small(self, device, dtype, op):
281
        if dtype in (torch.bool,):
282
            raise self.skipTest("bool has no small values")
283

284
        tensors = generate_elementwise_unary_small_value_tensors(
285
            op, device=device, dtype=dtype, requires_grad=False
286
        )
287
        self._test_reference_numerics(dtype, op, tensors)
288

289
    @suppress_warnings
290
    @ops(reference_filtered_ops)
291
    def test_reference_numerics_large(self, device, dtype, op):
292
        if dtype in (torch.bool, torch.uint8, torch.int8):
293
            raise self.skipTest("bool, uint8, and int8 dtypes have no large values")
294

295
        tensors = generate_elementwise_unary_large_value_tensors(
296
            op, device=device, dtype=dtype, requires_grad=False
297
        )
298
        self._test_reference_numerics(dtype, op, tensors)
299

300
    @suppress_warnings
301
    @ops(
302
        reference_filtered_ops,
303
        allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
304
    )
305
    def test_reference_numerics_extremal(self, device, dtype, op):
306
        tensors = generate_elementwise_unary_extremal_value_tensors(
307
            op, device=device, dtype=dtype, requires_grad=False
308
        )
309
        self._test_reference_numerics(dtype, op, tensors)
310

311
    # Tests for testing (non)contiguity consistency
312
    @ops(unary_ufuncs)
313
    def test_contig_vs_every_other(self, device, dtype, op):
314
        contig = make_tensor(
315
            (1026,), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1]
316
        )
317
        non_contig = contig[::2]
318

319
        self.assertTrue(contig.is_contiguous())
320
        self.assertFalse(non_contig.is_contiguous())
321

322
        torch_kwargs, _ = op.sample_kwargs(device, dtype, non_contig)
323
        expected = op(non_contig, **torch_kwargs)
324
        result = op(contig, **torch_kwargs)
325
        result = pytree.tree_map(lambda x: x[::2], result)
326
        self.assertEqual(result, expected)
327

328
    @ops(unary_ufuncs)
329
    def test_contig_vs_transposed(self, device, dtype, op):
330
        contig = make_tensor(
331
            (789, 357), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1]
332
        )
333
        non_contig = contig.T
334

335
        self.assertTrue(contig.is_contiguous())
336
        self.assertFalse(non_contig.is_contiguous())
337

338
        torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
339
        expected = op(non_contig, **torch_kwargs)
340
        result = op(contig, **torch_kwargs)
341
        result = pytree.tree_map(lambda x: x.T, result)
342
        self.assertEqual(result, expected)
343

344
    @ops(unary_ufuncs)
345
    def test_non_contig(self, device, dtype, op):
346
        shapes = [(5, 7), (1024,)]
347
        for shape in shapes:
348
            contig = make_tensor(
349
                shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
350
            )
351
            non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0]
352
            non_contig.copy_(contig)
353

354
            self.assertTrue(contig.is_contiguous())
355
            self.assertFalse(non_contig.is_contiguous())
356

357
            torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
358
            self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs))
359

360
    @ops(unary_ufuncs)
361
    def test_non_contig_index(self, device, dtype, op):
362
        contig = make_tensor(
363
            (2, 2, 1, 2),
364
            dtype=dtype,
365
            device=device,
366
            low=op.domain[0],
367
            high=op.domain[1],
368
        )
369
        non_contig = contig[:, 1, ...]
370
        contig = non_contig.contiguous()
371

372
        self.assertTrue(contig.is_contiguous())
373
        self.assertFalse(non_contig.is_contiguous())
374

375
        torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
376
        self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs))
377

378
    @ops(unary_ufuncs)
379
    def test_non_contig_expand(self, device, dtype, op):
380
        shapes = [(1, 3), (1, 7), (5, 7)]
381
        for shape in shapes:
382
            contig = make_tensor(
383
                shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
384
            )
385
            non_contig = contig.clone().expand(3, -1, -1)
386

387
            self.assertTrue(contig.is_contiguous())
388
            self.assertFalse(non_contig.is_contiguous())
389

390
            torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
391
            contig = op(contig, **torch_kwargs)
392
            non_contig = op(non_contig, **torch_kwargs)
393
            for i in range(3):
394
                non_contig_i = pytree.tree_map(lambda x: x[i], non_contig)
395
                self.assertEqual(
396
                    contig, non_contig_i, msg="non-contiguous expand[" + str(i) + "]"
397
                )
398

399
    @ops(unary_ufuncs)
400
    def test_contig_size1(self, device, dtype, op):
401
        contig = make_tensor(
402
            (5, 100), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
403
        )
404
        contig = contig[:1, :50]
405
        contig2 = torch.empty(contig.size(), device=device, dtype=dtype)
406
        contig2.copy_(contig)
407

408
        self.assertTrue(contig.is_contiguous())
409
        self.assertTrue(contig2.is_contiguous())
410

411
        torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
412
        self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs))
413

414
    @ops(unary_ufuncs)
415
    def test_contig_size1_large_dim(self, device, dtype, op):
416
        contig = make_tensor(
417
            (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4),
418
            dtype=dtype,
419
            device=device,
420
            low=op.domain[0],
421
            high=op.domain[1],
422
        )
423
        contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :]
424
        contig2 = torch.empty(contig.size(), device=device, dtype=dtype)
425
        contig2.copy_(contig)
426

427
        self.assertTrue(contig.is_contiguous())
428
        self.assertTrue(contig2.is_contiguous())
429

430
        torch_kwargs, _ = op.sample_kwargs(device, dtype, contig)
431
        self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs))
432

433
    # Tests that computation on a multiple batches is the same as
434
    # per-batch computation.
435
    @ops(unary_ufuncs)
436
    def test_batch_vs_slicing(self, device, dtype, op):
437
        input = make_tensor(
438
            (1024, 512), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1]
439
        )
440

441
        torch_kwargs, _ = op.sample_kwargs(device, dtype, input)
442
        actual = op(input, **torch_kwargs)
443

444
        all_outs = [op(slice, **torch_kwargs) for slice in input]
445
        if is_iterable_of_tensors(actual):
446
            expected = [torch.stack([out[i] for out in all_outs]) for i in range(len(actual))]
447
        else:
448
            expected = torch.stack(all_outs)
449

450
        self.assertEqual(actual, expected)
451

452
    @dtypes(*all_types_and_complex_and(torch.bool, torch.half))
453
    def test_nan_to_num(self, device, dtype):
454
        for contiguous in [False, True]:
455
            x = make_tensor((64, 64), low=0.0, high=100.0, dtype=dtype, device=device)
456

457
            if dtype.is_floating_point:
458
                # Add extremal values.
459
                extremals = [float("nan"), float("inf"), -float("inf")]
460
                for idx, extremal in zip(torch.randint(0, 63, (3,)), extremals):
461
                    x[idx, :] = extremal
462

463
            if not contiguous:
464
                x = x.T
465

466
            # With args
467
            nan = random.random()
468
            posinf = random.random() * 5
469
            neginf = random.random() * 10
470

471
            self.compare_with_numpy(
472
                lambda x: x.nan_to_num(nan=nan, posinf=posinf),
473
                lambda x: np.nan_to_num(x, nan=nan, posinf=posinf),
474
                x,
475
            )
476
            self.compare_with_numpy(
477
                lambda x: x.nan_to_num(posinf=posinf, neginf=neginf),
478
                lambda x: np.nan_to_num(x, posinf=posinf, neginf=neginf),
479
                x,
480
            )
481

482
            # Out Variant
483
            out = torch.empty_like(x)
484
            result = torch.nan_to_num(x)
485
            torch.nan_to_num(x, out=out)
486
            self.assertEqual(result, out)
487

488
            result = torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
489
            torch.nan_to_num(x, out=out, nan=nan, posinf=posinf, neginf=neginf)
490
            self.assertEqual(result, out)
491

492
    @onlyCPU
493
    def test_nan_to_num_bfloat16(self, device):
494
        def test_dtype(fn, input, dtype):
495
            input = input.detach().clone().to(dtype=dtype).requires_grad_(True)
496
            input2 = input.detach().clone().float().requires_grad_(True)
497
            out = fn(input)
498
            out.sum().backward()
499
            out2 = fn(input2)
500
            out2.sum().backward()
501
            self.assertEqual(out.dtype, dtype)
502
            self.assertEqual(input.grad.dtype, dtype)
503
            self.assertEqual(out, out2, exact_dtype=False)
504
            self.assertEqual(input.grad, input2.grad, exact_dtype=False)
505

506
        def func():
507
            return torch.nan_to_num
508

509
        shapes = [[1, 3, 6, 6], [1, 3, 6, 128], [1, 3, 256, 256]]
510
        for shape in shapes:
511
            x = torch.randn(shape, device=device)
512
            extremals = [float('nan'), float('inf'), -float('inf')]
513
            for id1, id2, extremal in zip(torch.randint(0, 2, (3,)), torch.randint(0, 5, (3,)), extremals):
514
                x[0, id1, id2, :] = extremal
515
            test_dtype(func(), x, torch.bfloat16)
516

517
    @dtypes(torch.complex64, torch.complex128)
518
    def test_nan_to_num_complex(self, device, dtype):
519
        value_dtype = torch.tensor([], dtype=dtype).real.dtype
520

521
        def gen_tensor(a):
522
            return torch.view_as_complex(torch.tensor(a, dtype=value_dtype, device=device))
523

524
        for extremal, kwarg_name in zip(['nan', 'inf', '-inf'], ['nan', 'posinf', 'neginf']):
525
            a = gen_tensor([123, float(extremal)])
526
            res = torch.nan_to_num(a, **{kwarg_name: 12})
527
            res_check = gen_tensor([123, 12])
528
            self.assertEqual(res, res_check)
529

530
            a = gen_tensor([float(extremal), 456])
531
            res = torch.nan_to_num(a, **{kwarg_name: 21})
532
            res_check = gen_tensor([21, 456])
533
            self.assertEqual(res, res_check)
534

535
    @dtypes(torch.cdouble)
536
    def test_complex_edge_values(self, device, dtype):
537
        # sqrt Test Reference: https://github.com/pytorch/pytorch/pull/47424
538
        x = torch.tensor(0.0 - 1.0e20j, dtype=dtype, device=device)
539
        self.compare_with_numpy(torch.sqrt, np.sqrt, x)
540
        # acos test reference: https://github.com/pytorch/pytorch/issue/42952
541
        # Skip on Windows, as CUDA acos  returns conjugate value
542
        # see https://github.com/pytorch/pytorch/issues/52299
543
        if not (IS_WINDOWS and dtype == torch.cdouble and "cuda" in device):
544
            self.compare_with_numpy(torch.acos, np.arccos, x)
545

546
        x = torch.tensor(
547
            (-1.0e60 if dtype == torch.cdouble else -1.0e20) - 4988429.2j,
548
            dtype=dtype,
549
            device=device,
550
        )
551
        self.compare_with_numpy(torch.sqrt, np.sqrt, x)
552

553
    @unittest.skipIf(not TEST_SCIPY, "Requires SciPy")
554
    @dtypes(torch.float, torch.double)
555
    def test_digamma_special(self, device, dtype):
556
        # Based on SciPy test for the following special values.
557
        # Reference:
558
        # https://github.com/scipy/scipy/blob/3a8a3a1d4657254a6611e77e9c28feafa26e6645/scipy/special/tests/test_digamma.py#L22
559
        euler = 0.57721566490153286
560
        dataset = [
561
            (0.0, -0.0),
562
            (1, -euler),
563
            (0.5, -2 * math.log(2) - euler),
564
            (1 / 3, -math.pi / (2 * math.sqrt(3)) - 3 * math.log(3) / 2 - euler),
565
            (1 / 4, -math.pi / 2 - 3 * math.log(2) - euler),
566
            (
567
                1 / 6,
568
                -math.pi * math.sqrt(3) / 2
569
                - 2 * math.log(2)
570
                - 3 * math.log(3) / 2
571
                - euler,
572
            ),
573
            (
574
                1 / 8,
575
                -math.pi / 2
576
                - 4 * math.log(2)
577
                - (math.pi + math.log(2 + math.sqrt(2)) - math.log(2 - math.sqrt(2)))
578
                / math.sqrt(2)
579
                - euler,
580
            ),
581
        ]
582
        x = torch.tensor(dataset, device=device, dtype=dtype)
583
        self.compare_with_numpy(torch.digamma, scipy.special.digamma, x)
584

585
    @unittest.skipIf(not TEST_SCIPY, "Requires SciPy")
586
    @dtypes(torch.float, torch.double)
587
    def test_digamma(self, device, dtype):
588
        # Tests pole behavior
589
        tensor = torch.tensor(
590
            [
591
                -0.999999994,
592
                -1.999999994,
593
                -2.0000000111,
594
                -100.99999994,
595
                0.000000111,
596
                -1931.99999994,
597
                -0.000000111,
598
                0,
599
                -0,
600
                -1,
601
                -2,
602
                -931,
603
            ],
604
            dtype=dtype,
605
            device=device,
606
        )
607
        self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor)
608

609
    @dtypes(*floating_types_and(torch.half))
610
    def test_frexp(self, device, dtype):
611
        input = make_tensor((50, 50), dtype=dtype, device=device)
612
        mantissa, exponent = torch.frexp(input)
613
        np_mantissa, np_exponent = np.frexp(input.cpu().numpy())
614

615
        self.assertEqual(mantissa, np_mantissa)
616
        self.assertEqual(exponent, np_exponent)
617

618
        # torch.frexp returns exponent in int32 to be compatible with np.frexp
619
        self.assertTrue(exponent.dtype == torch.int32)
620
        self.assertTrue(torch_to_numpy_dtype_dict[exponent.dtype] == np_exponent.dtype)
621

622
    def test_frexp_assert_raises(self, device):
623
        invalid_input_dtypes = integral_types_and(torch.bool) + complex_types()
624
        for dtype in invalid_input_dtypes:
625
            input = make_tensor((50, 50), dtype=dtype, device=device)
626
            with self.assertRaisesRegex(
627
                RuntimeError, r"torch\.frexp\(\) only supports floating-point dtypes"
628
            ):
629
                torch.frexp(input)
630

631
        for dtype in floating_types_and(torch.half):
632
            input = make_tensor((50, 50), dtype=dtype, device=device)
633

634
            dtypes = list(
635
                all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)
636
            )
637
            dtypes.remove(dtype)
638
            for mantissa_dtype in dtypes:
639
                mantissa = torch.empty_like(input, dtype=mantissa_dtype)
640
                exponent = torch.empty_like(input, dtype=torch.int)
641
                with self.assertRaisesRegex(
642
                    RuntimeError,
643
                    r"torch\.frexp\(\) expects mantissa to have dtype .+ but got .+",
644
                ):
645
                    torch.frexp(input, out=(mantissa, exponent))
646

647
            dtypes.append(dtype)
648
            dtypes.remove(torch.int)
649
            for exponent_dtype in dtypes:
650
                mantissa = torch.empty_like(input)
651
                exponent = torch.empty_like(input, dtype=exponent_dtype)
652
                with self.assertRaisesRegex(
653
                    RuntimeError,
654
                    r"torch\.frexp\(\) expects exponent to have int dtype but got .+",
655
                ):
656
                    torch.frexp(input, out=(mantissa, exponent))
657

658
    def test_polygamma_neg(self, device):
659
        with self.assertRaisesRegex(
660
            RuntimeError, r"polygamma\(n, x\) does not support negative n\."
661
        ):
662
            torch.polygamma(-1, torch.tensor([1.0, 2.0], device=device))
663

664
    # TODO resolve with opinfos
665
    @onlyCPU
666
    def test_op_invert(self, device):
667
        res = 0xFFFF - torch.arange(127, dtype=torch.int8)
668
        for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
669
            a = torch.arange(127, dtype=dtype)
670
            self.assertEqual(res.to(dtype), ~a)
671

672
        self.assertEqual(torch.tensor([True, False]), ~torch.tensor([False, True]))
673

674
        # test exceptions
675
        for dtype in (torch.half, torch.float, torch.double):
676
            a = torch.zeros(10, dtype=dtype)
677
            with self.assertRaises(TypeError):
678
                b = ~a
679

680
    @dtypes(torch.complex64, torch.complex128)
681
    def test_abs_angle_complex_to_float(self, device, dtype):
682
        # Constructs random complex values
683
        from random import random
684

685
        random_vals = []
686
        for multiplier in (-1, 1, -10, 10, -100, 100):
687
            for _ in range(10):
688
                random_vals.append(
689
                    complex(random() * multiplier, random() * multiplier)
690
                )
691

692
        for vals in (random_vals, []):
693
            a = np.array(vals, dtype=torch_to_numpy_dtype_dict[dtype])
694
            t = torch.tensor(vals, device=device, dtype=dtype)
695

696
            for fn_name in ("abs", "angle"):
697
                torch_fn = getattr(torch, fn_name)
698
                np_fn = getattr(np, fn_name)
699

700
                # Tests function
701
                np_result = torch.from_numpy(np_fn(a))
702
                torch_result = torch_fn(t).cpu()
703
                self.assertEqual(np_result, torch_result, exact_dtype=True)
704

705
                # Tests float out
706
                float_dtype = (
707
                    torch.float32 if dtype is torch.complex64 else torch.float64
708
                )
709
                np_float_out = np_fn(a).astype(torch_to_numpy_dtype_dict[float_dtype])
710
                float_out = torch.empty_like(t, dtype=float_dtype)
711
                torch_fn(t, out=float_out)
712
                self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu())
713

714
                # Tests float out (resized out)
715
                float_out = torch.empty(1, device=device, dtype=float_dtype)
716
                torch_fn(t, out=float_out)
717
                self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu())
718

719
                # Tests complex out
720
                np_complex_out = np_fn(a).astype(torch_to_numpy_dtype_dict[dtype])
721
                complex_out = torch.empty_like(t)
722
                torch_fn(t, out=complex_out)
723
                self.assertEqual(torch.from_numpy(np_complex_out), complex_out.cpu())
724

725
                # Tests complex out (resized out)
726
                complex_out = torch.empty(0, device=device, dtype=dtype)
727
                torch_fn(t, out=complex_out)
728
                self.assertEqual(torch.from_numpy(np_complex_out), complex_out.cpu())
729

730
                # Tests long out behavior (expected failure)
731
                long_out = torch.empty(0, device=device, dtype=torch.long)
732
                with self.assertRaises(RuntimeError):
733
                    torch_fn(t, out=long_out)
734

735
                # Tests inplace
736
                if fn_name == "abs":
737
                    torch_inplace_method = getattr(torch.Tensor, fn_name + "_")
738
                    np_fn(a, out=a)
739
                    if dtype.is_complex:
740
                        with self.assertRaisesRegex(
741
                            RuntimeError,
742
                            "In-place abs is not supported for complex tensors.",
743
                        ):
744
                            torch_inplace_method(t)
745
                        return
746
                    torch_inplace_method(t)
747
                    self.assertEqual(torch.from_numpy(a), t.cpu())
748

749
                # Note: angle does not have an in-place variant
750
                if fn_name == "angle":
751
                    with self.assertRaises(AttributeError):
752
                        torch_inplace_method = getattr(torch.Tensor, fn_name + "_")
753

754
    def check_internal_mem_overlap(
755
        self, inplace_op, num_inputs, dtype, device, expected_failure=False
756
    ):
757
        if isinstance(inplace_op, str):
758
            inplace_op = getattr(torch.Tensor, inplace_op)
759
        input = torch.randn(1, dtype=dtype, device=device).expand(3, 3)
760
        inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)]
761
        if not expected_failure:
762
            with self.assertRaisesRegex(RuntimeError, "single memory location"):
763
                inplace_op(*inputs)
764
        else:
765
            with self.assertRaises(AssertionError):
766
                with self.assertRaisesRegex(RuntimeError, "single memory location"):
767
                    inplace_op(*inputs)
768

769
    def unary_check_input_output_mem_overlap(
770
        self, data, sz, op, expected_failure=False
771
    ):
772
        def _test(op, output, input):
773
            output_exp = torch.empty_like(output)
774
            op(input, out=output_exp)
775
            self.assertEqual(op(input, out=output), output_exp, msg=op.__name__)
776

777
        # output is identical to input:
778
        _test(op, output=data[0:sz], input=data[0:sz])
779
        # output and input are independent:
780
        _test(op, output=data[0:sz], input=data[sz : 2 * sz])
781
        # output partially overlaps with input:
782
        if not expected_failure:
783
            with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
784
                _test(op, data[0:sz], data[1 : sz + 1])
785
        else:
786
            with self.assertRaises(AssertionError):
787
                with self.assertRaisesRegex(RuntimeError, "unsupported operation"):
788
                    _test(op, data[0:sz], data[1 : sz + 1])
789

790
    # TODO: run on non-native device types
791
    @dtypes(torch.double)
792
    def test_unary_out_op_mem_overlap(self, device, dtype):
793
        sz = 3
794
        doubles = torch.randn(2 * sz, dtype=dtype, device=device)
795
        positives = torch.randint(1, 100, (2 * sz,), device=device).double()
796
        ints = torch.randint(-100, 100, (2 * sz,), device=device)
797
        unary_mem_overlap_cases = [
798
            ("abs", doubles, True, True, "cpu"),
799
            ("abs", doubles, True, True, "cuda"),
800
            ("acos", doubles, True, True, "cpu"),
801
            ("acos", doubles, True, True, "cuda"),
802
            ("asin", doubles, True, True, "cpu"),
803
            ("asin", doubles, True, True, "cuda"),
804
            ("atan", doubles, True, True, "cpu"),
805
            ("atan", doubles, True, True, "cuda"),
806
            ("acosh", doubles, True, True, "cpu"),
807
            ("acosh", doubles, True, True, "cuda"),
808
            ("asinh", doubles, True, True, "cpu"),
809
            ("asinh", doubles, True, True, "cuda"),
810
            ("atanh", doubles, True, True, "cpu"),
811
            ("atanh", doubles, True, True, "cuda"),
812
            ("bitwise_not", ints, True, True, "cpu"),
813
            ("bitwise_not", ints, True, True, "cuda"),
814
            ("ceil", doubles, True, True, "cpu"),
815
            ("ceil", doubles, True, True, "cuda"),
816
            ("cos", doubles, True, True, "cpu"),
817
            ("cos", doubles, True, True, "cuda"),
818
            ("cosh", doubles, True, True, "cpu"),
819
            ("cosh", doubles, True, True, "cuda"),
820
            ("digamma", doubles, True, True, "cpu"),
821
            ("erf", doubles, True, True, "cpu"),
822
            ("erf", doubles, True, True, "cuda"),
823
            ("erfc", doubles, True, True, "cpu"),
824
            ("erfc", doubles, True, True, "cuda"),
825
            ("erfinv", doubles, True, True, "cpu"),
826
            ("erfinv", doubles, True, True, "cuda"),
827
            ("exp", doubles, True, True, "cpu"),
828
            ("exp", doubles, True, True, "cuda"),
829
            ("exp2", doubles, True, True, "cpu"),
830
            ("exp2", doubles, True, True, "cuda"),
831
            ("expm1", doubles, True, True, "cpu"),
832
            ("expm1", doubles, True, True, "cuda"),
833
            ("floor", doubles, True, True, "cpu"),
834
            ("floor", doubles, True, True, "cuda"),
835
            ("frac", doubles, True, True, "cpu"),
836
            ("frac", doubles, True, True, "cuda"),
837
            ("i0", doubles, True, True, "cpu"),
838
            ("i0", doubles, True, True, "cuda"),
839
            ("log", positives, True, True, "cpu"),
840
            ("log", positives, True, True, "cuda"),
841
            ("log10", positives, True, True, "cpu"),
842
            ("log10", positives, True, True, "cuda"),
843
            ("log1p", positives, True, True, "cpu"),
844
            ("log1p", positives, True, True, "cuda"),
845
            ("log2", positives, True, True, "cpu"),
846
            ("log2", positives, True, True, "cuda"),
847
            ("neg", doubles, True, True, "cpu"),
848
            ("neg", doubles, True, True, "cuda"),
849
            ("reciprocal", doubles, True, True, "cpu"),
850
            ("reciprocal", doubles, True, True, "cuda"),
851
            ("round", doubles, True, True, "cpu"),
852
            ("round", doubles, True, True, "cuda"),
853
            ("rsqrt", positives, True, True, "cpu"),
854
            ("rsqrt", positives, True, True, "cuda"),
855
            ("sin", doubles, True, True, "cpu"),
856
            ("sin", doubles, True, True, "cuda"),
857
            ("sinh", doubles, True, True, "cpu"),
858
            ("sinh", doubles, False, True, "cuda"),
859
            ("sigmoid", doubles, True, True, "cpu"),
860
            ("sigmoid", doubles, True, True, "cuda"),
861
            ("logit", doubles, True, True, "cpu"),
862
            ("logit", doubles, True, True, "cuda"),
863
            ("sqrt", doubles, True, True, "cpu"),
864
            ("sqrt", doubles, False, True, "cuda"),
865
            ("tan", doubles, True, True, "cpu"),
866
            ("tan", doubles, True, True, "cuda"),
867
            ("tanh", doubles, True, True, "cpu"),
868
            ("tanh", doubles, True, True, "cuda"),
869
            ("trunc", doubles, True, True, "cpu"),
870
            ("trunc", doubles, True, True, "cuda"),
871
        ]
872

873
        for (
874
            fn,
875
            inputs,
876
            has_input_output_mem_overlap_check,
877
            has_internal_mem_overlap_check,
878
            dev,
879
        ) in unary_mem_overlap_cases:
880
            if dev != device:
881
                continue
882
            out_fn = getattr(torch, fn)
883
            in_fn = getattr(torch.Tensor, fn + "_")
884

885
            self.unary_check_input_output_mem_overlap(
886
                inputs,
887
                sz,
888
                out_fn,
889
                expected_failure=not has_input_output_mem_overlap_check,
890
            )
891

892
            self.check_internal_mem_overlap(
893
                in_fn,
894
                1,
895
                dtype,
896
                dev,
897
                expected_failure=not has_internal_mem_overlap_check,
898
            )
899

900
    # TODO: opinfo hardshrink
901
    @onlyCPU
902
    @dtypes(torch.float, torch.double, torch.bfloat16)
903
    def test_hardshrink(self, device, dtype):
904
        data = torch.tensor([1, 0.5, 0.3, 0.6], dtype=dtype, device=device).view(2, 2)
905
        self.assertEqual(
906
            torch.tensor([1, 0.5, 0, 0.6], dtype=dtype, device=device).view(2, 2),
907
            data.hardshrink(0.3),
908
        )
909
        self.assertEqual(
910
            torch.tensor([1, 0, 0, 0.6], dtype=dtype, device=device).view(2, 2),
911
            data.hardshrink(0.5),
912
        )
913

914
        # test default lambd=0.5
915
        self.assertEqual(data.hardshrink(), data.hardshrink(0.5))
916

917
        # test non-contiguous case
918
        self.assertEqual(
919
            torch.tensor([1, 0, 0.5, 0.6], dtype=dtype, device=device).view(2, 2),
920
            data.t().hardshrink(0.3),
921
        )
922

923
    @onlyCPU
924
    @dtypes(torch.float, torch.double, torch.bfloat16)
925
    def test_hardshrink_edge_cases(self, device, dtype) -> None:
926
        def h(values, l_expected):
927
            for l, expected in l_expected.items():
928
                values_tensor = torch.tensor(
929
                    [float(v) for v in values], dtype=dtype, device=device
930
                )
931
                expected_tensor = torch.tensor(
932
                    [float(v) for v in expected], dtype=dtype, device=device
933
                )
934
                self.assertEqual(
935
                    expected_tensor == values_tensor.hardshrink(l),
936
                    torch.ones_like(values_tensor, dtype=torch.bool),
937
                )
938

939
        def test_helper(min, max):
940
            h(
941
                [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
942
                {
943
                    0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
944
                    min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
945
                    0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf],
946
                    1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf],
947
                    max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf],
948
                    inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
949
                },
950
            )
951

952
        test_helper(torch.finfo(dtype).tiny, torch.finfo(dtype).max)
953

954
    @onlyCPU
955
    @slowTest
956
    @dtypes(torch.float)
957
    @unittest.skipIf(True, "Insufficient memory on linux.(2|4)xlarge")
958
    def test_exp_slow(self, device, dtype):
959
        # Test for https://github.com/pytorch/pytorch/issues/17271
960
        # This is pretty slow on my Macbook but it only takes a few
961
        # seconds on a beefy Xeon server
962
        a = torch.exp(torch.ones(2**31, dtype=dtype, device=device))
963
        b = torch.exp(torch.ones(1, dtype=dtype, device=device))
964
        self.assertEqual(a, b.expand(2**31))
965

966
    @precisionOverride(
967
        {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}
968
    )
969
    @dtypes(torch.float, torch.double, torch.bfloat16)
970
    def test_hardswish(self, device, dtype):
971
        inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000]
972
        expectedOutput = np.multiply(
973
            inputValues, np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0
974
        )
975

976
        inputTensor = torch.tensor(inputValues, dtype=dtype, device=device)
977
        expectedOutputTensor = torch.tensor(expectedOutput, dtype=dtype, device=device)
978

979
        # normal
980
        self.assertEqual(
981
            torch.nn.functional.hardswish(inputTensor), expectedOutputTensor
982
        )
983

984
        # inplace
985
        inputTensorCpy = inputTensor.clone().detach()
986
        torch.nn.functional.hardswish(inputTensorCpy, inplace=True)
987
        self.assertEqual(inputTensorCpy, expectedOutputTensor)
988

989
    @precisionOverride(
990
        {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}
991
    )
992
    @dtypes(torch.float, torch.double, torch.bfloat16)
993
    def test_hardsigmoid(self, device, dtype):
994
        inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000]
995
        expectedOutput = np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0
996

997
        inputTensor = torch.tensor(inputValues, dtype=dtype, device=device)
998

999
        # normal
1000
        self.assertEqual(
1001
            torch.nn.functional.hardsigmoid(inputTensor),
1002
            torch.tensor(expectedOutput, dtype=dtype, device=device),
1003
        )
1004

1005
        # inplace
1006
        inputTensorCpy = inputTensor.clone().detach()
1007
        self.assertEqual(
1008
            torch.nn.functional.hardsigmoid(inputTensorCpy, inplace=True),
1009
            torch.tensor(expectedOutput, dtype=dtype, device=device),
1010
        )
1011

1012
    @precisionOverride(
1013
        {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}
1014
    )
1015
    @dtypes(torch.float, torch.double, torch.bfloat16)
1016
    def test_hardsigmoid_backward(self, device, dtype):
1017
        inputValues = [-3.0, 3.0, -2.0, 2.0, -6.0, 6.0]
1018
        expectedValues = [0.0, 0.0, 1.0 / 6.0, 1.0 / 6.0, 0.0, 0.0]
1019
        inputTensor = torch.tensor(
1020
            inputValues, dtype=dtype, device=device
1021
        ).requires_grad_()
1022
        expetedTensor = torch.tensor(expectedValues, dtype=dtype, device=device)
1023
        out = torch.nn.functional.hardsigmoid(inputTensor)
1024
        out.backward(torch.ones_like(inputTensor))
1025
        self.assertEqual(inputTensor.grad, expetedTensor)
1026

1027
    @skipIfNoSciPy
1028
    @dtypes(torch.float, torch.double)
1029
    def test_silu(self, device, dtype):
1030
        input_np = np.random.randn(5, 8)
1031
        special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]]
1032
        input_np = np.concatenate((input_np, special_input), axis=0).astype(
1033
            torch_to_numpy_dtype_dict[dtype]
1034
        )
1035
        expected_output_np = input_np * scipy.special.expit(input_np)
1036

1037
        expected_output = torch.from_numpy(expected_output_np).to(device)
1038
        expected_output_noncontig = expected_output.transpose(0, 1)
1039

1040
        atol = 1e-6
1041
        rtol = 1e-6
1042

1043
        input = torch.from_numpy(input_np).clone().contiguous().to(device)
1044
        self.assertEqual(
1045
            torch.nn.functional.silu(input), expected_output, atol=atol, rtol=rtol
1046
        )
1047
        self.assertEqual(
1048
            torch.nn.functional.silu(input, inplace=True),
1049
            expected_output,
1050
            atol=atol,
1051
            rtol=rtol,
1052
        )
1053

1054
        input = torch.from_numpy(input_np).clone().to(device)
1055
        input_noncontig = input.transpose(0, 1)
1056
        self.assertEqual(
1057
            torch.nn.functional.silu(input_noncontig),
1058
            expected_output_noncontig,
1059
            atol=atol,
1060
            rtol=rtol,
1061
        )
1062
        self.assertEqual(
1063
            torch.nn.functional.silu(input_noncontig, inplace=True),
1064
            expected_output_noncontig,
1065
            atol=atol,
1066
            rtol=rtol,
1067
        )
1068

1069
    @dtypes(torch.complex64, torch.complex128)
1070
    def test_silu_complex(self, device, dtype):
1071
        atol = 1e-6
1072
        rtol = 1e-6
1073
        inouts = [
1074
            (0.2 + 0.3j, 0.08775215595960617065 + 0.18024823069572448730j),
1075
            (1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j),
1076
            (-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j),
1077
            (0.0 + 0.5j, -0.06383547931909561157 + 0.25000000000000000000j),
1078
            (2.0j, -1.55740761756896972656 + 0.99999988079071044922j)
1079
        ]
1080

1081
        for inp, out in inouts:
1082
            res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device))
1083
            self.assertFalse(torch.any(torch.isnan(res)))
1084
            self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
1085
            self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
1086

1087
        for inp, out in inouts:
1088
            res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device), inplace=True)
1089
            self.assertFalse(torch.any(torch.isnan(res)))
1090
            self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
1091
            self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
1092

1093
    # It is not obvious how to merge this into OpInfo becuase these inputs
1094
    # succeed for gradcheck but are expected to fail for gradgradcheck
1095
    @dtypes(torch.double)
1096
    def test_sinc(self, device, dtype):
1097
        # The derivative of sinc(x) at x=0 has to be special cased.
1098
        # A naive computation will result in 0/0 -> NaN.
1099
        # We also need to be careful when we are very close to 0, as the
1100
        # derivative's denominator is squared, and there are some floats
1101
        # that are positive and whose squares are zero.
1102
        a = torch.tensor(
1103
            [0.0, torch.finfo(torch.double).tiny, 1.0],
1104
            dtype=dtype,
1105
            requires_grad=True,
1106
            device=device,
1107
        )
1108
        gradcheck(torch.sinc, a)
1109

1110
    @skipIfNoSciPy
1111
    @dtypes(torch.float, torch.double)
1112
    def test_mish(self, device, dtype):
1113
        input_np = np.random.randn(5, 8)
1114
        special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]]
1115
        input_np = np.concatenate((input_np, special_input), axis=0).astype(
1116
            torch_to_numpy_dtype_dict[dtype]
1117
        )
1118
        expected_output_np = input_np * np.tanh(np.log1p(np.exp(input_np)))
1119

1120
        expected_output = torch.from_numpy(expected_output_np).to(device)
1121
        expected_output_noncontig = expected_output.transpose(0, 1)
1122

1123
        atol = 1e-6
1124
        rtol = 1e-6
1125

1126
        input = torch.from_numpy(input_np).clone().contiguous().to(device)
1127
        self.assertEqual(
1128
            torch.nn.functional.mish(input), expected_output, atol=atol, rtol=rtol
1129
        )
1130
        self.assertEqual(
1131
            torch.nn.functional.mish(input, inplace=True),
1132
            expected_output,
1133
            atol=atol,
1134
            rtol=rtol,
1135
        )
1136

1137
        input = torch.from_numpy(input_np).clone().to(device)
1138
        input_noncontig = input.transpose(0, 1)
1139
        self.assertEqual(
1140
            torch.nn.functional.mish(input_noncontig),
1141
            expected_output_noncontig,
1142
            atol=atol,
1143
            rtol=rtol,
1144
        )
1145
        self.assertEqual(
1146
            torch.nn.functional.mish(input_noncontig, inplace=True),
1147
            expected_output_noncontig,
1148
            atol=atol,
1149
            rtol=rtol,
1150
        )
1151

1152
    @dtypes(torch.complex64, torch.complex128)
1153
    def test_log1p_complex(self, device, dtype):
1154
        # The output values here were obtained using arbitrary precision math (mpmath)
1155
        # and double checked with WolframAlpha.
1156
        # Not using numpy's log1p here because by the time of writing this,
1157
        # np.log1p has precision problems for small complex input values, see here:
1158
        # https://github.com/numpy/numpy/issues/22609
1159
        inouts = [
1160
            (0.2 + 0.3j, 0.21263386770217202 + 0.24497866312686414j),
1161
            (1e-19 + 1e-18j, 1e-19 + 1e-18j),
1162
            (1e-18 + 0.1j, 0.00497517 + 0.0996687j),
1163
            (0.1 + 1e-18j, 0.0953102 + 9.090909090909090909e-19j),
1164
            (0.5 + 0j, 0.40546510810816 + 0j),
1165
            (0.0 + 0.5j, 0.111571776 + 0.463647609j),
1166
            (2.0 + 1.0j, 1.151292546497023 + 0.3217505543966422j),
1167
            (-1.0 + 2.0j, 0.6931471805599453 + 1.570796326794897j),
1168
            (2.0j, 0.80471895621705014 + 1.1071487177940904j),
1169
            (-2.0j, 0.80471895621705014 - 1.1071487177940904j),
1170
        ]
1171
        # test the extreme values
1172
        if dtype == torch.complex128:
1173
            inouts += [
1174
                (-1 + 1e250j, 575.6462732485114 + 1.5707963267948966j),
1175
                (1e250 + 1j, 575.6462732485114 + 1e-250j),
1176
                (1e250 + 1e250j, 575.9928468387914 + 0.7853981633974483j),
1177
                (1e-250 + 1e250j, 575.6462732485114 + 1.5707963267948966j),
1178
                (1e-250 + 2e-250j, 1e-250 + 2e-250j),
1179
                (1e250 + 1e-250j, 575.6462732485114 + 0.0j),
1180
            ]
1181
        elif dtype == torch.complex64:
1182
            inouts += [
1183
                (-1 + 1e30j, 69.07755278982137 + 1.5707963267948966j),
1184
                (1e30 + 1j, 69.07755278982137 + 1e-30j),
1185
                (1e30 + 1e30j, 69.42412638010134 + 0.7853981633974483j),
1186
                (1e-30 + 1e30j, 69.07755278982137 + 1.5707963267948966j),
1187
                (1e-30 + 2e-30j, 1e-30 + 2e-30j),
1188
                (1e30 + 1e-30j, 69.07755278982137 + 0.0j),
1189
            ]
1190

1191
        # test the log1p individually
1192
        for inp, out in inouts:
1193
            res = torch.log1p(torch.tensor(inp, dtype=dtype, device=device))
1194
            self.assertFalse(torch.any(torch.isnan(res)))
1195
            # setting up atol == 0.0 because some part has very small values
1196
            self.assertEqual(res.real, out.real, atol=0.0, rtol=1e-6)
1197
            self.assertEqual(res.imag, out.imag, atol=0.0, rtol=1e-6)
1198

1199
        # test the log1p in tensor
1200
        inp_lst, out_lst = (list(elmt) for elmt in zip(*inouts))
1201
        inp_tens = torch.tensor(inp_lst, dtype=dtype, device=device)
1202
        out_tens = torch.tensor(out_lst, dtype=dtype, device=device)
1203
        res_tens = torch.log1p(inp_tens)
1204
        self.assertEqual(res_tens.real, out_tens.real, atol=0.0, rtol=1e-6)
1205
        self.assertEqual(res_tens.imag, out_tens.imag, atol=0.0, rtol=1e-6)
1206

1207
    # do ops like threshold need a test_unary(_nonufunc) test suite?
1208
    @onlyCPU
1209
    @dtypes(*get_all_math_dtypes("cpu"))
1210
    def test_threshold(self, device, dtype):
1211
        if dtype != torch.uint8 and dtype != torch.float16 and not dtype.is_complex:
1212
            # 100 is wide enough to use AVX2 instructions for all types
1213
            x = (
1214
                torch.randn(100, dtype=torch.float, device=device)
1215
                .sign()
1216
                .to(dtype=dtype)
1217
            )
1218
            y = torch.threshold(x, 0, 0)
1219
            self.assertTrue(y.le(0).any())
1220

1221
    def _helper_test_igamma(self, loglo, loghi, device, dtype, torch_fcn, scipy_fcn):
1222
        exp1 = 2.71828182846
1223
        vec1 = torch.logspace(
1224
            loglo, loghi, steps=500, base=exp1, dtype=torch.float64, device=device
1225
        ).unsqueeze(-1)
1226
        vec1 = vec1.to(dtype)
1227
        inputs = [
1228
            (vec1, vec1.transpose(0, 1)),
1229
            (vec1, vec1),  # for large number, it should approach 0.5
1230
            (vec1, 0.5 * vec1),  # test for considerable ratio
1231
            (vec1, 2.0 * vec1),
1232
            (vec1[::2, :], vec1[::2, :]),  # contiguous/noncontiguous tests
1233
            (vec1[::2, :], vec1[: vec1.shape[0] // 2, :]),
1234
            (vec1[: vec1.shape[0] // 2, :], vec1[::2, :]),
1235
        ]
1236
        half_prec = dtype in [torch.bfloat16, torch.float16]
1237
        for input0, input1 in inputs:
1238
            actual = torch_fcn(input0, input1)
1239
            if half_prec:
1240
                input0 = input0.to(torch.float)
1241
                input1 = input1.to(torch.float)
1242
            expected = scipy_fcn(input0.cpu().numpy(), input1.cpu().numpy())
1243
            expected = torch.from_numpy(expected).to(dtype)
1244
            self.assertEqual(actual, expected)
1245

1246
    @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1247
    @dtypes(torch.float32, torch.float64)
1248
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1249
    @onlyNativeDeviceTypes
1250
    def test_igamma_common(self, device, dtype):
1251
        # test igamma for reasonable range of values
1252
        loglo = -4  # approx 0.018
1253
        loghi = 4  # approx 54.6
1254
        self._helper_test_igamma(
1255
            loglo, loghi, device, dtype, torch.igamma, scipy.special.gammainc
1256
        )
1257

1258
    @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1259
    @dtypes(torch.float32, torch.float64)
1260
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1261
    @onlyNativeDeviceTypes
1262
    def test_igammac_common(self, device, dtype):
1263
        # test igammac for reasonable range of values
1264
        loglo = -4  # approx 0.018
1265
        loghi = 4  # approx 54.6
1266
        self._helper_test_igamma(
1267
            loglo, loghi, device, dtype, torch.igammac, scipy.special.gammaincc
1268
        )
1269

1270
    @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1271
    @dtypes(torch.float32, torch.float64)
1272
    @onlyNativeDeviceTypes
1273
    def test_igamma_edge_cases(self, device, dtype):
1274
        tkwargs = {"dtype": dtype, "device": device}
1275
        infs = torch.zeros((3,), **tkwargs) + float("inf")
1276
        zeros = torch.zeros((3,), **tkwargs)
1277
        ones = torch.ones((3,), **tkwargs)
1278
        zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs)
1279
        small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs)
1280
        nans = torch.zeros((3,), **tkwargs) + float("nan")
1281
        inpouts = [
1282
            # (a    ,    x),       out
1283
            ((zeros, small_to_inf), ones),
1284
            ((small_to_inf, zeros), zeros),
1285
            ((infs, zero_to_large), zeros),
1286
            ((zero_to_large, infs), ones),
1287
            ((zeros, zeros), nans),
1288
            ((infs, infs), nans),
1289
            ((-small_to_inf, small_to_inf), nans),
1290
        ]
1291
        for inputs, output in inpouts:
1292
            input0, input1 = inputs
1293
            calc = torch.igamma(input0, input1)
1294
            if torch.all(torch.isnan(output)):
1295
                self.assertTrue(torch.all(torch.isnan(calc)))
1296
            else:
1297
                self.assertEqual(calc, output)
1298

1299
    @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1300
    @dtypes(torch.float32, torch.float64)
1301
    @onlyNativeDeviceTypes
1302
    def test_igammac_edge_cases(self, device, dtype):
1303
        tkwargs = {"dtype": dtype, "device": device}
1304
        infs = torch.zeros((3,), **tkwargs) + float("inf")
1305
        zeros = torch.zeros((3,), **tkwargs)
1306
        ones = torch.ones((3,), **tkwargs)
1307
        zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs)
1308
        small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs)
1309
        nans = torch.zeros((3,), **tkwargs) + float("nan")
1310
        inpouts = [
1311
            # (a    ,    x),       out
1312
            ((zeros, small_to_inf), zeros),
1313
            ((small_to_inf, zeros), ones),
1314
            ((infs, zero_to_large), ones),
1315
            ((zero_to_large, infs), zeros),
1316
            ((zeros, zeros), nans),
1317
            ((infs, infs), nans),
1318
            ((-small_to_inf, small_to_inf), nans),
1319
        ]
1320
        for inputs, output in inpouts:
1321
            input0, input1 = inputs
1322
            calc = torch.igammac(input0, input1)
1323
            if torch.all(torch.isnan(output)):
1324
                self.assertTrue(torch.all(torch.isnan(calc)))
1325
            else:
1326
                self.assertEqual(calc, output)
1327

1328
    def _i0_helper(self, t):
1329
        # Test by comparing to scipy
1330
        dtype = t.dtype
1331
        actual = torch.i0(t)
1332
        if dtype is torch.bfloat16:
1333
            t = t.to(torch.float32)
1334
        expected = scipy.special.i0(t.cpu().numpy())
1335
        # Casting down for dtype float16 is required since scipy upcasts to float32
1336
        if dtype is torch.bfloat16 or dtype is torch.float16:
1337
            expected = torch.from_numpy(expected).to(dtype)
1338
        self.assertEqual(actual, expected)
1339

1340
    def _i0_range_helper(self, range, device, dtype):
1341
        # i0 tests are broken up by the domain for which the function does not overflow for each dtype
1342
        # This is done to ensure that the function performs well across all possible input values, without worrying
1343
        # about inf or nan possibilities
1344
        for r in (range, -range):
1345
            t = torch.rand(1000, device=device).to(dtype) * r
1346
            self._i0_helper(t)
1347

1348
    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1349
    @dtypes(torch.bfloat16, torch.float32, torch.float64)
1350
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1351
    def test_i0_range1(self, device, dtype):
1352
        # This tests the domain for i0 for which float16 does not overflow
1353
        # The domain is (-13.25, 13.25)
1354
        self._i0_range_helper(13.25, device, dtype)
1355

1356
    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1357
    @dtypes(torch.bfloat16, torch.float32, torch.float64)
1358
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1359
    def test_i0_range2(self, device, dtype):
1360
        # This tests the domain for i0 for which float32 and bfloat16 does not overflow
1361
        # The domain is (-88.5, 88.5)
1362
        self._i0_range_helper(88.5, device, dtype)
1363

1364
    @dtypes(torch.float64)
1365
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1366
    def test_i0_range3(self, device, dtype):
1367
        # This tests the domain for i0 for which float64 does not overflow
1368
        # The domain is (-709.75, 709.75)
1369
        self._i0_range_helper(709.75, device, dtype)
1370

1371
    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1372
    @dtypes(torch.bfloat16, torch.float32, torch.float64)
1373
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1374
    def test_i0_special(self, device, dtype):
1375
        t = torch.tensor([], device=device, dtype=dtype)
1376
        self._i0_helper(t)
1377

1378
        t = torch.tensor([inf, -inf, nan], device=device, dtype=dtype)
1379
        self.assertTrue(torch.i0(t).isnan().all())
1380

1381
    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1382
    @dtypes(torch.bfloat16, torch.float32, torch.float64)
1383
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1384
    def test_special_i0_i1_vs_scipy(self, device, dtype):
1385
        def check_equal(t, torch_fn, scipy_fn):
1386
            # Test by comparing to scipy
1387
            actual = torch_fn(t)
1388
            if dtype is torch.bfloat16:
1389
                t = t.to(torch.float32)
1390
            expected = scipy_fn(t.cpu().numpy())
1391

1392
            # Casting down for dtype float16 is required since scipy upcasts to float32
1393
            if dtype is torch.bfloat16 or dtype is torch.float16:
1394
                expected = torch.from_numpy(expected).to(dtype)
1395
            self.assertEqual(actual, expected)
1396

1397
        t = torch.tensor([], device=device, dtype=dtype)
1398
        check_equal(t, torch.i0, scipy.special.i0)
1399
        check_equal(t, torch.special.i0e, scipy.special.i0e)
1400
        if dtype not in [torch.half, torch.bfloat16]:
1401
            check_equal(t, torch.special.i1, scipy.special.i1)
1402
            check_equal(t, torch.special.i1e, scipy.special.i1e)
1403

1404
        range = (-1e7, 1e7)
1405
        if dtype == torch.half:
1406
            range = (-65000, 65000)
1407

1408
        t = torch.linspace(*range, int(1e4), device=device, dtype=dtype)
1409
        check_equal(t, torch.i0, scipy.special.i0)
1410
        check_equal(t, torch.special.i0e, scipy.special.i0e)
1411
        if dtype not in [torch.half, torch.bfloat16]:
1412
            check_equal(t, torch.special.i1, scipy.special.i1)
1413
            check_equal(t, torch.special.i1e, scipy.special.i1e)
1414

1415
        # NaN, inf, -inf are tested in reference_numerics tests.
1416
        info = torch.finfo(dtype)
1417
        min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
1418
        t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
1419
        check_equal(t, torch.i0, scipy.special.i0)
1420
        check_equal(t, torch.special.i0e, scipy.special.i0e)
1421
        if dtype not in [torch.half, torch.bfloat16]:
1422
            check_equal(t, torch.special.i1, scipy.special.i1)
1423
            check_equal(t, torch.special.i1e, scipy.special.i1e)
1424

1425
    @dtypes(torch.float32, torch.float64)
1426
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1427
    def test_special_ndtr_vs_scipy(self, device, dtype):
1428
        def check_equal(t):
1429
            # Test by comparing to scipy
1430
            actual = torch.special.ndtr(t)
1431
            expected = scipy.special.ndtr(t.cpu().numpy())
1432
            self.assertEqual(actual, expected)
1433

1434
        range = (-10, 10)
1435
        t = torch.linspace(*range, 1, device=device, dtype=dtype)
1436
        check_equal(t)
1437

1438
        # Skip testing NaN, inf, -inf since they are tested in reference_numerics tests.
1439
        info = torch.finfo(dtype)
1440
        min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
1441
        t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
1442
        check_equal(t)
1443

1444
    @dtypes(torch.float32, torch.float64)
1445
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1446
    def test_special_log_ndtr_vs_scipy(self, device, dtype):
1447
        def check_equal(t):
1448
            # Test by comparing with scipy
1449
            actual = torch.special.log_ndtr(t)
1450
            expected = scipy.special.log_ndtr(t.cpu().numpy())
1451
            self.assertEqual(actual, expected)
1452

1453
        # Skip testing NaN, inf, -inf since they are tested in reference_numerics tests.
1454
        info = torch.finfo(dtype)
1455
        min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
1456
        t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
1457
        check_equal(t)
1458

1459
    # TODO: allow large opinfo values to be opted-into via metadata
1460
    @dtypes(torch.long)
1461
    def test_abs_big_number(self, device, dtype):
1462
        bignumber = 2**31 + 1
1463
        res = torch.tensor([bignumber], device=device, dtype=dtype)
1464
        self.assertGreater(res.abs()[0], 0)
1465

1466
    # TODO: add signed zero testing to opinfos
1467
    @dtypes(torch.float, torch.double)
1468
    def test_abs_signed_zero(self, device, dtype):
1469
        # Both abs(0.0) and abs(-0.0) should result in 0.0
1470
        size = 128 + 1  # pick a large enough number with remainder so that
1471
        # both vectorized and nonvectorized op is tested
1472
        inp = torch.zeros(size, device=device, dtype=dtype)
1473
        inp[::2] = -0.0
1474
        inp = inp.abs()
1475
        for v in inp:
1476
            self.assertGreater(math.copysign(1.0, v), 0.0)
1477

1478
    # TODO: update to compare against NumPy by rationalizing with OpInfo
1479
    @onlyCUDA
1480
    @dtypes(torch.float, torch.double)
1481
    def test_abs_zero(self, device, dtype):
1482
        # Both abs(0.0) and abs(-0.0) should result in 0.0
1483
        abs_zeros = torch.tensor([0.0, -0.0], device=device, dtype=dtype).abs().tolist()
1484
        for num in abs_zeros:
1485
            self.assertGreater(math.copysign(1.0, num), 0.0)
1486

1487
    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
1488
    def test_isposinf_isneginf_non_boolean_output(self, device, dtype):
1489
        # test non-boolean tensors as the `out=` parameters
1490
        # boolean outputs are tested in the above testcases
1491
        vals = (float("inf"), -float("inf"), 1.2)
1492
        t = torch.tensor(vals, device=device)
1493
        for torch_op in (torch.isposinf, torch.isneginf):
1494
            out = torch.empty_like(t, dtype=dtype)
1495
            with self.assertRaisesRegex(
1496
                RuntimeError, "does not support non-boolean outputs"
1497
            ):
1498
                torch_op(t, out=out)
1499

1500
    def test_nonzero_empty(self, device):
1501
        def assert_tuple_empty(tup, dim):
1502
            self.assertEqual(dim, len(tup))
1503
            for t in tup:
1504
                self.assertEqual(torch.Size([0]), t.shape)
1505

1506
        x = torch.randn(0, 2, 0, 5, 0, device=device)
1507
        y = torch.nonzero(x)
1508
        z = torch.nonzero(x, as_tuple=True)
1509

1510
        self.assertEqual(0, y.numel())
1511
        self.assertEqual(torch.Size([0, 5]), y.shape)
1512
        assert_tuple_empty(z, 5)
1513

1514
        x = torch.tensor(0.5, device=device)
1515
        y = torch.nonzero(x)
1516
        # nonzero with as_tuple returns a
1517
        # tuple of len 1 for a zero-dim tensor.
1518
        # This is done to match Numpy behavior.
1519
        z = torch.nonzero(x, as_tuple=True)
1520
        self.assertEqual(1, len(z))
1521
        self.assertEqual(torch.zeros(1, dtype=torch.long), z[0])
1522

1523
        x = torch.zeros((), device=device)
1524
        y = torch.nonzero(x)
1525
        z = torch.nonzero(x, as_tuple=True)
1526
        self.assertEqual(torch.Size([0, 0]), y.shape)
1527
        self.assertEqual(1, len(z))
1528
        self.assertEqual(torch.empty(0, dtype=torch.long), z[0])
1529

1530
    # TODO: rationalize with exp OpInfo
1531
    @dtypes(*floating_and_complex_types_and(torch.bfloat16))
1532
    @dtypesIfCUDA(*floating_and_complex_types_and(torch.half, torch.bfloat16))
1533
    def test_exp(self, device, dtype):
1534
        for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
1535
            a = (
1536
                torch.tensor(v, dtype=dtype, device=device)
1537
                * torch.arange(18, device=device)
1538
                / 3
1539
                * math.pi
1540
            )
1541
            a = a.to(dtype)
1542
            # bfloat16 overflows
1543
            if dtype == torch.bfloat16:
1544
                return
1545
            self.compare_with_numpy(torch.exp, np.exp, a)
1546

1547
            if dtype.is_complex:
1548
                inf_real_zero_imag_in = torch.tensor(
1549
                    complex(float("inf"), 0), device=device, dtype=dtype
1550
                )
1551
                inf_real_zero_imag_out = torch.exp(inf_real_zero_imag_in).item()
1552
                self.assertTrue(math.isinf(inf_real_zero_imag_out.real))
1553
                if self.device_type == "cpu":
1554
                    pass
1555
                    # These are commented out because it cannot be consistently reproduced.
1556
                    # This is incorrect. It should be zero. Need fix!
1557
                    # https://github.com/pytorch/pytorch/issues/40590
1558
                    # self.assertNotEqual(inf_real_zero_imag_out.imag, 0)
1559
                    # This is incorrect. They should equal. Need fix!
1560
                    # https://github.com/pytorch/pytorch/issues/40590
1561
                    # with self.assertRaises(AssertionError):
1562
                    #     self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in)
1563
                else:
1564
                    self.assertEqual(inf_real_zero_imag_out.imag, 0, atol=0, rtol=0)
1565
                    self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in)
1566

1567
                zero_real_inf_imag_in = torch.tensor(
1568
                    complex(0, float("inf")), device=device, dtype=dtype
1569
                )
1570
                zero_real_inf_imag_out = torch.exp(zero_real_inf_imag_in).item()
1571
                self.assertTrue(math.isnan(zero_real_inf_imag_out.real))
1572
                self.assertTrue(math.isnan(zero_real_inf_imag_out.imag))
1573
                # Ensure we are notified when NumPy changes its behavior
1574
                self.compare_with_numpy(torch.exp, np.exp, zero_real_inf_imag_in)
1575

1576
                inf_real_imag_in = torch.tensor(
1577
                    complex(float("inf"), float("inf")), device=device, dtype=dtype
1578
                )
1579
                inf_real_imag_out = torch.exp(inf_real_imag_in).item()
1580
                if self.device_type == "cpu":
1581
                    pass
1582
                    # This is incorrect. Need fix! https://github.com/pytorch/pytorch/issues/40590
1583
                    # This is commented out because it cannot be consistently reproduced.
1584
                    # with self.assertRaises(AssertionError):
1585
                    #     self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in)
1586
                else:
1587
                    self.assertTrue(math.isinf(inf_real_imag_out.real))
1588
                    self.assertTrue(math.isnan(inf_real_imag_out.imag))
1589
                    self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in)
1590

1591
                inf_real_nan_imag_in = torch.tensor(
1592
                    complex(float("inf"), float("nan")), device=device, dtype=dtype
1593
                )
1594
                inf_real_nan_imag_out = torch.exp(inf_real_nan_imag_in).item()
1595
                if self.device_type == "cpu":
1596
                    pass
1597
                    # This is incorrect. It should be inf. Need fix! https://github.com/pytorch/pytorch/issues/40590
1598
                    # This is commented out because it cannot be consistently reproduced.
1599
                    # with self.assertRaises(AssertionError):
1600
                    #     self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in)
1601
                else:
1602
                    self.assertTrue(math.isinf(inf_real_nan_imag_out.real))
1603
                    self.assertTrue(math.isnan(inf_real_nan_imag_out.imag))
1604
                    self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in)
1605

1606
                nan_real_inf_imag_in = torch.tensor(
1607
                    complex(float("nan"), float("inf")), device=device, dtype=dtype
1608
                )
1609
                nan_real_inf_imag_out = torch.exp(nan_real_inf_imag_in).item()
1610
                self.assertTrue(math.isnan(nan_real_inf_imag_out.real))
1611
                self.assertTrue(math.isnan(nan_real_inf_imag_out.imag))
1612
                # Ensure we are notified when NumPy changes its behavior
1613
                self.compare_with_numpy(torch.exp, np.exp, nan_real_inf_imag_in)
1614

1615

1616
instantiate_device_type_tests(TestUnaryUfuncs, globals())
1617

1618
if __name__ == "__main__":
1619
    run_tests()
1620

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.