pytorch

Форк
0
/
test_onnx_opset.py 
539 строк · 17.2 Кб
1
# Owner(s): ["module: onnx"]
2

3
import io
4
import itertools
5

6
import onnx
7
import pytorch_test_common
8

9
import torch
10
import torch.onnx
11
from torch.nn import Module
12
from torch.onnx import producer_name, producer_version
13
from torch.onnx._globals import GLOBALS
14
from torch.testing._internal import common_utils
15

16

17
def check_onnx_opset_operator(
18
    model, ops, opset_version=GLOBALS.export_onnx_opset_version
19
):
20
    # check_onnx_components
21
    assert (
22
        model.producer_name == producer_name
23
        and model.producer_version == producer_version
24
        and model.opset_import[0].version == opset_version
25
    )
26

27
    # check the schema with the onnx checker
28
    onnx.checker.check_model(model)
29

30
    # check target type and attributes
31
    graph = model.graph
32
    # ops should contain an object for each node
33
    # in graph.node, in the right order.
34
    # At least the op_name should be specified,
35
    # but the op's attributes can optionally be
36
    # specified as well
37
    assert len(ops) == len(graph.node)
38
    for i in range(0, len(ops)):
39
        assert graph.node[i].op_type == ops[i]["op_name"]
40
        if "attributes" in ops[i]:
41
            attributes = ops[i]["attributes"]
42
            assert len(attributes) == len(graph.node[i].attribute)
43
            for j in range(0, len(attributes)):
44
                for attribute_field in attributes[j].keys():
45
                    assert attributes[j][attribute_field] == getattr(
46
                        graph.node[i].attribute[j], attribute_field
47
                    )
48

49

50
def check_onnx_opsets_operator(
51
    module,
52
    x,
53
    ops,
54
    opset_versions,
55
    training=torch.onnx.TrainingMode.EVAL,
56
    input_names=None,
57
    dynamic_axes=None,
58
):
59
    for opset_version in opset_versions:
60
        f = io.BytesIO()
61
        torch.onnx.export(
62
            module,
63
            x,
64
            f,
65
            opset_version=opset_version,
66
            training=training,
67
            input_names=input_names,
68
            dynamic_axes=dynamic_axes,
69
        )
70
        model = onnx.load(io.BytesIO(f.getvalue()))
71
        check_onnx_opset_operator(model, ops[opset_version], opset_version)
72

73

74
class TestONNXOpset(pytorch_test_common.ExportTestCase):
75
    def test_opset_fallback(self):
76
        class MyModule(Module):
77
            def forward(self, x):
78
                return torch.isnan(x)
79

80
        ops = [{"op_name": "IsNaN"}]
81
        ops = {9: ops, 10: ops}
82
        x = torch.tensor([1.0, float("nan"), 2.0])
83
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
84

85
    def test_topk(self):
86
        class MyModule(Module):
87
            def forward(self, x):
88
                return torch.topk(x, 3)
89

90
        ops_9 = [
91
            {
92
                "op_name": "TopK",
93
                "attributes": [
94
                    {"name": "axis", "i": -1, "type": 2},
95
                    {"name": "k", "i": 3, "type": 2},
96
                ],
97
            }
98
        ]
99
        ops_10 = [
100
            {"op_name": "Constant"},
101
            {"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]},
102
        ]
103
        ops = {9: ops_9, 10: ops_10}
104
        x = torch.arange(1.0, 6.0, requires_grad=True)
105
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
106

107
        # test with dynamic k
108
        class MyModuleDynamic(torch.jit.ScriptModule):
109
            @torch.jit.script_method
110
            def forward(self, input, k):
111
                return torch.topk(input, k)
112

113
        ops_10 = [
114
            {"op_name": "Constant", "attributes": [{"name": "value", "type": 4}]},
115
            {"op_name": "Reshape"},
116
            {"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]},
117
        ]
118
        ops = {10: ops_10}
119
        x = torch.arange(1.0, 6.0, requires_grad=True)
120
        k = torch.tensor(3)
121
        module = MyModuleDynamic()
122
        check_onnx_opsets_operator(module, (x, k), ops, opset_versions=[10])
123

124
    def test_maxpool(self):
125
        module = torch.nn.MaxPool1d(2, stride=1)
126

127
        ops_9 = [
128
            {
129
                "op_name": "MaxPool",
130
                "attributes": [
131
                    {"name": "kernel_shape", "ints": [2], "type": 7},
132
                    {"name": "pads", "ints": [0, 0], "type": 7},
133
                    {"name": "strides", "ints": [1], "type": 7},
134
                ],
135
            }
136
        ]
137
        ops_10 = [
138
            {
139
                "op_name": "MaxPool",
140
                "attributes": [
141
                    {"name": "ceil_mode", "i": 0, "type": 2},
142
                    {"name": "dilations", "ints": [1], "type": 7},
143
                    {"name": "kernel_shape", "ints": [2], "type": 7},
144
                    {"name": "pads", "ints": [0, 0], "type": 7},
145
                    {"name": "strides", "ints": [1], "type": 7},
146
                ],
147
            }
148
        ]
149
        ops = {9: ops_9, 10: ops_10}
150
        x = torch.randn(20, 16, 50)
151
        check_onnx_opsets_operator(module, x, ops, opset_versions=[9, 10])
152

153
        # add test with dilations
154
        module = torch.nn.MaxPool1d(2, stride=1, dilation=2)
155

156
        ops_10 = [
157
            {
158
                "op_name": "MaxPool",
159
                "attributes": [
160
                    {"name": "ceil_mode", "i": 0, "type": 2},
161
                    {"name": "dilations", "ints": [2], "type": 7},
162
                    {"name": "kernel_shape", "ints": [2], "type": 7},
163
                    {"name": "pads", "ints": [0, 0], "type": 7},
164
                    {"name": "strides", "ints": [1], "type": 7},
165
                ],
166
            }
167
        ]
168
        ops = {10: ops_10}
169
        x = torch.randn(20, 16, 50)
170
        check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
171

172
    def test_upsample(self):
173
        class MyModule(Module):
174
            def forward(self, x):
175
                size = [v * 2 for v in x.size()[2:]]
176
                size = [int(i) for i in size]
177
                return torch.nn.functional.interpolate(x, size=size, mode="nearest")
178

179
        module = MyModule()
180
        ops8 = [
181
            {
182
                "op_name": "Upsample",
183
                "attributes": [
184
                    {"name": "mode", "s": (b"nearest"), "type": 3},
185
                    {"name": "scales", "floats": [1.0, 1.0, 2.0, 2.0], "type": 6},
186
                ],
187
            }
188
        ]
189
        ops9 = [
190
            {"op_name": "Constant"},
191
            {
192
                "op_name": "Upsample",
193
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
194
            },
195
        ]
196
        ops = {8: ops8, 9: ops9}
197
        x = torch.randn(2, 2, 2, 2)
198
        check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
199

200
    def test_cast_constant(self):
201
        class MyModule(Module):
202
            def forward(self, x):
203
                return x - 1
204

205
        module = MyModule()
206
        ops_8 = [
207
            {"op_name": "Constant"},
208
            {"op_name": "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]},
209
            {"op_name": "Sub"},
210
        ]
211
        ops_9 = [{"op_name": "Constant"}, {"op_name": "Sub"}]
212
        ops = {8: ops_8, 9: ops_9}
213
        x = torch.ones(5, 6, dtype=torch.long)
214
        check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
215

216
    def test_slice(self):
217
        class MyModule(Module):
218
            def forward(self, x):
219
                return x[0:1]
220

221
        ops_9 = [
222
            {
223
                "op_name": "Slice",
224
                "attributes": [
225
                    {"name": "axes", "ints": [0], "type": 7},
226
                    {"name": "ends", "ints": [1], "type": 7},
227
                    {"name": "starts", "ints": [0], "type": 7},
228
                ],
229
            }
230
        ]
231
        ops_10 = [
232
            {"op_name": "Constant"},
233
            {"op_name": "Constant"},
234
            {"op_name": "Constant"},
235
            {"op_name": "Constant"},
236
            {"op_name": "Slice", "attributes": []},
237
        ]
238
        ops = {9: ops_9, 10: ops_10}
239
        x = torch.randn(3)
240
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
241

242
        class DynamicSliceModel(torch.jit.ScriptModule):
243
            @torch.jit.script_method
244
            def forward(self, x):
245
                return x[1 : x.size(0)]
246

247
        module = DynamicSliceModel()
248
        x = torch.rand(1, 2)
249
        ops_10 = [
250
            {"op_name": "Shape"},
251
            {"op_name": "Constant"},
252
            {"op_name": "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
253
            {"op_name": "Constant"},
254
            {"op_name": "Constant"},
255
            {
256
                "op_name": "Unsqueeze",
257
                "attributes": [{"name": "axes", "i": 0, "type": 7}],
258
            },
259
            {"op_name": "Constant"},
260
            {"op_name": "Slice", "attributes": []},
261
        ]
262
        ops = {10: ops_10}
263
        check_onnx_opsets_operator(
264
            module,
265
            x,
266
            ops,
267
            opset_versions=[10],
268
            input_names=["x"],
269
            dynamic_axes={"x": [0, 1]},
270
        )
271

272
        ops_10 = [
273
            {"op_name": "Constant"},
274
            {"op_name": "Constant"},
275
            {"op_name": "Constant"},
276
            {"op_name": "Constant"},
277
            {"op_name": "Slice", "attributes": []},
278
        ]
279
        ops = {10: ops_10}
280
        check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
281

282
    def test_flip(self):
283
        class MyModule(Module):
284
            def forward(self, x):
285
                return torch.flip(x, dims=[0])
286

287
        ops_10 = [
288
            {"op_name": "Constant"},
289
            {"op_name": "Constant"},
290
            {"op_name": "Constant"},
291
            {"op_name": "Constant"},
292
            {"op_name": "Slice", "attributes": []},
293
        ]
294
        ops = {10: ops_10}
295
        import numpy
296

297
        x = torch.tensor(numpy.arange(6.0).reshape(2, 3))
298
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[10])
299

300
    def test_dropout(self):
301
        class MyModule(Module):
302
            def __init__(self):
303
                super().__init__()
304
                self.dropout = torch.nn.Dropout(0.5)
305

306
            def forward(self, x):
307
                return self.dropout(x)
308

309
        x = torch.randn(1, 2, 3)
310

311
        # we should only export the onnx Dropout op in training mode; test both modes
312

313
        # test training mode
314
        ops = [
315
            {
316
                "op_name": "Dropout",
317
                "attributes": [{"name": "ratio", "f": 0.5, "type": 1}],
318
            }
319
        ]
320
        ops = {9: ops, 10: ops}
321
        check_onnx_opsets_operator(
322
            MyModule(),
323
            x,
324
            ops,
325
            opset_versions=[9, 10],
326
            training=torch.onnx.TrainingMode.TRAINING,
327
        )
328

329
        # test eval mode
330
        ops = [{"op_name": "Identity"}]
331
        ops = {9: ops, 10: ops}
332
        check_onnx_opsets_operator(
333
            MyModule(),
334
            x,
335
            ops,
336
            opset_versions=[9, 10],
337
            training=torch.onnx.TrainingMode.EVAL,
338
        )
339

340
    def test_full(self):
341
        class MyModule(Module):
342
            def forward(self, x):
343
                return torch.full((3, 4), x)
344

345
        ops = [
346
            {"op_name": "Constant"},
347
            {"op_name": "ConstantOfShape"},
348
            {"op_name": "Add"},
349
        ]
350
        ops = {9: ops, 10: ops}
351
        x = torch.tensor(12.0)
352
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
353

354
    def test_interpolate(self):
355
        class MyModel(torch.nn.Module):
356
            def forward(self, x):
357
                size = [v * 2 for v in x.size()[2:]]
358
                return torch.nn.functional.interpolate(x, size=size, mode="nearest")
359

360
        ops_9 = [
361
            {"op_name": "Shape"},
362
            {"op_name": "Constant"},
363
            {"op_name": "Gather"},
364
            {"op_name": "Shape"},
365
            {"op_name": "Constant"},
366
            {"op_name": "Gather"},
367
            {"op_name": "Constant"},
368
            {"op_name": "Mul"},
369
            {"op_name": "Constant"},
370
            {"op_name": "Mul"},
371
            {"op_name": "Unsqueeze"},
372
            {"op_name": "Unsqueeze"},
373
            {"op_name": "Concat"},
374
            {"op_name": "Cast"},
375
            {"op_name": "Shape"},
376
            {"op_name": "Slice"},
377
            {"op_name": "Cast"},
378
            {"op_name": "Div"},
379
            {"op_name": "Constant"},
380
            {"op_name": "Concat"},
381
            {
382
                "op_name": "Upsample",
383
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
384
            },
385
        ]
386
        ops_10 = [
387
            {"op_name": "Shape"},
388
            {"op_name": "Constant"},
389
            {"op_name": "Gather"},
390
            {"op_name": "Shape"},
391
            {"op_name": "Constant"},
392
            {"op_name": "Gather"},
393
            {"op_name": "Constant"},
394
            {"op_name": "Mul"},
395
            {"op_name": "Constant"},
396
            {"op_name": "Mul"},
397
            {"op_name": "Unsqueeze"},
398
            {"op_name": "Unsqueeze"},
399
            {"op_name": "Concat"},
400
            {"op_name": "Cast"},
401
            {"op_name": "Shape"},
402
            {"op_name": "Constant"},
403
            {"op_name": "Constant"},
404
            {"op_name": "Constant"},
405
            {"op_name": "Slice"},
406
            {"op_name": "Cast"},
407
            {"op_name": "Div"},
408
            {"op_name": "Constant"},
409
            {"op_name": "Concat"},
410
            {
411
                "op_name": "Resize",
412
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
413
            },
414
        ]
415

416
        ops = {9: ops_9, 10: ops_10}
417
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
418
        check_onnx_opsets_operator(
419
            MyModel(),
420
            x,
421
            ops,
422
            opset_versions=[9, 10],
423
            input_names=["x"],
424
            dynamic_axes={"x": [0, 1, 2, 3]},
425
        )
426

427
        ops_9 = [
428
            {"op_name": "Constant"},
429
            {"op_name": "Shape"},
430
            {"op_name": "Slice"},
431
            {"op_name": "Cast"},
432
            {"op_name": "Div"},
433
            {"op_name": "Constant"},
434
            {"op_name": "Concat"},
435
            {
436
                "op_name": "Upsample",
437
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
438
            },
439
        ]
440
        ops_10 = [
441
            {"op_name": "Constant"},
442
            {"op_name": "Shape"},
443
            {"op_name": "Constant"},
444
            {"op_name": "Constant"},
445
            {"op_name": "Constant"},
446
            {"op_name": "Slice"},
447
            {"op_name": "Cast"},
448
            {"op_name": "Div"},
449
            {"op_name": "Constant"},
450
            {"op_name": "Concat"},
451
            {"op_name": "Resize"},
452
        ]
453

454
        ops = {9: ops_9, 10: ops_10}
455
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
456
        check_onnx_opsets_operator(MyModel(), x, ops, opset_versions=[9, 10])
457

458
        class MyDynamicModel(torch.nn.Module):
459
            def forward(self, x):
460
                size = [v * 2 for v in x.size()[2:]]
461
                # work around for now: turn the dynamic sizes into constant
462
                size = [int(i) for i in size]
463
                return torch.nn.functional.interpolate(x, size=size, mode="nearest")
464

465
        ops_9 = [
466
            {"op_name": "Constant"},
467
            {
468
                "op_name": "Upsample",
469
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
470
            },
471
        ]
472
        ops_10 = [
473
            {"op_name": "Constant"},
474
            {
475
                "op_name": "Resize",
476
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
477
            },
478
        ]
479
        ops = {9: ops_9, 10: ops_10}
480
        x = torch.randn(20, 16, 50)
481
        check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10])
482

483
    def test_grid_sample(self):
484
        n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
485
        ops = {16: [{"op_name": "GridSample"}]}
486

487
        class MyModule(Module):
488
            def forward(self, x, grid, mode, padding_mode, align_corers):
489
                return torch.nn.functional.grid_sample(
490
                    x, grid, mode, padding_mode, align_corners
491
                )
492

493
        for mode, padding_mode, align_corners in itertools.product(
494
            ("bilinear", "nearest", "bicubic"),
495
            ("zeros", "border", "reflection"),
496
            (True, False),
497
        ):
498
            args = (
499
                torch.randn(n, c, h_in, w_in),  # x
500
                torch.randn(n, h_out, w_out, 2),  # grid,
501
                mode,
502
                padding_mode,
503
                align_corners,
504
            )
505
            check_onnx_opsets_operator(
506
                MyModule(),
507
                args,
508
                ops,
509
                opset_versions=[16],
510
                training=torch.onnx.TrainingMode.TRAINING,
511
            )
512
            check_onnx_opsets_operator(
513
                MyModule(),
514
                args,
515
                ops,
516
                opset_versions=[16],
517
                training=torch.onnx.TrainingMode.EVAL,
518
            )
519

520
    def test_flatten(self):
521
        class MyModule(Module):
522
            def forward(self, x):
523
                return torch.flatten(x)
524

525
        module = MyModule()
526

527
        ops_0d = [{"op_name": "Constant"}, {"op_name": "Reshape"}]
528
        ops_1d = [{"op_name": "Identity"}]
529
        for shape in ([], [3]):
530
            x = torch.randn(shape)
531
            for opset_version in [9, 10]:
532
                ops = {opset_version: (ops_0d if len(shape) == 0 else ops_1d)}
533
                check_onnx_opsets_operator(
534
                    module, x, ops, opset_versions=[opset_version]
535
                )
536

537

538
if __name__ == "__main__":
539
    common_utils.run_tests()
540

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.