pytorch

Форк
0
/
test_pass_infra.py 
230 строк · 7.4 Кб
1
# Owner(s): ["module: fx"]
2

3
import torch
4
import torch.fx as fx
5
from torch.fx.passes.infra.pass_base import PassBase, PassResult
6
from torch.fx.passes.infra.pass_manager import (
7
    _topological_sort_passes,
8
    pass_result_wrapper,
9
    PassManager,
10
    this_before_that_pass_constraint,
11
)
12
from torch.testing._internal.common_utils import TestCase
13

14

15
# Pass that uses PassBase and returns a PassResult (best scenario)
16
class ReplaceAddWithMulPass(PassBase):
17
    def call(self, gm) -> PassResult:
18
        modified = False
19
        for node in gm.graph.nodes:
20
            if node.op == "call_function" and node.target == torch.add:
21
                node.target = torch.mul
22
                modified = True
23
        return PassResult(gm, modified)
24

25

26
# Pass that is a callable and returns a PassResult
27
def replace_mul_with_div_pass(gm) -> PassResult:
28
    modified = False
29
    for node in gm.graph.nodes:
30
        if node.op == "call_function" and node.target == torch.mul:
31
            node.target = torch.div
32
            modified = True
33
    return PassResult(gm, modified)
34

35

36
# Pass that is a PassBase and does not return a PassResult
37
# Need to wrap with pass_result_wrapper or else it will fail
38
class ReplaceDivWithSubPass(PassBase):
39
    def call(self, gm) -> PassResult:
40
        for node in gm.graph.nodes:
41
            if node.op == "call_function" and node.target == torch.div:
42
                node.target = torch.sub
43

44

45
# Pass that is a callable and does not return a PassResult
46
# Need to wrap with pass_result_wrapper or else it will fail
47
def replace_sub_with_add_pass(gm) -> PassResult:
48
    for node in gm.graph.nodes:
49
        if node.op == "call_function" and node.target == torch.sub:
50
            node.target = torch.add
51

52

53
class AddModule(torch.nn.Module):
54
    def forward(self, x):
55
        y = torch.add(x, x)
56
        z = torch.add(y, x)
57
        return z
58

59

60
class TestPassManager(TestCase):
61
    def test_pass_manager(self):
62
        """
63
        Tests that the pass manager runs the passes correctly.
64
        """
65

66
        m = AddModule()
67
        traced_m = torch.fx.symbolic_trace(m)
68
        pm = PassManager(
69
            passes=[
70
                ReplaceAddWithMulPass(),
71
                replace_mul_with_div_pass,
72
                pass_result_wrapper(ReplaceDivWithSubPass()),
73
                pass_result_wrapper(replace_sub_with_add_pass),
74
            ],
75
            steps=5,
76
        )
77

78
        pm.validate_constraints()
79
        self.assertEqual(len(pm.passes), 4)
80

81
        res = pm(traced_m)
82
        modified_m = res.graph_module
83
        assert isinstance(modified_m, fx.GraphModule)
84

85
        # Check that all call_function nodes are divs
86
        for node in modified_m.graph.nodes:
87
            if node.op == "call_function":
88
                self.assertEqual(node.target, torch.add)
89

90
    def test_this_before_that_pass_constraint(self):
91
        """
92
        Tests the construction of constraints
93
        """
94
        passes = [lambda x: 2 * x for _ in range(10)]
95
        pm = PassManager(passes)
96

97
        # add unfulfillable constraint
98
        pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0]))
99

100
        with self.assertRaises(RuntimeError):
101
            pm.validate_constraints()
102

103
    def test_pass_manager_checks(self):
104
        """
105
        Tests that users can add in check functions correctly
106
        """
107
        m = AddModule()
108
        traced_m = fx.symbolic_trace(m)
109
        pm = PassManager(passes=[ReplaceAddWithMulPass(), replace_mul_with_div_pass])
110

111
        def check_div_target(graph_module):
112
            for node in graph_module.graph.nodes:
113
                if node.op == "call_function" and node.target != torch.div:
114
                    raise ValueError("Target should be div!")
115

116
        pm.add_checks(check_div_target)
117

118
        with self.assertRaises(ValueError):
119
            pm(traced_m)
120

121
    def test_pass_manager_bad_checks(self):
122
        """
123
        Checks that we error if we pass in a check function with the wrong parameters
124
        """
125

126
        def check_bad_args(graph_module, i):
127
            pass
128

129
        pm = PassManager()
130
        self.assertRaises(TypeError, pm.add_checks, check_bad_args)
131

132
    def test_topological_sort(self):
133
        """
134
        Tests that passes are correctly ordered based on contraints.
135
        """
136

137
        def pass0(x):
138
            return x
139

140
        def pass1(x):
141
            return x + 1
142

143
        def pass2(x):
144
            return x + 2
145

146
        def pass3(x):
147
            return x + 3
148

149
        def pass4(x):
150
            return x + 4
151

152
        def pass5(x):
153
            return x + 5
154

155
        # Not passing any constraints should keep the original order
156
        passes = [pass0, pass1, pass2, pass3, pass4, pass5]
157
        sorted = _topological_sort_passes(passes, [])
158
        self.assertEqual(sorted, passes)
159

160
        # Graph that we are constructing:
161
        #     5 ---->  0  <---- 4
162
        #     |                 |
163
        #     +-> 2 -> 3 -> 1 <-+
164
        # Which has a possible topological order of: [4, 5, 0, 2, 3, 1]
165
        passes = [pass0, pass1, pass2, pass3, pass4, pass5]
166
        constraints = [
167
            this_before_that_pass_constraint(pass5, pass0),
168
            this_before_that_pass_constraint(pass5, pass2),
169
            this_before_that_pass_constraint(pass4, pass0),
170
            this_before_that_pass_constraint(pass4, pass1),
171
            this_before_that_pass_constraint(pass2, pass3),
172
            this_before_that_pass_constraint(pass3, pass1),
173
        ]
174
        sorted = _topological_sort_passes(passes, constraints)
175
        self.assertEqual(sorted, [pass4, pass5, pass0, pass2, pass3, pass1])
176

177
        # Circular dependency should result in the circular_dep flag being set
178
        passes = [pass0, pass1, pass2]
179
        constraints = [
180
            this_before_that_pass_constraint(passes[0], passes[1]),
181
            this_before_that_pass_constraint(passes[1], passes[2]),
182
            this_before_that_pass_constraint(passes[2], passes[0]),
183
        ]
184
        with self.assertRaises(RuntimeError) as e:
185
            _topological_sort_passes(passes, constraints)
186
        expected_error_msg = (
187
            f"Circular dependency detected within the following passes: {passes}"
188
        )
189
        self.assertEqual(e.exception.args[0], expected_error_msg)
190

191
    def test_pass_manager_error(self):
192
        """
193
        Tests error catching + debug
194
        """
195

196
        def pass_fail(graph_module):
197
            raise RuntimeError("bad")
198

199
        m = AddModule()
200
        traced_m = torch.fx.symbolic_trace(m)
201
        pm = PassManager(
202
            passes=[
203
                ReplaceAddWithMulPass(),
204
                replace_mul_with_div_pass,
205
                ReplaceDivWithSubPass(),
206
                pass_result_wrapper(replace_sub_with_add_pass),
207
            ],
208
        )
209

210
        # Comment out this line to see the actual error message
211
        error_msg = (
212
            "ReplaceDivWithSubPass.*ReplaceAddWithMulPass.*replace_mul_with_div_pass"
213
        )
214
        with self.assertRaisesRegex(Exception, error_msg):
215
            pm(traced_m)
216

217
        pm = PassManager(
218
            passes=[
219
                ReplaceAddWithMulPass(),
220
                replace_mul_with_div_pass,
221
                pass_result_wrapper(ReplaceDivWithSubPass()),
222
                pass_result_wrapper(replace_sub_with_add_pass),
223
                pass_fail,
224
            ],
225
        )
226

227
        # Comment out this line to see the actual error message
228
        error_msg = "pass_fail.*ReplaceAddWithMulPass.*replace_mul_with_div_pass.*ReplaceDivWithSubPass.*replace_sub_with_add_pass"
229
        with self.assertRaisesRegex(Exception, error_msg):
230
            pm(traced_m)
231

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.