pytorch

Форк
0
/
test_verification.py 
295 строк · 9.7 Кб
1
# Owner(s): ["module: onnx"]
2

3
import contextlib
4
import io
5
import tempfile
6
import unittest
7

8
import numpy as np
9
import onnx
10
import parameterized
11
import pytorch_test_common
12

13
import torch
14
from packaging import version
15
from torch.onnx import _constants, _experimental, verification
16
from torch.testing._internal import common_utils
17

18

19
class TestVerification(pytorch_test_common.ExportTestCase):
20
    def test_check_export_model_diff_returns_diff_when_constant_mismatch(self):
21
        class UnexportableModel(torch.nn.Module):
22
            def forward(self, x, y):
23
                # tensor.data() will be exported as a constant,
24
                # leading to wrong model output under different inputs.
25
                return x + y.data
26

27
        test_input_groups = [
28
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
29
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
30
        ]
31

32
        results = verification.check_export_model_diff(
33
            UnexportableModel(), test_input_groups
34
        )
35
        self.assertRegex(
36
            results,
37
            r"Graph diff:(.|\n)*"
38
            r"First diverging operator:(.|\n)*"
39
            r"prim::Constant(.|\n)*"
40
            r"Former source location:(.|\n)*"
41
            r"Latter source location:",
42
        )
43

44
    def test_check_export_model_diff_returns_diff_when_dynamic_controlflow_mismatch(
45
        self,
46
    ):
47
        class UnexportableModel(torch.nn.Module):
48
            def forward(self, x, y):
49
                for i in range(x.size(0)):
50
                    y = x[i] + y
51
                return y
52

53
        test_input_groups = [
54
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
55
            ((torch.randn(4, 3), torch.randn(2, 3)), {}),
56
        ]
57

58
        export_options = _experimental.ExportOptions(
59
            input_names=["x", "y"], dynamic_axes={"x": [0]}
60
        )
61
        results = verification.check_export_model_diff(
62
            UnexportableModel(), test_input_groups, export_options
63
        )
64
        self.assertRegex(
65
            results,
66
            r"Graph diff:(.|\n)*"
67
            r"First diverging operator:(.|\n)*"
68
            r"prim::Constant(.|\n)*"
69
            r"Latter source location:(.|\n)*",
70
        )
71

72
    def test_check_export_model_diff_returns_empty_when_correct_export(self):
73
        class SupportedModel(torch.nn.Module):
74
            def forward(self, x, y):
75
                return x + y
76

77
        test_input_groups = [
78
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
79
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
80
        ]
81

82
        results = verification.check_export_model_diff(
83
            SupportedModel(), test_input_groups
84
        )
85
        self.assertEqual(results, "")
86

87
    def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage(
88
        self,
89
    ):
90
        ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
91
        pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
92
        options = verification.VerificationOptions(
93
            rtol=1e-5,
94
            atol=1e-6,
95
            check_shape=True,
96
            check_dtype=False,
97
            ignore_none=True,
98
            acceptable_error_percentage=0.3,
99
        )
100
        verification._compare_onnx_pytorch_outputs(
101
            ort_outs,
102
            pytorch_outs,
103
            options,
104
        )
105

106
    def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage(
107
        self,
108
    ):
109
        ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
110
        pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
111
        options = verification.VerificationOptions(
112
            rtol=1e-5,
113
            atol=1e-6,
114
            check_shape=True,
115
            check_dtype=False,
116
            ignore_none=True,
117
            acceptable_error_percentage=None,
118
        )
119
        with self.assertRaises(AssertionError):
120
            verification._compare_onnx_pytorch_outputs(
121
                ort_outs,
122
                pytorch_outs,
123
                options,
124
            )
125

126

127
@common_utils.instantiate_parametrized_tests
128
class TestVerificationOnWrongExport(pytorch_test_common.ExportTestCase):
129
    opset_version: int
130

131
    def setUp(self):
132
        super().setUp()
133

134
        def incorrect_add_symbolic_function(g, self, other, alpha):
135
            return self
136

137
        self.opset_version = _constants.ONNX_DEFAULT_OPSET
138
        torch.onnx.register_custom_op_symbolic(
139
            "aten::add",
140
            incorrect_add_symbolic_function,
141
            opset_version=self.opset_version,
142
        )
143

144
    def tearDown(self):
145
        super().tearDown()
146
        torch.onnx.unregister_custom_op_symbolic(
147
            "aten::add", opset_version=self.opset_version
148
        )
149

150
    @common_utils.parametrize(
151
        "onnx_backend",
152
        [
153
            common_utils.subtest(
154
                verification.OnnxBackend.REFERENCE,
155
                decorators=[
156
                    unittest.skipIf(
157
                        version.Version(onnx.__version__) < version.Version("1.13"),
158
                        reason="Reference Python runtime was introduced in 'onnx' 1.13.",
159
                    )
160
                ],
161
            ),
162
            verification.OnnxBackend.ONNX_RUNTIME_CPU,
163
        ],
164
    )
165
    def test_verify_found_mismatch_when_export_is_wrong(
166
        self, onnx_backend: verification.OnnxBackend
167
    ):
168
        class Model(torch.nn.Module):
169
            def forward(self, x):
170
                return x + 1
171

172
        with self.assertRaisesRegex(AssertionError, ".*Tensor-likes are not close!.*"):
173
            verification.verify(
174
                Model(),
175
                (torch.randn(2, 3),),
176
                opset_version=self.opset_version,
177
                options=verification.VerificationOptions(backend=onnx_backend),
178
            )
179

180

181
@parameterized.parameterized_class(
182
    [
183
        # TODO: enable this when ONNX submodule catches up to >= 1.13.
184
        # {"onnx_backend": verification.OnnxBackend.ONNX},
185
        {"onnx_backend": verification.OnnxBackend.ONNX_RUNTIME_CPU},
186
    ],
187
    class_name_func=lambda cls, idx, input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}",
188
)
189
class TestFindMismatch(pytorch_test_common.ExportTestCase):
190
    onnx_backend: verification.OnnxBackend
191
    opset_version: int
192
    graph_info: verification.GraphInfo
193

194
    def setUp(self):
195
        super().setUp()
196
        self.opset_version = _constants.ONNX_DEFAULT_OPSET
197

198
        def incorrect_relu_symbolic_function(g, self):
199
            return g.op("Add", self, g.op("Constant", value_t=torch.tensor(1.0)))
200

201
        torch.onnx.register_custom_op_symbolic(
202
            "aten::relu",
203
            incorrect_relu_symbolic_function,
204
            opset_version=self.opset_version,
205
        )
206

207
        class Model(torch.nn.Module):
208
            def __init__(self):
209
                super().__init__()
210
                self.layers = torch.nn.Sequential(
211
                    torch.nn.Linear(3, 4),
212
                    torch.nn.ReLU(),
213
                    torch.nn.Linear(4, 5),
214
                    torch.nn.ReLU(),
215
                    torch.nn.Linear(5, 6),
216
                )
217

218
            def forward(self, x):
219
                return self.layers(x)
220

221
        self.graph_info = verification.find_mismatch(
222
            Model(),
223
            (torch.randn(2, 3),),
224
            opset_version=self.opset_version,
225
            options=verification.VerificationOptions(backend=self.onnx_backend),
226
        )
227

228
    def tearDown(self):
229
        super().tearDown()
230
        torch.onnx.unregister_custom_op_symbolic(
231
            "aten::relu", opset_version=self.opset_version
232
        )
233
        delattr(self, "opset_version")
234
        delattr(self, "graph_info")
235

236
    def test_pretty_print_tree_visualizes_mismatch(self):
237
        f = io.StringIO()
238
        with contextlib.redirect_stdout(f):
239
            self.graph_info.pretty_print_tree()
240
        self.assertExpected(f.getvalue())
241

242
    def test_preserve_mismatch_source_location(self):
243
        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
244

245
        self.assertTrue(len(mismatch_leaves) > 0)
246

247
        for leaf_info in mismatch_leaves:
248
            f = io.StringIO()
249
            with contextlib.redirect_stdout(f):
250
                leaf_info.pretty_print_mismatch(graph=True)
251
            self.assertRegex(
252
                f.getvalue(),
253
                r"(.|\n)*" r"aten::relu.*/torch/nn/functional.py:[0-9]+(.|\n)*",
254
            )
255

256
    def test_find_all_mismatch_operators(self):
257
        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
258

259
        self.assertEqual(len(mismatch_leaves), 2)
260

261
        for leaf_info in mismatch_leaves:
262
            self.assertEqual(leaf_info.essential_node_count(), 1)
263
            self.assertEqual(leaf_info.essential_node_kinds(), {"aten::relu"})
264

265
    def test_find_mismatch_prints_correct_info_when_no_mismatch(self):
266
        self.maxDiff = None
267

268
        class Model(torch.nn.Module):
269
            def forward(self, x):
270
                return x + 1
271

272
        f = io.StringIO()
273
        with contextlib.redirect_stdout(f):
274
            verification.find_mismatch(
275
                Model(),
276
                (torch.randn(2, 3),),
277
                opset_version=self.opset_version,
278
                options=verification.VerificationOptions(backend=self.onnx_backend),
279
            )
280
        self.assertExpected(f.getvalue())
281

282
    def test_export_repro_for_mismatch(self):
283
        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
284
        self.assertTrue(len(mismatch_leaves) > 0)
285
        leaf_info = mismatch_leaves[0]
286
        with tempfile.TemporaryDirectory() as temp_dir:
287
            repro_dir = leaf_info.export_repro(temp_dir)
288

289
            with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
290
                options = verification.VerificationOptions(backend=self.onnx_backend)
291
                verification.OnnxTestCaseRepro(repro_dir).validate(options)
292

293

294
if __name__ == "__main__":
295
    common_utils.run_tests()
296

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.