pytorch
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5
6import torch
7from torch.testing import FileCheck
8
9
10# Make the helper files in test/ importable
11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12sys.path.append(pytorch_test_dir)
13from torch.testing._internal.jit_utils import JitTestCase
14
15
16if __name__ == "__main__":
17raise RuntimeError(
18"This test file is not meant to be run directly, use:\n\n"
19"\tpython test/test_jit.py TESTNAME\n\n"
20"instead."
21)
22
23
24class TestFunctionalBlocks(JitTestCase):
25def test_subgraph_creation(self):
26def fn(x, y, z):
27x = x + 1
28y = y + 1
29z = z + 1
30z.add_(2)
31z = z * z
32y = y * z
33if y < 2:
34y = y + 5
35return x + y + z
36
37graph = torch.jit.script(fn).graph
38self.run_pass("create_functional_graphs", graph)
39
40# all uses of x and y should be sunk
41FileCheck().check(r"%x").check_not(r"%x").check("FunctionalGraph").check(
42r"%x"
43).run(graph)
44FileCheck().check(r"%y").check_not(r"%y").check("FunctionalGraph").check(
45r"%y"
46).run(graph)
47
48# Don't allow any outputs which escape scope, so there is one final addition in the graph
49FileCheck().check("Tensor = prim::Functional").check_next("aten::add").run(
50graph
51)
52
53# z + 1, z.add_(2) considered non functional, z = z * z should be considered functional
54FileCheck().check("add").check("add_").check_not("mul").check(
55"FunctionalGraph"
56).run(graph)
57