1
# Owner(s): ["module: fx"]
5
from typing import Callable
8
import torch.nn.functional as F
9
from torch.fx import symbolic_trace
10
from torch.fx.experimental.proxy_tensor import make_fx
13
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14
sys.path.append(pytorch_test_dir)
17
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
18
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
19
SubgraphMatcherWithNameNodeMap,
21
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
22
from torch.testing._internal.jit_utils import JitTestCase
25
class WrapperModule(torch.nn.Module):
26
def __init__(self, fn: Callable):
30
def forward(self, *args, **kwargs):
31
return self.fn(*args, **kwargs)
34
class TestMatcher(JitTestCase):
35
def test_subgraph_matcher_with_attributes(self):
36
class LargeModel(torch.nn.Module):
37
def __init__(self) -> None:
39
self._weight = torch.nn.Parameter(torch.ones(3, 3))
40
self._bias = torch.nn.Parameter(torch.ones(3, 3))
43
return torch.ops.aten.addmm.default(self._bias, x, self._weight)
46
# opcode name target args kwargs
47
# ------------- ------------- ------------------ ------------------- --------
48
# placeholder x x () {}
49
# get_attr _bias _bias () {}
50
# get_attr _weight _weight () {}
51
# call_function addmm_default aten.addmm.default (_bias, x, _weight) {}
52
# output output output (addmm_default,) {}
53
large_model_graph = symbolic_trace(LargeModel()).graph
55
class PatternModel(torch.nn.Module):
56
def __init__(self) -> None:
58
self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
59
self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
62
return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
64
pattern_graph = torch.fx.symbolic_trace(PatternModel()).graph
66
subgraph_matcher = SubgraphMatcher(pattern_graph)
67
match_result = subgraph_matcher.match(large_model_graph)
68
self.assertEqual(len(match_result), 1)
70
def test_subgraph_matcher_with_list(self):
72
return torch.ops.aten.view(x, [5, y.shape[0]])
74
original_graph = torch.fx.symbolic_trace(original).graph
77
return torch.ops.aten.view(x, [z, y.shape[0]])
79
pattern_graph = torch.fx.symbolic_trace(pattern).graph
81
subgraph_matcher = SubgraphMatcher(pattern_graph)
82
match_result = subgraph_matcher.match(original_graph)
83
self.assertEqual(len(match_result), 1)
85
def test_subgraph_matcher_with_list_bad(self):
87
return torch.ops.aten._reshape_alias_copy.default(
88
x, [1, y.shape[0]], [y.shape[1], y.shape[1]]
91
original_graph = torch.fx.symbolic_trace(original).graph
94
return torch.ops.aten._reshape_alias_copy.default(
95
x, [b, y.shape[0], y.shape[1]], [y.shape[1]]
98
pattern_graph = torch.fx.symbolic_trace(pattern).graph
100
subgraph_matcher = SubgraphMatcher(pattern_graph)
101
match_result = subgraph_matcher.match(original_graph)
102
self.assertEqual(len(match_result), 0)
104
def test_subgraph_matcher_ignore_literals(self):
108
original_graph = make_fx(original)(torch.ones(3, 3)).graph
109
original_graph.eliminate_dead_code()
114
pattern_graph = make_fx(pattern)(torch.ones(4, 4)).graph
115
pattern_graph.eliminate_dead_code()
117
subgraph_matcher = SubgraphMatcher(pattern_graph)
118
match_result = subgraph_matcher.match(original_graph)
119
self.assertEqual(len(match_result), 0)
121
subgraph_matcher = SubgraphMatcher(pattern_graph, ignore_literals=True)
122
match_result = subgraph_matcher.match(original_graph)
123
self.assertEqual(len(match_result), 1)
125
def test_variatic_arg_matching(self):
126
inputs = (torch.randn(20, 16, 50, 32),)
128
def maxpool(x, kernel_size, stride, padding, dilation):
129
return torch.ops.aten.max_pool2d_with_indices.default(
130
x, kernel_size, stride, padding, dilation
133
maxpool_graph = torch.fx.symbolic_trace(maxpool).graph
135
maxpool_matcher = SubgraphMatcher(maxpool_graph)
136
match_result = maxpool_matcher.match(maxpool_graph)
137
self.assertEqual(len(match_result), 1)
139
# Graph only contains "stride" argument
140
maxpool_s = torch.nn.MaxPool2d(kernel_size=2, stride=1).eval()
141
maxpool_s_graph = make_fx(maxpool_s)(*inputs).graph
142
match_s_result = maxpool_matcher.match(maxpool_s_graph)
143
self.assertEqual(len(match_s_result), 1)
145
# Graph only contains "padding" argument
146
maxpool_p = torch.nn.MaxPool2d(kernel_size=2, padding=1)
147
maxpool_p_graph = make_fx(maxpool_p)(*inputs).graph
148
match_p_result = maxpool_matcher.match(maxpool_p_graph)
149
self.assertEqual(len(match_p_result), 1)
151
# Graph only contains "stride, padding" argument
152
maxpool_sp = torch.nn.MaxPool2d(kernel_size=2, stride=1, padding=1)
153
maxpool_sp_graph = make_fx(maxpool_sp)(*inputs).graph
154
match_sp_result = maxpool_matcher.match(maxpool_sp_graph)
155
self.assertEqual(len(match_sp_result), 1)
157
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
158
def test_split_to_graph_and_name_node_map(self):
159
"""Testing the internal helper function for splitting the pattern graph"""
160
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
161
_split_to_graph_and_name_node_map,
164
def pattern(x, weight):
165
conv = F.conv2d(x, weight)
167
relu_mul_by_two = relu * 2
168
return relu, relu_mul_by_two, {"conv": conv, "relu": relu}
170
from torch._export import capture_pre_autograd_graph
173
torch.randn(1, 3, 3, 3) * 10,
174
torch.randn(3, 3, 3, 3),
176
pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs)
177
before_split_res = pattern_gm(*example_inputs)
178
pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm)
179
after_split_res = pattern_gm(*example_inputs)
180
self.assertEqual(before_split_res[0], after_split_res[0])
181
self.assertEqual(before_split_res[1], after_split_res[1])
183
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
184
def test_matcher_with_name_node_map_function(self):
185
"""Testing SubgraphMatcherWithNameNodeMap with function pattern"""
187
def target_graph(x, weight):
190
conv = F.conv2d(x, weight)
195
def pattern(x, weight):
196
conv = F.conv2d(x, weight)
198
relu_mul_by_two = relu * 2
199
return relu, relu_mul_by_two, {"conv": conv, "relu": relu}
201
from torch._export import capture_pre_autograd_graph
204
torch.randn(1, 3, 3, 3) * 10,
205
torch.randn(3, 3, 3, 3),
207
pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs)
208
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
209
target_gm = capture_pre_autograd_graph(
210
WrapperModule(target_graph), example_inputs
212
internal_matches = matcher.match(target_gm.graph)
213
for internal_match in internal_matches:
214
name_node_map = internal_match.name_node_map
215
assert "conv" in name_node_map
216
assert "relu" in name_node_map
217
name_node_map["conv"].meta["custom_annotation"] = "annotation"
218
# check if we correctly annotated the target graph module
219
for n in target_gm.graph.nodes:
220
if n == name_node_map["conv"]:
222
"custom_annotation" in n.meta
223
and n.meta["custom_annotation"] == "annotation"
226
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
227
def test_matcher_with_name_node_map_module(self):
228
"""Testing SubgraphMatcherWithNameNodeMap with module pattern"""
230
class M(torch.nn.Module):
231
def __init__(self) -> None:
233
self.linear = torch.nn.Linear(5, 5)
235
def forward(self, x):
236
return self.linear(x)
238
class Pattern(torch.nn.Module):
239
def __init__(self) -> None:
241
self.linear = torch.nn.Linear(5, 5)
243
def forward(self, x):
244
linear = self.linear(x)
245
# Note: we can't put "weight": self.linear.weight in dictionary since
246
# nn.Parameter is not an allowed output type in dynamo
247
return linear, {"linear": linear, "x": x}
249
from torch._export import capture_pre_autograd_graph
251
example_inputs = (torch.randn(3, 5),)
252
pattern_gm = capture_pre_autograd_graph(Pattern(), example_inputs)
253
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
254
target_gm = capture_pre_autograd_graph(M(), example_inputs)
255
internal_matches = matcher.match(target_gm.graph)
256
for internal_match in internal_matches:
257
name_node_map = internal_match.name_node_map
258
assert "linear" in name_node_map
259
assert "x" in name_node_map
260
name_node_map["linear"].meta["custom_annotation"] = "annotation"
261
# check if we correctly annotated the target graph module
262
for n in target_gm.graph.nodes:
263
if n == name_node_map["linear"]:
265
"custom_annotation" in n.meta
266
and n.meta["custom_annotation"] == "annotation"
270
if __name__ == "__main__":