pytorch
45 строк · 1.2 Кб
1import unittest
2
3import benchmark_cpp_extension # noqa: F401
4import torch
5
6
7class TestConsumeOp(unittest.TestCase):
8def test_jit_consume_op(self):
9iters = 6
10
11def foo(x):
12for i in range(iters):
13result = torch.ops.operator_benchmark._consume(torch.sum(x))
14return result
15
16r = torch.jit.trace(foo, (torch.rand(2, 2)))
17
18graph = str(r.graph)
19occurrence = graph.count("aten::sum")
20
21x = torch.rand(2, 2)
22value = r(x)
23self.assertEqual(value, torch.sum(x))
24self.assertEqual(occurrence, iters)
25
26def test_jit_consume_op_for_list_input(self):
27iters = 6
28
29def foo(x):
30for i in range(iters):
31result = torch.ops.operator_benchmark._consume(torch.chunk(x, 2))
32return result
33
34r = torch.jit.trace(foo, torch.rand(2, 2))
35
36graph = str(r.graph)
37occurrence = graph.count("aten::chunk")
38
39x = torch.rand(2, 2)
40value = r(x)
41
42self.assertTrue(
43all(torch.allclose(t1, t2) for t1, t2 in zip(value, torch.chunk(x, 2)))
44)
45self.assertEqual(occurrence, iters)
46