pytorch

Форк
0
/
test_onnx_opset.py 
654 строки · 21.2 Кб
1
# Owner(s): ["module: onnx"]
2

3
import io
4
import itertools
5

6
import onnx
7

8
import pytorch_test_common
9

10
import torch
11
import torch.onnx
12
from torch.nn import Module
13
from torch.onnx import producer_name, producer_version
14
from torch.onnx._globals import GLOBALS
15
from torch.testing._internal import common_utils
16

17

18
def check_onnx_opset_operator(
19
    model, ops, opset_version=GLOBALS.export_onnx_opset_version
20
):
21
    # check_onnx_components
22
    assert (
23
        model.producer_name == producer_name
24
        and model.producer_version == producer_version
25
        and model.opset_import[0].version == opset_version
26
    )
27

28
    # check the schema with the onnx checker
29
    onnx.checker.check_model(model)
30

31
    # check target type and attributes
32
    graph = model.graph
33
    # ops should contain an object for each node
34
    # in graph.node, in the right order.
35
    # At least the op_name should be specified,
36
    # but the op's attributes can optionally be
37
    # specified as well
38
    assert len(ops) == len(graph.node)
39
    for i in range(0, len(ops)):
40
        assert graph.node[i].op_type == ops[i]["op_name"]
41
        if "attributes" in ops[i]:
42
            attributes = ops[i]["attributes"]
43
            assert len(attributes) == len(graph.node[i].attribute)
44
            for j in range(0, len(attributes)):
45
                for attribute_field in attributes[j].keys():
46
                    assert attributes[j][attribute_field] == getattr(
47
                        graph.node[i].attribute[j], attribute_field
48
                    )
49

50

51
def check_onnx_opsets_operator(
52
    module,
53
    x,
54
    ops,
55
    opset_versions,
56
    training=torch.onnx.TrainingMode.EVAL,
57
    input_names=None,
58
    dynamic_axes=None,
59
):
60
    for opset_version in opset_versions:
61
        f = io.BytesIO()
62
        torch.onnx.export(
63
            module,
64
            x,
65
            f,
66
            opset_version=opset_version,
67
            training=training,
68
            input_names=input_names,
69
            dynamic_axes=dynamic_axes,
70
        )
71
        model = onnx.load(io.BytesIO(f.getvalue()))
72
        check_onnx_opset_operator(model, ops[opset_version], opset_version)
73

74

75
class TestONNXOpset(pytorch_test_common.ExportTestCase):
76
    def test_opset_fallback(self):
77
        class MyModule(Module):
78
            def forward(self, x):
79
                return torch.isnan(x)
80

81
        ops = [{"op_name": "IsNaN"}]
82
        ops = {9: ops, 10: ops}
83
        x = torch.tensor([1.0, float("nan"), 2.0])
84
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
85

86
    def test_topk(self):
87
        class MyModule(Module):
88
            def forward(self, x):
89
                return torch.topk(x, 3)
90

91
        ops_9 = [
92
            {
93
                "op_name": "TopK",
94
                "attributes": [
95
                    {"name": "axis", "i": -1, "type": 2},
96
                    {"name": "k", "i": 3, "type": 2},
97
                ],
98
            }
99
        ]
100
        ops_10 = [
101
            {"op_name": "Constant"},
102
            {"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]},
103
        ]
104
        ops = {9: ops_9, 10: ops_10}
105
        x = torch.arange(1.0, 6.0, requires_grad=True)
106
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
107

108
        # test with dynamic k
109
        class MyModuleDynamic(torch.jit.ScriptModule):
110
            @torch.jit.script_method
111
            def forward(self, input, k):
112
                return torch.topk(input, k)
113

114
        ops_10 = [
115
            {"op_name": "Constant", "attributes": [{"name": "value", "type": 4}]},
116
            {"op_name": "Reshape"},
117
            {"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]},
118
        ]
119
        ops = {10: ops_10}
120
        x = torch.arange(1.0, 6.0, requires_grad=True)
121
        k = torch.tensor(3)
122
        module = MyModuleDynamic()
123
        check_onnx_opsets_operator(module, (x, k), ops, opset_versions=[10])
124

125
    def test_maxpool(self):
126
        module = torch.nn.MaxPool1d(2, stride=1)
127

128
        ops_9 = [
129
            {
130
                "op_name": "MaxPool",
131
                "attributes": [
132
                    {"name": "kernel_shape", "ints": [2], "type": 7},
133
                    {"name": "pads", "ints": [0, 0], "type": 7},
134
                    {"name": "strides", "ints": [1], "type": 7},
135
                ],
136
            }
137
        ]
138
        ops_10 = [
139
            {
140
                "op_name": "MaxPool",
141
                "attributes": [
142
                    {"name": "ceil_mode", "i": 0, "type": 2},
143
                    {"name": "dilations", "ints": [1], "type": 7},
144
                    {"name": "kernel_shape", "ints": [2], "type": 7},
145
                    {"name": "pads", "ints": [0, 0], "type": 7},
146
                    {"name": "strides", "ints": [1], "type": 7},
147
                ],
148
            }
149
        ]
150
        ops = {9: ops_9, 10: ops_10}
151
        x = torch.randn(20, 16, 50)
152
        check_onnx_opsets_operator(module, x, ops, opset_versions=[9, 10])
153

154
        # add test with dilations
155
        module = torch.nn.MaxPool1d(2, stride=1, dilation=2)
156

157
        ops_10 = [
158
            {
159
                "op_name": "MaxPool",
160
                "attributes": [
161
                    {"name": "ceil_mode", "i": 0, "type": 2},
162
                    {"name": "dilations", "ints": [2], "type": 7},
163
                    {"name": "kernel_shape", "ints": [2], "type": 7},
164
                    {"name": "pads", "ints": [0, 0], "type": 7},
165
                    {"name": "strides", "ints": [1], "type": 7},
166
                ],
167
            }
168
        ]
169
        ops = {10: ops_10}
170
        x = torch.randn(20, 16, 50)
171
        check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
172

173
    def test_upsample(self):
174
        class MyModule(Module):
175
            def forward(self, x):
176
                size = [v * 2 for v in x.size()[2:]]
177
                size = [int(i) for i in size]
178
                return torch.nn.functional.interpolate(x, size=size, mode="nearest")
179

180
        module = MyModule()
181
        ops8 = [
182
            {
183
                "op_name": "Upsample",
184
                "attributes": [
185
                    {"name": "mode", "s": (b"nearest"), "type": 3},
186
                    {"name": "scales", "floats": [1.0, 1.0, 2.0, 2.0], "type": 6},
187
                ],
188
            }
189
        ]
190
        ops9 = [
191
            {"op_name": "Constant"},
192
            {
193
                "op_name": "Upsample",
194
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
195
            },
196
        ]
197
        ops = {8: ops8, 9: ops9}
198
        x = torch.randn(2, 2, 2, 2)
199
        check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
200

201
    def test_cast_constant(self):
202
        class MyModule(Module):
203
            def forward(self, x):
204
                return x - 1
205

206
        module = MyModule()
207
        ops_8 = [
208
            {"op_name": "Constant"},
209
            {"op_name": "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]},
210
            {"op_name": "Sub"},
211
        ]
212
        ops_9 = [{"op_name": "Constant"}, {"op_name": "Sub"}]
213
        ops = {8: ops_8, 9: ops_9}
214
        x = torch.ones(5, 6, dtype=torch.long)
215
        check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
216

217
    def test_slice(self):
218
        class MyModule(Module):
219
            def forward(self, x):
220
                return x[0:1]
221

222
        ops_9 = [
223
            {
224
                "op_name": "Slice",
225
                "attributes": [
226
                    {"name": "axes", "ints": [0], "type": 7},
227
                    {"name": "ends", "ints": [1], "type": 7},
228
                    {"name": "starts", "ints": [0], "type": 7},
229
                ],
230
            }
231
        ]
232
        ops_10 = [
233
            {"op_name": "Constant"},
234
            {"op_name": "Constant"},
235
            {"op_name": "Constant"},
236
            {"op_name": "Constant"},
237
            {"op_name": "Slice", "attributes": []},
238
        ]
239
        ops = {9: ops_9, 10: ops_10}
240
        x = torch.randn(3)
241
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
242

243
        class DynamicSliceModel(torch.jit.ScriptModule):
244
            @torch.jit.script_method
245
            def forward(self, x):
246
                return x[1 : x.size(0)]
247

248
        module = DynamicSliceModel()
249
        x = torch.rand(1, 2)
250
        ops_10 = [
251
            {"op_name": "Shape"},
252
            {"op_name": "Constant"},
253
            {"op_name": "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
254
            {"op_name": "Constant"},
255
            {"op_name": "Constant"},
256
            {
257
                "op_name": "Unsqueeze",
258
                "attributes": [{"name": "axes", "i": 0, "type": 7}],
259
            },
260
            {"op_name": "Constant"},
261
            {"op_name": "Slice", "attributes": []},
262
        ]
263
        ops = {10: ops_10}
264
        check_onnx_opsets_operator(
265
            module,
266
            x,
267
            ops,
268
            opset_versions=[10],
269
            input_names=["x"],
270
            dynamic_axes={"x": [0, 1]},
271
        )
272

273
        ops_10 = [
274
            {"op_name": "Constant"},
275
            {"op_name": "Constant"},
276
            {"op_name": "Constant"},
277
            {"op_name": "Constant"},
278
            {"op_name": "Slice", "attributes": []},
279
        ]
280
        ops = {10: ops_10}
281
        check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
282

283
    def test_flip(self):
284
        class MyModule(Module):
285
            def forward(self, x):
286
                return torch.flip(x, dims=[0])
287

288
        ops_10 = [
289
            {"op_name": "Constant"},
290
            {"op_name": "Constant"},
291
            {"op_name": "Constant"},
292
            {"op_name": "Constant"},
293
            {"op_name": "Slice", "attributes": []},
294
        ]
295
        ops = {10: ops_10}
296
        import numpy
297

298
        x = torch.tensor(numpy.arange(6.0).reshape(2, 3))
299
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[10])
300

301
    def test_dropout(self):
302
        class MyModule(Module):
303
            def __init__(self) -> None:
304
                super().__init__()
305
                self.dropout = torch.nn.Dropout(0.5)
306

307
            def forward(self, x):
308
                return self.dropout(x)
309

310
        x = torch.randn(1, 2, 3)
311

312
        # we should only export the onnx Dropout op in training mode; test both modes
313

314
        # test training mode
315
        ops = [
316
            {
317
                "op_name": "Dropout",
318
                "attributes": [{"name": "ratio", "f": 0.5, "type": 1}],
319
            }
320
        ]
321
        ops = {9: ops, 10: ops}
322
        check_onnx_opsets_operator(
323
            MyModule(),
324
            x,
325
            ops,
326
            opset_versions=[9, 10],
327
            training=torch.onnx.TrainingMode.TRAINING,
328
        )
329

330
        # test eval mode
331
        ops = [{"op_name": "Identity"}]
332
        ops = {9: ops, 10: ops}
333
        check_onnx_opsets_operator(
334
            MyModule(),
335
            x,
336
            ops,
337
            opset_versions=[9, 10],
338
            training=torch.onnx.TrainingMode.EVAL,
339
        )
340

341
    def test_full(self):
342
        class MyModule(Module):
343
            def forward(self, x):
344
                return torch.full((3, 4), x)
345

346
        ops = [
347
            {"op_name": "Constant"},
348
            {"op_name": "ConstantOfShape"},
349
            {"op_name": "Add"},
350
        ]
351
        ops = {9: ops, 10: ops}
352
        x = torch.tensor(12.0)
353
        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
354

355
    def test_interpolate(self):
356
        class MyModel(torch.nn.Module):
357
            def forward(self, x):
358
                size = [v * 2 for v in x.size()[2:]]
359
                return torch.nn.functional.interpolate(x, size=size, mode="nearest")
360

361
        ops_9 = [
362
            {"op_name": "Shape"},
363
            {"op_name": "Constant"},
364
            {"op_name": "Gather"},
365
            {"op_name": "Shape"},
366
            {"op_name": "Constant"},
367
            {"op_name": "Gather"},
368
            {"op_name": "Constant"},
369
            {"op_name": "Mul"},
370
            {"op_name": "Constant"},
371
            {"op_name": "Mul"},
372
            {"op_name": "Unsqueeze"},
373
            {"op_name": "Unsqueeze"},
374
            {"op_name": "Concat"},
375
            {"op_name": "Cast"},
376
            {"op_name": "Shape"},
377
            {"op_name": "Slice"},
378
            {"op_name": "Cast"},
379
            {"op_name": "Div"},
380
            {"op_name": "Constant"},
381
            {"op_name": "Concat"},
382
            {
383
                "op_name": "Upsample",
384
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
385
            },
386
        ]
387
        ops_10 = [
388
            {"op_name": "Shape"},
389
            {"op_name": "Constant"},
390
            {"op_name": "Gather"},
391
            {"op_name": "Shape"},
392
            {"op_name": "Constant"},
393
            {"op_name": "Gather"},
394
            {"op_name": "Constant"},
395
            {"op_name": "Mul"},
396
            {"op_name": "Constant"},
397
            {"op_name": "Mul"},
398
            {"op_name": "Unsqueeze"},
399
            {"op_name": "Unsqueeze"},
400
            {"op_name": "Concat"},
401
            {"op_name": "Cast"},
402
            {"op_name": "Shape"},
403
            {"op_name": "Constant"},
404
            {"op_name": "Constant"},
405
            {"op_name": "Constant"},
406
            {"op_name": "Slice"},
407
            {"op_name": "Cast"},
408
            {"op_name": "Div"},
409
            {"op_name": "Constant"},
410
            {"op_name": "Concat"},
411
            {
412
                "op_name": "Resize",
413
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
414
            },
415
        ]
416

417
        ops = {9: ops_9, 10: ops_10}
418
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
419
        check_onnx_opsets_operator(
420
            MyModel(),
421
            x,
422
            ops,
423
            opset_versions=[9, 10],
424
            input_names=["x"],
425
            dynamic_axes={"x": [0, 1, 2, 3]},
426
        )
427

428
        ops_9 = [
429
            {"op_name": "Constant"},
430
            {"op_name": "Shape"},
431
            {"op_name": "Slice"},
432
            {"op_name": "Cast"},
433
            {"op_name": "Div"},
434
            {"op_name": "Constant"},
435
            {"op_name": "Concat"},
436
            {
437
                "op_name": "Upsample",
438
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
439
            },
440
        ]
441
        ops_10 = [
442
            {"op_name": "Constant"},
443
            {"op_name": "Shape"},
444
            {"op_name": "Constant"},
445
            {"op_name": "Constant"},
446
            {"op_name": "Constant"},
447
            {"op_name": "Slice"},
448
            {"op_name": "Cast"},
449
            {"op_name": "Div"},
450
            {"op_name": "Constant"},
451
            {"op_name": "Concat"},
452
            {"op_name": "Resize"},
453
        ]
454

455
        ops = {9: ops_9, 10: ops_10}
456
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
457
        check_onnx_opsets_operator(MyModel(), x, ops, opset_versions=[9, 10])
458

459
        class MyDynamicModel(torch.nn.Module):
460
            def forward(self, x):
461
                size = [v * 2 for v in x.size()[2:]]
462
                # work around for now: turn the dynamic sizes into constant
463
                size = [int(i) for i in size]
464
                return torch.nn.functional.interpolate(x, size=size, mode="nearest")
465

466
        ops_9 = [
467
            {"op_name": "Constant"},
468
            {
469
                "op_name": "Upsample",
470
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
471
            },
472
        ]
473
        ops_10 = [
474
            {"op_name": "Constant"},
475
            {
476
                "op_name": "Resize",
477
                "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
478
            },
479
        ]
480
        ops = {9: ops_9, 10: ops_10}
481
        x = torch.randn(20, 16, 50)
482
        check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10])
483

484
    def test_affine_grid(self):
485
        class MyModule(Module):
486
            def __init__(self, align_corners):
487
                super().__init__()
488
                self.align_corners = align_corners
489

490
            def forward(self, theta, size):
491
                return torch.nn.functional.affine_grid(
492
                    theta, size, align_corners=self.align_corners
493
                )
494

495
        opset_version = 20
496
        ops_2d = {
497
            opset_version: [
498
                {"op_name": "Constant"},
499
                {"op_name": "Unsqueeze"},
500
                {"op_name": "Constant"},
501
                {"op_name": "Unsqueeze"},
502
                {"op_name": "Constant"},
503
                {"op_name": "Unsqueeze"},
504
                {"op_name": "Constant"},
505
                {"op_name": "Unsqueeze"},
506
                {"op_name": "Concat"},
507
                {"op_name": "AffineGrid"},
508
            ]
509
        }
510

511
        ops_3d = {
512
            opset_version: [
513
                {"op_name": "Constant"},
514
                {"op_name": "Unsqueeze"},
515
                {"op_name": "Constant"},
516
                {"op_name": "Unsqueeze"},
517
                {"op_name": "Constant"},
518
                {"op_name": "Unsqueeze"},
519
                {"op_name": "Constant"},
520
                {"op_name": "Unsqueeze"},
521
                {"op_name": "Constant"},
522
                {"op_name": "Unsqueeze"},
523
                {"op_name": "Concat"},
524
                {"op_name": "AffineGrid"},
525
            ]
526
        }
527
        # 2D affine
528
        theta_2d = torch.empty(1, 2, 3, dtype=torch.double)
529
        size_2d = torch.Size([1, 1, 2, 2])
530
        # 3D affine
531
        theta_3d = torch.empty(1, 3, 4, dtype=torch.double)
532
        size_3d = torch.Size([1, 1, 2, 2, 2])
533

534
        for inputs, align_corners in itertools.product(
535
            ((theta_2d, size_2d, ops_2d), (theta_3d, size_3d, ops_3d)),
536
            (True, False),
537
        ):
538
            theta, size, ops = inputs
539
            args = (
540
                theta,
541
                size,
542
            )
543
            check_onnx_opsets_operator(
544
                MyModule(align_corners=align_corners),
545
                args,
546
                ops,
547
                opset_versions=[opset_version],
548
                training=torch.onnx.TrainingMode.TRAINING,
549
            )
550
            check_onnx_opsets_operator(
551
                MyModule(align_corners=align_corners),
552
                args,
553
                ops,
554
                opset_versions=[opset_version],
555
                training=torch.onnx.TrainingMode.EVAL,
556
            )
557

558
    def test_grid_sample(self):
559
        class MyModule(torch.nn.Module):
560
            def __init__(self, mode, padding_mode, align_corners):
561
                super().__init__()
562
                self.mode = mode
563
                self.padding_mode = padding_mode
564
                self.align_corners = align_corners
565

566
            def forward(self, x, grid):
567
                return torch.nn.functional.grid_sample(
568
                    x,
569
                    grid,
570
                    mode=self.mode,
571
                    padding_mode=self.padding_mode,
572
                    align_corners=self.align_corners,
573
                )
574

575
        for mode, padding_mode, align_corners, opset_version in itertools.product(
576
            ("bilinear", "nearest", "bicubic"),
577
            ("zeros", "border", "reflection"),
578
            (True, False),
579
            (16, 20),
580
        ):
581

582
            def test_eval_and_training(
583
                ops, opset_version, mode, padding_mode, align_corners, x_shape, grid
584
            ):
585
                args = (
586
                    torch.randn(*x_shape),  # x
587
                    torch.randn(grid),  # grid,
588
                )
589
                check_onnx_opsets_operator(
590
                    MyModule(
591
                        mode=mode,
592
                        padding_mode=padding_mode,
593
                        align_corners=align_corners,
594
                    ),
595
                    args,
596
                    ops,
597
                    opset_versions=[opset_version],
598
                    training=torch.onnx.TrainingMode.TRAINING,
599
                )
600
                check_onnx_opsets_operator(
601
                    MyModule(
602
                        mode=mode,
603
                        padding_mode=padding_mode,
604
                        align_corners=align_corners,
605
                    ),
606
                    args,
607
                    ops,
608
                    opset_versions=[opset_version],
609
                    training=torch.onnx.TrainingMode.EVAL,
610
                )
611

612
            ops = {opset_version: [{"op_name": "GridSample"}]}
613
            # mode = convert_grid_sample_mode(mode) if opset_version == 20 else mode
614
            n, c, d_in, h_in, w_in, d_out, h_out, w_out = 1, 1, 2, 3, 2, 3, 2, 4
615
            test_eval_and_training(
616
                ops,
617
                opset_version,
618
                mode,
619
                padding_mode,
620
                align_corners,
621
                (n, c, h_in, w_in),
622
                (n, h_out, w_out, 2),
623
            )
624
            if opset_version == 20 and mode != "bicubic":
625
                test_eval_and_training(
626
                    ops,
627
                    opset_version,
628
                    mode,
629
                    padding_mode,
630
                    align_corners,
631
                    (n, c, d_in, h_in, w_in),
632
                    (n, d_out, h_out, w_out, 3),
633
                )
634

635
    def test_flatten(self):
636
        class MyModule(Module):
637
            def forward(self, x):
638
                return torch.flatten(x)
639

640
        module = MyModule()
641

642
        ops_0d = [{"op_name": "Constant"}, {"op_name": "Reshape"}]
643
        ops_1d = [{"op_name": "Identity"}]
644
        for shape in ([], [3]):
645
            x = torch.randn(shape)
646
            for opset_version in [9, 10]:
647
                ops = {opset_version: (ops_0d if len(shape) == 0 else ops_1d)}
648
                check_onnx_opsets_operator(
649
                    module, x, ops, opset_versions=[opset_version]
650
                )
651

652

653
if __name__ == "__main__":
654
    common_utils.run_tests()
655

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.