pytorch

Форк
0
/
test_matcher_utils.py 
271 строка · 10.6 Кб
1
# Owner(s): ["module: fx"]
2

3
import os
4
import sys
5
from typing import Callable
6

7
import torch
8
import torch.nn.functional as F
9
from torch.fx import symbolic_trace
10
from torch.fx.experimental.proxy_tensor import make_fx
11

12

13
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14
sys.path.append(pytorch_test_dir)
15
import unittest
16

17
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
18
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
19
    SubgraphMatcherWithNameNodeMap,
20
)
21
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
22
from torch.testing._internal.jit_utils import JitTestCase
23

24

25
class WrapperModule(torch.nn.Module):
26
    def __init__(self, fn: Callable):
27
        super().__init__()
28
        self.fn = fn
29

30
    def forward(self, *args, **kwargs):
31
        return self.fn(*args, **kwargs)
32

33

34
class TestMatcher(JitTestCase):
35
    def test_subgraph_matcher_with_attributes(self):
36
        class LargeModel(torch.nn.Module):
37
            def __init__(self) -> None:
38
                super().__init__()
39
                self._weight = torch.nn.Parameter(torch.ones(3, 3))
40
                self._bias = torch.nn.Parameter(torch.ones(3, 3))
41

42
            def forward(self, x):
43
                return torch.ops.aten.addmm.default(self._bias, x, self._weight)
44

45
        # Large Model graph:
46
        # opcode         name           target              args                 kwargs
47
        # -------------  -------------  ------------------  -------------------  --------
48
        # placeholder    x              x                   ()                   {}
49
        # get_attr       _bias          _bias               ()                   {}
50
        # get_attr       _weight        _weight             ()                   {}
51
        # call_function  addmm_default  aten.addmm.default  (_bias, x, _weight)  {}
52
        # output         output         output              (addmm_default,)     {}
53
        large_model_graph = symbolic_trace(LargeModel()).graph
54

55
        class PatternModel(torch.nn.Module):
56
            def __init__(self) -> None:
57
                super().__init__()
58
                self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
59
                self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
60

61
            def forward(self, x):
62
                return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
63

64
        pattern_graph = torch.fx.symbolic_trace(PatternModel()).graph
65

66
        subgraph_matcher = SubgraphMatcher(pattern_graph)
67
        match_result = subgraph_matcher.match(large_model_graph)
68
        self.assertEqual(len(match_result), 1)
69

70
    def test_subgraph_matcher_with_list(self):
71
        def original(x, y):
72
            return torch.ops.aten.view(x, [5, y.shape[0]])
73

74
        original_graph = torch.fx.symbolic_trace(original).graph
75

76
        def pattern(x, y, z):
77
            return torch.ops.aten.view(x, [z, y.shape[0]])
78

79
        pattern_graph = torch.fx.symbolic_trace(pattern).graph
80

81
        subgraph_matcher = SubgraphMatcher(pattern_graph)
82
        match_result = subgraph_matcher.match(original_graph)
83
        self.assertEqual(len(match_result), 1)
84

85
    def test_subgraph_matcher_with_list_bad(self):
86
        def original(x, y):
87
            return torch.ops.aten._reshape_alias_copy.default(
88
                x, [1, y.shape[0]], [y.shape[1], y.shape[1]]
89
            )
90

91
        original_graph = torch.fx.symbolic_trace(original).graph
92

93
        def pattern(x, y, b):
94
            return torch.ops.aten._reshape_alias_copy.default(
95
                x, [b, y.shape[0], y.shape[1]], [y.shape[1]]
96
            )
97

98
        pattern_graph = torch.fx.symbolic_trace(pattern).graph
99

100
        subgraph_matcher = SubgraphMatcher(pattern_graph)
101
        match_result = subgraph_matcher.match(original_graph)
102
        self.assertEqual(len(match_result), 0)
103

104
    def test_subgraph_matcher_ignore_literals(self):
105
        def original(x):
106
            return x + 1
107

108
        original_graph = make_fx(original)(torch.ones(3, 3)).graph
109
        original_graph.eliminate_dead_code()
110

111
        def pattern(x):
112
            return x + 2
113

114
        pattern_graph = make_fx(pattern)(torch.ones(4, 4)).graph
115
        pattern_graph.eliminate_dead_code()
116

117
        subgraph_matcher = SubgraphMatcher(pattern_graph)
118
        match_result = subgraph_matcher.match(original_graph)
119
        self.assertEqual(len(match_result), 0)
120

121
        subgraph_matcher = SubgraphMatcher(pattern_graph, ignore_literals=True)
122
        match_result = subgraph_matcher.match(original_graph)
123
        self.assertEqual(len(match_result), 1)
124

125
    def test_variatic_arg_matching(self):
126
        inputs = (torch.randn(20, 16, 50, 32),)
127

128
        def maxpool(x, kernel_size, stride, padding, dilation):
129
            return torch.ops.aten.max_pool2d_with_indices.default(
130
                x, kernel_size, stride, padding, dilation
131
            )
132

133
        maxpool_graph = torch.fx.symbolic_trace(maxpool).graph
134

135
        maxpool_matcher = SubgraphMatcher(maxpool_graph)
136
        match_result = maxpool_matcher.match(maxpool_graph)
137
        self.assertEqual(len(match_result), 1)
138

139
        # Graph only contains "stride" argument
140
        maxpool_s = torch.nn.MaxPool2d(kernel_size=2, stride=1).eval()
141
        maxpool_s_graph = make_fx(maxpool_s)(*inputs).graph
142
        match_s_result = maxpool_matcher.match(maxpool_s_graph)
143
        self.assertEqual(len(match_s_result), 1)
144

145
        # Graph only contains "padding" argument
146
        maxpool_p = torch.nn.MaxPool2d(kernel_size=2, padding=1)
147
        maxpool_p_graph = make_fx(maxpool_p)(*inputs).graph
148
        match_p_result = maxpool_matcher.match(maxpool_p_graph)
149
        self.assertEqual(len(match_p_result), 1)
150

151
        # Graph only contains "stride, padding" argument
152
        maxpool_sp = torch.nn.MaxPool2d(kernel_size=2, stride=1, padding=1)
153
        maxpool_sp_graph = make_fx(maxpool_sp)(*inputs).graph
154
        match_sp_result = maxpool_matcher.match(maxpool_sp_graph)
155
        self.assertEqual(len(match_sp_result), 1)
156

157
    @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
158
    def test_split_to_graph_and_name_node_map(self):
159
        """Testing the internal helper function for splitting the pattern graph"""
160
        from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
161
            _split_to_graph_and_name_node_map,
162
        )
163

164
        def pattern(x, weight):
165
            conv = F.conv2d(x, weight)
166
            relu = F.relu(conv)
167
            relu_mul_by_two = relu * 2
168
            return relu, relu_mul_by_two, {"conv": conv, "relu": relu}
169

170
        from torch._export import capture_pre_autograd_graph
171

172
        example_inputs = (
173
            torch.randn(1, 3, 3, 3) * 10,
174
            torch.randn(3, 3, 3, 3),
175
        )
176
        pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs)
177
        before_split_res = pattern_gm(*example_inputs)
178
        pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm)
179
        after_split_res = pattern_gm(*example_inputs)
180
        self.assertEqual(before_split_res[0], after_split_res[0])
181
        self.assertEqual(before_split_res[1], after_split_res[1])
182

183
    @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
184
    def test_matcher_with_name_node_map_function(self):
185
        """Testing SubgraphMatcherWithNameNodeMap with function pattern"""
186

187
        def target_graph(x, weight):
188
            x = x * 2
189
            weight = weight * 3
190
            conv = F.conv2d(x, weight)
191
            relu = F.relu(conv)
192
            relu2 = relu * 2
193
            return relu + relu2
194

195
        def pattern(x, weight):
196
            conv = F.conv2d(x, weight)
197
            relu = F.relu(conv)
198
            relu_mul_by_two = relu * 2
199
            return relu, relu_mul_by_two, {"conv": conv, "relu": relu}
200

201
        from torch._export import capture_pre_autograd_graph
202

203
        example_inputs = (
204
            torch.randn(1, 3, 3, 3) * 10,
205
            torch.randn(3, 3, 3, 3),
206
        )
207
        pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs)
208
        matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
209
        target_gm = capture_pre_autograd_graph(
210
            WrapperModule(target_graph), example_inputs
211
        )
212
        internal_matches = matcher.match(target_gm.graph)
213
        for internal_match in internal_matches:
214
            name_node_map = internal_match.name_node_map
215
            assert "conv" in name_node_map
216
            assert "relu" in name_node_map
217
            name_node_map["conv"].meta["custom_annotation"] = "annotation"
218
            # check if we correctly annotated the target graph module
219
            for n in target_gm.graph.nodes:
220
                if n == name_node_map["conv"]:
221
                    assert (
222
                        "custom_annotation" in n.meta
223
                        and n.meta["custom_annotation"] == "annotation"
224
                    )
225

226
    @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
227
    def test_matcher_with_name_node_map_module(self):
228
        """Testing SubgraphMatcherWithNameNodeMap with module pattern"""
229

230
        class M(torch.nn.Module):
231
            def __init__(self) -> None:
232
                super().__init__()
233
                self.linear = torch.nn.Linear(5, 5)
234

235
            def forward(self, x):
236
                return self.linear(x)
237

238
        class Pattern(torch.nn.Module):
239
            def __init__(self) -> None:
240
                super().__init__()
241
                self.linear = torch.nn.Linear(5, 5)
242

243
            def forward(self, x):
244
                linear = self.linear(x)
245
                # Note: we can't put "weight": self.linear.weight in dictionary since
246
                # nn.Parameter is not an allowed output type in dynamo
247
                return linear, {"linear": linear, "x": x}
248

249
        from torch._export import capture_pre_autograd_graph
250

251
        example_inputs = (torch.randn(3, 5),)
252
        pattern_gm = capture_pre_autograd_graph(Pattern(), example_inputs)
253
        matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
254
        target_gm = capture_pre_autograd_graph(M(), example_inputs)
255
        internal_matches = matcher.match(target_gm.graph)
256
        for internal_match in internal_matches:
257
            name_node_map = internal_match.name_node_map
258
            assert "linear" in name_node_map
259
            assert "x" in name_node_map
260
            name_node_map["linear"].meta["custom_annotation"] = "annotation"
261
            # check if we correctly annotated the target graph module
262
            for n in target_gm.graph.nodes:
263
                if n == name_node_map["linear"]:
264
                    assert (
265
                        "custom_annotation" in n.meta
266
                        and n.meta["custom_annotation"] == "annotation"
267
                    )
268

269

270
if __name__ == "__main__":
271
    run_tests()
272

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.