8
import pytorch_test_common
12
from torch.nn import Module
13
from torch.onnx import producer_name, producer_version
14
from torch.onnx._globals import GLOBALS
15
from torch.testing._internal import common_utils
18
def check_onnx_opset_operator(
19
model, ops, opset_version=GLOBALS.export_onnx_opset_version
23
model.producer_name == producer_name
24
and model.producer_version == producer_version
25
and model.opset_import[0].version == opset_version
29
onnx.checker.check_model(model)
38
assert len(ops) == len(graph.node)
39
for i in range(0, len(ops)):
40
assert graph.node[i].op_type == ops[i]["op_name"]
41
if "attributes" in ops[i]:
42
attributes = ops[i]["attributes"]
43
assert len(attributes) == len(graph.node[i].attribute)
44
for j in range(0, len(attributes)):
45
for attribute_field in attributes[j].keys():
46
assert attributes[j][attribute_field] == getattr(
47
graph.node[i].attribute[j], attribute_field
51
def check_onnx_opsets_operator(
56
training=torch.onnx.TrainingMode.EVAL,
60
for opset_version in opset_versions:
66
opset_version=opset_version,
68
input_names=input_names,
69
dynamic_axes=dynamic_axes,
71
model = onnx.load(io.BytesIO(f.getvalue()))
72
check_onnx_opset_operator(model, ops[opset_version], opset_version)
75
class TestONNXOpset(pytorch_test_common.ExportTestCase):
76
def test_opset_fallback(self):
77
class MyModule(Module):
81
ops = [{"op_name": "IsNaN"}]
82
ops = {9: ops, 10: ops}
83
x = torch.tensor([1.0, float("nan"), 2.0])
84
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
87
class MyModule(Module):
89
return torch.topk(x, 3)
95
{"name": "axis", "i": -1, "type": 2},
96
{"name": "k", "i": 3, "type": 2},
101
{"op_name": "Constant"},
102
{"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]},
104
ops = {9: ops_9, 10: ops_10}
105
x = torch.arange(1.0, 6.0, requires_grad=True)
106
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
109
class MyModuleDynamic(torch.jit.ScriptModule):
110
@torch.jit.script_method
111
def forward(self, input, k):
112
return torch.topk(input, k)
115
{"op_name": "Constant", "attributes": [{"name": "value", "type": 4}]},
116
{"op_name": "Reshape"},
117
{"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]},
120
x = torch.arange(1.0, 6.0, requires_grad=True)
122
module = MyModuleDynamic()
123
check_onnx_opsets_operator(module, (x, k), ops, opset_versions=[10])
125
def test_maxpool(self):
126
module = torch.nn.MaxPool1d(2, stride=1)
130
"op_name": "MaxPool",
132
{"name": "kernel_shape", "ints": [2], "type": 7},
133
{"name": "pads", "ints": [0, 0], "type": 7},
134
{"name": "strides", "ints": [1], "type": 7},
140
"op_name": "MaxPool",
142
{"name": "ceil_mode", "i": 0, "type": 2},
143
{"name": "dilations", "ints": [1], "type": 7},
144
{"name": "kernel_shape", "ints": [2], "type": 7},
145
{"name": "pads", "ints": [0, 0], "type": 7},
146
{"name": "strides", "ints": [1], "type": 7},
150
ops = {9: ops_9, 10: ops_10}
151
x = torch.randn(20, 16, 50)
152
check_onnx_opsets_operator(module, x, ops, opset_versions=[9, 10])
155
module = torch.nn.MaxPool1d(2, stride=1, dilation=2)
159
"op_name": "MaxPool",
161
{"name": "ceil_mode", "i": 0, "type": 2},
162
{"name": "dilations", "ints": [2], "type": 7},
163
{"name": "kernel_shape", "ints": [2], "type": 7},
164
{"name": "pads", "ints": [0, 0], "type": 7},
165
{"name": "strides", "ints": [1], "type": 7},
170
x = torch.randn(20, 16, 50)
171
check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
173
def test_upsample(self):
174
class MyModule(Module):
175
def forward(self, x):
176
size = [v * 2 for v in x.size()[2:]]
177
size = [int(i) for i in size]
178
return torch.nn.functional.interpolate(x, size=size, mode="nearest")
183
"op_name": "Upsample",
185
{"name": "mode", "s": (b"nearest"), "type": 3},
186
{"name": "scales", "floats": [1.0, 1.0, 2.0, 2.0], "type": 6},
191
{"op_name": "Constant"},
193
"op_name": "Upsample",
194
"attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
197
ops = {8: ops8, 9: ops9}
198
x = torch.randn(2, 2, 2, 2)
199
check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
201
def test_cast_constant(self):
202
class MyModule(Module):
203
def forward(self, x):
208
{"op_name": "Constant"},
209
{"op_name": "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]},
212
ops_9 = [{"op_name": "Constant"}, {"op_name": "Sub"}]
213
ops = {8: ops_8, 9: ops_9}
214
x = torch.ones(5, 6, dtype=torch.long)
215
check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
217
def test_slice(self):
218
class MyModule(Module):
219
def forward(self, x):
226
{"name": "axes", "ints": [0], "type": 7},
227
{"name": "ends", "ints": [1], "type": 7},
228
{"name": "starts", "ints": [0], "type": 7},
233
{"op_name": "Constant"},
234
{"op_name": "Constant"},
235
{"op_name": "Constant"},
236
{"op_name": "Constant"},
237
{"op_name": "Slice", "attributes": []},
239
ops = {9: ops_9, 10: ops_10}
241
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
243
class DynamicSliceModel(torch.jit.ScriptModule):
244
@torch.jit.script_method
245
def forward(self, x):
246
return x[1 : x.size(0)]
248
module = DynamicSliceModel()
251
{"op_name": "Shape"},
252
{"op_name": "Constant"},
253
{"op_name": "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
254
{"op_name": "Constant"},
255
{"op_name": "Constant"},
257
"op_name": "Unsqueeze",
258
"attributes": [{"name": "axes", "i": 0, "type": 7}],
260
{"op_name": "Constant"},
261
{"op_name": "Slice", "attributes": []},
264
check_onnx_opsets_operator(
270
dynamic_axes={"x": [0, 1]},
274
{"op_name": "Constant"},
275
{"op_name": "Constant"},
276
{"op_name": "Constant"},
277
{"op_name": "Constant"},
278
{"op_name": "Slice", "attributes": []},
281
check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
284
class MyModule(Module):
285
def forward(self, x):
286
return torch.flip(x, dims=[0])
289
{"op_name": "Constant"},
290
{"op_name": "Constant"},
291
{"op_name": "Constant"},
292
{"op_name": "Constant"},
293
{"op_name": "Slice", "attributes": []},
298
x = torch.tensor(numpy.arange(6.0).reshape(2, 3))
299
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[10])
301
def test_dropout(self):
302
class MyModule(Module):
303
def __init__(self) -> None:
305
self.dropout = torch.nn.Dropout(0.5)
307
def forward(self, x):
308
return self.dropout(x)
310
x = torch.randn(1, 2, 3)
317
"op_name": "Dropout",
318
"attributes": [{"name": "ratio", "f": 0.5, "type": 1}],
321
ops = {9: ops, 10: ops}
322
check_onnx_opsets_operator(
326
opset_versions=[9, 10],
327
training=torch.onnx.TrainingMode.TRAINING,
331
ops = [{"op_name": "Identity"}]
332
ops = {9: ops, 10: ops}
333
check_onnx_opsets_operator(
337
opset_versions=[9, 10],
338
training=torch.onnx.TrainingMode.EVAL,
342
class MyModule(Module):
343
def forward(self, x):
344
return torch.full((3, 4), x)
347
{"op_name": "Constant"},
348
{"op_name": "ConstantOfShape"},
351
ops = {9: ops, 10: ops}
352
x = torch.tensor(12.0)
353
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
355
def test_interpolate(self):
356
class MyModel(torch.nn.Module):
357
def forward(self, x):
358
size = [v * 2 for v in x.size()[2:]]
359
return torch.nn.functional.interpolate(x, size=size, mode="nearest")
362
{"op_name": "Shape"},
363
{"op_name": "Constant"},
364
{"op_name": "Gather"},
365
{"op_name": "Shape"},
366
{"op_name": "Constant"},
367
{"op_name": "Gather"},
368
{"op_name": "Constant"},
370
{"op_name": "Constant"},
372
{"op_name": "Unsqueeze"},
373
{"op_name": "Unsqueeze"},
374
{"op_name": "Concat"},
376
{"op_name": "Shape"},
377
{"op_name": "Slice"},
380
{"op_name": "Constant"},
381
{"op_name": "Concat"},
383
"op_name": "Upsample",
384
"attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
388
{"op_name": "Shape"},
389
{"op_name": "Constant"},
390
{"op_name": "Gather"},
391
{"op_name": "Shape"},
392
{"op_name": "Constant"},
393
{"op_name": "Gather"},
394
{"op_name": "Constant"},
396
{"op_name": "Constant"},
398
{"op_name": "Unsqueeze"},
399
{"op_name": "Unsqueeze"},
400
{"op_name": "Concat"},
402
{"op_name": "Shape"},
403
{"op_name": "Constant"},
404
{"op_name": "Constant"},
405
{"op_name": "Constant"},
406
{"op_name": "Slice"},
409
{"op_name": "Constant"},
410
{"op_name": "Concat"},
413
"attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
417
ops = {9: ops_9, 10: ops_10}
418
x = torch.randn(1, 2, 3, 4, requires_grad=True)
419
check_onnx_opsets_operator(
423
opset_versions=[9, 10],
425
dynamic_axes={"x": [0, 1, 2, 3]},
429
{"op_name": "Constant"},
430
{"op_name": "Shape"},
431
{"op_name": "Slice"},
434
{"op_name": "Constant"},
435
{"op_name": "Concat"},
437
"op_name": "Upsample",
438
"attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
442
{"op_name": "Constant"},
443
{"op_name": "Shape"},
444
{"op_name": "Constant"},
445
{"op_name": "Constant"},
446
{"op_name": "Constant"},
447
{"op_name": "Slice"},
450
{"op_name": "Constant"},
451
{"op_name": "Concat"},
452
{"op_name": "Resize"},
455
ops = {9: ops_9, 10: ops_10}
456
x = torch.randn(1, 2, 3, 4, requires_grad=True)
457
check_onnx_opsets_operator(MyModel(), x, ops, opset_versions=[9, 10])
459
class MyDynamicModel(torch.nn.Module):
460
def forward(self, x):
461
size = [v * 2 for v in x.size()[2:]]
463
size = [int(i) for i in size]
464
return torch.nn.functional.interpolate(x, size=size, mode="nearest")
467
{"op_name": "Constant"},
469
"op_name": "Upsample",
470
"attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
474
{"op_name": "Constant"},
477
"attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
480
ops = {9: ops_9, 10: ops_10}
481
x = torch.randn(20, 16, 50)
482
check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10])
484
def test_affine_grid(self):
485
class MyModule(Module):
486
def __init__(self, align_corners):
488
self.align_corners = align_corners
490
def forward(self, theta, size):
491
return torch.nn.functional.affine_grid(
492
theta, size, align_corners=self.align_corners
498
{"op_name": "Constant"},
499
{"op_name": "Unsqueeze"},
500
{"op_name": "Constant"},
501
{"op_name": "Unsqueeze"},
502
{"op_name": "Constant"},
503
{"op_name": "Unsqueeze"},
504
{"op_name": "Constant"},
505
{"op_name": "Unsqueeze"},
506
{"op_name": "Concat"},
507
{"op_name": "AffineGrid"},
513
{"op_name": "Constant"},
514
{"op_name": "Unsqueeze"},
515
{"op_name": "Constant"},
516
{"op_name": "Unsqueeze"},
517
{"op_name": "Constant"},
518
{"op_name": "Unsqueeze"},
519
{"op_name": "Constant"},
520
{"op_name": "Unsqueeze"},
521
{"op_name": "Constant"},
522
{"op_name": "Unsqueeze"},
523
{"op_name": "Concat"},
524
{"op_name": "AffineGrid"},
528
theta_2d = torch.empty(1, 2, 3, dtype=torch.double)
529
size_2d = torch.Size([1, 1, 2, 2])
531
theta_3d = torch.empty(1, 3, 4, dtype=torch.double)
532
size_3d = torch.Size([1, 1, 2, 2, 2])
534
for inputs, align_corners in itertools.product(
535
((theta_2d, size_2d, ops_2d), (theta_3d, size_3d, ops_3d)),
538
theta, size, ops = inputs
543
check_onnx_opsets_operator(
544
MyModule(align_corners=align_corners),
547
opset_versions=[opset_version],
548
training=torch.onnx.TrainingMode.TRAINING,
550
check_onnx_opsets_operator(
551
MyModule(align_corners=align_corners),
554
opset_versions=[opset_version],
555
training=torch.onnx.TrainingMode.EVAL,
558
def test_grid_sample(self):
559
class MyModule(torch.nn.Module):
560
def __init__(self, mode, padding_mode, align_corners):
563
self.padding_mode = padding_mode
564
self.align_corners = align_corners
566
def forward(self, x, grid):
567
return torch.nn.functional.grid_sample(
571
padding_mode=self.padding_mode,
572
align_corners=self.align_corners,
575
for mode, padding_mode, align_corners, opset_version in itertools.product(
576
("bilinear", "nearest", "bicubic"),
577
("zeros", "border", "reflection"),
582
def test_eval_and_training(
583
ops, opset_version, mode, padding_mode, align_corners, x_shape, grid
586
torch.randn(*x_shape),
589
check_onnx_opsets_operator(
592
padding_mode=padding_mode,
593
align_corners=align_corners,
597
opset_versions=[opset_version],
598
training=torch.onnx.TrainingMode.TRAINING,
600
check_onnx_opsets_operator(
603
padding_mode=padding_mode,
604
align_corners=align_corners,
608
opset_versions=[opset_version],
609
training=torch.onnx.TrainingMode.EVAL,
612
ops = {opset_version: [{"op_name": "GridSample"}]}
614
n, c, d_in, h_in, w_in, d_out, h_out, w_out = 1, 1, 2, 3, 2, 3, 2, 4
615
test_eval_and_training(
622
(n, h_out, w_out, 2),
624
if opset_version == 20 and mode != "bicubic":
625
test_eval_and_training(
631
(n, c, d_in, h_in, w_in),
632
(n, d_out, h_out, w_out, 3),
635
def test_flatten(self):
636
class MyModule(Module):
637
def forward(self, x):
638
return torch.flatten(x)
642
ops_0d = [{"op_name": "Constant"}, {"op_name": "Reshape"}]
643
ops_1d = [{"op_name": "Identity"}]
644
for shape in ([], [3]):
645
x = torch.randn(shape)
646
for opset_version in [9, 10]:
647
ops = {opset_version: (ops_0d if len(shape) == 0 else ops_1d)}
648
check_onnx_opsets_operator(
649
module, x, ops, opset_versions=[opset_version]
653
if __name__ == "__main__":
654
common_utils.run_tests()