15
from typing import Any, Callable, Iterator, List, Tuple
19
from torch.testing import make_tensor
20
from torch.testing._internal.common_utils import \
21
(IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest,
22
parametrize, subtest, instantiate_parametrized_tests, dtype_name, TEST_WITH_ROCM, decorateIf)
23
from torch.testing._internal.common_device_type import \
24
(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
25
get_device_type_test_bases, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes,
26
deviceCountAtLeast, ops, expectedFailureMeta, OpDTypes)
27
from torch.testing._internal.common_methods_invocations import op_db
28
from torch.testing._internal import opinfo
29
from torch.testing._internal.common_dtype import all_types_and_complex_and, floating_types
30
from torch.testing._internal.common_modules import modules, module_db, ModuleInfo
31
from torch.testing._internal.opinfo.core import SampleInput, DecorateInfo, OpInfo
35
class TestTesting(TestCase):
37
@dtypes(*all_types_and_complex_and(torch.bool, torch.half))
38
def test_assertEqual_numpy(self, device, dtype):
47
for test_size in test_sizes:
48
a = make_tensor(test_size, dtype=dtype, device=device, low=-5, high=5)
50
msg = f'size: {test_size}'
51
self.assertEqual(a_n, a, rtol=0, atol=0, msg=msg)
52
self.assertEqual(a, a_n, rtol=0, atol=0, msg=msg)
53
self.assertEqual(a_n, a_n, rtol=0, atol=0, msg=msg)
55
def test_assertEqual_longMessage(self):
59
long_message = self.longMessage
62
self.longMessage = False
64
self.assertEqual(actual, expected)
65
except AssertionError as error:
66
default_msg = str(error)
68
raise AssertionError("AssertionError not raised")
70
self.longMessage = True
71
extra_msg = "sentinel"
72
with self.assertRaisesRegex(AssertionError, re.escape(f"{default_msg}\n{extra_msg}")):
73
self.assertEqual(actual, expected, msg=extra_msg)
75
self.longMessage = long_message
77
def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05):
79
a = torch.tensor((test[0],), device=device, dtype=dtype)
80
b = torch.tensor((test[1],), device=device, dtype=dtype)
82
actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol)
84
self.assertEqual(actual.item(), expected)
86
def test_isclose_bool(self, device):
94
self._isclose_helper(tests, device, torch.bool, False)
97
torch.int8, torch.int16, torch.int32, torch.int64)
98
def test_isclose_integer(self, device, dtype):
105
self._isclose_helper(tests, device, dtype, False)
114
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
116
if dtype is torch.uint8:
127
self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5)
129
@onlyNativeDeviceTypes
130
@dtypes(torch.float16, torch.float32, torch.float64)
131
def test_isclose_float(self, device, dtype):
135
(float('inf'), float('inf'), True),
136
(-float('inf'), float('inf'), False),
137
(float('inf'), float('nan'), False),
138
(float('nan'), float('nan'), False),
139
(0, float('nan'), False),
143
self._isclose_helper(tests, device, dtype, False)
146
eps = 1e-2 if dtype is torch.half else 1e-6
154
(-.25 - eps, .5, False),
156
(.25 + eps, -.5, False),
159
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
163
(0, float('nan'), False),
164
(float('inf'), float('nan'), False),
165
(float('nan'), float('nan'), True),
168
self._isclose_helper(tests, device, dtype, True)
170
@unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle")
171
@dtypes(torch.complex64, torch.complex128)
172
def test_isclose_complex(self, device, dtype):
174
(complex(1, 1), complex(1, 1 + 1e-8), True),
175
(complex(0, 1), complex(1, 1), False),
176
(complex(1, 1), complex(1, 0), False),
177
(complex(1, 1), complex(1, float('nan')), False),
178
(complex(1, float('nan')), complex(1, float('nan')), False),
179
(complex(1, 1), complex(1, float('inf')), False),
180
(complex(float('inf'), 1), complex(1, float('inf')), False),
181
(complex(-float('inf'), 1), complex(1, float('inf')), False),
182
(complex(-float('inf'), 1), complex(float('inf'), 1), False),
183
(complex(float('inf'), 1), complex(float('inf'), 1), True),
184
(complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False),
187
self._isclose_helper(tests, device, dtype, False)
195
(complex(0, 0), complex(1, 0), True),
196
(complex(0, 0), complex(1 + eps, 0), False),
197
(complex(1, 0), complex(0, 0), False),
198
(complex(1, 0), complex(3, 0), True),
199
(complex(1 - eps, 0), complex(3, 0), False),
200
(complex(-.25, 0), complex(.5, 0), True),
201
(complex(-.25 - eps, 0), complex(.5, 0), False),
202
(complex(.25, 0), complex(-.5, 0), True),
203
(complex(.25 + eps, 0), complex(-.5, 0), False),
205
(complex(0, 0), complex(0, 1), True),
206
(complex(0, 0), complex(0, 1 + eps), False),
207
(complex(0, 1), complex(0, 0), False),
208
(complex(0, 1), complex(0, 3), True),
209
(complex(0, 1 - eps), complex(0, 3), False),
210
(complex(0, -.25), complex(0, .5), True),
211
(complex(0, -.25 - eps), complex(0, .5), False),
212
(complex(0, .25), complex(0, -.5), True),
213
(complex(0, .25 + eps), complex(0, -.5), False),
216
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
221
(complex(1, -1), complex(-1, 1), False),
222
(complex(1, -1), complex(2, -2), True),
223
(complex(-math.sqrt(2), math.sqrt(2)),
224
complex(-math.sqrt(.5), math.sqrt(.5)), True),
225
(complex(-math.sqrt(2), math.sqrt(2)),
226
complex(-math.sqrt(.501), math.sqrt(.499)), False),
227
(complex(2, 4), complex(1., 8.8523607), True),
228
(complex(2, 4), complex(1., 8.8523607 + eps), False),
229
(complex(1, 99), complex(4, 100), True),
231
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
235
(complex(1, 1), complex(1, float('nan')), False),
236
(complex(1, 1), complex(float('nan'), 1), False),
237
(complex(float('nan'), 1), complex(float('nan'), 1), True),
238
(complex(float('nan'), 1), complex(1, float('nan')), True),
239
(complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True),
241
self._isclose_helper(tests, device, dtype, True)
245
@dtypes(torch.bool, torch.uint8,
246
torch.int8, torch.int16, torch.int32, torch.int64,
247
torch.float16, torch.float32, torch.float64)
248
def test_isclose_atol_rtol_greater_than_zero(self, device, dtype):
249
t = torch.tensor((1,), device=device, dtype=dtype)
251
with self.assertRaises(RuntimeError):
252
torch.isclose(t, t, atol=-1, rtol=1)
253
with self.assertRaises(RuntimeError):
254
torch.isclose(t, t, atol=1, rtol=-1)
255
with self.assertRaises(RuntimeError):
256
torch.isclose(t, t, atol=-1, rtol=-1)
258
def test_isclose_equality_shortcut(self):
262
a = torch.tensor(2 ** 53, dtype=torch.int64)
265
self.assertFalse(torch.isclose(a, b, rtol=0, atol=0))
267
@dtypes(torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128)
268
def test_isclose_nan_equality_shortcut(self, device, dtype):
269
if dtype.is_floating_point:
272
a = complex(torch.nan, 0)
273
b = complex(0, torch.nan)
276
tests = [(a, b, expected)]
278
self._isclose_helper(tests, device, dtype, equal_nan=True, rtol=0, atol=0)
284
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
287
def test_cuda_assert_should_stop_common_utils_test_suite(self, device):
289
stderr = TestCase.runWithPytorchAPIUsageStderr("""\
290
#!/usr/bin/env python3
293
from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
295
class TestThatContainsCUDAAssertFailure(TestCase):
298
def test_throw_unrecoverable_cuda_exception(self):
299
x = torch.rand(10, device='cuda')
300
# cause unrecoverable CUDA exception, recoverable on CPU
301
y = x[torch.tensor([25])].cpu()
304
def test_trivial_passing_test_case_on_cpu_cuda(self):
305
x1 = torch.tensor([0., 1.], device='cuda')
306
x2 = torch.tensor([0., 1.], device='cpu')
307
self.assertEqual(x1, x2)
309
if __name__ == '__main__':
313
self.assertIn('CUDA error: device-side assert triggered', stderr)
315
self.assertIn('errors=1', stderr)
318
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
321
def test_cuda_assert_should_stop_common_device_type_test_suite(self, device):
323
stderr = TestCase.runWithPytorchAPIUsageStderr("""\
324
#!/usr/bin/env python3
327
from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
328
from torch.testing._internal.common_device_type import instantiate_device_type_tests
330
class TestThatContainsCUDAAssertFailure(TestCase):
333
def test_throw_unrecoverable_cuda_exception(self, device):
334
x = torch.rand(10, device=device)
335
# cause unrecoverable CUDA exception, recoverable on CPU
336
y = x[torch.tensor([25])].cpu()
339
def test_trivial_passing_test_case_on_cpu_cuda(self, device):
340
x1 = torch.tensor([0., 1.], device=device)
341
x2 = torch.tensor([0., 1.], device='cpu')
342
self.assertEqual(x1, x2)
344
instantiate_device_type_tests(
345
TestThatContainsCUDAAssertFailure,
350
if __name__ == '__main__':
354
self.assertIn('CUDA error: device-side assert triggered', stderr)
356
self.assertIn('errors=1', stderr)
359
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
362
def test_cuda_assert_should_not_stop_common_distributed_test_suite(self, device):
364
stderr = TestCase.runWithPytorchAPIUsageStderr("""\
365
#!/usr/bin/env python3
368
from torch.testing._internal.common_utils import (run_tests, slowTest)
369
from torch.testing._internal.common_device_type import instantiate_device_type_tests
370
from torch.testing._internal.common_distributed import MultiProcessTestCase
372
class TestThatContainsCUDAAssertFailure(MultiProcessTestCase):
375
def test_throw_unrecoverable_cuda_exception(self, device):
376
x = torch.rand(10, device=device)
377
# cause unrecoverable CUDA exception, recoverable on CPU
378
y = x[torch.tensor([25])].cpu()
381
def test_trivial_passing_test_case_on_cpu_cuda(self, device):
382
x1 = torch.tensor([0., 1.], device=device)
383
x2 = torch.tensor([0., 1.], device='cpu')
384
self.assertEqual(x1, x2)
386
instantiate_device_type_tests(
387
TestThatContainsCUDAAssertFailure,
392
if __name__ == '__main__':
396
self.assertIn('errors=2', stderr)
399
@onlyNativeDeviceTypes
400
def test_get_supported_dtypes(self, device):
404
ops_to_test = list(filter(lambda op: op.name in ['atan2', 'topk', 'xlogy'], op_db))
406
for op in ops_to_test:
407
dynamic_dtypes = opinfo.utils.get_supported_dtypes(op, op.sample_inputs_func, self.device_type)
408
dynamic_dispatch = opinfo.utils.dtypes_dispatch_hint(dynamic_dtypes)
409
if self.device_type == 'cpu':
412
dtypes = op.dtypesIfCUDA
414
self.assertTrue(set(dtypes) == set(dynamic_dtypes))
415
self.assertTrue(set(dtypes) == set(dynamic_dispatch.dispatch_fn()))
423
op.supported_dtypes("cpu").symmetric_difference(
424
op.supported_dtypes("cuda")
429
dtypes=OpDTypes.none,
431
def test_supported_dtypes(self, device, op):
432
self.assertNotEqual(op.supported_dtypes("cpu"), op.supported_dtypes("cuda"))
433
self.assertEqual(op.supported_dtypes("cuda"), op.supported_dtypes("cuda:0"))
435
op.supported_dtypes(torch.device("cuda")),
436
op.supported_dtypes(torch.device("cuda", index=1)),
439
instantiate_device_type_tests(TestTesting, globals())
442
class TestFrameworkUtils(TestCase):
444
@unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows")
445
@unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle")
446
def test_filtering_env_var(self):
448
test_filter_file_template = """\
449
#!/usr/bin/env python3
452
from torch.testing._internal.common_utils import (TestCase, run_tests)
453
from torch.testing._internal.common_device_type import instantiate_device_type_tests
455
class TestEnvironmentVariable(TestCase):
457
def test_trivial_passing_test(self, device):
458
x1 = torch.tensor([0., 1.], device=device)
459
x2 = torch.tensor([0., 1.], device='cpu')
460
self.assertEqual(x1, x2)
462
instantiate_device_type_tests(
463
TestEnvironmentVariable,
467
if __name__ == '__main__':
470
test_bases_count = len(get_device_type_test_bases())
472
env = dict(os.environ)
473
for k in ['CI', PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY]:
476
_, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
477
self.assertIn(f'Ran {test_bases_count} test', stderr.decode('ascii'))
480
env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
481
_, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
482
self.assertIn('Ran 1 test', stderr.decode('ascii'))
485
del env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY]
486
env[PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY] = 'cpu'
487
_, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
488
self.assertIn(f'Ran {test_bases_count-1} test', stderr.decode('ascii'))
491
env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
492
_, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
493
self.assertNotIn('OK', stderr.decode('ascii'))
496
def make_assert_close_inputs(actual: Any, expected: Any) -> List[Tuple[Any, Any]]:
497
"""Makes inputs for :func:`torch.testing.assert_close` functions based on two examples.
500
actual (Any): Actual input.
501
expected (Any): Expected input.
504
List[Tuple[Any, Any]]: Pair of example inputs, as well as the example inputs wrapped in sequences
505
(:class:`tuple`, :class:`list`), and mappings (:class:`dict`, :class:`~collections.OrderedDict`).
510
((actual,), (expected,)),
512
([actual], [expected]),
514
((actual,), [expected]),
516
({"t": actual}, {"t": expected}),
518
(collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])),
520
({"t": actual}, collections.OrderedDict([("t", expected)])),
522
([(actual,)], ([expected],)),
524
([{"t": actual}], (collections.OrderedDict([("t", expected)]),)),
526
({"t": [actual]}, collections.OrderedDict([("t", (expected,))])),
530
def assert_close_with_inputs(actual: Any, expected: Any) -> Iterator[Callable]:
531
"""Yields :func:`torch.testing.assert_close` with predefined positional inputs based on two examples.
535
Every test that does not test for a specific input should iterate over this to maximize the coverage.
538
actual (Any): Actual input.
539
expected (Any): Expected input.
542
Callable: :func:`torch.testing.assert_close` with predefined positional inputs.
544
for inputs in make_assert_close_inputs(actual, expected):
545
yield functools.partial(torch.testing.assert_close, *inputs)
548
class TestAssertClose(TestCase):
549
def test_mismatching_types_subclasses(self):
550
actual = torch.rand(())
551
expected = torch.nn.Parameter(actual)
553
for fn in assert_close_with_inputs(actual, expected):
556
def test_mismatching_types_type_equality(self):
557
actual = torch.empty(())
558
expected = torch.nn.Parameter(actual)
560
for fn in assert_close_with_inputs(actual, expected):
561
with self.assertRaisesRegex(TypeError, str(type(expected))):
562
fn(allow_subclasses=False)
564
def test_mismatching_types(self):
565
actual = torch.empty(2)
566
expected = actual.numpy()
568
for fn, allow_subclasses in itertools.product(assert_close_with_inputs(actual, expected), (True, False)):
569
with self.assertRaisesRegex(TypeError, str(type(expected))):
570
fn(allow_subclasses=allow_subclasses)
572
def test_unknown_type(self):
576
for fn in assert_close_with_inputs(actual, expected):
577
with self.assertRaisesRegex(TypeError, str(type(actual))):
580
def test_mismatching_shape(self):
581
actual = torch.empty(())
582
expected = actual.clone().reshape((1,))
584
for fn in assert_close_with_inputs(actual, expected):
585
with self.assertRaisesRegex(AssertionError, "shape"):
588
@unittest.skipIf(not torch.backends.mkldnn.is_available(), reason="MKLDNN is not available.")
589
def test_unknown_layout(self):
590
actual = torch.empty((2, 2))
591
expected = actual.to_mkldnn()
593
for fn in assert_close_with_inputs(actual, expected):
594
with self.assertRaisesRegex(ValueError, "layout"):
598
actual = torch.empty((2, 2), device="meta")
599
expected = torch.empty((2, 2), device="meta")
601
for fn in assert_close_with_inputs(actual, expected):
604
def test_mismatching_layout(self):
605
strided = torch.empty((2, 2))
606
sparse_coo = strided.to_sparse()
607
sparse_csr = strided.to_sparse_csr()
609
for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
610
for fn in assert_close_with_inputs(actual, expected):
611
with self.assertRaisesRegex(AssertionError, "layout"):
614
def test_mismatching_layout_no_check(self):
615
strided = torch.randn((2, 2))
616
sparse_coo = strided.to_sparse()
617
sparse_csr = strided.to_sparse_csr()
619
for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
620
for fn in assert_close_with_inputs(actual, expected):
621
fn(check_layout=False)
623
def test_mismatching_dtype(self):
624
actual = torch.empty((), dtype=torch.float)
625
expected = actual.clone().to(torch.int)
627
for fn in assert_close_with_inputs(actual, expected):
628
with self.assertRaisesRegex(AssertionError, "dtype"):
631
def test_mismatching_dtype_no_check(self):
632
actual = torch.ones((), dtype=torch.float)
633
expected = actual.clone().to(torch.int)
635
for fn in assert_close_with_inputs(actual, expected):
636
fn(check_dtype=False)
638
def test_mismatching_stride(self):
639
actual = torch.empty((2, 2))
640
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
642
for fn in assert_close_with_inputs(actual, expected):
643
with self.assertRaisesRegex(AssertionError, "stride"):
644
fn(check_stride=True)
646
def test_mismatching_stride_no_check(self):
647
actual = torch.rand((2, 2))
648
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
649
for fn in assert_close_with_inputs(actual, expected):
652
def test_only_rtol(self):
653
actual = torch.empty(())
654
expected = actual.clone()
656
for fn in assert_close_with_inputs(actual, expected):
657
with self.assertRaises(ValueError):
660
def test_only_atol(self):
661
actual = torch.empty(())
662
expected = actual.clone()
664
for fn in assert_close_with_inputs(actual, expected):
665
with self.assertRaises(ValueError):
668
def test_mismatching_values(self):
669
actual = torch.tensor(1)
670
expected = torch.tensor(2)
672
for fn in assert_close_with_inputs(actual, expected):
673
with self.assertRaises(AssertionError):
676
def test_mismatching_values_rtol(self):
678
actual = torch.tensor(1.0)
679
expected = torch.tensor(1.0 + eps)
681
for fn in assert_close_with_inputs(actual, expected):
682
with self.assertRaises(AssertionError):
683
fn(rtol=eps / 2, atol=0.0)
685
def test_mismatching_values_atol(self):
687
actual = torch.tensor(0.0)
688
expected = torch.tensor(eps)
690
for fn in assert_close_with_inputs(actual, expected):
691
with self.assertRaises(AssertionError):
692
fn(rtol=0.0, atol=eps / 2)
694
def test_matching(self):
695
actual = torch.tensor(1.0)
696
expected = actual.clone()
698
torch.testing.assert_close(actual, expected)
700
def test_matching_rtol(self):
702
actual = torch.tensor(1.0)
703
expected = torch.tensor(1.0 + eps)
705
for fn in assert_close_with_inputs(actual, expected):
706
fn(rtol=eps * 2, atol=0.0)
708
def test_matching_atol(self):
710
actual = torch.tensor(0.0)
711
expected = torch.tensor(eps)
713
for fn in assert_close_with_inputs(actual, expected):
714
fn(rtol=0.0, atol=eps * 2)
718
def test_matching_conjugate_bit(self):
719
actual = torch.tensor(complex(1, 1)).conj()
720
expected = torch.tensor(complex(1, -1))
722
for fn in assert_close_with_inputs(actual, expected):
725
def test_matching_nan(self):
730
(complex(nan, 0), complex(0, nan)),
731
(complex(nan, nan), complex(nan, 0)),
732
(complex(nan, nan), complex(nan, nan)),
735
for actual, expected in tests:
736
for fn in assert_close_with_inputs(actual, expected):
737
with self.assertRaises(AssertionError):
740
def test_matching_nan_with_equal_nan(self):
745
(complex(nan, 0), complex(0, nan)),
746
(complex(nan, nan), complex(nan, 0)),
747
(complex(nan, nan), complex(nan, nan)),
750
for actual, expected in tests:
751
for fn in assert_close_with_inputs(actual, expected):
754
def test_numpy(self):
755
tensor = torch.rand(2, 2, dtype=torch.float32)
756
actual = tensor.numpy()
757
expected = actual.copy()
759
for fn in assert_close_with_inputs(actual, expected):
762
def test_scalar(self):
763
number = torch.randint(10, size=()).item()
764
for actual, expected in itertools.product((int(number), float(number), complex(number)), repeat=2):
765
check_dtype = type(actual) is type(expected)
767
for fn in assert_close_with_inputs(actual, expected):
768
fn(check_dtype=check_dtype)
771
actual = torch.tensor([True, False])
772
expected = actual.clone()
774
for fn in assert_close_with_inputs(actual, expected):
778
actual = expected = None
780
for fn in assert_close_with_inputs(actual, expected):
783
def test_none_mismatch(self):
786
for actual in (False, 0, torch.nan, torch.tensor(torch.nan)):
787
for fn in assert_close_with_inputs(actual, expected):
788
with self.assertRaises(AssertionError):
792
def test_docstring_examples(self):
793
finder = doctest.DocTestFinder(verbose=False)
794
runner = doctest.DocTestRunner(verbose=False, optionflags=doctest.NORMALIZE_WHITESPACE)
795
globs = dict(torch=torch)
796
doctests = finder.find(torch.testing.assert_close, globs=globs)[0]
798
runner.run(doctests, out=lambda report: failures.append(report))
800
raise AssertionError(f"Doctest found {len(failures)} failures:\n\n" + "\n".join(failures))
802
def test_default_tolerance_selection_mismatching_dtypes(self):
805
actual = torch.tensor(0.99, dtype=torch.bfloat16)
806
expected = torch.tensor(1.0, dtype=torch.float64)
808
for fn in assert_close_with_inputs(actual, expected):
809
fn(check_dtype=False)
811
class UnexpectedException(Exception):
812
"""The only purpose of this exception is to test ``assert_close``'s handling of unexpected exceptions. Thus,
813
the test should mock a component to raise this instead of the regular behavior. We avoid using a builtin
814
exception here to avoid triggering possible handling of them.
817
@unittest.mock.patch("torch.testing._comparison.TensorLikePair.__init__", side_effect=UnexpectedException)
818
def test_unexpected_error_originate(self, _):
819
actual = torch.tensor(1.0)
820
expected = actual.clone()
822
with self.assertRaisesRegex(RuntimeError, "unexpected exception"):
823
torch.testing.assert_close(actual, expected)
825
@unittest.mock.patch("torch.testing._comparison.TensorLikePair.compare", side_effect=UnexpectedException)
826
def test_unexpected_error_compare(self, _):
827
actual = torch.tensor(1.0)
828
expected = actual.clone()
830
with self.assertRaisesRegex(RuntimeError, "unexpected exception"):
831
torch.testing.assert_close(actual, expected)
836
class TestAssertCloseMultiDevice(TestCase):
837
@deviceCountAtLeast(1)
838
def test_mismatching_device(self, devices):
839
for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
840
actual = torch.empty((), device=actual_device)
841
expected = actual.clone().to(expected_device)
842
for fn in assert_close_with_inputs(actual, expected):
843
with self.assertRaisesRegex(AssertionError, "device"):
846
@deviceCountAtLeast(1)
847
def test_mismatching_device_no_check(self, devices):
848
for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
849
actual = torch.rand((), device=actual_device)
850
expected = actual.clone().to(expected_device)
851
for fn in assert_close_with_inputs(actual, expected):
852
fn(check_device=False)
855
instantiate_device_type_tests(TestAssertCloseMultiDevice, globals(), only_for="cuda")
858
class TestAssertCloseErrorMessage(TestCase):
859
def test_identifier_tensor_likes(self):
860
actual = torch.tensor([1, 2, 3, 4])
861
expected = torch.tensor([1, 2, 5, 6])
863
for fn in assert_close_with_inputs(actual, expected):
864
with self.assertRaisesRegex(AssertionError, re.escape("Tensor-likes")):
867
def test_identifier_scalars(self):
870
for fn in assert_close_with_inputs(actual, expected):
871
with self.assertRaisesRegex(AssertionError, re.escape("Scalars")):
874
def test_not_equal(self):
875
actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
876
expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32)
878
for fn in assert_close_with_inputs(actual, expected):
879
with self.assertRaisesRegex(AssertionError, re.escape("not equal")):
880
fn(rtol=0.0, atol=0.0)
882
def test_not_close(self):
883
actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
884
expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32)
886
for fn, (rtol, atol) in itertools.product(
887
assert_close_with_inputs(actual, expected), ((1.3e-6, 0.0), (0.0, 1e-5), (1.3e-6, 1e-5))
889
with self.assertRaisesRegex(AssertionError, re.escape("not close")):
890
fn(rtol=rtol, atol=atol)
892
def test_mismatched_elements(self):
893
actual = torch.tensor([1, 2, 3, 4])
894
expected = torch.tensor([1, 2, 5, 6])
896
for fn in assert_close_with_inputs(actual, expected):
897
with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
900
def test_abs_diff(self):
901
actual = torch.tensor([[1, 2], [3, 4]])
902
expected = torch.tensor([[1, 2], [5, 4]])
904
for fn in assert_close_with_inputs(actual, expected):
905
with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at index (1, 0)")):
908
def test_abs_diff_scalar(self):
912
for fn in assert_close_with_inputs(actual, expected):
913
with self.assertRaisesRegex(AssertionError, re.escape("Absolute difference: 2")):
916
def test_rel_diff(self):
917
actual = torch.tensor([[1, 2], [3, 4]])
918
expected = torch.tensor([[1, 4], [3, 4]])
920
for fn in assert_close_with_inputs(actual, expected):
921
with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at index (0, 1)")):
924
def test_rel_diff_scalar(self):
928
for fn in assert_close_with_inputs(actual, expected):
929
with self.assertRaisesRegex(AssertionError, re.escape("Relative difference: 0.5")):
932
def test_zero_div_zero(self):
933
actual = torch.tensor([1.0, 0.0])
934
expected = torch.tensor([2.0, 0.0])
936
for fn in assert_close_with_inputs(actual, expected):
939
with self.assertRaisesRegex(AssertionError, "((?!nan).)*"):
945
actual = torch.tensor((1, 2))
946
expected = torch.tensor((2, 2))
948
for fn in assert_close_with_inputs(actual, expected):
949
with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {rtol} allowed)")):
950
fn(rtol=rtol, atol=0.0)
955
actual = torch.tensor((1, 2))
956
expected = torch.tensor((2, 2))
958
for fn in assert_close_with_inputs(actual, expected):
959
with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {atol} allowed)")):
960
fn(rtol=0.0, atol=atol)
962
def test_msg_str(self):
963
msg = "Custom error message!"
965
actual = torch.tensor(1)
966
expected = torch.tensor(2)
968
for fn in assert_close_with_inputs(actual, expected):
969
with self.assertRaisesRegex(AssertionError, msg):
972
def test_msg_callable(self):
973
msg = "Custom error message"
975
actual = torch.tensor(1)
976
expected = torch.tensor(2)
978
for fn in assert_close_with_inputs(actual, expected):
979
with self.assertRaisesRegex(AssertionError, msg):
980
fn(msg=lambda _: msg)
983
class TestAssertCloseContainer(TestCase):
984
def test_sequence_mismatching_len(self):
985
actual = (torch.empty(()),)
988
with self.assertRaises(AssertionError):
989
torch.testing.assert_close(actual, expected)
991
def test_sequence_mismatching_values_msg(self):
998
with self.assertRaisesRegex(AssertionError, re.escape("item [1]")):
999
torch.testing.assert_close(actual, expected)
1001
def test_mapping_mismatching_keys(self):
1002
actual = {"a": torch.empty(())}
1005
with self.assertRaises(AssertionError):
1006
torch.testing.assert_close(actual, expected)
1008
def test_mapping_mismatching_values_msg(self):
1009
t1 = torch.tensor(1)
1010
t2 = torch.tensor(2)
1012
actual = {"a": t1, "b": t1}
1013
expected = {"a": t1, "b": t2}
1015
with self.assertRaisesRegex(AssertionError, re.escape("item ['b']")):
1016
torch.testing.assert_close(actual, expected)
1019
class TestAssertCloseSparseCOO(TestCase):
1020
def test_matching_coalesced(self):
1026
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)).coalesce()
1027
expected = actual.clone()
1029
for fn in assert_close_with_inputs(actual, expected):
1032
def test_matching_uncoalesced(self):
1038
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
1039
expected = actual.clone()
1041
for fn in assert_close_with_inputs(actual, expected):
1044
def test_mismatching_sparse_dims(self):
1045
t = torch.randn(2, 3, 4)
1046
actual = t.to_sparse()
1047
expected = t.to_sparse(2)
1049
for fn in assert_close_with_inputs(actual, expected):
1050
with self.assertRaisesRegex(AssertionError, re.escape("number of sparse dimensions in sparse COO tensors")):
1053
def test_mismatching_nnz(self):
1058
actual_values = (1, 2)
1059
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
1061
expected_indices = (
1065
expected_values = (1, 1, 1)
1066
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
1068
for fn in assert_close_with_inputs(actual, expected):
1069
with self.assertRaisesRegex(AssertionError, re.escape("number of specified values in sparse COO tensors")):
1072
def test_mismatching_indices_msg(self):
1077
actual_values = (1, 2)
1078
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
1080
expected_indices = (
1084
expected_values = (1, 2)
1085
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
1087
for fn in assert_close_with_inputs(actual, expected):
1088
with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO indices")):
1091
def test_mismatching_values_msg(self):
1096
actual_values = (1, 2)
1097
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
1099
expected_indices = (
1103
expected_values = (1, 3)
1104
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
1106
for fn in assert_close_with_inputs(actual, expected):
1107
with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO values")):
1111
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSR testing")
1112
class TestAssertCloseSparseCSR(TestCase):
1113
def test_matching(self):
1114
crow_indices = (0, 1, 2)
1115
col_indices = (1, 0)
1117
actual = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2))
1118
expected = actual.clone()
1120
for fn in assert_close_with_inputs(actual, expected):
1123
def test_mismatching_crow_indices_msg(self):
1124
actual_crow_indices = (0, 1, 2)
1125
actual_col_indices = (0, 1)
1126
actual_values = (1, 2)
1127
actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1129
expected_crow_indices = (0, 2, 2)
1130
expected_col_indices = actual_col_indices
1131
expected_values = actual_values
1132
expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1134
for fn in assert_close_with_inputs(actual, expected):
1135
with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR crow_indices")):
1138
def test_mismatching_col_indices_msg(self):
1139
actual_crow_indices = (0, 1, 2)
1140
actual_col_indices = (1, 0)
1141
actual_values = (1, 2)
1142
actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1144
expected_crow_indices = actual_crow_indices
1145
expected_col_indices = (1, 1)
1146
expected_values = actual_values
1147
expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1149
for fn in assert_close_with_inputs(actual, expected):
1150
with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR col_indices")):
1153
def test_mismatching_values_msg(self):
1154
actual_crow_indices = (0, 1, 2)
1155
actual_col_indices = (1, 0)
1156
actual_values = (1, 2)
1157
actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1159
expected_crow_indices = actual_crow_indices
1160
expected_col_indices = actual_col_indices
1161
expected_values = (1, 3)
1162
expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1164
for fn in assert_close_with_inputs(actual, expected):
1165
with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR values")):
1169
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSC testing")
1170
class TestAssertCloseSparseCSC(TestCase):
1171
def test_matching(self):
1172
ccol_indices = (0, 1, 2)
1173
row_indices = (1, 0)
1175
actual = torch.sparse_csc_tensor(ccol_indices, row_indices, values, size=(2, 2))
1176
expected = actual.clone()
1178
for fn in assert_close_with_inputs(actual, expected):
1181
def test_mismatching_ccol_indices_msg(self):
1182
actual_ccol_indices = (0, 1, 2)
1183
actual_row_indices = (0, 1)
1184
actual_values = (1, 2)
1185
actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1187
expected_ccol_indices = (0, 2, 2)
1188
expected_row_indices = actual_row_indices
1189
expected_values = actual_values
1190
expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1192
for fn in assert_close_with_inputs(actual, expected):
1193
with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC ccol_indices")):
1196
def test_mismatching_row_indices_msg(self):
1197
actual_ccol_indices = (0, 1, 2)
1198
actual_row_indices = (1, 0)
1199
actual_values = (1, 2)
1200
actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1202
expected_ccol_indices = actual_ccol_indices
1203
expected_row_indices = (1, 1)
1204
expected_values = actual_values
1205
expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1207
for fn in assert_close_with_inputs(actual, expected):
1208
with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC row_indices")):
1211
def test_mismatching_values_msg(self):
1212
actual_ccol_indices = (0, 1, 2)
1213
actual_row_indices = (1, 0)
1214
actual_values = (1, 2)
1215
actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1217
expected_ccol_indices = actual_ccol_indices
1218
expected_row_indices = actual_row_indices
1219
expected_values = (1, 3)
1220
expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1222
for fn in assert_close_with_inputs(actual, expected):
1223
with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC values")):
1227
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSR testing")
1228
class TestAssertCloseSparseBSR(TestCase):
1229
def test_matching(self):
1230
crow_indices = (0, 1, 2)
1231
col_indices = (1, 0)
1232
values = ([[1]], [[2]])
1233
actual = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(2, 2))
1234
expected = actual.clone()
1236
for fn in assert_close_with_inputs(actual, expected):
1239
def test_mismatching_crow_indices_msg(self):
1240
actual_crow_indices = (0, 1, 2)
1241
actual_col_indices = (0, 1)
1242
actual_values = ([[1]], [[2]])
1243
actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1245
expected_crow_indices = (0, 2, 2)
1246
expected_col_indices = actual_col_indices
1247
expected_values = actual_values
1248
expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1250
for fn in assert_close_with_inputs(actual, expected):
1251
with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR crow_indices")):
1254
def test_mismatching_col_indices_msg(self):
1255
actual_crow_indices = (0, 1, 2)
1256
actual_col_indices = (1, 0)
1257
actual_values = ([[1]], [[2]])
1258
actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1260
expected_crow_indices = actual_crow_indices
1261
expected_col_indices = (1, 1)
1262
expected_values = actual_values
1263
expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1265
for fn in assert_close_with_inputs(actual, expected):
1266
with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR col_indices")):
1269
def test_mismatching_values_msg(self):
1270
actual_crow_indices = (0, 1, 2)
1271
actual_col_indices = (1, 0)
1272
actual_values = ([[1]], [[2]])
1273
actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1275
expected_crow_indices = actual_crow_indices
1276
expected_col_indices = actual_col_indices
1277
expected_values = ([[1]], [[3]])
1278
expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1280
for fn in assert_close_with_inputs(actual, expected):
1281
with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR values")):
1285
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSC testing")
1286
class TestAssertCloseSparseBSC(TestCase):
1287
def test_matching(self):
1288
ccol_indices = (0, 1, 2)
1289
row_indices = (1, 0)
1290
values = ([[1]], [[2]])
1291
actual = torch.sparse_bsc_tensor(ccol_indices, row_indices, values, size=(2, 2))
1292
expected = actual.clone()
1294
for fn in assert_close_with_inputs(actual, expected):
1297
def test_mismatching_ccol_indices_msg(self):
1298
actual_ccol_indices = (0, 1, 2)
1299
actual_row_indices = (0, 1)
1300
actual_values = ([[1]], [[2]])
1301
actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1303
expected_ccol_indices = (0, 2, 2)
1304
expected_row_indices = actual_row_indices
1305
expected_values = actual_values
1306
expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1308
for fn in assert_close_with_inputs(actual, expected):
1309
with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC ccol_indices")):
1312
def test_mismatching_row_indices_msg(self):
1313
actual_ccol_indices = (0, 1, 2)
1314
actual_row_indices = (1, 0)
1315
actual_values = ([[1]], [[2]])
1316
actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1318
expected_ccol_indices = actual_ccol_indices
1319
expected_row_indices = (1, 1)
1320
expected_values = actual_values
1321
expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1323
for fn in assert_close_with_inputs(actual, expected):
1324
with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC row_indices")):
1327
def test_mismatching_values_msg(self):
1328
actual_ccol_indices = (0, 1, 2)
1329
actual_row_indices = (1, 0)
1330
actual_values = ([[1]], [[2]])
1331
actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1333
expected_ccol_indices = actual_ccol_indices
1334
expected_row_indices = actual_row_indices
1335
expected_values = ([[1]], [[3]])
1336
expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1338
for fn in assert_close_with_inputs(actual, expected):
1339
with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC values")):
1343
class TestAssertCloseQuantized(TestCase):
1344
def test_mismatching_is_quantized(self):
1345
actual = torch.tensor(1.0)
1346
expected = torch.quantize_per_tensor(actual, scale=1.0, zero_point=0, dtype=torch.qint32)
1348
for fn in assert_close_with_inputs(actual, expected):
1349
with self.assertRaisesRegex(AssertionError, "is_quantized"):
1352
def test_mismatching_qscheme(self):
1353
t = torch.tensor((1.0,))
1354
actual = torch.quantize_per_tensor(t, scale=1.0, zero_point=0, dtype=torch.qint32)
1355
expected = torch.quantize_per_channel(
1357
scales=torch.tensor((1.0,)),
1358
zero_points=torch.tensor((0,)),
1363
for fn in assert_close_with_inputs(actual, expected):
1364
with self.assertRaisesRegex(AssertionError, "qscheme"):
1367
def test_matching_per_tensor(self):
1368
actual = torch.quantize_per_tensor(torch.tensor(1.0), scale=1.0, zero_point=0, dtype=torch.qint32)
1369
expected = actual.clone()
1371
for fn in assert_close_with_inputs(actual, expected):
1374
def test_matching_per_channel(self):
1375
actual = torch.quantize_per_channel(
1376
torch.tensor((1.0,)),
1377
scales=torch.tensor((1.0,)),
1378
zero_points=torch.tensor((0,)),
1382
expected = actual.clone()
1384
for fn in assert_close_with_inputs(actual, expected):
1388
class TestMakeTensor(TestCase):
1389
supported_dtypes = dtypes(
1391
torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
1392
torch.float16, torch.bfloat16, torch.float32, torch.float64,
1393
torch.complex32, torch.complex64, torch.complex128,
1397
@parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)])
1398
@parametrize("splat_shape", [False, True])
1399
def test_smoke(self, dtype, device, shape, splat_shape):
1400
t = torch.testing.make_tensor(*shape if splat_shape else shape, dtype=dtype, device=device)
1402
self.assertIsInstance(t, torch.Tensor)
1403
self.assertEqual(t.shape, shape)
1404
self.assertEqual(t.dtype, dtype)
1405
self.assertEqual(t.device, torch.device(device))
1408
@parametrize("requires_grad", [False, True])
1409
def test_requires_grad(self, dtype, device, requires_grad):
1410
make_tensor = functools.partial(
1411
torch.testing.make_tensor,
1414
requires_grad=requires_grad,
1417
if not requires_grad or dtype.is_floating_point or dtype.is_complex:
1419
self.assertEqual(t.requires_grad, requires_grad)
1421
with self.assertRaisesRegex(
1422
ValueError, "`requires_grad=True` is not supported for boolean and integral dtypes"
1427
@parametrize("noncontiguous", [False, True])
1428
@parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)])
1429
def test_noncontiguous(self, dtype, device, noncontiguous, shape):
1430
numel = functools.reduce(operator.mul, shape, 1)
1432
t = torch.testing.make_tensor(shape, dtype=dtype, device=device, noncontiguous=noncontiguous)
1433
self.assertEqual(t.is_contiguous(), not noncontiguous or numel < 2)
1437
"memory_format_and_shape",
1440
(torch.contiguous_format, (2, 3, 4)),
1441
(torch.channels_last, (2, 3, 4, 5)),
1442
(torch.channels_last_3d, (2, 3, 4, 5, 6)),
1443
(torch.preserve_format, (2, 3, 4)),
1446
def test_memory_format(self, dtype, device, memory_format_and_shape):
1447
memory_format, shape = memory_format_and_shape
1449
t = torch.testing.make_tensor(shape, dtype=dtype, device=device, memory_format=memory_format)
1452
t.is_contiguous(memory_format=torch.contiguous_format if memory_format is None else memory_format)
1456
def test_noncontiguous_memory_format(self, dtype, device):
1457
with self.assertRaisesRegex(ValueError, "`noncontiguous` and `memory_format` are mutually exclusive"):
1458
torch.testing.make_tensor(
1463
memory_format=torch.channels_last,
1467
def test_exclude_zero(self, dtype, device):
1468
t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, exclude_zero=True, low=-1, high=2)
1470
self.assertTrue((t != 0).all())
1473
def test_low_high_smoke(self, dtype, device):
1474
low_inclusive, high_exclusive = 0, 2
1476
t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive)
1477
if dtype.is_complex:
1478
t = torch.view_as_real(t)
1480
self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all())
1483
def test_low_high_default_smoke(self, dtype, device):
1484
low_inclusive, high_exclusive = {
1486
torch.uint8: (0, 10),
1487
**dict.fromkeys([torch.int8, torch.int16, torch.int32, torch.int64], (-9, 10)),
1488
}.get(dtype, (-9, 9))
1490
t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive)
1491
if dtype.is_complex:
1492
t = torch.view_as_real(t)
1494
self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all())
1496
@parametrize("low_high", [(0, 0), (1, 0), (0, -1)])
1497
@parametrize("value_types", list(itertools.product([int, float], repeat=2)))
1499
def test_low_ge_high(self, dtype, device, low_high, value_types):
1500
low, high = (value_type(value) for value, value_type in zip(low_high, value_types))
1502
if low == high and (dtype.is_floating_point or dtype.is_complex):
1503
with self.assertWarnsRegex(
1505
"Passing `low==high` to `torch.testing.make_tensor` for floating or complex types is deprecated",
1507
t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low, high=high)
1508
self.assertEqual(t, torch.full_like(t, complex(low, low) if dtype.is_complex else low))
1510
with self.assertRaisesRegex(ValueError, "`low` must be less than `high`"):
1511
torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high)
1514
@parametrize("low_high", [(None, torch.nan), (torch.nan, None), (torch.nan, torch.nan)])
1515
def test_low_high_nan(self, dtype, device, low_high):
1516
low, high = low_high
1518
with self.assertRaisesRegex(ValueError, "`low` and `high` cannot be NaN"):
1519
torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high)
1522
def test_low_high_outside_valid_range(self, dtype, device):
1523
make_tensor = functools.partial(torch.testing.make_tensor, dtype=dtype, device=device)
1525
def get_dtype_limits(dtype):
1526
if dtype is torch.bool:
1529
info = (torch.finfo if dtype.is_floating_point or dtype.is_complex else torch.iinfo)(dtype)
1534
return int(info.min), int(info.max)
1536
lowest_inclusive, highest_inclusive = get_dtype_limits(dtype)
1538
with self.assertRaisesRegex(ValueError, ""):
1539
low, high = (-2, -1) if lowest_inclusive == 0 else (lowest_inclusive * 4, lowest_inclusive * 2)
1540
make_tensor(low=low, high=high)
1542
with self.assertRaisesRegex(ValueError, ""):
1543
make_tensor(low=highest_inclusive * 2, high=highest_inclusive * 4)
1545
@dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
1546
def test_low_high_boolean_integral1(self, dtype, device):
1550
actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=-(1 - eps), high=1 - eps)
1551
expected = torch.zeros(shape, dtype=dtype, device=device)
1553
torch.testing.assert_close(actual, expected)
1555
@dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
1556
def test_low_high_boolean_integral2(self, dtype, device):
1558
if dtype is torch.bool:
1560
elif dtype is torch.int64:
1562
low = torch.iinfo(dtype).max - 1
1564
low = torch.iinfo(dtype).max
1567
actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high)
1568
expected = torch.full(shape, low, dtype=dtype, device=device)
1570
torch.testing.assert_close(actual, expected)
1573
instantiate_device_type_tests(TestMakeTensor, globals())
1576
def _get_test_names_for_test_class(test_cls):
1577
""" Convenience function to get all test names for a given test class. """
1578
test_names = [f'{test_cls.__name__}.{key}' for key in test_cls.__dict__
1579
if key.startswith('test_')]
1580
return sorted(test_names)
1583
def _get_test_funcs_for_test_class(test_cls):
1584
""" Convenience function to get all (test function, parametrized_name) pairs for a given test class. """
1585
test_funcs = [(getattr(test_cls, key), key) for key in test_cls.__dict__ if key.startswith('test_')]
1589
class TestTestParametrization(TestCase):
1590
def test_default_names(self):
1592
class TestParametrized(TestCase):
1593
@parametrize("x", range(5))
1594
def test_default_names(self, x):
1597
@parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
1598
def test_two_things_default_names(self, x, y):
1601
instantiate_parametrized_tests(TestParametrized)
1603
expected_test_names = [
1604
'TestParametrized.test_default_names_x_0',
1605
'TestParametrized.test_default_names_x_1',
1606
'TestParametrized.test_default_names_x_2',
1607
'TestParametrized.test_default_names_x_3',
1608
'TestParametrized.test_default_names_x_4',
1609
'TestParametrized.test_two_things_default_names_x_1_y_2',
1610
'TestParametrized.test_two_things_default_names_x_2_y_3',
1611
'TestParametrized.test_two_things_default_names_x_3_y_4',
1613
test_names = _get_test_names_for_test_class(TestParametrized)
1614
self.assertEqual(expected_test_names, test_names)
1616
def test_name_fn(self):
1618
class TestParametrized(TestCase):
1619
@parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
1620
def test_custom_names(self, bias):
1623
@parametrize("x", [1, 2], name_fn=str)
1624
@parametrize("y", [3, 4], name_fn=str)
1625
@parametrize("z", [5, 6], name_fn=str)
1626
def test_three_things_composition_custom_names(self, x, y, z):
1629
@parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}')
1630
def test_two_things_custom_names_alternate(self, x, y):
1633
instantiate_parametrized_tests(TestParametrized)
1635
expected_test_names = [
1636
'TestParametrized.test_custom_names_bias',
1637
'TestParametrized.test_custom_names_no_bias',
1638
'TestParametrized.test_three_things_composition_custom_names_1_3_5',
1639
'TestParametrized.test_three_things_composition_custom_names_1_3_6',
1640
'TestParametrized.test_three_things_composition_custom_names_1_4_5',
1641
'TestParametrized.test_three_things_composition_custom_names_1_4_6',
1642
'TestParametrized.test_three_things_composition_custom_names_2_3_5',
1643
'TestParametrized.test_three_things_composition_custom_names_2_3_6',
1644
'TestParametrized.test_three_things_composition_custom_names_2_4_5',
1645
'TestParametrized.test_three_things_composition_custom_names_2_4_6',
1646
'TestParametrized.test_two_things_custom_names_alternate_1__2',
1647
'TestParametrized.test_two_things_custom_names_alternate_1__3',
1648
'TestParametrized.test_two_things_custom_names_alternate_1__4',
1650
test_names = _get_test_names_for_test_class(TestParametrized)
1651
self.assertEqual(expected_test_names, test_names)
1653
def test_subtest_names(self):
1655
class TestParametrized(TestCase):
1656
@parametrize("bias", [subtest(True, name='bias'),
1657
subtest(False, name='no_bias')])
1658
def test_custom_names(self, bias):
1661
@parametrize("x,y", [subtest((1, 2), name='double'),
1662
subtest((1, 3), name='triple'),
1663
subtest((1, 4), name='quadruple')])
1664
def test_two_things_custom_names(self, x, y):
1667
instantiate_parametrized_tests(TestParametrized)
1669
expected_test_names = [
1670
'TestParametrized.test_custom_names_bias',
1671
'TestParametrized.test_custom_names_no_bias',
1672
'TestParametrized.test_two_things_custom_names_double',
1673
'TestParametrized.test_two_things_custom_names_quadruple',
1674
'TestParametrized.test_two_things_custom_names_triple',
1676
test_names = _get_test_names_for_test_class(TestParametrized)
1677
self.assertEqual(expected_test_names, test_names)
1679
def test_apply_param_specific_decorators(self):
1683
func._decorator_applied = True
1686
class TestParametrized(TestCase):
1687
@parametrize("x", [subtest(1, name='one'),
1688
subtest(2, name='two', decorators=[test_dec]),
1689
subtest(3, name='three')])
1690
def test_param(self, x):
1693
instantiate_parametrized_tests(TestParametrized)
1695
for test_func, name in _get_test_funcs_for_test_class(TestParametrized):
1696
self.assertEqual(hasattr(test_func, '_decorator_applied'), name == 'test_param_two')
1698
def test_compose_param_specific_decorators(self):
1702
func._decorator_applied = True
1705
class TestParametrized(TestCase):
1706
@parametrize("x", [subtest(1),
1707
subtest(2, decorators=[test_dec]),
1709
@parametrize("y", [subtest(False, decorators=[test_dec]),
1711
def test_param(self, x, y):
1714
instantiate_parametrized_tests(TestParametrized)
1716
for test_func, name in _get_test_funcs_for_test_class(TestParametrized):
1718
should_apply = ('x_2' in name) or ('y_False' in name)
1719
self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
1721
def test_modules_decorator_misuse_error(self):
1724
class TestParametrized(TestCase):
1726
def test_modules(self, module_info):
1729
with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'):
1730
instantiate_parametrized_tests(TestParametrized)
1732
def test_ops_decorator_misuse_error(self):
1735
class TestParametrized(TestCase):
1737
def test_ops(self, module_info):
1740
with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'):
1741
instantiate_parametrized_tests(TestParametrized)
1743
def test_multiple_handling_of_same_param_error(self):
1746
class TestParametrized(TestCase):
1747
@parametrize("x", range(3))
1748
@parametrize("x", range(5))
1749
def test_param(self, x):
1752
with self.assertRaisesRegex(RuntimeError, 'multiple parametrization decorators'):
1753
instantiate_parametrized_tests(TestParametrized)
1755
@parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
1756
def test_subtest_expected_failure(self, x):
1758
raise RuntimeError('Boom')
1760
@parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
1761
@parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
1762
def test_two_things_subtest_expected_failure(self, x, y):
1763
if x == 1 or y == 6:
1764
raise RuntimeError('Boom')
1767
class TestTestParametrizationDeviceType(TestCase):
1768
def test_unparametrized_names(self, device):
1772
device = self.device_type
1774
class TestParametrized(TestCase):
1775
def test_device_specific(self, device):
1778
@dtypes(torch.float32, torch.float64)
1779
def test_device_dtype_specific(self, device, dtype):
1782
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1784
device_cls = locals()[f'TestParametrized{device.upper()}']
1785
expected_test_names = [name.format(device_cls.__name__, device) for name in (
1786
'{}.test_device_dtype_specific_{}_float32',
1787
'{}.test_device_dtype_specific_{}_float64',
1788
'{}.test_device_specific_{}')
1790
test_names = _get_test_names_for_test_class(device_cls)
1791
self.assertEqual(expected_test_names, test_names)
1793
def test_empty_param_names(self, device):
1795
device = self.device_type
1797
class TestParametrized(TestCase):
1798
@parametrize("", [])
1799
def test_foo(self, device):
1802
@parametrize("", range(5))
1803
def test_bar(self, device):
1806
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1808
device_cls = locals()[f'TestParametrized{device.upper()}']
1809
expected_test_names = [name.format(device_cls.__name__, device) for name in (
1813
test_names = _get_test_names_for_test_class(device_cls)
1814
self.assertEqual(expected_test_names, test_names)
1816
def test_empty_param_list(self, device):
1819
device = self.device_type
1821
generator = (a for a in range(5))
1823
class TestParametrized(TestCase):
1824
@parametrize("x", generator)
1825
def test_foo(self, device, x):
1829
@parametrize("y", generator)
1830
def test_bar(self, device, y):
1833
with self.assertRaisesRegex(ValueError, 'An empty arg_values was passed'):
1834
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1836
def test_default_names(self, device):
1837
device = self.device_type
1839
class TestParametrized(TestCase):
1840
@parametrize("x", range(5))
1841
def test_default_names(self, device, x):
1844
@parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
1845
def test_two_things_default_names(self, device, x, y):
1849
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1851
device_cls = locals()[f'TestParametrized{device.upper()}']
1852
expected_test_names = [name.format(device_cls.__name__, device) for name in (
1853
'{}.test_default_names_x_0_{}',
1854
'{}.test_default_names_x_1_{}',
1855
'{}.test_default_names_x_2_{}',
1856
'{}.test_default_names_x_3_{}',
1857
'{}.test_default_names_x_4_{}',
1858
'{}.test_two_things_default_names_x_1_y_2_{}',
1859
'{}.test_two_things_default_names_x_2_y_3_{}',
1860
'{}.test_two_things_default_names_x_3_y_4_{}')
1862
test_names = _get_test_names_for_test_class(device_cls)
1863
self.assertEqual(expected_test_names, test_names)
1865
def test_default_name_non_primitive(self, device):
1866
device = self.device_type
1868
class TestParametrized(TestCase):
1869
@parametrize("x", [1, .5, "foo", object()])
1870
def test_default_names(self, device, x):
1873
@parametrize("x,y", [(1, object()), (object(), .5), (object(), object())])
1874
def test_two_things_default_names(self, device, x, y):
1877
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1879
device_cls = locals()[f'TestParametrized{device.upper()}']
1880
expected_test_names = sorted(name.format(device_cls.__name__, device) for name in (
1881
'{}.test_default_names_x_1_{}',
1882
'{}.test_default_names_x_0_5_{}',
1883
'{}.test_default_names_x_foo_{}',
1884
'{}.test_default_names_x3_{}',
1885
'{}.test_two_things_default_names_x_1_y0_{}',
1886
'{}.test_two_things_default_names_x1_y_0_5_{}',
1887
'{}.test_two_things_default_names_x2_y2_{}')
1889
test_names = _get_test_names_for_test_class(device_cls)
1890
self.assertEqual(expected_test_names, test_names)
1892
def test_name_fn(self, device):
1893
device = self.device_type
1895
class TestParametrized(TestCase):
1896
@parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
1897
def test_custom_names(self, device, bias):
1900
@parametrize("x", [1, 2], name_fn=str)
1901
@parametrize("y", [3, 4], name_fn=str)
1902
@parametrize("z", [5, 6], name_fn=str)
1903
def test_three_things_composition_custom_names(self, device, x, y, z):
1906
@parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}')
1907
def test_two_things_custom_names_alternate(self, device, x, y):
1910
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1912
device_cls = locals()[f'TestParametrized{device.upper()}']
1913
expected_test_names = [name.format(device_cls.__name__, device) for name in (
1914
'{}.test_custom_names_bias_{}',
1915
'{}.test_custom_names_no_bias_{}',
1916
'{}.test_three_things_composition_custom_names_1_3_5_{}',
1917
'{}.test_three_things_composition_custom_names_1_3_6_{}',
1918
'{}.test_three_things_composition_custom_names_1_4_5_{}',
1919
'{}.test_three_things_composition_custom_names_1_4_6_{}',
1920
'{}.test_three_things_composition_custom_names_2_3_5_{}',
1921
'{}.test_three_things_composition_custom_names_2_3_6_{}',
1922
'{}.test_three_things_composition_custom_names_2_4_5_{}',
1923
'{}.test_three_things_composition_custom_names_2_4_6_{}',
1924
'{}.test_two_things_custom_names_alternate_1__2_{}',
1925
'{}.test_two_things_custom_names_alternate_1__3_{}',
1926
'{}.test_two_things_custom_names_alternate_1__4_{}')
1928
test_names = _get_test_names_for_test_class(device_cls)
1929
self.assertEqual(expected_test_names, test_names)
1931
def test_subtest_names(self, device):
1932
device = self.device_type
1934
class TestParametrized(TestCase):
1935
@parametrize("bias", [subtest(True, name='bias'),
1936
subtest(False, name='no_bias')])
1937
def test_custom_names(self, device, bias):
1940
@parametrize("x,y", [subtest((1, 2), name='double'),
1941
subtest((1, 3), name='triple'),
1942
subtest((1, 4), name='quadruple')])
1943
def test_two_things_custom_names(self, device, x, y):
1946
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1948
device_cls = locals()[f'TestParametrized{device.upper()}']
1949
expected_test_names = [name.format(device_cls.__name__, device) for name in (
1950
'{}.test_custom_names_bias_{}',
1951
'{}.test_custom_names_no_bias_{}',
1952
'{}.test_two_things_custom_names_double_{}',
1953
'{}.test_two_things_custom_names_quadruple_{}',
1954
'{}.test_two_things_custom_names_triple_{}')
1956
test_names = _get_test_names_for_test_class(device_cls)
1957
self.assertEqual(expected_test_names, test_names)
1959
def test_ops_composition_names(self, device):
1960
device = self.device_type
1962
class TestParametrized(TestCase):
1964
@parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
1965
def test_op_parametrized(self, device, dtype, op, flag):
1968
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1970
device_cls = locals()[f'TestParametrized{device.upper()}']
1971
expected_test_names = []
1973
for dtype in op.supported_dtypes(torch.device(device).type):
1974
for flag_part in ('flag_disabled', 'flag_enabled'):
1975
expected_name = f'{device_cls.__name__}.test_op_parametrized_{op.formatted_name}_{flag_part}_{device}_{dtype_name(dtype)}'
1976
expected_test_names.append(expected_name)
1978
test_names = _get_test_names_for_test_class(device_cls)
1979
self.assertEqual(sorted(expected_test_names), sorted(test_names))
1981
def test_modules_composition_names(self, device):
1982
device = self.device_type
1984
class TestParametrized(TestCase):
1986
@parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
1987
def test_module_parametrized(self, device, dtype, module_info, training, flag):
1990
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1992
device_cls = locals()[f'TestParametrized{device.upper()}']
1993
expected_test_names = []
1994
for module_info in module_db:
1995
for dtype in module_info.dtypes:
1996
for flag_part in ('flag_disabled', 'flag_enabled'):
1997
expected_train_modes = (
1998
['train_mode', 'eval_mode'] if module_info.train_and_eval_differ else [''])
1999
for training_part in expected_train_modes:
2000
expected_name = '{}.test_module_parametrized_{}{}_{}_{}_{}'.format(
2001
device_cls.__name__, module_info.formatted_name,
2002
'_' + training_part if len(training_part) > 0 else '',
2003
flag_part, device, dtype_name(dtype))
2004
expected_test_names.append(expected_name)
2006
test_names = _get_test_names_for_test_class(device_cls)
2007
self.assertEqual(sorted(expected_test_names), sorted(test_names))
2009
def test_ops_decorator_applies_op_and_param_specific_decorators(self, device):
2017
func._decorator_applied = True
2020
test_op_info = OpInfo(
2023
dtypes=floating_types(),
2024
sample_inputs_func=lambda _: [],
2026
DecorateInfo(test_dec, 'TestParametrized', 'test_op_param',
2027
device_type='cpu', dtypes=[torch.float64],
2028
active_if=lambda p: p['x'] == 2)
2031
class TestParametrized(TestCase):
2032
@ops(op_db + [test_op_info])
2033
@parametrize("x", [2, 3])
2034
def test_op_param(self, device, dtype, op, x):
2037
@ops(op_db + [test_op_info])
2040
subtest(5, decorators=[test_dec])])
2041
def test_other(self, device, dtype, op, y):
2044
@decorateIf(test_dec, lambda p: p['dtype'] == torch.int16)
2046
def test_three(self, device, dtype, op):
2049
device = self.device_type
2050
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2051
device_cls = locals()[f'TestParametrized{device.upper()}']
2053
for test_func, name in _get_test_funcs_for_test_class(device_cls):
2054
should_apply = (name == 'test_op_param_test_op_x_2_cpu_float64' or
2055
('test_other' in name and 'y_5' in name) or
2056
('test_three' in name and name.endswith('_int16')))
2057
self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
2059
def test_modules_decorator_applies_module_and_param_specific_decorators(self, device):
2063
class TestModule(torch.nn.Module):
2064
def __init__(self) -> None:
2066
self.x = torch.nn.Parameter(torch.randn(3))
2068
def forward(self, y):
2072
func._decorator_applied = True
2075
test_module_info = ModuleInfo(
2077
module_inputs_func=lambda _: [],
2079
DecorateInfo(test_dec, 'TestParametrized', 'test_module_param',
2080
device_type='cpu', dtypes=[torch.float64],
2081
active_if=lambda p: p['x'] == 2)
2084
class TestParametrized(TestCase):
2085
@modules(module_db + [test_module_info])
2086
@parametrize("x", [2, 3])
2087
def test_module_param(self, device, dtype, module_info, training, x):
2090
@modules(module_db + [test_module_info])
2093
subtest(5, decorators=[test_dec])])
2094
def test_other(self, device, dtype, module_info, training, y):
2097
@decorateIf(test_dec, lambda p: p['dtype'] == torch.float64)
2099
def test_three(self, device, dtype, module_info):
2102
device = self.device_type
2103
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2104
device_cls = locals()[f'TestParametrized{device.upper()}']
2106
for test_func, name in _get_test_funcs_for_test_class(device_cls):
2107
should_apply = (name == 'test_module_param_TestModule_x_2_cpu_float64' or
2108
('test_other' in name and 'y_5' in name) or
2109
('test_three' in name and name.endswith('float64')))
2110
self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
2112
def test_param_specific_decoration(self, device):
2115
func._decorator_applied = True
2118
class TestParametrized(TestCase):
2119
@decorateIf(test_dec, lambda params: params["x"] == 1 and params["y"])
2120
@parametrize("x", range(5))
2121
@parametrize("y", [False, True])
2122
def test_param(self, x, y):
2125
device = self.device_type
2126
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2127
device_cls = locals()[f'TestParametrized{device.upper()}']
2129
for test_func, name in _get_test_funcs_for_test_class(device_cls):
2130
should_apply = ('test_param_x_1_y_True' in name)
2131
self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
2133
def test_dtypes_composition_valid(self, device):
2137
device = self.device_type
2139
class TestParametrized(TestCase):
2140
@dtypes(torch.float32, torch.float64)
2141
@parametrize("x", range(3))
2142
def test_parametrized(self, x, dtype):
2145
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2147
device_cls = locals()[f'TestParametrized{device.upper()}']
2148
expected_test_names = [name.format(device_cls.__name__, device) for name in (
2149
'{}.test_parametrized_x_0_{}_float32',
2150
'{}.test_parametrized_x_0_{}_float64',
2151
'{}.test_parametrized_x_1_{}_float32',
2152
'{}.test_parametrized_x_1_{}_float64',
2153
'{}.test_parametrized_x_2_{}_float32',
2154
'{}.test_parametrized_x_2_{}_float64')
2156
test_names = _get_test_names_for_test_class(device_cls)
2157
self.assertEqual(sorted(expected_test_names), sorted(test_names))
2159
def test_dtypes_composition_invalid(self, device):
2163
device = self.device_type
2165
class TestParametrized(TestCase):
2166
@dtypes(torch.float32, torch.float64)
2167
@parametrize("dtype", [torch.int32, torch.int64])
2168
def test_parametrized(self, dtype):
2171
with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
2172
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2176
class TestParametrized(TestCase):
2177
@dtypes(torch.float32, torch.float64)
2179
def test_parametrized(self, op, dtype):
2182
with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
2183
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2185
def test_multiple_handling_of_same_param_error(self, device):
2189
class TestParametrized(TestCase):
2192
def test_param(self, device, dtype, op, module_info, training):
2195
with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
2196
instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2198
@parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
2199
def test_subtest_expected_failure(self, device, x):
2201
raise RuntimeError('Boom')
2203
@parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
2204
@parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
2205
def test_two_things_subtest_expected_failure(self, device, x, y):
2206
if x == 1 or y == 6:
2207
raise RuntimeError('Boom')
2210
instantiate_parametrized_tests(TestTestParametrization)
2211
instantiate_device_type_tests(TestTestParametrizationDeviceType, globals())
2214
class TestImports(TestCase):
2216
def _check_python_output(cls, program) -> str:
2217
return subprocess.check_output(
2218
[sys.executable, "-W", "always", "-c", program],
2219
stderr=subprocess.STDOUT,
2222
cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
2224
def test_circular_dependencies(self) -> None:
2225
""" Checks that all modules inside torch can be imported
2226
Prevents regression reported in https://github.com/pytorch/pytorch/issues/77441 """
2227
ignored_modules = ["torch.utils.tensorboard",
2228
"torch.distributed.elastic.rendezvous",
2229
"torch.backends._coreml",
2231
"torch.testing._internal.distributed.",
2232
"torch.ao.pruning._experimental.",
2233
"torch.onnx._internal",
2234
"torch._inductor.runtime.triton_helpers",
2235
"torch._inductor.codegen.cuda",
2238
if not sys.version_info >= (3, 9):
2239
ignored_modules.append("torch.utils.benchmark")
2240
if IS_WINDOWS or IS_MACOS or IS_JETSON:
2242
if IS_MACOS or IS_JETSON:
2243
ignored_modules.append("torch.distributed.")
2245
ignored_modules.append("torch.distributed.nn.api.")
2246
ignored_modules.append("torch.distributed.optim.")
2247
ignored_modules.append("torch.distributed.rpc.")
2248
ignored_modules.append("torch.testing._internal.dist_utils")
2250
ignored_modules.append("torch.nn.parallel._replicated_tensor_ddp_interop")
2251
ignored_modules.append("torch.testing._internal.common_fsdp")
2252
ignored_modules.append("torch.testing._internal.common_distributed")
2254
torch_dir = os.path.dirname(torch.__file__)
2255
for base, folders, files in os.walk(torch_dir):
2256
prefix = os.path.relpath(base, os.path.dirname(torch_dir)).replace(os.path.sep, ".")
2258
if not f.endswith(".py"):
2260
mod_name = f"{prefix}.{f[:-3]}" if f != "__init__.py" else prefix
2262
if f == "__main__.py":
2264
if any(mod_name.startswith(x) for x in ignored_modules):
2267
mod = importlib.import_module(mod_name)
2268
except Exception as e:
2269
raise RuntimeError(f"Failed to import {mod_name}: {e}") from e
2270
self.assertTrue(inspect.ismodule(mod))
2272
@unittest.skipIf(IS_WINDOWS, "TODO enable on Windows")
2273
def test_lazy_imports_are_lazy(self) -> None:
2274
out = self._check_python_output("import sys;import torch;print(all(x not in sys.modules for x in torch._lazy_modules))")
2275
self.assertEqual(out.strip(), "True")
2277
@unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning")
2278
def test_no_warning_on_import(self) -> None:
2279
out = self._check_python_output("import torch")
2280
self.assertEqual(out, "")
2282
def test_not_import_sympy(self) -> None:
2283
out = self._check_python_output("import torch;import sys;print('sympy' not in sys.modules)")
2284
self.assertEqual(out.strip(), "True",
2285
"PyTorch should not depend on SymPy at import time as importing SymPy is *very* slow.\n"
2286
"See the beginning of the following blog post for how to profile and find which file is importing sympy:\n"
2287
"https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589\n\n"
2288
"If you hit this error, you may want to:\n"
2289
" - Refactor your code to avoid depending on sympy files you may not need to depend\n"
2290
" - Use TYPE_CHECKING if you are using sympy + strings if you are using sympy on type annotations\n"
2291
" - Import things that depend on SymPy locally")
2293
@unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning")
2294
@parametrize('path', ['torch', 'functorch'])
2295
def test_no_mutate_global_logging_on_import(self, path) -> None:
2299
expected = 'abcdefghijklmnopqrstuvwxyz'
2303
'_logger = logging.getLogger("torch_test_testing")',
2304
'logging.root.addHandler(logging.StreamHandler())',
2305
'logging.root.setLevel(logging.INFO)',
2306
f'_logger.info("{expected}")'
2308
out = self._check_python_output("; ".join(commands))
2309
self.assertEqual(out.strip(), expected)
2311
class TestOpInfos(TestCase):
2312
def test_sample_input(self) -> None:
2313
a, b, c, d, e = (object() for _ in range(5))
2316
s = SampleInput(a, b, c, d=d, e=e)
2318
assert s.args == (b, c)
2319
assert s.kwargs == dict(d=d, e=e)
2322
s = SampleInput(a, args=(b,), kwargs=dict(c=c, d=d, e=e))
2324
assert s.args == (b,)
2325
assert s.kwargs == dict(c=c, d=d, e=e)
2328
with self.assertRaises(AssertionError):
2329
s = SampleInput(a, b, c, args=(d, e))
2331
with self.assertRaises(AssertionError):
2332
s = SampleInput(a, b, c, kwargs=dict(d=d, e=e))
2334
with self.assertRaises(AssertionError):
2335
s = SampleInput(a, args=(b, c), d=d, e=e)
2337
with self.assertRaises(AssertionError):
2338
s = SampleInput(a, b, c=c, kwargs=dict(d=d, e=e))
2341
with self.assertRaises(AssertionError):
2342
s = SampleInput(a, b, name="foo")
2344
with self.assertRaises(AssertionError):
2345
s = SampleInput(a, b, output_process_fn_grad=lambda x: x)
2347
with self.assertRaises(AssertionError):
2348
s = SampleInput(a, b, broadcasts_input=True)
2352
s = SampleInput(a, broadcasts_input=True)
2354
assert s.broadcasts_input
2356
def test_sample_input_metadata(self) -> None:
2357
a, b = (object() for _ in range(2))
2358
s1 = SampleInput(a, b=b)
2359
self.assertIs(s1.output_process_fn_grad(None), None)
2360
self.assertFalse(s1.broadcasts_input)
2361
self.assertEqual(s1.name, "")
2363
s2 = s1.with_metadata(
2364
output_process_fn_grad=lambda x: a,
2365
broadcasts_input=True,
2368
self.assertIs(s1, s2)
2369
self.assertIs(s2.output_process_fn_grad(None), a)
2370
self.assertTrue(s2.broadcasts_input)
2371
self.assertEqual(s2.name, "foo")
2375
class TestOpInfoSampleFunctions(TestCase):
2377
@ops(op_db, dtypes=OpDTypes.any_one)
2378
def test_opinfo_sample_generators(self, device, dtype, op):
2380
samples = op.sample_inputs(device, dtype)
2381
self.assertIsInstance(samples, Iterator)
2383
@ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one)
2384
def test_opinfo_reference_generators(self, device, dtype, op):
2386
samples = op.reference_inputs(device, dtype)
2387
self.assertIsInstance(samples, Iterator)
2389
@ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
2390
def test_opinfo_error_generators(self, device, op):
2392
samples = op.error_inputs(device)
2393
self.assertIsInstance(samples, Iterator)
2396
instantiate_device_type_tests(TestOpInfoSampleFunctions, globals())
2397
instantiate_parametrized_tests(TestImports)
2400
if __name__ == '__main__':