pytorch

Форк
0
/
test_per_overload_api.py 
72 строки · 2.5 Кб
1
# Owner(s): ["module: unknown"]
2
import torch
3
import copy
4
from torch.testing._internal.common_utils import TestCase, run_tests
5

6
class TestPerOverloadAPI(TestCase):
7
    def test_basics_opoverloadpacket(self):
8
        # add is ony used as an example here. It is ok to update the test
9
        # if the semantics of add are modified in the future.
10
        add_packet = torch.ops.aten.add
11

12
        # class attributes
13
        self.assertEqual(add_packet.__name__, 'add')
14
        self.assertEqual(str(add_packet), 'aten.add')
15

16
        # callable
17
        self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
18

19
        # correct module
20
        self.assertEqual(add_packet.__module__, add_packet.op.__module__)
21

22
        # caching
23
        another_add_packet = torch.ops.aten.add
24
        self.assertEqual(id(add_packet), id(another_add_packet))
25

26
        # deepcopy is a no-op
27
        self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
28

29
        # pretty print
30
        self.assertEqual(repr(add_packet), "<OpOverloadPacket(op='aten.add')>")
31

32
        self.assertRaises(AttributeError, lambda: add_packet.foo)
33

34
    def test_basics_opoverload(self):
35
        add_packet = torch.ops.aten.add
36
        add_tensoroverload = add_packet.Tensor
37

38
        # class attributes
39
        self.assertEqual(str(add_tensoroverload), 'aten.add.Tensor')
40
        self.assertEqual(add_tensoroverload.__name__, 'add.Tensor')
41
        self.assertEqual(add_tensoroverload.overloadpacket, add_packet)
42

43
        # deepcopy is a no-op
44
        self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
45

46
        # caching
47
        another_add_tensoroverload = torch.ops.aten.add.Tensor
48
        self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
49

50
        # pretty print
51
        self.assertEqual(repr(add_tensoroverload), "<OpOverload(op='aten.add', overload='Tensor')>")
52

53
        # callable
54
        self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
55

56
        a = torch.tensor(2)
57
        b = torch.tensor(0)
58
        torch.ops.aten.add.out(a, a, out=b)
59
        self.assertEqual(b, torch.tensor(4))
60

61
        self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
62

63
    def test_decompose(self):
64
        x = torch.randn(2, 3)
65
        y = torch.randn(5, 3)
66
        self.assertEqual(
67
            torch.ops.aten.linear.default.decompose(x, y),
68
            torch.ops.aten.linear.default(x, y)
69
        )
70

71
if __name__ == '__main__':
72
    run_tests()
73

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.