1
# Owner(s): ["module: fx"]
5
from torch.fx.passes.infra.pass_base import PassBase, PassResult
6
from torch.fx.passes.infra.pass_manager import (
7
_topological_sort_passes,
10
this_before_that_pass_constraint,
12
from torch.testing._internal.common_utils import TestCase
15
# Pass that uses PassBase and returns a PassResult (best scenario)
16
class ReplaceAddWithMulPass(PassBase):
17
def call(self, gm) -> PassResult:
19
for node in gm.graph.nodes:
20
if node.op == "call_function" and node.target == torch.add:
21
node.target = torch.mul
23
return PassResult(gm, modified)
26
# Pass that is a callable and returns a PassResult
27
def replace_mul_with_div_pass(gm) -> PassResult:
29
for node in gm.graph.nodes:
30
if node.op == "call_function" and node.target == torch.mul:
31
node.target = torch.div
33
return PassResult(gm, modified)
36
# Pass that is a PassBase and does not return a PassResult
37
# Need to wrap with pass_result_wrapper or else it will fail
38
class ReplaceDivWithSubPass(PassBase):
39
def call(self, gm) -> PassResult:
40
for node in gm.graph.nodes:
41
if node.op == "call_function" and node.target == torch.div:
42
node.target = torch.sub
45
# Pass that is a callable and does not return a PassResult
46
# Need to wrap with pass_result_wrapper or else it will fail
47
def replace_sub_with_add_pass(gm) -> PassResult:
48
for node in gm.graph.nodes:
49
if node.op == "call_function" and node.target == torch.sub:
50
node.target = torch.add
53
class AddModule(torch.nn.Module):
60
class TestPassManager(TestCase):
61
def test_pass_manager(self):
63
Tests that the pass manager runs the passes correctly.
67
traced_m = torch.fx.symbolic_trace(m)
70
ReplaceAddWithMulPass(),
71
replace_mul_with_div_pass,
72
pass_result_wrapper(ReplaceDivWithSubPass()),
73
pass_result_wrapper(replace_sub_with_add_pass),
78
pm.validate_constraints()
79
self.assertEqual(len(pm.passes), 4)
82
modified_m = res.graph_module
83
assert isinstance(modified_m, fx.GraphModule)
85
# Check that all call_function nodes are divs
86
for node in modified_m.graph.nodes:
87
if node.op == "call_function":
88
self.assertEqual(node.target, torch.add)
90
def test_this_before_that_pass_constraint(self):
92
Tests the construction of constraints
94
passes = [lambda x: 2 * x for _ in range(10)]
95
pm = PassManager(passes)
97
# add unfulfillable constraint
98
pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0]))
100
with self.assertRaises(RuntimeError):
101
pm.validate_constraints()
103
def test_pass_manager_checks(self):
105
Tests that users can add in check functions correctly
108
traced_m = fx.symbolic_trace(m)
109
pm = PassManager(passes=[ReplaceAddWithMulPass(), replace_mul_with_div_pass])
111
def check_div_target(graph_module):
112
for node in graph_module.graph.nodes:
113
if node.op == "call_function" and node.target != torch.div:
114
raise ValueError("Target should be div!")
116
pm.add_checks(check_div_target)
118
with self.assertRaises(ValueError):
121
def test_pass_manager_bad_checks(self):
123
Checks that we error if we pass in a check function with the wrong parameters
126
def check_bad_args(graph_module, i):
130
self.assertRaises(TypeError, pm.add_checks, check_bad_args)
132
def test_topological_sort(self):
134
Tests that passes are correctly ordered based on contraints.
155
# Not passing any constraints should keep the original order
156
passes = [pass0, pass1, pass2, pass3, pass4, pass5]
157
sorted = _topological_sort_passes(passes, [])
158
self.assertEqual(sorted, passes)
160
# Graph that we are constructing:
163
# +-> 2 -> 3 -> 1 <-+
164
# Which has a possible topological order of: [4, 5, 0, 2, 3, 1]
165
passes = [pass0, pass1, pass2, pass3, pass4, pass5]
167
this_before_that_pass_constraint(pass5, pass0),
168
this_before_that_pass_constraint(pass5, pass2),
169
this_before_that_pass_constraint(pass4, pass0),
170
this_before_that_pass_constraint(pass4, pass1),
171
this_before_that_pass_constraint(pass2, pass3),
172
this_before_that_pass_constraint(pass3, pass1),
174
sorted = _topological_sort_passes(passes, constraints)
175
self.assertEqual(sorted, [pass4, pass5, pass0, pass2, pass3, pass1])
177
# Circular dependency should result in the circular_dep flag being set
178
passes = [pass0, pass1, pass2]
180
this_before_that_pass_constraint(passes[0], passes[1]),
181
this_before_that_pass_constraint(passes[1], passes[2]),
182
this_before_that_pass_constraint(passes[2], passes[0]),
184
with self.assertRaises(RuntimeError) as e:
185
_topological_sort_passes(passes, constraints)
186
expected_error_msg = (
187
f"Circular dependency detected within the following passes: {passes}"
189
self.assertEqual(e.exception.args[0], expected_error_msg)
191
def test_pass_manager_error(self):
193
Tests error catching + debug
196
def pass_fail(graph_module):
197
raise RuntimeError("bad")
200
traced_m = torch.fx.symbolic_trace(m)
203
ReplaceAddWithMulPass(),
204
replace_mul_with_div_pass,
205
ReplaceDivWithSubPass(),
206
pass_result_wrapper(replace_sub_with_add_pass),
210
# Comment out this line to see the actual error message
212
"ReplaceDivWithSubPass.*ReplaceAddWithMulPass.*replace_mul_with_div_pass"
214
with self.assertRaisesRegex(Exception, error_msg):
219
ReplaceAddWithMulPass(),
220
replace_mul_with_div_pass,
221
pass_result_wrapper(ReplaceDivWithSubPass()),
222
pass_result_wrapper(replace_sub_with_add_pass),
227
# Comment out this line to see the actual error message
228
error_msg = "pass_fail.*ReplaceAddWithMulPass.*replace_mul_with_div_pass.*ReplaceDivWithSubPass.*replace_sub_with_add_pass"
229
with self.assertRaisesRegex(Exception, error_msg):