pytorch

Форк
0
/
test_verification.py 
298 строк · 9.7 Кб
1
# Owner(s): ["module: onnx"]
2

3
import contextlib
4
import io
5
import tempfile
6
import unittest
7

8
import numpy as np
9

10
import onnx
11
import parameterized
12
import pytorch_test_common
13
from packaging import version
14

15
import torch
16
from torch.onnx import _constants, _experimental, verification
17
from torch.testing._internal import common_utils
18

19

20
class TestVerification(pytorch_test_common.ExportTestCase):
21
    def test_check_export_model_diff_returns_diff_when_constant_mismatch(self):
22
        class UnexportableModel(torch.nn.Module):
23
            def forward(self, x, y):
24
                # tensor.data() will be exported as a constant,
25
                # leading to wrong model output under different inputs.
26
                return x + y.data
27

28
        test_input_groups = [
29
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
30
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
31
        ]
32

33
        results = verification.check_export_model_diff(
34
            UnexportableModel(), test_input_groups
35
        )
36
        self.assertRegex(
37
            results,
38
            r"Graph diff:(.|\n)*"
39
            r"First diverging operator:(.|\n)*"
40
            r"prim::Constant(.|\n)*"
41
            r"Former source location:(.|\n)*"
42
            r"Latter source location:",
43
        )
44

45
    def test_check_export_model_diff_returns_diff_when_dynamic_controlflow_mismatch(
46
        self,
47
    ):
48
        class UnexportableModel(torch.nn.Module):
49
            def forward(self, x, y):
50
                for i in range(x.size(0)):
51
                    y = x[i] + y
52
                return y
53

54
        test_input_groups = [
55
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
56
            ((torch.randn(4, 3), torch.randn(2, 3)), {}),
57
        ]
58

59
        export_options = _experimental.ExportOptions(
60
            input_names=["x", "y"], dynamic_axes={"x": [0]}
61
        )
62
        results = verification.check_export_model_diff(
63
            UnexportableModel(), test_input_groups, export_options
64
        )
65
        self.assertRegex(
66
            results,
67
            r"Graph diff:(.|\n)*"
68
            r"First diverging operator:(.|\n)*"
69
            r"prim::Constant(.|\n)*"
70
            r"Latter source location:(.|\n)*",
71
        )
72

73
    def test_check_export_model_diff_returns_empty_when_correct_export(self):
74
        class SupportedModel(torch.nn.Module):
75
            def forward(self, x, y):
76
                return x + y
77

78
        test_input_groups = [
79
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
80
            ((torch.randn(2, 3), torch.randn(2, 3)), {}),
81
        ]
82

83
        results = verification.check_export_model_diff(
84
            SupportedModel(), test_input_groups
85
        )
86
        self.assertEqual(results, "")
87

88
    def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage(
89
        self,
90
    ):
91
        ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
92
        pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
93
        options = verification.VerificationOptions(
94
            rtol=1e-5,
95
            atol=1e-6,
96
            check_shape=True,
97
            check_dtype=False,
98
            ignore_none=True,
99
            acceptable_error_percentage=0.3,
100
        )
101
        verification._compare_onnx_pytorch_outputs(
102
            ort_outs,
103
            pytorch_outs,
104
            options,
105
        )
106

107
    def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage(
108
        self,
109
    ):
110
        ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
111
        pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
112
        options = verification.VerificationOptions(
113
            rtol=1e-5,
114
            atol=1e-6,
115
            check_shape=True,
116
            check_dtype=False,
117
            ignore_none=True,
118
            acceptable_error_percentage=None,
119
        )
120
        with self.assertRaises(AssertionError):
121
            verification._compare_onnx_pytorch_outputs(
122
                ort_outs,
123
                pytorch_outs,
124
                options,
125
            )
126

127

128
@common_utils.instantiate_parametrized_tests
129
class TestVerificationOnWrongExport(pytorch_test_common.ExportTestCase):
130
    opset_version: int
131

132
    def setUp(self):
133
        super().setUp()
134

135
        def incorrect_add_symbolic_function(g, self, other, alpha):
136
            return self
137

138
        self.opset_version = _constants.ONNX_DEFAULT_OPSET
139
        torch.onnx.register_custom_op_symbolic(
140
            "aten::add",
141
            incorrect_add_symbolic_function,
142
            opset_version=self.opset_version,
143
        )
144

145
    def tearDown(self):
146
        super().tearDown()
147
        torch.onnx.unregister_custom_op_symbolic(
148
            "aten::add", opset_version=self.opset_version
149
        )
150

151
    @common_utils.parametrize(
152
        "onnx_backend",
153
        [
154
            common_utils.subtest(
155
                verification.OnnxBackend.REFERENCE,
156
                decorators=[
157
                    unittest.skipIf(
158
                        version.Version(onnx.__version__) < version.Version("1.13"),
159
                        reason="Reference Python runtime was introduced in 'onnx' 1.13.",
160
                    )
161
                ],
162
            ),
163
            verification.OnnxBackend.ONNX_RUNTIME_CPU,
164
        ],
165
    )
166
    def test_verify_found_mismatch_when_export_is_wrong(
167
        self, onnx_backend: verification.OnnxBackend
168
    ):
169
        class Model(torch.nn.Module):
170
            def forward(self, x):
171
                return x + 1
172

173
        with self.assertRaisesRegex(AssertionError, ".*Tensor-likes are not close!.*"):
174
            verification.verify(
175
                Model(),
176
                (torch.randn(2, 3),),
177
                opset_version=self.opset_version,
178
                options=verification.VerificationOptions(backend=onnx_backend),
179
            )
180

181

182
@parameterized.parameterized_class(
183
    [
184
        # TODO: enable this when ONNX submodule catches up to >= 1.13.
185
        # {"onnx_backend": verification.OnnxBackend.ONNX},
186
        {"onnx_backend": verification.OnnxBackend.ONNX_RUNTIME_CPU},
187
    ],
188
    class_name_func=lambda cls,
189
    idx,
190
    input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}",
191
)
192
class TestFindMismatch(pytorch_test_common.ExportTestCase):
193
    onnx_backend: verification.OnnxBackend
194
    opset_version: int
195
    graph_info: verification.GraphInfo
196

197
    def setUp(self):
198
        super().setUp()
199
        self.opset_version = _constants.ONNX_DEFAULT_OPSET
200

201
        def incorrect_relu_symbolic_function(g, self):
202
            return g.op("Add", self, g.op("Constant", value_t=torch.tensor(1.0)))
203

204
        torch.onnx.register_custom_op_symbolic(
205
            "aten::relu",
206
            incorrect_relu_symbolic_function,
207
            opset_version=self.opset_version,
208
        )
209

210
        class Model(torch.nn.Module):
211
            def __init__(self) -> None:
212
                super().__init__()
213
                self.layers = torch.nn.Sequential(
214
                    torch.nn.Linear(3, 4),
215
                    torch.nn.ReLU(),
216
                    torch.nn.Linear(4, 5),
217
                    torch.nn.ReLU(),
218
                    torch.nn.Linear(5, 6),
219
                )
220

221
            def forward(self, x):
222
                return self.layers(x)
223

224
        self.graph_info = verification.find_mismatch(
225
            Model(),
226
            (torch.randn(2, 3),),
227
            opset_version=self.opset_version,
228
            options=verification.VerificationOptions(backend=self.onnx_backend),
229
        )
230

231
    def tearDown(self):
232
        super().tearDown()
233
        torch.onnx.unregister_custom_op_symbolic(
234
            "aten::relu", opset_version=self.opset_version
235
        )
236
        delattr(self, "opset_version")
237
        delattr(self, "graph_info")
238

239
    def test_pretty_print_tree_visualizes_mismatch(self):
240
        f = io.StringIO()
241
        with contextlib.redirect_stdout(f):
242
            self.graph_info.pretty_print_tree()
243
        self.assertExpected(f.getvalue())
244

245
    def test_preserve_mismatch_source_location(self):
246
        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
247

248
        self.assertTrue(len(mismatch_leaves) > 0)
249

250
        for leaf_info in mismatch_leaves:
251
            f = io.StringIO()
252
            with contextlib.redirect_stdout(f):
253
                leaf_info.pretty_print_mismatch(graph=True)
254
            self.assertRegex(
255
                f.getvalue(),
256
                r"(.|\n)*" r"aten::relu.*/torch/nn/functional.py:[0-9]+(.|\n)*",
257
            )
258

259
    def test_find_all_mismatch_operators(self):
260
        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
261

262
        self.assertEqual(len(mismatch_leaves), 2)
263

264
        for leaf_info in mismatch_leaves:
265
            self.assertEqual(leaf_info.essential_node_count(), 1)
266
            self.assertEqual(leaf_info.essential_node_kinds(), {"aten::relu"})
267

268
    def test_find_mismatch_prints_correct_info_when_no_mismatch(self):
269
        self.maxDiff = None
270

271
        class Model(torch.nn.Module):
272
            def forward(self, x):
273
                return x + 1
274

275
        f = io.StringIO()
276
        with contextlib.redirect_stdout(f):
277
            verification.find_mismatch(
278
                Model(),
279
                (torch.randn(2, 3),),
280
                opset_version=self.opset_version,
281
                options=verification.VerificationOptions(backend=self.onnx_backend),
282
            )
283
        self.assertExpected(f.getvalue())
284

285
    def test_export_repro_for_mismatch(self):
286
        mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
287
        self.assertTrue(len(mismatch_leaves) > 0)
288
        leaf_info = mismatch_leaves[0]
289
        with tempfile.TemporaryDirectory() as temp_dir:
290
            repro_dir = leaf_info.export_repro(temp_dir)
291

292
            with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
293
                options = verification.VerificationOptions(backend=self.onnx_backend)
294
                verification.OnnxTestCaseRepro(repro_dir).validate(options)
295

296

297
if __name__ == "__main__":
298
    common_utils.run_tests()
299

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.