4
Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
5
--no-onnx: no onnx python dependency
6
--produce-onnx-test-data: generate onnx test data
7
--accept: accept onnx updates and overwrite models
23
import torch.nn.functional as F
26
from pytorch_test_common import (
33
from torch.autograd import Function, Variable
34
from torch.nn import functional, Module
35
from torch.onnx._internal import diagnostics
36
from torch.onnx.symbolic_helper import (
41
from torch.testing._internal import common_utils
42
from torch.testing._internal.common_utils import skipIfCaffe2, skipIfNoLapack
44
unittest.TestCase.maxDiff = None
50
def export_to_pbtxt(model, inputs, *args, **kwargs):
51
return torch.onnx.export_to_pretty_string(
52
model, inputs, *args, google_printer=True, **kwargs
56
def export_to_pb(model, inputs, *args, **kwargs):
59
torch.onnx.export(model, inputs, f, *args, **kwargs)
63
class FuncModule(Module):
64
def __init__(self, f, params=None):
69
self.params = nn.ParameterList(list(params))
71
def forward(self, *args):
72
return self.f(*itertools.chain(args, self.params))
75
class TestOperators(common_utils.TestCase):
78
diagnostics.engine.clear()
80
def assertONNX(self, f, args, params=None, **kwargs):
83
if isinstance(f, nn.Module):
86
m = FuncModule(f, params)
88
onnx_model_pbtxt = export_to_pbtxt(m, args, **kwargs)
89
subname = kwargs.pop("subname", None)
90
self.assertExpected(onnx_model_pbtxt, subname)
92
onnx_model_pb = export_to_pb(m, args, **kwargs)
95
import onnx.numpy_helper
96
import onnx_test_common
98
model_def = onnx.ModelProto.FromString(onnx_model_pb)
99
onnx.checker.check_model(model_def)
101
test_function = inspect.stack()[1][0].f_code.co_name
102
test_name = test_function[0:4] + "_operator" + test_function[4:]
103
output_dir = os.path.join(
104
onnx_test_common.pytorch_operator_dir, test_name
109
assert not os.path.exists(output_dir), f"{output_dir} should not exist!"
110
os.makedirs(output_dir)
111
with open(os.path.join(output_dir, "model.onnx"), "wb") as file:
112
file.write(model_def.SerializeToString())
113
data_dir = os.path.join(output_dir, "test_data_set_0")
114
os.makedirs(data_dir)
115
if isinstance(args, Variable):
117
for index, var in enumerate(flatten(args)):
118
tensor = onnx.numpy_helper.from_array(var.data.numpy())
120
os.path.join(data_dir, f"input_{index}.pb"), "wb"
122
file.write(tensor.SerializeToString())
124
if isinstance(outputs, Variable):
126
for index, var in enumerate(flatten(outputs)):
127
tensor = onnx.numpy_helper.from_array(var.data.numpy())
129
os.path.join(data_dir, f"output_{index}.pb"), "wb"
131
file.write(tensor.SerializeToString())
133
def assertONNXRaises(self, err, f, args, params=None, **kwargs):
136
if isinstance(f, nn.Module):
139
m = FuncModule(f, params)
140
self.assertExpectedRaises(err, lambda: export_to_pbtxt(m, args, **kwargs))
142
def assertONNXRaisesRegex(self, err, reg, f, args, params=None, **kwargs):
145
if isinstance(f, nn.Module):
148
m = FuncModule(f, params)
149
with self.assertRaisesRegex(err, reg):
150
export_to_pbtxt(m, args, **kwargs)
152
def test_basic(self):
153
x = torch.tensor([0.4], requires_grad=True)
154
y = torch.tensor([0.7], requires_grad=True)
155
self.assertONNX(lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), (x, y))
158
x = torch.tensor([0.0], requires_grad=True)
159
self.assertONNX(lambda x: x.view(1, 1), x)
161
def test_index(self):
162
x = torch.tensor([[0.0]], requires_grad=True)
163
self.assertONNX(lambda x: x[0], x)
165
def test_type_as(self):
166
x = torch.tensor([0.0], requires_grad=True)
167
self.assertONNX(lambda x: x.type_as(x), x)
169
def test_addconstant(self):
170
x = torch.randn(2, 3, requires_grad=True).double()
171
self.assertONNX(lambda x: x + 1, x)
173
def test_add_broadcast(self):
174
x = torch.randn(2, 3, requires_grad=True).double()
175
y = torch.randn(3, requires_grad=True).double()
176
self.assertONNX(operator.add, (x, y))
178
def test_add_left_broadcast(self):
179
x = torch.randn(3, requires_grad=True).double()
180
y = torch.randn(2, 3, requires_grad=True).double()
181
self.assertONNX(operator.add, (x, y))
183
def test_add_size1_broadcast(self):
184
x = torch.randn(2, 3, requires_grad=True).double()
185
y = torch.randn(2, 1, requires_grad=True).double()
186
self.assertONNX(operator.add, (x, y))
188
def test_add_size1_right_broadcast(self):
189
x = torch.randn(2, 3, requires_grad=True).double()
190
y = torch.randn(3, requires_grad=True).double()
191
self.assertONNX(operator.add, (x, y))
193
def test_add_size1_singleton_broadcast(self):
194
x = torch.randn(2, 3, requires_grad=True).double()
195
y = torch.randn(1, 3, requires_grad=True).double()
196
self.assertONNX(operator.add, (x, y))
199
x = torch.randn(2, 3, requires_grad=True).double()
200
self.assertONNX(lambda x: 1 - x, (x,))
202
def test_mul_bool(self):
203
x = torch.tensor([True, False, True, False])
204
y = torch.tensor([True, True, False, False])
205
self.assertONNX(lambda x, y: torch.mul(x, y), (x, y))
207
def test_mul_fp_bool(self):
208
x = torch.tensor([9.4, 1.7, 3.6])
209
y = torch.tensor([True, True, False])
210
self.assertONNX(lambda x, y: torch.mul(x, y), (x, y))
212
def test_transpose(self):
213
x = torch.tensor([[0.0, 1.0], [2.0, 3.0]], requires_grad=True)
214
self.assertONNX(lambda x: x.transpose(0, 1).transpose(1, 0), x)
216
def test_chunk(self):
217
x = torch.tensor([0.0, 1.0, 2.0], requires_grad=True)
218
self.assertONNX(lambda x: x.chunk(2), x)
220
def test_split(self):
222
[[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]]
224
self.assertONNX(lambda x: torch.split(x, 2, 1), x)
226
def test_split_with_sizes(self):
228
[[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]]
230
self.assertONNX(lambda x: torch.split(x, [2, 1, 3], 1), x)
232
def test_concat2(self):
233
x = torch.randn(2, 3)
234
y = torch.randn(2, 3)
235
self.assertONNX(lambda inputs: torch.cat(inputs, 1), ((x, y),))
238
m1 = torch.randn(2, 3, requires_grad=True)
239
m2 = torch.randn(3, 4, requires_grad=True)
240
self.assertONNX(torch.mm, (m1, m2))
242
def test_addmm(self):
243
m1 = torch.randn(2, 3, requires_grad=True)
244
m2 = torch.randn(3, 4, requires_grad=True)
245
m3 = torch.randn(4, requires_grad=True)
247
lambda x, y, z: torch.addmm(torch.addmm(z, x, y), x, y), (m1, m2, m3)
250
def test_permute2(self):
251
x = torch.tensor([[[[[[0.0]]]]]], requires_grad=True)
252
self.assertONNX(lambda x: x.permute(0, 1, 4, 2, 5, 3), x)
256
[[[[0.0, 1.0, 1.0, 1.0], [2.0, 3.0, 7.0, 7.0]]]], requires_grad=True
258
self.assertONNX(nn.ReflectionPad2d((2, 3, 0, 1)), x)
260
def test_params(self):
261
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
262
y = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True))
264
lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))),
267
keep_initializers_as_inputs=True,
270
def test_params_onnx_irv4(self):
271
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
272
y = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True))
274
lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))),
277
keep_initializers_as_inputs=False,
280
def test_symbolic_mismatch(self):
281
class MyFun(Function):
286
raise AssertionError()
289
def forward(ctx, x, y):
296
with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"):
297
export_to_pbtxt(FuncModule(MyFun().apply), (x, y))
300
def test_batchnorm(self):
301
x = torch.ones(2, 2, 2, 2, requires_grad=True)
302
self.assertONNX(nn.BatchNorm2d(2), x, keep_initializers_as_inputs=True)
304
def test_batchnorm_onnx_irv4(self):
305
x = torch.ones(2, 2, 2, 2, requires_grad=True)
306
self.assertONNX(nn.BatchNorm2d(2), x)
308
def test_batchnorm_1d(self):
309
x = torch.ones(2, 2, requires_grad=True)
310
self.assertONNX(nn.BatchNorm1d(2), x, keep_initializers_as_inputs=True)
312
def test_batchnorm_training(self):
313
x = torch.ones(2, 2, 2, 2, requires_grad=True)
317
training=torch.onnx.TrainingMode.TRAINING,
318
keep_initializers_as_inputs=True,
322
x = torch.ones(20, 16, 50, 40, requires_grad=True)
324
nn.Conv2d(16, 13, 3, bias=False), x, keep_initializers_as_inputs=True
327
def test_conv_onnx_irv4(self):
328
x = torch.ones(20, 16, 50, 40, requires_grad=True)
329
self.assertONNX(nn.Conv2d(16, 13, 3, bias=False), x)
331
def test_conv_onnx_irv4_opset8(self):
336
x = torch.ones(1, 2, 5, 7, requires_grad=True)
337
conv_node = nn.Conv2d(2, 4, 3, bias=False)
338
conv_node.weight.data.fill_(1.0)
340
conv_node, x, opset_version=8, keep_initializers_as_inputs=False
343
def test_conv_variable_length(self):
344
x = torch.ones(5, 3, 6, 6, requires_grad=True)
345
model = torch.nn.Conv2d(3, 2, 3)
348
"input_1": [0, 2, 3],
349
"output_1": {0: "output_1_variable_dim_0", 1: "output_1_variable_dim_1"},
351
model_proto_file = tempfile.NamedTemporaryFile()
355
model_proto_file.name,
357
input_names=["input_1"],
358
output_names=["output_1"],
359
dynamic_axes=dynamic_axes,
364
onnx_model = onnx.load(model_proto_file.name)
365
onnx.checker.check_model(onnx_model)
369
onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
370
== "input_1_dynamic_axes_1"
373
onnx_model.graph.input[0].type.tensor_type.shape.dim[2].dim_param
374
== "input_1_dynamic_axes_2"
377
onnx_model.graph.input[0].type.tensor_type.shape.dim[3].dim_param
378
== "input_1_dynamic_axes_3"
383
onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_param
384
== "output_1_variable_dim_0"
387
onnx_model.graph.output[0].type.tensor_type.shape.dim[1].dim_param
388
== "output_1_variable_dim_1"
391
def test_convtranspose(self):
392
x = torch.ones(2, 3, 4, 5, requires_grad=True)
395
3, 3, 3, stride=3, bias=False, padding=1, output_padding=2
398
keep_initializers_as_inputs=True,
401
def test_maxpool(self):
402
x = torch.randn(20, 16, 50)
403
self.assertONNX(nn.MaxPool1d(3, stride=2), x)
405
def test_maxpool_dilations(self):
406
x = torch.randn(20, 16, 50)
407
self.assertONNX(nn.MaxPool1d(2, stride=1, dilation=2), x, opset_version=10)
409
def test_avg_pool2d(self):
410
x = torch.randn(20, 16, 50, 32)
411
self.assertONNX(nn.AvgPool2d(3, stride=2), x)
413
def test_maxpool_indices(self):
414
x = torch.randn(20, 16, 50)
415
self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x)
418
def test_at_op(self):
419
x = torch.randn(3, 4)
421
class MyFun(Function):
424
return g.at("add", x, x)
430
class MyModule(Module):
431
def forward(self, x):
432
return MyFun.apply(x)
437
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
441
x = torch.randn(3, 4, requires_grad=True)
442
self.assertONNX(lambda x: torch.clamp(x, min=-0.5, max=0.5), x)
444
def test_clip_min(self):
445
x = torch.randn(1, 2, 3, 4, requires_grad=True)
446
self.assertONNX(lambda x: x.clamp(min=-0.1), x)
448
def test_clip_max(self):
449
x = torch.randn(1, 2, 3, 4, requires_grad=True)
450
self.assertONNX(lambda x: x.clamp(max=0.1), x)
452
def test_hardtanh(self):
453
x = torch.randn(3, 4, requires_grad=True)
454
self.assertONNX(lambda x: torch.nn.Hardtanh(-0.5, 0.5)(x), x)
457
x = torch.randn(3, 4, requires_grad=True)
458
self.assertONNX(lambda x: torch.full(x.shape, 2.0), x)
460
def test_full_like(self):
461
x = torch.randn(3, 4, requires_grad=True)
462
self.assertONNX(lambda x: torch.full_like(x, 2), x)
465
x = torch.randn(3, 4, requires_grad=True)
466
y = torch.randn(3, 4, requires_grad=True)
467
self.assertONNX(lambda x, y: torch.max(x, y), (x, y))
470
x = torch.randn(3, 4, requires_grad=True)
471
y = torch.randn(3, 4, requires_grad=True)
472
self.assertONNX(lambda x, y: torch.min(x, y), (x, y))
475
x = torch.randn(1, 2, 3, 4, requires_grad=True)
476
self.assertONNX(lambda x: torch.mean(x), x)
478
def test_reduced_mean(self):
479
x = torch.randn(1, 2, 3, 4, requires_grad=True)
480
self.assertONNX(lambda x: torch.mean(x, dim=2), x)
482
def test_reduced_mean_keepdim(self):
483
x = torch.randn(1, 2, 3, 4, requires_grad=True)
484
self.assertONNX(lambda x: torch.mean(x, dim=(2, 3), keepdim=True), x)
486
def test_mean_dtype(self):
487
x = torch.randn(1, 2, 3, 4, requires_grad=True)
488
self.assertONNX(lambda x: torch.mean(x, dtype=torch.double), x)
490
def test_reduced_mean_dtype(self):
491
x = torch.randn(1, 2, 3, 4, requires_grad=True)
492
self.assertONNX(lambda x: torch.mean(x, dim=0, dtype=torch.double), x)
495
x = torch.randn(1, 2, 3, 4, requires_grad=True)
496
self.assertONNX(lambda x: torch.sum(x), x)
498
def test_sum_dtype(self):
499
x = torch.randn(1, 2, 3, 4, requires_grad=True)
500
self.assertONNX(lambda x: torch.sum(x, dtype=torch.double), x)
502
def test_reduced_sum_dtype(self):
503
x = torch.randn(1, 2, 3, 4, requires_grad=True)
504
self.assertONNX(lambda x: torch.sum(x, dim=0, dtype=torch.double), x)
506
def test_reduced_sum(self):
507
x = torch.randn(1, 2, 3, 4, requires_grad=True)
508
self.assertONNX(lambda x: torch.sum(x, dim=(1, 2)), x)
510
def test_reduced_sum_keepdim(self):
511
x = torch.randn(1, 2, 3, 4, requires_grad=True)
512
self.assertONNX(lambda x: torch.sum(x, dim=2, keepdim=True), x)
515
x = torch.randn(1, 2, 3, 4, requires_grad=True)
516
self.assertONNX(lambda x: torch.prod(x), x)
518
def test_reduced_prod(self):
519
x = torch.randn(1, 2, 3, 4, requires_grad=True)
520
self.assertONNX(lambda x: torch.prod(x, dim=2), x)
522
def test_reduced_prod_keepdim(self):
523
x = torch.randn(1, 2, 3, 4, requires_grad=True)
524
self.assertONNX(lambda x: torch.prod(x, dim=2, keepdim=True), x)
526
def test_prod_dtype(self):
527
x = torch.randn(1, 2, 3, 4, requires_grad=True)
528
self.assertONNX(lambda x: torch.prod(x, dtype=torch.double), x)
530
def test_reduced_prod_dtype(self):
531
x = torch.randn(1, 2, 3, 4, requires_grad=True)
532
self.assertONNX(lambda x: torch.prod(x, dim=0, dtype=torch.double), x)
535
x = torch.randn(3, 4, requires_grad=True)
536
self.assertONNX(lambda x: torch.sqrt(x), x)
538
def test_rsqrt(self):
539
x = torch.randn(3, 4, requires_grad=True)
540
self.assertONNX(lambda x: torch.rsqrt(x), x)
542
def test_equal(self):
543
x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
544
y = torch.randn(1, 4, requires_grad=False).int()
545
self.assertONNX(operator.eq, (x, y))
548
x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
549
y = torch.randn(1, 4, requires_grad=False).int()
550
self.assertONNX(operator.lt, (x, y))
553
x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
554
y = torch.randn(1, 4, requires_grad=False).int()
555
self.assertONNX(operator.gt, (x, y))
558
x = torch.randn(3, 4, requires_grad=False).int()
559
y = torch.randn(3, 4, requires_grad=False).int()
560
self.assertONNX(operator.le, (x, y))
563
x = torch.randn(3, 4, requires_grad=False).int()
564
y = torch.randn(3, 4, requires_grad=False).int()
565
self.assertONNX(operator.ge, (x, y))
568
x = torch.randn(3, 4, requires_grad=True)
569
self.assertONNX(lambda x: x.exp(), x)
572
x = torch.randn(3, 4, requires_grad=True)
573
self.assertONNX(lambda x: x.sin(), x)
576
x = torch.randn(3, 4, requires_grad=True)
577
self.assertONNX(lambda x: x.cos(), x)
580
x = torch.randn(3, 4, requires_grad=True)
581
self.assertONNX(lambda x: x.tan(), x)
584
x = torch.rand(3, 4, requires_grad=True)
585
self.assertONNX(lambda x: x.asin(), x)
588
x = torch.rand(3, 4, requires_grad=True)
589
self.assertONNX(lambda x: x.acos(), x)
591
def test_slice(self):
592
x = torch.rand(3, 4, requires_grad=True)
593
self.assertONNX(lambda x: x[:, 1:2], x)
595
def test_slice_dynamic(self):
596
x = torch.rand(3, 4, requires_grad=True)
597
self.assertONNX(lambda x: x[x.size(0) :, x.size(1) - 3], x, opset_version=10)
600
x = torch.rand(3, 4, requires_grad=True)
601
self.assertONNX(lambda x: x.sign(), x)
603
def test_narrow(self):
604
x = torch.randn(3, 3, requires_grad=True)
605
self.assertONNX(lambda x: torch.narrow(x, 0, 0, 2), x)
608
x = torch.randn(3, 4, requires_grad=True)
609
self.assertONNX(lambda x: x.atan(), x)
611
def test_view_flatten(self):
612
x = torch.randn(1, 2, 3, 4, requires_grad=True)
613
self.assertONNX(lambda x: x.view(x.size()[0], x.numel() // x.size()[0]), x)
615
def test_flatten(self):
616
x = torch.randn(1, 2, 3, 4, requires_grad=True)
617
self.assertONNX(lambda x: torch.flatten(x), x)
619
def test_flatten2D(self):
620
x = torch.randn(1, 2, 3, 4, requires_grad=True)
621
self.assertONNX(lambda x: torch.flatten(x, 1), x)
623
def test_isnan(self):
624
x = torch.tensor([1, float("nan"), 2])
625
self.assertONNX(lambda x: torch.isnan(x), x)
627
def test_argmax(self):
628
x = torch.randn(4, 4, requires_grad=True)
629
self.assertONNX(lambda x: torch.argmax(x, dim=1), x)
631
def test_logsoftmax(self):
632
x = torch.randn(1, 2, 3, 4, requires_grad=True)
633
self.assertONNX(nn.LogSoftmax(dim=3), x)
636
x = torch.randn(1, 2, 3, 4, requires_grad=True)
637
y = torch.randn(1, 2, 3, 4, requires_grad=True)
638
self.assertONNX(lambda x, y: x.pow(y), (x, y))
641
x = torch.randn(1, 2, 3, 4, requires_grad=True)
642
self.assertONNX(nn.ELU(), x)
645
x = torch.randn(1, 2, 3, 4, requires_grad=True)
646
self.assertONNX(nn.SELU(), x)
648
def test_repeat(self):
649
x = torch.randn(1, 2, 3, 4, requires_grad=True)
650
self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x)
652
def test_repeat_dim_overflow(self):
653
x = torch.randn(1, 2, requires_grad=True)
654
self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x)
656
def test_norm_p1(self):
657
x = torch.randn(1, 2, 3, 4, requires_grad=True)
658
self.assertONNX(lambda x: x.norm(p=1, dim=2), (x))
660
def test_norm_p2(self):
661
x = torch.randn(1, 2, 3, 4, requires_grad=True)
662
self.assertONNX(lambda x: x.norm(p=2, dim=2), (x))
664
def test_upsample_nearest_scale(self):
665
x = torch.randn(1, 2, 3, 4, requires_grad=True)
667
lambda x: nn.functional.interpolate(
668
x, scale_factor=2.0, mode="nearest", recompute_scale_factor=False
673
def test_upsample_nearest_scale_default_scale_factor(self):
674
x = torch.randn(1, 2, 3, 4, requires_grad=True)
676
lambda x: nn.functional.interpolate(x, scale_factor=2.0, mode="nearest"), x
679
def test_upsample_nearest_size(self):
680
x = torch.randn(1, 2, 3, 4, requires_grad=True)
682
lambda x: nn.functional.interpolate(x, size=16, mode="nearest"), x
685
def test_unsqueeze(self):
686
x = torch.randn(3, 4, requires_grad=True)
687
self.assertONNX(lambda x: x.unsqueeze(len(x.shape)), x)
689
def test_batchnorm_noaffine(self):
690
x = torch.randn(128, 128, 1, 1, requires_grad=True)
692
nn.BatchNorm2d(128, affine=False, momentum=0.3),
694
keep_initializers_as_inputs=True,
698
def test_embedding_bags(self):
699
emb_bag = nn.EmbeddingBag(10, 8)
700
input = torch.tensor([1, 2, 3, 4]).long()
701
offset = torch.tensor([0]).long()
705
keep_initializers_as_inputs=True,
706
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
709
def test_implicit_expand(self):
710
x = torch.randn(3, 4, requires_grad=True)
711
self.assertONNX(lambda x: x + 1, x)
713
def test_reduce_sum_negative_indices(self):
714
x = torch.randn(3, 4, requires_grad=True)
715
self.assertONNX(lambda x: x.sum(-1), x)
717
def test_randn(self):
718
x = torch.randn(1, 2, 3, 4)
719
self.assertONNX(lambda x: torch.randn(1, 2, 3, 4) + x, x)
722
x = torch.rand(1, 2, 3, 4)
723
self.assertONNX(lambda x: torch.rand(1, 2, 3, 4) + x, x)
725
def test_rrelu(self):
726
x = torch.randn(1, 2, 3, 4)
727
self.assertONNX(torch.nn.RReLU(), x)
729
def test_prelu(self):
730
x = torch.randn(1, 2, 3, 4)
731
self.assertONNX(torch.nn.PReLU(2), x, keep_initializers_as_inputs=True)
733
def test_log_sigmoid(self):
734
x = torch.randn(1, 2, 3, 4)
735
self.assertONNX(torch.nn.LogSigmoid(), x)
737
def test_linear(self):
738
x = torch.randn(3, 4)
740
torch.nn.Linear(4, 5, bias=True), x, keep_initializers_as_inputs=True
743
def test_empty_like(self):
744
x = torch.randn(5, 8, requires_grad=True)
745
self.assertONNX(lambda x: torch.empty_like(x), x)
747
def test_zeros_like(self):
748
x = torch.randn(5, 8, requires_grad=True)
749
self.assertONNX(lambda x: torch.zeros_like(x), x)
751
def test_ones_like(self):
752
x = torch.randn(6, 10, requires_grad=True)
753
self.assertONNX(lambda x: torch.ones_like(x), x)
755
def test_expand(self):
756
x = torch.randn(6, 1, requires_grad=True)
757
self.assertONNX(lambda x: x.expand(4, 6, 2), x)
760
x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
761
y = torch.randn(1, 4, requires_grad=False).int()
762
self.assertONNX(lambda x, y: torch.ne(x, y), (x, y))
764
def test_reducemax(self):
765
x = torch.randn(1, 2, 3, 4)
766
self.assertONNX(lambda x: torch.max(x), x)
768
def test_reducemin(self):
769
x = torch.randn(1, 2, 3, 4)
770
self.assertONNX(lambda x: torch.min(x), x)
773
x = torch.randn(1, 2, 3, 4)
774
self.assertONNX(lambda x: x.erf(), x)
776
def test_dropout(self):
777
x = torch.randn(3, 4, requires_grad=True)
778
self.assertONNX(lambda x: torch.max(functional.dropout(x, training=False)), x)
780
def test_dropout_default(self):
781
x = torch.randn(3, 4, requires_grad=True)
791
def test_dropout_training(self):
792
x = torch.randn(3, 4, requires_grad=True)
794
lambda x: torch.max(functional.dropout(x)),
796
training=torch.onnx.TrainingMode.TRAINING,
799
def test_dropout_opset12(self):
800
x = torch.randn(3, 4, requires_grad=True)
802
lambda x: torch.max(functional.dropout(x, training=False)),
807
def test_dropout_training_opset12(self):
808
x = torch.randn(3, 4, requires_grad=True)
810
lambda x: torch.max(functional.dropout(x)),
813
training=torch.onnx.TrainingMode.TRAINING,
816
def test_nonzero(self):
818
[[[2.0, 2.0], [1.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]]], requires_grad=True
820
self.assertONNX(lambda x: torch.nonzero(x), x)
822
def test_gather(self):
823
data = torch.randn(3, 4, 3, requires_grad=True)
824
index = torch.tensor([2, 0]).view(1, 2, 1).expand(3, 2, 3)
825
self.assertONNX(lambda data, index: data.gather(1, index), (data, index))
827
def test_gather_opset11(self):
828
data = torch.randn(3, 4, 3, requires_grad=True)
829
index = torch.tensor([2, 0]).view(1, 2, 1).expand(3, 2, 3)
831
lambda data, index: data.gather(1, index), (data, index), opset_version=11
834
def test_scatter_add(self):
835
data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
836
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
837
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
839
lambda data, index: data.scatter_add(1, indices, values),
840
(data, (indices, values)),
843
def test_scatter_add_opset11(self):
844
data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
845
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
846
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
848
lambda data, index: data.scatter_add(1, indices, values),
849
(data, (indices, values)),
853
def test_scatter_add_opset16(self):
854
data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
855
indices = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64)
856
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
858
lambda data, index: data.scatter_add(1, indices, values),
859
(data, (indices, values)),
863
def test_master_opset(self):
864
x = torch.randn(2, 3).float()
865
y = torch.randn(2, 3).float()
866
self.assertONNX(operator.add, (x, y), opset_version=10)
869
x = torch.randn(2, 3, 4).float()
871
lambda x: torch.std(x, dim=(0, 1), unbiased=True, keepdim=True), x
874
def test_cumsum(self):
875
x = torch.randn(2, 3, 4, requires_grad=True)
876
self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11)
906
class MyModel(torch.nn.Module):
907
def forward(self, x_in):
909
x_out["test_key_out"] = torch.add(
910
x_in[list(x_in.keys())[0]], list(x_in.keys())[0]
914
x = {torch.tensor(1.0): torch.randn(1, 2, 3)}
915
self.assertONNX(MyModel(), (x, {}))
917
def test_dict_str(self):
918
class MyModel(torch.nn.Module):
919
def forward(self, x_in):
921
x_out["test_key_out"] = torch.add(x_in["test_key_in"], 2.0)
924
x = {"test_key_in": torch.randn(1, 2, 3)}
925
self.assertONNX(MyModel(), (x, {}))
927
def test_arange_dynamic(self):
928
class TestModel(torch.nn.Module):
929
def forward(self, input):
930
return torch.arange(input.shape[0], input.shape[0] + 5, 0.5)
932
input = torch.randn(5, 3, 2)
933
self.assertONNX(TestModel(), input, opset_version=11)
935
def test_bitshift(self):
936
class BitshiftModel(torch.nn.Module):
937
def forward(self, input):
938
return input >> 1, input >> 2
940
input = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
941
self.assertONNX(BitshiftModel(), input, opset_version=11)
944
def test_layer_norm_aten(self):
945
model = torch.nn.LayerNorm([10, 10])
946
x = torch.randn(20, 5, 10, 10)
950
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
953
def test_pixel_shuffle(self):
954
x = torch.randn(2, 8, 3, 4).float()
956
lambda x: torch.pixel_shuffle(x, upscale_factor=2), x, opset_version=11
959
def test_frobenius_norm(self):
960
x = torch.randn(2, 3, 4).float()
961
self.assertONNX(lambda x: torch.norm(x, p="fro", dim=(0, 1), keepdim=True), x)
963
def test_unfold(self):
964
x = torch.randn(2, 3, 4, requires_grad=True)
965
self.assertONNX(lambda x: x.unfold(dimension=2, size=2, step=2), x)
967
def test_remainder(self):
968
x = torch.randn(2, 3, 4)
969
y = torch.randn(2, 1, 4)
970
self.assertONNX(lambda x, y: torch.remainder(x, y), (x, y))
973
x = torch.randn(2, 3, 4)
974
y = torch.randn(2, 1, 4)
975
self.assertONNX(lambda x, y: torch.fmod(x, y), (x, y), opset_version=10)
978
x = torch.randn(2, 3, 4, 5, requires_grad=True)
979
self.assertONNX(lambda x: torch.nn.functional.gelu(x), x)
981
def test_unique(self):
982
x = torch.randint(3, (2, 3, 4, 5)).float()
984
lambda x: torch.unique(
985
x, dim=0, sorted=True, return_inverse=False, return_counts=True
991
def test_meshgrid(self):
992
x = torch.ones(3, requires_grad=True)
993
y = torch.zeros(4, requires_grad=True)
994
z = torch.ones(5, requires_grad=True)
995
self.assertONNX(lambda x, y, z: torch.meshgrid(x, y, z), (x, y, z))
997
def test_meshgrid_indexing(self):
998
x = torch.ones(3, requires_grad=True)
999
y = torch.zeros(4, requires_grad=True)
1000
z = torch.ones(5, requires_grad=True)
1002
lambda x, y, z: torch.meshgrid(x, y, z, indexing="xy"),
1007
def test_topk(self):
1008
x = torch.arange(1.0, 6.0, requires_grad=True)
1010
self.assertONNX(lambda x, k: torch.topk(x, k), (x, k), opset_version=10)
1012
def test_topk_smallest_unsorted(self):
1013
x = torch.arange(1.0, 6.0, requires_grad=True)
1016
lambda x, k: torch.topk(x, k, largest=False, sorted=False),
1021
def test_baddbmm(self):
1022
x = torch.randn(10, 3, 5)
1023
b1 = torch.randn(10, 3, 4)
1024
b2 = torch.randn(10, 4, 5)
1025
self.assertONNX(lambda x, b1, b2: torch.baddbmm(x, b1, b2), (x, b1, b2))
1027
def test_round(self):
1028
x = torch.tensor([0.9920, -1.0362, -1.5000, 2.5000], requires_grad=True)
1029
self.assertONNX(lambda x: torch.round(x), x, opset_version=11)
1032
x = torch.ones((2, 2), requires_grad=True)
1033
self.assertONNX(lambda x: torch.scalar_tensor(x.dim()), x)
1037
x = torch.randn(2, 3, 5, 5, device=torch.device("cpu"))
1038
self.assertONNX(lambda x: torch.det(x), x, opset_version=11)
1039
self.assertONNX(lambda x: torch.linalg.det(x), x, opset_version=11)
1041
def test_softmaxcrossentropy(self):
1042
x = torch.randn(3, 5)
1043
y = torch.empty(3, dtype=torch.long).random_(5)
1044
self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12)
1046
def test_softmaxcrossentropy_ignore_index(self):
1047
x = torch.randn(3, 5)
1048
y = torch.empty(3, dtype=torch.long).random_(5)
1050
torch.nn.CrossEntropyLoss(ignore_index=1), (x, y), opset_version=12
1053
def test_softmaxcrossentropy_weights(self):
1054
x = torch.randn(3, 5)
1055
y = torch.empty(3, dtype=torch.long).random_(5)
1057
torch.nn.CrossEntropyLoss(weight=torch.randn(5)), (x, y), opset_version=12
1060
def test_softmaxcrossentropy_3d(self):
1061
x = torch.randn(3, 5, 2)
1062
y = torch.empty(3, 2, dtype=torch.long).random_(5)
1063
self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12)
1065
def test_softmaxcrossentropy_3d_none(self):
1066
x = torch.randn(3, 5, 2)
1067
y = torch.empty(3, 2, dtype=torch.long).random_(5)
1069
torch.nn.CrossEntropyLoss(reduction="none"), (x, y), opset_version=12
1072
def test_softmaxcrossentropy_4d(self):
1073
x = torch.randn(3, 5, 2, 1)
1074
y = torch.empty(3, 2, 1, dtype=torch.long).random_(5)
1075
self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12)
1077
def test_lstm_none_sequence_lens(self):
1078
"""Test symbolic shape inference for LSTM when the input sequence_lens = None."""
1079
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
1080
h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
1081
c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
1083
class LSTMModel(torch.nn.Module):
1086
self.rnn = torch.nn.LSTM(
1087
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
1090
def forward(self, x, h0, c0):
1091
a, b = self.rnn(x, (h0, c0))
1092
return torch.ones(b[0].shape)
1097
input_names=["x", "y"],
1098
dynamic_axes={"x": {0: "batch"}},
1102
def test_dynamic_axes_add(self):
1103
m1 = torch.randn(2, 3, requires_grad=True)
1104
m2 = torch.randn(2, 1, requires_grad=True)
1106
lambda x, y: torch.add(x, y),
1108
input_names=["input_1", "input_2"],
1109
dynamic_axes={"input_1": {1: "dim_1"}, "input_2": {1: "dim_2"}},
1113
def test_dynamic_axes_add_inputs_same_symbolic_shape(self):
1114
m1 = torch.randn(2, 3, requires_grad=True)
1116
lambda x: torch.add(x, x),
1118
input_names=["input_1"],
1119
dynamic_axes={"input_1": {1: "dim_1"}},
1123
def test_dynamic_axes_matmul(self):
1124
m1 = torch.randn(2, 2, 4, requires_grad=True)
1125
m2 = torch.randn(2, 4, 3, requires_grad=True)
1127
lambda x, y: torch.matmul(x, y),
1129
input_names=["input_1", "input_2"],
1130
dynamic_axes={"input_1": {1: "dim_0"}, "input_2": {2: "dim_1"}},
1134
def test_dynamic_axes_reduce_mean(self):
1135
m1 = torch.randn(2, 3, 4, requires_grad=True)
1137
lambda x: torch.mean(x, dim=1),
1139
input_names=["input"],
1140
dynamic_axes={"input": {1: "dim_1", 2: "dim_2"}},
1144
def test_dynamic_axes_unchange(self):
1145
"""Test ProcessUnchangeNode in symbolic shape inference."""
1146
m1 = torch.randn(2, 3, requires_grad=True)
1148
lambda x: torch.softmax(x, dim=0),
1150
input_names=["input"],
1151
dynamic_axes={"input": {1: "dim_1"}},
1155
def test_aten_embedding_1(self):
1156
_onnx_opset_version = 12
1158
@parse_args("v", "v", "i", "b", "b")
1159
def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
1160
custom_attributes_json = (
1162
f'"padding_idx":{str(padding_idx)},'
1163
f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
1164
f'"sparse":{str(sparse).lower()}'
1171
custom_attributes_json_s=custom_attributes_json,
1175
torch.onnx.register_custom_op_symbolic(
1176
"::embedding", embedding, _onnx_opset_version
1179
class Model(torch.nn.Module):
1182
self.emb = torch.nn.Embedding(4, 8)
1184
def forward(self, x, y):
1187
return torch.ones(res.shape[0])
1190
x = torch.ones(32, dtype=torch.long)
1191
y = torch.randn(1, 8)
1192
self.assertONNX(model, (x, y), opset_version=_onnx_opset_version)
1194
torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version)
1198
def test_aten_embedding_2(self):
1199
_onnx_opset_version = 12
1201
@parse_args("v", "v", "i", "b", "b")
1202
def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
1203
custom_attributes_json = (
1205
f'"padding_idx":{str(padding_idx)},'
1206
f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
1207
f'"sparse":{str(sparse).lower()}'
1214
custom_attributes_json_s=custom_attributes_json,
1218
indices_shape = _get_tensor_sizes(indices)
1219
if indices_shape is not None and hasattr(weight.type(), "with_sizes"):
1220
output_type = weight.type().with_sizes(
1221
indices_shape + [_get_tensor_dim_size(weight, 1)]
1223
output.setType(output_type)
1226
torch.onnx.register_custom_op_symbolic(
1227
"::embedding", embedding, _onnx_opset_version
1230
class Model(torch.nn.Module):
1233
self.emb = torch.nn.Embedding(4, 8)
1235
def forward(self, x, y):
1238
return torch.ones(res.shape[0])
1241
x = torch.ones(32, dtype=torch.long)
1242
y = torch.randn(1, 8)
1246
opset_version=_onnx_opset_version,
1247
input_names=["input_1", "input_2"],
1248
dynamic_axes={"input_1": {0: "dim_0"}, "input_2": {0: "dim_1", 1: "dim_2"}},
1249
keep_initializers_as_inputs=False,
1250
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
1253
torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version)
1270
def test_shape_value_map(self):
1271
class RSoftMax(torch.nn.Module):
1272
def __init__(self, radix, cardinality):
1275
self.cardinality = cardinality
1277
def forward(self, x):
1279
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
1280
x = F.softmax(x, dim=1)
1281
x = x.reshape(batch, -1)
1286
x = torch.randn(10, 1, 128, 1)
1288
RSoftMax(radix, cardinality),
1291
dynamic_axes={"x": {0: "dim_0"}},
1295
if __name__ == "__main__":
1296
no_onnx_dep_flag = "--no-onnx"
1297
_onnx_dep = no_onnx_dep_flag not in common_utils.UNITTEST_ARGS
1298
if no_onnx_dep_flag in common_utils.UNITTEST_ARGS:
1299
common_utils.UNITTEST_ARGS.remove(no_onnx_dep_flag)
1300
onnx_test_flag = "--produce-onnx-test-data"
1301
_onnx_test = onnx_test_flag in common_utils.UNITTEST_ARGS
1302
if onnx_test_flag in common_utils.UNITTEST_ARGS:
1303
common_utils.UNITTEST_ARGS.remove(onnx_test_flag)
1306
import onnx_test_common
1309
os.path.join(onnx_test_common.pytorch_operator_dir, "test_operator_*")
1312
common_utils.run_tests()