pytorch
1#!/usr/bin/env python3
2# Owner(s): ["module: internals"]
3
4import torch5from torch.testing._internal.common_utils import run_tests, TestCase6
7
8class TestComparisonUtils(TestCase):9def test_all_equal_no_assert(self):10t = torch.tensor([0.5])11torch._assert_tensor_metadata(t, [1], [1], torch.float)12
13def test_all_equal_no_assert_nones(self):14t = torch.tensor([0.5])15torch._assert_tensor_metadata(t, None, None, None)16
17def test_assert_dtype(self):18t = torch.tensor([0.5])19
20with self.assertRaises(RuntimeError):21torch._assert_tensor_metadata(t, None, None, torch.int32)22
23def test_assert_strides(self):24t = torch.tensor([0.5])25
26with self.assertRaises(RuntimeError):27torch._assert_tensor_metadata(t, None, [3], torch.float)28
29def test_assert_sizes(self):30t = torch.tensor([0.5])31
32with self.assertRaises(RuntimeError):33torch._assert_tensor_metadata(t, [3], [1], torch.float)34
35
36if __name__ == "__main__":37run_tests()38