7
import pytorch_test_common
11
from torch.nn import Module
12
from torch.onnx import producer_name, producer_version
13
from torch.onnx._globals import GLOBALS
14
from torch.testing._internal import common_utils
17
def check_onnx_opset_operator(
18
model, ops, opset_version=GLOBALS.export_onnx_opset_version
22
model.producer_name == producer_name
23
and model.producer_version == producer_version
24
and model.opset_import[0].version == opset_version
28
onnx.checker.check_model(model)
37
assert len(ops) == len(graph.node)
38
for i in range(0, len(ops)):
39
assert graph.node[i].op_type == ops[i]["op_name"]
40
if "attributes" in ops[i]:
41
attributes = ops[i]["attributes"]
42
assert len(attributes) == len(graph.node[i].attribute)
43
for j in range(0, len(attributes)):
44
for attribute_field in attributes[j].keys():
45
assert attributes[j][attribute_field] == getattr(
46
graph.node[i].attribute[j], attribute_field
50
def check_onnx_opsets_operator(
55
training=torch.onnx.TrainingMode.EVAL,
59
for opset_version in opset_versions:
65
opset_version=opset_version,
67
input_names=input_names,
68
dynamic_axes=dynamic_axes,
70
model = onnx.load(io.BytesIO(f.getvalue()))
71
check_onnx_opset_operator(model, ops[opset_version], opset_version)
74
class TestONNXOpset(pytorch_test_common.ExportTestCase):
75
def test_opset_fallback(self):
76
class MyModule(Module):
80
ops = [{"op_name": "IsNaN"}]
81
ops = {9: ops, 10: ops}
82
x = torch.tensor([1.0, float("nan"), 2.0])
83
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
86
class MyModule(Module):
88
return torch.topk(x, 3)
94
{"name": "axis", "i": -1, "type": 2},
95
{"name": "k", "i": 3, "type": 2},
100
{"op_name": "Constant"},
101
{"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]},
103
ops = {9: ops_9, 10: ops_10}
104
x = torch.arange(1.0, 6.0, requires_grad=True)
105
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
108
class MyModuleDynamic(torch.jit.ScriptModule):
109
@torch.jit.script_method
110
def forward(self, input, k):
111
return torch.topk(input, k)
114
{"op_name": "Constant", "attributes": [{"name": "value", "type": 4}]},
115
{"op_name": "Reshape"},
116
{"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]},
119
x = torch.arange(1.0, 6.0, requires_grad=True)
121
module = MyModuleDynamic()
122
check_onnx_opsets_operator(module, (x, k), ops, opset_versions=[10])
124
def test_maxpool(self):
125
module = torch.nn.MaxPool1d(2, stride=1)
129
"op_name": "MaxPool",
131
{"name": "kernel_shape", "ints": [2], "type": 7},
132
{"name": "pads", "ints": [0, 0], "type": 7},
133
{"name": "strides", "ints": [1], "type": 7},
139
"op_name": "MaxPool",
141
{"name": "ceil_mode", "i": 0, "type": 2},
142
{"name": "dilations", "ints": [1], "type": 7},
143
{"name": "kernel_shape", "ints": [2], "type": 7},
144
{"name": "pads", "ints": [0, 0], "type": 7},
145
{"name": "strides", "ints": [1], "type": 7},
149
ops = {9: ops_9, 10: ops_10}
150
x = torch.randn(20, 16, 50)
151
check_onnx_opsets_operator(module, x, ops, opset_versions=[9, 10])
154
module = torch.nn.MaxPool1d(2, stride=1, dilation=2)
158
"op_name": "MaxPool",
160
{"name": "ceil_mode", "i": 0, "type": 2},
161
{"name": "dilations", "ints": [2], "type": 7},
162
{"name": "kernel_shape", "ints": [2], "type": 7},
163
{"name": "pads", "ints": [0, 0], "type": 7},
164
{"name": "strides", "ints": [1], "type": 7},
169
x = torch.randn(20, 16, 50)
170
check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
172
def test_upsample(self):
173
class MyModule(Module):
174
def forward(self, x):
175
size = [v * 2 for v in x.size()[2:]]
176
size = [int(i) for i in size]
177
return torch.nn.functional.interpolate(x, size=size, mode="nearest")
182
"op_name": "Upsample",
184
{"name": "mode", "s": (b"nearest"), "type": 3},
185
{"name": "scales", "floats": [1.0, 1.0, 2.0, 2.0], "type": 6},
190
{"op_name": "Constant"},
192
"op_name": "Upsample",
193
"attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
196
ops = {8: ops8, 9: ops9}
197
x = torch.randn(2, 2, 2, 2)
198
check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
200
def test_cast_constant(self):
201
class MyModule(Module):
202
def forward(self, x):
207
{"op_name": "Constant"},
208
{"op_name": "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]},
211
ops_9 = [{"op_name": "Constant"}, {"op_name": "Sub"}]
212
ops = {8: ops_8, 9: ops_9}
213
x = torch.ones(5, 6, dtype=torch.long)
214
check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
216
def test_slice(self):
217
class MyModule(Module):
218
def forward(self, x):
225
{"name": "axes", "ints": [0], "type": 7},
226
{"name": "ends", "ints": [1], "type": 7},
227
{"name": "starts", "ints": [0], "type": 7},
232
{"op_name": "Constant"},
233
{"op_name": "Constant"},
234
{"op_name": "Constant"},
235
{"op_name": "Constant"},
236
{"op_name": "Slice", "attributes": []},
238
ops = {9: ops_9, 10: ops_10}
240
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
242
class DynamicSliceModel(torch.jit.ScriptModule):
243
@torch.jit.script_method
244
def forward(self, x):
245
return x[1 : x.size(0)]
247
module = DynamicSliceModel()
250
{"op_name": "Shape"},
251
{"op_name": "Constant"},
252
{"op_name": "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
253
{"op_name": "Constant"},
254
{"op_name": "Constant"},
256
"op_name": "Unsqueeze",
257
"attributes": [{"name": "axes", "i": 0, "type": 7}],
259
{"op_name": "Constant"},
260
{"op_name": "Slice", "attributes": []},
263
check_onnx_opsets_operator(
269
dynamic_axes={"x": [0, 1]},
273
{"op_name": "Constant"},
274
{"op_name": "Constant"},
275
{"op_name": "Constant"},
276
{"op_name": "Constant"},
277
{"op_name": "Slice", "attributes": []},
280
check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
283
class MyModule(Module):
284
def forward(self, x):
285
return torch.flip(x, dims=[0])
288
{"op_name": "Constant"},
289
{"op_name": "Constant"},
290
{"op_name": "Constant"},
291
{"op_name": "Constant"},
292
{"op_name": "Slice", "attributes": []},
297
x = torch.tensor(numpy.arange(6.0).reshape(2, 3))
298
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[10])
300
def test_dropout(self):
301
class MyModule(Module):
304
self.dropout = torch.nn.Dropout(0.5)
306
def forward(self, x):
307
return self.dropout(x)
309
x = torch.randn(1, 2, 3)
316
"op_name": "Dropout",
317
"attributes": [{"name": "ratio", "f": 0.5, "type": 1}],
320
ops = {9: ops, 10: ops}
321
check_onnx_opsets_operator(
325
opset_versions=[9, 10],
326
training=torch.onnx.TrainingMode.TRAINING,
330
ops = [{"op_name": "Identity"}]
331
ops = {9: ops, 10: ops}
332
check_onnx_opsets_operator(
336
opset_versions=[9, 10],
337
training=torch.onnx.TrainingMode.EVAL,
341
class MyModule(Module):
342
def forward(self, x):
343
return torch.full((3, 4), x)
346
{"op_name": "Constant"},
347
{"op_name": "ConstantOfShape"},
350
ops = {9: ops, 10: ops}
351
x = torch.tensor(12.0)
352
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
354
def test_interpolate(self):
355
class MyModel(torch.nn.Module):
356
def forward(self, x):
357
size = [v * 2 for v in x.size()[2:]]
358
return torch.nn.functional.interpolate(x, size=size, mode="nearest")
361
{"op_name": "Shape"},
362
{"op_name": "Constant"},
363
{"op_name": "Gather"},
364
{"op_name": "Shape"},
365
{"op_name": "Constant"},
366
{"op_name": "Gather"},
367
{"op_name": "Constant"},
369
{"op_name": "Constant"},
371
{"op_name": "Unsqueeze"},
372
{"op_name": "Unsqueeze"},
373
{"op_name": "Concat"},
375
{"op_name": "Shape"},
376
{"op_name": "Slice"},
379
{"op_name": "Constant"},
380
{"op_name": "Concat"},
382
"op_name": "Upsample",
383
"attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
387
{"op_name": "Shape"},
388
{"op_name": "Constant"},
389
{"op_name": "Gather"},
390
{"op_name": "Shape"},
391
{"op_name": "Constant"},
392
{"op_name": "Gather"},
393
{"op_name": "Constant"},
395
{"op_name": "Constant"},
397
{"op_name": "Unsqueeze"},
398
{"op_name": "Unsqueeze"},
399
{"op_name": "Concat"},
401
{"op_name": "Shape"},
402
{"op_name": "Constant"},
403
{"op_name": "Constant"},
404
{"op_name": "Constant"},
405
{"op_name": "Slice"},
408
{"op_name": "Constant"},
409
{"op_name": "Concat"},
412
"attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
416
ops = {9: ops_9, 10: ops_10}
417
x = torch.randn(1, 2, 3, 4, requires_grad=True)
418
check_onnx_opsets_operator(
422
opset_versions=[9, 10],
424
dynamic_axes={"x": [0, 1, 2, 3]},
428
{"op_name": "Constant"},
429
{"op_name": "Shape"},
430
{"op_name": "Slice"},
433
{"op_name": "Constant"},
434
{"op_name": "Concat"},
436
"op_name": "Upsample",
437
"attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
441
{"op_name": "Constant"},
442
{"op_name": "Shape"},
443
{"op_name": "Constant"},
444
{"op_name": "Constant"},
445
{"op_name": "Constant"},
446
{"op_name": "Slice"},
449
{"op_name": "Constant"},
450
{"op_name": "Concat"},
451
{"op_name": "Resize"},
454
ops = {9: ops_9, 10: ops_10}
455
x = torch.randn(1, 2, 3, 4, requires_grad=True)
456
check_onnx_opsets_operator(MyModel(), x, ops, opset_versions=[9, 10])
458
class MyDynamicModel(torch.nn.Module):
459
def forward(self, x):
460
size = [v * 2 for v in x.size()[2:]]
462
size = [int(i) for i in size]
463
return torch.nn.functional.interpolate(x, size=size, mode="nearest")
466
{"op_name": "Constant"},
468
"op_name": "Upsample",
469
"attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
473
{"op_name": "Constant"},
476
"attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}],
479
ops = {9: ops_9, 10: ops_10}
480
x = torch.randn(20, 16, 50)
481
check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10])
483
def test_grid_sample(self):
484
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
485
ops = {16: [{"op_name": "GridSample"}]}
487
class MyModule(Module):
488
def forward(self, x, grid, mode, padding_mode, align_corers):
489
return torch.nn.functional.grid_sample(
490
x, grid, mode, padding_mode, align_corners
493
for mode, padding_mode, align_corners in itertools.product(
494
("bilinear", "nearest", "bicubic"),
495
("zeros", "border", "reflection"),
499
torch.randn(n, c, h_in, w_in),
500
torch.randn(n, h_out, w_out, 2),
505
check_onnx_opsets_operator(
510
training=torch.onnx.TrainingMode.TRAINING,
512
check_onnx_opsets_operator(
517
training=torch.onnx.TrainingMode.EVAL,
520
def test_flatten(self):
521
class MyModule(Module):
522
def forward(self, x):
523
return torch.flatten(x)
527
ops_0d = [{"op_name": "Constant"}, {"op_name": "Reshape"}]
528
ops_1d = [{"op_name": "Identity"}]
529
for shape in ([], [3]):
530
x = torch.randn(shape)
531
for opset_version in [9, 10]:
532
ops = {opset_version: (ops_0d if len(shape) == 0 else ops_1d)}
533
check_onnx_opsets_operator(
534
module, x, ops, opset_versions=[opset_version]
538
if __name__ == "__main__":
539
common_utils.run_tests()