pytorch

Форк
0
/
test_testing.py 
2401 строка · 94.8 Кб
1
# Owner(s): ["module: tests"]
2

3
import collections
4
import doctest
5
import functools
6
import importlib
7
import inspect
8
import itertools
9
import math
10
import os
11
import re
12
import subprocess
13
import sys
14
import unittest.mock
15
from typing import Any, Callable, Iterator, List, Tuple
16

17
import torch
18

19
from torch.testing import make_tensor
20
from torch.testing._internal.common_utils import \
21
    (IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest,
22
     parametrize, subtest, instantiate_parametrized_tests, dtype_name, TEST_WITH_ROCM, decorateIf)
23
from torch.testing._internal.common_device_type import \
24
    (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
25
     get_device_type_test_bases, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes,
26
     deviceCountAtLeast, ops, expectedFailureMeta, OpDTypes)
27
from torch.testing._internal.common_methods_invocations import op_db
28
from torch.testing._internal import opinfo
29
from torch.testing._internal.common_dtype import all_types_and_complex_and, floating_types
30
from torch.testing._internal.common_modules import modules, module_db, ModuleInfo
31
from torch.testing._internal.opinfo.core import SampleInput, DecorateInfo, OpInfo
32
import operator
33

34
# For testing TestCase methods and torch.testing functions
35
class TestTesting(TestCase):
36
    # Ensure that assertEqual handles numpy arrays properly
37
    @dtypes(*all_types_and_complex_and(torch.bool, torch.half))
38
    def test_assertEqual_numpy(self, device, dtype):
39
        S = 10
40
        test_sizes = [
41
            (),
42
            (0,),
43
            (S,),
44
            (S, S),
45
            (0, S),
46
            (S, 0)]
47
        for test_size in test_sizes:
48
            a = make_tensor(test_size, dtype=dtype, device=device, low=-5, high=5)
49
            a_n = a.cpu().numpy()
50
            msg = f'size: {test_size}'
51
            self.assertEqual(a_n, a, rtol=0, atol=0, msg=msg)
52
            self.assertEqual(a, a_n, rtol=0, atol=0, msg=msg)
53
            self.assertEqual(a_n, a_n, rtol=0, atol=0, msg=msg)
54

55
    def test_assertEqual_longMessage(self):
56
        actual = "actual"
57
        expected = "expected"
58

59
        long_message = self.longMessage
60
        try:
61
            # Capture the default error message by forcing TestCase.longMessage = False
62
            self.longMessage = False
63
            try:
64
                self.assertEqual(actual, expected)
65
            except AssertionError as error:
66
                default_msg = str(error)
67
            else:
68
                raise AssertionError("AssertionError not raised")
69

70
            self.longMessage = True
71
            extra_msg = "sentinel"
72
            with self.assertRaisesRegex(AssertionError, re.escape(f"{default_msg}\n{extra_msg}")):
73
                self.assertEqual(actual, expected, msg=extra_msg)
74
        finally:
75
            self.longMessage = long_message
76

77
    def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05):
78
        for test in tests:
79
            a = torch.tensor((test[0],), device=device, dtype=dtype)
80
            b = torch.tensor((test[1],), device=device, dtype=dtype)
81

82
            actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol)
83
            expected = test[2]
84
            self.assertEqual(actual.item(), expected)
85

86
    def test_isclose_bool(self, device):
87
        tests = (
88
            (True, True, True),
89
            (False, False, True),
90
            (True, False, False),
91
            (False, True, False),
92
        )
93

94
        self._isclose_helper(tests, device, torch.bool, False)
95

96
    @dtypes(torch.uint8,
97
            torch.int8, torch.int16, torch.int32, torch.int64)
98
    def test_isclose_integer(self, device, dtype):
99
        tests = (
100
            (0, 0, True),
101
            (0, 1, False),
102
            (1, 0, False),
103
        )
104

105
        self._isclose_helper(tests, device, dtype, False)
106

107
        # atol and rtol tests
108
        tests = [
109
            (0, 1, True),
110
            (1, 0, False),
111
            (1, 3, True),
112
        ]
113

114
        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
115

116
        if dtype is torch.uint8:
117
            tests = [
118
                (-1, 1, False),
119
                (1, -1, False)
120
            ]
121
        else:
122
            tests = [
123
                (-1, 1, True),
124
                (1, -1, True)
125
            ]
126

127
        self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5)
128

129
    @onlyNativeDeviceTypes
130
    @dtypes(torch.float16, torch.float32, torch.float64)
131
    def test_isclose_float(self, device, dtype):
132
        tests = (
133
            (0, 0, True),
134
            (0, -1, False),
135
            (float('inf'), float('inf'), True),
136
            (-float('inf'), float('inf'), False),
137
            (float('inf'), float('nan'), False),
138
            (float('nan'), float('nan'), False),
139
            (0, float('nan'), False),
140
            (1, 1, True),
141
        )
142

143
        self._isclose_helper(tests, device, dtype, False)
144

145
        # atol and rtol tests
146
        eps = 1e-2 if dtype is torch.half else 1e-6
147
        tests = (
148
            (0, 1, True),
149
            (0, 1 + eps, False),
150
            (1, 0, False),
151
            (1, 3, True),
152
            (1 - eps, 3, False),
153
            (-.25, .5, True),
154
            (-.25 - eps, .5, False),
155
            (.25, -.5, True),
156
            (.25 + eps, -.5, False),
157
        )
158

159
        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
160

161
        # equal_nan = True tests
162
        tests = (
163
            (0, float('nan'), False),
164
            (float('inf'), float('nan'), False),
165
            (float('nan'), float('nan'), True),
166
        )
167

168
        self._isclose_helper(tests, device, dtype, True)
169

170
    @unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle")
171
    @dtypes(torch.complex64, torch.complex128)
172
    def test_isclose_complex(self, device, dtype):
173
        tests = (
174
            (complex(1, 1), complex(1, 1 + 1e-8), True),
175
            (complex(0, 1), complex(1, 1), False),
176
            (complex(1, 1), complex(1, 0), False),
177
            (complex(1, 1), complex(1, float('nan')), False),
178
            (complex(1, float('nan')), complex(1, float('nan')), False),
179
            (complex(1, 1), complex(1, float('inf')), False),
180
            (complex(float('inf'), 1), complex(1, float('inf')), False),
181
            (complex(-float('inf'), 1), complex(1, float('inf')), False),
182
            (complex(-float('inf'), 1), complex(float('inf'), 1), False),
183
            (complex(float('inf'), 1), complex(float('inf'), 1), True),
184
            (complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False),
185
        )
186

187
        self._isclose_helper(tests, device, dtype, False)
188

189
        # atol and rtol tests
190

191
        # atol and rtol tests
192
        eps = 1e-6
193
        tests = (
194
            # Complex versions of float tests (real part)
195
            (complex(0, 0), complex(1, 0), True),
196
            (complex(0, 0), complex(1 + eps, 0), False),
197
            (complex(1, 0), complex(0, 0), False),
198
            (complex(1, 0), complex(3, 0), True),
199
            (complex(1 - eps, 0), complex(3, 0), False),
200
            (complex(-.25, 0), complex(.5, 0), True),
201
            (complex(-.25 - eps, 0), complex(.5, 0), False),
202
            (complex(.25, 0), complex(-.5, 0), True),
203
            (complex(.25 + eps, 0), complex(-.5, 0), False),
204
            # Complex versions of float tests (imaginary part)
205
            (complex(0, 0), complex(0, 1), True),
206
            (complex(0, 0), complex(0, 1 + eps), False),
207
            (complex(0, 1), complex(0, 0), False),
208
            (complex(0, 1), complex(0, 3), True),
209
            (complex(0, 1 - eps), complex(0, 3), False),
210
            (complex(0, -.25), complex(0, .5), True),
211
            (complex(0, -.25 - eps), complex(0, .5), False),
212
            (complex(0, .25), complex(0, -.5), True),
213
            (complex(0, .25 + eps), complex(0, -.5), False),
214
        )
215

216
        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
217

218
        # atol and rtol tests for isclose
219
        tests = (
220
            # Complex-specific tests
221
            (complex(1, -1), complex(-1, 1), False),
222
            (complex(1, -1), complex(2, -2), True),
223
            (complex(-math.sqrt(2), math.sqrt(2)),
224
             complex(-math.sqrt(.5), math.sqrt(.5)), True),
225
            (complex(-math.sqrt(2), math.sqrt(2)),
226
             complex(-math.sqrt(.501), math.sqrt(.499)), False),
227
            (complex(2, 4), complex(1., 8.8523607), True),
228
            (complex(2, 4), complex(1., 8.8523607 + eps), False),
229
            (complex(1, 99), complex(4, 100), True),
230
        )
231
        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
232

233
        # equal_nan = True tests
234
        tests = (
235
            (complex(1, 1), complex(1, float('nan')), False),
236
            (complex(1, 1), complex(float('nan'), 1), False),
237
            (complex(float('nan'), 1), complex(float('nan'), 1), True),
238
            (complex(float('nan'), 1), complex(1, float('nan')), True),
239
            (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True),
240
        )
241
        self._isclose_helper(tests, device, dtype, True)
242

243
    # Tests that isclose with rtol or atol values less than zero throws a
244
    #   RuntimeError
245
    @dtypes(torch.bool, torch.uint8,
246
            torch.int8, torch.int16, torch.int32, torch.int64,
247
            torch.float16, torch.float32, torch.float64)
248
    def test_isclose_atol_rtol_greater_than_zero(self, device, dtype):
249
        t = torch.tensor((1,), device=device, dtype=dtype)
250

251
        with self.assertRaises(RuntimeError):
252
            torch.isclose(t, t, atol=-1, rtol=1)
253
        with self.assertRaises(RuntimeError):
254
            torch.isclose(t, t, atol=1, rtol=-1)
255
        with self.assertRaises(RuntimeError):
256
            torch.isclose(t, t, atol=-1, rtol=-1)
257

258
    def test_isclose_equality_shortcut(self):
259
        # For values >= 2**53, integers differing by 1 can no longer differentiated by torch.float64 or lower precision
260
        # floating point dtypes. Thus, even with rtol == 0 and atol == 0, these tensors would be considered close if
261
        # they were not compared as integers.
262
        a = torch.tensor(2 ** 53, dtype=torch.int64)
263
        b = a + 1
264

265
        self.assertFalse(torch.isclose(a, b, rtol=0, atol=0))
266

267
    @dtypes(torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128)
268
    def test_isclose_nan_equality_shortcut(self, device, dtype):
269
        if dtype.is_floating_point:
270
            a = b = torch.nan
271
        else:
272
            a = complex(torch.nan, 0)
273
            b = complex(0, torch.nan)
274

275
        expected = True
276
        tests = [(a, b, expected)]
277

278
        self._isclose_helper(tests, device, dtype, equal_nan=True, rtol=0, atol=0)
279

280
    # The following tests (test_cuda_assert_*) are added to ensure test suite terminates early
281
    # when CUDA assert was thrown. Because all subsequent test will fail if that happens.
282
    # These tests are slow because it spawn another process to run test suite.
283
    # See: https://github.com/pytorch/pytorch/issues/49019
284
    @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
285
    @onlyCUDA
286
    @slowTest
287
    def test_cuda_assert_should_stop_common_utils_test_suite(self, device):
288
        # test to ensure common_utils.py override has early termination for CUDA.
289
        stderr = TestCase.runWithPytorchAPIUsageStderr("""\
290
#!/usr/bin/env python3
291

292
import torch
293
from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
294

295
class TestThatContainsCUDAAssertFailure(TestCase):
296

297
    @slowTest
298
    def test_throw_unrecoverable_cuda_exception(self):
299
        x = torch.rand(10, device='cuda')
300
        # cause unrecoverable CUDA exception, recoverable on CPU
301
        y = x[torch.tensor([25])].cpu()
302

303
    @slowTest
304
    def test_trivial_passing_test_case_on_cpu_cuda(self):
305
        x1 = torch.tensor([0., 1.], device='cuda')
306
        x2 = torch.tensor([0., 1.], device='cpu')
307
        self.assertEqual(x1, x2)
308

309
if __name__ == '__main__':
310
    run_tests()
311
""")
312
        # should capture CUDA error
313
        self.assertIn('CUDA error: device-side assert triggered', stderr)
314
        # should run only 1 test because it throws unrecoverable error.
315
        self.assertIn('errors=1', stderr)
316

317

318
    @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
319
    @onlyCUDA
320
    @slowTest
321
    def test_cuda_assert_should_stop_common_device_type_test_suite(self, device):
322
        # test to ensure common_device_type.py override has early termination for CUDA.
323
        stderr = TestCase.runWithPytorchAPIUsageStderr("""\
324
#!/usr/bin/env python3
325

326
import torch
327
from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
328
from torch.testing._internal.common_device_type import instantiate_device_type_tests
329

330
class TestThatContainsCUDAAssertFailure(TestCase):
331

332
    @slowTest
333
    def test_throw_unrecoverable_cuda_exception(self, device):
334
        x = torch.rand(10, device=device)
335
        # cause unrecoverable CUDA exception, recoverable on CPU
336
        y = x[torch.tensor([25])].cpu()
337

338
    @slowTest
339
    def test_trivial_passing_test_case_on_cpu_cuda(self, device):
340
        x1 = torch.tensor([0., 1.], device=device)
341
        x2 = torch.tensor([0., 1.], device='cpu')
342
        self.assertEqual(x1, x2)
343

344
instantiate_device_type_tests(
345
    TestThatContainsCUDAAssertFailure,
346
    globals(),
347
    only_for='cuda'
348
)
349

350
if __name__ == '__main__':
351
    run_tests()
352
""")
353
        # should capture CUDA error
354
        self.assertIn('CUDA error: device-side assert triggered', stderr)
355
        # should run only 1 test because it throws unrecoverable error.
356
        self.assertIn('errors=1', stderr)
357

358

359
    @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
360
    @onlyCUDA
361
    @slowTest
362
    def test_cuda_assert_should_not_stop_common_distributed_test_suite(self, device):
363
        # test to ensure common_distributed.py override should not early terminate CUDA.
364
        stderr = TestCase.runWithPytorchAPIUsageStderr("""\
365
#!/usr/bin/env python3
366

367
import torch
368
from torch.testing._internal.common_utils import (run_tests, slowTest)
369
from torch.testing._internal.common_device_type import instantiate_device_type_tests
370
from torch.testing._internal.common_distributed import MultiProcessTestCase
371

372
class TestThatContainsCUDAAssertFailure(MultiProcessTestCase):
373

374
    @slowTest
375
    def test_throw_unrecoverable_cuda_exception(self, device):
376
        x = torch.rand(10, device=device)
377
        # cause unrecoverable CUDA exception, recoverable on CPU
378
        y = x[torch.tensor([25])].cpu()
379

380
    @slowTest
381
    def test_trivial_passing_test_case_on_cpu_cuda(self, device):
382
        x1 = torch.tensor([0., 1.], device=device)
383
        x2 = torch.tensor([0., 1.], device='cpu')
384
        self.assertEqual(x1, x2)
385

386
instantiate_device_type_tests(
387
    TestThatContainsCUDAAssertFailure,
388
    globals(),
389
    only_for='cuda'
390
)
391

392
if __name__ == '__main__':
393
    run_tests()
394
""")
395
        # we are currently disabling CUDA early termination for distributed tests.
396
        self.assertIn('errors=2', stderr)
397

398
    @expectedFailureMeta  # This is only supported for CPU and CUDA
399
    @onlyNativeDeviceTypes
400
    def test_get_supported_dtypes(self, device):
401
        # Test the `get_supported_dtypes` helper function.
402
        # We acquire the dtypes for few Ops dynamically and verify them against
403
        # the correct statically described values.
404
        ops_to_test = list(filter(lambda op: op.name in ['atan2', 'topk', 'xlogy'], op_db))
405

406
        for op in ops_to_test:
407
            dynamic_dtypes = opinfo.utils.get_supported_dtypes(op, op.sample_inputs_func, self.device_type)
408
            dynamic_dispatch = opinfo.utils.dtypes_dispatch_hint(dynamic_dtypes)
409
            if self.device_type == 'cpu':
410
                dtypes = op.dtypes
411
            else:  # device_type ='cuda'
412
                dtypes = op.dtypesIfCUDA
413

414
            self.assertTrue(set(dtypes) == set(dynamic_dtypes))
415
            self.assertTrue(set(dtypes) == set(dynamic_dispatch.dispatch_fn()))
416

417
    @onlyCPU
418
    @ops(
419
        [
420
            op
421
            for op in op_db
422
            if len(
423
                op.supported_dtypes("cpu").symmetric_difference(
424
                    op.supported_dtypes("cuda")
425
                )
426
            )
427
            > 0
428
        ][:1],
429
        dtypes=OpDTypes.none,
430
    )
431
    def test_supported_dtypes(self, device, op):
432
        self.assertNotEqual(op.supported_dtypes("cpu"), op.supported_dtypes("cuda"))
433
        self.assertEqual(op.supported_dtypes("cuda"), op.supported_dtypes("cuda:0"))
434
        self.assertEqual(
435
            op.supported_dtypes(torch.device("cuda")),
436
            op.supported_dtypes(torch.device("cuda", index=1)),
437
        )
438

439
instantiate_device_type_tests(TestTesting, globals())
440

441

442
class TestFrameworkUtils(TestCase):
443

444
    @unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows")
445
    @unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle")
446
    def test_filtering_env_var(self):
447
        # Test environment variable selected device type test generator.
448
        test_filter_file_template = """\
449
#!/usr/bin/env python3
450

451
import torch
452
from torch.testing._internal.common_utils import (TestCase, run_tests)
453
from torch.testing._internal.common_device_type import instantiate_device_type_tests
454

455
class TestEnvironmentVariable(TestCase):
456

457
    def test_trivial_passing_test(self, device):
458
        x1 = torch.tensor([0., 1.], device=device)
459
        x2 = torch.tensor([0., 1.], device='cpu')
460
        self.assertEqual(x1, x2)
461

462
instantiate_device_type_tests(
463
    TestEnvironmentVariable,
464
    globals(),
465
)
466

467
if __name__ == '__main__':
468
    run_tests()
469
"""
470
        test_bases_count = len(get_device_type_test_bases())
471
        # Test without setting env var should run everything.
472
        env = dict(os.environ)
473
        for k in ['CI', PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY]:
474
            if k in env.keys():
475
                del env[k]
476
        _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
477
        self.assertIn(f'Ran {test_bases_count} test', stderr.decode('ascii'))
478

479
        # Test with setting only_for should only run 1 test.
480
        env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
481
        _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
482
        self.assertIn('Ran 1 test', stderr.decode('ascii'))
483

484
        # Test with setting except_for should run 1 less device type from default.
485
        del env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY]
486
        env[PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY] = 'cpu'
487
        _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
488
        self.assertIn(f'Ran {test_bases_count-1} test', stderr.decode('ascii'))
489

490
        # Test with setting both should throw exception
491
        env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
492
        _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
493
        self.assertNotIn('OK', stderr.decode('ascii'))
494

495

496
def make_assert_close_inputs(actual: Any, expected: Any) -> List[Tuple[Any, Any]]:
497
    """Makes inputs for :func:`torch.testing.assert_close` functions based on two examples.
498

499
    Args:
500
        actual (Any): Actual input.
501
        expected (Any): Expected input.
502

503
    Returns:
504
        List[Tuple[Any, Any]]: Pair of example inputs, as well as the example inputs wrapped in sequences
505
        (:class:`tuple`, :class:`list`), and mappings (:class:`dict`, :class:`~collections.OrderedDict`).
506
    """
507
    return [
508
        (actual, expected),
509
        # tuple vs. tuple
510
        ((actual,), (expected,)),
511
        # list vs. list
512
        ([actual], [expected]),
513
        # tuple vs. list
514
        ((actual,), [expected]),
515
        # dict vs. dict
516
        ({"t": actual}, {"t": expected}),
517
        # OrderedDict vs. OrderedDict
518
        (collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])),
519
        # dict vs. OrderedDict
520
        ({"t": actual}, collections.OrderedDict([("t", expected)])),
521
        # list of tuples vs. tuple of lists
522
        ([(actual,)], ([expected],)),
523
        # list of dicts vs. tuple of OrderedDicts
524
        ([{"t": actual}], (collections.OrderedDict([("t", expected)]),)),
525
        # dict of lists vs. OrderedDict of tuples
526
        ({"t": [actual]}, collections.OrderedDict([("t", (expected,))])),
527
    ]
528

529

530
def assert_close_with_inputs(actual: Any, expected: Any) -> Iterator[Callable]:
531
    """Yields :func:`torch.testing.assert_close` with predefined positional inputs based on two examples.
532

533
    .. note::
534

535
        Every test that does not test for a specific input should iterate over this to maximize the coverage.
536

537
    Args:
538
        actual (Any): Actual input.
539
        expected (Any): Expected input.
540

541
    Yields:
542
        Callable: :func:`torch.testing.assert_close` with predefined positional inputs.
543
    """
544
    for inputs in make_assert_close_inputs(actual, expected):
545
        yield functools.partial(torch.testing.assert_close, *inputs)
546

547

548
class TestAssertClose(TestCase):
549
    def test_mismatching_types_subclasses(self):
550
        actual = torch.rand(())
551
        expected = torch.nn.Parameter(actual)
552

553
        for fn in assert_close_with_inputs(actual, expected):
554
            fn()
555

556
    def test_mismatching_types_type_equality(self):
557
        actual = torch.empty(())
558
        expected = torch.nn.Parameter(actual)
559

560
        for fn in assert_close_with_inputs(actual, expected):
561
            with self.assertRaisesRegex(TypeError, str(type(expected))):
562
                fn(allow_subclasses=False)
563

564
    def test_mismatching_types(self):
565
        actual = torch.empty(2)
566
        expected = actual.numpy()
567

568
        for fn, allow_subclasses in itertools.product(assert_close_with_inputs(actual, expected), (True, False)):
569
            with self.assertRaisesRegex(TypeError, str(type(expected))):
570
                fn(allow_subclasses=allow_subclasses)
571

572
    def test_unknown_type(self):
573
        actual = "0"
574
        expected = "0"
575

576
        for fn in assert_close_with_inputs(actual, expected):
577
            with self.assertRaisesRegex(TypeError, str(type(actual))):
578
                fn()
579

580
    def test_mismatching_shape(self):
581
        actual = torch.empty(())
582
        expected = actual.clone().reshape((1,))
583

584
        for fn in assert_close_with_inputs(actual, expected):
585
            with self.assertRaisesRegex(AssertionError, "shape"):
586
                fn()
587

588
    @unittest.skipIf(not torch.backends.mkldnn.is_available(), reason="MKLDNN is not available.")
589
    def test_unknown_layout(self):
590
        actual = torch.empty((2, 2))
591
        expected = actual.to_mkldnn()
592

593
        for fn in assert_close_with_inputs(actual, expected):
594
            with self.assertRaisesRegex(ValueError, "layout"):
595
                fn()
596

597
    def test_meta(self):
598
        actual = torch.empty((2, 2), device="meta")
599
        expected = torch.empty((2, 2), device="meta")
600

601
        for fn in assert_close_with_inputs(actual, expected):
602
            fn()
603

604
    def test_mismatching_layout(self):
605
        strided = torch.empty((2, 2))
606
        sparse_coo = strided.to_sparse()
607
        sparse_csr = strided.to_sparse_csr()
608

609
        for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
610
            for fn in assert_close_with_inputs(actual, expected):
611
                with self.assertRaisesRegex(AssertionError, "layout"):
612
                    fn()
613

614
    def test_mismatching_layout_no_check(self):
615
        strided = torch.randn((2, 2))
616
        sparse_coo = strided.to_sparse()
617
        sparse_csr = strided.to_sparse_csr()
618

619
        for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
620
            for fn in assert_close_with_inputs(actual, expected):
621
                fn(check_layout=False)
622

623
    def test_mismatching_dtype(self):
624
        actual = torch.empty((), dtype=torch.float)
625
        expected = actual.clone().to(torch.int)
626

627
        for fn in assert_close_with_inputs(actual, expected):
628
            with self.assertRaisesRegex(AssertionError, "dtype"):
629
                fn()
630

631
    def test_mismatching_dtype_no_check(self):
632
        actual = torch.ones((), dtype=torch.float)
633
        expected = actual.clone().to(torch.int)
634

635
        for fn in assert_close_with_inputs(actual, expected):
636
            fn(check_dtype=False)
637

638
    def test_mismatching_stride(self):
639
        actual = torch.empty((2, 2))
640
        expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
641

642
        for fn in assert_close_with_inputs(actual, expected):
643
            with self.assertRaisesRegex(AssertionError, "stride"):
644
                fn(check_stride=True)
645

646
    def test_mismatching_stride_no_check(self):
647
        actual = torch.rand((2, 2))
648
        expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
649
        for fn in assert_close_with_inputs(actual, expected):
650
            fn()
651

652
    def test_only_rtol(self):
653
        actual = torch.empty(())
654
        expected = actual.clone()
655

656
        for fn in assert_close_with_inputs(actual, expected):
657
            with self.assertRaises(ValueError):
658
                fn(rtol=0.0)
659

660
    def test_only_atol(self):
661
        actual = torch.empty(())
662
        expected = actual.clone()
663

664
        for fn in assert_close_with_inputs(actual, expected):
665
            with self.assertRaises(ValueError):
666
                fn(atol=0.0)
667

668
    def test_mismatching_values(self):
669
        actual = torch.tensor(1)
670
        expected = torch.tensor(2)
671

672
        for fn in assert_close_with_inputs(actual, expected):
673
            with self.assertRaises(AssertionError):
674
                fn()
675

676
    def test_mismatching_values_rtol(self):
677
        eps = 1e-3
678
        actual = torch.tensor(1.0)
679
        expected = torch.tensor(1.0 + eps)
680

681
        for fn in assert_close_with_inputs(actual, expected):
682
            with self.assertRaises(AssertionError):
683
                fn(rtol=eps / 2, atol=0.0)
684

685
    def test_mismatching_values_atol(self):
686
        eps = 1e-3
687
        actual = torch.tensor(0.0)
688
        expected = torch.tensor(eps)
689

690
        for fn in assert_close_with_inputs(actual, expected):
691
            with self.assertRaises(AssertionError):
692
                fn(rtol=0.0, atol=eps / 2)
693

694
    def test_matching(self):
695
        actual = torch.tensor(1.0)
696
        expected = actual.clone()
697

698
        torch.testing.assert_close(actual, expected)
699

700
    def test_matching_rtol(self):
701
        eps = 1e-3
702
        actual = torch.tensor(1.0)
703
        expected = torch.tensor(1.0 + eps)
704

705
        for fn in assert_close_with_inputs(actual, expected):
706
            fn(rtol=eps * 2, atol=0.0)
707

708
    def test_matching_atol(self):
709
        eps = 1e-3
710
        actual = torch.tensor(0.0)
711
        expected = torch.tensor(eps)
712

713
        for fn in assert_close_with_inputs(actual, expected):
714
            fn(rtol=0.0, atol=eps * 2)
715

716
    # TODO: the code that this test was designed for was removed in https://github.com/pytorch/pytorch/pull/56058
717
    #  We need to check if this test is still needed or if this behavior is now enabled by default.
718
    def test_matching_conjugate_bit(self):
719
        actual = torch.tensor(complex(1, 1)).conj()
720
        expected = torch.tensor(complex(1, -1))
721

722
        for fn in assert_close_with_inputs(actual, expected):
723
            fn()
724

725
    def test_matching_nan(self):
726
        nan = float("NaN")
727

728
        tests = (
729
            (nan, nan),
730
            (complex(nan, 0), complex(0, nan)),
731
            (complex(nan, nan), complex(nan, 0)),
732
            (complex(nan, nan), complex(nan, nan)),
733
        )
734

735
        for actual, expected in tests:
736
            for fn in assert_close_with_inputs(actual, expected):
737
                with self.assertRaises(AssertionError):
738
                    fn()
739

740
    def test_matching_nan_with_equal_nan(self):
741
        nan = float("NaN")
742

743
        tests = (
744
            (nan, nan),
745
            (complex(nan, 0), complex(0, nan)),
746
            (complex(nan, nan), complex(nan, 0)),
747
            (complex(nan, nan), complex(nan, nan)),
748
        )
749

750
        for actual, expected in tests:
751
            for fn in assert_close_with_inputs(actual, expected):
752
                fn(equal_nan=True)
753

754
    def test_numpy(self):
755
        tensor = torch.rand(2, 2, dtype=torch.float32)
756
        actual = tensor.numpy()
757
        expected = actual.copy()
758

759
        for fn in assert_close_with_inputs(actual, expected):
760
            fn()
761

762
    def test_scalar(self):
763
        number = torch.randint(10, size=()).item()
764
        for actual, expected in itertools.product((int(number), float(number), complex(number)), repeat=2):
765
            check_dtype = type(actual) is type(expected)
766

767
            for fn in assert_close_with_inputs(actual, expected):
768
                fn(check_dtype=check_dtype)
769

770
    def test_bool(self):
771
        actual = torch.tensor([True, False])
772
        expected = actual.clone()
773

774
        for fn in assert_close_with_inputs(actual, expected):
775
            fn()
776

777
    def test_none(self):
778
        actual = expected = None
779

780
        for fn in assert_close_with_inputs(actual, expected):
781
            fn()
782

783
    def test_none_mismatch(self):
784
        expected = None
785

786
        for actual in (False, 0, torch.nan, torch.tensor(torch.nan)):
787
            for fn in assert_close_with_inputs(actual, expected):
788
                with self.assertRaises(AssertionError):
789
                    fn()
790

791

792
    def test_docstring_examples(self):
793
        finder = doctest.DocTestFinder(verbose=False)
794
        runner = doctest.DocTestRunner(verbose=False, optionflags=doctest.NORMALIZE_WHITESPACE)
795
        globs = dict(torch=torch)
796
        doctests = finder.find(torch.testing.assert_close, globs=globs)[0]
797
        failures = []
798
        runner.run(doctests, out=lambda report: failures.append(report))
799
        if failures:
800
            raise AssertionError(f"Doctest found {len(failures)} failures:\n\n" + "\n".join(failures))
801

802
    def test_default_tolerance_selection_mismatching_dtypes(self):
803
        # If the default tolerances where selected based on the promoted dtype, i.e. float64,
804
        # these tensors wouldn't be considered close.
805
        actual = torch.tensor(0.99, dtype=torch.bfloat16)
806
        expected = torch.tensor(1.0, dtype=torch.float64)
807

808
        for fn in assert_close_with_inputs(actual, expected):
809
            fn(check_dtype=False)
810

811
    class UnexpectedException(Exception):
812
        """The only purpose of this exception is to test ``assert_close``'s handling of unexpected exceptions. Thus,
813
        the test should mock a component to raise this instead of the regular behavior. We avoid using a builtin
814
        exception here to avoid triggering possible handling of them.
815
        """
816

817
    @unittest.mock.patch("torch.testing._comparison.TensorLikePair.__init__", side_effect=UnexpectedException)
818
    def test_unexpected_error_originate(self, _):
819
        actual = torch.tensor(1.0)
820
        expected = actual.clone()
821

822
        with self.assertRaisesRegex(RuntimeError, "unexpected exception"):
823
            torch.testing.assert_close(actual, expected)
824

825
    @unittest.mock.patch("torch.testing._comparison.TensorLikePair.compare", side_effect=UnexpectedException)
826
    def test_unexpected_error_compare(self, _):
827
        actual = torch.tensor(1.0)
828
        expected = actual.clone()
829

830
        with self.assertRaisesRegex(RuntimeError, "unexpected exception"):
831
            torch.testing.assert_close(actual, expected)
832

833

834

835

836
class TestAssertCloseMultiDevice(TestCase):
837
    @deviceCountAtLeast(1)
838
    def test_mismatching_device(self, devices):
839
        for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
840
            actual = torch.empty((), device=actual_device)
841
            expected = actual.clone().to(expected_device)
842
            for fn in assert_close_with_inputs(actual, expected):
843
                with self.assertRaisesRegex(AssertionError, "device"):
844
                    fn()
845

846
    @deviceCountAtLeast(1)
847
    def test_mismatching_device_no_check(self, devices):
848
        for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
849
            actual = torch.rand((), device=actual_device)
850
            expected = actual.clone().to(expected_device)
851
            for fn in assert_close_with_inputs(actual, expected):
852
                fn(check_device=False)
853

854

855
instantiate_device_type_tests(TestAssertCloseMultiDevice, globals(), only_for="cuda")
856

857

858
class TestAssertCloseErrorMessage(TestCase):
859
    def test_identifier_tensor_likes(self):
860
        actual = torch.tensor([1, 2, 3, 4])
861
        expected = torch.tensor([1, 2, 5, 6])
862

863
        for fn in assert_close_with_inputs(actual, expected):
864
            with self.assertRaisesRegex(AssertionError, re.escape("Tensor-likes")):
865
                fn()
866

867
    def test_identifier_scalars(self):
868
        actual = 3
869
        expected = 5
870
        for fn in assert_close_with_inputs(actual, expected):
871
            with self.assertRaisesRegex(AssertionError, re.escape("Scalars")):
872
                fn()
873

874
    def test_not_equal(self):
875
        actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
876
        expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32)
877

878
        for fn in assert_close_with_inputs(actual, expected):
879
            with self.assertRaisesRegex(AssertionError, re.escape("not equal")):
880
                fn(rtol=0.0, atol=0.0)
881

882
    def test_not_close(self):
883
        actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
884
        expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32)
885

886
        for fn, (rtol, atol) in itertools.product(
887
            assert_close_with_inputs(actual, expected), ((1.3e-6, 0.0), (0.0, 1e-5), (1.3e-6, 1e-5))
888
        ):
889
            with self.assertRaisesRegex(AssertionError, re.escape("not close")):
890
                fn(rtol=rtol, atol=atol)
891

892
    def test_mismatched_elements(self):
893
        actual = torch.tensor([1, 2, 3, 4])
894
        expected = torch.tensor([1, 2, 5, 6])
895

896
        for fn in assert_close_with_inputs(actual, expected):
897
            with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
898
                fn()
899

900
    def test_abs_diff(self):
901
        actual = torch.tensor([[1, 2], [3, 4]])
902
        expected = torch.tensor([[1, 2], [5, 4]])
903

904
        for fn in assert_close_with_inputs(actual, expected):
905
            with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at index (1, 0)")):
906
                fn()
907

908
    def test_abs_diff_scalar(self):
909
        actual = 3
910
        expected = 5
911

912
        for fn in assert_close_with_inputs(actual, expected):
913
            with self.assertRaisesRegex(AssertionError, re.escape("Absolute difference: 2")):
914
                fn()
915

916
    def test_rel_diff(self):
917
        actual = torch.tensor([[1, 2], [3, 4]])
918
        expected = torch.tensor([[1, 4], [3, 4]])
919

920
        for fn in assert_close_with_inputs(actual, expected):
921
            with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at index (0, 1)")):
922
                fn()
923

924
    def test_rel_diff_scalar(self):
925
        actual = 2
926
        expected = 4
927

928
        for fn in assert_close_with_inputs(actual, expected):
929
            with self.assertRaisesRegex(AssertionError, re.escape("Relative difference: 0.5")):
930
                fn()
931

932
    def test_zero_div_zero(self):
933
        actual = torch.tensor([1.0, 0.0])
934
        expected = torch.tensor([2.0, 0.0])
935

936
        for fn in assert_close_with_inputs(actual, expected):
937
            # Although it looks complicated, this regex just makes sure that the word 'nan' is not part of the error
938
            # message. That would happen if the 0 / 0 is used for the mismatch computation although it matches.
939
            with self.assertRaisesRegex(AssertionError, "((?!nan).)*"):
940
                fn()
941

942
    def test_rtol(self):
943
        rtol = 1e-3
944

945
        actual = torch.tensor((1, 2))
946
        expected = torch.tensor((2, 2))
947

948
        for fn in assert_close_with_inputs(actual, expected):
949
            with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {rtol} allowed)")):
950
                fn(rtol=rtol, atol=0.0)
951

952
    def test_atol(self):
953
        atol = 1e-3
954

955
        actual = torch.tensor((1, 2))
956
        expected = torch.tensor((2, 2))
957

958
        for fn in assert_close_with_inputs(actual, expected):
959
            with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {atol} allowed)")):
960
                fn(rtol=0.0, atol=atol)
961

962
    def test_msg_str(self):
963
        msg = "Custom error message!"
964

965
        actual = torch.tensor(1)
966
        expected = torch.tensor(2)
967

968
        for fn in assert_close_with_inputs(actual, expected):
969
            with self.assertRaisesRegex(AssertionError, msg):
970
                fn(msg=msg)
971

972
    def test_msg_callable(self):
973
        msg = "Custom error message"
974

975
        actual = torch.tensor(1)
976
        expected = torch.tensor(2)
977

978
        for fn in assert_close_with_inputs(actual, expected):
979
            with self.assertRaisesRegex(AssertionError, msg):
980
                fn(msg=lambda _: msg)
981

982

983
class TestAssertCloseContainer(TestCase):
984
    def test_sequence_mismatching_len(self):
985
        actual = (torch.empty(()),)
986
        expected = ()
987

988
        with self.assertRaises(AssertionError):
989
            torch.testing.assert_close(actual, expected)
990

991
    def test_sequence_mismatching_values_msg(self):
992
        t1 = torch.tensor(1)
993
        t2 = torch.tensor(2)
994

995
        actual = (t1, t1)
996
        expected = (t1, t2)
997

998
        with self.assertRaisesRegex(AssertionError, re.escape("item [1]")):
999
            torch.testing.assert_close(actual, expected)
1000

1001
    def test_mapping_mismatching_keys(self):
1002
        actual = {"a": torch.empty(())}
1003
        expected = {}
1004

1005
        with self.assertRaises(AssertionError):
1006
            torch.testing.assert_close(actual, expected)
1007

1008
    def test_mapping_mismatching_values_msg(self):
1009
        t1 = torch.tensor(1)
1010
        t2 = torch.tensor(2)
1011

1012
        actual = {"a": t1, "b": t1}
1013
        expected = {"a": t1, "b": t2}
1014

1015
        with self.assertRaisesRegex(AssertionError, re.escape("item ['b']")):
1016
            torch.testing.assert_close(actual, expected)
1017

1018

1019
class TestAssertCloseSparseCOO(TestCase):
1020
    def test_matching_coalesced(self):
1021
        indices = (
1022
            (0, 1),
1023
            (1, 0),
1024
        )
1025
        values = (1, 2)
1026
        actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)).coalesce()
1027
        expected = actual.clone()
1028

1029
        for fn in assert_close_with_inputs(actual, expected):
1030
            fn()
1031

1032
    def test_matching_uncoalesced(self):
1033
        indices = (
1034
            (0, 1),
1035
            (1, 0),
1036
        )
1037
        values = (1, 2)
1038
        actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
1039
        expected = actual.clone()
1040

1041
        for fn in assert_close_with_inputs(actual, expected):
1042
            fn()
1043

1044
    def test_mismatching_sparse_dims(self):
1045
        t = torch.randn(2, 3, 4)
1046
        actual = t.to_sparse()
1047
        expected = t.to_sparse(2)
1048

1049
        for fn in assert_close_with_inputs(actual, expected):
1050
            with self.assertRaisesRegex(AssertionError, re.escape("number of sparse dimensions in sparse COO tensors")):
1051
                fn()
1052

1053
    def test_mismatching_nnz(self):
1054
        actual_indices = (
1055
            (0, 1),
1056
            (1, 0),
1057
        )
1058
        actual_values = (1, 2)
1059
        actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
1060

1061
        expected_indices = (
1062
            (0, 1, 1,),
1063
            (1, 0, 0,),
1064
        )
1065
        expected_values = (1, 1, 1)
1066
        expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
1067

1068
        for fn in assert_close_with_inputs(actual, expected):
1069
            with self.assertRaisesRegex(AssertionError, re.escape("number of specified values in sparse COO tensors")):
1070
                fn()
1071

1072
    def test_mismatching_indices_msg(self):
1073
        actual_indices = (
1074
            (0, 1),
1075
            (1, 0),
1076
        )
1077
        actual_values = (1, 2)
1078
        actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
1079

1080
        expected_indices = (
1081
            (0, 1),
1082
            (1, 1),
1083
        )
1084
        expected_values = (1, 2)
1085
        expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
1086

1087
        for fn in assert_close_with_inputs(actual, expected):
1088
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO indices")):
1089
                fn()
1090

1091
    def test_mismatching_values_msg(self):
1092
        actual_indices = (
1093
            (0, 1),
1094
            (1, 0),
1095
        )
1096
        actual_values = (1, 2)
1097
        actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
1098

1099
        expected_indices = (
1100
            (0, 1),
1101
            (1, 0),
1102
        )
1103
        expected_values = (1, 3)
1104
        expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
1105

1106
        for fn in assert_close_with_inputs(actual, expected):
1107
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO values")):
1108
                fn()
1109

1110

1111
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSR testing")
1112
class TestAssertCloseSparseCSR(TestCase):
1113
    def test_matching(self):
1114
        crow_indices = (0, 1, 2)
1115
        col_indices = (1, 0)
1116
        values = (1, 2)
1117
        actual = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2))
1118
        expected = actual.clone()
1119

1120
        for fn in assert_close_with_inputs(actual, expected):
1121
            fn()
1122

1123
    def test_mismatching_crow_indices_msg(self):
1124
        actual_crow_indices = (0, 1, 2)
1125
        actual_col_indices = (0, 1)
1126
        actual_values = (1, 2)
1127
        actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1128

1129
        expected_crow_indices = (0, 2, 2)
1130
        expected_col_indices = actual_col_indices
1131
        expected_values = actual_values
1132
        expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1133

1134
        for fn in assert_close_with_inputs(actual, expected):
1135
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR crow_indices")):
1136
                fn()
1137

1138
    def test_mismatching_col_indices_msg(self):
1139
        actual_crow_indices = (0, 1, 2)
1140
        actual_col_indices = (1, 0)
1141
        actual_values = (1, 2)
1142
        actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1143

1144
        expected_crow_indices = actual_crow_indices
1145
        expected_col_indices = (1, 1)
1146
        expected_values = actual_values
1147
        expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1148

1149
        for fn in assert_close_with_inputs(actual, expected):
1150
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR col_indices")):
1151
                fn()
1152

1153
    def test_mismatching_values_msg(self):
1154
        actual_crow_indices = (0, 1, 2)
1155
        actual_col_indices = (1, 0)
1156
        actual_values = (1, 2)
1157
        actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1158

1159
        expected_crow_indices = actual_crow_indices
1160
        expected_col_indices = actual_col_indices
1161
        expected_values = (1, 3)
1162
        expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1163

1164
        for fn in assert_close_with_inputs(actual, expected):
1165
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR values")):
1166
                fn()
1167

1168

1169
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSC testing")
1170
class TestAssertCloseSparseCSC(TestCase):
1171
    def test_matching(self):
1172
        ccol_indices = (0, 1, 2)
1173
        row_indices = (1, 0)
1174
        values = (1, 2)
1175
        actual = torch.sparse_csc_tensor(ccol_indices, row_indices, values, size=(2, 2))
1176
        expected = actual.clone()
1177

1178
        for fn in assert_close_with_inputs(actual, expected):
1179
            fn()
1180

1181
    def test_mismatching_ccol_indices_msg(self):
1182
        actual_ccol_indices = (0, 1, 2)
1183
        actual_row_indices = (0, 1)
1184
        actual_values = (1, 2)
1185
        actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1186

1187
        expected_ccol_indices = (0, 2, 2)
1188
        expected_row_indices = actual_row_indices
1189
        expected_values = actual_values
1190
        expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1191

1192
        for fn in assert_close_with_inputs(actual, expected):
1193
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC ccol_indices")):
1194
                fn()
1195

1196
    def test_mismatching_row_indices_msg(self):
1197
        actual_ccol_indices = (0, 1, 2)
1198
        actual_row_indices = (1, 0)
1199
        actual_values = (1, 2)
1200
        actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1201

1202
        expected_ccol_indices = actual_ccol_indices
1203
        expected_row_indices = (1, 1)
1204
        expected_values = actual_values
1205
        expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1206

1207
        for fn in assert_close_with_inputs(actual, expected):
1208
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC row_indices")):
1209
                fn()
1210

1211
    def test_mismatching_values_msg(self):
1212
        actual_ccol_indices = (0, 1, 2)
1213
        actual_row_indices = (1, 0)
1214
        actual_values = (1, 2)
1215
        actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1216

1217
        expected_ccol_indices = actual_ccol_indices
1218
        expected_row_indices = actual_row_indices
1219
        expected_values = (1, 3)
1220
        expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1221

1222
        for fn in assert_close_with_inputs(actual, expected):
1223
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC values")):
1224
                fn()
1225

1226

1227
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSR testing")
1228
class TestAssertCloseSparseBSR(TestCase):
1229
    def test_matching(self):
1230
        crow_indices = (0, 1, 2)
1231
        col_indices = (1, 0)
1232
        values = ([[1]], [[2]])
1233
        actual = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(2, 2))
1234
        expected = actual.clone()
1235

1236
        for fn in assert_close_with_inputs(actual, expected):
1237
            fn()
1238

1239
    def test_mismatching_crow_indices_msg(self):
1240
        actual_crow_indices = (0, 1, 2)
1241
        actual_col_indices = (0, 1)
1242
        actual_values = ([[1]], [[2]])
1243
        actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1244

1245
        expected_crow_indices = (0, 2, 2)
1246
        expected_col_indices = actual_col_indices
1247
        expected_values = actual_values
1248
        expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1249

1250
        for fn in assert_close_with_inputs(actual, expected):
1251
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR crow_indices")):
1252
                fn()
1253

1254
    def test_mismatching_col_indices_msg(self):
1255
        actual_crow_indices = (0, 1, 2)
1256
        actual_col_indices = (1, 0)
1257
        actual_values = ([[1]], [[2]])
1258
        actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1259

1260
        expected_crow_indices = actual_crow_indices
1261
        expected_col_indices = (1, 1)
1262
        expected_values = actual_values
1263
        expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1264

1265
        for fn in assert_close_with_inputs(actual, expected):
1266
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR col_indices")):
1267
                fn()
1268

1269
    def test_mismatching_values_msg(self):
1270
        actual_crow_indices = (0, 1, 2)
1271
        actual_col_indices = (1, 0)
1272
        actual_values = ([[1]], [[2]])
1273
        actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1274

1275
        expected_crow_indices = actual_crow_indices
1276
        expected_col_indices = actual_col_indices
1277
        expected_values = ([[1]], [[3]])
1278
        expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1279

1280
        for fn in assert_close_with_inputs(actual, expected):
1281
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR values")):
1282
                fn()
1283

1284

1285
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSC testing")
1286
class TestAssertCloseSparseBSC(TestCase):
1287
    def test_matching(self):
1288
        ccol_indices = (0, 1, 2)
1289
        row_indices = (1, 0)
1290
        values = ([[1]], [[2]])
1291
        actual = torch.sparse_bsc_tensor(ccol_indices, row_indices, values, size=(2, 2))
1292
        expected = actual.clone()
1293

1294
        for fn in assert_close_with_inputs(actual, expected):
1295
            fn()
1296

1297
    def test_mismatching_ccol_indices_msg(self):
1298
        actual_ccol_indices = (0, 1, 2)
1299
        actual_row_indices = (0, 1)
1300
        actual_values = ([[1]], [[2]])
1301
        actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1302

1303
        expected_ccol_indices = (0, 2, 2)
1304
        expected_row_indices = actual_row_indices
1305
        expected_values = actual_values
1306
        expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1307

1308
        for fn in assert_close_with_inputs(actual, expected):
1309
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC ccol_indices")):
1310
                fn()
1311

1312
    def test_mismatching_row_indices_msg(self):
1313
        actual_ccol_indices = (0, 1, 2)
1314
        actual_row_indices = (1, 0)
1315
        actual_values = ([[1]], [[2]])
1316
        actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1317

1318
        expected_ccol_indices = actual_ccol_indices
1319
        expected_row_indices = (1, 1)
1320
        expected_values = actual_values
1321
        expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1322

1323
        for fn in assert_close_with_inputs(actual, expected):
1324
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC row_indices")):
1325
                fn()
1326

1327
    def test_mismatching_values_msg(self):
1328
        actual_ccol_indices = (0, 1, 2)
1329
        actual_row_indices = (1, 0)
1330
        actual_values = ([[1]], [[2]])
1331
        actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1332

1333
        expected_ccol_indices = actual_ccol_indices
1334
        expected_row_indices = actual_row_indices
1335
        expected_values = ([[1]], [[3]])
1336
        expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1337

1338
        for fn in assert_close_with_inputs(actual, expected):
1339
            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC values")):
1340
                fn()
1341

1342

1343
class TestAssertCloseQuantized(TestCase):
1344
    def test_mismatching_is_quantized(self):
1345
        actual = torch.tensor(1.0)
1346
        expected = torch.quantize_per_tensor(actual, scale=1.0, zero_point=0, dtype=torch.qint32)
1347

1348
        for fn in assert_close_with_inputs(actual, expected):
1349
            with self.assertRaisesRegex(AssertionError, "is_quantized"):
1350
                fn()
1351

1352
    def test_mismatching_qscheme(self):
1353
        t = torch.tensor((1.0,))
1354
        actual = torch.quantize_per_tensor(t, scale=1.0, zero_point=0, dtype=torch.qint32)
1355
        expected = torch.quantize_per_channel(
1356
            t,
1357
            scales=torch.tensor((1.0,)),
1358
            zero_points=torch.tensor((0,)),
1359
            axis=0,
1360
            dtype=torch.qint32,
1361
        )
1362

1363
        for fn in assert_close_with_inputs(actual, expected):
1364
            with self.assertRaisesRegex(AssertionError, "qscheme"):
1365
                fn()
1366

1367
    def test_matching_per_tensor(self):
1368
        actual = torch.quantize_per_tensor(torch.tensor(1.0), scale=1.0, zero_point=0, dtype=torch.qint32)
1369
        expected = actual.clone()
1370

1371
        for fn in assert_close_with_inputs(actual, expected):
1372
            fn()
1373

1374
    def test_matching_per_channel(self):
1375
        actual = torch.quantize_per_channel(
1376
            torch.tensor((1.0,)),
1377
            scales=torch.tensor((1.0,)),
1378
            zero_points=torch.tensor((0,)),
1379
            axis=0,
1380
            dtype=torch.qint32,
1381
        )
1382
        expected = actual.clone()
1383

1384
        for fn in assert_close_with_inputs(actual, expected):
1385
            fn()
1386

1387

1388
class TestMakeTensor(TestCase):
1389
    supported_dtypes = dtypes(
1390
        torch.bool,
1391
        torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
1392
        torch.float16, torch.bfloat16, torch.float32, torch.float64,
1393
        torch.complex32, torch.complex64, torch.complex128,
1394
    )
1395

1396
    @supported_dtypes
1397
    @parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)])
1398
    @parametrize("splat_shape", [False, True])
1399
    def test_smoke(self, dtype, device, shape, splat_shape):
1400
        t = torch.testing.make_tensor(*shape if splat_shape else shape, dtype=dtype, device=device)
1401

1402
        self.assertIsInstance(t, torch.Tensor)
1403
        self.assertEqual(t.shape, shape)
1404
        self.assertEqual(t.dtype, dtype)
1405
        self.assertEqual(t.device, torch.device(device))
1406

1407
    @supported_dtypes
1408
    @parametrize("requires_grad", [False, True])
1409
    def test_requires_grad(self, dtype, device, requires_grad):
1410
        make_tensor = functools.partial(
1411
            torch.testing.make_tensor,
1412
            dtype=dtype,
1413
            device=device,
1414
            requires_grad=requires_grad,
1415
        )
1416

1417
        if not requires_grad or dtype.is_floating_point or dtype.is_complex:
1418
            t = make_tensor()
1419
            self.assertEqual(t.requires_grad, requires_grad)
1420
        else:
1421
            with self.assertRaisesRegex(
1422
                    ValueError, "`requires_grad=True` is not supported for boolean and integral dtypes"
1423
            ):
1424
                make_tensor()
1425

1426
    @supported_dtypes
1427
    @parametrize("noncontiguous", [False, True])
1428
    @parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)])
1429
    def test_noncontiguous(self, dtype, device, noncontiguous, shape):
1430
        numel = functools.reduce(operator.mul, shape, 1)
1431

1432
        t = torch.testing.make_tensor(shape, dtype=dtype, device=device, noncontiguous=noncontiguous)
1433
        self.assertEqual(t.is_contiguous(), not noncontiguous or numel < 2)
1434

1435
    @supported_dtypes
1436
    @parametrize(
1437
        "memory_format_and_shape",
1438
        [
1439
            (None, (2, 3, 4)),
1440
            (torch.contiguous_format, (2, 3, 4)),
1441
            (torch.channels_last, (2, 3, 4, 5)),
1442
            (torch.channels_last_3d, (2, 3, 4, 5, 6)),
1443
            (torch.preserve_format, (2, 3, 4)),
1444
        ],
1445
    )
1446
    def test_memory_format(self, dtype, device, memory_format_and_shape):
1447
        memory_format, shape = memory_format_and_shape
1448

1449
        t = torch.testing.make_tensor(shape, dtype=dtype, device=device, memory_format=memory_format)
1450

1451
        self.assertTrue(
1452
            t.is_contiguous(memory_format=torch.contiguous_format if memory_format is None else memory_format)
1453
        )
1454

1455
    @supported_dtypes
1456
    def test_noncontiguous_memory_format(self, dtype, device):
1457
        with self.assertRaisesRegex(ValueError, "`noncontiguous` and `memory_format` are mutually exclusive"):
1458
            torch.testing.make_tensor(
1459
                (2, 3, 4, 5),
1460
                dtype=dtype,
1461
                device=device,
1462
                noncontiguous=True,
1463
                memory_format=torch.channels_last,
1464
            )
1465

1466
    @supported_dtypes
1467
    def test_exclude_zero(self, dtype, device):
1468
        t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, exclude_zero=True, low=-1, high=2)
1469

1470
        self.assertTrue((t != 0).all())
1471

1472
    @supported_dtypes
1473
    def test_low_high_smoke(self, dtype, device):
1474
        low_inclusive, high_exclusive = 0, 2
1475

1476
        t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive)
1477
        if dtype.is_complex:
1478
            t = torch.view_as_real(t)
1479

1480
        self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all())
1481

1482
    @supported_dtypes
1483
    def test_low_high_default_smoke(self, dtype, device):
1484
        low_inclusive, high_exclusive = {
1485
            torch.bool: (0, 2),
1486
            torch.uint8: (0, 10),
1487
            **dict.fromkeys([torch.int8, torch.int16, torch.int32, torch.int64], (-9, 10)),
1488
        }.get(dtype, (-9, 9))
1489

1490
        t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive)
1491
        if dtype.is_complex:
1492
            t = torch.view_as_real(t)
1493

1494
        self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all())
1495

1496
    @parametrize("low_high", [(0, 0), (1, 0), (0, -1)])
1497
    @parametrize("value_types", list(itertools.product([int, float], repeat=2)))
1498
    @supported_dtypes
1499
    def test_low_ge_high(self, dtype, device, low_high, value_types):
1500
        low, high = (value_type(value) for value, value_type in zip(low_high, value_types))
1501

1502
        if low == high and (dtype.is_floating_point or dtype.is_complex):
1503
            with self.assertWarnsRegex(
1504
                    FutureWarning,
1505
                    "Passing `low==high` to `torch.testing.make_tensor` for floating or complex types is deprecated",
1506
            ):
1507
                t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low, high=high)
1508
            self.assertEqual(t, torch.full_like(t, complex(low, low) if dtype.is_complex else low))
1509
        else:
1510
            with self.assertRaisesRegex(ValueError, "`low` must be less than `high`"):
1511
                torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high)
1512

1513
    @supported_dtypes
1514
    @parametrize("low_high", [(None, torch.nan), (torch.nan, None), (torch.nan, torch.nan)])
1515
    def test_low_high_nan(self, dtype, device, low_high):
1516
        low, high = low_high
1517

1518
        with self.assertRaisesRegex(ValueError, "`low` and `high` cannot be NaN"):
1519
            torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high)
1520

1521
    @supported_dtypes
1522
    def test_low_high_outside_valid_range(self, dtype, device):
1523
        make_tensor = functools.partial(torch.testing.make_tensor, dtype=dtype, device=device)
1524

1525
        def get_dtype_limits(dtype):
1526
            if dtype is torch.bool:
1527
                return 0, 1
1528

1529
            info = (torch.finfo if dtype.is_floating_point or dtype.is_complex else torch.iinfo)(dtype)
1530
            # We are using integer bounds here, because otherwise it would be impossible to pass `low` and `high`
1531
            # outside their valid range. Python uses 64bit floating point numbers and thus trying to do something like
1532
            # `torch.ffinfo(torch.float64)max * 2` will always result in `inf`. On the flipside, Pythons `int` is
1533
            # unbounded.
1534
            return int(info.min), int(info.max)
1535

1536
        lowest_inclusive, highest_inclusive = get_dtype_limits(dtype)
1537

1538
        with self.assertRaisesRegex(ValueError, ""):
1539
            low, high = (-2, -1) if lowest_inclusive == 0 else (lowest_inclusive * 4, lowest_inclusive * 2)
1540
            make_tensor(low=low, high=high)
1541

1542
        with self.assertRaisesRegex(ValueError, ""):
1543
            make_tensor(low=highest_inclusive * 2, high=highest_inclusive * 4)
1544

1545
    @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
1546
    def test_low_high_boolean_integral1(self, dtype, device):
1547
        shape = (10_000,)
1548
        eps = 1e-4
1549

1550
        actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=-(1 - eps), high=1 - eps)
1551
        expected = torch.zeros(shape, dtype=dtype, device=device)
1552

1553
        torch.testing.assert_close(actual, expected)
1554

1555
    @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
1556
    def test_low_high_boolean_integral2(self, dtype, device):
1557
        shape = (10_000,)
1558
        if dtype is torch.bool:
1559
            low = 1
1560
        elif dtype is torch.int64:
1561
            # Due to its internals, `make_tensor` is not able to sample `torch.iinfo(torch.int64).max`
1562
            low = torch.iinfo(dtype).max - 1
1563
        else:
1564
            low = torch.iinfo(dtype).max
1565
        high = low + 1
1566

1567
        actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high)
1568
        expected = torch.full(shape, low, dtype=dtype, device=device)
1569

1570
        torch.testing.assert_close(actual, expected)
1571

1572

1573
instantiate_device_type_tests(TestMakeTensor, globals())
1574

1575

1576
def _get_test_names_for_test_class(test_cls):
1577
    """ Convenience function to get all test names for a given test class. """
1578
    test_names = [f'{test_cls.__name__}.{key}' for key in test_cls.__dict__
1579
                  if key.startswith('test_')]
1580
    return sorted(test_names)
1581

1582

1583
def _get_test_funcs_for_test_class(test_cls):
1584
    """ Convenience function to get all (test function, parametrized_name) pairs for a given test class. """
1585
    test_funcs = [(getattr(test_cls, key), key) for key in test_cls.__dict__ if key.startswith('test_')]
1586
    return test_funcs
1587

1588

1589
class TestTestParametrization(TestCase):
1590
    def test_default_names(self):
1591

1592
        class TestParametrized(TestCase):
1593
            @parametrize("x", range(5))
1594
            def test_default_names(self, x):
1595
                pass
1596

1597
            @parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
1598
            def test_two_things_default_names(self, x, y):
1599
                pass
1600

1601
        instantiate_parametrized_tests(TestParametrized)
1602

1603
        expected_test_names = [
1604
            'TestParametrized.test_default_names_x_0',
1605
            'TestParametrized.test_default_names_x_1',
1606
            'TestParametrized.test_default_names_x_2',
1607
            'TestParametrized.test_default_names_x_3',
1608
            'TestParametrized.test_default_names_x_4',
1609
            'TestParametrized.test_two_things_default_names_x_1_y_2',
1610
            'TestParametrized.test_two_things_default_names_x_2_y_3',
1611
            'TestParametrized.test_two_things_default_names_x_3_y_4',
1612
        ]
1613
        test_names = _get_test_names_for_test_class(TestParametrized)
1614
        self.assertEqual(expected_test_names, test_names)
1615

1616
    def test_name_fn(self):
1617

1618
        class TestParametrized(TestCase):
1619
            @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
1620
            def test_custom_names(self, bias):
1621
                pass
1622

1623
            @parametrize("x", [1, 2], name_fn=str)
1624
            @parametrize("y", [3, 4], name_fn=str)
1625
            @parametrize("z", [5, 6], name_fn=str)
1626
            def test_three_things_composition_custom_names(self, x, y, z):
1627
                pass
1628

1629
            @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}')
1630
            def test_two_things_custom_names_alternate(self, x, y):
1631
                pass
1632

1633
        instantiate_parametrized_tests(TestParametrized)
1634

1635
        expected_test_names = [
1636
            'TestParametrized.test_custom_names_bias',
1637
            'TestParametrized.test_custom_names_no_bias',
1638
            'TestParametrized.test_three_things_composition_custom_names_1_3_5',
1639
            'TestParametrized.test_three_things_composition_custom_names_1_3_6',
1640
            'TestParametrized.test_three_things_composition_custom_names_1_4_5',
1641
            'TestParametrized.test_three_things_composition_custom_names_1_4_6',
1642
            'TestParametrized.test_three_things_composition_custom_names_2_3_5',
1643
            'TestParametrized.test_three_things_composition_custom_names_2_3_6',
1644
            'TestParametrized.test_three_things_composition_custom_names_2_4_5',
1645
            'TestParametrized.test_three_things_composition_custom_names_2_4_6',
1646
            'TestParametrized.test_two_things_custom_names_alternate_1__2',
1647
            'TestParametrized.test_two_things_custom_names_alternate_1__3',
1648
            'TestParametrized.test_two_things_custom_names_alternate_1__4',
1649
        ]
1650
        test_names = _get_test_names_for_test_class(TestParametrized)
1651
        self.assertEqual(expected_test_names, test_names)
1652

1653
    def test_subtest_names(self):
1654

1655
        class TestParametrized(TestCase):
1656
            @parametrize("bias", [subtest(True, name='bias'),
1657
                                  subtest(False, name='no_bias')])
1658
            def test_custom_names(self, bias):
1659
                pass
1660

1661
            @parametrize("x,y", [subtest((1, 2), name='double'),
1662
                                 subtest((1, 3), name='triple'),
1663
                                 subtest((1, 4), name='quadruple')])
1664
            def test_two_things_custom_names(self, x, y):
1665
                pass
1666

1667
        instantiate_parametrized_tests(TestParametrized)
1668

1669
        expected_test_names = [
1670
            'TestParametrized.test_custom_names_bias',
1671
            'TestParametrized.test_custom_names_no_bias',
1672
            'TestParametrized.test_two_things_custom_names_double',
1673
            'TestParametrized.test_two_things_custom_names_quadruple',
1674
            'TestParametrized.test_two_things_custom_names_triple',
1675
        ]
1676
        test_names = _get_test_names_for_test_class(TestParametrized)
1677
        self.assertEqual(expected_test_names, test_names)
1678

1679
    def test_apply_param_specific_decorators(self):
1680
        # Test that decorators can be applied on a per-param basis.
1681

1682
        def test_dec(func):
1683
            func._decorator_applied = True
1684
            return func
1685

1686
        class TestParametrized(TestCase):
1687
            @parametrize("x", [subtest(1, name='one'),
1688
                               subtest(2, name='two', decorators=[test_dec]),
1689
                               subtest(3, name='three')])
1690
            def test_param(self, x):
1691
                pass
1692

1693
        instantiate_parametrized_tests(TestParametrized)
1694

1695
        for test_func, name in _get_test_funcs_for_test_class(TestParametrized):
1696
            self.assertEqual(hasattr(test_func, '_decorator_applied'), name == 'test_param_two')
1697

1698
    def test_compose_param_specific_decorators(self):
1699
        # Test that multiple per-param decorators compose correctly.
1700

1701
        def test_dec(func):
1702
            func._decorator_applied = True
1703
            return func
1704

1705
        class TestParametrized(TestCase):
1706
            @parametrize("x", [subtest(1),
1707
                               subtest(2, decorators=[test_dec]),
1708
                               subtest(3)])
1709
            @parametrize("y", [subtest(False, decorators=[test_dec]),
1710
                               subtest(True)])
1711
            def test_param(self, x, y):
1712
                pass
1713

1714
        instantiate_parametrized_tests(TestParametrized)
1715

1716
        for test_func, name in _get_test_funcs_for_test_class(TestParametrized):
1717
            # Decorator should be applied whenever either x == 2 or y == False.
1718
            should_apply = ('x_2' in name) or ('y_False' in name)
1719
            self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
1720

1721
    def test_modules_decorator_misuse_error(self):
1722
        # Test that @modules errors out when used with instantiate_parametrized_tests().
1723

1724
        class TestParametrized(TestCase):
1725
            @modules(module_db)
1726
            def test_modules(self, module_info):
1727
                pass
1728

1729
        with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'):
1730
            instantiate_parametrized_tests(TestParametrized)
1731

1732
    def test_ops_decorator_misuse_error(self):
1733
        # Test that @ops errors out when used with instantiate_parametrized_tests().
1734

1735
        class TestParametrized(TestCase):
1736
            @ops(op_db)
1737
            def test_ops(self, module_info):
1738
                pass
1739

1740
        with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'):
1741
            instantiate_parametrized_tests(TestParametrized)
1742

1743
    def test_multiple_handling_of_same_param_error(self):
1744
        # Test that multiple decorators handling the same param errors out.
1745

1746
        class TestParametrized(TestCase):
1747
            @parametrize("x", range(3))
1748
            @parametrize("x", range(5))
1749
            def test_param(self, x):
1750
                pass
1751

1752
        with self.assertRaisesRegex(RuntimeError, 'multiple parametrization decorators'):
1753
            instantiate_parametrized_tests(TestParametrized)
1754

1755
    @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
1756
    def test_subtest_expected_failure(self, x):
1757
        if x == 2:
1758
            raise RuntimeError('Boom')
1759

1760
    @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
1761
    @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
1762
    def test_two_things_subtest_expected_failure(self, x, y):
1763
        if x == 1 or y == 6:
1764
            raise RuntimeError('Boom')
1765

1766

1767
class TestTestParametrizationDeviceType(TestCase):
1768
    def test_unparametrized_names(self, device):
1769
        # This test exists to protect against regressions in device / dtype test naming
1770
        # due to parametrization logic.
1771

1772
        device = self.device_type
1773

1774
        class TestParametrized(TestCase):
1775
            def test_device_specific(self, device):
1776
                pass
1777

1778
            @dtypes(torch.float32, torch.float64)
1779
            def test_device_dtype_specific(self, device, dtype):
1780
                pass
1781

1782
        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1783

1784
        device_cls = locals()[f'TestParametrized{device.upper()}']
1785
        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1786
            '{}.test_device_dtype_specific_{}_float32',
1787
            '{}.test_device_dtype_specific_{}_float64',
1788
            '{}.test_device_specific_{}')
1789
        ]
1790
        test_names = _get_test_names_for_test_class(device_cls)
1791
        self.assertEqual(expected_test_names, test_names)
1792

1793
    def test_empty_param_names(self, device):
1794
        # If no param names are passed, ensure things still work without parametrization.
1795
        device = self.device_type
1796

1797
        class TestParametrized(TestCase):
1798
            @parametrize("", [])
1799
            def test_foo(self, device):
1800
                pass
1801

1802
            @parametrize("", range(5))
1803
            def test_bar(self, device):
1804
                pass
1805

1806
        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1807

1808
        device_cls = locals()[f'TestParametrized{device.upper()}']
1809
        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1810
            '{}.test_bar_{}',
1811
            '{}.test_foo_{}')
1812
        ]
1813
        test_names = _get_test_names_for_test_class(device_cls)
1814
        self.assertEqual(expected_test_names, test_names)
1815

1816
    def test_empty_param_list(self, device):
1817
        # If no param values are passed, ensure a helpful error message is thrown.
1818
        # In the wild, this could indicate reuse of an exhausted generator.
1819
        device = self.device_type
1820

1821
        generator = (a for a in range(5))
1822

1823
        class TestParametrized(TestCase):
1824
            @parametrize("x", generator)
1825
            def test_foo(self, device, x):
1826
                pass
1827

1828
            # Reuse generator from first test function.
1829
            @parametrize("y", generator)
1830
            def test_bar(self, device, y):
1831
                pass
1832

1833
        with self.assertRaisesRegex(ValueError, 'An empty arg_values was passed'):
1834
            instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1835

1836
    def test_default_names(self, device):
1837
        device = self.device_type
1838

1839
        class TestParametrized(TestCase):
1840
            @parametrize("x", range(5))
1841
            def test_default_names(self, device, x):
1842
                pass
1843

1844
            @parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
1845
            def test_two_things_default_names(self, device, x, y):
1846
                pass
1847

1848

1849
        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1850

1851
        device_cls = locals()[f'TestParametrized{device.upper()}']
1852
        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1853
            '{}.test_default_names_x_0_{}',
1854
            '{}.test_default_names_x_1_{}',
1855
            '{}.test_default_names_x_2_{}',
1856
            '{}.test_default_names_x_3_{}',
1857
            '{}.test_default_names_x_4_{}',
1858
            '{}.test_two_things_default_names_x_1_y_2_{}',
1859
            '{}.test_two_things_default_names_x_2_y_3_{}',
1860
            '{}.test_two_things_default_names_x_3_y_4_{}')
1861
        ]
1862
        test_names = _get_test_names_for_test_class(device_cls)
1863
        self.assertEqual(expected_test_names, test_names)
1864

1865
    def test_default_name_non_primitive(self, device):
1866
        device = self.device_type
1867

1868
        class TestParametrized(TestCase):
1869
            @parametrize("x", [1, .5, "foo", object()])
1870
            def test_default_names(self, device, x):
1871
                pass
1872

1873
            @parametrize("x,y", [(1, object()), (object(), .5), (object(), object())])
1874
            def test_two_things_default_names(self, device, x, y):
1875
                pass
1876

1877
        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1878

1879
        device_cls = locals()[f'TestParametrized{device.upper()}']
1880
        expected_test_names = sorted(name.format(device_cls.__name__, device) for name in (
1881
            '{}.test_default_names_x_1_{}',
1882
            '{}.test_default_names_x_0_5_{}',
1883
            '{}.test_default_names_x_foo_{}',
1884
            '{}.test_default_names_x3_{}',
1885
            '{}.test_two_things_default_names_x_1_y0_{}',
1886
            '{}.test_two_things_default_names_x1_y_0_5_{}',
1887
            '{}.test_two_things_default_names_x2_y2_{}')
1888
        )
1889
        test_names = _get_test_names_for_test_class(device_cls)
1890
        self.assertEqual(expected_test_names, test_names)
1891

1892
    def test_name_fn(self, device):
1893
        device = self.device_type
1894

1895
        class TestParametrized(TestCase):
1896
            @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
1897
            def test_custom_names(self, device, bias):
1898
                pass
1899

1900
            @parametrize("x", [1, 2], name_fn=str)
1901
            @parametrize("y", [3, 4], name_fn=str)
1902
            @parametrize("z", [5, 6], name_fn=str)
1903
            def test_three_things_composition_custom_names(self, device, x, y, z):
1904
                pass
1905

1906
            @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}')
1907
            def test_two_things_custom_names_alternate(self, device, x, y):
1908
                pass
1909

1910
        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1911

1912
        device_cls = locals()[f'TestParametrized{device.upper()}']
1913
        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1914
            '{}.test_custom_names_bias_{}',
1915
            '{}.test_custom_names_no_bias_{}',
1916
            '{}.test_three_things_composition_custom_names_1_3_5_{}',
1917
            '{}.test_three_things_composition_custom_names_1_3_6_{}',
1918
            '{}.test_three_things_composition_custom_names_1_4_5_{}',
1919
            '{}.test_three_things_composition_custom_names_1_4_6_{}',
1920
            '{}.test_three_things_composition_custom_names_2_3_5_{}',
1921
            '{}.test_three_things_composition_custom_names_2_3_6_{}',
1922
            '{}.test_three_things_composition_custom_names_2_4_5_{}',
1923
            '{}.test_three_things_composition_custom_names_2_4_6_{}',
1924
            '{}.test_two_things_custom_names_alternate_1__2_{}',
1925
            '{}.test_two_things_custom_names_alternate_1__3_{}',
1926
            '{}.test_two_things_custom_names_alternate_1__4_{}')
1927
        ]
1928
        test_names = _get_test_names_for_test_class(device_cls)
1929
        self.assertEqual(expected_test_names, test_names)
1930

1931
    def test_subtest_names(self, device):
1932
        device = self.device_type
1933

1934
        class TestParametrized(TestCase):
1935
            @parametrize("bias", [subtest(True, name='bias'),
1936
                                  subtest(False, name='no_bias')])
1937
            def test_custom_names(self, device, bias):
1938
                pass
1939

1940
            @parametrize("x,y", [subtest((1, 2), name='double'),
1941
                                 subtest((1, 3), name='triple'),
1942
                                 subtest((1, 4), name='quadruple')])
1943
            def test_two_things_custom_names(self, device, x, y):
1944
                pass
1945

1946
        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1947

1948
        device_cls = locals()[f'TestParametrized{device.upper()}']
1949
        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1950
            '{}.test_custom_names_bias_{}',
1951
            '{}.test_custom_names_no_bias_{}',
1952
            '{}.test_two_things_custom_names_double_{}',
1953
            '{}.test_two_things_custom_names_quadruple_{}',
1954
            '{}.test_two_things_custom_names_triple_{}')
1955
        ]
1956
        test_names = _get_test_names_for_test_class(device_cls)
1957
        self.assertEqual(expected_test_names, test_names)
1958

1959
    def test_ops_composition_names(self, device):
1960
        device = self.device_type
1961

1962
        class TestParametrized(TestCase):
1963
            @ops(op_db)
1964
            @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
1965
            def test_op_parametrized(self, device, dtype, op, flag):
1966
                pass
1967

1968
        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1969

1970
        device_cls = locals()[f'TestParametrized{device.upper()}']
1971
        expected_test_names = []
1972
        for op in op_db:
1973
            for dtype in op.supported_dtypes(torch.device(device).type):
1974
                for flag_part in ('flag_disabled', 'flag_enabled'):
1975
                    expected_name = f'{device_cls.__name__}.test_op_parametrized_{op.formatted_name}_{flag_part}_{device}_{dtype_name(dtype)}'  # noqa: B950
1976
                    expected_test_names.append(expected_name)
1977

1978
        test_names = _get_test_names_for_test_class(device_cls)
1979
        self.assertEqual(sorted(expected_test_names), sorted(test_names))
1980

1981
    def test_modules_composition_names(self, device):
1982
        device = self.device_type
1983

1984
        class TestParametrized(TestCase):
1985
            @modules(module_db)
1986
            @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
1987
            def test_module_parametrized(self, device, dtype, module_info, training, flag):
1988
                pass
1989

1990
        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1991

1992
        device_cls = locals()[f'TestParametrized{device.upper()}']
1993
        expected_test_names = []
1994
        for module_info in module_db:
1995
            for dtype in module_info.dtypes:
1996
                for flag_part in ('flag_disabled', 'flag_enabled'):
1997
                    expected_train_modes = (
1998
                        ['train_mode', 'eval_mode'] if module_info.train_and_eval_differ else [''])
1999
                    for training_part in expected_train_modes:
2000
                        expected_name = '{}.test_module_parametrized_{}{}_{}_{}_{}'.format(
2001
                            device_cls.__name__, module_info.formatted_name,
2002
                            '_' + training_part if len(training_part) > 0 else '',
2003
                            flag_part, device, dtype_name(dtype))
2004
                        expected_test_names.append(expected_name)
2005

2006
        test_names = _get_test_names_for_test_class(device_cls)
2007
        self.assertEqual(sorted(expected_test_names), sorted(test_names))
2008

2009
    def test_ops_decorator_applies_op_and_param_specific_decorators(self, device):
2010
        # Test that decorators can be applied on a per-op / per-param basis.
2011

2012
        # Create a test op, OpInfo entry, and decorator to apply.
2013
        def test_op(x):
2014
            return -x
2015

2016
        def test_dec(func):
2017
            func._decorator_applied = True
2018
            return func
2019

2020
        test_op_info = OpInfo(
2021
            'test_op',
2022
            op=test_op,
2023
            dtypes=floating_types(),
2024
            sample_inputs_func=lambda _: [],
2025
            decorators=[
2026
                DecorateInfo(test_dec, 'TestParametrized', 'test_op_param',
2027
                             device_type='cpu', dtypes=[torch.float64],
2028
                             active_if=lambda p: p['x'] == 2)
2029
            ])
2030

2031
        class TestParametrized(TestCase):
2032
            @ops(op_db + [test_op_info])
2033
            @parametrize("x", [2, 3])
2034
            def test_op_param(self, device, dtype, op, x):
2035
                pass
2036

2037
            @ops(op_db + [test_op_info])
2038
            @parametrize("y", [
2039
                subtest(4),
2040
                subtest(5, decorators=[test_dec])])
2041
            def test_other(self, device, dtype, op, y):
2042
                pass
2043

2044
            @decorateIf(test_dec, lambda p: p['dtype'] == torch.int16)
2045
            @ops(op_db)
2046
            def test_three(self, device, dtype, op):
2047
                pass
2048

2049
        device = self.device_type
2050
        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2051
        device_cls = locals()[f'TestParametrized{device.upper()}']
2052

2053
        for test_func, name in _get_test_funcs_for_test_class(device_cls):
2054
            should_apply = (name == 'test_op_param_test_op_x_2_cpu_float64' or
2055
                            ('test_other' in name and 'y_5' in name) or
2056
                            ('test_three' in name and name.endswith('_int16')))
2057
            self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
2058

2059
    def test_modules_decorator_applies_module_and_param_specific_decorators(self, device):
2060
        # Test that decorators can be applied on a per-module / per-param basis.
2061

2062
        # Create a test module, ModuleInfo entry, and decorator to apply.
2063
        class TestModule(torch.nn.Module):
2064
            def __init__(self) -> None:
2065
                super().__init__()
2066
                self.x = torch.nn.Parameter(torch.randn(3))
2067

2068
            def forward(self, y):
2069
                return self.x + y
2070

2071
        def test_dec(func):
2072
            func._decorator_applied = True
2073
            return func
2074

2075
        test_module_info = ModuleInfo(
2076
            TestModule,
2077
            module_inputs_func=lambda _: [],
2078
            decorators=[
2079
                DecorateInfo(test_dec, 'TestParametrized', 'test_module_param',
2080
                             device_type='cpu', dtypes=[torch.float64],
2081
                             active_if=lambda p: p['x'] == 2)
2082
            ])
2083

2084
        class TestParametrized(TestCase):
2085
            @modules(module_db + [test_module_info])
2086
            @parametrize("x", [2, 3])
2087
            def test_module_param(self, device, dtype, module_info, training, x):
2088
                pass
2089

2090
            @modules(module_db + [test_module_info])
2091
            @parametrize("y", [
2092
                subtest(4),
2093
                subtest(5, decorators=[test_dec])])
2094
            def test_other(self, device, dtype, module_info, training, y):
2095
                pass
2096

2097
            @decorateIf(test_dec, lambda p: p['dtype'] == torch.float64)
2098
            @modules(module_db)
2099
            def test_three(self, device, dtype, module_info):
2100
                pass
2101

2102
        device = self.device_type
2103
        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2104
        device_cls = locals()[f'TestParametrized{device.upper()}']
2105

2106
        for test_func, name in _get_test_funcs_for_test_class(device_cls):
2107
            should_apply = (name == 'test_module_param_TestModule_x_2_cpu_float64' or
2108
                            ('test_other' in name and 'y_5' in name) or
2109
                            ('test_three' in name and name.endswith('float64')))
2110
            self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
2111

2112
    def test_param_specific_decoration(self, device):
2113

2114
        def test_dec(func):
2115
            func._decorator_applied = True
2116
            return func
2117

2118
        class TestParametrized(TestCase):
2119
            @decorateIf(test_dec, lambda params: params["x"] == 1 and params["y"])
2120
            @parametrize("x", range(5))
2121
            @parametrize("y", [False, True])
2122
            def test_param(self, x, y):
2123
                pass
2124

2125
        device = self.device_type
2126
        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2127
        device_cls = locals()[f'TestParametrized{device.upper()}']
2128

2129
        for test_func, name in _get_test_funcs_for_test_class(device_cls):
2130
            should_apply = ('test_param_x_1_y_True' in name)
2131
            self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
2132

2133
    def test_dtypes_composition_valid(self, device):
2134
        # Test checks that @parametrize and @dtypes compose as expected when @parametrize
2135
        # doesn't set dtype.
2136

2137
        device = self.device_type
2138

2139
        class TestParametrized(TestCase):
2140
            @dtypes(torch.float32, torch.float64)
2141
            @parametrize("x", range(3))
2142
            def test_parametrized(self, x, dtype):
2143
                pass
2144

2145
        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2146

2147
        device_cls = locals()[f'TestParametrized{device.upper()}']
2148
        expected_test_names = [name.format(device_cls.__name__, device) for name in (
2149
            '{}.test_parametrized_x_0_{}_float32',
2150
            '{}.test_parametrized_x_0_{}_float64',
2151
            '{}.test_parametrized_x_1_{}_float32',
2152
            '{}.test_parametrized_x_1_{}_float64',
2153
            '{}.test_parametrized_x_2_{}_float32',
2154
            '{}.test_parametrized_x_2_{}_float64')
2155
        ]
2156
        test_names = _get_test_names_for_test_class(device_cls)
2157
        self.assertEqual(sorted(expected_test_names), sorted(test_names))
2158

2159
    def test_dtypes_composition_invalid(self, device):
2160
        # Test checks that @dtypes cannot be composed with parametrization decorators when they
2161
        # also try to set dtype.
2162

2163
        device = self.device_type
2164

2165
        class TestParametrized(TestCase):
2166
            @dtypes(torch.float32, torch.float64)
2167
            @parametrize("dtype", [torch.int32, torch.int64])
2168
            def test_parametrized(self, dtype):
2169
                pass
2170

2171
        with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
2172
            instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2173

2174
        # Verify proper error behavior with @ops + @dtypes, as both try to set dtype.
2175

2176
        class TestParametrized(TestCase):
2177
            @dtypes(torch.float32, torch.float64)
2178
            @ops(op_db)
2179
            def test_parametrized(self, op, dtype):
2180
                pass
2181

2182
        with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
2183
            instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2184

2185
    def test_multiple_handling_of_same_param_error(self, device):
2186
        # Test that multiple decorators handling the same param errors out.
2187
        # Both @modules and @ops handle the dtype param.
2188

2189
        class TestParametrized(TestCase):
2190
            @ops(op_db)
2191
            @modules(module_db)
2192
            def test_param(self, device, dtype, op, module_info, training):
2193
                pass
2194

2195
        with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
2196
            instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2197

2198
    @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
2199
    def test_subtest_expected_failure(self, device, x):
2200
        if x == 2:
2201
            raise RuntimeError('Boom')
2202

2203
    @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
2204
    @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
2205
    def test_two_things_subtest_expected_failure(self, device, x, y):
2206
        if x == 1 or y == 6:
2207
            raise RuntimeError('Boom')
2208

2209

2210
instantiate_parametrized_tests(TestTestParametrization)
2211
instantiate_device_type_tests(TestTestParametrizationDeviceType, globals())
2212

2213

2214
class TestImports(TestCase):
2215
    @classmethod
2216
    def _check_python_output(cls, program) -> str:
2217
        return subprocess.check_output(
2218
            [sys.executable, "-W", "always", "-c", program],
2219
            stderr=subprocess.STDOUT,
2220
            # On Windows, opening the subprocess with the default CWD makes `import torch`
2221
            # fail, so just set CWD to this script's directory
2222
            cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
2223

2224
    def test_circular_dependencies(self) -> None:
2225
        """ Checks that all modules inside torch can be imported
2226
        Prevents regression reported in https://github.com/pytorch/pytorch/issues/77441 """
2227
        ignored_modules = ["torch.utils.tensorboard",  # deps on tensorboard
2228
                           "torch.distributed.elastic.rendezvous",  # depps on etcd
2229
                           "torch.backends._coreml",  # depends on pycoreml
2230
                           "torch.contrib.",  # something weird
2231
                           "torch.testing._internal.distributed.",  # just fails
2232
                           "torch.ao.pruning._experimental.",  # depends on pytorch_lightning, not user-facing
2233
                           "torch.onnx._internal",  # depends on onnx-script
2234
                           "torch._inductor.runtime.triton_helpers",  # depends on triton
2235
                           "torch._inductor.codegen.cuda",  # depends on cutlass
2236
                           ]
2237
        # See https://github.com/pytorch/pytorch/issues/77801
2238
        if not sys.version_info >= (3, 9):
2239
            ignored_modules.append("torch.utils.benchmark")
2240
        if IS_WINDOWS or IS_MACOS or IS_JETSON:
2241
            # Distributed should be importable on Windows(except nn.api.), but not on Mac
2242
            if IS_MACOS or IS_JETSON:
2243
                ignored_modules.append("torch.distributed.")
2244
            else:
2245
                ignored_modules.append("torch.distributed.nn.api.")
2246
                ignored_modules.append("torch.distributed.optim.")
2247
                ignored_modules.append("torch.distributed.rpc.")
2248
            ignored_modules.append("torch.testing._internal.dist_utils")
2249
            # And these both end up with transitive dependencies on distributed
2250
            ignored_modules.append("torch.nn.parallel._replicated_tensor_ddp_interop")
2251
            ignored_modules.append("torch.testing._internal.common_fsdp")
2252
            ignored_modules.append("torch.testing._internal.common_distributed")
2253

2254
        torch_dir = os.path.dirname(torch.__file__)
2255
        for base, folders, files in os.walk(torch_dir):
2256
            prefix = os.path.relpath(base, os.path.dirname(torch_dir)).replace(os.path.sep, ".")
2257
            for f in files:
2258
                if not f.endswith(".py"):
2259
                    continue
2260
                mod_name = f"{prefix}.{f[:-3]}" if f != "__init__.py" else prefix
2261
                # Do not attempt to import executable modules
2262
                if f == "__main__.py":
2263
                    continue
2264
                if any(mod_name.startswith(x) for x in ignored_modules):
2265
                    continue
2266
                try:
2267
                    mod = importlib.import_module(mod_name)
2268
                except Exception as e:
2269
                    raise RuntimeError(f"Failed to import {mod_name}: {e}") from e
2270
                self.assertTrue(inspect.ismodule(mod))
2271

2272
    @unittest.skipIf(IS_WINDOWS, "TODO enable on Windows")
2273
    def test_lazy_imports_are_lazy(self) -> None:
2274
        out = self._check_python_output("import sys;import torch;print(all(x not in sys.modules for x in torch._lazy_modules))")
2275
        self.assertEqual(out.strip(), "True")
2276

2277
    @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning")
2278
    def test_no_warning_on_import(self) -> None:
2279
        out = self._check_python_output("import torch")
2280
        self.assertEqual(out, "")
2281

2282
    def test_not_import_sympy(self) -> None:
2283
        out = self._check_python_output("import torch;import sys;print('sympy' not in sys.modules)")
2284
        self.assertEqual(out.strip(), "True",
2285
                         "PyTorch should not depend on SymPy at import time as importing SymPy is *very* slow.\n"
2286
                         "See the beginning of the following blog post for how to profile and find which file is importing sympy:\n"
2287
                         "https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589\n\n"
2288
                         "If you hit this error, you may want to:\n"
2289
                         "  - Refactor your code to avoid depending on sympy files you may not need to depend\n"
2290
                         "  - Use TYPE_CHECKING if you are using sympy + strings if you are using sympy on type annotations\n"
2291
                         "  - Import things that depend on SymPy locally")
2292

2293
    @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning")
2294
    @parametrize('path', ['torch', 'functorch'])
2295
    def test_no_mutate_global_logging_on_import(self, path) -> None:
2296
        # Calling logging.basicConfig, among other things, modifies the global
2297
        # logging state. It is not OK to modify the global logging state on
2298
        # `import torch` (or other submodules we own) because users do not expect it.
2299
        expected = 'abcdefghijklmnopqrstuvwxyz'
2300
        commands = [
2301
            'import logging',
2302
            f'import {path}',
2303
            '_logger = logging.getLogger("torch_test_testing")',
2304
            'logging.root.addHandler(logging.StreamHandler())',
2305
            'logging.root.setLevel(logging.INFO)',
2306
            f'_logger.info("{expected}")'
2307
        ]
2308
        out = self._check_python_output("; ".join(commands))
2309
        self.assertEqual(out.strip(), expected)
2310

2311
class TestOpInfos(TestCase):
2312
    def test_sample_input(self) -> None:
2313
        a, b, c, d, e = (object() for _ in range(5))
2314

2315
        # Construction with natural syntax
2316
        s = SampleInput(a, b, c, d=d, e=e)
2317
        assert s.input is a
2318
        assert s.args == (b, c)
2319
        assert s.kwargs == dict(d=d, e=e)
2320

2321
        # Construction with explicit args and kwargs
2322
        s = SampleInput(a, args=(b,), kwargs=dict(c=c, d=d, e=e))
2323
        assert s.input is a
2324
        assert s.args == (b,)
2325
        assert s.kwargs == dict(c=c, d=d, e=e)
2326

2327
        # Construction with a mixed form will error
2328
        with self.assertRaises(AssertionError):
2329
            s = SampleInput(a, b, c, args=(d, e))
2330

2331
        with self.assertRaises(AssertionError):
2332
            s = SampleInput(a, b, c, kwargs=dict(d=d, e=e))
2333

2334
        with self.assertRaises(AssertionError):
2335
            s = SampleInput(a, args=(b, c), d=d, e=e)
2336

2337
        with self.assertRaises(AssertionError):
2338
            s = SampleInput(a, b, c=c, kwargs=dict(d=d, e=e))
2339

2340
        # Mixing metadata into "natural" construction will error
2341
        with self.assertRaises(AssertionError):
2342
            s = SampleInput(a, b, name="foo")
2343

2344
        with self.assertRaises(AssertionError):
2345
            s = SampleInput(a, b, output_process_fn_grad=lambda x: x)
2346

2347
        with self.assertRaises(AssertionError):
2348
            s = SampleInput(a, b, broadcasts_input=True)
2349

2350
        # But when only input is given, metadata is allowed for backward
2351
        # compatibility
2352
        s = SampleInput(a, broadcasts_input=True)
2353
        assert s.input is a
2354
        assert s.broadcasts_input
2355

2356
    def test_sample_input_metadata(self) -> None:
2357
        a, b = (object() for _ in range(2))
2358
        s1 = SampleInput(a, b=b)
2359
        self.assertIs(s1.output_process_fn_grad(None), None)
2360
        self.assertFalse(s1.broadcasts_input)
2361
        self.assertEqual(s1.name, "")
2362

2363
        s2 = s1.with_metadata(
2364
            output_process_fn_grad=lambda x: a,
2365
            broadcasts_input=True,
2366
            name="foo",
2367
        )
2368
        self.assertIs(s1, s2)
2369
        self.assertIs(s2.output_process_fn_grad(None), a)
2370
        self.assertTrue(s2.broadcasts_input)
2371
        self.assertEqual(s2.name, "foo")
2372

2373

2374
# Tests that validate the various sample generating functions on each OpInfo.
2375
class TestOpInfoSampleFunctions(TestCase):
2376

2377
    @ops(op_db, dtypes=OpDTypes.any_one)
2378
    def test_opinfo_sample_generators(self, device, dtype, op):
2379
        # Test op.sample_inputs doesn't generate multiple samples when called
2380
        samples = op.sample_inputs(device, dtype)
2381
        self.assertIsInstance(samples, Iterator)
2382

2383
    @ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one)
2384
    def test_opinfo_reference_generators(self, device, dtype, op):
2385
        # Test op.reference_inputs doesn't generate multiple samples when called
2386
        samples = op.reference_inputs(device, dtype)
2387
        self.assertIsInstance(samples, Iterator)
2388

2389
    @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
2390
    def test_opinfo_error_generators(self, device, op):
2391
        # Test op.error_inputs doesn't generate multiple inputs when called
2392
        samples = op.error_inputs(device)
2393
        self.assertIsInstance(samples, Iterator)
2394

2395

2396
instantiate_device_type_tests(TestOpInfoSampleFunctions, globals())
2397
instantiate_parametrized_tests(TestImports)
2398

2399

2400
if __name__ == '__main__':
2401
    run_tests()
2402

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.