pytorch
1#!/usr/bin/env python3
2# Owner(s): ["module: internals"]
3
4import torch5from torch.testing._internal.common_utils import TestCase, run_tests6
7class TestComparisonUtils(TestCase):8def test_all_equal_no_assert(self):9t = torch.tensor([0.5])10torch._assert_tensor_metadata(t, [1], [1], torch.float)11
12def test_all_equal_no_assert_nones(self):13t = torch.tensor([0.5])14torch._assert_tensor_metadata(t, None, None, None)15
16def test_assert_dtype(self):17t = torch.tensor([0.5])18
19with self.assertRaises(RuntimeError):20torch._assert_tensor_metadata(t, None, None, torch.int32)21
22def test_assert_strides(self):23t = torch.tensor([0.5])24
25with self.assertRaises(RuntimeError):26torch._assert_tensor_metadata(t, None, [3], torch.float)27
28def test_assert_sizes(self):29t = torch.tensor([0.5])30
31with self.assertRaises(RuntimeError):32torch._assert_tensor_metadata(t, [3], [1], torch.float)33
34
35if __name__ == '__main__':36run_tests()37