pytorch

Форк
0
/
test_fx_passes.py 
259 строк · 8.7 Кб
1
# Owner(s): ["module: onnx"]
2
import torch
3
import torch._dynamo
4
import torch.fx
5

6
from torch._custom_op import impl as custom_op
7
from torch.onnx._internal.fx.passes import _utils as pass_utils
8
from torch.testing._internal import common_utils
9

10

11
class TestFxPasses(common_utils.TestCase):
12
    def test_set_node_name_correctly_renames_when_new_name_collides_recursively(self):
13
        def func(x, y, z):
14
            return x + y + z
15

16
        x = torch.randn(3)
17
        y = torch.randn(3)
18
        z = torch.randn(3)
19
        gm, _ = torch._dynamo.export(func)(x, y, z)
20
        torch._dynamo.reset()
21

22
        # Purposely name the nodes in a way that will cause a recursive collision later.
23
        # See :func:`set_node_name` for name collision renaming logic.
24
        base_name = "tensor"
25
        nodes = list(gm.graph.nodes)
26
        for i, node in enumerate(nodes[1:]):
27
            if i == 0:
28
                node.name = base_name
29
            else:
30
                node.name = f"{base_name}.{i}"
31

32
        # Run `set_node_name` and verify that the names are correct.
33
        name_to_node = {node.name: node for node in gm.graph.nodes}
34
        pass_utils.set_node_name(nodes[0], base_name, name_to_node)
35
        assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}"
36
        assert len({node.name for node in nodes}) == len(
37
            nodes
38
        ), f"Expected all names to be unique, got {nodes}"
39

40
    def test_set_node_name_succeeds_when_no_name_collisions(self):
41
        def func(x, y, z):
42
            return x + y + z
43

44
        x = torch.randn(3)
45
        y = torch.randn(3)
46
        z = torch.randn(3)
47
        gm, _ = torch._dynamo.export(func)(x, y, z)
48
        torch._dynamo.reset()
49

50
        # Run `set_node_name` and verify that the names are correct.
51
        new_name = "some_tensor"
52
        nodes = list(gm.graph.nodes)
53
        name_to_node = {node.name: node for node in nodes}
54
        pass_utils.set_node_name(nodes[1], new_name, name_to_node)
55
        assert nodes[1].name == new_name, f"Expected {new_name}, got {nodes[0].name}"
56
        assert len({node.name for node in nodes}) == len(
57
            nodes
58
        ), f"Expected all names to be unique, got {nodes}"
59

60
    def test_onnx_dynamo_export_raises_when_model_contains_unsupported_fx_nodes(self):
61
        @custom_op.custom_op("mylibrary::foo_op")
62
        def foo_op(x: torch.Tensor) -> torch.Tensor:
63
            ...
64

65
        @custom_op.custom_op("mylibrary::bar_op")
66
        def bar_op(x: torch.Tensor) -> torch.Tensor:
67
            ...
68

69
        @foo_op.impl_abstract()
70
        def foo_op_impl_abstract(x):
71
            return torch.empty_like(x)
72

73
        @foo_op.impl("cpu")
74
        def foo_op_impl(x):
75
            return x + 1
76

77
        @bar_op.impl_abstract()
78
        def bar_op_impl_abstract(x):
79
            return torch.empty_like(x)
80

81
        @bar_op.impl("cpu")
82
        def bar_op_impl(x):
83
            return x + 2
84

85
        torch._dynamo.allow_in_graph(foo_op)
86
        torch._dynamo.allow_in_graph(bar_op)
87

88
        def func(x, y, z):
89
            return foo_op(x) + bar_op(y) + z
90

91
        x = torch.randn(3)
92
        y = torch.randn(3)
93
        z = torch.randn(3)
94
        with self.assertRaises(torch.onnx.OnnxExporterError) as ctx:
95
            torch.onnx.dynamo_export(func, x, y, z)
96
        inner_exception = ctx.exception.__cause__
97
        self.assertRegex(
98
            str(inner_exception),
99
            r"Unsupported FX nodes.*mylibrary\.foo_op.*mylibrary\.bar_op",
100
        )
101

102
        torch._dynamo.reset()
103

104

105
@common_utils.instantiate_parametrized_tests
106
class TestModularizePass(common_utils.TestCase):
107
    @common_utils.parametrize(
108
        "is_exported_program",
109
        [
110
            common_utils.subtest(
111
                True,
112
                name="exported_program",
113
            ),
114
            common_utils.subtest(
115
                False,
116
                name="nn_module",
117
            ),
118
        ],
119
    )
120
    def test_modularize_pass_succeeds_when_submodule_output_is_unused(
121
        self, is_exported_program
122
    ):
123
        # This is an ill-formed model, but exporter must not crash.
124
        # It is illegal for submodule to have zero output. For modularization pass it can happen
125
        # when the submodule output is unused, so no inner node is connected to any outer
126
        # nodes.
127
        # However, this also means the entire submodule should be erased by DCE. Hence
128
        # it should never occur.
129
        #
130
        # Minified repro from Background_Matting. https://github.com/pytorch/benchmark/issues/1768
131
        class TestModule(torch.nn.Module):
132
            def __init__(self):
133
                super().__init__()
134
                self.unused_relu = torch.nn.ReLU()
135
                self.used_gelu = torch.nn.GELU()
136

137
            def forward(self, x, y):
138
                result = self.used_gelu(x + y)
139
                unused_relu_result = self.unused_relu(x)
140
                return result
141

142
        if is_exported_program:
143
            model = torch.export.export(
144
                TestModule(), args=(torch.randn(3), torch.randn(3))
145
            )
146
        else:
147
            model = TestModule()
148

149
        onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
150
        model_proto = onnx_program.model_proto
151
        function_proto_names = [function.name for function in model_proto.functions]
152
        self.assertIn(
153
            "torch_nn_modules_activation_GELU_used_gelu_1", function_proto_names
154
        )
155
        self.assertFalse(any("ReLU" in name for name in function_proto_names))
156

157
    @common_utils.parametrize(
158
        "is_exported_program",
159
        [
160
            common_utils.subtest(
161
                True,
162
                name="exported_program",
163
            ),
164
            common_utils.subtest(
165
                False,
166
                name="nn_module",
167
            ),
168
        ],
169
    )
170
    def test_modularize_pass_succeeds_when_a_submodule_is_called_multiple_times(
171
        self, is_exported_program
172
    ):
173
        class TestModule(torch.nn.Module):
174
            def __init__(self):
175
                super().__init__()
176
                self.relu = torch.nn.ReLU()
177

178
            def forward(self, x, y):
179
                out = x + y
180
                out = self.relu(out)
181
                out = out + x
182
                out = self.relu(out)
183
                return out
184

185
        if is_exported_program:
186
            model = torch.export.export(
187
                TestModule(), args=(torch.randn(3), torch.randn(3))
188
            )
189
        else:
190
            model = TestModule()
191

192
        onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
193
        model_proto = onnx_program.model_proto
194
        function_proto_names = [function.name for function in model_proto.functions]
195
        self.assertIn("torch_nn_modules_activation_ReLU_relu_1", function_proto_names)
196
        self.assertIn("torch_nn_modules_activation_ReLU_relu_2", function_proto_names)
197

198
    @common_utils.parametrize(
199
        "is_exported_program",
200
        [
201
            common_utils.subtest(
202
                True,
203
                name="exported_program",
204
            ),
205
            common_utils.subtest(
206
                False,
207
                name="nn_module",
208
            ),
209
        ],
210
    )
211
    def test_modularize_pass_succeeds_when_a_submodule_is_called_from_multiple_layers(
212
        self, is_exported_program
213
    ):
214
        # Minified repro from basic_gnn_edgecnn.
215
        class InnerModule(torch.nn.Module):
216
            def __init__(self):
217
                super().__init__()
218
                self.relu = torch.nn.ReLU()
219

220
            def forward(self, x):
221
                return self.relu(x)
222

223
        class TestModule(torch.nn.Module):
224
            def __init__(self):
225
                super().__init__()
226
                self.inner_module = InnerModule()
227

228
            def forward(self, x, y):
229
                out = x + y
230
                out = self.inner_module(out)
231
                out = out + x
232
                out = self.inner_module.relu(out)
233
                return out
234

235
        if is_exported_program:
236
            model = torch.export.export(
237
                TestModule(), args=(torch.randn(3), torch.randn(3))
238
            )
239
        else:
240
            model = TestModule()
241

242
        onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
243
        model_proto = onnx_program.model_proto
244
        function_proto_names = [function.name for function in model_proto.functions]
245
        self.assertIn(
246
            "torch_nn_modules_activation_ReLU_inner_module_relu_1", function_proto_names
247
        )
248
        self.assertIn(
249
            "torch_nn_modules_activation_ReLU_inner_module_relu_2", function_proto_names
250
        )
251
        # local module qualified name is unstable in test environment depending on different test
252
        # invocation methods.
253
        self.assertTrue(
254
            any("InnerModule_inner_module_1" in name for name in function_proto_names)
255
        )
256

257

258
if __name__ == "__main__":
259
    common_utils.run_tests()
260

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.