pytorch

Форк
0
/
test_dce_pass.py 
175 строк · 5.5 Кб
1
# Owner(s): ["module: fx"]
2

3
from typing import Set, Type
4
import torch
5
import torch.fx
6

7
from torch.testing._internal.common_utils import TestCase
8

9

10
class TestDCE(TestCase):
11
    def _has_nodes_without_users(self, m: torch.fx.GraphModule):
12
        for node in m.graph.nodes:
13
            if node.is_impure():
14
                continue
15
            if len(node.users) == 0:
16
                return True
17
        return False
18

19
    def _get_num_placeholders(self, m: torch.fx.GraphModule) -> int:
20
        count = 0
21
        for node in m.graph.nodes:
22
            if node.op == "placeholder":
23
                count += 1
24
        return count
25

26
    def _run_dce_and_test(
27
        self,
28
        m: torch.nn.Module,
29
        expect_dce_changes: bool,
30
        modules_to_be_leafs: Set[Type] = None,
31
    ):
32
        class TestTracer(torch.fx.Tracer):
33
            def is_leaf_module(self, m, qualname):
34
                if modules_to_be_leafs and type(m) in modules_to_be_leafs:
35
                    return True
36
                return super().trace(m, qualname)
37

38
        traced: torch.fx.GraphModule = torch.fx.GraphModule(m, TestTracer().trace(m))
39
        print(str(traced.graph))
40

41
        # Verify there are nodes without users (if expected).
42
        has_nodes_without_users = self._has_nodes_without_users(traced)
43
        if expect_dce_changes:
44
            self.assertTrue(has_nodes_without_users)
45
        else:
46
            self.assertFalse(has_nodes_without_users)
47

48
        # Get the original number of placeholders to verify it doesn't change
49
        # during DCE.
50
        orig_num_phs = self._get_num_placeholders(traced)
51
        changed = traced.graph.eliminate_dead_code()
52

53
        self.assertTrue(changed if expect_dce_changes else not changed)
54

55
        # Verify there are no nodes without users after DCE is run.
56
        self.assertFalse(self._has_nodes_without_users(traced))
57
        new_num_phs = self._get_num_placeholders(traced)
58
        self.assertEqual(orig_num_phs, new_num_phs)
59

60
        traced.recompile()
61
        # Make sure we run and get the same results before/after DCE.
62
        inputs = [torch.tensor([1.5])] * new_num_phs
63
        self.assertTrue(torch.equal(m(*inputs), traced(*inputs)))
64

65
    def test_simple(self):
66
        """
67
        Tests that a single node in the graph is DCE'd correctly.
68
        """
69

70
        class TestModule(torch.nn.Module):
71
            def __init__(self):
72
                super().__init__()
73
                self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
74

75
            def forward(self, x):
76
                a = x + 1
77
                return x + self.attr_1
78

79
        self._run_dce_and_test(TestModule(), expect_dce_changes=True)
80

81
    def test_dead_chain(self):
82
        """
83
        Tests that a chain of two nodes in the graph are DCE'd correctly.
84
        """
85

86
        class TestModule(torch.nn.Module):
87
            def __init__(self):
88
                super().__init__()
89
                self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
90

91
            def forward(self, x):
92
                a = x + 1
93
                b = a * 7
94
                return x + self.attr_1
95

96
        self._run_dce_and_test(TestModule(), expect_dce_changes=True)
97

98
    def test_dead_getattr(self):
99
        """
100
        Tests that a getatrr in the graph is DCE'd correctly.
101
        """
102

103
        class TestModule(torch.nn.Module):
104
            def __init__(self):
105
                super().__init__()
106
                self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
107

108
            def forward(self, x):
109
                a = x + 1
110
                b = a * self.attr_1
111
                return x + 11
112

113
        self._run_dce_and_test(TestModule(), expect_dce_changes=True)
114

115
    def test_dead_placeholder(self):
116
        """
117
        Tests that a placeholder in the graph is not DCE'd, as that would change
118
        the function signature.
119
        """
120

121
        class TestModule(torch.nn.Module):
122
            def forward(self, x, y):
123
                return x + 7
124

125
        self._run_dce_and_test(TestModule(), expect_dce_changes=False)
126

127
    def test_dead_placeholder_with_user(self):
128
        """
129
        Tests that a placeholder in the graph is not DCE'd, as that would change
130
        the function signature. Also verifies that a dead node that uses the
131
        placeholder is DCE'd.
132

133
        """
134

135
        class TestModule(torch.nn.Module):
136
            def forward(self, x, y):
137
                a = y + 2
138
                return x + 7
139

140
        self._run_dce_and_test(TestModule(), expect_dce_changes=True)
141

142
    def test_keep_module_with_side_effects(self):
143
        """
144
        Test that DCE doesn't remove a module if it's specified as having side effects.
145
        """
146

147
        class ReLUImpure(torch.nn.ReLU):
148
            _is_impure = True
149

150
        class TestModule(torch.nn.Module):
151
            def __init__(self):
152
                super().__init__()
153
                self.relu = ReLUImpure()
154

155
            def forward(self, a: torch.Tensor) -> torch.Tensor:
156
                r = self.relu(a)
157
                return a * 2
158

159
        self._run_dce_and_test(
160
            TestModule(), expect_dce_changes=False, modules_to_be_leafs={ReLUImpure}
161
        )
162

163
    def test_keep_torch_assert(self):
164
        """
165
        Test that DCE doesn't remove torch._assert since it has side effects.
166
        """
167

168
        class TestModule(torch.nn.Module):
169
            def forward(self, a: torch.Tensor) -> torch.Tensor:
170
                torch._assert(torch.equal(a, a), "a must equal a")
171
                return a * 2
172

173
        # Note: Don't need to specify torch._assert as having side effects
174
        # because it's known to.
175
        self._run_dce_and_test(TestModule(), expect_dce_changes=False)
176

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.