pytorch

Форк
0
45 строк · 1.2 Кб
1
import unittest
2

3
import benchmark_cpp_extension  # noqa: F401
4
import torch
5

6

7
class TestConsumeOp(unittest.TestCase):
8
    def test_jit_consume_op(self):
9
        iters = 6
10

11
        def foo(x):
12
            for i in range(iters):
13
                result = torch.ops.operator_benchmark._consume(torch.sum(x))
14
            return result
15

16
        r = torch.jit.trace(foo, (torch.rand(2, 2)))
17

18
        graph = str(r.graph)
19
        occurrence = graph.count("aten::sum")
20

21
        x = torch.rand(2, 2)
22
        value = r(x)
23
        self.assertEqual(value, torch.sum(x))
24
        self.assertEqual(occurrence, iters)
25

26
    def test_jit_consume_op_for_list_input(self):
27
        iters = 6
28

29
        def foo(x):
30
            for i in range(iters):
31
                result = torch.ops.operator_benchmark._consume(torch.chunk(x, 2))
32
            return result
33

34
        r = torch.jit.trace(foo, torch.rand(2, 2))
35

36
        graph = str(r.graph)
37
        occurrence = graph.count("aten::chunk")
38

39
        x = torch.rand(2, 2)
40
        value = r(x)
41

42
        self.assertTrue(
43
            all(torch.allclose(t1, t2) for t1, t2 in zip(value, torch.chunk(x, 2)))
44
        )
45
        self.assertEqual(occurrence, iters)
46

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.