4
from torch.testing._internal.common_utils import TestCase, run_tests
6
class TestPerOverloadAPI(TestCase):
7
def test_basics_opoverloadpacket(self):
10
add_packet = torch.ops.aten.add
13
self.assertEqual(add_packet.__name__, 'add')
14
self.assertEqual(str(add_packet), 'aten.add')
17
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
20
self.assertEqual(add_packet.__module__, add_packet.op.__module__)
23
another_add_packet = torch.ops.aten.add
24
self.assertEqual(id(add_packet), id(another_add_packet))
27
self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
30
self.assertEqual(repr(add_packet), "<OpOverloadPacket(op='aten.add')>")
32
self.assertRaises(AttributeError, lambda: add_packet.foo)
34
def test_basics_opoverload(self):
35
add_packet = torch.ops.aten.add
36
add_tensoroverload = add_packet.Tensor
39
self.assertEqual(str(add_tensoroverload), 'aten.add.Tensor')
40
self.assertEqual(add_tensoroverload.__name__, 'add.Tensor')
41
self.assertEqual(add_tensoroverload.overloadpacket, add_packet)
44
self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
47
another_add_tensoroverload = torch.ops.aten.add.Tensor
48
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
51
self.assertEqual(repr(add_tensoroverload), "<OpOverload(op='aten.add', overload='Tensor')>")
54
self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
58
torch.ops.aten.add.out(a, a, out=b)
59
self.assertEqual(b, torch.tensor(4))
61
self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
63
def test_decompose(self):
67
torch.ops.aten.linear.default.decompose(x, y),
68
torch.ops.aten.linear.default(x, y)
71
if __name__ == '__main__':