1
# Owner(s): ["module: onnx"]
3
from __future__ import annotations
11
from collections import OrderedDict
12
from typing import Dict, List, Optional, Tuple, Type, Union
16
import onnx_test_common
20
from model_defs import (
21
lstm_flattening_result,
22
rnn_model_with_packed_sequence,
25
from pytorch_test_common import (
32
skipIfQuantizationBackendQNNPack,
33
skipIfUnsupportedMaxOpsetVersion,
34
skipIfUnsupportedMinOpsetVersion,
35
skipIfUnsupportedOpsetVersion,
41
from torch import Tensor
42
from torch.nn.utils import rnn as rnn_utils
43
from torch.onnx import errors, verification
44
from torch.testing._internal import common_utils
45
from torch.testing._internal.common_utils import skipIfNoLapack
48
def _init_test_generalized_rcnn_transform():
51
image_mean = [0.485, 0.456, 0.406]
52
image_std = [0.229, 0.224, 0.225]
53
transform = torchvision.models.detection.transform.GeneralizedRCNNTransform(
54
min_size, max_size, image_mean, image_std
60
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
61
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
62
rpn_anchor_generator = torchvision.models.detection.rpn.AnchorGenerator(
63
anchor_sizes, aspect_ratios
66
rpn_head = torchvision.models.detection.rpn.RPNHead(
67
out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
69
rpn_fg_iou_thresh = 0.7
70
rpn_bg_iou_thresh = 0.3
71
rpn_batch_size_per_image = 256
72
rpn_positive_fraction = 0.5
73
rpn_pre_nms_top_n = dict(training=2000, testing=1000)
74
rpn_post_nms_top_n = dict(training=2000, testing=1000)
76
rpn_score_thresh = 0.0
78
rpn = torchvision.models.detection.rpn.RegionProposalNetwork(
83
rpn_batch_size_per_image,
84
rpn_positive_fraction,
88
score_thresh=rpn_score_thresh,
93
def _construct_tensor_for_quantization_test(
94
shape: Tuple[int, ...],
95
offset: Optional[Union[int, float]] = None,
96
max_val: Optional[Union[int, float]] = None,
98
"""Helper function to generate weights and test inputs in a deterministic way.
100
Due to difference in implementation details between PyTorch and ONNXRuntime, randomly generated
101
test data for quantization tests can be flaky. To help stablize the test, this helper function is
102
used to generate weights and test inputs in a deterministic way.
105
shape (Tuple[int]): Shape for tensor to construct.
106
offset (Optional[Union[int, float]]): Offset to be added to the generated tensor.
107
max_val (Optional[Union[int, float]]): If any element within tensor has a larger absolute value than
108
max_val, the tensor will be scaled by max_val / tensor.abs().max(). This step is done after
111
tensor = torch.arange(np.prod(shape), dtype=torch.float).view(shape)
112
if offset is not None:
113
tensor = tensor + offset
114
if max_val is not None and tensor.abs().max() > max_val:
115
tensor = tensor * max_val / tensor.abs().max()
119
def _parameterized_class_attrs_and_values(
120
min_opset_version: int, max_opset_version: int
122
attrs = ("opset_version", "is_script", "keep_initializers_as_inputs")
124
input_values.extend(itertools.product((7, 8), (True, False), (True,)))
125
# Valid opset versions are defined in torch/onnx/_constants.py.
126
# Versions are intentionally set statically, to not be affected by changes elsewhere.
127
if min_opset_version < 9:
128
raise ValueError("min_opset_version must be >= 9")
131
range(min_opset_version, max_opset_version + 1),
136
return {"attrs": attrs, "input_values": input_values}
139
def _parametrize_rnn_args(arg_name):
141
"layers": {1: "unilayer", 3: "trilayer"},
142
"bidirectional": {True: "bidirectional", False: "forward"},
143
"initial_state": {True: "with_initial_state", False: "no_initial_state"},
145
0: "without_sequence_lengths",
146
1: "with_variable_length_sequences",
147
2: "with_batch_first_sequence_lengths",
149
"dropout": {0.2: "with_dropout", 0.0: "without_dropout"},
154
"arg_values": options[arg_name].keys(),
155
"name_fn": lambda val: options[arg_name][val],
159
@parameterized.parameterized_class(
160
**_parameterized_class_attrs_and_values(
161
onnx_test_common.MIN_ONNX_OPSET_VERSION, onnx_test_common.MAX_ONNX_OPSET_VERSION
163
class_name_func=onnx_test_common.parameterize_class_name,
165
@common_utils.instantiate_parametrized_tests
166
class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
167
def test_fuse_conv_bn1d(self):
168
class Fuse(torch.nn.Module):
171
self.conv = torch.nn.Conv1d(16, 33, 3, stride=2)
172
self.bn = torch.nn.BatchNorm1d(33)
174
def forward(self, x):
179
x = torch.randn(20, 16, 50, requires_grad=True)
180
self.run_test(model, (x,))
182
def test_fuse_conv_bn2d(self):
183
class Fuse(torch.nn.Module):
186
self.conv = torch.nn.Conv2d(
187
3, 2, kernel_size=1, stride=2, padding=3, bias=False
189
self.bn = torch.nn.BatchNorm2d(2)
191
def forward(self, x):
196
x = torch.randn(2, 3, 2, 2, requires_grad=True)
197
self.run_test(model, (x,))
199
def test_fuse_conv_bn3d(self):
200
class Fuse(torch.nn.Module):
203
self.conv = torch.nn.Conv3d(
204
3, 2, (3, 5, 2), stride=(2, 1, 1), padding=(3, 2, 0), bias=False
206
self.bn = torch.nn.BatchNorm3d(2)
208
def forward(self, x):
213
x = torch.randn(2, 3, 10, 50, 100, requires_grad=True)
214
self.run_test(model, (x,), rtol=1e-3, atol=1e-6)
216
def test_fuse_conv_in_block(self):
217
class Fuse(torch.nn.Module):
220
self.conv = torch.nn.Conv1d(
228
self.bn = torch.nn.BatchNorm1d(5)
230
def forward(self, x):
231
results_available = True
234
results_available = False
236
if results_available:
243
x = torch.randn(2, 5, 9, requires_grad=True)
245
torch.jit.script(model),
248
dynamic_axes={"x": [0, 2]},
253
def test_conv_tbc(self):
254
from torch.nn.modules.utils import _single
256
class ConvTBC(torch.nn.Module):
257
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
259
self.in_channels = in_channels
260
self.out_channels = out_channels
261
self.kernel_size = _single(kernel_size)
262
self.padding = _single(padding)
264
self.weight = torch.nn.Parameter(
265
Tensor(self.kernel_size[0], in_channels, out_channels)
267
self.bias = torch.nn.Parameter(Tensor(out_channels))
268
self.reset_parameters()
270
def reset_parameters(self):
271
torch.nn.init.xavier_normal_(self.weight)
272
torch.nn.init.zeros_(self.bias)
274
def conv_tbc(self, input):
275
return torch.conv_tbc(
276
input.contiguous(), self.weight, self.bias, self.padding[0]
279
def forward(self, input):
280
return self.conv_tbc(input)
285
model = ConvTBC(in_channels, out_channels, kernel_size, padding=0)
286
x = torch.randn(10, 7, in_channels, requires_grad=True)
287
self.run_test(model, (x,), atol=1e-5)
289
def test_reshape_constant_fold(self):
290
class Reshape(torch.nn.Module):
295
self.register_buffer("weight", torch.ones(5))
297
def forward(self, x):
298
scale_1 = self.weight.reshape(1, -1, 1, 1)
301
x = torch.randn(4, 5)
302
self.run_test(Reshape(), (x,), rtol=1e-3, atol=1e-5)
304
def run_word_language_model(self, model_name):
312
if model_name == "GRU":
313
model = word_language_model.RNNModelWithTensorHidden(
314
model_name, ntokens, emsize, nhid, nlayers, dropout, tied, batchsize
316
elif model_name == "LSTM":
317
model = word_language_model.RNNModelWithTupleHidden(
318
model_name, ntokens, emsize, nhid, nlayers, dropout, tied, batchsize
321
model = word_language_model.RNNModel(
322
model_name, ntokens, emsize, nhid, nlayers, dropout, tied, batchsize
324
x = torch.arange(0, ntokens).long().view(-1, batchsize)
325
# Only support CPU version, since tracer is not working in GPU RNN.
326
self.run_test(model, (x, model.hidden))
328
def get_image(self, rel_path: str, size: Tuple[int, int]) -> Tensor:
329
from PIL import Image
330
from torchvision import transforms
332
data_dir = os.path.join(os.path.dirname(__file__), "assets")
333
path = os.path.join(data_dir, *rel_path.split("/"))
334
image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
336
return transforms.ToTensor()(image)
338
def get_test_images(self) -> Tuple[List[Tensor], List[Tensor]]:
340
[self.get_image("grace_hopper_517x606.jpg", (100, 320))],
341
[self.get_image("rgb_pytorch.png", (250, 380))],
344
def test_paste_mask_in_image(self):
345
masks = torch.rand(10, 1, 26, 26)
346
boxes = torch.rand(10, 4)
347
boxes[:, 2:] += torch.rand(10, 2)
350
from torchvision.models.detection.roi_heads import paste_masks_in_image
352
out = paste_masks_in_image(masks, boxes, o_im_s)
353
jit_trace = torch.jit.trace(
354
paste_masks_in_image,
355
(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])]),
357
out_trace = jit_trace(
358
masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])]
361
assert torch.all(out.eq(out_trace))
363
masks2 = torch.rand(20, 1, 26, 26)
364
boxes2 = torch.rand(20, 4)
365
boxes2[:, 2:] += torch.rand(20, 2)
368
from torchvision.models.detection.roi_heads import paste_masks_in_image
370
out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
371
out_trace2 = jit_trace(
372
masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])]
375
assert torch.all(out2.eq(out_trace2))
377
def test_heatmaps_to_keypoints(self):
378
maps = torch.rand(10, 1, 26, 26)
379
rois = torch.rand(10, 4)
380
from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
382
out = heatmaps_to_keypoints(maps, rois)
383
jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois))
384
out_trace = jit_trace(maps, rois)
386
assert torch.all(out[0].eq(out_trace[0]))
387
assert torch.all(out[1].eq(out_trace[1]))
389
maps2 = torch.rand(20, 2, 21, 21)
390
rois2 = torch.rand(20, 4)
391
from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
393
out2 = heatmaps_to_keypoints(maps2, rois2)
394
out_trace2 = jit_trace(maps2, rois2)
396
assert torch.all(out2[0].eq(out_trace2[0]))
397
assert torch.all(out2[1].eq(out_trace2[1]))
399
def test_word_language_model_RNN_TANH(self):
400
self.run_word_language_model("RNN_TANH")
402
def test_word_language_model_RNN_RELU(self):
403
self.run_word_language_model("RNN_RELU")
405
@skipScriptTest() # scripting prim::unchecked_cast prim::setattr
406
def test_word_language_model_LSTM(self):
407
self.run_word_language_model("LSTM")
409
def test_word_language_model_GRU(self):
410
self.run_word_language_model("GRU")
412
def test_index_1d(self):
413
class MyModel(torch.nn.Module):
414
def forward(self, input):
417
m1 = torch.randn(3, 4, 5, 6, 7)
418
self.run_test(MyModel(), m1)
420
def test_index_2d_1dimslice(self):
421
class MyModel(torch.nn.Module):
422
def forward(self, input):
425
m1 = torch.randn(3, 4, 5, 6, 7)
426
self.run_test(MyModel(), m1)
428
def test_index_2d_sliceint(self):
429
class MyModel(torch.nn.Module):
430
def forward(self, input):
433
m1 = torch.randn(3, 4, 5, 6, 7)
434
self.run_test(MyModel(), m1)
436
def test_index_2d_neg_slice(self):
437
class MyModel(torch.nn.Module):
438
def forward(self, input):
439
return input[0:-1, :]
441
m1 = torch.randn(3, 4, 5, 6, 7)
442
self.run_test(MyModel(), m1)
444
@skipIfUnsupportedMinOpsetVersion(9)
445
def test_index_mask(self):
446
class MyModel(torch.nn.Module):
447
def forward(self, input):
448
return input[torch.tensor([0, 1, 0], dtype=torch.uint8)]
450
m1 = torch.randn(3, 4, 5, 6, 7)
451
self.run_test(MyModel(), m1)
453
class MyModel(torch.nn.Module):
454
def forward(self, input):
455
return input[torch.tensor([0, 1, 0], dtype=torch.bool)]
457
m1 = torch.randn(3, 4, 5, 6, 7)
458
self.run_test(MyModel(), m1)
460
@skipIfUnsupportedMinOpsetVersion(9)
462
class Data(torch.jit.ScriptModule):
463
@torch.jit.script_method
464
def forward(self, x):
465
return x.new_zeros(x.data.size())
467
x = torch.randn(3, 4)
468
self.run_test(Data(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
469
self.run_test(Data(), x, remained_onnx_input_idx=[])
471
@skipIfUnsupportedMinOpsetVersion(11)
472
def test_index_mask_nd(self):
473
class MyModel(torch.nn.Module):
474
def forward(self, input):
475
return input[input > 0]
477
m1 = torch.randn(3, 4, 5, 6, 7)
478
self.run_test(MyModel(), m1)
482
class MyModel(torch.nn.Module):
483
def forward(self, x_in):
485
x_out["test_key_out"] = torch.add(
486
x_in[list(x_in.keys())[0]], list(x_in.keys())[0] # noqa: RUF015
490
x = {torch.tensor(1.0): torch.randn(1, 2, 3)}
491
self.run_test(MyModel(), (x,))
494
def test_dict_str(self):
495
class MyModel(torch.nn.Module):
496
def forward(self, x_in):
498
x_out["test_key_out"] = torch.add(x_in["test_key_in"], 2.0)
501
x = {"test_key_in": torch.randn(1, 2, 3)}
502
self.run_test(MyModel(), (x,))
504
@skipScriptTest() # User-defined class not supported
505
def test_dict_output(self):
506
class DictModelOutput(OrderedDict):
508
tuple_out: Optional[Tuple[Tensor]] = None
509
list_out: Optional[List[Tensor]] = None
511
class MyModel(torch.nn.Module):
512
def forward(self, a, b, c, d):
513
return DictModelOutput(
519
a = torch.randn(2, 3)
520
b = torch.randn(2, 3)
521
c = torch.randn(2, 3)
522
d = torch.randn(2, 3)
523
self.run_test(MyModel(), (a, b, c, d))
525
def test_tuple_output(self):
526
class MyModel(torch.nn.Module):
527
def forward(self, a, b, c, d):
530
a = torch.randn(2, 3)
531
b = torch.randn(2, 3)
532
c = torch.randn(2, 3)
533
d = torch.randn(2, 3)
534
self.run_test(MyModel(), (a, b, c, d))
536
def test_nested_tuple_output(self):
537
class MyModel(torch.nn.Module):
538
def forward(self, a, b, c, d):
539
return a, ((b,), (c, d))
541
a = torch.randn(2, 3)
542
b = torch.randn(2, 3)
543
c = torch.randn(2, 3)
544
d = torch.randn(2, 3)
545
self.run_test(MyModel(), (a, b, c, d))
547
def test_tuple_input(self):
548
class TupleModel(torch.nn.Module):
549
def forward(self, a: Tuple[Tensor, Tensor]):
552
x = (torch.randn(3, 4), torch.randn(4, 3))
553
self.run_test(TupleModel(), input_args=(x,))
555
def test_tuple_primitive_input(self):
556
class TupleModel(torch.nn.Module):
557
def forward(self, a: Tuple[int, Tensor], b):
558
return a[0], a[1] + b
560
x = (3, torch.randn(4, 3))
561
y = torch.randn(4, 3)
562
self.run_test(TupleModel(), input_args=(x, y))
564
def test_nested_tuple_input(self):
565
class NestedTupleModel(torch.nn.Module):
566
def forward(self, a, b: Tuple[Tensor, Tuple[Tensor, Tensor]]):
567
return a + b[0] + b[1][0] + b[1][1]
569
x = torch.randn(4, 5)
570
y = (torch.randn(4, 5), (torch.randn(1, 5), torch.randn(4, 1)))
571
self.run_test(NestedTupleModel(), input_args=(x, y))
573
@skipScriptTest() # Needs https://github.com/pytorch/rfcs/pull/21
574
@skipIfUnsupportedMinOpsetVersion(15)
575
def test_mixed_optional_default_none(self):
576
class Model(torch.nn.Module):
580
y: Optional[Tensor] = None,
581
z: Optional[Tensor] = None,
589
x = torch.randn(2, 3)
590
y = torch.randn(2, 3)
591
z = torch.randn(2, 3)
593
# Without kwargs dict.
594
self.run_test(model, (x, y, None))
595
self.run_test(model, (x, None, z))
597
self.run_test(model, (x,), {"y": y, "z": None})
598
self.run_test(model, (x,), {"y": None, "z": z})
599
self.run_test(model, (x,), {"z": z})
600
self.run_test(model, (x,), {"y": y})
602
@skipScriptTest() # tracing eliminates None inputs so it works differently. See _script version below.
603
@skipIfUnsupportedMinOpsetVersion(15)
604
def test_mixed_optional_default_tensor(self):
605
class Model(torch.nn.Module):
609
y: Optional[Tensor] = torch.ones(2, 3),
610
z: Optional[Tensor] = torch.zeros(2, 3),
618
x = torch.randn(2, 3)
619
y = torch.randn(2, 3)
620
z = torch.randn(2, 3)
623
self.run_test(model, (x, y, None))
624
self.run_test(model, (x, None, z))
626
@skipTraceTest() # tracing is verified with different set of inputs. See above.
627
@skipIfUnsupportedMinOpsetVersion(15)
628
def test_mixed_optional_default_tensor_script(self):
629
class Model(torch.nn.Module):
633
y: Optional[Tensor] = torch.ones(2, 3),
634
z: Optional[Tensor] = torch.zeros(2, 3),
642
x = torch.randn(2, 3)
643
y = torch.randn(2, 3)
644
z = torch.randn(2, 3)
645
model = torch.jit.script(Model())
647
self.run_test(model, (x, y, z), input_names=("x", "y", "z"))
648
self.run_test(model, (x,), {"y": y, "z": z}, input_names=("x", "y", "z"))
649
self.run_test(model, (x,), {"y": y}, input_names=("x", "y"))
651
for example_inputs, example_kwargs in (
654
((x,), {"y": y, "z": None}),
655
((x,), {"y": None, "z": z}),
657
with self.assertRaisesRegex(
658
ValueError, "args contained 1 None's after flattening."
661
model, example_inputs, example_kwargs, input_names=("x", "y", "z")
664
@skipScriptTest() # Needs https://github.com/pytorch/rfcs/pull/21
665
@skipIfUnsupportedMinOpsetVersion(15)
666
def test_all_optional_default_none(self):
667
class Model(torch.nn.Module):
668
def forward(self, x: Optional[Tensor] = None, y: Optional[Tensor] = None):
674
return torch.tensor(-1.0)
676
x = torch.randn(2, 3)
678
self.run_test(model, (x, None))
683
# y disappears in tracing.
687
@skipScriptTest() # tracing eliminates None inputs so it works differently. See _script version below.
688
@skipIfUnsupportedMinOpsetVersion(15)
689
def test_all_optional_default_tensor(self):
690
class Model(torch.nn.Module):
693
x: Optional[Tensor] = torch.ones(2, 3),
694
y: Optional[Tensor] = torch.zeros(2, 3),
701
return torch.tensor(-1.0)
703
x = torch.randn(2, 3)
704
y = torch.randn(2, 3)
706
self.run_test(model, (x, None))
707
self.run_test(model, (None, y))
708
# tracing means y is never used so it's removed from the exported model inputs,
709
# and we fail when trying to run ORT.
710
with self.assertRaisesRegex(ValueError, "got too many positional inputs"):
711
self.run_test(model, (x, y))
713
@skipTraceTest() # tracing is verified with different set of inputs. See above.
714
@skipIfUnsupportedMinOpsetVersion(15)
715
def test_all_optional_default_tensor_script(self):
716
class Model(torch.nn.Module):
719
x: Optional[Tensor] = torch.ones(2, 3),
720
y: Optional[Tensor] = torch.zeros(2, 3),
727
return torch.tensor(-1.0)
729
x = torch.randn(2, 3)
730
y = torch.randn(2, 3)
731
model = torch.jit.script(Model())
733
# Optional supports None inputs
734
self.run_test(model, (x,))
735
# NOTE: default value is not supported on ONNX, so torch and ONNX has
737
with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
738
self.run_test(model, (), {"y": y}, input_names=["y"])
740
self.run_test(model, (x, y))
741
self.run_test(model, (), {"x": x, "y": y}, input_names=("x", "y"))
743
@skipIfUnsupportedMinOpsetVersion(9)
744
def test_logit(self):
745
class Logit(torch.nn.Module):
746
def __init__(self, eps):
750
def forward(self, x):
751
return x.logit(self.eps)
753
model = Logit(eps=1e-6)
754
self.run_test(model, torch.randn(1, 3, 640, 640))
756
class Atleast1d(torch.nn.Module):
757
def forward(self, t, w, x, y, z):
758
return torch.atleast_1d((t, w, x, y, z))
760
class Atleast2d(torch.nn.Module):
761
def forward(self, t, w, x, y, z):
762
return torch.atleast_2d((t, w, x, y, z))
764
class Atleast3d(torch.nn.Module):
765
def forward(self, t, w, x, y, z):
766
return torch.atleast_3d((t, w, x, y, z))
768
class Atleast1dTensor(torch.nn.Module):
769
def forward(self, x):
770
return torch.atleast_1d(x)
772
class Atleast2dTensor(torch.nn.Module):
773
def forward(self, x):
774
return torch.atleast_2d(x)
776
class Atleast3dTensor(torch.nn.Module):
777
def forward(self, x):
778
return torch.atleast_3d(x)
780
@skipScriptTest() # tracing uses prim::ListUnpack to avoid onnx::SequenceConstruct
781
@skipIfUnsupportedMinOpsetVersion(11)
782
@common_utils.parametrize("module_class", (Atleast1d, Atleast2d, Atleast3d))
783
def test_atleast_nd_list_input(self, module_class: torch.nn.Module):
788
torch.randn(2, 3, 4),
789
torch.randn(2, 3, 4, 5),
791
self.run_test(module_class(), inputs)
793
@skipScriptTest() # tracing uses prim::ListUnpack to avoid onnx::SequenceConstruct
794
@skipIfUnsupportedMinOpsetVersion(11)
795
@common_utils.parametrize(
796
"module_class", (Atleast1dTensor, Atleast2dTensor, Atleast3dTensor)
798
@common_utils.parametrize(
804
torch.randn(2, 3, 4),
805
torch.randn(2, 3, 4, 5),
808
def test_atleast_nd_single_tensor_input(
809
self, module_class: torch.nn.Module, inputs: torch.Tensor
811
self.run_test(module_class(), inputs)
813
@skipScriptTest() # Needs https://github.com/pytorch/rfcs/pull/21
814
@skipIfUnsupportedMinOpsetVersion(15)
815
def test_mixed_optional(self):
816
class Model(torch.nn.Module):
817
def forward(self, x, y: Optional[Tensor]):
822
x = torch.randn(2, 3)
824
self.run_test(model, (x, None))
825
self.run_test(model, (x, x))
827
@skipScriptTest() # Needs https://github.com/pytorch/rfcs/pull/21
828
@skipIfUnsupportedMinOpsetVersion(15)
829
def test_tuple_of_optional(self):
830
class Model(torch.nn.Module):
831
def forward(self, x, y: Tuple[Optional[Tensor], Optional[Tensor]]):
838
x = torch.randn(2, 3)
839
y1 = torch.randn(2, 3)
840
self.run_test(Model(), (x, (None, y1)))
842
@skipScriptTest() # tracing eliminates None inputs so it works differently. See _script version below.
843
@skipIfUnsupportedMinOpsetVersion(15)
844
def test_tuple_of_optional_default_tensor(self):
845
class Model(torch.nn.Module):
849
y: Tuple[Optional[Tensor], Optional[Tensor]] = (
861
x = torch.randn(2, 3)
862
y1 = torch.randn(2, 3)
863
self.run_test(Model(), (x, (None, y1)))
865
@skipTraceTest() # tracing is verified with different set of inputs. See above.
866
@skipIfUnsupportedMinOpsetVersion(15)
867
def test_tuple_of_optional_default_tensor_script(self):
868
class Model(torch.nn.Module):
872
y: Tuple[Optional[Tensor], Optional[Tensor]] = (
884
x = torch.randn(2, 3)
885
y0 = torch.randn(2, 3)
886
y1 = torch.randn(2, 3)
887
model = torch.jit.script(Model())
888
with self.assertRaisesRegex(
889
ValueError, "args contained 1 None's after flattening."
891
self.run_test(model, (x, (None, y1)))
892
self.run_test(model, (x, (y0, y1)))
893
# export succeeds, but running ORT through run_test would fail because the exported model
894
# has the inputs flattened into 3 inputs.
896
model, (x, {"y": (y0, y1)}), io.BytesIO(), opset_version=self.opset_version
899
def test_primitive_input_integer(self):
900
class Model(torch.nn.Module):
901
def forward(self, x: int, y):
905
y = torch.randint(10, (2, 3, 4))
906
self.run_test(Model(), (x, y))
909
def test_primitive_input_floating(self):
910
class Model(torch.nn.Module):
911
def forward(self, x: float, y):
915
y = torch.randn(2, 3, 4)
916
self.run_test(Model(), (x, y))
918
def test_primitive_input_bool(self):
919
class Model(torch.nn.Module):
920
def forward(self, flag: bool, x, y):
927
x = torch.randn(2, 3, 4)
928
y = torch.randn(2, 3, 4)
929
self.run_test(torch.jit.script(Model()), (flag, x, y))
931
@skipIfUnsupportedMinOpsetVersion(9)
932
def test_cste_script(self):
933
class MyModel(torch.jit.ScriptModule):
934
@torch.jit.script_method
935
def forward(self, x):
936
return torch.zeros(x.size(0)), torch.ones(
937
(x.size(1), x.size(0)), dtype=torch.int64
940
x = torch.randn(3, 4)
941
self.run_test(MyModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
942
self.run_test(MyModel(), x, remained_onnx_input_idx=[])
944
def test_scalar_tensor(self):
945
class test(torch.nn.Module):
946
def forward(self, input):
947
return torch.scalar_tensor(input.size(0)), torch.scalar_tensor(
948
input.size(1), dtype=torch.int64
951
x = torch.randn(2, 3, 4)
952
y = torch.randn(7, 8, 9)
957
additional_test_inputs=[y],
958
input_names=["input_1"],
959
dynamic_axes={"input_1": [0, 1, 2]},
962
def test_tensor(self):
963
class ScalarInputModel(torch.jit.ScriptModule):
964
@torch.jit.script_method
965
def forward(self, input):
966
return torch.tensor(input.shape[1])
968
x = torch.randn(3, 4)
970
ScalarInputModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}
972
self.run_test(ScalarInputModel(), x, remained_onnx_input_idx=[])
974
class TensorInputModel(torch.jit.ScriptModule):
975
@torch.jit.script_method
976
def forward(self, input):
977
return torch.tensor([input.shape[0], input.shape[1]])
979
x = torch.randn(3, 4)
981
TensorInputModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}
983
self.run_test(TensorInputModel(), x, remained_onnx_input_idx=[])
985
class FloatInputModel(torch.jit.ScriptModule):
986
@torch.jit.script_method
987
def forward(self, input):
988
return torch.tensor([float(input)])
991
self.run_test(FloatInputModel(), x)
993
class InputWithDtypeModel(torch.jit.ScriptModule):
994
@torch.jit.script_method
995
def forward(self, input):
996
return torch.tensor(input.shape[1], dtype=torch.long)
998
x = torch.randn(3, 4)
1000
InputWithDtypeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}
1002
self.run_test(InputWithDtypeModel(), x, remained_onnx_input_idx=[])
1004
class MixedInputModel(torch.jit.ScriptModule):
1005
@torch.jit.script_method
1006
def forward(self, input):
1007
return torch.tensor([input.shape[0], int(input)])
1010
self.run_test(MixedInputModel(), x)
1012
def test_hardtanh(self):
1013
model = torch.nn.Hardtanh(-1.5, 2.5)
1014
x = torch.arange(-5, 5).to(dtype=torch.float32)
1015
self.run_test(model, x)
1017
def test_hardtanh_script_with_default_values(self):
1018
class MyModel(torch.jit.ScriptModule):
1019
@torch.jit.script_method
1020
def forward(self, x):
1021
return torch.nn.functional.hardtanh(x)
1023
x = torch.arange(-5, 5).to(dtype=torch.float32)
1024
self.run_test(MyModel(), x)
1026
def test_hardswish(self):
1027
model = torch.nn.Hardswish()
1029
x = torch.rand(3, 3).to(dtype=torch.float32)
1030
self.run_test(model, x)
1032
# Testing edge cases
1033
x = torch.tensor(3).to(dtype=torch.float32)
1034
self.run_test(model, x)
1035
x = torch.tensor(-3).to(dtype=torch.float32)
1036
self.run_test(model, x)
1038
def test_hardswish_script(self):
1039
class MyModel(torch.jit.ScriptModule):
1040
@torch.jit.script_method
1041
def forward(self, x):
1042
return torch.nn.functional.hardswish(x)
1044
x = torch.rand(3, 3).to(dtype=torch.float32)
1045
self.run_test(MyModel(), x)
1047
def test_hardsigmoid(self):
1048
model = torch.nn.Hardsigmoid()
1050
x = torch.rand(3, 3).to(dtype=torch.float32)
1051
self.run_test(model, x)
1054
x = torch.tensor(3).to(dtype=torch.float32)
1055
self.run_test(model, x)
1056
x = torch.tensor(-3).to(dtype=torch.float32)
1057
self.run_test(model, x)
1059
def test_tanhshrink(self):
1060
model = torch.nn.Tanhshrink()
1062
x = torch.rand(3, 3).to(dtype=torch.float32)
1063
self.run_test(model, x)
1065
@skipIfUnsupportedMinOpsetVersion(9)
1066
def test_hardshrink(self):
1067
model = torch.nn.Hardshrink()
1069
x = torch.rand(3, 3).to(dtype=torch.float32)
1070
self.run_test(model, x)
1072
# Testing edge cases
1073
x = torch.tensor(0.5).to(dtype=torch.float32)
1074
self.run_test(model, x)
1075
x = torch.tensor(-0.5).to(dtype=torch.float32)
1076
self.run_test(model, x)
1078
@skipIfUnsupportedMinOpsetVersion(9)
1079
def test_hardshrink_dtype(self):
1080
x = torch.rand(3, 3).to(dtype=torch.float64)
1081
self.run_test(torch.nn.Hardshrink(), x)
1083
@skipIfUnsupportedMinOpsetVersion(9)
1084
def test_softshrink(self):
1085
model = torch.nn.Softshrink()
1087
x = torch.rand(3, 3).to(dtype=torch.float32)
1088
self.run_test(model, x)
1090
# Testing edge cases
1091
x = torch.tensor(0.5).to(dtype=torch.float32)
1092
self.run_test(model, x)
1093
x = torch.tensor(-0.5).to(dtype=torch.float32)
1094
self.run_test(model, x)
1096
@skipIfUnsupportedMinOpsetVersion(9)
1097
def test_softshrink_dtype(self):
1098
x = torch.rand(3, 3).to(dtype=torch.float64)
1099
self.run_test(torch.nn.Softshrink(), x)
1101
def test_clamp(self):
1102
class ClampModel(torch.nn.Module):
1103
def forward(self, x):
1104
return x.clamp(-0.5, 0.5)
1106
x = torch.randn(3, 4)
1107
self.run_test(ClampModel(), x)
1109
class ClampMinModel(torch.nn.Module):
1110
def forward(self, x):
1111
return x.clamp(min=-0.5)
1113
x = torch.randn(3, 4)
1114
self.run_test(ClampMinModel(), x)
1116
class ClampMaxModel(torch.nn.Module):
1117
def forward(self, x):
1118
return x.clamp(max=0.5)
1120
x = torch.randn(3, 4)
1121
self.run_test(ClampMaxModel(), x)
1123
@skipIfUnsupportedMinOpsetVersion(8)
1124
def test_clamp_dyn(self):
1125
class ClampMaxModel(torch.jit.ScriptModule):
1126
@torch.jit.script_method
1127
def forward(self, x):
1128
return x.clamp(None, x.size(0))
1130
x = torch.arange(16).view(4, 4).float()
1131
self.run_test(ClampMaxModel(), x)
1133
class ClampMinModel(torch.jit.ScriptModule):
1134
@torch.jit.script_method
1135
def forward(self, x):
1136
return x.clamp(x.size(0), None)
1138
x = torch.arange(16).view(4, 4).float()
1139
self.run_test(ClampMinModel(), x)
1141
class ClampMinMaxModel(torch.jit.ScriptModule):
1142
@torch.jit.script_method
1143
def forward(self, x):
1144
return x.clamp(x.size(0), x.size(1))
1146
x = torch.arange(16).view(2, 8).float()
1147
self.run_test(ClampMinMaxModel(), x)
1149
class ClampTensorModel(torch.nn.Module):
1150
def forward(self, x, min, max):
1151
return x.clamp(min, max)
1153
x = torch.randn(3, 4)
1154
y = torch.randn(3, 4)
1155
z = torch.randn(3, 4)
1156
self.run_test(ClampTensorModel(), (x, y, z))
1158
class ClampTensorMinModel(torch.nn.Module):
1159
def forward(self, x, min):
1160
return x.clamp(min=min)
1162
self.run_test(ClampTensorMinModel(), (x, y))
1164
class ClampTensorMaxModel(torch.nn.Module):
1165
def forward(self, x, max):
1166
return x.clamp(max=max)
1168
self.run_test(ClampTensorMaxModel(), (x, z))
1170
@skipIfUnsupportedMinOpsetVersion(9)
1171
def test_full_trace(self):
1172
class FullModel(torch.nn.Module):
1173
def forward(self, x):
1174
return torch.full((3, 4), x, dtype=torch.long)
1176
x = torch.tensor(12)
1177
self.run_test(FullModel(), x)
1179
@skipIfUnsupportedMinOpsetVersion(9)
1180
def test_full_script(self):
1181
class FullModelScripting(torch.jit.ScriptModule):
1182
@torch.jit.script_method
1183
def forward(self, x):
1184
return torch.full((3, 4), x, dtype=torch.long)
1186
x = torch.tensor(12)
1187
self.run_test(FullModelScripting(), x)
1189
def test_fuse_addmm(self):
1190
class AddmmModel(torch.nn.Module):
1191
def forward(self, x):
1192
return torch.mm(x, x) + x
1194
x = torch.ones(3, 3)
1195
self.run_test(AddmmModel(), x)
1197
def test_maxpool(self):
1198
model = torch.nn.MaxPool1d(2, stride=1)
1199
x = torch.randn(20, 16, 50)
1200
self.run_test(model, x)
1202
def test_conv(self):
1203
class TraceModel(torch.nn.Module):
1206
self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2)
1207
self.conv2 = torch.nn.Conv2d(
1208
16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)
1210
self.conv3 = torch.nn.Conv3d(
1211
16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)
1214
def forward(self, input1, input2, input3):
1215
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
1217
x1 = torch.randn(20, 16, 50)
1218
x2 = torch.randn(20, 16, 50, 50)
1219
x3 = torch.randn(20, 16, 10, 50, 50)
1221
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
1223
def test_conv_str_padding(self):
1224
class TraceModel(torch.nn.Module):
1227
self.conv1 = torch.nn.Conv1d(16, 33, 3, padding="valid")
1228
self.conv2 = torch.nn.Conv2d(
1229
16, 33, (3, 5), stride=1, padding="valid", dilation=(3, 1)
1231
self.conv3 = torch.nn.Conv3d(
1232
16, 33, (3, 5, 2), stride=1, padding="same"
1235
def forward(self, input1, input2, input3):
1236
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
1238
x1 = torch.randn(20, 16, 50)
1239
x2 = torch.randn(20, 16, 50, 50)
1240
x3 = torch.randn(20, 16, 10, 50, 50)
1242
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
1244
def test_conv_shape_inference(self):
1245
class Model(torch.nn.Module):
1248
self.conv2 = torch.nn.Conv2d(
1249
16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)
1252
def forward(self, input):
1253
return self.conv2(input) + 2
1255
x = torch.randn(20, 16, 50, 100)
1257
Model(), x, atol=10e-5, input_names=["x"], dynamic_axes={"x": [0]}
1260
def test_conv_transpose(self):
1261
class TraceModel(torch.nn.Module):
1264
self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2)
1265
self.conv2 = torch.nn.ConvTranspose2d(
1266
16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)
1268
self.conv3 = torch.nn.ConvTranspose3d(
1269
16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)
1272
def forward(self, input1, input2, input3):
1273
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
1275
x1 = torch.randn(20, 16, 10)
1276
x2 = torch.randn(20, 16, 10, 10)
1277
x3 = torch.randn(20, 16, 10, 10, 10)
1279
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
1281
def test_numpy_T(self):
1282
class NumpyTranspose(torch.nn.Module):
1283
def forward(self, x):
1286
self.run_test(NumpyTranspose(), torch.randn(4, 7))
1288
# Conversion of Transpose depends on input shape to be known.
1289
# The following test only works when onnx shape inference is enabled.
1290
def test_transpose_infer_shape(self):
1291
class TransposeModule(torch.jit.ScriptModule):
1294
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
1296
@torch.jit.script_method
1297
def forward(self, x):
1299
return x.transpose(0, 1)
1301
x = torch.randn(32, 3, 64, 64)
1302
y = torch.randn(16, 3, 8, 64)
1307
dynamic_axes={"x": [0, 2]},
1308
additional_test_inputs=[y],
1311
def squeeze_model_tests(self, d, x1, x2):
1312
class Squeeze(torch.nn.Module):
1313
def __init__(self, d):
1317
def forward(self, x):
1318
if self.d is not None:
1319
return torch.squeeze(x, dim=self.d)
1321
return torch.squeeze(x)
1323
x2 = [] if x2 is None else [x2]
1328
input_names=["input"],
1329
dynamic_axes={"input": {0: "0", 1: "1", 2: "2"}},
1330
additional_test_inputs=x2,
1333
self.run_test(Squeeze(d), x1)
1335
def test_squeeze_without_no_op(self):
1336
x = torch.randn(2, 1, 4)
1337
self.squeeze_model_tests(1, x, None)
1339
@skipIfUnsupportedMinOpsetVersion(11)
1340
def test_squeeze_dynamic(self):
1341
x_squeeze = torch.randn(2, 1, 4)
1342
x_noop = torch.randn(2, 2, 3)
1343
self.squeeze_model_tests(1, x_squeeze, x_noop)
1345
def test_squeeze_neg_without_no_op(self):
1346
x = torch.randn(2, 1, 4)
1347
self.squeeze_model_tests(-2, x, None)
1349
@skipIfUnsupportedMinOpsetVersion(11)
1350
def test_squeeze_neg(self):
1351
x_squeeze = torch.randn(2, 1, 4)
1352
x_noop = torch.randn(2, 2, 3)
1353
self.squeeze_model_tests(-2, x_squeeze, x_noop)
1355
def test_squeeze_all_dims(self):
1356
x_squeeze = torch.randn(2, 1, 4)
1357
x_noop = torch.randn(2, 2, 3)
1358
self.squeeze_model_tests(None, x_squeeze, x_noop)
1360
@skipIfUnsupportedMinOpsetVersion(11)
1361
def test_squeeze_no_op(self):
1362
x_noop = torch.randn(2, 1, 4)
1363
x_squeeze = torch.randn(2, 2, 1)
1364
self.squeeze_model_tests(2, x_noop, x_squeeze)
1366
@skipIfUnsupportedMinOpsetVersion(11)
1367
def test_squeeze_runtime_dim(self):
1368
class Squeeze(torch.nn.Module):
1369
def forward(self, d1, d2):
1370
t = torch.zeros(d1[0], d2[0])
1373
d1 = torch.tensor([1])
1374
d3 = torch.tensor([3])
1375
d4 = torch.tensor([4])
1376
self.run_test(Squeeze(), (d1, d4), additional_test_inputs=[(d3, d4)])
1377
self.run_test(Squeeze(), (d3, d4), additional_test_inputs=[(d1, d3)])
1379
def test_squeeze(self):
1380
class Squeeze(torch.nn.Module):
1381
def forward(self, x):
1382
return torch.squeeze(x, dim=-2)
1384
x = torch.randn(2, 1, 4)
1385
self.run_test(Squeeze(), x)
1387
@skipIfUnsupportedMinOpsetVersion(13)
1388
def test_squeeze_dynamic_dim(self):
1389
class Squeeze(torch.nn.Module):
1390
def forward(self, x, dim: int):
1391
return torch.squeeze(x, dim)
1393
x = torch.randn(2, 1, 4)
1395
self.run_test(Squeeze(), (x, dim))
1397
def test_unsqueeze(self):
1398
class Unsqueeze(torch.nn.Module):
1399
def forward(self, x):
1400
return torch.unsqueeze(x, dim=-2)
1402
x = torch.randn(2, 3, 4)
1403
self.run_test(Unsqueeze(), x)
1405
@skipIfUnsupportedMinOpsetVersion(13)
1406
def test_unsqueeze_dynamic_dim(self):
1407
class Unsqueeze(torch.nn.Module):
1408
def forward(self, x, dim: int):
1409
return torch.unsqueeze(x, dim)
1411
x = torch.randn(2, 1, 4)
1413
self.run_test(Unsqueeze(), (x, dim))
1415
def test_maxpool_default_stride(self):
1416
class MaxPoolModel(torch.nn.Module):
1417
def forward(self, x):
1418
return torch.nn.functional.max_pool2d(x, 2)
1420
model = MaxPoolModel()
1421
x = torch.randn(10, 20, 16, 50)
1422
self.run_test(model, x)
1424
@skipIfUnsupportedMinOpsetVersion(8)
1425
def test_maxpool_adaptive(self):
1426
model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False)
1427
x = torch.randn(20, 16, 50, requires_grad=True)
1428
y = torch.randn(32, 16, 50, requires_grad=True)
1433
dynamic_axes={"x": [0]},
1434
additional_test_inputs=[y],
1437
def test_maxpool_2d(self):
1438
model = torch.nn.MaxPool2d(5, padding=(1, 2))
1439
x = torch.randn(1, 20, 16, 50, requires_grad=True)
1440
self.run_test(model, x)
1442
def test_maxpool_1d_ceil(self):
1443
model = torch.nn.MaxPool1d(3, 2, ceil_mode=True)
1444
x = torch.randn(20, 16, 50)
1445
self.run_test(model, x)
1447
def test_maxpool_2d_ceil(self):
1448
model = torch.nn.MaxPool2d(3, 2, ceil_mode=True)
1449
x = torch.randn(20, 16, 50, 32)
1450
self.run_test(model, x)
1452
def test_maxpool_3d_ceil(self):
1453
model = torch.nn.MaxPool3d(3, 2, ceil_mode=True)
1454
x = torch.randn(20, 16, 50, 44, 31)
1455
self.run_test(model, x)
1457
@skipIfUnsupportedMinOpsetVersion(10)
1458
def test_maxpool_dynamic(self):
1459
class test(torch.nn.Module):
1460
def __init__(self, in_channels, out_channels):
1462
norm_layer = functools.partial(torch.nn.BatchNorm2d, eps=0.0009)
1463
self.avgpool = torch.nn.MaxPool2d((2, 2), stride=2, ceil_mode=True)
1464
self.conv = torch.nn.Conv2d(
1465
in_channels, out_channels, kernel_size=1, stride=1, bias=False
1467
self.norm = norm_layer(out_channels)
1469
def forward(self, x):
1470
return self.norm(self.conv(self.avgpool(x)))
1473
inputs = torch.randn(2, 8, 64, 64)
1477
input_names=["input_0"],
1478
dynamic_axes={"input_0": {3: "x", 2: "y"}, "output_0": {3: "x", 2: "y"}},
1479
output_names=["output_0"],
1482
# TODO: Enable maxpool-ceil family after ONNX 1.15.1+ is bumped
1483
@skipIfUnsupportedMaxOpsetVersion(9)
1484
def test_maxpool_1d_ceil_corner(self):
1485
model = torch.nn.MaxPool1d(
1486
kernel_size=1, dilation=1, stride=2, ceil_mode=True, return_indices=False
1488
x = torch.randn(1, 3, 32)
1489
self.run_test(model, x)
1491
@skipIfUnsupportedMaxOpsetVersion(9)
1492
def test_maxpool_2d_ceil_corner(self):
1493
model = torch.nn.MaxPool2d(
1498
return_indices=False,
1500
x = torch.randn(1, 3, 32, 32)
1501
self.run_test(model, x)
1503
@skipIfUnsupportedMaxOpsetVersion(9)
1504
def test_maxpool_3d_ceil_corner(self):
1505
model = torch.nn.MaxPool3d(
1506
kernel_size=[7, 8, 4],
1511
return_indices=False,
1513
x = torch.randn(1, 3, 51, 52, 45)
1514
self.run_test(model, x)
1516
@skipIfUnsupportedMaxOpsetVersion(9)
1517
@skipIfUnsupportedMinOpsetVersion(8)
1518
def test_maxpool_1d_ceil_corner_with_indices(self):
1519
model = torch.nn.MaxPool1d(
1520
kernel_size=1, dilation=1, stride=2, ceil_mode=True, return_indices=True
1522
x = torch.randn(1, 3, 32)
1523
self.run_test(model, x)
1525
@skipIfUnsupportedMaxOpsetVersion(9)
1526
@skipIfUnsupportedMinOpsetVersion(8)
1527
def test_maxpool_2d_ceil_corner_with_indices(self):
1528
model = torch.nn.MaxPool2d(
1533
return_indices=True,
1535
x = torch.randn(1, 3, 32, 32)
1536
self.run_test(model, x)
1538
@skipIfUnsupportedMaxOpsetVersion(9)
1539
@skipIfUnsupportedMinOpsetVersion(8)
1540
def test_maxpool_3d_ceil_corner_with_indices(self):
1541
model = torch.nn.MaxPool3d(
1542
kernel_size=[7, 8, 4],
1547
return_indices=True,
1549
x = torch.randn(1, 3, 51, 52, 45)
1550
self.run_test(model, x)
1552
@skipIfUnsupportedMinOpsetVersion(8)
1553
def test_maxpool_with_indices(self):
1554
model = torch.nn.MaxPool1d(2, stride=1, return_indices=True)
1555
x = torch.randn(20, 16, 50)
1556
self.run_test(model, x)
1558
@skipIfUnsupportedMinOpsetVersion(10)
1559
def test_maxpool_dilation(self):
1560
model = torch.nn.MaxPool1d(2, stride=1, dilation=2)
1561
x = torch.randn(20, 16, 50)
1562
self.run_test(model, x)
1564
def test_avgpool_default_stride(self):
1565
class AvgPoolModel(torch.nn.Module):
1566
def forward(self, x):
1567
return torch.nn.functional.avg_pool2d(x, 2)
1569
model = AvgPoolModel()
1570
x = torch.randn(10, 20, 16, 50)
1571
self.run_test(model, x)
1573
def test_avgpool(self):
1574
model = torch.nn.AvgPool1d(2, stride=1)
1575
x = torch.randn(20, 16, 50)
1576
self.run_test(model, x)
1578
def test_avgpool_1d_ceil(self):
1579
model = torch.nn.AvgPool1d(3, 2, ceil_mode=True)
1580
x = torch.randn(1, 1, 7)
1581
self.run_test(model, x)
1583
# TODO: ceil_mode is not included in the test, because of
1584
# https://github.com/microsoft/onnxruntime/issues/16203
1585
# The ORT and PyTorch has different calculation for ceil_mode (the last value).
1586
@common_utils.parametrize(
1590
@common_utils.parametrize(
1591
"count_include_pad",
1594
def test_avgpool_2d(self, padding, count_include_pad):
1595
model = torch.nn.AvgPool2d(
1599
count_include_pad=count_include_pad,
1601
x = torch.randn(20, 16, 50, 32)
1602
self.run_test(model, x)
1604
# TODO: ceil_mode is not included in the test, because of
1605
# https://github.com/microsoft/onnxruntime/issues/16203
1606
# The ORT and PyTorch has different calculation for ceil_mode (the last value).
1607
@skipIfUnsupportedMinOpsetVersion(19)
1608
def test_avgpool_3d_ceil(self):
1609
model = torch.nn.AvgPool3d(3, 2, ceil_mode=True)
1610
x = torch.randn(20, 16, 50, 44, 31)
1611
y = torch.randn(32, 8, 50, 44, 31)
1616
dynamic_axes={"x": [0, 1]},
1617
additional_test_inputs=[y],
1620
@skipIfUnsupportedMinOpsetVersion(10)
1621
def test_avgpool_dynamic(self):
1622
class test(torch.nn.Module):
1623
def __init__(self, in_channels, out_channels):
1625
norm_layer = functools.partial(torch.nn.BatchNorm2d, eps=0.0009)
1626
self.avgpool = torch.nn.AvgPool2d(
1627
(2, 2), stride=2, ceil_mode=True, count_include_pad=False
1629
self.conv = torch.nn.Conv2d(
1630
in_channels, out_channels, kernel_size=1, stride=1, bias=False
1632
self.norm = norm_layer(out_channels)
1634
def forward(self, x):
1635
return self.norm(self.conv(self.avgpool(x)))
1638
inputs = torch.randn(2, 8, 64, 64)
1642
input_names=["input_0"],
1643
dynamic_axes={"input_0": {3: "x", 2: "y"}, "output_0": {3: "x", 2: "y"}},
1644
output_names=["output_0"],
1647
@skipIfUnsupportedMinOpsetVersion(9)
1648
def test_floating_point(self):
1649
class FloatingPoint(torch.jit.ScriptModule):
1650
@torch.jit.script_method
1651
def forward(self, x):
1652
if x.is_floating_point():
1653
return x.new_zeros(x.shape)
1654
return x.new_zeros(x.shape)
1656
x = torch.randn(2, 3, 4)
1658
FloatingPoint(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
1660
self.run_test(FloatingPoint(), x, remained_onnx_input_idx=[])
1662
class FloatingPoint(torch.jit.ScriptModule):
1663
@torch.jit.script_method
1664
def forward(self, x):
1667
if a.is_floating_point():
1672
x = torch.randn(2, 3, 4)
1673
self.run_test(FloatingPoint(), x)
1675
# Operator rank mismatch between outputs of two branches for opsets below 11.
1676
@skipIfUnsupportedMinOpsetVersion(11)
1677
def test_floating_point_infer_dtype(self):
1678
class FloatingPoint(torch.jit.ScriptModule):
1679
@torch.jit.script_method
1680
def forward(self, x):
1683
if a.is_floating_point():
1684
return x.new_zeros(x.shape[1:])
1685
return x.new_zeros(x.shape)
1688
x = torch.randn(2, 3, 4)
1690
FloatingPoint(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
1692
self.run_test(FloatingPoint(), x, remained_onnx_input_idx=[])
1694
class FloatingPoint(torch.jit.ScriptModule):
1695
@torch.jit.script_method
1696
def forward(self, x):
1699
if a.is_floating_point():
1704
x = torch.randn(2, 3, 4).to(torch.int32)
1705
self.run_test(FloatingPoint(), x)
1707
@skipIfUnsupportedMinOpsetVersion(12)
1708
def test_prim_min(self):
1710
def list_append(boxes: List[Tensor]):
1712
for i, b in enumerate(
1714
): # enumerate is creating a prim::min op in torch graph
1715
temp.append(torch.full_like(b[:, 1], i))
1718
class Min(torch.nn.Module):
1719
def forward(self, x):
1720
boxes = [x for _ in range(3)]
1721
return list_append(boxes)
1723
x = torch.rand(5, 5)
1724
self.run_test(Min(), (x,))
1726
class M(torch.jit.ScriptModule):
1727
@torch.jit.script_method
1728
def forward(self, x):
1732
x = torch.arange(6, dtype=torch.int64)
1733
self.run_test(M(), (x,))
1735
def test_arithmetic(self):
1736
class ArithmeticModule(torch.nn.Module):
1737
def forward(self, x):
1744
x = torch.randn(2, 3, 4)
1745
self.run_test(ArithmeticModule(), x)
1747
def test_arithmetic_prim_long(self):
1748
class ArithmeticModule(torch.nn.Module):
1749
def forward(self, x, y: int):
1756
x = torch.randn(2, 3, 4)
1758
self.run_test(ArithmeticModule(), (x, y))
1760
class ArithmeticModule(torch.nn.Module):
1761
def forward(self, x):
1766
x = torch.randn(2, 3, 4)
1767
self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])
1770
def test_arithmetic_prim_float(self):
1771
class ArithmeticModule(torch.nn.Module):
1772
def forward(self, x, y: float):
1779
x = torch.randn(2, 3, 4)
1781
self.run_test(ArithmeticModule(), (x, y))
1783
class ArithmeticModule(torch.nn.Module):
1784
def forward(self, x):
1787
return x.shape[1] / 2
1789
x = torch.randn(2, 3, 4)
1790
self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])
1793
def test_arithmetic_prim_bool(self):
1794
class ArithmeticModule(torch.nn.Module):
1795
def forward(self, x, y: int, z: bool, t: float):
1803
x = torch.randn(2, 3, 4)
1807
self.run_test(ArithmeticModule(), (x, y, z, t))
1809
class ArithmeticModule(torch.nn.Module):
1810
def forward(self, x: int, y: int):
1815
self.run_test(ArithmeticModule(), (x, y))
1819
reason="In trace: Outputs that are always None are removed. \
1820
In script: Outputs that are always None are removed before opset 15. \
1821
After opset 15, we replace the None in output with Optional node.",
1823
def test_tuple_with_none_outputs(self):
1824
class TupleModel(torch.nn.Module):
1825
def forward(self, x):
1826
return (x, (x, None, (x, None)))
1828
x = torch.randn(3, 4)
1829
self.run_test(TupleModel(), (x,))
1831
# In scripting the first transpose node do not carry shape and dtype info.
1832
# The following test only works when onnx shape inference is enabled.
1833
def test_arithmetic_infer_dtype(self):
1834
class ArithmeticModule(torch.jit.ScriptModule):
1835
@torch.jit.script_method
1836
def forward(self, x):
1844
x = torch.randn(2, 3)
1845
self.run_test(ArithmeticModule(), x)
1847
@unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)")
1848
def test_floor_div(self):
1849
class FloorDivModule(torch.nn.Module):
1850
def forward(self, x, y):
1854
x.to(dtype=torch.float64) // 3,
1855
x.to(dtype=torch.float64) // 2.0,
1856
x.to(dtype=torch.int64) // 3,
1857
x.to(dtype=torch.int64) // 2.0,
1858
x // (y + 1.0).to(dtype=torch.int64),
1860
x.to(dtype=torch.float64) // y.to(dtype=torch.int64),
1861
x.to(dtype=torch.float64) // y.to(dtype=torch.float64),
1862
x.to(dtype=torch.int64) // y.to(dtype=torch.int64),
1863
x.to(dtype=torch.int64) // y,
1866
x = torch.arange(-2, 4).reshape(2, 3, 1)
1867
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4)
1868
self.run_test(FloorDivModule(), (x, y))
1870
@unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)")
1871
def test_floor_div_script(self):
1872
class FloorDivModule(torch.jit.ScriptModule):
1873
@torch.jit.script_method
1874
def forward(self, x, y):
1875
return x // 3, x // 2.0, x // y
1877
x = torch.arange(-2, 4).reshape(2, 3, 1)
1878
y = torch.randn(2, 3, 4)
1879
self.run_test(FloorDivModule(), (x, y))
1881
@unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)")
1882
@skipIfUnsupportedMinOpsetVersion(9)
1883
def test_floordiv(self):
1884
class FloordivModule(torch.nn.Module):
1885
def forward(self, x):
1886
return x.new_zeros(x.size(2) // x.size(1))
1888
x = torch.randn(2, 3, 4)
1890
FloordivModule(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
1892
self.run_test(FloordivModule(), (x,), remained_onnx_input_idx=[])
1895
class DivModule(torch.nn.Module):
1896
def forward(self, x, y):
1897
return x / y, torch.true_divide(x, y)
1899
x = torch.randn(2, 3, 4).to(torch.int)
1900
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
1901
self.run_test(DivModule(), (x, y))
1902
self.run_test(DivModule(), (x.float(), y.float()))
1904
# Note: div cannot (generally) be exported via scripting
1905
# since its type promotion logic is dependent on knowing the scalar types
1906
# of the input tensors. That is, the ONNX graph is dependent on the
1907
# data type of the inputs. This makes it appropriate for tracing only.
1908
def test_div_promotion_trace(self):
1909
class DivModule(torch.nn.Module):
1910
def forward(self, x, y):
1911
return x / y, torch.true_divide(x, y)
1913
x = torch.randn(2, 3, 4).to(torch.int)
1914
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
1916
with common_utils.set_default_dtype(torch.float):
1917
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
1919
with common_utils.set_default_dtype(torch.double):
1920
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
1922
# In scripting x, y do not carry shape and dtype info.
1923
# The following test only works when onnx shape inference is enabled.
1924
def test_div_promotion_script(self):
1925
class DivModule(torch.nn.Module):
1926
def forward(self, x, y):
1927
# Add transpose to hide shape/type information
1928
# Otherwise shape and type are still avaiable from input.
1929
x = x.transpose(1, 2)
1930
y = y.transpose(1, 2)
1931
return x / y, torch.true_divide(x, y)
1933
x = torch.randn(2, 3, 4).to(torch.int)
1934
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
1936
# 1. x,y are int, and output is float.
1937
# This can be handled by the default case, where both are cast to float.
1938
# It works even if type of x, y are unknown.
1939
with common_utils.set_default_dtype(torch.float):
1940
self.run_test(torch.jit.script(DivModule()), (x, y))
1942
# 2. x,y are int, and output is double.
1943
# This can be handled by the default case, where both are cast to double.
1944
# It works even if type of x, y are unknown.
1945
with common_utils.set_default_dtype(torch.double):
1946
self.run_test(torch.jit.script(DivModule()), (x, y))
1948
# 3. x is int, y is double, and output is double.
1949
# This can only be handled when both type of x and y are known.
1950
x = torch.randn(2, 3, 4).to(torch.int)
1951
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double)
1952
self.run_test(torch.jit.script(DivModule()), (x, y))
1955
def test_div_rounding_mode(self):
1956
class TrueDivModule(torch.nn.Module):
1957
def forward(self, x, y):
1959
x.div(y, rounding_mode=None),
1960
torch.div(x, y, rounding_mode=None),
1963
class TruncDivModule(torch.nn.Module):
1964
def forward(self, x, y):
1966
x.div(y, rounding_mode="trunc"),
1967
torch.div(x, y, rounding_mode="trunc"),
1970
class FloorDivModule(torch.nn.Module):
1971
def forward(self, x, y):
1973
x.div(y, rounding_mode="floor"),
1974
torch.div(x, y, rounding_mode="floor"),
1977
modules = [TrueDivModule(), TruncDivModule(), FloorDivModule()]
1979
x = (torch.randn(2, 3, 4) * 100).to(torch.int)
1980
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
1982
for module in modules:
1983
self.run_test(module, (x, y))
1984
self.run_test(torch.jit.trace(module, (x, y)), (x, y))
1985
self.run_test(torch.jit.script(module), (x, y))
1987
x = torch.randn(2, 3, 4)
1988
y = torch.rand(2, 3, 4) * 10.0 + 0.1
1990
for module in modules:
1991
self.run_test(module, (x, y))
1992
self.run_test(torch.jit.trace(module, (x, y)), (x, y))
1993
self.run_test(torch.jit.script(module), (x, y))
1995
def test_slice_trace(self):
1996
class MyModule(torch.nn.Module):
1997
def forward(self, x):
2001
self.run_test(MyModule(), x)
2003
def test_slice_neg(self):
2004
class NegSlice(torch.nn.Module):
2005
def forward(self, x):
2008
x = torch.randn(3, 4, 5)
2009
self.run_test(NegSlice(), x)
2011
def test_slice_neg_large(self):
2012
class NegSlice(torch.nn.Module):
2013
def forward(self, x):
2014
return x[:, :, -3:-1, :, -1]
2016
x = torch.randn(3, 4, 5, 6, 7)
2017
self.run_test(NegSlice(), x)
2019
def test_slice_neg_large_negone(self):
2020
class NegSlice(torch.nn.Module):
2021
def forward(self, x):
2022
return x[:, :, :, :, -1]
2024
x = torch.randn(3, 4, 5, 6, 7)
2025
self.run_test(NegSlice(), x)
2027
@skipIfUnsupportedMinOpsetVersion(11)
2028
def test_slice_with_input_index(self):
2029
class InputIndexSlice(torch.nn.Module):
2030
def forward(self, x, y):
2031
x[: y.size(0), 0, :] = y
2034
x = torch.zeros((56, 6, 256))
2035
y = torch.rand((22, 256))
2036
self.run_test(InputIndexSlice(), (x, y))
2038
@skipIfUnsupportedMinOpsetVersion(11)
2039
@skipScriptTest() # Torchscript doesn't support 1d index.
2040
def test_slice_with_1d_input_index(self):
2041
class InputIndexSlice(torch.nn.Module):
2042
def forward(self, x, y):
2046
x = torch.zeros((56, 6, 256))
2047
y = torch.tensor([5], dtype=torch.int64)
2048
self.run_test(InputIndexSlice(), (x, y))
2050
@skipIfUnsupportedMinOpsetVersion(11)
2051
def test_slice_with_input_step_size(self):
2052
class InputIndexSlice(torch.nn.Module):
2053
def forward(self, x, y, z):
2054
x[:y:z, 0::z, :] = 1
2057
x = torch.zeros((56, 6, 256))
2058
y = torch.tensor(5, dtype=torch.int64)
2059
z = torch.tensor(2, dtype=torch.int64)
2060
self.run_test(InputIndexSlice(), (x, y, z))
2062
@skipIfUnsupportedMinOpsetVersion(10)
2063
@skipScriptTest() # scripting tuple/list append
2064
def test_slice_dynamic(self):
2065
class DynamicSliceExportMod(torch.nn.Module):
2066
def forward(self, x):
2069
results.append(x[: x.size(0) - i, i : x.size(2), i:3])
2070
return tuple(results)
2072
x = torch.rand(5, 5, 5)
2073
y = torch.randn(6, 7, 8)
2075
DynamicSliceExportMod(),
2077
additional_test_inputs=[y],
2078
input_names=["input_1"],
2079
output_names=["output_1"],
2080
dynamic_axes={"input_1": [0, 1, 2], "output_1": [0, 1, 2]},
2083
@skipIfUnsupportedMinOpsetVersion(10)
2084
def test_slice_dynamic_script(self):
2085
class DynamicSliceModel(torch.jit.ScriptModule):
2086
@torch.jit.script_method
2087
def forward(self, x):
2088
return x[1 : x.size(1)]
2090
x = torch.rand(1, 2)
2091
self.run_test(DynamicSliceModel(), x)
2093
@skipIfUnsupportedMinOpsetVersion(10)
2094
def test_slice_dynamic_shape_script(self):
2095
class DynamicSliceModel(torch.nn.Module):
2096
def forward(self, x):
2097
return x.new_zeros(x.shape[1 : x.size(2)])
2099
x = torch.rand(1, 2, 3, 4)
2101
DynamicSliceModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]}
2103
self.run_test(DynamicSliceModel(), x, remained_onnx_input_idx=[])
2105
@skipIfUnsupportedMinOpsetVersion(10)
2106
@skipScriptTest() # scripting tuple/list append
2107
def test_slice_dynamic_to_end(self):
2108
class DynamicSliceExportMod(torch.nn.Module):
2109
def forward(self, x):
2112
results.append(x[:, i:, x.size(2) - 5])
2113
return tuple(results)
2115
x = torch.rand(5, 5, 5)
2117
DynamicSliceExportMod(),
2119
dynamic_axes={"input_1": [0, 1, 2], "output_1": [0, 1, 2]},
2122
def test_square(self):
2123
class Square(torch.nn.Module):
2124
def forward(self, x):
2125
return torch.square(x)
2127
x = torch.randn(2, 3, 4)
2128
self.run_test(Square(), x)
2130
@skipIfUnsupportedMinOpsetVersion(9)
2131
def test_arange_dynamic(self):
2132
class ArangeModel(torch.nn.Module):
2133
def forward(self, input):
2135
torch.arange(input.shape[0]),
2137
torch.arange(start=input.shape[0], end=input.shape[0] + 5),
2140
x = torch.randn(5, 3, 2)
2141
y = torch.randn(8, 3, 2)
2145
additional_test_inputs=[y],
2146
input_names=["input_1"],
2147
output_names=["output_1", "output_2", "output_3"],
2148
dynamic_axes={"input_1": [0], "output_1": [0]},
2151
torch.jit.script(ArangeModel()),
2153
additional_test_inputs=[y],
2154
input_names=["input_1"],
2155
output_names=["output_1", "output_2", "output_3"],
2156
dynamic_axes={"input_1": [0], "output_1": [0]},
2159
@skipIfUnsupportedMinOpsetVersion(9)
2160
def test_dynamic_arange_out(self):
2161
class ArangeOutModel(torch.nn.Module):
2162
def forward(self, end):
2163
out_t = torch.tensor([1], dtype=torch.int64)
2164
return torch.arange(end, out=out_t)
2167
self.run_test(ArangeOutModel(), (x))
2169
@skipIfUnsupportedMinOpsetVersion(9)
2170
def test_dynamic_arange_start_out(self):
2171
class ArangeStartOutModel(torch.nn.Module):
2172
def forward(self, start, end):
2173
out_t = torch.tensor([1], dtype=torch.int64)
2174
return torch.arange(start.size(0), end, out=out_t)
2176
x = torch.randn(2, 3, 4)
2179
ArangeStartOutModel(),
2181
input_names=["x", "y"],
2182
dynamic_axes={"x": [0, 1, 2]},
2184
self.run_test(ArangeStartOutModel(), (x, y), remained_onnx_input_idx=[1])
2186
@skipIfUnsupportedMinOpsetVersion(9)
2187
def test_linspace(self):
2188
class LinspaceModel(torch.nn.Module):
2189
def forward(self, start, end, steps):
2190
return torch.linspace(start, end, steps)
2192
x = torch.tensor(3, dtype=torch.float)
2193
y = torch.tensor(10, dtype=torch.float)
2194
z = torch.tensor(5, dtype=torch.int)
2195
self.run_test(LinspaceModel(), (x, y, z))
2197
@skipIfUnsupportedMinOpsetVersion(9)
2198
def test_linspace_negative_start(self):
2199
class LinspaceModel(torch.nn.Module):
2200
def forward(self, start, end, steps):
2201
return torch.linspace(start, end, steps)
2203
x = torch.tensor(-1, dtype=torch.float)
2204
y = torch.tensor(1, dtype=torch.float)
2205
z = torch.tensor(6, dtype=torch.int)
2206
self.run_test(LinspaceModel(), (x, y, z))
2208
@skipIfUnsupportedMinOpsetVersion(9)
2209
def test_arange_with_floats_out(self):
2210
class ArangeModelEnd(torch.nn.Module):
2211
def forward(self, end):
2212
out_t = torch.tensor([1], dtype=torch.float)
2213
return torch.arange(end, out=out_t)
2215
y = torch.tensor(8.5, dtype=torch.float)
2216
self.run_test(ArangeModelEnd(), (y))
2218
class ArangeModelStep(torch.nn.Module):
2219
def forward(self, start, end):
2220
out_t = torch.tensor([1], dtype=torch.float)
2221
return torch.arange(start.size(0), end, 1.5, out=out_t)
2223
x = torch.randn(2, 3, 4)
2224
y = torch.tensor(8.5, dtype=torch.float)
2228
input_names=["x", "y"],
2229
dynamic_axes={"x": [0, 1, 2]},
2231
self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1])
2233
@skipIfUnsupportedMinOpsetVersion(9)
2234
def test_arange_with_floats(self):
2235
class ArangeModelEnd(torch.nn.Module):
2236
def forward(self, end):
2237
return torch.arange(end)
2239
y = torch.tensor(8.5, dtype=torch.float)
2240
self.run_test(ArangeModelEnd(), (y))
2242
class ArangeModelStep(torch.nn.Module):
2243
def forward(self, start, end):
2244
return torch.arange(start.size(0), end, 1.5)
2246
x = torch.randn(2, 3, 4)
2247
y = torch.tensor(8.5, dtype=torch.float)
2251
input_names=["x", "y"],
2252
dynamic_axes={"x": [0, 1, 2]},
2254
self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1])
2256
class ArangeModelStepNeg(torch.nn.Module):
2257
def forward(self, start, end):
2258
return torch.arange(end, start.size(0), -1.5)
2260
x = torch.randn(2, 3, 4)
2261
y = torch.tensor(8.5, dtype=torch.float)
2263
ArangeModelStepNeg(),
2265
input_names=["x", "y"],
2266
dynamic_axes={"x": [0, 1, 2]},
2268
self.run_test(ArangeModelStepNeg(), (x, y), remained_onnx_input_idx=[1])
2270
class ArangeModelStart(torch.nn.Module):
2271
def forward(self, start, end):
2272
return torch.arange(start.size(0), end)
2274
x = torch.randn(2, 3, 4)
2275
y = torch.tensor(8.5, dtype=torch.float)
2279
input_names=["x", "y"],
2280
dynamic_axes={"x": [0, 1, 2]},
2282
self.run_test(ArangeModelStart(), (x, y), remained_onnx_input_idx=[1])
2284
@skipIfUnsupportedMinOpsetVersion(9)
2285
def test_arange_with_floats_override(self):
2286
class ArangeModelEnd(torch.nn.Module):
2287
def forward(self, end):
2288
return torch.arange(end, dtype=torch.int64)
2290
y = torch.tensor(8.5, dtype=torch.float)
2291
self.run_test(ArangeModelEnd(), (y))
2293
class ArangeModelStep(torch.nn.Module):
2294
def forward(self, start, end):
2295
return torch.arange(start.size(0), end, 1.5, dtype=torch.int64)
2297
x = torch.randn(2, 3, 4)
2298
y = torch.tensor(8.5, dtype=torch.float)
2302
input_names=["x", "y"],
2303
dynamic_axes={"x": [0, 1, 2]},
2305
self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1])
2307
@skipIfUnsupportedMinOpsetVersion(11)
2308
def test_arange_out(self):
2309
class ArangeOutModel(torch.nn.Module):
2310
def forward(self, end):
2311
out_t = torch.tensor([1], dtype=torch.float)
2312
return torch.arange(end, out=out_t)
2314
x = torch.tensor(8.5, dtype=torch.float)
2315
self.run_test(ArangeOutModel(), (x))
2317
@skipIfUnsupportedMinOpsetVersion(11)
2318
def test_arange_start_out(self):
2319
class ArangeStartOutModel(torch.nn.Module):
2320
def forward(self, start, end):
2321
out_t = torch.tensor([1], dtype=torch.float)
2322
return torch.arange(start.size(0), end, out=out_t)
2324
x = torch.randn(2, 3, 4)
2325
y = torch.tensor(8.5, dtype=torch.float)
2327
ArangeStartOutModel(),
2329
input_names=["x", "y"],
2330
dynamic_axes={"x": [0, 1, 2]},
2332
self.run_test(ArangeStartOutModel(), (x, y), remained_onnx_input_idx=[1])
2334
@skipIfUnsupportedMinOpsetVersion(11)
2335
def test_arange_no_type(self):
2336
class ArangeModel(torch.nn.Module):
2337
def forward(self, end):
2338
return torch.arange(end), torch.arange(0, end)
2340
x = torch.tensor(6.2, dtype=torch.float)
2341
self.run_test(ArangeModel(), x)
2343
@skipIfUnsupportedMinOpsetVersion(9)
2344
def test_size(self):
2345
class SizeModel(torch.nn.Module):
2346
def forward(self, input):
2348
torch.arange(input.size(0)),
2349
torch.arange(input.size(-1)),
2350
torch.ones(input.shape),
2353
x = torch.randn(5, 3, 2)
2354
self.run_test(SizeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
2355
self.run_test(SizeModel(), x, remained_onnx_input_idx=[])
2357
@skipIfUnsupportedMinOpsetVersion(9)
2358
@skipScriptTest() # x.stride() not scriptable
2359
def test_as_strided(self):
2360
class Model(torch.nn.Module):
2361
def forward(self, x):
2362
chunk_size = list(x.size())
2363
chunk_size[1] = chunk_size[1] * 2 - 1
2364
chunk_stride = list(x.stride())
2365
chunk_stride[1] = chunk_stride[1] // 2
2366
return x.as_strided(
2367
(3, 3, 3), (1, 4, 2), storage_offset=2
2368
), x.as_strided(chunk_size, chunk_stride)
2370
x = torch.randn(5, 8, 7)
2371
self.run_test(Model(), x)
2373
@skipScriptTest() # Ellipses followed by tensor indexing not scriptable
2374
def test_tensor_index_advanced_indexing_ellipsis(self):
2375
class MyModel(torch.nn.Module):
2376
def forward(self, input):
2377
return input[..., torch.tensor([2, 1]), torch.tensor([0, 3])]
2379
m1 = torch.randn(3, 4, 5, 6, 7)
2380
self.run_test(MyModel(), (m1,))
2382
def test_tensor_index_advanced_indexing(self):
2383
class MyModel(torch.nn.Module):
2384
def forward(self, input):
2387
torch.tensor([[0, 2], [1, 1]]),
2389
torch.tensor([2, 1]),
2390
torch.tensor([0, 3]),
2393
m1 = torch.randn(3, 4, 5, 6, 7)
2394
self.run_test(MyModel(), (m1,))
2396
class MyModel(torch.nn.Module):
2397
def forward(self, input):
2399
:, torch.tensor([0, 2]), None, 2:4, torch.tensor([[1, 3], [4, 0]])
2402
self.run_test(MyModel(), (m1,))
2404
class MyModel(torch.nn.Module):
2405
def forward(self, input):
2408
torch.tensor([0, 2]),
2411
torch.tensor([[1], [4]]),
2414
self.run_test(MyModel(), (m1,))
2416
def test_tensor_index_advanced_indexing_consecutive(self):
2417
class MyModel(torch.nn.Module):
2418
def forward(self, input):
2420
:, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None
2423
m1 = torch.randn(3, 4, 5, 6, 7)
2424
self.run_test(MyModel(), (m1,))
2426
@skipIfUnsupportedMinOpsetVersion(11)
2427
def test_index_put(self):
2428
class IndexPutModel(torch.nn.Module):
2429
def forward(self, x, ind, update):
2433
x = torch.randn(3, 4)
2434
ind = torch.tensor([1], dtype=torch.long)
2435
update = torch.ones(4)
2436
self.run_test(IndexPutModel(), (x, ind, update))
2438
@skipIfUnsupportedMinOpsetVersion(11)
2439
def test_index_put_singular(self):
2440
class IndexPutBoolModel(torch.nn.Module):
2441
def forward(self, mask, indices):
2442
mask[indices] = True
2445
mask = torch.zeros(100, dtype=torch.bool)
2446
indices = (torch.rand(25) * mask.shape[0]).to(torch.int64)
2447
self.run_test(IndexPutBoolModel(), (mask, indices))
2449
class IndexPutFloatModel(torch.nn.Module):
2450
def forward(self, mask, indices):
2451
mask[indices] = torch.tensor(5.5)
2454
mask = torch.rand(100, dtype=torch.float)
2455
indices = (torch.rand(50) * mask.shape[0]).to(torch.int64)
2456
self.run_test(IndexPutFloatModel(), (mask, indices))
2458
@skipIfUnsupportedMinOpsetVersion(11)
2459
def test_index_put_accumulate(self):
2460
class IndexPutModel(torch.nn.Module):
2461
def forward(self, x, ind, update):
2462
return x.index_put((ind,), update, accumulate=True)
2464
x = torch.randn(3, 4)
2465
ind = torch.tensor([2], dtype=torch.long)
2466
update = torch.ones(4)
2467
self.run_test(IndexPutModel(), (x, ind, update))
2469
@skipIfUnsupportedMinOpsetVersion(11)
2470
def test_index_put_slice_index(self):
2471
class IndexPutModel(torch.nn.Module):
2472
def forward(self, x, update):
2473
x[1:2, 1:3, torch.tensor([1])] += update
2476
x = torch.randn(3, 4, 5)
2477
update = torch.tensor([10, 15]).view(1, 2, 1)
2478
self.run_test(IndexPutModel(), (x, update))
2480
class IndexPutModel2(torch.nn.Module):
2481
def forward(self, x, update):
2482
x[torch.tensor([0, 2]), torch.tensor([1, 2])] += update
2485
x = torch.randn(3, 4, 5)
2486
update = torch.randn(2, 5)
2487
self.run_test(IndexPutModel2(), (x, update))
2489
class IndexPutModel3(torch.nn.Module):
2490
def forward(self, x, update):
2491
x[torch.tensor([0, 2]), 1:2] += update
2494
x = torch.randn(3, 4, 5)
2495
update = torch.tensor([10, 15]).view(2, 1, 1)
2496
self.run_test(IndexPutModel3(), (x, update))
2498
class IndexPutModel4(torch.nn.Module):
2499
def forward(self, x, update):
2500
x[torch.tensor([0, 2]), 2] += update
2503
x = torch.randn(3, 4, 5)
2504
update = torch.tensor([10, 15]).view(2, 1)
2505
self.run_test(IndexPutModel4(), (x, update))
2507
class IndexPutModel5(torch.nn.Module):
2508
def forward(self, x, update):
2509
x[1:3, torch.tensor([0, 2]), 2] += update
2512
x = torch.randn(3, 4, 5)
2513
update = torch.tensor([10, 15]).view(2, 1)
2514
self.run_test(IndexPutModel5(), (x, update))
2516
class IndexPutModel6(torch.nn.Module):
2517
def forward(self, x, update):
2521
x = torch.randn(3, 4, 5)
2522
update = torch.arange(2 * 5).to(torch.float).view(2, 5)
2523
self.run_test(IndexPutModel6(), (x, update))
2525
class IndexPutModel7(torch.nn.Module):
2526
def forward(self, x, update):
2530
x = torch.randn(3, 4, 5)
2531
update = torch.arange(2 * 5).to(torch.float).view(2, 5)
2532
self.run_test(IndexPutModel7(), (x, update))
2534
class IndexPutModel8(torch.nn.Module):
2535
def forward(self, x, update):
2539
x = torch.randn(3, 4, 5)
2540
update = torch.arange(3 * 5).to(torch.float).view(3, 5)
2541
self.run_test(IndexPutModel8(), (x, update))
2543
class IndexPutModel9(torch.nn.Module):
2544
def forward(self, poses):
2546
x = poses[:, :, 0] - (w - 1) // 2
2547
boxes = torch.zeros([poses.shape[0], 17, 4])
2551
x = torch.zeros([2, 17, 3], dtype=torch.int64)
2552
self.run_test(IndexPutModel9(), (x,))
2554
class IndexPutModel10(torch.nn.Module):
2555
def forward(self, x, ind, update):
2556
x[ind, 1:3] = update.view(1, 1, 1, 5).expand(2, 2, 2, 5)
2559
x = torch.randn(3, 4, 5)
2560
ind = torch.tensor([[0, 2], [1, 1]])
2561
update = torch.randn(5)
2562
self.run_test(IndexPutModel10(), (x, ind, update))
2564
@skipIfUnsupportedMinOpsetVersion(11)
2565
@skipScriptTest() # Ellipses followed by tensor indexing not scriptable
2566
def test_index_put_ellipsis(self):
2567
class IndexPutModel(torch.nn.Module):
2568
def forward(self, x, update):
2569
x[..., torch.tensor([2, 1, 3]), 2:4] += update
2572
x = torch.randn(3, 4, 5, 6, 7)
2573
update = torch.randn(3, 1, 1, 3, 2)
2574
self.run_test(IndexPutModel(), (x, update))
2576
class IndexPutModel2(torch.nn.Module):
2577
def forward(self, x, update):
2578
x[2, ..., torch.tensor([2, 1, 3]), 2:4] += update
2581
x = torch.randn(3, 4, 5, 6, 7)
2582
update = torch.randn(4, 1, 3, 2)
2583
self.run_test(IndexPutModel2(), (x, update))
2585
@skipIfUnsupportedMinOpsetVersion(11)
2586
def test_index_put_loop(self):
2588
def ngram_attention_bias(
2589
sequence_length: int, ngram: int, device: torch.device, dtype: torch.dtype
2592
(ngram, sequence_length), device=device, dtype=dtype
2594
for stream_idx in range(ngram):
2595
for i in range(sequence_length):
2597
bias[stream_idx, i] = 5
2601
for stream_idx in range(ngram):
2602
for i in range(sequence_length):
2603
bias[stream_idx, i] = 5
2607
class ScriptModel(torch.nn.Module):
2611
self.max_target_positions = 512
2613
def forward(self, hidden_states):
2614
seq_length, batch_size = hidden_states.shape[:2]
2615
predict_causal_mask = ngram_attention_bias(
2616
self.max_target_positions,
2618
hidden_states.device,
2619
hidden_states.dtype,
2621
predict_causal_mask = predict_causal_mask[:, :seq_length]
2622
return predict_causal_mask
2624
x = torch.randn(6, 2)
2625
y = torch.randn(4, 1)
2630
dynamic_axes={"x": {0: "seq_length", 1: "batch_size"}},
2631
additional_test_inputs=[y],
2634
@skipIfUnsupportedMinOpsetVersion(11)
2635
def test_copy_(self):
2636
class CopyModel(torch.nn.Module):
2637
def forward(self, x, data):
2641
x = torch.randn(3, 4)
2642
update = torch.randn(2, 4)
2643
self.run_test(CopyModel(), (x, update))
2645
# mixed slice and select
2646
class CopyModel2(torch.nn.Module):
2647
def forward(self, x, data):
2651
x = torch.randn(3, 4)
2652
update = torch.tensor([0], dtype=torch.float32)
2653
self.run_test(CopyModel2(), (x, update))
2655
update = torch.tensor([2, 3], dtype=torch.float32)
2656
self.run_test(CopyModel2(), (x, update))
2658
update = torch.randn(2)
2659
self.run_test(CopyModel2(), (x, update))
2661
class CopyModel3(torch.nn.Module):
2662
def forward(self, x, data):
2666
x = torch.randn(3, 4)
2667
update = torch.tensor([0], dtype=torch.float32)
2668
self.run_test(CopyModel3(), (x, update))
2670
update = torch.tensor([2, 3], dtype=torch.float32)
2671
self.run_test(CopyModel3(), (x, update))
2673
update = torch.randn(2)
2674
self.run_test(CopyModel3(), (x, update))
2676
class CopyModel4(torch.nn.Module):
2677
def forward(self, x, ind, data):
2681
x = torch.randn(3, 4)
2682
ind = torch.tensor(2)
2683
data = torch.randn(4)
2684
self.run_test(CopyModel4(), (x, ind, data))
2686
class CopyModel5(torch.nn.Module):
2687
def forward(self, x, mask):
2688
if mask is not None:
2692
x = torch.randn(3, 4)
2693
mask = torch.randn(3, 1)
2694
self.run_test(CopyModel5(), (x, mask))
2696
@skipIfUnsupportedMinOpsetVersion(11)
2697
@skipScriptTest() # Model not scriptable (output with shape doesn't match the broadcast shape)
2698
def test_copy_tracing(self):
2699
class CopyModel(torch.nn.Module):
2700
def forward(self, x, data):
2704
x = torch.randn(3, 4)
2705
update = torch.randn(1, 2)
2706
self.run_test(CopyModel(), (x, update))
2708
@skipIfUnsupportedMinOpsetVersion(11)
2709
def test_copy_ellipsis(self):
2710
class CopyModel(torch.nn.Module):
2711
def forward(self, x, update):
2715
x = torch.randn(2, 3, 4)
2716
update = torch.ones(1)
2717
self.run_test(CopyModel(), (x, update))
2719
x = torch.randn(2, 3, 4, 5, 6)
2720
update = torch.ones(1)
2721
self.run_test(CopyModel(), (x, update))
2723
@skipIfUnsupportedMinOpsetVersion(11)
2724
def test_copy_ellipsis_script(self):
2725
class CopyModel(torch.nn.Module):
2726
def forward(self, x, update):
2727
# Insert reshape node to ensure no shape/type info for
2728
# x in scripting, without onnx shape inference.
2729
x = x.reshape(4, 3, 5, 6)
2730
x[2, ..., 1:3] = update
2733
x = torch.randn(3, 4, 5, 6)
2735
update = torch.ones(1)
2736
self.run_test(CopyModel(), (x, update))
2738
@skipIfUnsupportedMinOpsetVersion(10)
2739
def test_flip(self):
2740
class MyModule(torch.nn.Module):
2741
def forward(self, x):
2742
return torch.flip(x, dims=[0])
2744
x = torch.tensor(np.arange(6.0).reshape(2, 3))
2745
self.run_test(MyModule(), x)
2747
@skipIfUnsupportedMinOpsetVersion(9)
2748
def test_randint(self):
2749
class RandInt(torch.nn.Module):
2750
def forward(self, x):
2751
randint = torch.randint(1, 10, x.shape)
2755
x = torch.randn(2, 3, 4)
2756
self.run_test(RandInt(), x)
2758
@skipIfUnsupportedMinOpsetVersion(9)
2759
def test_randint_value(self):
2760
class RandInt(torch.nn.Module):
2761
def forward(self, x):
2762
# This randint call always returns 3
2763
return torch.randint(3, 4, x.shape) + x
2765
x = torch.randn(2, 3, 4)
2766
self.run_test(RandInt(), x)
2768
@skipIfUnsupportedMinOpsetVersion(9)
2769
def test_randint_like(self):
2770
class RandInt(torch.nn.Module):
2771
def forward(self, x):
2772
# This randint call always returns 3
2773
return torch.randint_like(x, 3, 4) + x
2775
x = torch.randn(2, 3, 4)
2776
self.run_test(RandInt(), x)
2778
def test_randn(self):
2779
class RandN(torch.nn.Module):
2780
def forward(self, x):
2781
return torch.mul(x, (torch.randn(2, 3, 4) + x).size(0))
2783
x = torch.randn(2, 3, 4)
2784
self.run_test(RandN(), x)
2786
def test_rand(self):
2787
class Rand(torch.nn.Module):
2788
def forward(self, x):
2789
return torch.mul(x, (torch.rand(2, 3, 4) + x).size(0))
2791
x = torch.randn(2, 3, 4)
2792
self.run_test(Rand(), x)
2794
def test_randn_dtype(self):
2795
class RandN(torch.nn.Module):
2796
def forward(self, x):
2797
# The resulting node's dtype should be double.
2800
* torch.randn(2, 3, 4, dtype=torch.double)
2801
* torch.tensor(0, dtype=torch.float32)
2804
x = torch.randn(2, 3, 4)
2805
self.run_test(RandN(), x)
2807
def test_rand_dtype(self):
2808
class Rand(torch.nn.Module):
2809
def forward(self, x):
2810
# The resulting node's dtype should be double.
2813
* torch.rand(2, 3, 4, dtype=torch.double)
2814
* torch.tensor(0, dtype=torch.float32)
2817
x = torch.randn(2, 3, 4)
2818
self.run_test(Rand(), x)
2820
@skipIfUnsupportedMinOpsetVersion(9)
2821
def test_randn_dynamic_size(self):
2822
class RandN(torch.nn.Module):
2823
def forward(self, x):
2824
return torch.mul(x, torch.randn(x.size()).size(1))
2826
x = torch.randn(2, 3, 4)
2827
self.run_test(RandN(), x)
2829
@skipIfUnsupportedMinOpsetVersion(9)
2830
def test_rand_dynamic_size(self):
2831
class Rand(torch.nn.Module):
2832
def forward(self, x):
2833
return torch.mul(x, torch.rand(x.size()).size(1))
2835
x = torch.randn(2, 3, 4)
2836
self.run_test(Rand(), x)
2838
def test_randn_like(self):
2839
class RandNLike(torch.nn.Module):
2840
def forward(self, x):
2841
return torch.mul(x, torch.randn_like(x).size(0))
2843
x = torch.randn(2, 3, 4)
2844
self.run_test(RandNLike(), x)
2845
self.run_test(torch.jit.script(RandNLike()), x)
2847
def test_rand_like(self):
2848
class RandLike(torch.nn.Module):
2849
def forward(self, x):
2850
return torch.mul(x, torch.rand_like(x).size(0))
2852
x = torch.randn(2, 3, 4)
2853
self.run_test(RandLike(), x)
2854
self.run_test(torch.jit.script(RandLike()), x)
2856
def test_randn_like_dtype(self):
2857
class RandNLike(torch.nn.Module):
2858
def forward(self, x):
2859
# The resulting node's dtype should be double.
2862
* torch.randn_like(x, dtype=torch.double)
2863
* torch.tensor(0, dtype=torch.float32)
2866
x = torch.randn(2, 3, 4)
2867
self.run_test(RandNLike(), x)
2869
def test_rand_like_dtype(self):
2870
class RandLike(torch.nn.Module):
2871
def forward(self, x):
2872
# The resulting node's dtype should be double.
2875
* torch.rand_like(x, dtype=torch.double)
2876
* torch.tensor(0, dtype=torch.float32)
2879
x = torch.randn(2, 3, 4)
2880
self.run_test(RandLike(), x)
2882
def test_bernoulli(self):
2883
class Bernoulli(torch.nn.Module):
2884
def forward(self, x):
2885
return torch.mul(x, torch.bernoulli(x).size(0))
2887
x = torch.empty(3, 3).uniform_(0, 1)
2888
self.run_test(Bernoulli(), x)
2890
x = torch.empty(2, 3, 3, dtype=torch.double).uniform_(0, 1)
2891
self.run_test(Bernoulli(), x)
2893
def test_bernoulli_p(self):
2894
class Bernoulli_float(torch.nn.Module):
2895
def forward(self, x):
2896
return torch.mul(x, torch.bernoulli(x, 0.2).size(0))
2898
class Bernoulli_tensor(torch.nn.Module):
2899
def forward(self, x):
2900
return torch.mul(x, torch.rand_like(x).bernoulli_(x).size(0))
2902
x = torch.rand(3, 3)
2903
self.run_test(Bernoulli_float(), x)
2904
self.run_test(Bernoulli_tensor(), x)
2906
x = torch.rand(2, 3, 3, dtype=torch.double)
2907
self.run_test(Bernoulli_float(), x)
2908
self.run_test(Bernoulli_tensor(), x)
2910
@unittest.skip("Bug in ORT, skip test until rel-1.11.")
2911
@skipIfUnsupportedMinOpsetVersion(14)
2912
def test_reshape_allowzero(self):
2913
class ReshapeModel(torch.nn.Module):
2914
def forward(self, x):
2915
x = x.reshape(3, 4, 0)
2918
x = torch.randn(0, 3, 4)
2919
self.run_test(ReshapeModel(), x)
2921
def test_reshape_different_rank(self):
2922
class ReshapeModel(torch.nn.Module):
2923
def forward(self, x):
2924
x = x.reshape(-1, 2, 4, 4, 5, 5)
2927
x = torch.randn(1, 32, 5, 5)
2928
self.run_test(ReshapeModel(), x)
2930
def _interpolate(self, x, mode, use_size, is_upsample, align_corners=False):
2931
class MyModel(torch.nn.Module):
2943
def __init__(self, mode, use_size, is_upsample, align_corners):
2946
self.use_size = use_size
2947
self.is_upsample = is_upsample
2948
self.align_corners = align_corners
2949
self.scale = 2.0 if self.is_upsample else 0.5
2950
self.size = 24 if self.is_upsample else 2
2952
self.scale_array = [2.3]
2953
self.size_array = [16]
2955
self.scale_array = [2.3, 3.1]
2956
self.size_array = [16, 32]
2958
self.scale_array = [2.3, 3.1, 4.6]
2959
self.size_array = [16, 32, 64]
2961
def forward(self, x):
2963
if self.align_corners:
2964
return torch.nn.functional.interpolate(
2965
x, mode=self.mode, size=self.size, align_corners=True
2966
), torch.nn.functional.interpolate(
2967
x, mode=self.mode, size=self.size_array, align_corners=True
2969
return torch.nn.functional.interpolate(
2970
x, mode=self.mode, size=self.size
2971
), torch.nn.functional.interpolate(
2972
x, mode=self.mode, size=self.size_array
2974
if self.align_corners:
2975
return torch.nn.functional.interpolate(
2978
scale_factor=self.scale,
2979
recompute_scale_factor=False,
2980
), torch.nn.functional.interpolate(
2983
scale_factor=self.scale_array,
2984
recompute_scale_factor=False,
2986
return torch.nn.functional.interpolate(
2989
scale_factor=self.scale,
2990
recompute_scale_factor=False,
2991
), torch.nn.functional.interpolate(
2994
scale_factor=self.scale_array,
2995
recompute_scale_factor=False,
2998
model = MyModel(mode, use_size, is_upsample, align_corners)
2999
self.run_test(model, x, atol=1e-6)
3001
def _interpolate_tests(self, is_upsample):
3002
# - cubic mode is not supported for opsets below 11;
3003
# - linear mode does not match for opsets below 11;
3004
modes = ["nearest", "linear", "bicubic"]
3005
if self.opset_version < 11:
3008
torch.randn(1, 2, 6, requires_grad=True),
3009
torch.randn(1, 2, 4, 6, requires_grad=True),
3010
torch.randn(1, 2, 4, 4, 6, requires_grad=True),
3016
# TODO: enable bicubic downsample when ORT precision loss fixed
3017
if mode == "bicubic" and xi.dim() != 4:
3019
elif mode == "linear":
3021
# TODO : enable when linear mode is implemented for 1d inputs in ORT
3026
# TODO : enable when linear mode is implemented for 3d inputs in ORT
3027
mode_i = "trilinear"
3029
self._interpolate(xi, mode_i, True, is_upsample)
3030
# test with align_corners if supported
3031
if mode != "nearest":
3032
self._interpolate(xi, mode_i, True, is_upsample, True)
3033
# the following cases, require dynamic sizes/scales,
3034
# which which is not supported for opset_version < 9
3035
if self.opset_version >= 9:
3036
self._interpolate(xi, mode_i, True, is_upsample)
3037
# test with align_corners if supported
3038
if mode != "nearest":
3039
self._interpolate(xi, mode_i, False, is_upsample, True)
3040
self._interpolate(xi, mode_i, False, is_upsample)
3042
# ONNX export failed on interpolate scripting because dynamic size not supported for opsets below 9.
3043
@skipIfUnsupportedMinOpsetVersion(9)
3044
def test_interpolate_upsample(self):
3045
self._interpolate_tests(True)
3047
@skipIfUnsupportedMaxOpsetVersion(8)
3048
@skipScriptTest() # Scripting supported for opsets > 8. See test_interpolate_upsample
3049
def test_interpolate_upsample_trace(self):
3050
self._interpolate_tests(True)
3052
@skipIfUnsupportedMinOpsetVersion(9)
3053
def test_interpolate_function_substitution(self):
3054
class ScriptModel(torch.jit.ScriptModule):
3055
@torch.jit.script_method
3056
def forward(self, x):
3057
return torch.nn.functional.interpolate(
3058
x, mode="nearest", scale_factor=2.0
3061
class ScriptModule(torch.jit.ScriptModule):
3064
self.submodule = ScriptModel()
3066
@torch.jit.script_method
3067
def forward(self, input):
3068
return self.submodule(input)
3070
x = torch.randn(1, 2, 4, 4, 6)
3071
self.run_test(ScriptModule(), (x,))
3074
def script_method(x):
3075
return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2.0)
3077
class TracingModule(torch.nn.Module):
3078
def forward(self, x):
3079
return script_method(x)
3081
self.run_test(TracingModule(), (x,))
3083
@skipIfUnsupportedMinOpsetVersion(10)
3084
def test_interpolate_downsample(self):
3085
self._interpolate_tests(False)
3087
@skipIfUnsupportedMinOpsetVersion(11)
3088
def test_interpolate_half_pixel(self):
3089
# testing whether it uses "half_pixel" or "pytorch_half_pixel"
3090
# see https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize
3092
class MyModel(torch.nn.Module):
3093
def __init__(self, mode, size):
3098
def forward(self, x):
3099
return torch.nn.functional.interpolate(
3100
x, mode=self.mode, size=self.size
3103
modes = ["linear", "bicubic"]
3105
torch.randn(1, 2, 6, requires_grad=True),
3106
torch.randn(1, 2, 4, 6, requires_grad=True),
3107
torch.randn(1, 2, 4, 4, 6, requires_grad=True),
3112
if mode == "bicubic" and xi.dim() != 4:
3114
elif mode == "linear":
3118
mode_i = "trilinear"
3119
for i in range(xi.dim() - 2):
3120
size = list(xi.shape[2:])
3122
self.run_test(MyModel(mode_i, size), xi)
3124
@skipIfUnsupportedMinOpsetVersion(11)
3125
def test_interpolate_no_shape(self):
3126
class MyModel(torch.jit.ScriptModule):
3127
@torch.jit.script_method
3128
def forward(self, x, y):
3130
out1 = torch.nn.functional.interpolate(
3131
x, mode="bilinear", size=(16, 16), align_corners=False
3133
out2 = torch.nn.functional.interpolate(
3134
x, mode="nearest", size=(int(y.size(0)), int(y.size(1)))
3138
x = torch.randn(1, 2, 4, 4, requires_grad=True)
3139
y = torch.randn(16, 16, requires_grad=True)
3143
input_names=["x", "y"],
3144
dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1]},
3146
self.run_test(MyModel(), (x, y), remained_onnx_input_idx=[0])
3148
@skipScriptTest() # scripting raises OnnxRuntimeError
3149
def test_interpolate_adaptive_pooling_error(self):
3150
x = torch.randn(1, 2, 6, requires_grad=True)
3151
with self.assertRaises(RuntimeError) as cm:
3152
self._interpolate(x, "area", True, True)
3154
with self.assertRaises(RuntimeError) as cm:
3155
self._interpolate(x, "area", False, True)
3157
def test_groupnorm(self):
3158
model = torch.nn.GroupNorm(3, 6, 0.002)
3159
x = torch.randn(4, 6, 36, 36, 18)
3160
self.run_test(model, x)
3162
model = torch.nn.GroupNorm(1, 6, 0.002)
3163
x = torch.randn(4, 6, 180, 180)
3164
self.run_test(model, x)
3166
model = torch.nn.GroupNorm(6, 6, 0.002)
3167
x = torch.randn(4, 6, 180, 180)
3168
self.run_test(model, x)
3170
def test_groupnorm_noaffine(self):
3171
model = torch.nn.GroupNorm(4, 8, 0.002, affine=False)
3172
x = torch.randn(3, 8, 224, 224)
3173
self.run_test(model, x)
3175
model = torch.nn.GroupNorm(1, 6, 0.002, affine=False)
3176
x = torch.randn(4, 6, 180, 180)
3177
self.run_test(model, x)
3179
model = torch.nn.GroupNorm(6, 6, 0.002, affine=False)
3180
x = torch.randn(4, 6, 180, 180)
3181
self.run_test(model, x)
3183
@skipIfUnsupportedMinOpsetVersion(9)
3184
def test_list_unpack_scripted(self):
3185
class ListUnpack(torch.nn.Module):
3186
def forward(self, x):
3188
return x.new_zeros((a, b))
3190
x = torch.randn(2, 3)
3192
torch.jit.script(ListUnpack()),
3195
dynamic_axes={"x": [0, 1]},
3197
self.run_test(torch.jit.script(ListUnpack()), x, remained_onnx_input_idx=[])
3199
@skipIfUnsupportedMinOpsetVersion(9)
3200
def test_list_unpack_scripted_runs_without_error_with_constructed_list_as_input(
3203
class PackUnpack(torch.nn.Module):
3204
"""Create and unpack a list of tensors.
3206
When scripted, it should produce a graph similar to
3209
graph(%self : __torch__.PackUnpack,
3212
%packed.1 : Tensor[] = prim::ListConstruct(%a.1, %b.1)
3213
%c.1 : Tensor, %8 : Tensor = prim::ListUnpack(%packed.1)
3218
def forward(self, a, b):
3224
torch.jit.script(PackUnpack()),
3225
(torch.tensor(0), torch.tensor([42])),
3226
remained_onnx_input_idx=[0],
3229
@skipIfUnsupportedMinOpsetVersion(9)
3230
def test_list_unpack_slice_scripted(self):
3231
class ListUnpackSlice(torch.nn.Module):
3232
def forward(self, x):
3234
return x.new_zeros((a, b))
3236
x = torch.randn(2, 3, 4, 5)
3238
torch.jit.script(ListUnpackSlice()),
3241
dynamic_axes={"x": [0, 1, 2, 3]},
3244
torch.jit.script(ListUnpackSlice()), x, remained_onnx_input_idx=[]
3249
class PowModule(torch.nn.Module):
3250
def forward(self, x, y):
3253
x = torch.randn(2, 3, 4)
3254
y = torch.randn(2, 3, 4)
3255
self.run_test(PowModule(), (x, y))
3257
x = torch.randint(10, (2, 3, 4))
3258
y = torch.randint(10, (2, 3, 4)).to(dtype=torch.int32)
3259
self.run_test(PowModule(), (x, y))
3261
x = torch.randint(10, (2, 3, 4))
3262
y = torch.randint(10, (2, 3, 4))
3263
self.run_test(PowModule(), (x, y))
3265
x = torch.randn(2, 3, 4).to(dtype=torch.float64)
3266
y = torch.randint(10, (2, 3, 4))
3267
self.run_test(PowModule(), (x, y))
3269
class PowModule2(torch.nn.Module):
3270
def forward(self, x):
3271
return torch.pow(2, x)
3273
x = torch.randn(1, 10)
3274
self.run_test(PowModule2(), (x,))
3276
x = torch.randint(10, (2, 3, 4))
3277
self.run_test(PowModule2(), (x,))
3279
x = torch.randn(1, 10).to(dtype=torch.float64)
3280
self.run_test(PowModule2(), (x,))
3282
class PowModule3(torch.nn.Module):
3283
def forward(self, x, y):
3284
return y[torch.pow(2, x)]
3286
x = torch.randint(5, (2, 3, 4))
3288
self.run_test(PowModule3(), (x, y))
3290
# the arithmeticOps(Add\Sub\Mul\Div\Gemm\Pow\Mod) with low precision include unit8 will be failed in ORT
3291
# add to(dtype=torch.long) to avoid ORT output type does not match expected type.
3292
# will be fixed in ONNX version 14.
3293
@skipIfUnsupportedMaxOpsetVersion(13)
3295
def test_arithmeticOps_with_low_precision(self):
3296
class AddModule(torch.nn.Module):
3297
def forward(self, x, y):
3300
class SubModule(torch.nn.Module):
3301
def forward(self, x, y):
3304
class MulModule(torch.nn.Module):
3305
def forward(self, x, y):
3308
class DivModule(torch.nn.Module):
3309
def forward(self, x, y):
3312
class PowModule(torch.nn.Module):
3313
def forward(self, x, y):
3316
x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3317
y = torch.tensor([2, 3, 5], dtype=torch.uint8)
3318
z = torch.tensor([1], dtype=torch.uint8)
3319
self.run_test(AddModule(), (x, y))
3320
self.run_test(SubModule(), (x, y))
3321
self.run_test(MulModule(), (x, y))
3322
self.run_test(DivModule(), (x, y))
3323
self.run_test(PowModule(), (x, z))
3325
x = torch.tensor([2, 3, 5], dtype=torch.int8)
3326
y = torch.tensor([2, 3, 5], dtype=torch.int8)
3327
z = torch.tensor([1], dtype=torch.int8)
3328
self.run_test(AddModule(), (x, y))
3329
self.run_test(SubModule(), (x, y))
3330
self.run_test(MulModule(), (x, y))
3331
self.run_test(DivModule(), (x, y))
3332
self.run_test(PowModule(), (x, z))
3334
x = torch.tensor([2, 3, 5], dtype=torch.int16)
3335
y = torch.tensor([2, 3, 5], dtype=torch.int16)
3336
z = torch.tensor([1], dtype=torch.int16)
3337
self.run_test(AddModule(), (x, y))
3338
self.run_test(SubModule(), (x, y))
3339
self.run_test(MulModule(), (x, y))
3340
self.run_test(DivModule(), (x, y))
3341
self.run_test(PowModule(), (x, z))
3343
x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3344
y = torch.tensor([2, 3, 5], dtype=torch.float32)
3345
z = torch.tensor([1], dtype=torch.float64)
3346
self.run_test(AddModule(), (x, y))
3347
self.run_test(SubModule(), (x, y))
3348
self.run_test(MulModule(), (x, y))
3349
self.run_test(DivModule(), (x, y))
3350
self.run_test(PowModule(), (x, z))
3352
x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3353
y = torch.tensor([2, 3, 5], dtype=torch.int64)
3354
z = torch.tensor([1], dtype=torch.int32)
3355
self.run_test(AddModule(), (x, y))
3356
self.run_test(SubModule(), (x, y))
3357
self.run_test(MulModule(), (x, y))
3358
self.run_test(DivModule(), (x, y))
3359
self.run_test(PowModule(), (x, z))
3361
def test_mul_bool(self):
3362
class MyModel(torch.nn.Module):
3363
def forward(self, x, y):
3364
return torch.mul(x, y)
3366
x_t = torch.tensor([True, False, True, False])
3367
y_t = torch.tensor([True, True, False, False])
3368
z_t = torch.tensor([1.0, 2.0, 3.0, 0.0])
3369
self.run_test(MyModel(), (x_t, y_t))
3370
self.run_test(MyModel(), (x_t, z_t))
3371
self.run_test(MyModel(), (z_t, y_t))
3373
# fmod was added in version 10
3374
@skipIfUnsupportedMinOpsetVersion(10)
3375
@skipIfUnsupportedMaxOpsetVersion(13)
3376
def test_mod_with_low_precision(self):
3377
class ModModule(torch.nn.Module):
3378
def forward(self, x, y):
3379
return torch.fmod(x, y).to(dtype=torch.long)
3381
x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3382
y = torch.tensor([2, 3, 5], dtype=torch.uint8)
3383
self.run_test(ModModule(), (x, y))
3385
x = torch.tensor([2, 3, 5], dtype=torch.int8)
3386
y = torch.tensor([2, 3, 5], dtype=torch.int8)
3387
self.run_test(ModModule(), (x, y))
3389
x = torch.tensor([2, 3, 5], dtype=torch.int16)
3390
y = torch.tensor([2, 3, 5], dtype=torch.int16)
3391
self.run_test(ModModule(), (x, y))
3393
x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3394
y = torch.tensor([2, 3, 5], dtype=torch.int32)
3395
self.run_test(ModModule(), (x, y))
3397
x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3398
y = torch.tensor([2, 3, 5], dtype=torch.float64)
3399
self.run_test(ModModule(), (x, y))
3401
@skipIfUnsupportedMinOpsetVersion(9)
3402
def test_empty_constant_shape(self):
3403
class Zeros(torch.nn.Module):
3404
def forward(self, x):
3409
x = torch.tensor(42.0)
3410
self.run_test(Zeros(), x)
3412
class Ones(torch.nn.Module):
3413
def forward(self, x):
3418
x = torch.tensor(42.0)
3419
self.run_test(Ones(), x)
3421
class Full(torch.nn.Module):
3422
def forward(self, x):
3423
y = torch.full((), 1.0)
3427
x = torch.tensor(42.0)
3428
self.run_test(Full(), x)
3430
class Empty(torch.nn.Module):
3431
def forward(self, x):
3432
y = torch.empty(()).fill_(0)
3436
x = torch.tensor(42.0)
3437
self.run_test(Empty(), x)
3440
class StandardDeviation(torch.nn.Module):
3441
def forward(self, input):
3442
return torch.std(input, unbiased=False)
3444
x = torch.randn(2, 3, 4)
3445
model = StandardDeviation()
3446
self.run_test(model, x)
3448
class StandardDeviationUnbiased(torch.nn.Module):
3449
def forward(self, input):
3450
return torch.std(input, unbiased=True)
3452
model = StandardDeviationUnbiased()
3453
self.run_test(model, x)
3455
def test_std_along_dims(self):
3456
class StandardDeviation(torch.nn.Module):
3457
def forward(self, input):
3458
return torch.std(input, dim=(0, 1), unbiased=False)
3460
x = torch.randn(2, 3, 4)
3461
model = StandardDeviation()
3462
self.run_test(model, x)
3464
class StandardDeviationUnbiased(torch.nn.Module):
3465
def forward(self, input):
3466
return torch.std(input, dim=(0, 1), unbiased=True)
3468
x = torch.randn(2, 3, 4)
3469
model = StandardDeviationUnbiased()
3470
self.run_test(model, x)
3472
def test_std_keepdim(self):
3473
class StandardDeviation(torch.nn.Module):
3474
def forward(self, input):
3475
return torch.std(input, dim=(0, 1), unbiased=False, keepdim=True)
3477
x = torch.randn(2, 3, 4)
3478
model = StandardDeviation()
3479
self.run_test(model, x)
3481
class StandardDeviationUnbiased(torch.nn.Module):
3482
def forward(self, input):
3483
return torch.std(input, dim=(0, 1), unbiased=True, keepdim=True)
3485
x = torch.randn(2, 3, 4)
3486
model = StandardDeviationUnbiased()
3487
self.run_test(model, x)
3489
def test_std_correction(self):
3490
class StandardDeviation(torch.nn.Module):
3491
def forward(self, input):
3492
return torch.std(input, dim=(0, 1), correction=3, keepdim=True)
3494
x = torch.randn(2, 3, 4)
3495
model = StandardDeviation()
3496
self.run_test(model, x)
3499
class Variance(torch.nn.Module):
3500
def forward(self, input):
3501
return torch.var(input, unbiased=False)
3503
x = torch.randn(2, 3, 4)
3505
self.run_test(model, x)
3507
class VarianceUnbiased(torch.nn.Module):
3508
def forward(self, input):
3509
return torch.var(input, unbiased=True)
3511
model = VarianceUnbiased()
3512
self.run_test(model, x)
3514
class VarianceSqrt(torch.nn.Module):
3515
def forward(self, input):
3516
y = torch.var(input, 1)
3517
return torch.sqrt(y + 1e-8)
3519
x = torch.randn(1, 2, 3, 300, 300)
3520
model = VarianceSqrt()
3521
self.run_test(model, x)
3523
def test_var_along_dims(self):
3524
class Variance(torch.nn.Module):
3525
def forward(self, input):
3526
return torch.var(input, dim=(0, 1), unbiased=False)
3528
x = torch.randn(2, 3, 4)
3530
self.run_test(model, x)
3532
class VarianceUnbiased(torch.nn.Module):
3533
def forward(self, input):
3534
return torch.var(input, dim=(0, 1), unbiased=True)
3536
x = torch.randn(2, 3, 4)
3537
model = VarianceUnbiased()
3538
self.run_test(model, x)
3540
def test_var_keepdim(self):
3541
class Variance(torch.nn.Module):
3542
def forward(self, input):
3543
return torch.var(input, dim=(0, 1), unbiased=False, keepdim=True)
3545
x = torch.randn(2, 3, 4)
3547
self.run_test(model, x)
3549
class VarianceUnbiased(torch.nn.Module):
3550
def forward(self, input):
3551
return torch.var(input, dim=(0, 1), unbiased=True, keepdim=True)
3553
x = torch.randn(2, 3, 4)
3554
model = VarianceUnbiased()
3555
self.run_test(model, x)
3557
def test_var_correction(self):
3558
class Variance(torch.nn.Module):
3559
def forward(self, input):
3560
return torch.var(input, dim=(0, 1), correction=3, keepdim=True)
3562
x = torch.randn(2, 3, 4)
3564
self.run_test(model, x)
3566
def test_var_mean(self):
3567
class Variance(torch.nn.Module):
3568
def forward(self, input):
3569
return torch.var_mean(input, unbiased=False)
3571
x = torch.randn(2, 3, 4)
3573
self.run_test(model, x)
3575
class VarianceUnbiased(torch.nn.Module):
3576
def forward(self, input):
3577
return torch.var_mean(input, unbiased=True)
3579
model = VarianceUnbiased()
3580
self.run_test(model, x)
3582
def test_var_mean_along_dims(self):
3583
class Variance(torch.nn.Module):
3584
def forward(self, input):
3585
return torch.var_mean(input, dim=(0, 1), unbiased=False)
3587
x = torch.randn(2, 3, 4)
3589
self.run_test(model, x)
3591
class VarianceUnbiased(torch.nn.Module):
3592
def forward(self, input):
3593
return torch.var_mean(input, dim=(0, 1), unbiased=True)
3595
x = torch.randn(2, 3, 4)
3596
model = VarianceUnbiased()
3597
self.run_test(model, x)
3599
def test_var_mean_mixed_dims(self):
3600
class ReverseDims(torch.nn.Module):
3601
def forward(self, input):
3602
return torch.var_mean(input, dim=(2, 1), unbiased=False)
3604
x = torch.randn(2, 3, 4)
3605
model = ReverseDims()
3606
self.run_test(model, x)
3608
class SkipDims(torch.nn.Module):
3609
def forward(self, input):
3610
return torch.var_mean(input, dim=(0, 2), unbiased=False)
3612
x = torch.randn(2, 3, 4)
3614
self.run_test(model, x)
3616
class NonZeroDims(torch.nn.Module):
3617
def forward(self, input):
3618
return torch.var_mean(input, dim=(1, 2), unbiased=False)
3620
x = torch.randn(2, 3, 4)
3621
model = NonZeroDims()
3622
self.run_test(model, x)
3624
def test_var_mean_keepdim(self):
3625
class Variance(torch.nn.Module):
3626
def forward(self, input):
3627
return torch.var_mean(input, dim=(0, 1), unbiased=False, keepdim=True)
3629
x = torch.randn(2, 3, 4)
3631
self.run_test(model, x)
3633
class VarianceUnbiased(torch.nn.Module):
3634
def forward(self, input):
3635
return torch.var_mean(input, dim=(0, 1), unbiased=True, keepdim=True)
3637
x = torch.randn(2, 3, 4)
3638
model = VarianceUnbiased()
3639
self.run_test(model, x)
3641
def test_var_mean_correction(self):
3642
class Variance(torch.nn.Module):
3643
def forward(self, input):
3644
return torch.var_mean(input, dim=(0, 1), correction=3, keepdim=True)
3646
x = torch.randn(2, 3, 4)
3648
self.run_test(model, x)
3650
def test_std_mean(self):
3651
class StandardDeviation(torch.nn.Module):
3652
def forward(self, input):
3653
return torch.std_mean(input, unbiased=False)
3655
x = torch.randn(2, 3, 4)
3656
model = StandardDeviation()
3657
self.run_test(model, x)
3659
class StandardDeviationUnbiased(torch.nn.Module):
3660
def forward(self, input):
3661
return torch.std_mean(input, unbiased=True)
3663
model = StandardDeviationUnbiased()
3664
self.run_test(model, x)
3666
def test_std_mean_along_dims(self):
3667
class StandardDeviation(torch.nn.Module):
3668
def forward(self, input):
3669
return torch.std_mean(input, dim=(0, 1), unbiased=False)
3671
x = torch.randn(2, 3, 4)
3672
model = StandardDeviation()
3673
self.run_test(model, x)
3675
class VarianceUnbiased(torch.nn.Module):
3676
def forward(self, input):
3677
return torch.std_mean(input, dim=(0, 1), unbiased=True)
3679
x = torch.randn(2, 3, 4)
3680
model = VarianceUnbiased()
3681
self.run_test(model, x)
3683
def test_std_mean_keepdim(self):
3684
class StandardDeviation(torch.nn.Module):
3685
def forward(self, input):
3686
return torch.std_mean(input, dim=(0, 1), unbiased=False, keepdim=True)
3688
x = torch.randn(2, 3, 4)
3689
model = StandardDeviation()
3690
self.run_test(model, x)
3692
class StandardDeviationUnbiased(torch.nn.Module):
3693
def forward(self, input):
3694
return torch.std_mean(input, dim=(0, 1), unbiased=True, keepdim=True)
3696
x = torch.randn(2, 3, 4)
3697
model = StandardDeviationUnbiased()
3698
self.run_test(model, x)
3700
def test_std_mean_correction(self):
3701
class StandardDeviation(torch.nn.Module):
3702
def forward(self, input):
3703
return torch.var_mean(input, dim=(0, 1), correction=3, keepdim=True)
3705
x = torch.randn(2, 3, 4)
3706
model = StandardDeviation()
3707
self.run_test(model, x)
3709
def test_bitshift(self):
3710
class BitshiftModel(torch.nn.Module):
3711
def forward(self, input):
3715
input >> torch.tensor([1, 2]),
3719
input = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2)
3720
self.run_test(BitshiftModel(), input)
3722
# uint8 not implemented in ORT for Mul used in
3723
# exporting bitshift for opset_version < 10
3724
@skipIfUnsupportedMinOpsetVersion(11)
3725
def test_bitshift_uint8(self):
3726
class BitshiftModel(torch.nn.Module):
3727
def forward(self, input, input2):
3731
input2 >> torch.tensor([1, 2], dtype=torch.uint8),
3735
input = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
3736
input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
3737
self.run_test(BitshiftModel(), (input, input2))
3739
def test_narrow(self):
3740
class NarrowModel(torch.nn.Module):
3741
def forward(self, input):
3742
return torch.narrow(input, 0, 0, 2)
3744
x = torch.randn(3, 3, requires_grad=True)
3745
self.run_test(NarrowModel(), x)
3747
@skipIfUnsupportedMinOpsetVersion(11)
3748
def test_narrow_dynamic(self):
3749
class NarrowModel(torch.nn.Module):
3750
def forward(self, input):
3751
return torch.narrow(input, 0, 0, input.shape[0] - 1)
3753
x = torch.randn(3, 3, requires_grad=True)
3754
self.run_test(NarrowModel(), x)
3756
@skipIfUnsupportedMinOpsetVersion(9)
3757
def test_index_fill(self):
3758
class IndexFillModel(torch.nn.Module):
3759
def forward(self, input):
3760
index = torch.tensor([2, 0])
3761
return input.index_fill(2, index, -1)
3763
x = torch.randn(3, 4, 5, requires_grad=True)
3764
self.run_test(IndexFillModel(), x)
3766
@skipIfUnsupportedMinOpsetVersion(9)
3767
def test_index_copy(self):
3768
class IndexCopyModel(torch.nn.Module):
3769
def __init__(self, dim):
3773
def forward(self, input):
3774
index = torch.tensor([2, 0])
3775
source = torch.ones(3, 2, 5)
3776
return input.index_copy(self.dim, index, source)
3778
x = torch.randn(3, 4, 5, requires_grad=True)
3780
self.run_test(IndexCopyModel(dim), x)
3782
def test_select(self):
3783
class Select(torch.nn.Module):
3784
def forward(self, x):
3787
x = torch.randn(3, 4)
3788
self.run_test(Select(), x)
3790
def test_select_negative_index(self):
3791
class Select(torch.nn.Module):
3792
def forward(self, x):
3795
x = torch.randn(3, 4)
3796
self.run_test(Select(), x)
3798
def test_index_select_constant_scaler_index(self):
3799
class IndexSelectScalerIndexModel(torch.nn.Module):
3800
def forward(self, x):
3802
return torch.index_select(x, 1, torch.tensor(index))
3804
x = torch.randn(3, 4)
3805
self.run_test(IndexSelectScalerIndexModel(), x)
3807
def test_index_select_scaler_index(self):
3808
class IndexSelectScalerIndexModel(torch.nn.Module):
3809
def __init__(self, index_base):
3811
self.index_base = torch.tensor(index_base)
3813
def forward(self, x, index_offset):
3814
index = self.index_base + index_offset
3815
return torch.index_select(x, 1, index)
3817
x = torch.randn(3, 4)
3819
index_offset = torch.tensor(offset)
3821
self.run_test(IndexSelectScalerIndexModel(base), (x, index_offset))
3823
def test_take(self):
3824
class TakeModel(torch.nn.Module):
3825
def forward(self, x, y):
3826
return torch.take(x, y)
3828
x = torch.randn(6, 4, 3, 3)
3829
y = torch.tensor([4, 1, 7, 15, 63])
3830
self.run_test(TakeModel(), (x, y))
3832
def test_topk(self):
3833
class MyModule(torch.nn.Module):
3834
def forward(self, x):
3835
return torch.topk(x, 3)
3837
x = torch.arange(1.0, 6.0, requires_grad=True)
3838
self.run_test(MyModule(), x)
3840
@skipIfUnsupportedMinOpsetVersion(10)
3841
def test_topk_int32_k(self):
3842
class Model(torch.nn.Module):
3843
def forward(self, x, k):
3844
return torch.topk(x, k)
3846
x = torch.arange(1.0, 6.0)
3847
k = torch.tensor(3, dtype=torch.int32)
3848
self.run_test(Model(), (x, k))
3850
@skipIfUnsupportedMinOpsetVersion(11)
3851
def test_topk_smallest_unsorted(self):
3852
class MyModule(torch.nn.Module):
3853
def forward(self, x, k):
3854
# When sorted=False, order of elements in the outout tensors
3855
# are not expected to match between PyTorch and ORT
3856
topk_unsorted = torch.topk(x, k, largest=False, sorted=False)
3857
topk_sorted = torch.topk(x, k, largest=False, sorted=True)
3858
return topk_sorted, torch.sort(topk_unsorted.values).values
3860
x = torch.arange(1.0, 6.0, requires_grad=True)
3862
self.run_test(MyModule(), (x, k))
3864
@skipIfUnsupportedMinOpsetVersion(10)
3865
def test_topk_script(self):
3866
class MyModuleDynamic(torch.jit.ScriptModule):
3867
@torch.jit.script_method
3868
def forward(self, x, k):
3869
return torch.topk(x, k)
3871
x = torch.arange(1.0, 6.0, requires_grad=True)
3873
self.run_test(MyModuleDynamic(), (x, k))
3875
@skipScriptTest() # Python builtin apply of FunctionMeta object is currently not supported in Torchscript.
3876
@skipIfUnsupportedMinOpsetVersion(11) # Clip op min is an input since opset 11.
3877
def test_auto_grad(self):
3878
class MyClip(torch.autograd.Function):
3880
def forward(ctx, input, scalar):
3881
ctx.save_for_backward(input)
3882
return input.clamp(min=scalar)
3884
class MyRelu(torch.autograd.Function):
3886
def forward(ctx, input):
3887
ctx.save_for_backward(input)
3888
return input.clamp(min=0)
3890
def symbolic_python_op(
3891
ctx: torch.onnx.SymbolicContext, g: torch._C.Graph, *args, **kwargs
3894
name = kwargs["name"]
3895
if name == "MyClip":
3896
return g.op("Clip", args[0], args[1], outputs=n.outputsSize())
3897
elif name == "MyRelu":
3898
return g.op("Relu", args[0], outputs=n.outputsSize())
3900
# TODO(justinchuby): Remove reference to internal names in symbolic_helper
3901
return torch.onnx.symbolic_helper._unimplemented(
3902
"prim::PythonOp", "unknown node kind: " + name
3905
torch.onnx.register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)
3906
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "prim::PythonOp", 1)
3908
class MyClipModule(torch.nn.Module):
3909
def forward(self, x, min):
3910
return MyClip.apply(x, min)
3912
x = torch.randn(3, 3)
3913
min = torch.tensor([0.0])
3914
self.run_test(MyClipModule(), (x, min))
3916
class MyReluModule(torch.nn.Module):
3917
def forward(self, x):
3918
return MyRelu.apply(x)
3920
x = torch.randn(3, 3)
3921
self.run_test(MyReluModule(), x)
3923
def test_clip_int(self):
3924
class MyClipInt(torch.nn.Module):
3925
def forward(self, x):
3926
return torch.clamp(x, 0, 1)
3928
self.run_test(MyClipInt(), torch.randn(3, 3).to(torch.int64))
3930
def test_relu_int(self):
3931
self.run_test(torch.nn.ReLU(), torch.randn(3, 3).to(torch.int32))
3933
def test_pad_int(self):
3934
class MyPadInt(torch.nn.Module):
3935
def forward(self, x):
3936
return torch.nn.functional.pad(x, (1, 1))
3938
self.run_test(MyPadInt(), torch.randn(3, 3).to(torch.int32))
3940
def test_min_int(self):
3941
class MyMinInt(torch.nn.Module):
3942
def forward(self, x):
3943
return torch.min(x, x + 1)
3945
self.run_test(MyMinInt(), torch.randn(3, 3).to(torch.int32))
3947
def test_max_int(self):
3948
class MyMaxnInt(torch.nn.Module):
3949
def forward(self, x):
3950
return torch.max(x, x + 1)
3952
self.run_test(MyMaxnInt(), torch.randn(3, 3).to(torch.int32))
3954
@skipIfUnsupportedOpsetVersion([7])
3955
def test_normalize(self):
3956
class Model(torch.nn.Module):
3957
def forward(self, x):
3958
return torch.nn.functional.normalize(x)
3960
x = torch.randn(3, 3)
3961
self.run_test(Model(), x)
3963
def test_norm_with_dtype(self):
3964
class Model(torch.nn.Module):
3965
def forward(self, x):
3966
# TODO(bowbao): There is a slight gap in today's test infrastructure
3967
# to directly test aten ops. OpInfo `torch.norm`` in `common_methods_invocations.py`
3968
# will not decompose to below aten op.
3969
return torch.ops.aten.norm(
3970
x, p=2, dim=[1], keepdim=True, dtype=torch.float64
3973
x = torch.randn(3, 3)
3974
self.run_test(Model(), x)
3976
def test_layer_norm(self):
3977
# As layer_norm works on the last D dimension, please keep
3978
# this test case at least three dimension to prevent the
3979
# situation of axis=2 mapping to the same axis as axis=-2
3980
for elementwise_affine in (True, False):
3981
for bias in (True, False):
3982
model = torch.nn.LayerNorm(
3983
[10, 10, 10], elementwise_affine=elementwise_affine, bias=bias
3985
x = torch.randn(20, 5, 10, 10, 10)
3986
self.run_test(model, x)
3988
def test_batchnorm1d(self):
3989
x = torch.randn(10, 10)
3990
model = torch.nn.BatchNorm1d(10, affine=True)
3991
self.run_test(model, x)
3993
x = torch.randn(10, 10, 128)
3994
self.run_test(model, x)
3996
def test_batchnorm1d_noaffine(self):
3997
x = torch.randn(10, 10)
3998
model = torch.nn.BatchNorm1d(10, affine=False)
3999
self.run_test(model, x)
4001
x = torch.randn(10, 10, 128)
4002
self.run_test(model, x)
4004
def test_batchnorm1d_norunningstats(self):
4005
x = torch.randn(10, 10)
4006
model = torch.nn.BatchNorm1d(10, track_running_stats=False)
4007
self.run_test(model, x)
4009
x = torch.randn(10, 10, 128)
4010
self.run_test(model, x)
4012
def test_batchnorm2d(self):
4013
x = torch.randn(10, 3, 128, 128)
4014
model = torch.nn.BatchNorm2d(3, affine=True)
4015
self.run_test(model, x)
4017
def test_batchnorm2d_noaffine(self):
4018
x = torch.randn(10, 3, 128, 128)
4019
model = torch.nn.BatchNorm2d(3, affine=False)
4020
self.run_test(model, x)
4022
def test_batchnorm2d_norunningstats(self):
4023
x = torch.randn(10, 3, 128, 128)
4024
model = torch.nn.BatchNorm2d(3, track_running_stats=False)
4025
self.run_test(model, x)
4027
def test_batchnorm3d(self):
4028
x = torch.randn(10, 3, 64, 64, 64)
4029
model = torch.nn.BatchNorm3d(3, affine=True)
4030
self.run_test(model, x)
4032
def test_batchnorm3d_noaffine(self):
4033
x = torch.randn(10, 3, 64, 64, 64)
4034
model = torch.nn.BatchNorm3d(3, affine=False)
4035
self.run_test(model, x)
4037
@skipIfUnsupportedMinOpsetVersion(
4039
) # Because ConstantOfShape op is not supported for opset < 9
4040
def test_instancenorm1d_runningstats(self):
4041
x = torch.randn(10, 5, 128)
4042
model = torch.nn.InstanceNorm1d(5, affine=True, track_running_stats=True)
4043
self.run_test(model, x)
4045
model = torch.nn.InstanceNorm1d(5, affine=False, track_running_stats=True)
4046
self.run_test(model, x)
4048
def test_instancenorm1d_norunningstats(self):
4049
x = torch.randn(10, 5, 128)
4050
model = torch.nn.InstanceNorm1d(5, affine=True, track_running_stats=False)
4051
self.run_test(model, x)
4053
model = torch.nn.InstanceNorm1d(5, affine=False, track_running_stats=False)
4054
self.run_test(model, x)
4056
@skipIfUnsupportedMinOpsetVersion(
4058
) # Because ConstantOfShape op is not supported for opset < 9
4059
def test_instancenorm2d_runningstats(self):
4060
x = torch.randn(10, 3, 128, 128)
4061
model = torch.nn.InstanceNorm2d(3, affine=True, track_running_stats=True)
4062
self.run_test(model, x)
4064
model = torch.nn.InstanceNorm2d(3, affine=False, track_running_stats=True)
4065
self.run_test(model, x)
4067
def test_instancenorm2d_norunningstats(self):
4068
x = torch.randn(10, 3, 128, 128)
4069
model = torch.nn.InstanceNorm2d(3, affine=True, track_running_stats=False)
4070
self.run_test(model, x)
4072
model = torch.nn.InstanceNorm2d(3, affine=False, track_running_stats=False)
4073
self.run_test(model, x)
4075
@skipIfUnsupportedMinOpsetVersion(
4077
) # Because ConstantOfShape op is not supported for opset < 9
4078
def test_instancenorm3d_runningstats(self):
4079
x = torch.randn(10, 3, 64, 64, 64)
4080
model = torch.nn.InstanceNorm3d(3, affine=True, track_running_stats=True)
4081
self.run_test(model, x)
4083
model = torch.nn.InstanceNorm3d(3, affine=False, track_running_stats=True)
4084
self.run_test(model, x)
4086
def test_instancenorm3d_norunningstats(self):
4087
x = torch.randn(10, 3, 64, 64, 64)
4088
model = torch.nn.InstanceNorm3d(3, affine=True, track_running_stats=False)
4089
self.run_test(model, x)
4091
model = torch.nn.InstanceNorm3d(3, affine=False, track_running_stats=False)
4092
self.run_test(model, x)
4094
@skipIfUnsupportedMinOpsetVersion(9)
4095
def test_scatter_with_scalar(self):
4096
class ScatterModel(torch.nn.Module):
4097
def forward(self, input, indices):
4099
return input.scatter(1, indices, values)
4101
input = torch.tensor(
4102
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float64
4104
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4105
self.run_test(ScatterModel(), input_args=(input, indices))
4107
@skipIfUnsupportedMinOpsetVersion(9)
4108
def test_scatter_with_scalar_different_types(self):
4109
# Tests the case when scalar src (updates values) type is different
4110
# from self type. Happens only with scalar src - PyTorch does not
4111
# allow this when src is a tensor.
4112
class ScatterModel(torch.nn.Module):
4113
def forward(self, input, indices):
4115
return input.scatter(1, indices, values)
4117
input = torch.tensor(
4118
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float32
4120
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4121
self.run_test(ScatterModel(), input_args=(input, indices))
4123
@skipIfUnsupportedMinOpsetVersion(9)
4124
def test_scatter(self):
4125
class ScatterModel(torch.nn.Module):
4126
def forward(self, input, indices, values):
4127
return input.scatter(1, indices, values)
4129
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
4130
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4131
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
4132
self.run_test(ScatterModel(), input_args=(input, indices, values))
4134
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
4135
indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64)
4136
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
4137
self.run_test(ScatterModel(), (input, indices, values))
4139
input = torch.zeros(3, 4, 5, 6)
4140
indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64)
4141
indices = indices.view(3, 2, 1, 1).expand(3, 2, 5, 6)
4142
values = torch.arange(3 * 2 * 5 * 6, dtype=torch.float32).view(3, 2, 5, 6)
4143
self.run_test(ScatterModel(), (input, indices, values))
4145
input = torch.zeros(3, 4, 2)
4146
indices = torch.tensor([[[1, 0], [0, 2]], [[1, 1], [0, 1]], [[2, 1], [2, 2]]])
4147
values = torch.arange(3 * 2 * 2, dtype=torch.float32).view(3, 2, 2)
4148
self.run_test(ScatterModel(), (input, indices, values))
4150
@skipIfUnsupportedMinOpsetVersion(9)
4151
def test_scatter_add(self):
4152
class ScatterModel(torch.nn.Module):
4153
def forward(self, input, indices, values):
4154
return input.scatter_add(1, indices, values)
4156
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
4157
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4158
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
4159
self.run_test(ScatterModel(), input_args=(input, indices, values))
4162
def scatter_sum(src: Tensor, index: Tensor):
4164
out = torch.zeros(size, dtype=src.dtype)
4165
return out.scatter_add_(1, index, src)
4167
class ScatterModel(torch.nn.Module):
4168
def forward(self, src, index):
4169
return scatter_sum(src, index)
4171
src = torch.rand(3, 2)
4172
index = torch.tensor([[0, 1], [0, 1], [0, 1]], dtype=torch.int64)
4173
self.run_test(ScatterModel(), (src, index))
4175
@skipIfUnsupportedMinOpsetVersion(16)
4176
def test_scatter_add_index_not_unique(self):
4177
class ScatterModel(torch.nn.Module):
4178
def forward(self, input, indices, values):
4179
return input.scatter_add(1, indices, values)
4181
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
4182
indices = torch.tensor([[0, 0], [1, 1], [2, 2]], dtype=torch.int64)
4183
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
4184
self.run_test(ScatterModel(), input_args=(input, indices, values))
4187
def scatter_sum(src: Tensor, index: Tensor):
4189
out = torch.zeros(size, dtype=src.dtype)
4190
return out.scatter_add_(1, index, src)
4192
class ScatterModel(torch.nn.Module):
4193
def forward(self, src, index):
4194
return scatter_sum(src, index)
4196
src = torch.rand(3, 2)
4197
index = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64)
4198
self.run_test(ScatterModel(), (src, index))
4200
@skipIfUnsupportedMinOpsetVersion(16)
4201
def test_scatter_add_different_size_index_src(self):
4202
class ScatterModel(torch.nn.Module):
4203
def forward(self, input, indices, src):
4204
return input.scatter_add(0, indices, src)
4206
src = torch.ones((2, 5))
4207
input = torch.zeros(3, 5, dtype=src.dtype)
4208
indices = torch.tensor([[0, 1, 2, 0, 0]])
4209
self.run_test(ScatterModel(), input_args=(input, indices, src))
4211
@common_utils.parametrize(
4214
common_utils.subtest(
4215
[torch.ones((1, 5)), torch.tensor([[0, 1, 2, 0, 0]])],
4216
name="src_indices_dynamic_combination1",
4218
common_utils.subtest(
4219
[torch.ones((2, 5)), torch.tensor([[0, 1, 2, 0, 0], [1, 0, 2, 1, 2]])],
4220
name="src_indices_dynamic_combination2",
4222
common_utils.subtest(
4223
[torch.ones((3, 5)), torch.tensor([[0, 1, 2, 0, 0], [1, 0, 2, 1, 2]])],
4224
name="src_indices_dynamic_combination3",
4226
common_utils.subtest(
4227
[torch.ones((3, 5)), torch.tensor([[0, 1, 2, 0], [1, 0, 2, 1]])],
4228
name="src_indices_dynamic_combination4",
4232
@skipIfUnsupportedMinOpsetVersion(16)
4233
def test_scatter_add_dynamic_index(self, src, indices):
4234
class ScatterModel(torch.nn.Module):
4235
def forward(self, input, indices, src):
4236
return input.scatter_add(0, indices, src)
4238
input = torch.zeros(3, 5, dtype=src.dtype)
4241
input_args=(input, indices, src),
4242
input_names=["input", "indices", "src"],
4243
dynamic_axes={"indices": {0: "a", 1: "b"}, "src": {0: "c", 1: "d"}},
4246
@skipIfUnsupportedMinOpsetVersion(16)
4247
def test_scatter_reduce(self):
4248
class Model(torch.nn.Module):
4252
def forward(self, x, index, input):
4253
y_max = input.scatter_reduce(0, index, x, reduce="amax")
4254
y_sum = input.scatter_reduce(0, index, x, reduce="sum")
4255
y_min = input.scatter_reduce(0, index, x, reduce="amin")
4256
y_mul = input.scatter_reduce(0, index, x, reduce="prod")
4257
return y_max, y_sum, y_min, y_mul
4262
src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
4263
index = torch.tensor([0, 1, 0, 1, 2, 1])
4264
input = torch.tensor([1.0, 2.0, 3.0, 8.0])
4266
self.run_test(model, (src, index, input))
4268
@skipIfUnsupportedMinOpsetVersion(16)
4269
def test_scatter_reduce_self_rank_zero(self):
4270
class Model(torch.nn.Module):
4274
def forward(self, x, index, input):
4275
y_max = input.scatter_reduce(0, index, x, reduce="amax")
4276
y_sum = input.scatter_reduce(0, index, x, reduce="sum")
4277
y_min = input.scatter_reduce(0, index, x, reduce="amin")
4278
y_mul = input.scatter_reduce(0, index, x, reduce="prod")
4279
return y_max, y_sum, y_min, y_mul
4284
empty_tensor = torch.tensor([])
4285
empty_idx = torch.tensor([], dtype=torch.int64)
4287
self.run_test(model, (empty_tensor, empty_idx, empty_tensor))
4289
@skipIfUnsupportedMinOpsetVersion(9)
4290
def test_bucketize(self):
4291
class BucketModel(torch.nn.Module):
4292
def forward(self, input, boundaries):
4293
return torch.bucketize(input, boundaries), torch.bucketize(
4294
input, boundaries, right=True
4297
input = torch.tensor([[2, 5, 10], [6, 8, 3]])
4298
boundaries = torch.tensor([1, 5, 7, 8, 10])
4299
self.run_test(BucketModel(), (input, boundaries))
4301
@skipIfUnsupportedMinOpsetVersion(9)
4302
def test_one_hot(self):
4303
class OneHot(torch.nn.Module):
4304
def __init__(self, num_classes):
4306
self.num_classes = num_classes
4308
def forward(self, x):
4309
return torch.nn.functional.one_hot(x, self.num_classes)
4311
x = torch.arange(10)
4312
self.run_test(OneHot(15), (x))
4314
class OneHot(torch.nn.Module):
4315
def forward(self, x, num_classes):
4316
num_classes = num_classes.to(torch.int32)
4317
return torch.nn.functional.one_hot(x, num_classes[0])
4319
x = torch.arange(10)
4320
num_classes = 15 * torch.ones(1)
4321
self.run_test(OneHot(), (x, num_classes))
4323
@skipIfUnsupportedMinOpsetVersion(9)
4324
def test_gather(self):
4325
class GatherModel(torch.nn.Module):
4326
def forward(self, input, indices):
4327
return input.gather(1, indices)
4329
input = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
4330
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4331
self.run_test(GatherModel(), input_args=(input, indices))
4333
@skipScriptTest() # Scripting error: Cannot instantiate nn module
4334
def test_gather_constant_fold(self):
4335
class GatherModule(torch.nn.Module):
4338
self.register_buffer("weight", torch.ones(5))
4339
# torch.nn.Embedding is converted to ONNX::Gather.
4340
# Constant folding will be triggerred for constant inputs.
4341
# This pattern is common for constant mask inputs in transformer models.
4342
self.embed = torch.nn.Embedding(8, 3)
4344
def forward(self, x):
4345
# shape is of rank 0
4346
shape = self.weight.shape[0]
4348
y = torch.ones(1, 4, dtype=torch.long)
4349
return x.clamp(min=m), self.embed(y)
4352
self.run_test(GatherModule(), (x,))
4354
class GatherModule(torch.nn.Module):
4357
self.register_buffer("weight", torch.ones(2))
4359
def forward(self, x):
4360
# shape is of rank 0
4361
shape = self.weight.shape[0]
4362
pad = [1, shape, shape, shape]
4363
zero_pad = torch.nn.ZeroPad2d(pad)
4366
x = torch.randn(1, 3, 2)
4367
self.run_test(GatherModule(), (x,))
4369
class GatherModule(torch.nn.Module):
4372
self.register_buffer("rb", torch.randn(1, 1, 3, 1, 1))
4374
def forward(self, x):
4378
x = torch.randn(1, 3, 224, 224)
4383
"input": {0: "batch", 2: "height", 3: "width"},
4384
"output": {0: "batch", 1: "class", 2: "height", 3: "width"},
4386
input_names=["input"],
4387
output_names=["output"],
4390
@skipIfUnsupportedOpsetVersion([13])
4391
@skipIfUnsupportedMinOpsetVersion(9)
4392
def test_expand(self):
4393
class ExpandModel(torch.nn.Module):
4394
def forward(self, input):
4395
return input.expand(2, 3, -1)
4397
input = torch.randn(2, 1, 4)
4398
self.run_test(ExpandModel(), input_args=(input))
4400
class ExpandInferDimModel(torch.nn.Module):
4401
def forward(self, input):
4402
return input.expand(-1, input.size(0))
4404
input = torch.randn(3, 1)
4405
self.run_test(ExpandInferDimModel(), input_args=(input))
4407
class ExpandTensorSizeModel(torch.nn.Module):
4408
def forward(self, input, size):
4409
return input.expand(size)
4411
input = torch.randn(
4414
size = torch.tensor(-1)
4415
self.run_test(ExpandTensorSizeModel(), input_args=(input, size))
4417
@skipIfUnsupportedMinOpsetVersion(11) # index_put is supported in opsets >= 11
4418
def test_dynamic_expand_as(self):
4419
class Model(torch.nn.Module):
4420
def forward(self, x):
4421
x[:, x.size(0) :] = 0
4424
x = torch.ones(2, 5)
4425
x2 = torch.randn(3, 4)
4430
dynamic_axes={"x": [0, 1]},
4431
additional_test_inputs=[x2],
4434
class Model(torch.nn.Module):
4435
def forward(self, x):
4436
x[:, x.size(0) :] = torch.tensor([1, 2, 3])
4439
x = torch.ones(2, 5, 3)
4440
x2 = torch.randn(3, 4, 3)
4445
dynamic_axes={"x": [0, 1, 2]},
4446
additional_test_inputs=[x2],
4449
class Model(torch.nn.Module):
4450
def forward(self, x):
4451
aa = torch.tensor([[0], [1], [2]])
4452
return aa.expand_as(x)
4454
x = torch.ones(3, 2)
4455
x2 = torch.randn(3, 5)
4460
dynamic_axes={"x": [0, 1]},
4461
additional_test_inputs=[x2],
4464
def test_multinomial(self):
4465
class Multinomial(torch.nn.Module):
4466
def forward(self, weight):
4467
return torch.multinomial(weight, 3, replacement=True)
4469
class MultinomialNoReplacement(torch.nn.Module):
4470
def forward(self, weight):
4471
return torch.multinomial(weight, 1)
4473
weight = torch.tensor([[0, 10, 0, 0], [0, 0, 100, 0]], dtype=torch.float)
4474
self.run_test(Multinomial(), (weight,))
4475
self.run_test(MultinomialNoReplacement(), (weight,))
4477
def _test_reduced_ops(self, op):
4478
class ReducedOpModule(torch.nn.Module):
4479
def forward(self, input):
4480
return op(input, dim=-1)
4482
if op != torch.mean: # torch.mean only supports float types
4483
x = torch.randint(10, (4, 4), dtype=torch.uint8)
4484
self.run_test(ReducedOpModule(), x)
4486
x = torch.randint(10, (4, 4), dtype=torch.int8)
4487
self.run_test(ReducedOpModule(), x)
4489
x = torch.randint(10, (4, 4), dtype=torch.int16)
4490
self.run_test(ReducedOpModule(), x)
4492
x = torch.randint(10, (4, 4), dtype=torch.int32)
4493
self.run_test(ReducedOpModule(), x)
4495
x = torch.randint(10, (4, 4), dtype=torch.int64)
4496
self.run_test(ReducedOpModule(), x)
4498
# torch.mean only supports float types
4499
# ORT does not support double ReduceProd for double
4500
if op != torch.prod and op != torch.mean:
4501
x = torch.randn(4, 5, dtype=torch.double)
4502
self.run_test(ReducedOpModule(), x)
4504
if op != torch.prod: # torch.prod not implemented for Half
4505
x = torch.randn(4, 4, dtype=torch.half)
4506
self.run_test(ReducedOpModule(), x)
4508
x = torch.randn(4, 5, dtype=torch.float)
4509
self.run_test(ReducedOpModule(), x)
4511
def test_reduced_sum(self):
4512
return self._test_reduced_ops(op=torch.sum)
4514
def test_reduced_mean(self):
4515
return self._test_reduced_ops(op=torch.mean)
4517
def test_reduced_prod(self):
4518
return self._test_reduced_ops(op=torch.prod)
4520
def test_reduced_sum_dtypes(self):
4521
class NoDimModel(torch.nn.Module):
4522
def forward(self, input):
4523
return input.sum(dtype=torch.float)
4525
class DimModel(torch.nn.Module):
4526
def forward(self, input):
4527
return input.sum(dim=-1, dtype=torch.float)
4529
input = torch.randn((4, 4), dtype=torch.half)
4530
self.run_test(NoDimModel(), input)
4531
self.run_test(DimModel(), input)
4533
def test_reduced_min_max(self):
4534
class ReducedMinMaxModule(torch.nn.Module):
4535
def forward(self, input):
4536
return torch.min(input, dim=-1)[0], torch.max(input, dim=0)[0]
4538
x = torch.randint(10, (4, 4), dtype=torch.int32)
4539
self.run_test(ReducedMinMaxModule(), x)
4541
x = torch.randint(10, (4, 4), dtype=torch.int64)
4542
self.run_test(ReducedMinMaxModule(), x)
4544
x = torch.randn(4, 5, dtype=torch.float)
4545
self.run_test(ReducedMinMaxModule(), x)
4547
def test_reduce_log_sum_exp(self):
4548
class ReduceLogSumExpModel(torch.nn.Module):
4549
def forward(self, input):
4550
a = torch.logsumexp(input, dim=0)
4551
b = torch.logsumexp(input, dim=(0, 1))
4554
x = torch.randn(4, 4, requires_grad=True)
4555
self.run_test(ReduceLogSumExpModel(), x)
4557
def test_softmax(self):
4558
for i in range(-4, 3):
4559
model = torch.nn.Softmax(dim=i)
4560
input = torch.randn(3, 4, 5, 6)
4561
self.run_test(model, input)
4563
class SoftmaxUnknownRank(torch.nn.Module):
4564
def __init__(self, i):
4566
self.softmax = torch.nn.Softmax(dim=i)
4568
def forward(self, x):
4569
return self.softmax(x.reshape(3, 4, 5, 6))
4571
model = torch.jit.script(SoftmaxUnknownRank(i))
4572
self.run_test(model, input)
4574
def test_softmax_large_values(self):
4575
input = torch.tensor(
4576
[[-1e12, -1e12, -1e12], [1e12, 0.0, -5.0], [3.0, 4.0, 5.0]]
4578
for i in range(-2, 1):
4579
model = torch.nn.Softmax(dim=i)
4580
self.run_test(model, input)
4582
class SoftmaxUnknownRank(torch.nn.Module):
4583
def __init__(self, i):
4585
self.softmax = torch.nn.Softmax(dim=i)
4587
def forward(self, x):
4588
return self.softmax(x.reshape(3, 3))
4590
model = torch.jit.script(SoftmaxUnknownRank(i))
4591
self.run_test(model, input)
4593
def test_logsoftmax(self):
4594
for i in range(7)[2:]:
4595
model = torch.nn.LogSoftmax(dim=i - 1)
4596
dims = [2] * (i - 2) + [3, 4]
4597
input = torch.ones(*dims, requires_grad=True)
4598
self.run_test(model, input)
4600
def test_logsoftmax_dim(self):
4601
for i in range(-4, 3):
4602
model = torch.nn.LogSoftmax(dim=i)
4603
input = torch.randn(3, 4, 5, 6)
4604
self.run_test(model, input)
4606
def test_logsoftmax_dtype(self):
4607
class Model(torch.nn.Module):
4608
def forward(self, x):
4609
return torch.nn.functional.log_softmax(x, dim=1, dtype=torch.float64)
4611
x = torch.randn(3, 4, 5, requires_grad=True)
4612
self.run_test(Model(), x)
4614
def test_softplus(self):
4615
class BetaOneModel(torch.nn.Module):
4616
def forward(self, x):
4617
return torch.nn.functional.softplus(x)
4619
x = torch.randn(3, 4, 5, requires_grad=True)
4620
self.run_test(BetaOneModel(), x)
4622
class BetaModel(torch.nn.Module):
4623
def forward(self, x):
4624
return torch.nn.functional.softplus(x, beta=2)
4626
x = torch.randn(3, 4, 5, requires_grad=True)
4627
self.run_test(BetaModel(), x)
4629
class BetaFloatModel(torch.nn.Module):
4630
def forward(self, x):
4631
return torch.nn.functional.softplus(x, beta=1.7)
4633
x = torch.randn(3, 4, 5, requires_grad=True)
4634
self.run_test(BetaFloatModel(), x)
4636
@skipIfUnsupportedMinOpsetVersion(9)
4637
def test_lstm_no_hidden(self):
4638
class LSTMModel(torch.nn.Module):
4641
self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16)
4643
def forward(self, x):
4646
input = torch.randn((10, 16, 16))
4647
self.run_test(LSTMModel(), (input,))
4649
@skipIfUnsupportedMinOpsetVersion(9)
4650
def test_lstm_proj_no_hidden(self):
4651
class LSTMModel(torch.nn.Module):
4654
self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16, proj_size=8)
4656
def forward(self, x):
4659
input = torch.randn((10, 16, 16))
4660
with self.assertRaises(RuntimeError):
4661
self.run_test(LSTMModel(), (input,))
4663
@skipIfUnsupportedMinOpsetVersion(9)
4664
def test_lstm(self):
4665
class LSTMModel(torch.nn.Module):
4668
self.rnn = torch.nn.LSTM(
4669
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
4672
def forward(self, x, h0, c0):
4673
return self.rnn(x, (h0, c0))
4675
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4676
h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
4677
c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
4678
self.run_test(LSTMModel(), (input, h0, c0))
4680
@skipIfUnsupportedMinOpsetVersion(9)
4681
def test_lstm_cell(self):
4682
class LSTMCellModel(torch.nn.Module):
4683
def __init__(self, bias):
4685
self.lstm_cell = torch.nn.LSTMCell(
4686
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, bias=bias
4689
def forward(self, x, h0, c0):
4690
return self.lstm_cell(x, (h0, c0))
4692
input = torch.randn(BATCH_SIZE, RNN_INPUT_SIZE)
4693
h0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE)
4694
c0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE)
4695
for bias in [True, False]:
4696
self.run_test(LSTMCellModel(bias), (input, h0, c0))
4698
@skipIfUnsupportedMinOpsetVersion(9)
4699
def test_lstm_default_init_state(self):
4700
class LSTMModel(torch.nn.Module):
4703
self.rnn = torch.nn.LSTM(
4704
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
4707
def forward(self, x):
4710
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4711
self.run_test(LSTMModel(), input)
4713
@skipIfUnsupportedMinOpsetVersion(9)
4714
def test_lstm_fixed_batch_size(self):
4715
class LSTMModel(torch.nn.Module):
4718
self.lstm = torch.nn.LSTM(
4719
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
4721
self.RNN_HIDDEN_SIZE = RNN_HIDDEN_SIZE
4723
def forward(self, input):
4724
batch_size = input.size()[1]
4725
h0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
4726
c0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
4727
return self.lstm(input, (h0, c0))
4729
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4730
# verify with different input of same batch size
4731
input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4733
LSTMModel(), input, fixed_batch_size=True, additional_test_inputs=[input2]
4736
@skipIfUnsupportedMinOpsetVersion(9)
4737
def test_lstm_post_fix_init_state(self):
4738
class LSTMModel(torch.nn.Module):
4741
self.lstm = torch.nn.LSTM(
4742
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
4744
self.RNN_HIDDEN_SIZE = RNN_HIDDEN_SIZE
4746
def forward(self, input):
4747
batch_size = input.size()[1]
4748
h0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
4749
c0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
4750
return self.lstm(input, (h0, c0))
4753
input = torch.randn(RNN_SEQUENCE_LENGTH, 1, RNN_INPUT_SIZE)
4754
# verify with different input of different batch size
4755
input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4759
input_names=["input.1"],
4760
dynamic_axes={"input.1": {0: "seq", 1: "batch"}},
4761
additional_test_inputs=[input2],
4764
def test_lstm_constant_folding(self):
4765
class LstmNet(torch.nn.Module):
4766
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
4768
self.lstm = torch.nn.LSTM(
4769
input_size, hidden_size, num_layers, bidirectional=bidirectional
4772
def forward(self, input, initial_state: Tuple[Tensor, Tensor]):
4773
return self.lstm(input, initial_state)
4775
def get_LstmNet_model_and_inputs(
4776
input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional
4778
num_directions = 2 if bidirectional else 1
4779
model = LstmNet(input_size, hidden_size, num_layers, bidirectional)
4780
input = torch.randn(seq_len, batch_size, input_size)
4781
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
4782
c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
4783
return model, (input, (h0, c0))
4786
model1, input1 = get_LstmNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
4787
self.run_test(model1, input1, do_constant_folding=True)
4790
model2, input2 = get_LstmNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
4791
self.run_test(model2, input2, do_constant_folding=True)
4793
@skipIfUnsupportedMinOpsetVersion(9)
4794
def test_lstm_no_bias(self):
4795
class LstmNet(torch.nn.Module):
4796
def __init__(self, num_layers, bidirectional):
4798
self.lstm = torch.nn.LSTM(
4803
bidirectional=bidirectional,
4806
def forward(self, input, initial_state: Tuple[Tensor, Tensor]):
4807
return self.lstm(input, initial_state)
4809
def get_LstmNet_model_and_inputs(num_layers, bidirectional):
4810
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4811
num_directions = 2 if bidirectional else 1
4812
model = LstmNet(num_layers, bidirectional)
4813
h0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE)
4814
c0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE)
4815
return model, (input, (h0, c0))
4817
num_layers = [1, 1, 2, 3]
4818
bidirectional = [True, False, True, False]
4819
models_and_inputs = [
4820
get_LstmNet_model_and_inputs(n, b)
4821
for n, b in zip(num_layers, bidirectional)
4823
for model, input in models_and_inputs:
4824
self.run_test(model, input)
4826
@skipIfUnsupportedMinOpsetVersion(9)
4827
def test_lstm_sequence(self):
4828
class LstmNet(torch.nn.Module):
4831
self.rnn1 = torch.nn.LSTM(8, 8, bidirectional=True, batch_first=True)
4832
self.linear1 = torch.nn.Linear(8 * 2, 8)
4833
self.rnn2 = torch.nn.LSTM(8, 8, bidirectional=True, batch_first=True)
4834
self.linear2 = torch.nn.Linear(8 * 2, 8)
4836
def forward(self, input):
4837
rnn_output1, _ = self.rnn1(input)
4838
linear_output1 = self.linear1(rnn_output1)
4839
rnn_output2, _ = self.rnn2(linear_output1)
4840
linear_output2 = self.linear2(rnn_output2)
4841
return linear_output2
4843
input = torch.zeros((1, 100, 8), dtype=torch.float32)
4847
input_names=["input"],
4848
output_names=["output"],
4850
"input": {0: "batch_size", 1: "w", 2: "h"},
4851
"output": {0: "batch_size", 1: "w", 2: "h"},
4856
def test_rnn_no_bias(self):
4857
def make_model(layers, packed_sequence):
4858
batch_first = True if packed_sequence == 2 else False
4859
model = torch.nn.RNN(
4863
bidirectional=False,
4864
batch_first=batch_first,
4868
if packed_sequence == 1:
4869
model = rnn_model_with_packed_sequence.RnnModelWithPackedSequence(
4872
if packed_sequence == 2:
4873
model = rnn_model_with_packed_sequence.RnnModelWithPackedSequence(
4878
def make_input(batch_size, layers, packed_sequence):
4879
batch_first = True if packed_sequence == 2 else False
4880
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
4881
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
4882
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
4883
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
4886
h0 = torch.randn(layers, batch_size, RNN_HIDDEN_SIZE)
4888
if packed_sequence != 0:
4889
inputs.append(torch.IntTensor(seq_lengths))
4890
if len(inputs) == 1:
4893
input = tuple(inputs)
4896
layers = [1, 3, 1, 3, 1, 3]
4897
packed_sequence = [0, 0, 1, 1, 2, 2]
4898
models = [make_model(l, p) for l, p in zip(layers, packed_sequence)]
4900
make_input(RNN_BATCH_SIZE, l, p) for l, p in zip(layers, packed_sequence)
4903
for model, input in zip(models, inputs):
4904
self.run_test(model, input)
4906
def test_gru_no_bias(self):
4907
class GruNet(torch.nn.Module):
4908
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
4910
self.mygru = torch.nn.GRU(
4914
bidirectional=bidirectional,
4918
def forward(self, input, initial_state):
4919
out = self.mygru(input, initial_state)
4922
def get_GruNet_model_and_inputs(
4923
input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional
4925
num_directions = 2 if bidirectional else 1
4926
model = GruNet(input_size, hidden_size, num_layers, bidirectional)
4927
input = torch.randn(seq_len, batch_size, input_size)
4928
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
4929
return model, (input, h0)
4932
hidden_size = [3, 4]
4936
bidirectional = [True, False]
4937
models_and_inputs = [
4938
get_GruNet_model_and_inputs(i, h, n, b, s, bi)
4939
for i, h, n, b, s, bi in zip(
4940
input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional
4943
for model, input in models_and_inputs:
4944
self.run_test(model, input, do_constant_folding=True)
4946
def test_gru_constant_folding(self):
4947
class GruNet(torch.nn.Module):
4948
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
4950
self.mygru = torch.nn.GRU(
4951
input_size, hidden_size, num_layers, bidirectional=bidirectional
4954
def forward(self, input, initial_state):
4955
out = self.mygru(input, initial_state)
4958
def get_GruNet_model_and_inputs(
4959
input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional
4961
num_directions = 2 if bidirectional else 1
4962
model = GruNet(input_size, hidden_size, num_layers, bidirectional)
4963
input = torch.randn(seq_len, batch_size, input_size)
4964
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
4965
return model, (input, h0)
4968
model1, input1 = get_GruNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
4969
self.run_test(model1, input1, do_constant_folding=True)
4972
model2, input2 = get_GruNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
4973
self.run_test(model2, input2, do_constant_folding=True)
4975
@skipIfUnsupportedMinOpsetVersion(8)
4976
def test_max_tensors(self):
4977
class MaxModel(torch.nn.Module):
4978
def forward(self, input, other):
4979
return torch.max(input, other)
4982
x = torch.randn(4, 4, requires_grad=True)
4983
y = torch.randn(4, 1, requires_grad=True)
4984
self.run_test(model, (x, y))
4986
def test_amax_amin(self):
4987
class Model(torch.nn.Module):
4988
def forward(self, x):
4989
return torch.amax(x, dim=0, keepdim=True), torch.amin(
4990
x, dim=[0, 1], keepdim=False
4994
x = torch.randn(4, 4)
4995
self.run_test(model, x)
4997
def test_aminmax(self):
4998
class Model(torch.nn.Module):
4999
def forward(self, x):
5000
return torch.aminmax(x, dim=1, keepdim=True), torch.aminmax(
5005
x = torch.randn(3, 4)
5006
self.run_test(model, x)
5008
@skipIfUnsupportedMinOpsetVersion(9)
5009
def test_arange_end(self):
5010
class ArangeScript(torch.jit.ScriptModule):
5011
@torch.jit.script_method
5012
def forward(self, a):
5013
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
5015
x = torch.randn(3, 4, requires_grad=True)
5016
outputs = ArangeScript()(x)
5017
self.run_test(ArangeScript(), x)
5019
class ArangeModel(torch.nn.Module):
5020
def forward(self, a):
5021
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
5023
self.run_test(ArangeModel(), x)
5025
@skipIfUnsupportedMinOpsetVersion(11)
5026
def test_arange_end_notype(self):
5027
class ArangeScript(torch.jit.ScriptModule):
5028
@torch.jit.script_method
5029
def forward(self, a):
5030
return torch.arange(a.size(0))
5032
x = torch.randn(3, 4, requires_grad=True)
5033
outputs = ArangeScript()(x)
5034
self.run_test(ArangeScript(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
5035
self.run_test(ArangeScript(), x, remained_onnx_input_idx=[])
5037
class ArangeModel(torch.nn.Module):
5038
def forward(self, a):
5039
return torch.arange(a.size(0))
5041
self.run_test(ArangeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
5042
self.run_test(ArangeModel(), x, remained_onnx_input_idx=[])
5044
@skipIfUnsupportedMinOpsetVersion(9)
5045
def test_arange_start_end(self):
5046
class ArangeScript(torch.jit.ScriptModule):
5047
@torch.jit.script_method
5048
def forward(self, a):
5049
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
5051
x = torch.randn(3, 4, requires_grad=True)
5052
self.run_test(ArangeScript(), x)
5054
class ArangeModel(torch.nn.Module):
5055
def forward(self, a):
5056
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
5058
self.run_test(ArangeModel(), x)
5060
@skipIfUnsupportedMinOpsetVersion(11)
5061
def test_arange_start_end_notype(self):
5062
class ArangeScript(torch.jit.ScriptModule):
5063
@torch.jit.script_method
5064
def forward(self, a):
5065
return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a
5067
x = torch.randn(3, 4, requires_grad=True)
5068
self.run_test(ArangeScript(), x)
5070
class ArangeModel(torch.nn.Module):
5071
def forward(self, a):
5072
return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a
5074
self.run_test(ArangeModel(), x)
5076
@skipIfUnsupportedMinOpsetVersion(9)
5077
def test_arange_start_end_step(self):
5078
class ArangeScript(torch.jit.ScriptModule):
5079
@torch.jit.script_method
5080
def forward(self, a):
5083
2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float
5088
x = torch.randn(3, 4, requires_grad=True)
5089
self.run_test(ArangeScript(), x)
5091
class ArangeModel(torch.nn.Module):
5092
def forward(self, a):
5095
2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float
5100
self.run_test(ArangeModel(), x)
5102
@skipIfUnsupportedMinOpsetVersion(11)
5103
def test_arange_start_end_step_notype(self):
5104
class ArangeScript(torch.jit.ScriptModule):
5105
@torch.jit.script_method
5106
def forward(self, a):
5108
torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1)
5112
x = torch.randn(3, 4, requires_grad=True)
5113
self.run_test(ArangeScript(), x)
5115
class ArangeModel(torch.nn.Module):
5116
def forward(self, a):
5118
torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1)
5122
self.run_test(ArangeModel(), x)
5124
@skipIfUnsupportedMinOpsetVersion(9)
5125
def test__dim_arange(self):
5126
class DimArange(torch.nn.Module):
5127
def forward(self, input):
5128
return torch._dim_arange(input, 1)
5130
x = torch.ones(5, 6)
5131
self.run_test(DimArange(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
5132
remained_onnx_input_idx = None if self.opset_version < 11 else []
5133
self.run_test(DimArange(), x, remained_onnx_input_idx=remained_onnx_input_idx)
5135
def _test_compare_ops(self, model, num_inputs):
5136
x_float = torch.randn(1, 2, 3, 4, requires_grad=True)
5137
x_int = torch.randint(10, (3, 4), dtype=torch.int32)
5139
y_float = torch.randn(1, 2, 3, 4, requires_grad=True)
5140
y_int = torch.randint(10, (3, 4), dtype=torch.int32)
5141
self.run_test(model, (x_float, y_float))
5142
self.run_test(model, (x_float, y_int))
5143
self.run_test(model, (x_int, y_float))
5144
self.run_test(model, (x_int, y_int))
5146
self.run_test(model, x_float)
5147
self.run_test(model, x_int)
5149
@skipIfUnsupportedMinOpsetVersion(9)
5150
def test_and_or_xor(self):
5151
class MyModel(torch.nn.Module):
5152
def forward(self, x, y):
5153
return x ^ y, x | y, x & y, ~x
5155
x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5156
y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5157
self.run_test(MyModel(), input_args=(x, y))
5159
@skipIfUnsupportedMinOpsetVersion(9)
5160
def test_logical_and(self):
5161
class AndModel(torch.nn.Module):
5162
def forward(self, x, y):
5163
return torch.logical_and(x, y)
5165
x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5166
y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5167
self.run_test(AndModel(), input_args=(x, y))
5169
x = torch.randint(10, (5, 5), dtype=torch.int32)
5170
y = torch.randint(10, (5, 5), dtype=torch.int32)
5171
self.run_test(AndModel(), input_args=(x, y))
5173
x = torch.randint(10, (5, 5), dtype=torch.double)
5174
y = torch.randint(10, (5, 5), dtype=torch.double)
5175
self.run_test(AndModel(), input_args=(x, y))
5177
x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
5178
y = torch.randint(10, (2, 3, 5), dtype=torch.long)
5179
self.run_test(AndModel(), input_args=(x, y))
5181
@skipIfUnsupportedMinOpsetVersion(9)
5182
def test_logical_or(self):
5183
class OrModel(torch.nn.Module):
5184
def forward(self, x, y):
5185
return torch.logical_or(x, y)
5187
x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5188
y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5189
self.run_test(OrModel(), input_args=(x, y))
5191
x = torch.randint(10, (5, 5), dtype=torch.int32)
5192
y = torch.randint(10, (5, 5), dtype=torch.int32)
5193
self.run_test(OrModel(), input_args=(x, y))
5195
x = torch.randint(10, (5, 5), dtype=torch.double)
5196
y = torch.randint(10, (5, 5), dtype=torch.double)
5197
self.run_test(OrModel(), input_args=(x, y))
5199
x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
5200
y = torch.randint(10, (2, 3, 5), dtype=torch.long)
5201
self.run_test(OrModel(), input_args=(x, y))
5203
@skipIfUnsupportedMinOpsetVersion(9)
5204
def test_logical_xor(self):
5205
class XorModel(torch.nn.Module):
5206
def forward(self, x, y):
5207
return torch.logical_xor(x, y)
5209
x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5210
y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5211
self.run_test(XorModel(), input_args=(x, y))
5213
x = torch.randint(10, (5, 5), dtype=torch.int32)
5214
y = torch.randint(10, (5, 5), dtype=torch.int32)
5215
self.run_test(XorModel(), input_args=(x, y))
5217
x = torch.randint(10, (5, 5), dtype=torch.double)
5218
y = torch.randint(10, (5, 5), dtype=torch.double)
5219
self.run_test(XorModel(), input_args=(x, y))
5221
x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
5222
y = torch.randint(10, (2, 3, 5), dtype=torch.long)
5223
self.run_test(XorModel(), input_args=(x, y))
5225
@skipIfUnsupportedMinOpsetVersion(9)
5226
def test_logical_not(self):
5227
class NotModel(torch.nn.Module):
5228
def forward(self, x):
5229
return torch.logical_not(x)
5231
x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5232
self.run_test(NotModel(), input_args=(x,))
5234
x = torch.randint(10, (5, 5), dtype=torch.int32)
5235
self.run_test(NotModel(), input_args=(x,))
5237
x = torch.randint(10, (5, 5), dtype=torch.double)
5238
self.run_test(NotModel(), input_args=(x,))
5240
x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
5241
self.run_test(NotModel(), input_args=(x,))
5243
@skipIfUnsupportedMinOpsetVersion(11) # float equal added after opset 11
5245
class EqualModel(torch.nn.Module):
5246
def forward(self, input, other):
5247
return input == other
5249
self._test_compare_ops(EqualModel(), 2)
5252
class GreaterModel(torch.nn.Module):
5253
def forward(self, input, other):
5254
return input > other
5256
self._test_compare_ops(GreaterModel(), 2)
5258
@skipIfUnsupportedMinOpsetVersion(9)
5260
class GreaterOrEqualModel(torch.nn.Module):
5261
def forward(self, input, other):
5262
return input >= other
5264
self._test_compare_ops(GreaterOrEqualModel(), 2)
5266
def test_gt_scalar(self):
5267
class GreaterModel(torch.nn.Module):
5268
def forward(self, input):
5271
self._test_compare_ops(GreaterModel(), 1)
5273
def test_gt_primitive(self):
5274
class GreaterModel(torch.nn.Module):
5279
def forward(self, x: int):
5283
self.run_test(GreaterModel(), (x,))
5285
@skipIfUnsupportedMinOpsetVersion(9)
5286
def test_ge_scalar(self):
5287
class GreaterOrEqualModel(torch.nn.Module):
5288
def forward(self, input):
5291
self._test_compare_ops(GreaterOrEqualModel(), 1)
5294
class LessModel(torch.nn.Module):
5295
def forward(self, input, other):
5296
return input > other
5298
self._test_compare_ops(LessModel(), 2)
5300
@skipIfUnsupportedMinOpsetVersion(9)
5302
class LessOrEqualModel(torch.nn.Module):
5303
def forward(self, input, other):
5304
return input <= other
5306
self._test_compare_ops(LessOrEqualModel(), 2)
5308
def test_lt_scalar(self):
5309
class LessModel(torch.nn.Module):
5310
def forward(self, input):
5313
self._test_compare_ops(LessModel(), 1)
5315
@skipIfUnsupportedMinOpsetVersion(9)
5316
def test_le_scalar(self):
5317
class LessOrEqualModel(torch.nn.Module):
5318
def forward(self, input):
5321
self._test_compare_ops(LessOrEqualModel(), 1)
5323
def test_matmul(self):
5324
class MatmulModel(torch.nn.Module):
5325
def forward(self, input, other):
5326
return torch.matmul(input, other)
5328
x = torch.randn(3, 4, requires_grad=True)
5329
y = torch.randn(4, 5, requires_grad=True)
5330
self.run_test(MatmulModel(), (x, y))
5332
x = torch.randint(10, (3, 4))
5333
y = torch.randint(10, (4, 5))
5334
self.run_test(MatmulModel(), (x, y))
5336
def test_matmul_batch(self):
5337
class MatmulModel(torch.nn.Module):
5338
def forward(self, input, other):
5339
return torch.matmul(input, other)
5341
x = torch.randn(2, 3, 4, requires_grad=True)
5342
y = torch.randn(2, 4, 5, requires_grad=True)
5343
self.run_test(MatmulModel(), (x, y))
5345
x = torch.randint(10, (2, 3, 4))
5346
y = torch.randint(10, (2, 4, 5))
5347
self.run_test(MatmulModel(), (x, y))
5349
def _argmin_argmax_model(self, input):
5350
class ArgminArgmaxModel(torch.nn.Module):
5351
def forward(self, input):
5353
torch.argmin(input),
5354
torch.argmax(input),
5355
torch.argmin(input, keepdim=True),
5356
torch.argmax(input, keepdim=True),
5357
torch.argmin(input, dim=0, keepdim=True),
5358
torch.argmax(input, dim=1, keepdim=True),
5361
self.run_test(ArgminArgmaxModel(), input)
5363
@skipIfUnsupportedMinOpsetVersion(9)
5364
def test_argmin_argmax(self):
5365
input = torch.randn(7, 3, 5)
5366
self._argmin_argmax_model(input)
5368
# Argmin and Argmax with "select_last_index" is not supprted before opset 12
5369
# "select_last_index" was added in opset 12 to deal with corner case where the
5370
# same value appears multiple times in the tensor
5371
@skipIfUnsupportedMinOpsetVersion(12)
5372
def test_argmin_argmax_select_last_index(self):
5373
input = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 2.0]])
5374
self._argmin_argmax_model(input)
5376
input = torch.ones(7, 3, 5)
5377
self._argmin_argmax_model(input)
5379
def test_repeat(self):
5380
class RepeatModel(torch.nn.Module):
5381
def forward(self, x, y):
5382
x2 = x.repeat(y.shape[0], 1)
5386
x = torch.tensor([1, 2, 3])
5387
y = torch.tensor([4, 5, 8, 9])
5388
self.run_test(RepeatModel(), (x, y))
5390
@skipIfUnsupportedMinOpsetVersion(9)
5391
def test_repeat_interleave(self):
5392
class FlattenModel(torch.nn.Module):
5393
def forward(self, x):
5394
return x.repeat_interleave(2)
5396
for shape in ([3], [3, 4], [2, 3, 4]):
5397
x = torch.randn(shape)
5398
self.run_test(FlattenModel(), (x,))
5400
class DimsModel(torch.nn.Module):
5401
def forward(self, x):
5402
return x.repeat_interleave(4, dim=1)
5404
x = torch.tensor([[1, 2], [3, 4]])
5405
self.run_test(DimsModel(), (x,))
5407
class DimsModel2(torch.nn.Module):
5408
def forward(self, x):
5409
repeats = torch.tensor([4])
5410
return torch.repeat_interleave(x, repeats, dim=1)
5412
x = torch.tensor([[1, 2], [3, 4]])
5413
self.run_test(DimsModel2(), (x,))
5415
class RepeatsDimsModel(torch.nn.Module):
5416
def forward(self, x):
5417
repeats = torch.tensor([1, 2])
5418
return torch.repeat_interleave(x, repeats, dim=0)
5420
x = torch.tensor([[1, 2], [3, 4]])
5421
self.run_test(RepeatsDimsModel(), (x,))
5423
class RepeatsDimsModel2(torch.nn.Module):
5424
def forward(self, x):
5425
repeats = torch.tensor([1, 2])
5426
return torch.repeat_interleave(x, repeats, dim=1)
5428
x = torch.tensor([[1, 2], [3, 4]])
5429
self.run_test(RepeatsDimsModel2(), (x,))
5431
@skipIfUnsupportedMinOpsetVersion(9)
5432
def test_repeat_interleave_noop(self):
5433
class Model(torch.nn.Module):
5434
def forward(self, x):
5435
return x.repeat_interleave(1, dim=1)
5437
x = torch.randn(4, 1, 8)
5438
self.run_test(Model(), (x,))
5440
@skipIfUnsupportedMinOpsetVersion(13)
5441
def test_dynamic_repeat_interleave(self):
5442
class SingleDynamicModel(torch.nn.Module):
5443
def forward(self, x):
5444
repeats = torch.tensor(4)
5445
return torch.repeat_interleave(x, repeats, dim=1)
5447
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5448
another_x = torch.tensor([[7, 8], [5, 6]])
5450
SingleDynamicModel(),
5452
additional_test_inputs=[another_x],
5453
input_names=["input_1"],
5454
dynamic_axes={"input_1": {1: "w"}},
5457
class NegDynamicModel(torch.nn.Module):
5458
def forward(self, x):
5459
repeats = torch.tensor(4)
5460
return torch.repeat_interleave(x, repeats, dim=-1)
5462
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5463
another_x = torch.tensor([[7, 8], [5, 6]])
5467
additional_test_inputs=[another_x],
5468
input_names=["input_1"],
5469
dynamic_axes={"input_1": {1: "w"}},
5472
class SingleDynamicModelFloat(torch.nn.Module):
5473
def forward(self, x):
5474
repeats = torch.tensor([4])
5475
return torch.repeat_interleave(x, repeats, dim=0)
5477
x = torch.tensor([[1.1, 2.1], [3.1, 4.1]])
5478
another_x = torch.tensor([[7.1, 8.1], [5.1, 6.1]])
5480
SingleDynamicModelFloat(),
5482
additional_test_inputs=[another_x],
5483
input_names=["input_1"],
5484
dynamic_axes={"input_1": {0: "h"}},
5487
class DynamicRepeatsModel(torch.nn.Module):
5488
def forward(self, x, repeats):
5489
return torch.repeat_interleave(x, repeats, dim=1)
5491
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5492
another_x = torch.tensor([[7, 8], [5, 6]])
5493
repeats = torch.tensor([2])
5494
another_repeats = torch.tensor([4])
5496
DynamicRepeatsModel(),
5498
additional_test_inputs=[(another_x, another_repeats)],
5499
input_names=["input_1", "repeats_1"],
5500
dynamic_axes={"input_1": {1: "w"}, "repeats_1": {0: "r"}},
5503
class DynamicRepeatsModel2(torch.nn.Module):
5504
def forward(self, x, repeats):
5505
return torch.repeat_interleave(x, repeats, dim=1)
5507
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5508
repeats = torch.tensor([2])
5509
another_repeats = torch.tensor([4])
5511
DynamicRepeatsModel2(),
5513
additional_test_inputs=[(x, another_repeats)],
5514
input_names=["input_1", "repeats_1"],
5515
dynamic_axes={"repeats_1": {0: "r"}},
5518
class DynamicFlattenModel(torch.nn.Module):
5519
def forward(self, x):
5520
return x.repeat_interleave(2)
5522
x = torch.tensor([1, 2, 3])
5524
DynamicFlattenModel(),
5526
input_names=["input_1"],
5527
dynamic_axes={"input_1": {0: "w"}},
5530
@skipIfUnsupportedMinOpsetVersion(13)
5531
def test_multiple_dynamic_repeat_interleave(self):
5532
class DynamicRepeatsModel(torch.nn.Module):
5533
def forward(self, x, repeats):
5534
return torch.repeat_interleave(x, repeats, dim=1)
5536
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5537
repeats = torch.tensor([2, 3, 4])
5538
another_repeats = torch.tensor([4, 3, 2])
5540
DynamicRepeatsModel(),
5542
additional_test_inputs=[(x, another_repeats)],
5543
input_names=["input_1", "repeats_1"],
5544
dynamic_axes={"repeats_1": {0: "r"}},
5547
class DynamicRepeatsModel2(torch.nn.Module):
5548
def forward(self, x, repeats):
5549
return torch.repeat_interleave(x, repeats, dim=0)
5551
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5552
repeats = torch.tensor([2, 3])
5553
another_repeats = torch.tensor([4, 3])
5555
DynamicRepeatsModel2(),
5557
additional_test_inputs=[(x, another_repeats)],
5558
input_names=["input_1", "repeats_1"],
5559
dynamic_axes={"repeats_1": {0: "r"}},
5562
def test_view(self):
5563
class ViewModel(torch.nn.Module):
5564
def forward(self, input):
5565
return input.view(4, 24)
5567
x = torch.randint(10, (4, 2, 3, 4), dtype=torch.int32)
5568
self.run_test(ViewModel(), x)
5570
def test_view_dynamic(self):
5571
class ViewModel(torch.nn.Module):
5572
def forward(self, input, other):
5573
return input.view(other.shape)
5575
x = torch.randn(2, 3, 4)
5576
shape = torch.randn(6, 4)
5580
input_names=["x", "shape"],
5581
dynamic_axes={"x": [0, 1, 2], "shape": [0, 1]},
5583
self.run_test(ViewModel(), (x, shape), remained_onnx_input_idx=[0])
5585
def test_view_dynamic_zero_dim(self):
5586
class ViewModel(torch.nn.Module):
5587
def forward(self, input):
5588
input = input.view(-1, 2)
5589
return input.view(1, -1)
5592
another_x = torch.empty((0,))
5596
additional_test_inputs=[another_x],
5597
input_names=["input_1"],
5605
def test_view_as(self):
5606
class ViewModel(torch.nn.Module):
5607
def forward(self, input, other):
5608
return input.view_as(other)
5610
x = torch.randn(2, 3, 4)
5611
y = torch.randn(6, 4)
5612
self.run_test(ViewModel(), (x, y))
5614
def test_linear(self):
5615
class LinearModel(torch.nn.Module):
5618
self.fc = torch.nn.Linear(16, 16)
5620
def forward(self, x):
5625
x = torch.randn(3, 16)
5626
self.run_test(LinearModel(), (x,))
5628
class LinearModel(torch.nn.Module):
5629
def forward(self, input, weight, bias):
5630
return torch.nn.functional.linear(input, weight, bias)
5633
x = torch.randn(2, 2)
5634
y = torch.randn(2, 2)
5636
self.run_test(LinearModel(), (x, y, z))
5639
x = torch.randn(3, 3, 3)
5640
y = torch.randn(3, 3)
5642
self.run_test(LinearModel(), (x, y, z))
5645
def test_weight_norm(self):
5646
# addmm for 3-d inputs converts to onnx::MatMul
5647
model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1)
5648
x = torch.randn(3, 4, 5, requires_grad=True)
5649
self.run_test(model, x)
5651
# addmm for 2-d inputs converts to onnx::Gemm
5652
model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1)
5653
x = torch.randn(4, 5, requires_grad=True)
5654
self.run_test(model, x)
5656
model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3))
5657
x = torch.randn(1, 1, 5, requires_grad=True)
5658
self.run_test(model, x)
5660
model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3), dim=-2)
5661
x = torch.randn(1, 1, 5, requires_grad=True)
5662
self.run_test(model, x)
5664
model = torch.nn.utils.weight_norm(torch.nn.Conv1d(3, 6, 3), name="weight")
5665
x = torch.randn(3, 3, 5, requires_grad=True)
5666
self.run_test(model, x)
5669
def test_weight_norm_nodim(self):
5670
# addmm for 3-d inputs converts to onnx::MatMul
5671
model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None)
5672
x = torch.randn(3, 4, 5, requires_grad=True)
5673
self.run_test(model, x)
5675
# addmm for 2-d inputs converts to onnx::Gemm
5676
model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None)
5677
x = torch.randn(4, 5, requires_grad=True)
5678
self.run_test(model, x)
5680
def test_flatten(self):
5681
class FlattenModel(torch.nn.Module):
5682
def forward(self, input):
5683
return torch.flatten(input)
5685
model = FlattenModel()
5687
# flatten with 4d input
5688
x = torch.randint(10, (1, 2, 3, 4))
5689
self.run_test(model, x)
5691
# flatten with 0d input
5693
self.run_test(model, x)
5695
# flatten with 1d input
5697
self.run_test(model, x)
5699
def test_flatten2d(self):
5700
class FlattenModel(torch.nn.Module):
5701
def forward(self, input):
5702
return torch.flatten(input, 1)
5704
x = torch.randint(10, (1, 2, 3, 4))
5705
self.run_test(FlattenModel(), x)
5707
def test_flatten2d_neg(self):
5708
class FlattenModel(torch.nn.Module):
5709
def forward(self, x):
5711
torch.flatten(x, 1, -1),
5712
torch.flatten(x, 0, -2),
5713
torch.flatten(x, 1, -2),
5716
x = torch.randint(10, (1, 2, 3, 4))
5717
self.run_test(FlattenModel(), x)
5719
@skipIfUnsupportedMinOpsetVersion(9)
5720
def test_flatten_dynamic_axes(self):
5721
class MyModule(torch.nn.Module):
5722
def forward(self, x):
5723
return torch.flatten(x, start_dim=2, end_dim=3)
5726
x = torch.randn(batch_size, 5, 4, 5)
5727
y = torch.randn(5, 5, 4, 5)
5732
additional_test_inputs=[y],
5733
input_names=["input"],
5734
output_names=["output"],
5735
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
5738
@skipIfUnsupportedMinOpsetVersion(11)
5739
def test_getitem(self):
5740
class GetItemModel(torch.jit.ScriptModule):
5741
@torch.jit.script_method
5742
def forward(self, x, y, z, ind):
5743
# this will create prim::ListConstruct(x, y, z) + aten::__getitem__
5747
x = torch.randn(3, 4, 5)
5748
y = torch.randn(1, 4, 5)
5749
z = torch.randn(2, 4, 5)
5750
ind = torch.tensor(1, dtype=torch.long)
5751
self.run_test(GetItemModel(), (x, y, z, ind))
5753
ind = torch.tensor(-2, dtype=torch.long)
5754
self.run_test(GetItemModel(), (x, y, z, ind))
5757
def test_item(self):
5758
class M(torch.nn.Module):
5759
def forward(self, x, y, i: int):
5760
return int(x[y[i]].item())
5762
x = torch.arange(6, dtype=torch.float)
5763
y = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)
5765
self.run_test(torch.jit.script(M()), (x, y, i))
5767
@skipScriptTest() # torch.nonzero(x, as_tuple=True) is not scriptable.
5768
@skipIfUnsupportedMinOpsetVersion(9)
5769
def test_nonzero(self):
5770
class NonzeroModel(torch.nn.Module):
5771
def forward(self, x):
5772
return x.nonzero(), x.nonzero(as_tuple=True)
5774
x = torch.randn(60).index_fill_(0, torch.randint(0, 60, (20,)), 0).view(3, 4, 5)
5775
self.run_test(NonzeroModel(), (x,))
5777
def test_unbind(self):
5778
class UnbindModel(torch.nn.Module):
5779
def forward(self, input):
5780
_, out, _ = input.unbind()
5783
x = torch.randn(3, 4, 5)
5784
self.run_test(UnbindModel(), x)
5786
class UnbindModel2(torch.nn.Module):
5787
def forward(self, input):
5788
_, out, _, _ = input.unbind(1)
5791
x = torch.randn(3, 4, 5)
5792
self.run_test(UnbindModel2(), x)
5794
class UnbindModel3(torch.nn.Module):
5795
def forward(self, input):
5796
_, out, _, _ = input.unbind(-2)
5799
x = torch.randn(3, 4, 5)
5800
self.run_test(UnbindModel3(), x)
5802
@skipIfUnsupportedMinOpsetVersion(11)
5804
class LenModel(torch.jit.ScriptModule):
5805
@torch.jit.script_method
5806
def forward(self, input):
5807
return len(input.unbind()) + input
5809
x = torch.randn(4, 5)
5813
input_names=["input"],
5814
dynamic_axes={"input": {0: "seq"}},
5815
additional_test_inputs=(torch.randn(5, 5),),
5818
@skipIfUnsupportedMinOpsetVersion(9)
5819
def test_len_list(self):
5820
class LenListModel(torch.jit.ScriptModule):
5821
@torch.jit.script_method
5822
def forward(self, input):
5823
return torch.ones(len(input.shape))
5825
x = torch.randn(4, 5)
5826
self.run_test(LenListModel(), x, remained_onnx_input_idx=[])
5828
@skipIfUnsupportedMinOpsetVersion(11)
5829
def test_unbind_dynamic(self):
5830
class UnbindModel(torch.jit.ScriptModule):
5831
@torch.jit.script_method
5832
def forward(self, input):
5833
return input.unbind()[1]
5835
x = torch.randn(3, 4, 5)
5836
self.run_test(UnbindModel(), x)
5838
class UnbindModel2(torch.jit.ScriptModule):
5839
@torch.jit.script_method
5840
def forward(self, input):
5841
return input.unbind(-1)[1]
5843
x = torch.randn(3, 4, 5)
5844
self.run_test(UnbindModel2(), x)
5846
@skipScriptTest() # scripting tests run for opsets > 11. See: test_split_script
5847
def test_split(self):
5848
class SplitModel(torch.nn.Module):
5849
def forward(self, input):
5850
return input.split([2, 1, 2]), input.split([3, 2])[0]
5852
x = torch.randn(5, 4, 3)
5853
self.run_test(SplitModel(), x)
5855
class SplitModel2(torch.nn.Module):
5856
def forward(self, input):
5857
return input.split([2, 1, 1], -2), input.split([2, 2], -2)[-1]
5859
x = torch.randn(5, 4, 3)
5860
self.run_test(SplitModel2(), x)
5862
class SplitModel3(torch.nn.Module):
5863
def forward(self, input):
5864
return input.split([2, 1, 2])
5866
x = torch.randn(5, 4, 3)
5867
self.run_test(SplitModel3(), x)
5869
@skipIfUnsupportedMinOpsetVersion(11)
5870
def test_split_script(self):
5871
class SplitModel(torch.nn.Module):
5872
def forward(self, input):
5873
return input.split([2, 1, 2]), input.split([3, 2])[0]
5875
x = torch.randn(5, 4, 3)
5876
self.run_test(SplitModel(), x)
5878
class SplitModel2(torch.nn.Module):
5879
def forward(self, input):
5880
return input.split([2, 1, 1], -2), input.split([2, 2], -2)[-1]
5882
x = torch.randn(5, 4, 3)
5883
self.run_test(SplitModel2(), x)
5885
class SplitModel3(torch.nn.Module):
5886
def forward(self, input):
5887
return input.split([2, 1, 2])
5889
x = torch.randn(5, 4, 3)
5890
self.run_test(SplitModel3(), x)
5892
@skipIfUnsupportedMinOpsetVersion(11)
5894
def test_split_size_as_list(self):
5895
class SplitModel(torch.nn.Module):
5896
def forward(self, input, split_sizes: List[int]):
5898
split_list: List[Tensor] = input.split(split_sizes)
5900
for ob in split_list:
5901
out.append(ob) # noqa: PERF402
5902
return torch.cat(out, dim=0)
5904
x = torch.randn(6, 4, 3)
5905
split_sizes = [torch.tensor(2), torch.tensor(4)]
5906
self.run_test(SplitModel(), (x, split_sizes))
5908
@skipIfUnsupportedMinOpsetVersion(11)
5909
def test_split_size_with_slice(self):
5910
class SplitModule(torch.nn.Module):
5911
def forward(self, x, y, t):
5912
splits = (x.size(1), y.size(1))
5913
out, out2 = torch.split(t, splits, dim=1)
5916
x = torch.randn(2, 3)
5917
y = torch.randn(2, 4)
5918
t = torch.randn(2, 7)
5922
input_names=["x", "y", "t"],
5923
dynamic_axes={"x": [0, 1], "y": [0, 1], "t": [0, 1]},
5925
self.run_test(SplitModule(), (x, y, t), remained_onnx_input_idx=[2])
5927
@skipIfUnsupportedMinOpsetVersion(11)
5928
def test_split_dynamic(self):
5929
class SplitModel(torch.jit.ScriptModule):
5930
@torch.jit.script_method
5931
def forward(self, input):
5932
return input.split(2)[1]
5934
x = torch.randn(5, 4, 3)
5935
self.run_test(SplitModel(), x)
5937
class SplitModel2(torch.jit.ScriptModule):
5938
@torch.jit.script_method
5939
def forward(self, input):
5940
return input.split(2, -3)[1]
5942
x = torch.randn(5, 4, 3)
5943
self.run_test(SplitModel2(), x)
5945
@skipIfUnsupportedMinOpsetVersion(11)
5946
def test_split_dynamic_axes(self):
5947
class Split(torch.nn.Module):
5948
def forward(self, x):
5949
return x.split(1, dim=-1)
5951
x = torch.randn(4, 384, 2)
5952
input_names = ["logits"]
5956
input_names=input_names,
5957
dynamic_axes={input_names[0]: {0: "batch"}},
5960
@skipIfUnsupportedMinOpsetVersion(11)
5961
def test_chunk(self):
5962
class ChunkModel(torch.nn.Module):
5963
def __init__(self, dim=1):
5967
def forward(self, x):
5968
return torch.chunk(x, 3, dim=self.dim)
5970
model = ChunkModel()
5972
model_neg_dim = ChunkModel(-1)
5973
model_neg_dim.eval()
5974
x = torch.randn(1, 18)
5976
for dim_size_ in range(13, 16):
5977
y = torch.randn(1, dim_size_)
5981
additional_test_inputs=[y],
5983
dynamic_axes={"x": {0: "batch_size", 1: "dims"}},
5989
additional_test_inputs=[y],
5991
dynamic_axes={"x": {0: "batch_size", 1: "dims"}},
5994
@skipIfUnsupportedMinOpsetVersion(11)
5995
def test_dynamic_chunk(self):
5996
class ChunkModel(torch.nn.Module):
5997
def __init__(self, dim=1):
6001
def forward(self, x):
6002
return torch.chunk(x, x.size(0), dim=self.dim)
6004
model = ChunkModel()
6006
model_neg_dim = ChunkModel(-1)
6007
model_neg_dim.eval()
6008
x = torch.randn(3, 18)
6010
for dim_size_ in range(13, 16):
6011
y = torch.randn(3, dim_size_)
6015
additional_test_inputs=[y],
6017
dynamic_axes={"x": {0: "batch_size", 1: "dims"}},
6023
additional_test_inputs=[y],
6025
dynamic_axes={"x": {0: "batch_size", 1: "dims"}},
6028
def test_concat(self):
6029
class ConcatModel(torch.nn.Module):
6030
def forward(self, x, y, z):
6031
return torch.cat((x, y, z))
6033
x = torch.randn(3, 4, 5)
6034
y = torch.randn(1, 4, 5)
6035
z = torch.randn(2, 4, 5)
6036
self.run_test(ConcatModel(), (x, y, z))
6038
@skipIfUnsupportedMinOpsetVersion(11)
6039
def test_concat_dynamic(self):
6040
class ConcatDynamicModel(torch.jit.ScriptModule):
6041
@torch.jit.script_method
6042
def forward(self, x):
6043
return torch.cat(x.unbind())
6045
x = torch.randn(4, 5, 6)
6046
self.run_test(ConcatDynamicModel(), x)
6048
def test_stack(self):
6049
class StackModel(torch.nn.Module):
6050
def forward(self, x, y, z):
6051
return torch.stack((x, y, z), 1)
6053
x = torch.randn(3, 4, 5)
6054
y = torch.randn(3, 4, 5)
6055
z = torch.randn(3, 4, 5)
6056
self.run_test(StackModel(), (x, y, z))
6058
@skipIfUnsupportedMinOpsetVersion(11)
6059
def test_stack_dynamic(self):
6060
class StackDynamicModel(torch.jit.ScriptModule):
6061
@torch.jit.script_method
6062
def forward(self, x):
6063
return torch.stack(x.unbind(), 1)
6065
x = torch.randn(4, 5, 6)
6066
self.run_test(StackDynamicModel(), x)
6068
def test_loop_dynamic(self):
6069
class LoopModel(torch.jit.ScriptModule):
6070
@torch.jit.script_method
6071
def forward(self, x):
6072
for i in range(x.size(2)):
6077
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
6078
self.run_test(model, inputs)
6080
@skipIfUnsupportedMinOpsetVersion(9)
6081
def test_loop_nested(self):
6082
class NestedLoopsModel(torch.jit.ScriptModule):
6083
@torch.jit.script_method
6084
def forward(self, x):
6092
model = NestedLoopsModel()
6093
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
6094
self.run_test(model, inputs)
6096
@skipIfUnsupportedMinOpsetVersion(11)
6097
def test_loop_with_list(self):
6098
class ListLoopModel(torch.jit.ScriptModule):
6099
@torch.jit.script_method
6100
def forward(self, x):
6103
arr = x.split([3, 4, 1, 1, 2, 3, 2], 0)
6104
res2 = torch.zeros(3, 4, dtype=torch.long)
6107
for i in range(len(arr)):
6108
res.append(arr[i].sum(0, False))
6109
res1.append(arr[-1 - i].sum(0, False))
6111
res3 = res3 + [arr[i].sum(0, False)]
6112
res4 += [arr[-1 - i].sum(0, False)]
6113
return res, res1, res2, torch.stack(res3), torch.stack(res4)
6115
model = ListLoopModel()
6116
inputs = torch.randn(16)
6117
self.run_test(model, inputs)
6119
@skipIfUnsupportedMinOpsetVersion(11)
6120
def test_loop_transpose(self):
6121
class LoopModel(torch.nn.Module):
6122
def forward(self, x):
6123
res = torch.zeros_like(x[0])
6124
for i in range(x.size(0)):
6125
res += x[0].transpose(0, 1)
6128
model = torch.jit.script(LoopModel())
6129
x = torch.randn(5, 3, 3)
6130
self.run_test(model, x)
6132
@skipIfUnsupportedMinOpsetVersion(11)
6133
def test_loop_multi_dim(self):
6134
class LoopMultiDimModel(torch.jit.ScriptModule):
6135
@torch.jit.script_method
6136
def forward(self, x, y):
6137
for x_ in torch.flip(x.narrow(0, 0, 7), [0]):
6141
model = LoopMultiDimModel()
6142
x = torch.randint(0, 5, (8, 1, 17), dtype=torch.long)
6143
y = torch.ones(1, dtype=torch.long)
6144
self.run_test(model, (x, y))
6146
@skipIfUnsupportedMinOpsetVersion(11)
6147
def test_list(self):
6148
class ListModel(torch.jit.ScriptModule):
6149
@torch.jit.script_method
6150
def forward(self, x):
6151
tensors = x.unbind()
6153
res.append(tensors[0])
6154
res.append(tensors[1])
6157
res.insert(0, tensors[1])
6158
res.append(tensors[2])
6159
res += [tensors[3], tensors[4]]
6160
res = res + [tensors[5]]
6161
return torch.ones(len(res))
6164
inputs = torch.randn(16, 1)
6165
self.run_test(model, inputs)
6167
@skipIfUnsupportedMinOpsetVersion(11)
6168
def test_list_append(self):
6169
class ListModel(torch.nn.Module):
6170
def forward(self, x, y):
6172
for i in range(x.size(0)):
6173
res += [torch.matmul(x[i], y)]
6176
model = torch.jit.script(ListModel())
6177
x = torch.randn(16, 3, 4)
6178
y = torch.randn(4, 5)
6179
self.run_test(model, (x, y))
6181
@skipIfUnsupportedMinOpsetVersion(13)
6182
def test_list_append_nested(self):
6183
class ListModel(torch.nn.Module):
6184
def forward(self, x, y):
6186
for i in range(x.size(0)):
6187
for j in range(x.size(1)):
6188
res += [torch.matmul(x[i][j], y)]
6191
model = torch.jit.script(ListModel())
6192
x = torch.randn(4, 4, 3, 4)
6193
y = torch.randn(4, 5)
6194
self.run_test(model, (x, y))
6196
@skipIfUnsupportedMinOpsetVersion(14) # Need onnx::Identity of sequence in opset 14
6197
def test_list_append_nested_2(self):
6198
class ListModel(torch.nn.Module):
6199
def forward(self, x):
6202
for i in range(x.size(0)):
6204
for j in range(x.size(1)):
6206
res_replicate.append(res[-1])
6207
res.append(res_replicate[-1])
6208
return res, res_replicate
6210
model = torch.jit.script(ListModel())
6211
x = torch.randn(4, 4, 3, 4)
6212
self.run_test(model, (x,))
6214
@skipIfUnsupportedMinOpsetVersion(13)
6215
def test_list_append_nested_mixed_dtype(self):
6216
class ListModel(torch.nn.Module):
6217
def forward(self, x, y):
6219
for i in range(x.size(0)):
6220
for j in range(x.size(1)):
6227
model = torch.jit.script(ListModel())
6228
x = torch.randn(4, 4, 3, 4)
6229
y = torch.randn(3, 4)
6230
self.run_test(model, (x, y))
6232
@skipIfUnsupportedMinOpsetVersion(11)
6233
def test_list_pop(self):
6234
class ListModel(torch.nn.Module):
6235
def forward(self, x, y):
6237
for i in range(x.size(0)):
6238
res += [torch.matmul(x[i], y)]
6242
model = torch.jit.script(ListModel())
6243
x = torch.randn(16, 3, 4)
6244
y = torch.randn(4, 5)
6245
self.run_test(model, (x, y))
6247
@skipIfUnsupportedMinOpsetVersion(13)
6248
def test_list_pop_nested(self):
6249
class ListModel(torch.nn.Module):
6250
def forward(self, x, y):
6252
for i in range(x.size(0)):
6253
for j in range(x.size(1)):
6254
res += [torch.matmul(x[i][j], y)]
6256
res += [torch.matmul(x[i][0], y)]
6259
model = torch.jit.script(ListModel())
6260
x = torch.randn(4, 4, 3, 4)
6261
y = torch.randn(4, 5)
6262
self.run_test(model, (x, y))
6264
@skipIfUnsupportedMinOpsetVersion(11)
6265
def test_list_del(self):
6266
class ListModel(torch.nn.Module):
6267
def forward(self, x, y):
6269
for i in range(x.size(0)):
6270
res += [torch.matmul(x[i], y)]
6274
model = torch.jit.script(ListModel())
6275
x = torch.randn(16, 3, 4)
6276
y = torch.randn(4, 5)
6277
self.run_test(model, (x, y))
6279
@skipIfUnsupportedMinOpsetVersion(13)
6280
def test_list_del_nested(self):
6281
class ListModel(torch.nn.Module):
6282
def forward(self, x, y):
6284
for i in range(x.size(0)):
6285
for j in range(x.size(1)):
6286
res += [torch.matmul(x[i][j], y)]
6288
res += [torch.matmul(x[i][0], y)]
6291
model = torch.jit.script(ListModel())
6292
x = torch.randn(4, 4, 3, 4)
6293
y = torch.randn(4, 5)
6294
self.run_test(model, (x, y))
6296
@skipIfUnsupportedMinOpsetVersion(11)
6297
def test_list_set(self):
6298
class ListModel(torch.nn.Module):
6299
def forward(self, x, y):
6301
for i in range(x.size(0)):
6306
model = torch.jit.script(ListModel())
6307
x = torch.randn(12, 4)
6308
y = torch.tensor(2, dtype=torch.long)
6309
self.run_test(model, (x, y))
6311
@skipIfUnsupportedMinOpsetVersion(13)
6312
def test_list_idx_sum(self):
6313
class ListModel(torch.nn.Module):
6314
def forward(self, x, y):
6315
indices = torch.arange(x.size(0))
6317
for i in range(x.size(0)):
6319
return res[torch.sum(indices[:y])]
6321
model = torch.jit.script(ListModel())
6322
x = torch.randn(12, 4)
6323
y = torch.tensor(2, dtype=torch.long)
6324
self.run_test(model, (x, y))
6326
@skipIfUnsupportedMinOpsetVersion(9)
6327
def test_tensor_factories(self):
6328
class TensorFactory(torch.nn.Module):
6329
def forward(self, x):
6330
return torch.zeros(x.size()) + torch.ones(x.size())
6332
x = torch.randn(2, 3, 4)
6334
TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
6336
self.run_test(TensorFactory(), x, remained_onnx_input_idx=[])
6338
@skipIfUnsupportedMinOpsetVersion(9)
6339
def test_tensor_factories_script(self):
6340
class TensorFactory(torch.jit.ScriptModule):
6341
@torch.jit.script_method
6342
def forward(self, x):
6343
return torch.zeros(x.shape, dtype=torch.float) + torch.ones(
6344
x.shape, dtype=torch.float
6347
x = torch.randn(2, 3, 4)
6349
TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
6351
self.run_test(TensorFactory(), x, remained_onnx_input_idx=[])
6353
@skipIfUnsupportedMinOpsetVersion(9)
6354
def test_tensor_like_factories_script(self):
6355
class TensorFactory(torch.jit.ScriptModule):
6356
@torch.jit.script_method
6357
def forward(self, x):
6358
zeros = torch.zeros_like(
6361
layout=torch.strided,
6362
device=torch.device("cpu"),
6364
ones = torch.ones_like(
6367
layout=torch.strided,
6368
device=torch.device("cpu"),
6372
x = torch.randn(2, 3, 4)
6374
TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
6376
self.run_test(TensorFactory(), x, remained_onnx_input_idx=[])
6378
@skipIfUnsupportedMinOpsetVersion(13)
6379
def test_tensor_split(self):
6380
class TensorSplitModel(torch.nn.Module):
6381
def forward(self, input):
6383
input.tensor_split([1, 3]),
6384
# test with output indexing.
6385
input.tensor_split([2, 4])[0],
6386
# test split on specific dim.
6387
input.tensor_split([1, 3, 4], dim=-2),
6388
# test split on specific dim and output indexing.
6389
input.tensor_split([0, 2], dim=-2)[-1],
6390
# test with out of bound end index (5).
6391
input.tensor_split([2, 3, 5]),
6394
self.run_test(TensorSplitModel(), torch.randn(5, 4, 3))
6396
@skipIfUnsupportedMinOpsetVersion(13)
6397
def test_tensor_split_scalar(self):
6398
class TensorSplitModel(torch.nn.Module):
6399
def forward(self, x):
6400
return torch.tensor_split(x, x.size(1))
6402
self.run_test(TensorSplitModel(), torch.randn(1, 2, 3))
6404
@skipIfUnsupportedMinOpsetVersion(13)
6405
def test_tensor_split_dynamic_axes(self):
6406
class TensorSplitModel(torch.nn.Module):
6407
def forward(self, x):
6408
return x.tensor_split(1, dim=-1)
6410
x = torch.randn(4, 384, 2)
6411
input_names = ["logits"]
6415
input_names=input_names,
6416
dynamic_axes={input_names[0]: {0: "batch"}},
6419
@skipIfUnsupportedMinOpsetVersion(9)
6421
class TensorFactory(torch.nn.Module):
6422
def forward(self, x):
6424
torch.eye(x.size()[1], 3),
6425
torch.eye(4, 4, dtype=torch.long),
6426
torch.eye(x.size()[1], 2, dtype=torch.long),
6427
torch.eye(x.shape[0]),
6428
torch.eye(x.shape[0], dtype=torch.float64),
6431
x = torch.randn(2, 3, 4)
6432
another_x = torch.randn(5, 6, 7)
6436
additional_test_inputs=[another_x],
6437
input_names=["input_1"],
6438
dynamic_axes={"input_1": [0, 1, 2]},
6441
@skipIfUnsupportedMinOpsetVersion(13)
6442
def test_diagonal(self):
6443
class DiagonalModel(torch.nn.Module):
6444
def forward(self, x):
6445
return torch.diagonal(x)
6447
x = torch.randn(2, 4, 5, 2)
6448
# Other test inputs to test dynamic behavior
6449
another_x = torch.randn(5, 6, 7, 8)
6453
additional_test_inputs=[another_x],
6454
input_names=["input_1"],
6455
dynamic_axes={"input_1": [0, 1, 2, 3]},
6458
class DiagonalModelNegOffset(torch.nn.Module):
6459
def forward(self, x):
6460
return torch.diagonal(x, offset=-1)
6462
x = torch.randn(2, 4, 5, 2)
6463
# Other test inputs to test dynamic behavior
6464
another_x = torch.randn(5, 6, 7, 8)
6466
DiagonalModelNegOffset(),
6468
additional_test_inputs=[another_x],
6469
input_names=["input_1"],
6470
dynamic_axes={"input_1": [0, 1, 2, 3]},
6473
class DiagonalModelPosOffset(torch.nn.Module):
6474
def forward(self, x):
6475
return torch.diagonal(x, offset=1)
6477
x = torch.randn(2, 4, 5, 2)
6478
# Other test inputs to test dynamic behavior
6479
another_x = torch.randn(5, 6, 7, 8)
6481
DiagonalModelPosOffset(),
6483
additional_test_inputs=[another_x],
6484
input_names=["input_1"],
6485
dynamic_axes={"input_1": [0, 1, 2, 3]},
6488
class DiagonalModelWithDims(torch.nn.Module):
6489
def forward(self, x):
6490
return torch.diagonal(x, offset=-1, dim1=1, dim2=2)
6492
x = torch.randn(2, 4, 5, 2)
6493
# Other test inputs to test dynamic behavior
6494
another_x = torch.randn(5, 6, 7, 8)
6496
DiagonalModelWithDims(),
6498
additional_test_inputs=[another_x],
6499
input_names=["input_1"],
6500
dynamic_axes={"input_1": [0, 1, 2, 3]},
6503
class DiagonalModelWithNegativeDims(torch.nn.Module):
6504
def forward(self, x):
6505
return torch.diagonal(x, offset=0, dim1=-2, dim2=-1)
6507
x = torch.randn(2, 4, 5, 2)
6508
# Other test inputs to test dynamic behavior
6509
another_x = torch.randn(5, 6, 7, 8)
6511
DiagonalModelWithNegativeDims(),
6513
additional_test_inputs=[another_x],
6514
input_names=["input_1"],
6515
dynamic_axes={"input_1": [0, 1, 2, 3]},
6518
class DiagonalModelOffsetOverrun(torch.nn.Module):
6519
def forward(self, x):
6520
return torch.diagonal(x, offset=-2), torch.diagonal(x, offset=5)
6522
x = torch.randn(2, 4, 5, 2)
6523
# Other test inputs to test dynamic behavior
6524
another_x = torch.randn(5, 6, 7, 8)
6526
DiagonalModelOffsetOverrun(),
6528
additional_test_inputs=[another_x],
6529
input_names=["input_1"],
6530
dynamic_axes={"input_1": [0, 1, 2, 3]},
6533
@skipIfUnsupportedMinOpsetVersion(9)
6534
def test_inplace_zero(self):
6535
class Zero_(torch.nn.Module):
6536
def forward(self, x):
6539
x = torch.randn(2, 3, 4)
6540
self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6541
self.run_test(Zero_(), x, remained_onnx_input_idx=[])
6543
@skipIfUnsupportedMinOpsetVersion(11)
6544
def test_inplace_zero_qkv(self):
6545
class Zero_(torch.nn.Module):
6546
def forward(self, x):
6547
return x[2:4].zero_()
6549
x = torch.randn(24, 3, 4)
6550
self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6552
@skipIfUnsupportedMinOpsetVersion(9)
6553
def test_new_zeros(self):
6554
class Zero_(torch.nn.Module):
6555
def forward(self, x):
6556
return x.new_zeros(x.shape[1:2]), x.new_zeros(
6557
x.shape[2:], dtype=torch.long
6560
x = torch.randn(2, 3, 4)
6561
self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6562
self.run_test(Zero_(), x, remained_onnx_input_idx=[])
6564
@skipIfUnsupportedMinOpsetVersion(9)
6565
def test_new_zeros_with_dtype(self):
6566
class MyModel(torch.nn.Module):
6569
self.emb = torch.nn.Embedding(50, 64)
6571
def forward(self, x):
6572
inp = x.new_zeros(x.shape)
6573
return self.emb(inp)
6576
x = torch.Tensor([[2, 5, 6], [3, 2, 5]]).to(torch.int64)
6577
self.run_test(model, x, input_names=["x"], dynamic_axes={"x": [0, 1]})
6579
@skipIfUnsupportedMinOpsetVersion(9)
6580
def test_new_ones(self):
6581
class OnesModel(torch.nn.Module):
6582
def forward(self, x):
6583
return x.new_ones(x.shape[1:2]), x.new_ones(
6584
x.shape[2:], dtype=torch.long
6587
x = torch.randn(2, 3, 4)
6588
self.run_test(OnesModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6589
self.run_test(OnesModel(), x, remained_onnx_input_idx=[])
6591
@skipIfUnsupportedMinOpsetVersion(9)
6592
@skipScriptTest() # torch.zeros/torch.ones with size tensor of dim != 0 not scriptable.
6593
def test_zeros_ones_with_tensor_input(self):
6594
class ZeroAndOnes(torch.nn.Module):
6595
def forward(self, x):
6596
return torch.zeros(x, 1), torch.ones(x, 1)
6598
x = torch.tensor([2])
6599
self.run_test(ZeroAndOnes(), (x,))
6601
@skipIfUnsupportedMinOpsetVersion(9)
6603
def test_tolist(self):
6604
class List(torch.jit.ScriptModule):
6605
@torch.jit.script_method
6606
def forward(self, input):
6607
res: List[int] = input.tolist()
6610
self.run_test(List(), (torch.randint(100, (1,)),))
6612
@skipIfUnsupportedMinOpsetVersion(9)
6613
def test_list_pass(self):
6614
class Slice(torch.nn.Module):
6615
def forward(self, x, y):
6616
return x.new_zeros(x.shape[2:] + y.shape[1:])
6618
x = torch.randn(2, 3, 4, 5)
6619
y = torch.randn(1, 2, 3, 4)
6623
input_names=["x", "y"],
6624
dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]},
6626
self.run_test(Slice(), (x, y), remained_onnx_input_idx=[])
6628
class Size(torch.nn.Module):
6629
def forward(self, x, y):
6630
return x.new_zeros(x.shape + y.shape)
6632
x = torch.randn(2, 3, 4)
6633
y = torch.randn(1, 2, 3)
6637
input_names=["x", "y"],
6638
dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
6640
self.run_test(Size(), (x, y), remained_onnx_input_idx=[])
6642
class Array(torch.nn.Module):
6643
def forward(self, x, y):
6644
arr1 = [x.shape[0], x.shape[1], 2]
6645
arr2 = [y.shape[0], y.shape[1]]
6646
return x.new_zeros(arr1 + arr2)
6648
x = torch.randn(2, 3, 4)
6649
y = torch.randn(1, 2, 3)
6653
input_names=["x", "y"],
6654
dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
6656
self.run_test(Array(), (x, y), remained_onnx_input_idx=[])
6658
class List(torch.nn.Module):
6659
def forward(self, x, y):
6662
return x.new_zeros(l1 + l2)
6664
x = torch.randn(2, 3, 4)
6665
y = torch.randn(1, 2, 3)
6669
input_names=["x", "y"],
6670
dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
6672
self.run_test(List(), (x, y), remained_onnx_input_idx=[])
6674
@skipIfUnsupportedMinOpsetVersion(9)
6675
def test_new_empty(self):
6676
class Emtpy(torch.nn.Module):
6677
def forward(self, x):
6679
x.new_empty(x.shape[0]).fill_(0),
6680
x.new_empty(x.shape[0], dtype=torch.long) * 0,
6683
x = torch.randn(2, 3, 4)
6684
self.run_test(Emtpy(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6685
self.run_test(Emtpy(), x, remained_onnx_input_idx=[])
6687
@skipIfUnsupportedMinOpsetVersion(9)
6688
def test_new_full(self):
6689
class Full(torch.nn.Module):
6690
def forward(self, x):
6691
return x.new_full(x.shape[1:2], 5), x.new_full(
6692
x.shape[0:1], 1.3, dtype=torch.long
6695
x = torch.randn(2, 3, 4)
6696
self.run_test(Full(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6697
self.run_test(Full(), x, remained_onnx_input_idx=[])
6699
@skipIfUnsupportedMinOpsetVersion(9)
6700
def test_inplace_list(self):
6701
class Arithmetic(torch.jit.ScriptModule):
6702
@torch.jit.script_method
6703
def forward(self, x, y):
6704
return torch.cat([x.add_(3), y.fill_(0)])
6706
x = torch.randn(2, 3)
6707
y = torch.randn(2, 3)
6711
input_names=["x", "y"],
6712
dynamic_axes={"x": [0, 1], "y": [0, 1]},
6714
self.run_test(Arithmetic(), (x, y), remained_onnx_input_idx=[0])
6716
@skipIfUnsupportedMinOpsetVersion(9)
6717
def test_inplace_fill(self):
6718
class Fill_(torch.nn.Module):
6719
def forward(self, x):
6720
return x.fill_(3), x
6722
x = torch.randn(2, 3, 4)
6723
self.run_test(Fill_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6724
self.run_test(Fill_(), x, remained_onnx_input_idx=[])
6726
def test_inplace_arithmetic(self):
6727
class Arithmetic(torch.jit.ScriptModule):
6728
@torch.jit.script_method
6729
def forward(self, x, y):
6734
x = torch.randn(2, 3, 4)
6735
y = torch.randn(2, 3, 4)
6736
self.run_test(Arithmetic(), (x, y))
6738
def test_inplace_arithmetic_half(self):
6739
class InplaceAddModel(torch.nn.Module):
6740
def forward(self, x, y):
6743
class InplaceMulModel(torch.nn.Module):
6744
def forward(self, x, y):
6747
x = torch.randn(2, 2, dtype=torch.half)
6748
y = torch.randn(2, 2, dtype=torch.float)
6749
self.run_test(InplaceAddModel(), (x, y), rtol=1e-2, atol=1e-2)
6750
self.run_test(InplaceMulModel(), (x, y), rtol=1e-2, atol=1e-2)
6752
@skipIfUnsupportedMinOpsetVersion(9)
6753
def test_inplace_with_loop(self):
6754
class M(torch.nn.Module):
6755
def forward(self, x):
6771
self.run_test(torch.jit.script(M()), (x))
6773
@skipIfUnsupportedMinOpsetVersion(9)
6774
def test_inplace_with_loop_2(self):
6775
class M(torch.nn.Module):
6776
def forward(self, x):
6782
) # used in loop, altered.
6783
a_ref = a # not used in loop, should be altered.
6784
b = x.clone() # used in loop, not be altered.
6785
b_ref = b # not used in loop, should not be altered.
6805
# TODO: value for a_ref is incorrect.
6806
# a_ref += torch.ones(12,)
6807
b_ref += torch.ones(
6810
return _bias + x, a, b, b_ref
6816
self.run_test(torch.jit.script(M()), (x))
6818
@skipIfUnsupportedMinOpsetVersion(11)
6819
def test_inplace_attr_with_loop(self):
6820
class M(torch.nn.Module):
6823
self._bias = torch.arange(
6827
def forward(self, x):
6828
self._bias = torch.arange(
6834
self._bias += torch.arange(
6837
return self._bias + x
6843
self.run_test(torch.jit.script(M()), (x))
6845
@skipIfUnsupportedMinOpsetVersion(11)
6846
def test_inplace_attr_copy_with_loop(self):
6847
class M(torch.nn.Module):
6850
self._bias = torch.arange(
6854
def forward(self, x):
6855
self._bias = torch.arange(
6879
return self._bias + x
6885
self.run_test(torch.jit.script(M()), (x))
6887
@skipIfUnsupportedMinOpsetVersion(14) # Need onnx::Identity of sequence in opset 14
6888
def test_inplace_sequence_with_loop(self):
6889
class M(torch.nn.Module):
6890
def process(self, beam_hyps: List[Tensor], done: Tensor, x):
6891
batch_size = x.shape[0]
6892
for i in range(batch_size):
6897
for _, token in enumerate(x[i]):
6898
beam_hyps.append(token)
6904
done[i] = len(beam_hyps) > 4
6906
return beam_hyps, done
6908
def forward(self, x):
6909
beam_hyps: List[Tensor] = []
6910
batch_size = x.shape[0]
6912
max_len = x.shape[1]
6913
done = torch.zeros(batch_size, dtype=torch.bool)
6914
while cur_len < max_len:
6915
beam_hyps, done = self.process(beam_hyps, done, x[:, 0, :])
6916
cur_len = cur_len + 1
6920
m = torch.jit.script(M())
6921
x = torch.randn(8, 4, 3)
6922
self.run_test(torch.jit.script(M()), (x))
6924
@skipScriptTest() # Sort with dynamic dim not supported in ONNX
6925
def test_sort(self):
6926
class SortModel(torch.nn.Module):
6927
def forward(self, x):
6929
for i in range(-2, 2):
6930
out.append(torch.sort(x, dim=i, descending=True))
6933
x = torch.randn(3, 4)
6934
self.run_test(SortModel(), x)
6936
@skipIfUnsupportedMinOpsetVersion(11)
6937
@skipScriptTest() # Sort with dynamic dim not supported in ONNX
6938
def test_sort_ascending(self):
6939
class SortModel(torch.nn.Module):
6940
def forward(self, x):
6942
for i in range(-2, 2):
6943
out.append(torch.sort(x, dim=i, descending=False))
6946
x = torch.randn(3, 4)
6947
self.run_test(SortModel(), x)
6949
@skipIfUnsupportedMinOpsetVersion(11)
6950
def test_argsort(self):
6951
class ArgSortModel(torch.nn.Module):
6952
def forward(self, x):
6953
return torch.argsort(x, dim=1, descending=False)
6955
x = torch.randn(3, 4)
6956
self.run_test(ArgSortModel(), x)
6958
@skipIfUnsupportedMinOpsetVersion(9)
6959
def test_masked_fill(self):
6960
class MaskedFillModel(torch.nn.Module):
6961
def forward(self, x):
6962
mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.bool)
6963
return x.masked_fill(mask, 2)
6965
x = torch.zeros(4, 2, 3, requires_grad=True)
6966
self.run_test(MaskedFillModel(), x)
6968
class MaskedFillModel2(torch.nn.Module):
6969
def forward(self, x):
6970
return x.masked_fill(x > 3, -1)
6972
x = torch.arange(16).view(2, 2, 4).to(torch.float32)
6973
self.run_test(MaskedFillModel2(), x)
6975
@skipIfUnsupportedMinOpsetVersion(9)
6976
def test_masked_fill_inplace(self):
6977
class MaskedFillModel(torch.jit.ScriptModule):
6978
@torch.jit.script_method
6979
def forward(self, x):
6980
mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.bool)
6981
x.masked_fill_(mask, 2)
6984
x = torch.zeros(4, 2, 3, requires_grad=True)
6985
self.run_test(MaskedFillModel(), x)
6987
class MaskedFillModel2(torch.jit.ScriptModule):
6988
@torch.jit.script_method
6989
def forward(self, x):
6990
x.masked_fill_(x > 3, -1)
6993
x = torch.arange(16).view(2, 2, 4).to(torch.float32)
6994
self.run_test(MaskedFillModel2(), x)
6996
@skipIfUnsupportedMinOpsetVersion(11)
6997
def test_masked_scatter(self):
6998
class MaskedScatterModel(torch.nn.Module):
6999
def forward(self, x):
7000
return torch.masked_scatter(x, x.ge(0.5), torch.ones(100, 100) * 5)
7002
x = torch.randn(3, 4, 5, requires_grad=True)
7003
self.run_test(MaskedScatterModel(), x)
7005
@skipIfUnsupportedMinOpsetVersion(11)
7006
def test_masked_select(self):
7007
class MaskedSelectModel(torch.nn.Module):
7008
def forward(self, x):
7009
return torch.masked_select(x, x.ge(0.5))
7011
x = torch.randn(3, 4, 5, requires_grad=True)
7012
self.run_test(MaskedSelectModel(), x)
7014
@skipIfUnsupportedMinOpsetVersion(11)
7015
def test_index_put_to_masked_fill(self):
7016
class MaskedFillModel(torch.nn.Module):
7017
def forward(self, input_mask, some_const):
7018
mask = input_mask.clone()
7019
mask[mask != some_const] = 1
7020
mask[mask == some_const] = 0
7023
mask = torch.randn(2, 2, 2, requires_grad=True)
7024
constant = torch.tensor(5, dtype=torch.float)
7025
self.run_test(MaskedFillModel(), (mask, constant))
7027
@skipIfUnsupportedMinOpsetVersion(11)
7028
def test_index_put_to_masked_scatter(self):
7029
class MaskedScatterModel(torch.nn.Module):
7030
def forward(self, input_mask, some_const):
7031
mask = input_mask.clone()
7032
mask[mask != some_const] = torch.ones(8)
7035
mask = torch.randn(2, 2, 2, requires_grad=True)
7036
constant = torch.tensor(5, dtype=torch.float)
7037
self.run_test(MaskedScatterModel(), (mask, constant))
7039
@skipIfUnsupportedMinOpsetVersion(11)
7040
def test_index_put_with_1d_mask_to_masked_scatter(self):
7041
class MaskedScatterModel(torch.nn.Module):
7042
def forward(self, tensor, mask, some_const):
7043
tensor[mask] = some_const
7046
mask = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.bool)
7047
tensor = torch.randn(8, 4, 5, requires_grad=True)
7048
some_const = torch.randn(4, 4, 5, dtype=torch.float)
7049
self.run_test(MaskedScatterModel(), (tensor, mask, some_const))
7051
@skipIfUnsupportedMinOpsetVersion(9)
7052
def test_pixel_shuffle(self):
7053
class PixelShuffle(torch.nn.Module):
7054
def forward(self, x):
7055
return torch.pixel_shuffle(x, upscale_factor=2)
7057
x = torch.randn(2, 16, 4, 3, requires_grad=True)
7058
y = torch.randn(4, 32, 8, 4, requires_grad=True)
7059
self.run_test(PixelShuffle(), x)
7064
dynamic_axes={"x": [0, 1, 2, 3]},
7065
additional_test_inputs=[y],
7068
@skipIfUnsupportedMinOpsetVersion(9)
7069
def test_pixel_unshuffle(self):
7070
class PixelUnshuffle(torch.nn.Module):
7071
def forward(self, x):
7072
return torch.pixel_unshuffle(x, downscale_factor=2)
7074
x = torch.randn(2, 16, 4, 6, requires_grad=True)
7075
y = torch.randn(4, 32, 8, 4, requires_grad=True)
7076
self.run_test(PixelUnshuffle(), x)
7081
dynamic_axes={"x": [0, 1, 2, 3]},
7082
additional_test_inputs=[y],
7085
@skipIfUnsupportedMinOpsetVersion(9)
7086
def test_reciprocal(self):
7087
class ReciprocalModel(torch.nn.Module):
7088
def forward(self, x):
7089
return torch.reciprocal(x)
7091
model = ReciprocalModel()
7092
x = torch.tensor([2, 4])
7093
self.run_test(model, x.to(torch.long))
7094
self.run_test(model, x.to(torch.float))
7095
self.run_test(model, x.to(torch.double))
7097
@skipIfUnsupportedMinOpsetVersion(9)
7098
def test_scalar_type(self):
7099
class ArithmeticModel(torch.nn.Module):
7100
def forward(self, x):
7101
return x.size(0) * 2 * x, 2 - x
7103
x = torch.ones(2, 3, dtype=torch.float32)
7104
self.run_test(ArithmeticModel(), x)
7106
class ComparisonModel(torch.nn.Module):
7107
def forward(self, x, y):
7108
a = torch.tensor([12.0])
7109
return x.lt(1.5) & y.le(2) & x.le(1), x.gt(y), x.lt(y), a.ge(x.size(0))
7111
x = torch.ones(2, 3, dtype=torch.int32)
7112
y = torch.ones(2, 3, dtype=torch.float32)
7113
self.run_test(ComparisonModel(), (x, y))
7115
class MatMulModel(torch.nn.Module):
7116
def forward(self, x):
7117
return torch.mm(x, x) + x + torch.mm(x, x) + x
7119
x = torch.ones(3, 3)
7120
self.run_test(MatMulModel(), x)
7122
class AddMMModel(torch.nn.Module):
7123
def forward(self, x):
7124
return torch.mm(x, x) + x
7126
x = torch.ones(3, 3)
7127
self.run_test(AddMMModel(), x)
7129
class FullModel(torch.nn.Module):
7130
# add is used for exporting full
7131
def forward(self, x):
7132
return torch.full((3, 4), x)
7134
x = torch.tensor(12.0)
7135
self.run_test(FullModel(), x)
7137
class CatModel(torch.nn.Module):
7138
def forward(self, fp16, fp32):
7139
return torch.cat([fp16, fp32])
7141
fp16 = Tensor([0.5])
7143
fp32 = Tensor([1.5])
7144
self.run_test(CatModel(), (fp16, fp32))
7146
@skipIfUnsupportedMinOpsetVersion(9)
7147
def test_scalar_type_does_not_trigger_upcast_type_promotion(self):
7148
class DoNotUpcastModel(torch.nn.Module):
7149
def forward(self, x):
7150
scale = x.size()[-1] ** -0.5
7151
# 'scale' is exported as onnx float32 rank 0 tensor.
7152
# The following 'Mul' should NOT be promoted to float32.
7155
x = torch.ones(2, 3, dtype=torch.float16)
7156
self.run_test(DoNotUpcastModel(), x)
7158
@skipIfUnsupportedMinOpsetVersion(9)
7159
def test_full_like(self):
7160
class FullLikeModel(torch.nn.Module):
7161
def forward(self, x):
7162
return torch.full_like(x, 1.3, dtype=torch.int)
7164
x = torch.tensor(12)
7165
self.run_test(FullLikeModel(), x)
7167
@skipIfUnsupportedMinOpsetVersion(9)
7169
def test_full_like_value(self):
7170
class FullLikeModel(torch.nn.Module):
7171
def forward(self, x, y):
7173
return torch.full_like(x, out)
7175
x = torch.tensor(12)
7177
self.run_test(FullLikeModel(), (x, y))
7179
def test_l1_norm(self):
7180
class NormModel(torch.nn.Module):
7181
def forward(self, x):
7182
return torch.norm(x, p=1, dim=-1, keepdim=False)
7184
x = torch.randn(4, 2, 3, requires_grad=True)
7185
self.run_test(NormModel(), x)
7187
def test_l2_norm(self):
7188
class NormModel(torch.nn.Module):
7189
def forward(self, x):
7190
return torch.norm(x, p=2, dim=-2, keepdim=False)
7192
x = torch.randn(4, 2, 3, requires_grad=True)
7193
self.run_test(NormModel(), x)
7195
def test_frobenius_norm(self):
7196
class NormModel(torch.nn.Module):
7197
def forward(self, x):
7198
return torch.norm(x, p="fro", dim=0, keepdim=False)
7200
x = torch.randn(4, 2, 3, requires_grad=True)
7201
self.run_test(NormModel(), x)
7203
def test_frobenius_norm_keepdim(self):
7204
class NormModel(torch.nn.Module):
7205
def forward(self, x):
7206
return torch.norm(x, p="fro", dim=(0, 1), keepdim=True)
7208
x = torch.randn(4, 2, 3, requires_grad=True)
7209
self.run_test(NormModel(), x)
7211
def test_unfold(self):
7212
class UnfoldModel(torch.nn.Module):
7213
def forward(self, x):
7214
return x.unfold(dimension=2, size=2, step=2)
7216
x = torch.randn(4, 2, 3, requires_grad=True)
7217
y = torch.randn(2, 1, 3, requires_grad=True)
7221
dynamic_axes={"x": [0, 1]},
7223
additional_test_inputs=[y],
7226
def test_unfold_infer_shape(self):
7227
class UnfoldModule(torch.jit.ScriptModule):
7230
self.conv = torch.nn.Conv1d(3, 1, 3, stride=2)
7232
@torch.jit.script_method
7233
def forward(self, x):
7235
return x.unfold(dimension=2, size=2, step=2)
7237
x = torch.randn(32, 3, 64)
7238
self.run_test(UnfoldModule(), x)
7240
@skipIfUnsupportedMinOpsetVersion(12)
7241
def test_unfold_dynamic_inputs(self):
7242
class UnfoldModel(torch.nn.Module):
7243
def forward(self, x):
7244
return x.unfold(dimension=2, size=x.shape[1], step=x.shape[1] - 1)
7246
x = torch.randn(4, 2, 4, requires_grad=True)
7247
self.run_test(UnfoldModel(), x)
7249
class UnfoldModel(torch.nn.Module):
7250
def forward(self, x):
7251
return x.unfold(dimension=2, size=x.shape[1], step=1)
7253
x = torch.randn(4, 2, 4, requires_grad=True)
7254
self.run_test(UnfoldModel(), x)
7256
@skipIfUnsupportedMinOpsetVersion(9) # MatMul long inputs is added in ONNX opset 9.
7258
class MatmulModel(torch.nn.Module):
7259
def forward(self, input, other):
7260
return torch.mv(input, other)
7262
x = torch.randn(4, 5, requires_grad=True)
7263
y = torch.randn(5, requires_grad=True)
7264
self.run_test(MatmulModel(), (x, y))
7266
x = torch.randint(10, (4, 5))
7267
y = torch.randint(10, (5,))
7268
self.run_test(MatmulModel(), (x, y))
7270
@skipIfUnsupportedMinOpsetVersion(9) # MatMul long inputs is added in ONNX opset 9.
7272
class MatmulModel(torch.nn.Module):
7273
def forward(self, input, other):
7274
return torch.dot(input, other)
7276
x = torch.randn(5, requires_grad=True)
7277
y = torch.randn(5, requires_grad=True)
7278
self.run_test(MatmulModel(), (x, y))
7280
x = torch.randint(10, (5,))
7281
y = torch.randint(10, (5,))
7282
self.run_test(MatmulModel(), (x, y))
7284
@skipScriptTest() # SpectralNorm not TorchScript compatible.
7285
def test_spectral_norm(self):
7286
m = torch.nn.utils.spectral_norm(torch.nn.Linear(2, 4))
7288
x = torch.randn(6, 2)
7289
self.run_test(m, (x,))
7291
def test_prelu(self):
7292
class PReluModel(torch.nn.Module):
7295
self.prelu = torch.nn.PReLU()
7297
def forward(self, x):
7298
return self.prelu(x)
7300
x = torch.randn(2, 3, 4)
7301
y = torch.randn(2, 4, 5)
7306
dynamic_axes={"x": [1, 2]},
7307
additional_test_inputs=[y],
7310
def test_prelu_scalar(self):
7311
x = torch.scalar_tensor(1.0)
7312
self.run_test(torch.nn.PReLU(), x, input_names=["x"])
7314
def test_relu6(self):
7315
class Relu6Model(torch.nn.Module):
7318
self.relu6 = torch.nn.ReLU6()
7320
def forward(self, x):
7321
return self.relu6(x)
7323
x = torch.randn(2, 3, 4) * 100.0
7324
y = torch.randn(2, 4, 5) * 100.0
7329
dynamic_axes={"x": [1, 2]},
7330
additional_test_inputs=[y],
7333
def test_silu(self):
7334
class SiLUModel(torch.nn.Module):
7337
self.silu = torch.nn.SiLU()
7339
def forward(self, x):
7342
x = torch.randn(2, 3, 4)
7343
self.run_test(SiLUModel(), (x))
7345
@skipIfUnsupportedMinOpsetVersion(14)
7346
def test_tril(self):
7347
class trilModel(torch.nn.Module):
7348
def forward(self, x):
7349
return torch.tril(x)
7351
x = torch.randn(2, 3, 4)
7352
self.run_test(trilModel(), (x))
7354
class trilModelwithDiagonal(torch.nn.Module):
7355
def forward(self, x):
7356
return torch.tril(x, diagonal=1)
7358
x = torch.randn(2, 3, 4)
7359
self.run_test(trilModelwithDiagonal(), (x))
7361
class trilModelwithNegDiagonal(torch.nn.Module):
7362
def forward(self, x):
7363
return torch.tril(x, diagonal=-1)
7365
x = torch.randn(2, 3, 4)
7366
self.run_test(trilModelwithNegDiagonal(), (x))
7368
class trilModelWithDiagonalInput(torch.nn.Module):
7369
def forward(self, x, diagnonal: int):
7370
return torch.tril(x, diagonal=diagnonal)
7372
x = torch.randn(2, 3, 4)
7373
self.run_test(trilModelWithDiagonalInput(), (x, 5))
7375
@skipIfUnsupportedMinOpsetVersion(14)
7376
def test_triu(self):
7377
class triuModel(torch.nn.Module):
7378
def forward(self, x):
7379
return torch.triu(x)
7381
x = torch.randn(2, 3, 4)
7382
self.run_test(triuModel(), (x))
7384
class triuModelwithDiagonal(torch.nn.Module):
7385
def forward(self, x):
7386
return torch.triu(x, diagonal=1)
7388
x = torch.randn(2, 3, 4)
7389
self.run_test(triuModelwithDiagonal(), (x))
7391
class triuModelwithNegDiagonal(torch.nn.Module):
7392
def forward(self, x):
7393
return torch.triu(x, diagonal=-1)
7395
x = torch.randn(2, 3, 4)
7396
self.run_test(triuModelwithNegDiagonal(), (x))
7398
class triuModelWithDiagonalInput(torch.nn.Module):
7399
def forward(self, x, diagnonal: int):
7400
return torch.triu(x, diagonal=diagnonal)
7402
x = torch.randn(2, 3, 4)
7403
self.run_test(triuModelWithDiagonalInput(), (x, 5))
7405
def test_mish(self):
7406
class MishModel(torch.nn.Module):
7409
self.mish = torch.nn.Mish()
7411
def forward(self, x):
7414
x = torch.randn(2, 3, 4)
7415
self.run_test(MishModel(), (x))
7417
def test_remainder(self):
7418
class RemainderModel(torch.nn.Module):
7419
def forward(self, input, other):
7420
return torch.remainder(input, other)
7422
x = torch.randn(4, 2, 3)
7423
y = torch.randn(1, 2, 1)
7424
self.run_test(RemainderModel(), (x, y))
7426
x = torch.tensor([7, 6, -7, -6], dtype=torch.long)
7427
y = torch.tensor([2], dtype=torch.long)
7428
self.run_test(RemainderModel(), (x, y))
7430
x = x.to(torch.float)
7431
self.run_test(RemainderModel(), (x, y))
7433
y = y.to(torch.float)
7434
self.run_test(RemainderModel(), (x, y))
7436
x = x.to(torch.int32)
7437
self.run_test(RemainderModel(), (x, y))
7439
def test_remainder_scalar(self):
7440
class RemainderModel(torch.nn.Module):
7441
def __init__(self, scalar=2.55):
7443
self.scalar = scalar
7445
def forward(self, input):
7446
return torch.remainder(input, self.scalar)
7448
x = torch.randint(10, (2, 3))
7449
self.run_test(RemainderModel(), x)
7451
x = torch.tensor([7, 6, -7, -6], dtype=torch.long)
7452
self.run_test(RemainderModel(2), x)
7454
@skipIfUnsupportedMinOpsetVersion(10)
7455
def test_fmod(self):
7456
class FModModel(torch.nn.Module):
7457
def forward(self, input, other):
7458
return torch.fmod(input, other)
7460
x = torch.randn(4, 2, 3)
7461
y = torch.randn(1, 2, 1)
7462
self.run_test(FModModel(), (x, y))
7464
@skipIfUnsupportedMinOpsetVersion(10)
7465
def test_fmod_scalar(self):
7466
class FModModel(torch.nn.Module):
7467
def forward(self, input):
7468
return torch.fmod(input, 2.55)
7470
x = torch.randint(10, (2, 3))
7471
self.run_test(FModModel(), x)
7473
@skipIfUnsupportedMinOpsetVersion(9)
7475
class GluModel(torch.nn.Module):
7476
def forward(self, x):
7477
return torch.nn.functional.glu(x)
7479
x = torch.randn(2, 4, 5, 6, requires_grad=True)
7480
self.run_test(GluModel(), x)
7482
@skipIfUnsupportedMinOpsetVersion(9)
7483
def test_gelu(self):
7484
class GeluModel(torch.nn.Module):
7485
def forward(self, x):
7486
return torch.nn.functional.gelu(x, approximate="none")
7488
x = torch.randn(2, 4, 5, 6, requires_grad=True)
7489
self.run_test(GeluModel(), x)
7491
@skipIfUnsupportedMinOpsetVersion(9)
7492
def test_tanh_gelu(self):
7493
class GeluModel(torch.nn.Module):
7494
def forward(self, x):
7495
return torch.nn.functional.gelu(x, approximate="tanh")
7497
x = torch.randn(2, 4, 5, 6, requires_grad=True)
7498
self.run_test(GeluModel(), x)
7500
def test_add_inplace(self):
7501
class InplaceAddModel(torch.nn.Module):
7502
def forward(self, x):
7506
x = torch.randn(4, 2, 3, requires_grad=True)
7507
self.run_test(InplaceAddModel(), x)
7509
def test_addcmul(self):
7510
class AddcmulModel(torch.nn.Module):
7511
def forward(self, x, t1, t2):
7512
return torch.addcmul(x, t1, t2), torch.addcmul(x, t1, t2, value=2.2)
7514
x = torch.randn(1, 3)
7515
t1 = torch.randn(3, 1)
7516
t2 = torch.randn(1, 3)
7517
self.run_test(AddcmulModel(), (x, t1, t2))
7519
def test_rsqrt(self):
7520
class RsqrtModel(torch.nn.Module):
7521
def forward(self, x):
7524
x = torch.randn(4, 2, 3, requires_grad=True, dtype=torch.float64)
7525
self.run_test(RsqrtModel(), x)
7527
def test_rsqrt_zeros(self):
7528
class RsqrtModel(torch.nn.Module):
7529
def forward(self, x):
7532
x = torch.zeros(4, 2, 3, requires_grad=True, dtype=torch.float64)
7533
self.run_test(RsqrtModel(), x)
7535
@skipIfUnsupportedMinOpsetVersion(11)
7536
def test_unique(self):
7537
class UniqueModel(torch.nn.Module):
7538
def forward(self, x):
7539
return torch.unique(
7540
x, sorted=True, return_inverse=False, return_counts=True
7543
x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
7544
self.run_test(UniqueModel(), x)
7546
@skipIfUnsupportedMinOpsetVersion(11)
7547
def test_unique_along_dim(self):
7548
class UniqueModel(torch.nn.Module):
7549
def forward(self, x):
7550
return torch.unique(
7551
x, dim=0, sorted=True, return_inverse=True, return_counts=False
7554
x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
7555
self.run_test(UniqueModel(), x)
7557
@skipIfUnsupportedMinOpsetVersion(11)
7558
def test_cumsum(self):
7559
class CumSum(torch.nn.Module):
7560
def forward(self, input):
7561
return torch.cumsum(input, dim=0)
7563
x = torch.randn(2, 3, 4)
7565
self.run_test(model, x)
7567
@skipIfUnsupportedMinOpsetVersion(11)
7568
def test_cumsum_with_cast(self):
7569
class CumSum(torch.nn.Module):
7570
def forward(self, input):
7571
return torch.cumsum(input, dim=0, dtype=torch.float32)
7574
x = torch.tensor([2, 3, 4], dtype=torch.int32)
7575
self.run_test(model, x)
7576
x = torch.tensor([False, True, True])
7577
self.run_test(model, x)
7579
@skipScriptTest() # error in propagate as assign input shape
7580
@skipIfUnsupportedMinOpsetVersion(10)
7581
def test_embedding_bag(self):
7582
model = torch.nn.EmbeddingBag(10, 5, mode="sum", scale_grad_by_freq=True)
7583
input = torch.randint(10, (7,))
7584
offset = torch.tensor([0, 2, 5, 6])
7585
self.run_test(model, (input, offset))
7587
model = torch.nn.EmbeddingBag(10, 5, mode="sum", include_last_offset=True)
7588
input = torch.randint(10, (7,))
7589
offset = torch.tensor([0, 2, 5, 6])
7590
self.run_test(model, (input, offset))
7592
model = torch.nn.EmbeddingBag(10, 5, mode="max")
7593
input = torch.randint(10, (7, 5))
7594
self.run_test(model, (input))
7596
@skipIfUnsupportedMinOpsetVersion(11)
7597
def test_embedding_bag_1d_per_sample_weights(self):
7598
class EmbeddingModel(torch.nn.Module):
7599
def forward(self, embedding_matrix, input, offset, weights):
7600
return torch.nn.functional.embedding_bag(
7605
per_sample_weights=weights,
7608
model = EmbeddingModel()
7609
x = torch.randint(7, (6,))
7613
offset = torch.tensor([0, 2, 5])
7614
embedding_matrix = torch.rand(10, 15)
7615
self.run_test(model, (embedding_matrix, x, offset, w))
7617
@skipIfUnsupportedMinOpsetVersion(11)
7619
"This test is broken with ONNXRuntime(17): "
7620
"when running with onnxruntime 1.17.0 this test fails with the following error:"
7621
"FAIL : Non-zero status code returned while running If node. "
7622
"Name:'/If' Status Message: if.cc:253 Compute "
7623
"If nodes condition input must have exactly one element"
7624
"https://github.com/pytorch/pytorch/issues/119442"
7626
def test_embedding_bag_2d_per_sample_weights(self):
7627
class EmbeddingModel(torch.nn.Module):
7628
def forward(self, embedding_matrix, input, weights):
7629
return torch.nn.functional.embedding_bag(
7630
input, embedding_matrix, mode="sum", per_sample_weights=weights
7633
embedding_matrix = torch.rand(10, 15)
7634
model = EmbeddingModel()
7635
x = torch.randint(7, (2, 3))
7636
w = torch.randn(2, 3)
7638
x2 = torch.randint(7, (4, 3))
7639
w2 = torch.randn(4, 3)
7642
(embedding_matrix, x, w),
7643
input_names=["embed", "x", "w"],
7644
dynamic_axes={"x": [0], "w": [0]},
7645
additional_test_inputs=[(embedding_matrix, x2, w2)],
7648
@skipScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
7649
@skipIfUnsupportedMinOpsetVersion(11)
7651
"Due to ONNX Loop shape inference issue. "
7652
"https://msdata.visualstudio.com/Vienna/_workitems/edit/1352001"
7654
def test_embedding_bag_dynamic_input(self):
7655
class EmbeddingModel1D(torch.nn.Module):
7656
def forward(self, embedding_matrix, input, weights, offsets):
7657
return torch.nn.functional.embedding_bag(
7662
per_sample_weights=weights,
7665
model = EmbeddingModel1D()
7666
x = torch.randint(7, (6,))
7670
offsets = torch.tensor([0, 2, 5], dtype=torch.long)
7671
embedding_matrix = torch.rand(10, 15)
7672
x2 = torch.randint(7, (2,))
7676
embedding_matrix2 = torch.rand(12, 25)
7677
offsets2 = torch.tensor(
7685
(embedding_matrix, x, w, offsets),
7686
additional_test_inputs=[(embedding_matrix2, x2, w2, offsets2)],
7687
input_names=["embedding_matrix", "x", "offsets", "w"],
7689
"embedding_matrix": [0, 1],
7696
class EmbeddingModel2D(torch.nn.Module):
7697
def forward(self, embedding_matrix, input, weights):
7698
return torch.nn.functional.embedding_bag(
7699
input, embedding_matrix, mode="sum", per_sample_weights=weights
7702
model = EmbeddingModel2D()
7703
x = torch.randint(7, (2, 3))
7704
w = torch.randn(2, 3)
7705
embedding_matrix = torch.rand(10, 15)
7706
x2 = torch.randint(7, (3, 5))
7707
w2 = torch.randn(3, 5)
7708
embedding_matrix2 = torch.rand(12, 25)
7711
(embedding_matrix, x, w),
7712
additional_test_inputs=[(embedding_matrix2, x2, w2)],
7713
input_names=["embedding_matrix", "x", "w"],
7714
dynamic_axes={"embedding_matrix": [0, 1], "x": [0, 1], "w": [0, 1]},
7717
@skipIfUnsupportedMinOpsetVersion(8)
7718
def test_meshgrid(self):
7719
class Meshgrid(torch.nn.Module):
7720
def forward(self, x, y, z):
7721
output1, output2, output3 = torch.meshgrid(x, y, z)
7722
return output1, output2, output3
7724
x = torch.randn(3, requires_grad=True)
7725
y = torch.zeros(4, requires_grad=True)
7726
z = torch.randn(5, requires_grad=True)
7727
self.run_test(Meshgrid(), (x, y, z))
7729
@skipIfUnsupportedMinOpsetVersion(8)
7730
def test_meshgrid_indexing(self):
7731
class Meshgrid(torch.nn.Module):
7732
def __init__(self, indexing):
7734
self.indexing = indexing
7736
def forward(self, x, y, z):
7737
output1, output2, output3 = torch.meshgrid(
7738
x, y, z, indexing=self.indexing
7740
return output1, output2, output3
7742
x = torch.randn(5, requires_grad=True)
7743
y = torch.zeros(6, requires_grad=True)
7744
z = torch.randn(7, requires_grad=True)
7745
for indexing in ("xy", "ij"):
7746
self.run_test(Meshgrid(indexing), (x, y, z))
7748
@skipIfUnsupportedMinOpsetVersion(8)
7749
def test_meshgrid_scalar(self):
7750
class Meshgrid(torch.nn.Module):
7751
def forward(self, x, y, z):
7752
output1, output2, output3 = torch.meshgrid(x, y, z)
7753
return output1, output2, output3
7755
x = torch.ones(3, requires_grad=True)
7756
y = torch.zeros(4, requires_grad=True)
7757
z = torch.tensor(2.0)
7758
self.run_test(Meshgrid(), (x, y, z))
7760
def test_baddbmm(self):
7761
class MyModule(torch.nn.Module):
7762
def forward(self, input, batch1, batch2):
7763
return torch.baddbmm(
7764
input, batch1, batch2, alpha=torch.tensor(5), beta=3.5
7767
x = torch.randn(10, 3, 5)
7768
batch1 = torch.randn(10, 3, 4)
7769
batch2 = torch.randn(10, 4, 5)
7771
self.run_test(model, (x, batch1, batch2))
7773
def test_baddbmm_dynamic(self):
7774
class MyModule(torch.nn.Module):
7775
def forward(self, input, batch1, batch2, alpha, beta):
7776
return torch.baddbmm(input, batch1, batch2, alpha=alpha, beta=beta)
7778
x = torch.randn(10, 3, 5)
7779
batch1 = torch.randn(10, 3, 4)
7780
batch2 = torch.randn(10, 4, 5)
7781
alpha = torch.tensor(5)
7782
beta = torch.tensor(3.5)
7784
self.run_test(model, (x, batch1, batch2, alpha, beta))
7786
def test_numel(self):
7787
class MyModule(torch.nn.Module):
7788
def forward(self, input):
7789
return input.numel() * input
7791
x = torch.randn(2, 3, 5)
7792
x2 = torch.randn(4, 5, 6)
7798
dynamic_axes={"x": [0, 1, 2]},
7799
additional_test_inputs=[(x2,)],
7802
def test_numel_empty(self):
7803
class MyModule(torch.nn.Module):
7804
def forward(self, input):
7805
return input.numel() * input
7814
dynamic_axes={"x": [0]},
7815
additional_test_inputs=[(x2,)],
7818
def test_dtype(self):
7819
class MyModel(torch.jit.ScriptModule):
7820
@torch.jit.script_method
7821
def forward(self, input, other):
7822
return input.to(dtype=other.dtype) + other
7824
x = torch.randn(2, 3)
7825
y = torch.randn(2, 3)
7826
self.run_test(MyModel(), (x, y))
7828
def test_dtype_eq(self):
7829
class MyModel(torch.jit.ScriptModule):
7830
@torch.jit.script_method
7831
def forward(self, input, other):
7832
if input.dtype == other.dtype:
7833
return input + other
7836
x = torch.randn(2, 3)
7837
y = torch.randn(2, 3)
7838
self.run_test(MyModel(), (x, y))
7840
def test_cast_to(self):
7841
class MyModule(torch.jit.ScriptModule):
7842
@torch.jit.script_method
7843
def forward(self, input, other):
7844
return input.to(other) + other
7846
x = torch.randn(2, 3, 4)
7847
y = torch.tensor([1], dtype=torch.int64)
7849
self.run_test(model, (x, y))
7851
def test_cast_to_bool(self):
7852
class MyModule(torch.nn.Module):
7853
def forward(self, input, other):
7854
return torch.cat((input.to(other), other), 0)
7856
x = torch.randn(2, 3, 4)
7857
y = torch.zeros([2, 3, 4], dtype=torch.bool)
7859
self.run_test(model, (x, y))
7861
# ONNX supports bfloat16 for opsets >= 13
7862
@skipIfUnsupportedMinOpsetVersion(13)
7863
def test_cast_type_as_with_bfloat16(self):
7864
class MyModule(torch.nn.Module):
7865
def forward(self, x):
7866
y = torch.ones((3, 4), dtype=torch.bfloat16)
7868
return x.to(dtype=torch.float16)
7870
x = torch.ones(3, 4, dtype=torch.float16)
7872
self.run_test(model, x)
7874
@skipIfUnsupportedMinOpsetVersion(9)
7875
def test_type_as(self):
7876
class MyModule(torch.nn.Module):
7877
def forward(self, x):
7878
y = torch.tensor([1.0])
7881
a = torch.tensor([True, False], dtype=torch.bool)
7882
b = torch.randn(3, 4, dtype=torch.double)
7883
c = torch.ones((2, 2), dtype=torch.int64)
7885
self.run_test(model, a)
7886
self.run_test(model, b)
7887
self.run_test(model, c)
7889
@skipIfUnsupportedMinOpsetVersion(9)
7890
def test_ones_bool(self):
7891
class MyModule(torch.nn.Module):
7892
def forward(self, input):
7893
true = torch.ones(input.shape, dtype=torch.bool)
7894
return input.to(true) & true
7896
x = torch.randn(2, 3, 4)
7898
self.run_test(model, x)
7901
class Log(torch.nn.Module):
7902
def forward(self, input):
7903
return torch.log(input)
7905
x = torch.rand(2, 3, 4)
7907
self.run_test(model, x)
7909
def test_log1p(self):
7910
class Log1p(torch.nn.Module):
7911
def forward(self, input):
7912
return torch.log1p(input)
7914
x = torch.rand(2, 3, 4)
7916
self.run_test(model, x)
7918
def test_log10(self):
7919
class Log10(torch.nn.Module):
7920
def forward(self, input):
7921
return torch.log10(input)
7923
x = torch.rand(2, 3, 4)
7925
self.run_test(model, x)
7927
def test_log2(self):
7928
class Log2(torch.nn.Module):
7929
def forward(self, input):
7930
return torch.log2(input)
7932
x = torch.tensor(1.0)
7934
self.run_test(model, x)
7936
@skipIfUnsupportedMinOpsetVersion(11)
7937
def test_round(self):
7938
class Round(torch.nn.Module):
7939
def forward(self, x):
7940
return torch.round(x)
7942
x = torch.tensor([0.9920, -1.0362, -1.5000, 3.5000], requires_grad=True)
7943
self.run_test(Round(), x)
7945
int_x = torch.tensor([9920, 1036, -1500, 35], dtype=torch.int32)
7946
self.run_test(Round(), int_x)
7948
@skipIfUnsupportedMinOpsetVersion(11)
7949
def test_round_with_decimals(self):
7950
class Round(torch.nn.Module):
7951
def __init__(self, decimals):
7953
self.decimals = decimals
7955
def forward(self, x):
7956
return torch.round(x, decimals=self.decimals)
7958
x = torch.tensor([0.9920, -1234.0362, -1.58960, 3.5000])
7959
for decimals in (0, -2, 3):
7960
self.run_test(Round(decimals), x)
7962
@skipIfUnsupportedMinOpsetVersion(17)
7963
def test_stft_default(self):
7964
class STFT(torch.nn.Module):
7965
def forward(self, x):
7967
return torch.stft(x, n_fft=n_fft, center=False, return_complex=False)
7969
x = torch.randn((1, 32), requires_grad=True)
7970
self.run_test(STFT(), x, atol=1e-6)
7972
@skipIfUnsupportedMinOpsetVersion(17)
7973
def test_stft_hop_length(self):
7974
class STFT(torch.nn.Module):
7975
def forward(self, x):
7982
hop_length=hop_length,
7983
return_complex=False,
7986
x = torch.randn((1, 32), requires_grad=True)
7987
self.run_test(STFT(), x, atol=1e-6)
7989
@skipIfUnsupportedMinOpsetVersion(17)
7990
def test_stft_non_divisible_hop_length(self):
7991
class STFT(torch.nn.Module):
7992
def forward(self, x):
7999
hop_length=hop_length,
8000
return_complex=False,
8003
x = torch.randn((1, 32), requires_grad=True)
8004
self.run_test(STFT(), x, atol=1e-6)
8006
@skipIfUnsupportedMinOpsetVersion(17)
8007
def test_stft_window_int_same_size(self):
8008
class STFT(torch.nn.Module):
8009
def forward(self, x):
8016
win_length=win_length,
8017
return_complex=False,
8020
x = torch.randn((1, 32), requires_grad=True)
8021
self.run_test(STFT(), x, atol=1e-6)
8023
@skipIfUnsupportedMinOpsetVersion(17)
8024
def test_stft_window_int_different_size(self):
8025
class STFT(torch.nn.Module):
8026
def forward(self, x):
8033
win_length=win_length,
8034
return_complex=False,
8037
x = torch.randn((1, 32), requires_grad=True)
8038
self.run_test(STFT(), x, atol=1e-6)
8040
@skipIfUnsupportedMinOpsetVersion(17)
8041
def test_stft_window_custom(self):
8042
class STFT(torch.nn.Module):
8043
def forward(self, x):
8045
window = torch.hann_window(16)
8051
return_complex=False,
8054
x = torch.randn((1, 32), requires_grad=True)
8055
self.run_test(STFT(), x, atol=1e-6)
8057
@skipIfUnsupportedMinOpsetVersion(17)
8058
def test_stft_wrong_custom_window_size(self):
8059
class STFT(torch.nn.Module):
8060
def forward(self, x):
8062
window = torch.hann_window(10)
8064
x, n_fft=n_fft, window=window, center=False, return_complex=False
8067
x = torch.randn((1, 32), requires_grad=True)
8068
with self.assertRaises((AssertionError, RuntimeError)):
8069
self.run_test(STFT(), x)
8071
@skipIfUnsupportedMinOpsetVersion(17)
8072
def test_stft_wrong_window_length(self):
8073
class STFT(torch.nn.Module):
8074
def forward(self, x):
8082
return_complex=False,
8085
x = torch.randn((1, 32), requires_grad=True)
8086
with self.assertRaises(RuntimeError):
8087
self.run_test(STFT(), x)
8089
@skipIfUnsupportedMinOpsetVersion(17)
8090
def test_stft_window_size_with_win_len(self):
8091
class STFT(torch.nn.Module):
8092
def forward(self, x):
8094
window = torch.hann_window(10)
8102
return_complex=False,
8105
x = torch.randn((1, 32), requires_grad=True)
8106
self.run_test(STFT(), x, atol=1e-6)
8108
@skipIfUnsupportedMinOpsetVersion(17)
8109
def test_stft_one_dimension(self):
8110
class STFT(torch.nn.Module):
8111
def forward(self, x):
8117
return_complex=False,
8120
x = torch.randn((32), requires_grad=True)
8121
self.run_test(STFT(), x, atol=1e-6)
8123
@skipIfUnsupportedMinOpsetVersion(17)
8124
def test_stft_wrong_input_size(self):
8125
class STFT(torch.nn.Module):
8126
def forward(self, x):
8128
return torch.stft(x, n_fft=n_fft, center=False, return_complex=False)
8130
x = torch.randn((1, 1, 32), requires_grad=True)
8131
with self.assertRaises(RuntimeError):
8132
self.run_test(STFT(), x)
8134
@skipIfUnsupportedMinOpsetVersion(17)
8135
def test_stft_wrong_return_complex(self):
8136
class STFT(torch.nn.Module):
8137
def forward(self, x):
8139
return torch.stft(x, n_fft=n_fft, center=False, return_complex=True)
8141
x = torch.randn((1, 32), requires_grad=True)
8142
with self.assertRaises(errors.SymbolicValueError):
8143
self.run_test(STFT(), x)
8145
@skipIfUnsupportedMinOpsetVersion(17)
8146
def test_stft_normalize(self):
8147
class STFT(torch.nn.Module):
8148
def forward(self, x):
8155
return_complex=False,
8158
x = torch.randn((32), requires_grad=True)
8159
self.run_test(STFT(), x, atol=1e-6)
8161
@skipIfUnsupportedMinOpsetVersion(17)
8162
def test_stft_not_onesided(self):
8163
class STFT(torch.nn.Module):
8164
def forward(self, x):
8171
return_complex=False,
8174
x = torch.randn((32), requires_grad=True)
8175
self.run_test(STFT(), x, atol=1e-6)
8177
def test_constant_pad(self):
8178
model = torch.nn.ConstantPad1d(2, 3.5)
8179
x = torch.randn(2, 4, 4)
8180
self.run_test(model, x)
8182
model = torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5)
8183
x = torch.randn(2, 2, 4, 4)
8184
self.run_test(model, x)
8186
@common_utils.parametrize(
8189
common_utils.subtest([2, 4], name="scalar_list"),
8190
common_utils.subtest(
8192
torch.tensor(2, dtype=torch.int64),
8193
torch.tensor(4, dtype=torch.int64),
8195
name="scalar_tensor_list",
8199
@skipIfUnsupportedMinOpsetVersion(11) # Dynamic padding is added in opset 11
8200
def test_pad_types(self, pad):
8201
# Test for different pad integer types
8202
class Pad(torch.nn.Module):
8203
def forward(self, x, pad: List[int]):
8204
return torch.nn.functional.pad(x, pad)
8206
x = torch.randn(2, 2, 4, 4)
8207
self.run_test(Pad(), (x, pad))
8209
@skipIfUnsupportedMinOpsetVersion(11)
8210
def test_pad_circular(self):
8211
class PadModel(torch.nn.Module):
8212
def forward(self, x):
8213
out = torch.nn.functional.pad(x, (1, 2, 1, 2), mode="circular")
8216
x = torch.randn(2, 3, 3, 4)
8217
self.run_test(PadModel(), (x))
8219
@skipIfUnsupportedMinOpsetVersion(11)
8220
def test_pad_circular_negative(self):
8221
# Test for different pad integer types
8222
class PadModel(torch.nn.Module):
8223
def forward(self, x):
8224
out = torch.nn.functional.pad(x, (-1, -2), mode="circular")
8227
x = torch.randn(2, 3, 6)
8228
self.run_test(PadModel(), (x))
8230
@skipIfUnsupportedMinOpsetVersion(11)
8231
def test_pad_circular_dynamic_axes(self):
8232
class PadModel(torch.nn.Module):
8233
def forward(self, x):
8234
out = torch.nn.functional.pad(x, (2, 1, 2, 1), mode="circular")
8237
x = torch.randn(4, 3, 5, 6)
8241
input_names=["input_1"],
8242
dynamic_axes={"input_1": [0, 1, 2, 3]},
8245
@skipIfUnsupportedMaxOpsetVersion(10)
8246
@skipScriptTest() # TODO: the logic in symbolic_opset9 doesn't handle script
8247
def test_unsupported_pad(self):
8248
class Pad(torch.nn.Module):
8249
def forward(self, x, pad: List[int]):
8250
return torch.nn.functional.pad(x, pad)
8252
x = torch.randn(2, 2, 4, 4)
8255
with self.assertRaisesRegex(
8258
"Unsupported: ONNX export of Pad.*"
8259
+ "The sizes of the padding must be constant"
8262
self.run_test(Pad(), (x, y))
8264
@skipIfUnsupportedMinOpsetVersion(9)
8265
def test_if_fold(self):
8266
class IfFoldModel(torch.nn.Module):
8267
def forward(self, y):
8275
x = torch.ones((3, 4), dtype=torch.int)
8276
self.run_test(IfFoldModel(), x)
8278
class IfFoldModel(torch.nn.Module):
8279
def forward(self, y):
8286
x = torch.ones((3, 4), dtype=torch.int)
8287
self.run_test(IfFoldModel(), x)
8289
class IfFoldModel(torch.nn.Module):
8290
def forward(self, y):
8298
x = torch.ones((3, 4), dtype=torch.int)
8299
self.run_test(IfFoldModel(), x)
8301
class IfFoldModel(torch.nn.Module):
8302
def forward(self, y):
8309
x = torch.ones((3, 4), dtype=torch.int)
8310
self.run_test(IfFoldModel(), x)
8312
class IfFoldModel(torch.nn.Module):
8313
def forward(self, y):
8320
x = torch.ones((3, 4), dtype=torch.int)
8321
self.run_test(IfFoldModel(), x)
8323
class IfFoldModel(torch.nn.Module):
8324
def forward(self, y):
8325
if y.dim() < 3 and y.dtype == torch.int:
8332
x = torch.ones((3, 4), dtype=torch.int)
8333
self.run_test(IfFoldModel(), x)
8335
class IfFoldModel(torch.nn.Module):
8336
def forward(self, y):
8337
if y.dim() == 3 and y.dtype == torch.int:
8344
x = torch.ones((3, 4), dtype=torch.int)
8345
self.run_test(IfFoldModel(), x)
8347
class IfFoldModel(torch.nn.Module):
8348
def forward(self, y):
8349
if y.numel() != 0 and y.dim() == 2:
8356
x = torch.ones((3, 4), dtype=torch.int)
8357
self.run_test(IfFoldModel(), x)
8359
class IfFoldModel(torch.nn.Module):
8360
def forward(self, x, y):
8361
if x.numel() == y.numel():
8367
x = torch.ones((3, 4), dtype=torch.int)
8368
y = torch.ones((3, 4), dtype=torch.int)
8369
self.run_test(IfFoldModel(), (x, y))
8371
class IfFoldModel(torch.nn.Module):
8372
def forward(self, x, y):
8373
if x.numel() != y.numel():
8379
x = torch.ones((3, 4), dtype=torch.int)
8380
y = torch.ones((3, 4), dtype=torch.int)
8381
self.run_test(IfFoldModel(), (x, y))
8383
@skipIfUnsupportedMinOpsetVersion(11)
8384
def test_uninitialized(self):
8385
class UninitializedModel(torch.nn.Module):
8386
def forward(self, y):
8394
x = torch.ones((3, 4), dtype=torch.int)
8395
self.run_test(UninitializedModel(), x)
8397
@skipIfUnsupportedMinOpsetVersion(11)
8398
def test_uninitialized_dynamic(self):
8399
class UninitializedModel(torch.nn.Module):
8400
def forward(self, y):
8408
x = torch.ones((3, 4), dtype=torch.int)
8409
y = torch.ones((6, 7), dtype=torch.int)
8411
UninitializedModel(),
8413
additional_test_inputs=[y],
8414
input_names=["input_1"],
8415
dynamic_axes={"input_1": [0, 1]},
8418
# onnx::Identity of sequence supported for ONNX opset >= 14
8419
@skipIfUnsupportedMinOpsetVersion(14)
8420
def test_uninitialized_tensorList(self):
8421
class UninitializedTensorListModel(torch.nn.Module):
8422
def forward(self, x):
8423
if x[0].shape[0] < 5:
8430
x = torch.ones((3, 4), dtype=torch.int)
8431
self.run_test(torch.jit.script(UninitializedTensorListModel()), x)
8433
# onnx::Identity of sequence supported for ONNX opset >= 14
8434
@skipIfUnsupportedMinOpsetVersion(14)
8435
def test_uninitialized_tensorList_dynamic(self):
8436
class UninitializedTensorListModel(torch.nn.Module):
8437
def forward(self, x):
8438
if x[0].shape[0] < 5:
8445
x = torch.ones((3, 4), dtype=torch.double)
8447
torch.jit.script(UninitializedTensorListModel()),
8449
input_names=["input_1"],
8450
dynamic_axes={"input_1": [0, 1]},
8453
# onnx::Identity of sequence supported for ONNX opset >= 14
8454
@skipIfUnsupportedMinOpsetVersion(14)
8455
def test_uninitialized_intList(self):
8456
class UninitializedListModel(torch.nn.Module):
8457
def forward(self, x):
8458
y = list(range(x.size(0)))
8460
# if x.size(0) != 3, ORT will throw type error.
8467
x = torch.ones((3, 4), dtype=torch.int)
8469
torch.jit.script(UninitializedListModel()),
8471
input_names=["input_1"],
8472
dynamic_axes={"input_1": [0, 1]},
8475
# onnx::Identity of sequence supported for ONNX opset >= 14
8476
@skipIfUnsupportedMinOpsetVersion(14)
8477
def test_uninitialized_tensorList_shape(self):
8478
class UninitializedModel(torch.nn.Module):
8479
def forward(self, x):
8489
x = torch.ones((3, 4), dtype=torch.int)
8490
y = torch.ones((4, 6), dtype=torch.int)
8492
torch.jit.script(UninitializedModel()),
8494
additional_test_inputs=[y],
8495
input_names=["input_1"],
8496
dynamic_axes={"input_1": [0, 1]},
8499
# Sequence type as loop-carried dependencies only supported for ONNX opset >= 13
8500
@skipIfUnsupportedMinOpsetVersion(13)
8501
def test_sequance_loopcarried(self):
8502
class SequanceLoopModel(torch.nn.Module):
8503
def forward(self, x):
8507
return torch.stack(outputs).transpose(0, 1)
8509
x = torch.ones((3, 4), dtype=torch.int)
8510
self.run_test(torch.jit.script(SequanceLoopModel()), x)
8512
def test_reflection_pad(self):
8513
model = torch.nn.ReflectionPad1d(2)
8514
x = torch.randn(2, 4, 4)
8515
self.run_test(model, x)
8517
model = torch.nn.ReflectionPad2d((3, 0, 2, 1))
8518
x = torch.randn(2, 2, 4, 4)
8519
self.run_test(model, x)
8521
def test_replication_pad(self):
8522
model = torch.nn.ReplicationPad1d(2)
8523
x = torch.randn(2, 4, 4)
8524
self.run_test(model, x)
8526
model = torch.nn.ReplicationPad2d((3, 0, 2, 1))
8527
x = torch.randn(2, 2, 4, 4)
8528
self.run_test(model, x)
8530
@skipIfUnsupportedMinOpsetVersion(11)
8531
def test_im2col(self):
8532
class Unfold(torch.nn.Module):
8533
def forward(self, input):
8535
torch.nn.functional.unfold(
8536
input, kernel_size=(10, 15), dilation=2, padding=5, stride=3
8538
torch.nn.functional.unfold(
8539
input, kernel_size=(2, 2), dilation=1, padding=0, stride=3
8541
torch.nn.functional.unfold(
8542
input, kernel_size=(1, 1), dilation=5, padding=2, stride=3
8546
x = torch.rand(1, 1, 200, 100)
8547
self.run_test(Unfold(), x)
8550
@skipIfUnsupportedMinOpsetVersion(11)
8552
class Det(torch.nn.Module):
8553
def forward(self, x):
8554
return torch.linalg.det(x)
8556
x = torch.randn(2, 3, 5, 5)
8557
self.run_test(Det(), x)
8559
def test_linalg_norm(self):
8560
class LinalgSingleDimModel(torch.nn.Module):
8561
def __init__(self, ord_val):
8565
def forward(self, x):
8566
return torch.linalg.norm(x, ord=self.ord, dim=1)
8568
x = torch.randn(2, 3, 5, 5)
8569
self.run_test(LinalgSingleDimModel(None), x)
8570
self.run_test(LinalgSingleDimModel(2), x)
8571
self.run_test(LinalgSingleDimModel(float("inf")), x)
8572
self.run_test(LinalgSingleDimModel(-float("inf")), x)
8573
self.run_test(LinalgSingleDimModel(-4), x)
8574
self.run_test(LinalgSingleDimModel(1.5), x)
8576
class LinalgMultiDimModel(torch.nn.Module):
8577
def __init__(self, ord_val):
8581
def forward(self, x):
8582
return torch.linalg.norm(x, ord=self.ord, dim=(0, 2))
8584
x = torch.randn(2, 3, 5, 5)
8585
self.run_test(LinalgMultiDimModel("fro"), x)
8586
self.run_test(LinalgMultiDimModel(float("inf")), x)
8587
self.run_test(LinalgMultiDimModel(-float("inf")), x)
8588
self.run_test(LinalgMultiDimModel(1), x)
8589
self.run_test(LinalgMultiDimModel(-1), x)
8591
class LinalgNoDimNoOrdModel(torch.nn.Module):
8592
def forward(self, x):
8593
return torch.linalg.norm(x)
8595
x = torch.randn(2, 3, 5, 5)
8596
self.run_test(LinalgNoDimNoOrdModel(), x)
8597
y = torch.randn(2, 3)
8598
self.run_test(LinalgNoDimNoOrdModel(), y)
8600
self.run_test(LinalgNoDimNoOrdModel(), z)
8602
class LinalgNoDim1DModel(torch.nn.Module):
8603
def __init__(self, ord_val):
8607
def forward(self, x):
8608
return torch.linalg.norm(x, ord=self.ord)
8611
self.run_test(LinalgNoDim1DModel(None), x)
8612
self.run_test(LinalgNoDim1DModel(2), x)
8613
self.run_test(LinalgNoDim1DModel(float("inf")), x)
8614
self.run_test(LinalgNoDim1DModel(-float("inf")), x)
8615
self.run_test(LinalgNoDim1DModel(-4), x)
8616
self.run_test(LinalgNoDim1DModel(1.5), x)
8618
class LinalgNoDim2DModel(torch.nn.Module):
8619
def __init__(self, ord_val):
8623
def forward(self, x):
8624
return torch.linalg.norm(x, ord=self.ord)
8626
x = torch.randn(2, 3)
8627
self.run_test(LinalgNoDim2DModel("fro"), x)
8628
self.run_test(LinalgNoDim2DModel(float("inf")), x)
8629
self.run_test(LinalgNoDim2DModel(-float("inf")), x)
8630
self.run_test(LinalgNoDim2DModel(1), x)
8631
self.run_test(LinalgNoDim2DModel(-1), x)
8633
@skipIfUnsupportedMinOpsetVersion(11)
8634
def test_linalg_vector_norm_zero(self):
8635
class LinalgVectorNormModel(torch.nn.Module):
8636
def __init__(self, ord_val):
8640
def forward(self, x):
8641
return torch.linalg.vector_norm(x, ord=self.ord)
8643
x = torch.randn(2, 3, 5, 5)
8644
self.run_test(LinalgVectorNormModel(0), x)
8646
def test_linalg_vector_norm(self):
8647
class LinalgVectorNormModel(torch.nn.Module):
8648
def __init__(self, ord_val, dim_info):
8651
self.dim, self.keepdim = dim_info
8653
def forward(self, x):
8654
return torch.linalg.vector_norm(
8655
x, ord=self.ord, dim=self.dim, keepdim=self.keepdim
8658
x = torch.randn(2, 3, 5, 5)
8659
ord_options = [2, float("inf"), -float("inf"), -4, 1.5]
8660
dim_options = [(None, False), (1, False), ((1, 2), False), ((1, 2), True)]
8661
for ord_val in ord_options:
8662
for dim_info in dim_options:
8663
self.run_test(LinalgVectorNormModel(ord_val, dim_info), x)
8665
def test_linalg_matrix_norm(self):
8666
class LinalgMatrixNormModel(torch.nn.Module):
8667
def __init__(self, ord_val, dim_val=(-2, -1), keepdim_val=False):
8671
self.keepdim = keepdim_val
8673
def forward(self, x):
8674
return torch.linalg.matrix_norm(
8675
x, ord=self.ord, dim=self.dim, keepdim=self.keepdim
8678
x = torch.randn(2, 3, 5, 5)
8679
ord_options = ["fro", float("inf"), -float("inf"), 1, -1]
8680
for ord_val in ord_options:
8681
self.run_test(LinalgMatrixNormModel(ord_val), x)
8682
self.run_test(LinalgMatrixNormModel(ord_val, (0, 2)), x)
8683
self.run_test(LinalgMatrixNormModel(ord_val, (0, 2), True), x)
8685
@skipIfUnsupportedMinOpsetVersion(9)
8686
def test_linalg_cross(self):
8687
class Cross(torch.nn.Module):
8688
def forward(self, x, y):
8689
return torch.linalg.cross(x, y, dim=1), torch.linalg.cross(x, y)
8691
x = torch.randn(5, 3, 2, 3)
8692
y = torch.randn(1, 3, 1, 3)
8693
self.run_test(Cross(), input_args=(x, y))
8695
# This test checks output scalar type in the ONNX graph should not be null
8696
# https://github.com/pytorch/pytorch/issues/28607
8697
@skipIfUnsupportedMinOpsetVersion(10)
8698
def test_trace_script(self):
8700
def center_slice_helper(input, h_offset):
8701
return input[:, h_offset:]
8703
class CenterCrop(torch.nn.Module):
8704
def forward(self, input):
8705
return center_slice_helper(input, torch.tensor(input.shape[1] - 1))
8707
x = torch.randn(3, 4)
8708
self.run_test(CenterCrop(), x)
8711
@skipIfUnsupportedMinOpsetVersion(11)
8712
def test_logdet(self):
8713
class LogDet(torch.nn.Module):
8714
def forward(self, x):
8715
return torch.logdet(x)
8717
x = torch.randn(2, 3, 5, 5)
8718
self.run_test(LogDet(), x)
8721
class DimModel(torch.jit.ScriptModule):
8722
@torch.jit.script_method
8723
def forward(self, input):
8728
empty_input = torch.randn(0, requires_grad=True)
8729
multi_dim_input = torch.randn(1, 2, 3, requires_grad=True)
8730
self.run_test(DimModel(), empty_input)
8731
self.run_test(DimModel(), multi_dim_input)
8733
@skipIfUnsupportedMinOpsetVersion(11)
8734
def test_dim_1(self):
8735
class M(torch.jit.ScriptModule):
8736
@torch.jit.script_method
8737
def forward(self, poses):
8738
boxes = torch.zeros([poses.shape[0], 2, 4])
8740
for kp_boxes in boxes:
8741
kp_boxes = torchvision.ops.clip_boxes_to_image(kp_boxes, (2, 3))
8742
batch_boxes.append(kp_boxes)
8745
dummy_inputs = torch.rand(2, 2, 3)
8746
self.run_test(M(), (dummy_inputs,), input_names=["x"], dynamic_axes={"x": [0]})
8748
@skipIfUnsupportedMinOpsetVersion(12)
8750
def test_outer(self):
8751
class Outer(torch.nn.Module):
8752
def forward(self, x, y):
8753
return torch.outer(x, y)
8755
x = torch.arange(1, 5)
8756
y = torch.arange(1, 4)
8757
self.run_test(Outer(), input_args=(x, y))
8759
x = torch.arange(1, 6).to(dtype=torch.float32)
8760
y = torch.arange(1, 4).to(dtype=torch.long)
8761
self.run_test(Outer(), input_args=(x, y))
8763
x = torch.arange(2, 5).to(dtype=torch.float32)
8764
y = torch.arange(2, 4).to(dtype=torch.float64)
8765
self.run_test(Outer(), input_args=(x, y))
8767
x = torch.arange(3, 6).to(dtype=torch.int32)
8768
y = torch.arange(4, 7).to(dtype=torch.long)
8769
self.run_test(Outer(), input_args=(x, y))
8771
@skipIfUnsupportedMinOpsetVersion(9)
8772
def test_movedim(self):
8773
class MovedimModel(torch.nn.Module):
8774
def forward(self, x):
8779
x.movedim((1, 2, 3), (3, 0, 1)),
8780
x.movedim((0, 1, 2), (1, 2, 3)),
8781
x.movedim((1, 3, 2), (1, 3, 2)),
8784
x = torch.randn(5, 3, 4, 2)
8786
self.run_test(MovedimModel(), x)
8788
@skipIfUnsupportedMinOpsetVersion(9)
8789
def test_moveaxis(self):
8790
# moveaxis is an alias of movedim; thus, mostly copied from `test_movedim`.
8791
class MoveaxisModel(torch.nn.Module):
8792
def forward(self, x):
8797
x.moveaxis((1, 2, 3), (3, 0, 1)),
8798
x.moveaxis((0, 1, 2), (1, 2, 3)),
8799
x.moveaxis((1, 3, 2), (1, 3, 2)),
8802
x = torch.randn(5, 3, 4, 2)
8804
self.run_test(MoveaxisModel(), x)
8806
@skipIfUnsupportedMinOpsetVersion(12)
8807
def test_einsum(self):
8808
class EinsumModelBatchDiagonal(torch.nn.Module):
8809
def forward(self, x):
8810
eqn = "...ii ->...i"
8811
return torch.einsum(eqn, x)
8813
for x in [torch.randn(3, 5, 5), torch.randn(3, 5, 5).to(dtype=torch.bool)]:
8814
self.run_test(EinsumModelBatchDiagonal(), input_args=(x,))
8816
class EinsumModelBatchMatmul(torch.nn.Module):
8817
def forward(self, x, y):
8818
eqn = "bij, bjk -> bik"
8819
return torch.einsum(eqn, x, y)
8821
x = torch.randn(5, 2, 3)
8822
y = torch.randn(5, 3, 4)
8823
self.run_test(EinsumModelBatchMatmul(), input_args=(x, y))
8825
class EinsumModelInnerProd(torch.nn.Module):
8826
def forward(self, x, y):
8828
return torch.einsum(eqn, x, y)
8832
self.run_test(EinsumModelInnerProd(), input_args=(x, y))
8834
class EinsumModelTranspose(torch.nn.Module):
8835
def forward(self, x):
8837
return torch.einsum(eqn, x)
8839
for x in [torch.randn(3, 4), torch.randn(3, 4).to(dtype=torch.bool)]:
8840
self.run_test(EinsumModelTranspose(), input_args=(x,))
8842
@skipIfUnsupportedMinOpsetVersion(9)
8843
def test_cosine_similarity(self):
8844
x = torch.randn(5, 3, 2)
8845
y = torch.randn(5, 3, 2)
8846
self.run_test(torch.nn.CosineSimilarity(dim=2), input_args=(x, y))
8848
@skipIfUnsupportedMinOpsetVersion(9)
8849
def test_pairwise_distance(self):
8850
x = torch.randn(5, 3, 2)
8851
y = torch.randn(5, 3, 2)
8852
self.run_test(torch.nn.PairwiseDistance(p=2.0), input_args=(x, y))
8854
@skipIfUnsupportedMinOpsetVersion(9)
8855
def test_cross(self):
8856
class Cross(torch.nn.Module):
8857
def forward(self, x, y):
8858
return torch.cross(x, y, dim=3), torch.cross(x, y)
8860
x = torch.randn(5, 3, 2, 3)
8861
y = torch.randn(5, 3, 2, 3)
8862
self.run_test(Cross(), input_args=(x, y))
8864
@skipIfUnsupportedMinOpsetVersion(9)
8865
def test_cdist(self):
8866
class Cdist(torch.nn.Module):
8867
def forward(self, x, y):
8868
return torch.cdist(x, y)
8870
x = torch.randn(5, 3, 3)
8871
y = torch.randn(5, 2, 3)
8872
self.run_test(Cdist(), input_args=(x, y))
8874
@skipIfUnsupportedMinOpsetVersion(12)
8875
def test_crossentropyloss(self):
8876
for ignore_index in [-100, 1]:
8877
x = torch.randn(3, 5)
8878
y = torch.empty(3, dtype=torch.long).random_(5)
8879
y[y == 1] = ignore_index
8881
self._crossentropyloss(x, y, ignore_index)
8883
x = torch.randn(3, 5, 2)
8884
y = torch.empty(3, 2, dtype=torch.long).random_(5)
8885
y[y == 1] = ignore_index
8886
self._crossentropyloss(x, y, ignore_index)
8888
x = torch.randn(3, 5, 2, 7)
8889
y = torch.empty(3, 2, 7, dtype=torch.long).random_(5)
8890
y[y == 1] = ignore_index
8891
self._crossentropyloss(x, y, ignore_index)
8893
def _crossentropyloss(self, x, y, ignore_index):
8894
class CrossEntropyLossNone(torch.nn.Module):
8895
def __init__(self, ignore_index):
8897
if ignore_index == -100:
8898
self.loss = torch.nn.CrossEntropyLoss(reduction="none")
8900
self.loss = torch.nn.CrossEntropyLoss(
8901
reduction="none", ignore_index=ignore_index
8904
def forward(self, input, target):
8905
return self.loss(input, target)
8907
self.run_test(CrossEntropyLossNone(ignore_index), input_args=(x, y))
8909
class CrossEntropyLossNoneWeight(torch.nn.Module):
8910
def __init__(self, ignore_index):
8912
if ignore_index == -100:
8913
self.loss = torch.nn.CrossEntropyLoss(
8914
reduction="none", weight=torch.randn(5)
8917
self.loss = torch.nn.CrossEntropyLoss(
8919
weight=torch.randn(5),
8920
ignore_index=ignore_index,
8923
def forward(self, input, target):
8924
return self.loss(input, target)
8926
self.run_test(CrossEntropyLossNoneWeight(ignore_index), input_args=(x, y))
8928
class CrossEntropyLossSum(torch.nn.Module):
8929
def __init__(self, ignore_index):
8931
if ignore_index == -100:
8932
self.loss = torch.nn.CrossEntropyLoss(reduction="sum")
8934
self.loss = torch.nn.CrossEntropyLoss(
8935
reduction="sum", ignore_index=ignore_index
8938
def forward(self, input, target):
8939
return self.loss(input, target)
8941
self.run_test(CrossEntropyLossSum(ignore_index), input_args=(x, y))
8943
class CrossEntropyLossSumWeight(torch.nn.Module):
8944
def __init__(self, ignore_index):
8946
if ignore_index == -100:
8947
self.loss = torch.nn.CrossEntropyLoss(
8948
reduction="sum", weight=torch.randn(5)
8951
self.loss = torch.nn.CrossEntropyLoss(
8953
weight=torch.randn(5),
8954
ignore_index=ignore_index,
8957
def forward(self, input, target):
8958
return self.loss(input, target)
8960
self.run_test(CrossEntropyLossSumWeight(ignore_index), input_args=(x, y))
8962
class CrossEntropyLossMean(torch.nn.Module):
8963
def __init__(self, ignore_index):
8965
if ignore_index == -100:
8966
self.loss = torch.nn.CrossEntropyLoss()
8968
self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
8970
def forward(self, input, target):
8971
return self.loss(input, target)
8973
self.run_test(CrossEntropyLossMean(ignore_index), input_args=(x, y))
8975
class CrossEntropyLossMeanWeight(torch.nn.Module):
8976
def __init__(self, ignore_index):
8978
if ignore_index == -100:
8979
self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5))
8981
self.loss = torch.nn.CrossEntropyLoss(
8982
weight=torch.randn(5), ignore_index=ignore_index
8985
def forward(self, input, target):
8986
return self.loss(input, target)
8988
self.run_test(CrossEntropyLossMeanWeight(ignore_index), input_args=(x, y))
8990
@skipIfUnsupportedMinOpsetVersion(9)
8991
def test_MSELoss(self):
8992
class MSELoss(torch.nn.Module):
8995
self.loss1 = torch.nn.MSELoss(reduction="none")
8996
self.loss2 = torch.nn.MSELoss(reduction="sum")
8997
self.loss3 = torch.nn.MSELoss(reduction="mean")
8999
def forward(self, input, target):
9001
self.loss1(input, target),
9002
self.loss2(input, target),
9003
self.loss3(input, target),
9006
x = torch.randn(2, 3, 5)
9007
y = torch.randn(2, 3, 5)
9008
self.run_test(MSELoss(), input_args=(x, y))
9010
@skipIfUnsupportedMinOpsetVersion(9)
9011
def test_kldiv_loss(self):
9012
x = torch.rand(5).log()
9014
self._kldiv_loss(x, y)
9016
x = torch.rand(2, 3, 5).log()
9017
y = torch.rand(2, 3, 5)
9018
self._kldiv_loss(x, y)
9020
x = torch.rand(2, 3, 5, 7).log()
9021
y = torch.rand(2, 3, 5, 7)
9022
self._kldiv_loss(x, y)
9024
def _kldiv_loss(self, x, y):
9025
class KLDivLossNone(torch.nn.Module):
9028
self.loss = torch.nn.KLDivLoss(reduction="none", log_target=True)
9030
def forward(self, input, target):
9031
return self.loss(input, target.log())
9033
self.run_test(KLDivLossNone(), input_args=(x, y))
9035
class KLDivLossMean(torch.nn.Module):
9038
self.loss = torch.nn.KLDivLoss(reduction="mean", log_target=False)
9040
def forward(self, input, target):
9041
return self.loss(input, target)
9043
self.run_test(KLDivLossMean(), input_args=(x, y))
9045
class KLDivLossSum(torch.nn.Module):
9048
self.loss = torch.nn.KLDivLoss(reduction="sum", log_target=True)
9050
def forward(self, input, target):
9051
return self.loss(input, target.log())
9053
self.run_test(KLDivLossSum(), input_args=(x, y))
9055
class KLDivLossBatchMean(torch.nn.Module):
9058
self.loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=False)
9060
def forward(self, input, target):
9061
return self.loss(input, target)
9063
self.run_test(KLDivLossBatchMean(), input_args=(x, y))
9065
class KLDivLossMiniBatchMean(torch.nn.Module):
9068
self.loss = torch.nn.KLDivLoss(
9069
reduction="batchmean", size_average=False, log_target=True
9072
def forward(self, input, target):
9073
return self.loss(input, target.log())
9075
self.run_test(KLDivLossMiniBatchMean(), input_args=(x, y))
9077
@skipIfUnsupportedMinOpsetVersion(12)
9078
def test_nllloss(self):
9079
class NLLModel(torch.nn.Module):
9082
self.loss = torch.nn.NLLLoss(reduction="none")
9083
self.m = torch.nn.LogSoftmax(dim=1)
9085
def forward(self, input, target):
9086
output = self.loss(self.m(2 * input), target)
9090
input = torch.randn(N, 16)
9091
target = torch.empty(N, dtype=torch.long).random_(0, C)
9093
# using test data containing default ignore_index=-100
9094
target[target == 1] = -100
9095
self.run_test(NLLModel(), (input, target))
9097
@skipIfUnsupportedMinOpsetVersion(12)
9098
def test_nllloss_2d_none(self):
9099
class NLLModel(torch.nn.Module):
9102
self.loss = torch.nn.NLLLoss(reduction="none")
9103
self.conv = torch.nn.Conv2d(16, C, (3, 3))
9104
self.m = torch.nn.LogSoftmax(dim=1)
9106
def forward(self, input, target):
9107
output = self.loss(self.m(self.conv(input)), target)
9111
input = torch.randn(N, 16, 10, 10)
9112
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9114
# using test data containing default ignore_index=-100
9115
target[target == 1] = -100
9116
self.run_test(NLLModel(), (input, target))
9118
@skipIfUnsupportedMinOpsetVersion(12)
9119
def test_nllloss_2d_mean(self):
9120
class NLLModel(torch.nn.Module):
9123
self.loss = torch.nn.NLLLoss(reduction="mean")
9124
self.conv = torch.nn.Conv2d(16, C, (3, 3))
9125
self.m = torch.nn.LogSoftmax(dim=1)
9127
def forward(self, input, target):
9128
output = self.loss(self.m(self.conv(input)), target)
9132
input = torch.randn(N, 16, 10, 10)
9133
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9135
# using test data containing default ignore_index=-100
9136
target[target == 1] = -100
9137
self.run_test(NLLModel(), (input, target))
9139
@skipIfUnsupportedMinOpsetVersion(12)
9140
def test_nllloss_2d_sum(self):
9141
class NLLModel(torch.nn.Module):
9144
self.loss = torch.nn.NLLLoss(reduction="sum")
9145
self.conv = torch.nn.Conv2d(16, C, (3, 3))
9146
self.m = torch.nn.LogSoftmax(dim=1)
9148
def forward(self, input, target):
9149
output = self.loss(self.m(self.conv(input)), target)
9153
input = torch.randn(N, 16, 10, 10)
9154
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9156
# using test data containing default ignore_index=-100
9157
target[target == 1] = -100
9158
self.run_test(NLLModel(), (input, target))
9160
@skipIfUnsupportedMinOpsetVersion(12)
9161
def test_nllloss_2d_mean_weights(self):
9162
class NLLModel(torch.nn.Module):
9165
self.loss = torch.nn.NLLLoss(reduction="mean", weight=torch.randn(C))
9166
self.conv = torch.nn.Conv2d(16, C, (3, 3))
9167
self.m = torch.nn.LogSoftmax(dim=1)
9169
def forward(self, input, target):
9170
output = self.loss(self.m(self.conv(input)), target)
9174
input = torch.randn(N, 16, 10, 10)
9175
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9177
# using test data containing default ignore_index=-100
9178
target[target == 1] = -100
9179
self.run_test(NLLModel(), (input, target))
9181
@skipIfUnsupportedMinOpsetVersion(12)
9182
def test_nllloss_2d_mean_ignore_index(self):
9183
class NLLModel(torch.nn.Module):
9186
self.loss = torch.nn.NLLLoss(reduction="mean", ignore_index=1)
9187
self.conv = torch.nn.Conv2d(16, C, (3, 3))
9188
self.m = torch.nn.LogSoftmax(dim=1)
9190
def forward(self, input, target):
9191
output = self.loss(self.m(self.conv(input)), target)
9195
input = torch.randn(N, 16, 10, 10)
9196
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9197
self.run_test(NLLModel(), (input, target))
9199
@skipIfUnsupportedMinOpsetVersion(12)
9200
def test_nllloss_dynamic_ignore_index(self):
9201
import torch.nn.functional as F
9203
def linear_combination(x, y, epsilon):
9204
return epsilon * x + (1 - epsilon) * y
9206
def reduce_loss(loss, reduction="mean"):
9209
if reduction == "mean"
9211
if reduction == "sum"
9215
class LabelSmoothingCrossEntropy(torch.nn.Module):
9216
def __init__(self, epsilon: float = 0.1, reduction="mean"):
9218
self.epsilon = epsilon
9219
self.reduction = reduction
9221
def forward(self, preds, target, start_position):
9222
n = preds.size()[-1]
9223
log_preds = F.log_softmax(preds, dim=-1)
9224
ignore_index = start_position.size(1)
9228
reduction=self.reduction,
9229
ignore_index=ignore_index,
9231
return nll + start_position.float()
9234
preds = torch.randn(N, 16)
9235
target = torch.randint(5, (N,))
9236
start_position = torch.randint(10, (N, N))
9237
self.run_test(LabelSmoothingCrossEntropy(), (preds, target, start_position))
9239
@skipIfUnsupportedMinOpsetVersion(12)
9240
def test_nllloss_2d_mean_ignore_index_weights(self):
9241
class NLLModel(torch.nn.Module):
9244
self.loss = torch.nn.NLLLoss(
9245
reduction="mean", weight=torch.randn(C), ignore_index=1
9247
self.conv = torch.nn.Conv2d(16, C, (3, 3))
9248
self.m = torch.nn.LogSoftmax(dim=1)
9250
def forward(self, input, target):
9251
output = self.loss(self.m(self.conv(input)), target)
9255
input = torch.randn(N, 16, 10, 10)
9256
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9257
self.run_test(NLLModel(), (input, target))
9259
@skipIfUnsupportedMinOpsetVersion(12)
9260
def test_binary_cross_entropy_with_logits(self):
9262
y = torch.empty(5).random_(2)
9263
self._bce_logits(x, y)
9265
x = torch.randn(3, 4)
9266
y = torch.empty(3, 4).random_(2)
9267
weight = torch.tensor([3])
9268
self._bce_logits_wegiht(x, y, weight)
9270
x = torch.randn(3, 2, 4)
9271
y = torch.empty(3, 2, 4).random_(2)
9272
pos_weight = torch.empty([2, 4]).random_(2)
9273
self._bce_logits_posweight(x, y, pos_weight)
9275
x = torch.randn(3, 3, 4)
9276
y = torch.empty(3, 3, 4).random_(2)
9277
weight = torch.tensor([3])
9278
pos_weight = torch.empty([3, 4]).random_(2)
9279
self._bce_logits_loss_weight_posweight(x, y, weight, pos_weight)
9281
def _bce_logits(self, x, y):
9282
class BCEWithLogitsLossNone(torch.nn.Module):
9283
def forward(self, input, target):
9284
return torch.nn.functional.binary_cross_entropy_with_logits(
9285
input, target, reduction="none"
9288
self.run_test(BCEWithLogitsLossNone(), input_args=(x, y))
9290
class BCEWithLogitsLossMean(torch.nn.Module):
9291
def forward(self, input, target):
9292
return torch.nn.functional.binary_cross_entropy_with_logits(
9293
input, target, reduction="mean"
9296
self.run_test(BCEWithLogitsLossMean(), input_args=(x, y))
9298
class BCEWithLogitsLossSum(torch.nn.Module):
9299
def forward(self, input, target):
9300
return torch.nn.functional.binary_cross_entropy_with_logits(
9301
input, target, reduction="sum"
9304
self.run_test(BCEWithLogitsLossSum(), input_args=(x, y))
9306
def _bce_logits_wegiht(self, x, y, weight):
9307
class BCEWithLogitsLossWegihtNone(torch.nn.Module):
9308
def forward(self, input, target, weight):
9309
return torch.nn.functional.binary_cross_entropy_with_logits(
9310
input, target, weight=weight, reduction="none"
9313
self.run_test(BCEWithLogitsLossWegihtNone(), input_args=(x, y, weight))
9315
class BCEWithLogitsLossWegihtMean(torch.nn.Module):
9316
def forward(self, input, target, weight):
9317
return torch.nn.functional.binary_cross_entropy_with_logits(
9318
input, target, weight=weight, reduction="mean"
9321
self.run_test(BCEWithLogitsLossWegihtMean(), input_args=(x, y, weight))
9323
class BCEWithLogitsLossWegihtSum(torch.nn.Module):
9324
def forward(self, input, target, weight):
9325
return torch.nn.functional.binary_cross_entropy_with_logits(
9326
input, target, weight=weight, reduction="sum"
9329
self.run_test(BCEWithLogitsLossWegihtSum(), input_args=(x, y, weight))
9331
def _bce_logits_posweight(self, x, y, pos_weight):
9332
class BCEWithLogitsLossPosWegihtNone(torch.nn.Module):
9333
def forward(self, input, target, pos_weight):
9334
return torch.nn.functional.binary_cross_entropy_with_logits(
9335
input, target, pos_weight=pos_weight, reduction="none"
9338
self.run_test(BCEWithLogitsLossPosWegihtNone(), input_args=(x, y, pos_weight))
9340
class BCEWithLogitsLossPosWegihtMean(torch.nn.Module):
9341
def forward(self, input, target, pos_weight):
9342
return torch.nn.functional.binary_cross_entropy_with_logits(
9343
input, target, pos_weight=pos_weight, reduction="mean"
9346
self.run_test(BCEWithLogitsLossPosWegihtMean(), input_args=(x, y, pos_weight))
9348
class BCEWithLogitsLossPosWegihtSum(torch.nn.Module):
9349
def forward(self, input, target, pos_weight):
9350
return torch.nn.functional.binary_cross_entropy_with_logits(
9351
input, target, pos_weight=pos_weight, reduction="sum"
9354
self.run_test(BCEWithLogitsLossPosWegihtSum(), input_args=(x, y, pos_weight))
9356
def _bce_logits_loss_weight_posweight(self, x, y, weight, pos_weight):
9357
class BCEWithLogitsLossWeightPosweightNone(torch.nn.Module):
9358
def forward(self, input, target, weight, pos_weight):
9359
return torch.nn.functional.binary_cross_entropy_with_logits(
9363
pos_weight=pos_weight,
9368
BCEWithLogitsLossWeightPosweightNone(),
9369
input_args=(x, y, weight, pos_weight),
9372
class BCEWithLogitsLossWeightPosweightMean(torch.nn.Module):
9373
def forward(self, input, target, weight, pos_weight):
9374
return torch.nn.functional.binary_cross_entropy_with_logits(
9378
pos_weight=pos_weight,
9383
BCEWithLogitsLossWeightPosweightMean(),
9384
input_args=(x, y, weight, pos_weight),
9387
class BCEWithLogitsLossWeightPosweightSum(torch.nn.Module):
9388
def forward(self, input, target, weight, pos_weight):
9389
return torch.nn.functional.binary_cross_entropy_with_logits(
9390
input, target, weight=weight, pos_weight=pos_weight, reduction="sum"
9394
BCEWithLogitsLossWeightPosweightSum(), input_args=(x, y, weight, pos_weight)
9397
def test_torch_mm(self):
9398
class M(torch.nn.Module):
9399
def forward(self, mat1, mat2):
9400
mm = torch.mm(mat1, mat2)
9403
mat1 = torch.randn(2, 3)
9404
mat2 = torch.randn(3, 3)
9405
self.run_test(M(), input_args=(mat1, mat2))
9407
@skipIfUnsupportedMinOpsetVersion(
9409
) # Because where op is not supported for opset < 9.
9410
def test_where_with_bool_tensor(self):
9411
class M(torch.nn.Module):
9412
def forward(self, mat1, mat2):
9413
out = torch.where(mat1 > 0, mat1, mat2)
9416
mat1 = torch.randn(2, 3)
9417
mat2 = torch.ones(2, 3)
9418
self.run_test(M(), input_args=(mat1, mat2))
9420
@skipIfUnsupportedMinOpsetVersion(
9422
) # Because where op is not supported for opset < 9.
9423
def test_where_with_byte_tensor(self):
9424
class M(torch.nn.Module):
9425
def forward(self, cond, mat1, mat2):
9426
out = torch.where(cond, mat1, mat2)
9429
cond = torch.ones(2, 3, dtype=torch.uint8)
9431
mat1 = torch.randn(2, 3)
9432
mat2 = torch.ones(2, 3)
9433
self.run_test(M(), input_args=(cond, mat1, mat2))
9435
@skipIfUnsupportedMinOpsetVersion(10) # ONNX IsInf op is added in opset 10.
9436
def test_isinf(self):
9437
class M(torch.nn.Module):
9438
def forward(self, x):
9441
x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), float("inf")]])
9442
self.run_test(M(), (x,))
9444
@skipIfUnsupportedMinOpsetVersion(10)
9445
def test_isfinite(self):
9446
class M(torch.nn.Module):
9447
def forward(self, x):
9450
x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]])
9451
self.run_test(M(), (x,))
9453
@skipIfUnsupportedMinOpsetVersion(9) # ONNX IsNaN op is added in opset 9.
9454
def test_isnan(self):
9455
class M(torch.nn.Module):
9456
def forward(self, x):
9459
x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), float("inf")]])
9460
self.run_test(M(), (x,))
9462
@skipIfUnsupportedMinOpsetVersion(
9464
) # ONNX IsNaN, IsInf op is added in opset 9, 10 respectively.
9465
def test_nan_to_num(self):
9466
class NoParams(torch.nn.Module):
9467
def forward(self, x):
9468
return x.nan_to_num()
9470
x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]])
9471
xint = torch.ones((2, 4), dtype=torch.int)
9472
xhalf = torch.ones((2, 4), dtype=torch.half)
9473
self.run_test(NoParams(), (x,))
9474
self.run_test(NoParams(), (xint,))
9475
self.run_test(NoParams(), (xhalf,))
9477
class WithParams(torch.nn.Module):
9478
def forward(self, x):
9479
return x.nan_to_num(nan=2.3, posinf=4.5, neginf=6.7)
9481
x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]])
9482
self.run_test(WithParams(), (x,))
9484
@skipIfUnsupportedMinOpsetVersion(9)
9485
def test_maximum_minimum(self):
9486
class ModelWithNan(torch.nn.Module):
9487
def forward(self, x, y):
9488
return torch.maximum(x, y), torch.minimum(x, y)
9490
x = torch.tensor([-2, -2, float("nan")])
9491
y = torch.rand(1, 3)
9492
self.run_test(ModelWithNan(), (x, y))
9494
@skipIfUnsupportedMinOpsetVersion(12)
9495
def test_minimum_dtypes(self):
9496
class MinimumModel(torch.nn.Module):
9497
def forward(self, x, y):
9498
return torch.minimum(x, y)
9500
x = torch.randn((5, 5), dtype=torch.float16)
9501
y = torch.randn((5, 5), dtype=torch.float)
9502
self.run_test(MinimumModel(), (x, y))
9504
x = torch.randn((5, 5), dtype=torch.float16)
9505
y = torch.randint(10, (5, 5), dtype=torch.int16)
9506
self.run_test(MinimumModel(), (x, y))
9508
x = torch.randint(10, (5, 5), dtype=torch.int16)
9509
y = torch.randint(10, (5, 5), dtype=torch.int32)
9510
self.run_test(MinimumModel(), (x, y))
9512
x = torch.randint(10, (5, 5), dtype=torch.int)
9513
y = torch.full_like(x, True)
9514
self.run_test(MinimumModel(), (x, y))
9516
@skipIfUnsupportedMinOpsetVersion(12)
9517
def test_maximum_dtypes(self):
9518
class MaximumModel(torch.nn.Module):
9519
def forward(self, x, y):
9520
return torch.maximum(x, y)
9522
x = torch.randn((5, 5), dtype=torch.float16)
9523
y = torch.randn((5, 5), dtype=torch.float)
9524
self.run_test(MaximumModel(), (x, y))
9526
x = torch.randn((5, 5), dtype=torch.float16)
9527
y = torch.randint(10, (5, 5), dtype=torch.int16)
9528
self.run_test(MaximumModel(), (x, y))
9530
x = torch.randint(10, (5, 5), dtype=torch.int16)
9531
y = torch.randint(10, (5, 5), dtype=torch.int32)
9532
self.run_test(MaximumModel(), (x, y))
9534
x = torch.randint(10, (5, 5), dtype=torch.int)
9535
y = torch.full_like(x, True)
9536
self.run_test(MaximumModel(), (x, y))
9538
@skipIfUnsupportedMinOpsetVersion(9)
9540
class M(torch.nn.Module):
9541
def forward(self, x):
9544
x = torch.tensor([[True, False], [False, False]])
9545
self.run_test(M(), (x,))
9547
class MDim(torch.nn.Module):
9548
def forward(self, x):
9551
x = torch.rand(3, 4).bool()
9552
self.run_test(MDim(), (x,))
9554
class MKeepdim(torch.nn.Module):
9555
def forward(self, x):
9556
return x.any(dim=1, keepdim=True)
9558
x = torch.rand(3, 4).bool()
9559
self.run_test(MKeepdim(), (x,))
9561
@skipIfUnsupportedMinOpsetVersion(9)
9563
class M(torch.nn.Module):
9564
def forward(self, x):
9567
x = torch.tensor([[True, False], [False, False]])
9568
self.run_test(M(), (x,))
9570
class MDim(torch.nn.Module):
9571
def forward(self, x):
9574
x = torch.rand(3, 4).bool()
9575
self.run_test(MDim(), (x,))
9577
class MKeepdim(torch.nn.Module):
9578
def forward(self, x):
9579
return x.all(dim=1, keepdim=True)
9581
x = torch.rand(3, 4).bool()
9582
self.run_test(MKeepdim(), (x,))
9584
def test_dropout(self):
9585
class M(torch.nn.Module):
9588
self.dropout = torch.nn.Dropout(0.3)
9590
def forward(self, x):
9591
dropout = self.dropout(x)
9594
x = torch.randn(10, 3, 53)
9595
self.run_test(M(), (x))
9597
def test_rrelu_eval(self):
9598
x = torch.tensor([0.5, -0.5])
9599
self.run_test(torch.nn.RReLU(0.1, 0.3).eval(), x)
9601
def test_shape_constant_fold(self):
9602
class ShapeModule(torch.nn.Module):
9605
self.register_buffer("weight", torch.ones(5))
9607
def forward(self, x):
9608
shape = self.weight.shape[0]
9611
x = torch.randn(2, 5)
9612
self.run_test(ShapeModule(), (x,), rtol=1e-3, atol=1e-5)
9614
@skipIfUnsupportedMinOpsetVersion(12)
9615
def test_celu(self):
9616
class Celu(torch.nn.Module):
9619
self.celu = torch.nn.CELU(alpha=1.0)
9621
def forward(self, input):
9622
return self.celu(input)
9624
input = torch.randn(2)
9625
self.run_test(Celu(), (input,))
9627
@skipIfUnsupportedMinOpsetVersion(12)
9628
def test_celu_default(self):
9629
class Celu(torch.nn.Module):
9632
self.celu = torch.nn.CELU()
9634
def forward(self, input):
9635
return self.celu(input)
9637
input = torch.randn(2)
9638
self.run_test(Celu(), (input,))
9640
@skipIfUnsupportedMinOpsetVersion(12)
9641
def test_celu_alpha(self):
9642
class Celu(torch.nn.Module):
9645
self.celu = torch.nn.CELU(alpha=2.0)
9647
def forward(self, input):
9648
return self.celu(input)
9650
input = torch.randn(2)
9651
self.run_test(Celu(), (input,))
9653
@skipIfUnsupportedMinOpsetVersion(12)
9654
def test_celu_cast(self):
9655
class Celu(torch.nn.Module):
9658
self.celu = torch.nn.CELU()
9660
def forward(self, input):
9661
return self.celu(input)
9663
input = torch.randn(2, 5, 7, dtype=torch.float64)
9664
self.run_test(Celu(), (input,))
9666
def test_lower_tuple(self):
9667
class TupleModule(torch.nn.Module):
9668
def forward(self, input1: Tensor, input2: Tensor, input3: Tensor) -> Tensor:
9669
a = (input1, input2)
9671
c = (input1, input2, input3)
9678
if f.size(0) != input1.size(-1):
9689
input1 = torch.randn(2)
9690
input2 = torch.randn(2)
9691
input3 = torch.randn(2)
9692
self.run_test(TupleModule(), (input1, input2, input3))
9694
def test_lower_tuple_2(self):
9695
class TupleModule(torch.nn.Module):
9696
def forward(self, input1: Tensor, input2: Tensor) -> Tuple[Tensor, Tensor]:
9697
a = (input1, input2)
9703
input1 = torch.randn(2)
9704
input2 = torch.randn(2)
9705
self.run_test(TupleModule(), (input1, input2))
9707
def test_lower_tuple_3(self):
9708
class TupleModule(torch.nn.Module):
9711
input1: Tuple[Tensor, Tensor],
9712
input2: Tuple[Tensor, Tensor],
9713
) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
9719
if c.shape[0] == e.shape[0]:
9727
input1 = (torch.randn(2), torch.randn(2))
9728
input2 = (torch.randn(2), torch.randn(2))
9729
self.run_test(TupleModule(), (input1, input2))
9731
@skipIfUnsupportedMinOpsetVersion(9)
9732
def test_where(self):
9733
class Model(torch.nn.Module):
9734
def forward(self, cond, input, other):
9735
return torch.where(cond, input, other)
9737
x = torch.randint(0, 1, (2, 3, 4), dtype=torch.bool)
9738
y = torch.randn(2, 1, 4)
9739
z = torch.ones(2, 3, 1)
9740
self.run_test(Model(), (x, y, z))
9742
@skipIfUnsupportedMinOpsetVersion(9)
9743
@skipScriptTest() # scripting tests run for opsets > 11. See: test_where_condition_script
9744
def test_where_condition(self):
9745
class Model1(torch.nn.Module):
9746
def forward(self, input):
9747
return torch.stack(torch.where(input > 0.5), dim=1)
9749
x = torch.randint(0, 2, (2, 3, 4), dtype=bool)
9750
self.run_test(Model1(), (x))
9752
class Model2(torch.nn.Module):
9753
def forward(self, input, other):
9754
return torch.stack(torch.where(input > other), dim=1)
9756
x = torch.randint(0, 1, (2, 3, 4), dtype=bool)
9757
y = torch.randint(1, 2, (2, 3, 4), dtype=bool)
9758
self.run_test(Model2(), (x, y))
9760
@skipIfUnsupportedOpsetVersion([13])
9761
@skipIfUnsupportedMinOpsetVersion(11)
9762
def test_where_condition_script(self):
9763
class Model1(torch.nn.Module):
9764
def forward(self, input):
9765
return torch.stack(torch.where(input > 0.5), dim=1)
9767
x = torch.randint(0, 2, (2, 3, 4), dtype=bool)
9768
self.run_test(Model1(), (x))
9770
class Model2(torch.nn.Module):
9771
def forward(self, input, other):
9772
return torch.stack(torch.where(input > other), dim=1)
9774
x = torch.randint(0, 1, (2, 3, 4), dtype=bool)
9775
y = torch.randint(1, 2, (2, 3, 4), dtype=bool)
9776
self.run_test(Model2(), (x, y))
9778
def test_empty_branch(self):
9779
class EmptyBranchModel(torch.jit.ScriptModule):
9780
@torch.jit.script_method
9781
def forward(self, input):
9792
x = torch.randn(1, 2, 3, requires_grad=True)
9793
self.run_test(EmptyBranchModel(), x)
9795
@skipIfUnsupportedMinOpsetVersion(11)
9796
def test_derive_index_scripting(self):
9797
class MyModule(torch.nn.Module):
9798
def forward(self, x: Tensor):
9800
for idx in range(len(x) - 1, -len(x), -2):
9805
x = torch.randn(5, 13)
9806
self.run_test(MyModule(), x)
9808
class MyModule(torch.nn.Module):
9809
def forward(self, x: Tensor):
9811
for idx in range(-len(x), len(x) - 1, 2):
9816
x = torch.randn(5, 13)
9817
self.run_test(MyModule(), x)
9819
class MyModule(torch.nn.Module):
9820
def forward(self, x: Tensor):
9822
for idx in range(len(x) - 1, -len(x), -3):
9827
self.run_test(MyModule(), x)
9829
class MyModule(torch.nn.Module):
9830
def forward(self, x: Tensor):
9832
for idx in range(-len(x), len(x) - 1, 3):
9837
self.run_test(MyModule(), x)
9839
@skipScriptTest() # Scripting fails for add lists for opsets < 11. Chek test_derive_index_scripting
9840
def test_derive_index(self):
9841
class MyModule(torch.nn.Module):
9842
def forward(self, x: Tensor):
9844
for idx in range(len(x) - 1, -len(x), -2):
9849
x = torch.randn(5, 13)
9850
self.run_test(MyModule(), x)
9852
class MyModule(torch.nn.Module):
9853
def forward(self, x: Tensor):
9855
for idx in range(-len(x), len(x) - 1, 2):
9860
x = torch.randn(5, 13)
9861
self.run_test(MyModule(), x)
9863
class MyModule(torch.nn.Module):
9864
def forward(self, x: Tensor):
9866
for idx in range(len(x) - 1, -len(x), -3):
9871
self.run_test(MyModule(), x)
9873
class MyModule(torch.nn.Module):
9874
def forward(self, x: Tensor):
9876
for idx in range(-len(x), len(x) - 1, 3):
9881
self.run_test(MyModule(), x)
9883
@skipIfUnsupportedMinOpsetVersion(11)
9884
def test_if_transpose(self):
9885
class IfModel(torch.nn.Module):
9886
def forward(self, x):
9887
x = x.transpose(0, 1)
9889
return x.transpose(0, 1)
9893
x = torch.randn(2, 3)
9895
torch.jit.script(IfModel()),
9897
output_names=["output_1"],
9898
dynamic_axes={"output_1": [0, 1]},
9901
@skipIfUnsupportedMinOpsetVersion(13)
9902
def test_if_list(self):
9903
class IfModel(torch.nn.Module):
9904
def forward(self, x, y, cond):
9912
x = torch.randn(2, 3)
9913
y = torch.randn(3, 3)
9914
cond = torch.tensor(1, dtype=torch.bool)
9915
self.run_test(torch.jit.script(IfModel()), (x, y, cond))
9917
@skipIfUnsupportedMinOpsetVersion(13)
9918
def test_if_view(self):
9919
class IfModel(torch.nn.Module):
9920
def forward(self, x, y, cond):
9921
bs, seq = y.shape[:2]
9923
res = x.view(bs, seq, -1)
9926
return res.transpose(1, 2)
9928
x = torch.randn(2, 16, 2, 2)
9929
y = torch.randn(2, 16, 8)
9930
cond = torch.tensor(1, dtype=torch.bool)
9932
torch.jit.script(IfModel()),
9934
output_names=["output_1"],
9935
dynamic_axes={"output_1": [1]},
9939
skip_before_opset_version=11, reason="dynamic split support added in 11"
9941
def test_split_tensor_scalar(self):
9942
class SplitModel(torch.nn.Module):
9943
def forward(self, x):
9944
return torch.split(x, x.size(1))
9946
x = torch.randn(1, 2, 3, requires_grad=True)
9947
self.run_test(SplitModel(), x)
9949
def test_split_tensor_multi(self):
9950
class SplitModel(torch.nn.Module):
9951
def forward(self, x):
9952
return torch.split(x, torch.ones(3))
9954
x = torch.randn(1, 2, 3, requires_grad=True)
9959
self.assertRaises(TypeError, run_model)
9961
@skipIfUnsupportedMinOpsetVersion(9)
9962
def test_embedding(self):
9963
class EmbedModel(torch.nn.Module):
9964
def forward(self, input, emb):
9965
return torch.nn.functional.embedding(input, emb, padding_idx=1)
9967
model = EmbedModel()
9968
x = torch.randint(4, (4,))
9970
embedding_matrix = torch.rand(10, 3)
9971
self.run_test(model, (x, embedding_matrix))
9973
x = torch.randint(4, (4, 3, 2))
9976
self.run_test(model, (x, embedding_matrix))
9978
model, (x, embedding_matrix), training=torch.onnx.TrainingMode.TRAINING
9981
class EmbedModelWithoutPaddingIdx(torch.nn.Module):
9982
def forward(self, input, emb):
9983
return torch.nn.functional.embedding(input, emb)
9985
model = EmbedModelWithoutPaddingIdx()
9986
x = torch.randint(4, (4, 3, 2))
9987
self.run_test(model, (x, embedding_matrix))
9989
@skipIfUnsupportedMinOpsetVersion(9)
9990
def test_embedding_module(self):
9991
class EmbedModel(torch.nn.Module):
9994
self.emb = torch.nn.Embedding(4, 3, padding_idx=1)
9995
self.emb2 = torch.nn.Embedding(4, 3, padding_idx=1)
9996
with torch.no_grad():
9997
self.emb2.weight[1] = torch.ones(3)
9999
def forward(self, input):
10000
return self.emb(input), self.emb2(input)
10002
model = EmbedModel()
10003
x = torch.randint(4, (4,))
10005
self.run_test(model, (x,))
10007
x = torch.randint(4, (4, 3, 2))
10010
self.run_test(model, (x,))
10012
class EmbedModelWithoutPaddingIdx(torch.nn.Module):
10013
def __init__(self):
10015
self.emb = torch.nn.Embedding(4, 3)
10017
def forward(self, input):
10018
return self.emb(input)
10020
model = EmbedModelWithoutPaddingIdx()
10021
x = torch.randint(4, (4, 3, 2))
10022
self.run_test(model, (x,))
10024
@skipIfUnsupportedMinOpsetVersion(11)
10025
def test_embedding_renorm(self):
10027
embedding = torch.nn.Embedding(n, d, max_norm=0.2)
10028
idx = torch.tensor([2, 1])
10029
self.run_test(embedding, idx)
10031
embedding = torch.nn.Embedding(n, d, max_norm=0.5, norm_type=1.0)
10032
idx = torch.tensor([4, 3, 4, 2])
10033
self.run_test(embedding, idx)
10035
def _dispatch_rnn_test(self, name, *args, **kwargs):
10036
if name == "elman":
10037
self._elman_rnn_test(*args, **kwargs)
10039
self._lstm_test(*args, **kwargs)
10041
self._gru_test(*args, **kwargs)
10043
def _elman_rnn_test(
10053
class ElmanWithStateModel(torch.nn.Module):
10054
def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first):
10057
self.batch_first = batch_first
10058
self.inner_model = torch.nn.RNN(
10062
nonlinearity=nonlinearity,
10063
bidirectional=bidirectional,
10065
batch_first=batch_first,
10068
def forward(self, input: rnn_utils.PackedSequence, hx=None):
10069
return self.inner_model(input, hx)
10071
class ElmanWithoutStateModel(torch.nn.Module):
10072
def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first):
10074
self.batch_first = batch_first
10075
self.inner_model = torch.nn.RNN(
10079
nonlinearity=nonlinearity,
10080
bidirectional=bidirectional,
10082
batch_first=batch_first,
10085
def forward(self, input: rnn_utils.PackedSequence):
10086
return self.inner_model(input)
10088
batch_first = packed_sequence == 2
10091
model = ElmanWithStateModel(
10093
bidirect=bidirectional,
10094
nonlinearity=nonlinearity,
10096
batch_first=batch_first,
10098
if packed_sequence:
10100
rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState(
10105
model = ElmanWithoutStateModel(
10107
bidirect=bidirectional,
10108
nonlinearity=nonlinearity,
10110
batch_first=batch_first,
10112
if packed_sequence:
10113
model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState(
10117
def make_input(batch_size):
10118
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
10119
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
10120
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
10121
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
10123
input_names = ["input"]
10125
directions = 2 if bidirectional else 1
10128
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
10130
input_names.append("h0")
10131
if packed_sequence != 0:
10132
inputs.append(torch.IntTensor(seq_lengths))
10133
input_names.append("seq_lengths")
10134
if len(inputs) == 1:
10137
input = tuple(inputs)
10138
return input, input_names
10140
input, input_names = make_input(RNN_BATCH_SIZE)
10141
dynamic_axes = {"input": [0, 1], "seq_lengths": [0]}
10143
dynamic_axes.update({"h0": [1]})
10144
export_options = {"input_names": input_names, "dynamic_axes": dynamic_axes}
10146
# test that the model still runs with a different batch size
10147
other_input, _ = make_input(RNN_BATCH_SIZE + 1)
10149
model, input, additional_test_inputs=[other_input], **export_options
10161
batch_first = packed_sequence == 2
10163
if packed_sequence:
10164
model = lstm_flattening_result.LstmFlatteningResultWithSeqLength(
10174
rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState(
10179
model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState(
10183
model = lstm_flattening_result.LstmFlatteningResultWithoutSeqLength(
10192
def make_input(batch_size):
10193
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
10194
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
10195
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
10196
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
10198
input_names = ["input"]
10199
directions = 2 if bidirectional else 1
10202
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
10203
c0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
10204
inputs.append((h0, c0))
10205
input_names.append("h0")
10206
input_names.append("c0")
10207
if packed_sequence != 0:
10208
inputs.append(torch.IntTensor(seq_lengths))
10209
input_names.append("seq_lengths")
10210
if len(inputs) == 1:
10213
input = tuple(inputs)
10214
return input, input_names
10216
input, input_names = make_input(RNN_BATCH_SIZE)
10217
dynamic_axes = {"input": [0, 1], "seq_lengths": [0]}
10219
dynamic_axes.update({"h0": [1], "c0": [1]})
10220
export_options = {"input_names": input_names, "dynamic_axes": dynamic_axes}
10222
# test that the model still runs with a different batch size
10223
other_input, _ = make_input(RNN_BATCH_SIZE + 1)
10225
model, input, additional_test_inputs=[other_input], **export_options
10237
class GRUWithStateModel(torch.nn.Module):
10238
def __init__(self, layers, bidirect, dropout, batch_first):
10241
self.batch_first = batch_first
10242
self.inner_model = torch.nn.GRU(
10246
bidirectional=bidirectional,
10248
batch_first=batch_first,
10251
def forward(self, input: rnn_utils.PackedSequence, hx):
10252
return self.inner_model(input, hx)
10254
class GRUWithoutStateModel(torch.nn.Module):
10255
def __init__(self, layers, bidirect, dropout, batch_first):
10257
self.batch_first = batch_first
10258
self.inner_model = torch.nn.GRU(
10262
bidirectional=bidirectional,
10264
batch_first=batch_first,
10267
def forward(self, input: rnn_utils.PackedSequence):
10268
return self.inner_model(input)
10270
class GRUNoSeqLengthWithoutStateModel(torch.nn.Module):
10271
def __init__(self, layers, bidirect, dropout, batch_first):
10273
self.batch_first = batch_first
10274
self.inner_model = torch.nn.GRU(
10278
bidirectional=bidirectional,
10280
batch_first=batch_first,
10283
def forward(self, input):
10284
return self.inner_model(input)
10286
class GRUNoSeqLengthWithStateModel(torch.nn.Module):
10287
def __init__(self, layers, bidirect, dropout, batch_first):
10289
self.batch_first = batch_first
10290
self.inner_model = torch.nn.GRU(
10294
bidirectional=bidirectional,
10296
batch_first=batch_first,
10299
def forward(self, input, hx):
10300
return self.inner_model(input, hx)
10302
batch_first = packed_sequence == 2
10304
if packed_sequence:
10306
model = GRUWithStateModel(
10308
bidirect=bidirectional,
10310
batch_first=batch_first,
10313
rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState(
10318
model = GRUWithoutStateModel(
10320
bidirect=bidirectional,
10322
batch_first=batch_first,
10324
model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState(
10329
model = GRUNoSeqLengthWithStateModel(
10331
bidirect=bidirectional,
10333
batch_first=batch_first,
10336
model = GRUNoSeqLengthWithoutStateModel(
10338
bidirect=bidirectional,
10340
batch_first=batch_first,
10343
def make_input(batch_size):
10344
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
10345
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
10346
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
10347
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
10349
input_names = ["input"]
10351
directions = 2 if bidirectional else 1
10354
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
10356
input_names.append("h0")
10357
if packed_sequence != 0:
10358
inputs.append(torch.IntTensor(seq_lengths))
10359
input_names.append("seq_lengths")
10360
if len(inputs) == 1:
10363
input = tuple(inputs)
10364
return input, input_names
10366
input, input_names = make_input(RNN_BATCH_SIZE)
10367
dynamic_axes = {"input": [0, 1], "seq_lengths": [0]}
10369
dynamic_axes.update({"h0": [1]})
10370
export_options = {"input_names": input_names, "dynamic_axes": dynamic_axes}
10372
# test that the model still runs with a different batch size
10373
other_input, _ = make_input(RNN_BATCH_SIZE + 1)
10375
model, input, additional_test_inputs=[other_input], **export_options
10378
@skipIfUnsupportedMinOpsetVersion(10)
10379
def test_fake_quantize_per_tensor(self):
10380
class FakeQuantizePerTensorModel(torch.nn.Module):
10381
def forward(self, input):
10386
return torch.fake_quantize_per_tensor_affine(
10387
input, scale, zero_point, quant_min, quant_max
10390
x = torch.randn(6, 4, 3, 3)
10391
self.run_test(FakeQuantizePerTensorModel(), (x))
10393
@skipIfUnsupportedMinOpsetVersion(13)
10394
def test_fake_quantize_per_tensor_dynamic_scale_zeropoint(self):
10395
class FakeQuantizePerTensorModel(torch.nn.Module):
10396
def forward(self, input, scale, zero_point):
10399
return torch.fake_quantize_per_tensor_affine(
10400
input, scale, zero_point, quant_min, quant_max
10403
x = torch.randn(6, 4, 3, 3)
10404
scale = torch.tensor(1.0 / 127)
10405
zero_point = torch.tensor(0)
10406
self.run_test(FakeQuantizePerTensorModel(), (x, scale, zero_point))
10408
@skipIfUnsupportedMinOpsetVersion(13)
10409
def test_fake_quantize_per_channel(self):
10410
class FakeQuantizePerChannelModel(torch.nn.Module):
10411
def forward(self, input):
10412
amax = torch.ones(4)
10413
scale = amax / 127.0
10414
zero_point = torch.zeros_like(amax, dtype=torch.int)
10415
# Quantize twice to test differnet branches
10416
y = torch.fake_quantize_per_channel_affine(
10417
input, scale, zero_point, 1, 0, 255
10419
return torch.fake_quantize_per_channel_affine(
10420
y, scale, zero_point, 1, -128, 127
10423
x = torch.randn(6, 4, 3, 3)
10424
self.run_test(FakeQuantizePerChannelModel(), (x))
10426
@skipIfUnsupportedMinOpsetVersion(13)
10427
# RuntimeError: Can't redefine method:
10428
# forward on class: __torch__.torch.nn.modules.linear.Linear
10430
def test_fake_quantize_activation(self):
10431
from torch.ao import quantization
10433
m = torch.nn.Linear(1, 1)
10434
m.qconfig = quantization.QConfig(
10435
activation=quantization.default_fake_quant,
10436
weight=quantization.default_per_channel_weight_fake_quant,
10438
quantization.prepare_qat(m.train(), inplace=True)
10439
m.apply(quantization.enable_observer)
10440
m.apply(quantization.enable_fake_quant)
10441
for module in m.modules():
10442
if isinstance(module, quantization.FakeQuantize):
10443
module.calculate_qparams()
10445
m.apply(quantization.disable_observer)
10448
# Fake quantize activation is a special case, as it restricts quantized range to be (0, 127),
10449
# while standard 8bit quantization range is (-128, 127) or (0, 255).
10450
# Set fixed weight, bias and inputs to test if ONNX handles the overflow correctly.
10451
m.weight = torch.nn.Parameter(torch.tensor([[1.0], [1.0], [1.0]]))
10452
m.bias = torch.nn.Parameter(torch.tensor([0.0]))
10453
x = torch.tensor([[150.0], [127.0], [-5.0]])
10454
self.run_test(m, x)
10456
def test_batchnorm_training(self):
10457
class MyModule(torch.nn.Module):
10458
def __init__(self):
10460
self.bn1 = torch.nn.BatchNorm2d(3, affine=False)
10461
self.cv1 = torch.nn.Conv2d(3, 3, 10)
10462
self.bn2 = torch.nn.BatchNorm2d(3, affine=True)
10463
self.cv2 = torch.nn.Conv2d(3, 3, 10)
10464
self.bn3 = torch.nn.BatchNorm2d(3, affine=False)
10466
def forward(self, x):
10474
x = torch.randn(10, 3, 20, 20) * 2
10475
model_export = MyModule()
10479
training=torch.onnx.TrainingMode.TRAINING,
10483
model_export.train()
10487
training=torch.onnx.TrainingMode.PRESERVE,
10492
def test_batchnorm_training_mode_fix_layer(self):
10493
class MyModule(torch.nn.Module):
10494
def __init__(self):
10496
self.bn1 = torch.nn.BatchNorm2d(3, affine=True)
10497
self.cv1 = torch.nn.Conv2d(3, 3, 10)
10498
self.bn2 = torch.nn.BatchNorm2d(3, affine=False)
10499
self.cv2 = torch.nn.Conv2d(3, 3, 10)
10500
self.bn3 = torch.nn.BatchNorm2d(3, affine=True)
10503
def forward(self, x):
10511
x = torch.randn(10, 3, 128, 128)
10512
model_export = MyModule()
10516
training=torch.onnx.TrainingMode.TRAINING,
10520
model_export.train()
10524
training=torch.onnx.TrainingMode.PRESERVE,
10529
def test_batchnorm_eval_mode_train_layer(self):
10530
class MyModule(torch.nn.Module):
10531
def __init__(self):
10533
self.bn1 = torch.nn.BatchNorm2d(3, affine=True)
10534
self.cv1 = torch.nn.Conv2d(3, 3, 10)
10535
self.bn2 = torch.nn.BatchNorm2d(3, affine=False)
10536
self.cv2 = torch.nn.Conv2d(3, 3, 10)
10537
self.bn3 = torch.nn.BatchNorm2d(3, affine=True)
10540
def forward(self, x):
10548
x = torch.randn(10, 3, 128, 128)
10549
model_export = MyModule()
10553
training=torch.onnx.TrainingMode.EVAL,
10557
model_export.eval()
10561
training=torch.onnx.TrainingMode.PRESERVE,
10566
def test_instancenorm_training(self):
10567
class MyModule(torch.nn.Module):
10568
def __init__(self):
10570
self.in1 = torch.nn.InstanceNorm2d(3, affine=True)
10571
self.cv1 = torch.nn.Conv2d(3, 3, 10)
10572
self.in2 = torch.nn.InstanceNorm2d(3, affine=False)
10573
self.cv2 = torch.nn.Conv2d(3, 3, 10)
10574
self.in3 = torch.nn.InstanceNorm2d(3, affine=True)
10576
def forward(self, x):
10584
x = torch.randn(10, 3, 128, 128)
10585
model_export = MyModule()
10589
training=torch.onnx.TrainingMode.TRAINING,
10593
model_export.train()
10597
training=torch.onnx.TrainingMode.PRESERVE,
10602
def test_instancenorm_training_mode_fix_layer(self):
10603
class MyModule(torch.nn.Module):
10604
def __init__(self):
10606
self.in1 = torch.nn.InstanceNorm2d(3, affine=True)
10607
self.cv1 = torch.nn.Conv2d(3, 3, 10)
10608
self.in2 = torch.nn.InstanceNorm2d(3, affine=False)
10609
self.cv2 = torch.nn.Conv2d(3, 3, 10)
10610
self.in3 = torch.nn.InstanceNorm2d(3, affine=True)
10613
def forward(self, x):
10621
x = torch.randn(10, 3, 128, 128)
10622
model_export = MyModule()
10626
training=torch.onnx.TrainingMode.TRAINING,
10630
model_export.train()
10634
training=torch.onnx.TrainingMode.PRESERVE,
10639
def test_instancenorm_eval_mode_train_layer(self):
10640
class MyModule(torch.nn.Module):
10641
def __init__(self):
10643
self.in1 = torch.nn.InstanceNorm2d(8, affine=True)
10644
self.cv1 = torch.nn.Conv2d(8, 8, 10)
10645
self.in2 = torch.nn.InstanceNorm2d(8, affine=False)
10646
self.cv2 = torch.nn.Conv2d(8, 8, 10)
10647
self.in3 = torch.nn.InstanceNorm2d(8, affine=True)
10650
def forward(self, x):
10658
x = torch.randn(10, 8, 128, 128)
10659
model_export = MyModule()
10663
training=torch.onnx.TrainingMode.EVAL,
10667
model_export.eval()
10671
training=torch.onnx.TrainingMode.PRESERVE,
10676
@skipIfUnsupportedMinOpsetVersion(12)
10677
def test_dropout_training(self):
10678
class MyModule(torch.nn.Module):
10679
def __init__(self):
10681
self.dropout = torch.nn.Dropout(0.4)
10683
def forward(self, x):
10684
dropout = self.dropout(x)
10688
x = torch.randn(10)
10691
model_onnx = io.BytesIO()
10696
opset_version=self.opset_version,
10697
do_constant_folding=False,
10698
training=torch.onnx.TrainingMode.TRAINING,
10700
ort_sess = verification._ort_session(model_onnx)
10701
ort_outs = verification._run_onnx(ort_sess, (x,))
10702
assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
10704
script_model = torch.jit.script(model)
10706
model_onnx = io.BytesIO()
10711
opset_version=self.opset_version,
10712
do_constant_folding=False,
10713
training=torch.onnx.TrainingMode.TRAINING,
10715
ort_outs = verification._run_onnx(ort_sess, (x,))
10716
assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
10718
@skipIfUnsupportedMinOpsetVersion(12)
10719
def test_dropout_training_zero(self):
10720
class MyModule(torch.nn.Module):
10721
def __init__(self):
10723
self.dropout = torch.nn.Dropout(0.5)
10725
def forward(self, x):
10726
dropout = self.dropout(x)
10731
# ensure there are no zeros in the input
10732
x = torch.randn(10, 3, 128, 128)
10734
y_mask = np.where(y == 0, 1, y)
10735
input = torch.from_numpy(y_mask)
10736
nb_elements = torch.numel(input)
10739
model_onnx = io.BytesIO()
10744
opset_version=self.opset_version,
10745
do_constant_folding=False,
10746
training=torch.onnx.TrainingMode.TRAINING,
10748
ort_sess = verification._ort_session(model_onnx)
10749
ort_outs = verification._run_onnx(ort_sess, (x,))
10752
output = y.cpu().numpy()
10753
ort_mask = np.where(ort_outs[0] != 0, 1, 0)
10754
pyt_mask = np.where(output != 0, 1, 0)
10756
ratio_pytorch = np.sum(pyt_mask) / nb_elements
10757
ratio_ort = np.sum(ort_mask) / nb_elements
10759
np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
10761
script_model = torch.jit.script(model)
10763
output = y.cpu().numpy()
10764
model_onnx = io.BytesIO()
10769
opset_version=self.opset_version,
10770
do_constant_folding=False,
10771
training=torch.onnx.TrainingMode.TRAINING,
10773
ort_sess = verification._ort_session(model_onnx)
10774
ort_outs = verification._run_onnx(ort_sess, (x,))
10775
ort_mask = np.where(ort_outs[0] != 0, 1, 0)
10776
pyt_mask = np.where(output != 0, 1, 0)
10778
ratio_pytorch = np.sum(pyt_mask) / nb_elements
10779
ratio_ort = np.sum(ort_mask) / nb_elements
10781
np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
10783
def test_conv_bn(self):
10784
class MyModule(torch.nn.Module):
10785
def __init__(self):
10787
self.conv = torch.nn.Conv2d(
10788
3, 16, kernel_size=1, stride=2, padding=3, bias=True
10790
self.bn = torch.nn.BatchNorm2d(16, affine=True)
10792
def forward(self, x):
10797
model_export = MyModule()
10798
x = torch.randn(10, 3, 128, 128)
10799
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL)
10803
training=torch.onnx.TrainingMode.TRAINING,
10808
def test_multiple_conv_bn(self):
10809
class MyModule(torch.nn.Module):
10810
def __init__(self):
10812
self.conv1 = torch.nn.Conv2d(
10813
3, 64, kernel_size=7, stride=2, padding=3, bias=False
10815
self.conv2 = torch.nn.Conv2d(
10816
64, 2, kernel_size=1, stride=1, padding=0, bias=False
10818
self.conv3 = torch.nn.Conv2d(
10819
2, 2, kernel_size=3, stride=1, padding=1, bias=False
10821
self.bn = torch.nn.BatchNorm2d(64)
10822
self.bn2 = torch.nn.BatchNorm2d(2)
10823
self.relu = torch.nn.ReLU(inplace=True)
10824
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
10826
def forward(self, x):
10830
x = self.maxpool(x)
10839
model_export = MyModule()
10840
x = torch.randn(2, 3, 224, 224)
10844
training=torch.onnx.TrainingMode.TRAINING,
10848
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL)
10850
@skipIfUnsupportedMinOpsetVersion(11)
10851
def test_nms(self):
10853
boxes = torch.rand(num_boxes, 4)
10854
boxes[:, 2:] += boxes[:, :2]
10855
scores = torch.randn(num_boxes)
10857
class Module(torch.nn.Module):
10858
def forward(self, boxes, scores):
10859
return torchvision.ops.nms(boxes, scores, 0.5)
10861
self.run_test(Module(), (boxes, scores))
10863
@skipIfUnsupportedMinOpsetVersion(11)
10864
def test_batched_nms(self):
10866
boxes = torch.rand(num_boxes, 4)
10867
boxes[:, 2:] += boxes[:, :2]
10868
scores = torch.randn(num_boxes)
10869
idxs = torch.randint(0, 5, size=(num_boxes,))
10871
class Module(torch.nn.Module):
10872
def forward(self, boxes, scores, idxs):
10873
return torchvision.ops.batched_nms(boxes, scores, idxs, 0.5)
10875
self.run_test(Module(), (boxes, scores, idxs))
10877
@skipIfUnsupportedMinOpsetVersion(11)
10879
def test_clip_boxes_to_image(self):
10880
boxes = torch.randn(5, 4) * 500
10881
boxes[:, 2:] += boxes[:, :2]
10882
size = torch.randn(200, 300)
10884
size_2 = torch.randn(300, 400)
10886
class Module(torch.nn.Module):
10887
def forward(self, boxes, size):
10888
shape = (size.shape[0], size.shape[1])
10889
return torchvision.ops.boxes.clip_boxes_to_image(boxes, shape)
10894
input_names=["boxes", "size"],
10895
dynamic_axes={"size": [0, 1]},
10896
additional_test_inputs=[(boxes, size), (boxes, size_2)],
10900
reason="Conditioning on input type via prim::isinstance unsupported in ONNX"
10902
@skipIfUnsupportedMinOpsetVersion(11)
10903
def test_roi_align(self):
10904
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10905
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
10906
model = torchvision.ops.RoIAlign((5, 5), 1.0, 2)
10907
self.run_test(model, (x, single_roi))
10910
reason="Conditioning on input type via prim::isinstance unsupported in ONNX"
10912
@skipIfUnsupportedMinOpsetVersion(16)
10913
def test_roi_align_aligned(self):
10914
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10915
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
10916
model1 = torchvision.ops.RoIAlign((5, 5), 1.0, 2, aligned=True)
10917
self.run_test(model1, (x, single_roi))
10919
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10920
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
10921
model2 = torchvision.ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
10922
self.run_test(model2, (x, single_roi))
10924
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10925
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
10926
model3 = torchvision.ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
10927
self.run_test(model3, (x, single_roi))
10929
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10930
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
10931
model4 = torchvision.ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
10932
self.run_test(model4, (x, single_roi))
10935
reason="Conditioning on input type via prim::isinstance unsupported in ONNX"
10937
@skipIfUnsupportedMinOpsetVersion(11)
10938
def test_roi_pool(self):
10939
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10940
rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
10943
model = torchvision.ops.RoIPool((pool_h, pool_w), 2.0)
10944
self.run_test(model, (x, rois))
10946
@skipIfUnsupportedMinOpsetVersion(11)
10947
def test_resize_images(self):
10948
class TransformModule(torch.nn.Module):
10949
def __init__(self):
10951
self.transform = _init_test_generalized_rcnn_transform()
10953
def forward(self, images):
10954
return self.transform.resize(images, None)[0]
10956
input = torch.rand(3, 10, 20)
10957
input_test = torch.rand(3, 100, 150)
10961
input_names=["input1"],
10962
dynamic_axes={"input1": [0, 1, 2]},
10963
additional_test_inputs=[(input,), (input_test,)],
10966
@skipIfUnsupportedMinOpsetVersion(11)
10968
def test_transform_images(self):
10969
class TransformModule(torch.nn.Module):
10970
def __init__(self):
10972
self.transform = _init_test_generalized_rcnn_transform()
10974
def forward(self, images: List[Tensor]):
10975
return self.transform(images)[0].tensors
10977
input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
10978
input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
10982
additional_test_inputs=[(input,), (input_test,)],
10985
def get_features(self, images):
10986
s0, s1 = images.shape[-2:]
10988
("0", torch.rand(2, 256, s0 // 4, s1 // 4)),
10989
("1", torch.rand(2, 256, s0 // 8, s1 // 8)),
10990
("2", torch.rand(2, 256, s0 // 16, s1 // 16)),
10991
("3", torch.rand(2, 256, s0 // 32, s1 // 32)),
10992
("4", torch.rand(2, 256, s0 // 64, s1 // 64)),
10994
features = OrderedDict(features)
10997
@skipIfUnsupportedMinOpsetVersion(11)
10999
def test_rpn(self):
11000
class RPNModule(torch.nn.Module):
11001
def __init__(self):
11003
self.rpn = _init_test_rpn()
11005
def forward(self, images, features: Dict[str, Tensor]):
11006
images_m = torchvision.models.detection.image_list.ImageList(
11007
images, [(i.shape[-1], i.shape[-2]) for i in images]
11009
return self.rpn(images_m, features)
11011
images = torch.rand(2, 3, 150, 150)
11012
features = self.get_features(images)
11013
images2 = torch.rand(2, 3, 80, 80)
11014
test_features = self.get_features(images2)
11016
model = RPNModule()
11018
model(images, features)
11021
(images, features),
11022
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
11024
"input1": [0, 1, 2, 3],
11025
"input2": [0, 1, 2, 3],
11026
"input3": [0, 1, 2, 3],
11027
"input4": [0, 1, 2, 3],
11028
"input5": [0, 1, 2, 3],
11029
"input6": [0, 1, 2, 3],
11031
additional_test_inputs=[(images, features), (images2, test_features)],
11032
# dict_check=False,
11035
@skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch
11036
@skipIfUnsupportedMinOpsetVersion(11)
11038
def test_multi_scale_roi_align(self):
11039
class TransformModule(torch.nn.Module):
11040
def __init__(self):
11042
self.model = torchvision.ops.MultiScaleRoIAlign(
11043
["feat1", "feat2"], 3, 2
11045
self.image_sizes = [(512, 512)]
11047
def forward(self, input: Dict[str, Tensor], boxes: List[Tensor]) -> Tensor:
11048
return self.model(input, boxes, self.image_sizes)
11051
i["feat1"] = torch.rand(1, 5, 64, 64)
11052
i["feat2"] = torch.rand(1, 5, 16, 16)
11053
boxes = torch.rand(6, 4) * 256
11054
boxes[:, 2:] += boxes[:, :2]
11057
i1["feat1"] = torch.rand(1, 5, 64, 64)
11058
i1["feat2"] = torch.rand(1, 5, 16, 16)
11059
boxes1 = torch.rand(6, 4) * 256
11060
boxes1[:, 2:] += boxes1[:, :2]
11068
additional_test_inputs=[
11080
def test_set_(self):
11081
class M(torch.nn.Module):
11082
def forward(self, x, y):
11086
x = torch.ones(2, 3)
11087
y = torch.randn(4, 6)
11088
self.run_test(M(), (x, y), remained_onnx_input_idx=[1])
11090
y2 = torch.randn(5, 2)
11094
remained_onnx_input_idx=[1],
11095
input_names=["x", "y"],
11096
dynamic_axes={"x": [0, 1], "y": [0, 1]},
11097
additional_test_inputs=[(y, y2)],
11100
@skipIfUnsupportedMinOpsetVersion(9)
11101
def test_set_attr_modules(self):
11102
class InnerModule2(torch.nn.Module):
11103
def __init__(self, embedding_dim):
11105
self.weights = InnerModule2.get_embedding(embedding_dim)
11106
self.register_buffer("_float_tensor", torch.FloatTensor(1))
11110
def get_embedding(embedding_dim: int):
11111
emb = 4 / ((embedding_dim // 2) - 1)
11113
torch.arange((embedding_dim // 2), dtype=torch.float) * -emb
11117
def forward(self, input, incremental_state: Optional[Tensor] = None):
11118
bsz, seq_len = input.shape[0], input.shape[1]
11120
if self.weights is None:
11121
self.weights = InnerModule.get_embedding(self.embedding_dim)
11122
self.weights = self.weights.to(self._float_tensor)
11123
self.weights = self.weights * self.const
11124
if incremental_state is not None:
11126
return self.weights[1 + pos, :].expand(bsz, 1, -1)
11127
return self.weights.index_select(
11128
0, torch.ones((bsz * seq_len), dtype=torch.int64)
11129
).view(bsz, seq_len, -1)
11131
class InnerModule(torch.nn.Module):
11132
def __init__(self, embedding_dim):
11134
self.weights = InnerModule.get_embedding(embedding_dim)
11135
self.module = InnerModule2(embedding_dim=8)
11138
def get_embedding(embedding_dim: int):
11139
emb = 4 / ((embedding_dim // 2) - 1)
11141
torch.arange((embedding_dim // 2), dtype=torch.float) * -emb
11145
def forward(self, x):
11146
return self.module(x) + self.weights
11148
class Module(torch.nn.Module):
11149
def __init__(self):
11151
self.module = InnerModule(embedding_dim=8)
11153
def forward(self, x):
11154
return self.module(x)
11156
x = torch.randn(3, 256)
11157
self.run_test(Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]})
11158
self.run_test(Module(), (x,), remained_onnx_input_idx=[])
11160
@skipIfUnsupportedMinOpsetVersion(9)
11161
def test_set_attr_modules_2(self):
11162
class InnerModule(torch.nn.Module):
11163
def __init__(self, embedding_dim):
11165
self.embedding_dim = embedding_dim
11167
self.weights = InnerModule.get_embedding(self.embedding_dim)
11168
self.register_buffer("_float_tensor", torch.FloatTensor(1))
11171
def get_embedding(embedding_dim: int):
11172
emb = 4 / ((embedding_dim // 2) - 1)
11174
torch.arange((embedding_dim // 2), dtype=torch.float) * -emb
11178
def forward(self, input, incremental_state: Optional[Tensor] = None):
11179
bsz, seq_len = input.shape[0], input.shape[1]
11181
self.weights = InnerModule.get_embedding(self.embedding_dim)
11183
self.weights.index_select(
11184
0, torch.ones((bsz * seq_len), dtype=torch.int64)
11185
).view(bsz, seq_len, -1)
11188
class Module(torch.nn.Module):
11189
def __init__(self):
11191
self.module = InnerModule(embedding_dim=8)
11193
def forward(self, x):
11194
return self.module(x)
11196
x = torch.randn(3, 256)
11197
self.run_test(Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]})
11198
self.run_test(Module(), (x,), remained_onnx_input_idx=[])
11200
def test_set_attr(self):
11201
class MyModule(torch.nn.Module):
11202
def __init__(self):
11204
self.conv = torch.nn.Conv1d(3, 10, 2)
11207
def forward(self, box_regression, weight):
11209
self.conv.weight = weight
11210
w = torch.softmax(self.conv.weight, dim=0)
11211
self.conv.weight = w + w
11213
return box_regression + self.conv.weight
11215
return box_regression - self.conv.weight
11217
model = torch.jit.script(MyModule())
11218
weight = torch.ones(3, 2)
11219
box_regression = torch.randn(3, 2)
11220
self.run_test(model, (box_regression, weight))
11222
@skipIfUnsupportedMinOpsetVersion(11)
11223
def test_set_attr_2(self):
11224
class MyModule(torch.nn.Module):
11225
def __init__(self):
11227
self.conv = torch.nn.Conv1d(10, 3, 3)
11228
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11230
def set_cell_anchors(self, anchors):
11231
if self.conv.bias is not None:
11233
assert b is not None
11234
self.conv.bias = anchors + b
11235
elif self.conv.weight is not None:
11236
self.conv.weight = torch.randn(3, 10)
11237
self.conv.bias = self.conv.weight[:]
11239
def forward(self, anchors) -> Optional[Tensor]:
11240
self.set_cell_anchors(anchors)
11241
return self.conv.bias
11243
model = torch.jit.script(MyModule())
11244
anchors = torch.ones(3, 10, 3)
11245
self.run_test(model, (anchors))
11247
@skipIfUnsupportedMinOpsetVersion(11)
11248
def test_set_attr_3(self):
11249
class MyModule(torch.nn.Module):
11250
def __init__(self):
11252
self.conv = torch.nn.Conv1d(10, 3, 3)
11253
self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
11254
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11256
def set_cell_anchors(self, anchors, boxes):
11257
self.conv.weight = torch.ones(3, 10)
11258
if self.conv.bias is not None:
11259
self.conv.bias = torch.randn(3, 10, 3)
11260
self.conv.weight = anchors + self.conv.weight
11261
boxes[:] = torch.zeros(2, 3)
11263
def forward(self, anchors) -> Tuple[Tensor, Tensor]:
11264
boxes = torch.ones(2, 2, 3)
11265
self.set_cell_anchors(anchors, boxes)
11266
if self.conv.bias is not None:
11267
return self.conv.weight, boxes
11268
return anchors, boxes
11270
model = torch.jit.script(MyModule())
11271
anchors = torch.rand(3, 10)
11272
self.run_test(model, (anchors))
11274
@skipIfUnsupportedMinOpsetVersion(11)
11275
def test_set_attr_4(self):
11276
class MyModule(torch.nn.Module):
11277
def __init__(self):
11279
self.conv = torch.nn.Conv1d(10, 3, 3)
11280
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11282
def set_cell_anchors(self, anchors):
11283
self.conv.weight = torch.zeros(10, 3)
11284
if self.conv.bias is not None:
11286
assert w is not None
11287
self.conv.bias = anchors + w
11289
self.conv.bias = torch.ones(3, 10, 3)
11291
def forward(self, feature_maps, anchors) -> Tuple[Tensor, Tensor]:
11292
self.set_cell_anchors(anchors)
11294
if self.conv.bias is not None:
11296
assert a is not None
11298
result += [feature_maps]
11299
return result[0], result[1]
11301
model = torch.jit.script(MyModule())
11302
x = torch.rand(5, 11, 30)
11303
anchors = torch.ones(3, 10, 3)
11304
self.run_test(model, (x, anchors))
11306
@skipIfUnsupportedMinOpsetVersion(11)
11307
def test_set_attr_5(self):
11308
class MyModule(torch.nn.Module):
11309
def __init__(self):
11311
self.conv = torch.nn.Conv1d(10, 3, 3)
11312
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11314
def set_cell_anchors(self, anchors):
11315
self.conv.weight = torch.arange(10)
11316
for i in range(10):
11318
for j in range(10):
11319
w = self.conv.weight
11320
self.conv.weight = torch.arange(10) + w
11322
self.conv.weight = self.conv.weight + torch.arange(10)
11323
# NOTE: `is not None` and `assert` is for passing torchscript.
11324
if self.conv.bias is not None:
11326
assert a is not None
11327
self.conv.bias = anchors + a
11329
def forward(self, anchors):
11330
self.set_cell_anchors(anchors)
11331
return self.conv.weight, self.conv.bias
11333
model = torch.jit.script(MyModule())
11334
anchors = torch.ones(3, 10, 3)
11335
self.run_test(model, (anchors))
11337
@skipIfUnsupportedMinOpsetVersion(11)
11338
def test_set_attr_in_loop(self):
11339
class MyModule(torch.nn.Module):
11340
def __init__(self):
11342
self.conv = torch.nn.Conv1d(10, 3, 3)
11343
self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
11344
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11346
def set_cell_anchors(self, anchors, boxes):
11347
self.conv.weight = torch.randn(3, 10)
11348
for i in range(self.conv.weight.size(0)):
11349
for j in range(10):
11350
self.conv.bias = torch.randn(3, 10, 3)
11351
self.conv.weight = anchors * i
11352
boxes[j] += torch.ones(3, 3)
11354
def forward(self, anchors) -> Tuple[Tensor, Tensor]:
11355
boxes = torch.ones(10, 3, 3)
11356
self.set_cell_anchors(anchors, boxes)
11357
if self.conv.bias is not None:
11358
return self.conv.weight, boxes
11359
return anchors, boxes
11361
model = torch.jit.script(MyModule())
11362
anchors = torch.rand(10)
11363
self.run_test(model, anchors)
11365
@skipIfUnsupportedMinOpsetVersion(13)
11366
def test_set_attr_in_loop_with_list(self):
11367
class MyModule(torch.nn.Module):
11368
def __init__(self):
11370
self.conv = torch.nn.Conv1d(10, 3, 3)
11371
self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
11372
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11373
self.boxes: List[Tensor] = [
11375
] # Workaround placeholder for TorchScript
11377
def set_cell_anchors(self, anchors):
11378
self.conv.weight = torch.randn(3, 10)
11379
for i in range(self.conv.weight.size(0)):
11380
for j in range(10):
11381
self.conv.bias = torch.randn(3, 10, 3)
11382
self.conv.weight = anchors * i
11383
self.boxes.append(torch.ones(3, 3))
11385
def forward(self, anchors) -> Tuple[Tensor, List[Tensor]]:
11387
self.set_cell_anchors(anchors)
11388
if self.conv.bias is not None:
11389
return self.conv.weight, self.boxes
11390
return anchors, self.boxes
11392
model = torch.jit.script(MyModule())
11393
anchors = torch.rand(10)
11394
self.run_test(model, anchors)
11396
@skipIfUnsupportedMinOpsetVersion(11)
11397
def test_index_put_if(self):
11400
input_data: Tensor, hidden_size: int, prev_state: Tensor
11401
) -> Tuple[Tensor, Tensor]:
11402
batch_size = input_data.size(0)
11403
spatial_size_0 = input_data.size(2)
11404
spatial_size_1 = input_data.size(3)
11405
# generate empty prev_state, if None is provided
11406
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11407
state = torch.zeros(state_size, device=input_data.device)
11408
state_copy = torch.zeros(state_size, device=input_data.device)
11409
if prev_state.size(0) == 0:
11411
torch.zeros(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11415
torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11419
torch.zeros(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11424
torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11427
return state, state_copy
11429
class Example(torch.nn.Module):
11430
def __init__(self, hidden_size):
11432
self.hidden_size = hidden_size
11434
def forward(self, input_data, prev_state):
11435
prev_state = check_init(input_data, self.hidden_size, prev_state)
11436
return prev_state[0], prev_state[1]
11438
model = Example(10)
11439
random_data = torch.rand((1, 5, 30, 30))
11440
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11443
(random_data, empty_tensor),
11444
input_names=["random_data", "empty_tensor"],
11445
dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]},
11447
self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11449
@skipIfUnsupportedMinOpsetVersion(11)
11450
def test_index_put_if_2(self):
11453
input_data: Tensor, hidden_size: int, prev_state: Tensor
11454
) -> Tuple[Tensor, Tensor]:
11455
batch_size = input_data.size(0)
11456
spatial_size_0 = input_data.size(2)
11457
spatial_size_1 = input_data.size(3)
11458
# generate empty prev_state, if None is provided
11459
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11460
state = torch.zeros(state_size, device=input_data.device)
11461
state_copy = torch.zeros(state_size, device=input_data.device)
11462
if prev_state.size(0) == 0:
11466
batch_size, hidden_size, spatial_size_0, spatial_size_1
11472
batch_size, hidden_size, spatial_size_0, spatial_size_1
11476
elif prev_state.size(0) == 1:
11478
state[:] = prev_state + s
11479
elif prev_state.size(0) == 2:
11481
torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11484
return state, state_copy
11486
class Example(torch.nn.Module):
11487
def __init__(self, hidden_size):
11489
self.hidden_size = hidden_size
11491
def forward(self, input_data, prev_state):
11492
prev_state = check_init(input_data, self.hidden_size, prev_state)
11493
return prev_state[0], prev_state[1]
11495
model = Example(10)
11496
random_data = torch.rand((1, 5, 30, 30))
11497
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11498
random_state = torch.rand((1, 1, 10, 30, 30))
11501
(random_data, empty_tensor),
11502
input_names=["data", "state"],
11503
dynamic_axes={"data": [0, 1, 2], "state": [0, 1, 2, 3, 4]},
11504
additional_test_inputs=[(random_data, random_state)],
11508
(random_data, empty_tensor),
11509
input_names=["data", "state"],
11510
dynamic_axes={"state": [0, 1, 2, 3, 4]},
11511
additional_test_inputs=[(random_data, random_state)],
11512
remained_onnx_input_idx=[1],
11514
self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11516
@skipIfUnsupportedMinOpsetVersion(11)
11517
def test_index_put_if_3(self):
11520
input_data: Tensor, hidden_size: int, prev_state: Tensor
11522
batch_size = input_data.size(0)
11523
spatial_size_0 = input_data.size(2)
11524
spatial_size_1 = input_data.size(3)
11525
# generate empty prev_state, if None is provided
11526
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11527
state = torch.zeros(state_size, device=input_data.device)
11528
if prev_state.size(0) < 2:
11530
if prev_state.size(0) == 0:
11533
batch_size, hidden_size, spatial_size_0, spatial_size_1
11542
class Example(torch.nn.Module):
11543
def __init__(self, hidden_size):
11545
self.hidden_size = hidden_size
11547
def forward(self, input_data, prev_state):
11548
prev_state = check_init(input_data, self.hidden_size, prev_state)
11552
random_data = torch.rand((1, 5, 4, 4))
11553
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11556
(random_data, empty_tensor),
11557
input_names=["random_data", "empty_tensor"],
11558
dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]},
11560
self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11562
@skipIfUnsupportedMinOpsetVersion(11)
11563
def test_index_put_if_4(self):
11566
input_data: Tensor, hidden_size: int, prev_state: Tensor
11568
batch_size = input_data.size(0)
11569
spatial_size_0 = input_data.size(2)
11570
spatial_size_1 = input_data.size(3)
11571
# generate empty prev_state, if None is provided
11572
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11573
state = torch.zeros(state_size, device=input_data.device)
11574
if prev_state.size(0) == 0:
11577
torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11582
torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11589
class Example(torch.nn.Module):
11590
def __init__(self, hidden_size):
11592
self.hidden_size = hidden_size
11594
def forward(self, input_data, prev_state):
11595
prev_state = check_init(input_data, self.hidden_size, prev_state)
11599
random_data = torch.rand((1, 5, 4, 4))
11600
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11603
(random_data, empty_tensor),
11604
input_names=["random_data", "empty_tensor"],
11605
dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]},
11607
self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11609
@skipIfUnsupportedMinOpsetVersion(11)
11610
def test_index_put_if_5(self):
11613
input_data: Tensor, hidden_size: int, prev_state: Tensor
11614
) -> Tuple[Tensor, Tensor]:
11615
batch_size = input_data.size(0)
11616
spatial_size_0 = input_data.size(2)
11617
spatial_size_1 = input_data.size(3)
11618
# generate empty prev_state, if None is provided
11619
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11620
state = torch.zeros(state_size, device=input_data.device)
11622
if prev_state.size(0) == 0:
11624
torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11629
torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11634
return state, state_ref
11636
class Example(torch.nn.Module):
11637
def __init__(self, hidden_size):
11639
self.hidden_size = hidden_size
11641
def forward(self, input_data, prev_state):
11642
prev_state, state_ref = check_init(
11643
input_data, self.hidden_size, prev_state
11645
return prev_state, state_ref
11648
random_data = torch.rand((1, 5, 4, 4))
11649
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11652
(random_data, empty_tensor),
11653
input_names=["random_data", "empty_tensor"],
11654
dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]},
11656
self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11658
@skipIfUnsupportedMinOpsetVersion(11)
11659
def test_list_append_in_block(self):
11660
class ListModel(torch.nn.Module):
11661
def forward(self, x, y):
11663
for i in range(x.size(0)):
11664
res.append(torch.matmul(x[i], y))
11667
model = torch.jit.script(ListModel())
11668
x = torch.randn(16, 3, 4)
11669
y = torch.randn(4, 5)
11670
self.run_test(model, (x, y))
11672
@skipIfUnsupportedMinOpsetVersion(13)
11673
def test_list_append_in_nested_block(self):
11674
class ListModel(torch.nn.Module):
11675
def forward(self, x, y):
11677
for i in range(x.size(0)):
11678
for j in range(x.size(1)):
11679
res.append(torch.matmul(x[i][j], y))
11682
model = torch.jit.script(ListModel())
11683
x = torch.randn(4, 4, 3, 4)
11684
y = torch.randn(4, 5)
11685
self.run_test(model, (x, y))
11687
@skipIfUnsupportedMinOpsetVersion(13)
11688
def test_list_pop_in_block(self):
11689
class ListModel(torch.nn.Module):
11690
def forward(self, x, y):
11692
elem = torch.matmul(x[0], y)
11693
for i in range(x.size(0)):
11694
res.append(torch.matmul(x[i], y))
11695
for i in range(x.size(0)):
11697
for i in range(x.size(0)):
11698
res.append(torch.matmul(x[i], y))
11700
return res.append(elem)
11702
model = torch.jit.script(ListModel())
11703
x = torch.randn(16, 3, 4)
11704
y = torch.randn(4, 5)
11705
self.run_test(model, (x, y))
11707
@skipIfUnsupportedMinOpsetVersion(13)
11708
def test_list_del_in_block(self):
11709
class ListModel(torch.nn.Module):
11710
def forward(self, x, y):
11712
elem = torch.matmul(x[0], y)
11713
for i in range(x.size(0)):
11714
res.append(torch.matmul(x[i], y))
11715
for i in range(x.size(0)):
11717
for i in range(x.size(0)):
11718
res.append(torch.matmul(x[i], y))
11720
return res.append(elem)
11722
model = torch.jit.script(ListModel())
11723
x = torch.randn(16, 3, 4)
11724
y = torch.randn(4, 5)
11725
self.run_test(model, (x, y))
11727
@skipIfUnsupportedMinOpsetVersion(11)
11728
def test_list_unpack(self):
11729
class ListModel(torch.nn.Module):
11730
def forward(self, x, y):
11732
elem = torch.matmul(x[0], y)
11733
for i in range(x.size(0)):
11734
res.append(torch.matmul(x[i], y))
11738
model = torch.jit.script(ListModel())
11739
x = torch.randn(3, 3, 4)
11740
y = torch.randn(4, 5)
11741
self.run_test(model, (x, y))
11743
@skipIfUnsupportedMinOpsetVersion(11)
11744
def test_index_put_inplace_ops(self):
11746
def check_init(input_data: Tensor, hidden_size: int) -> Tensor:
11747
batch_size = input_data.size(0)
11748
spatial_size_0 = input_data.size(2)
11749
spatial_size_1 = input_data.size(3)
11750
# generate empty prev_state, if None is provided
11751
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11752
state = torch.zeros(state_size, device=input_data.device)
11753
if input_data.size(0) == 1:
11755
torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11759
torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11762
for i in range(input_data.size(0)):
11763
state[1] += torch.ones(
11764
batch_size, hidden_size, spatial_size_0, spatial_size_1
11767
torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11772
class Example(torch.nn.Module):
11773
def __init__(self, hidden_size):
11775
self.hidden_size = hidden_size
11777
def forward(self, input_data):
11778
state = check_init(input_data, self.hidden_size)
11781
model = Example(10)
11782
random_data = torch.rand((1, 5, 30, 30))
11786
input_names=["random_data"],
11787
dynamic_axes={"random_data": [0, 1, 2, 3]},
11789
self.run_test(model, (random_data), remained_onnx_input_idx=[])
11791
@skipIfUnsupportedMinOpsetVersion(11)
11792
def test_input_mask_model(self):
11793
class InputMaskModel(torch.nn.Module):
11794
def __init__(self, output_size):
11796
self.bias = torch.nn.Parameter(
11797
torch.empty(output_size, dtype=torch.float)
11799
with torch.no_grad():
11802
def forward(self, model_input, y):
11803
input_mask = (model_input <= 0) | (model_input > 25)
11804
y[input_mask, :] = 0.0
11805
output = y + self.bias
11809
m = InputMaskModel(output_size)
11810
x = torch.tensor([0, 4, 24, 25], dtype=torch.int64)
11813
[0.1, 0.2, 0.3, 0.4],
11814
[0.1, 0.2, 0.3, 0.4],
11815
[0.1, 0.2, 0.3, 0.4],
11816
[0.1, 0.2, 0.3, 0.4],
11820
self.run_test(m, (x, y))
11822
class InputMaskModel(torch.nn.Module):
11823
def __init__(self, output_size):
11826
def forward(self, model_input_1, model_input_2, y):
11827
input_mask_1 = (model_input_1 <= 0) | (model_input_1 > 25)
11828
input_mask_2 = (model_input_2 < 1) | (model_input_2 >= 12)
11829
y[input_mask_1, input_mask_2] = 0.0
11833
m = InputMaskModel(output_size)
11834
x1 = torch.tensor([0, 4, 24, 25], dtype=torch.int64)
11835
x2 = torch.tensor([0, 3, 12, 15], dtype=torch.int64)
11838
[0.1, 0.2, 0.3, 0.4],
11839
[0.1, 0.2, 0.3, 0.4],
11840
[0.1, 0.2, 0.3, 0.4],
11841
[0.1, 0.2, 0.3, 0.4],
11845
self.run_test(m, (x1, x2, y))
11848
def test_unsafe_chunk(self):
11849
class ChunkModel(torch.nn.Module):
11850
def forward(self, x):
11851
return torch.unsafe_chunk(x, 3, dim=1)
11853
model = ChunkModel()
11855
x = torch.randn(1, 18)
11856
self.run_test(model, x, input_names=["x"])
11858
def test_symbolic_shape_inference(self):
11859
# ConstantOfShape is tested in test_embedding_bag
11860
# Tile is tested in test_repeat
11861
# test Shape, Reshape, Transpose, Gather
11862
class ShapeModel(torch.nn.Module):
11863
def forward(self, x, y):
11864
shape = x.size()[:3] + (-1,) # shape [4], ("batch", 3, 4, -1)
11865
y = y.reshape(shape) # batch, 3, 4, 10/batch
11866
return y.transpose(1, 2)
11868
model = ShapeModel()
11870
x = torch.ones(2, 3, 4, 5)
11871
y = torch.ones(3, 4, 5, 2)
11875
input_names=["x", "y"],
11876
dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]},
11878
self.run_test(model, (x, y), remained_onnx_input_idx=[1])
11880
class ViewModel(torch.nn.Module):
11881
def forward(self, x):
11884
model = ViewModel()
11886
x = torch.tensor(2.0)
11887
self.run_test(model, (x,))
11889
# test prim::ListConstruct for Reshape input 1
11890
class ViewModel_2(torch.nn.Module):
11891
def forward(self, x):
11892
N, C, H, W = x.shape[0], x.shape[2], x.shape[3], x.shape[4]
11893
x1 = x.view(N, -1, C, H, W)
11894
x2 = x1.permute(0, 3, 4, 1, 2)
11895
return x2.reshape(N, -1, C)
11897
model = ViewModel_2()
11899
x = torch.ones(2, 3, 4, 5, 6)
11900
self.run_test(model, x)
11902
@skipIfUnsupportedMinOpsetVersion(9)
11903
def test_symbolic_shape_inference_arange(self):
11905
class ArangeModel(torch.nn.Module):
11906
def forward(self, signal):
11908
outer_dimensions = signal.size()[:-2]
11909
frames, frame_length = signal.size()[-2:]
11911
subframe_length = signal.size()[0]
11912
subframe_step = frame_step // subframe_length
11913
subframes_per_frame = frame_length // subframe_length
11914
output_size = frame_step * (frames - 1) + frame_length
11915
output_subframes = output_size // subframe_length
11917
frame = torch.arange(0, output_subframes)
11920
model = ArangeModel()
11922
M, C, K, N = 1, 2, 3, 4
11923
x = torch.randint(5, (M, C, K, N))
11924
y = torch.randint(5, (M, C + 1, K + 1, N + 1))
11925
self.run_test(model, x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]})
11926
self.run_test(model, x, remained_onnx_input_idx=[])
11931
dynamic_axes={"x": [0, 1, 2, 3]},
11932
additional_test_inputs=[(x,), (y,)],
11935
@skipIfUnsupportedMinOpsetVersion(11)
11936
def test_symbolic_shape_inference_box(self):
11938
class BoxModel(torch.nn.Module):
11939
def forward(self, boxes):
11941
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
11942
keep = (ws >= min_size) & (hs >= min_size)
11943
keep = torch.where(keep)[0]
11948
x = torch.ones(2, 4)
11949
y = torch.ones(3, 5)
11950
self.run_test(model, x)
11955
dynamic_axes={"x": [0, 1]},
11956
additional_test_inputs=[(x,), (y,)],
11959
@skipIfUnsupportedMinOpsetVersion(11)
11960
def test_symbolic_shape_inference_box_if(self):
11962
class BoxIfModel(torch.nn.Module):
11963
def forward(self, boxes, scores):
11965
inds = torch.where(scores > score_thresh)[0]
11966
boxes_1 = boxes[inds]
11967
if boxes_1.numel() > 3:
11972
model = BoxIfModel()
11974
boxes = torch.ones(2, 4)
11975
scores = torch.ones(1, 4)
11976
self.run_test(model, (boxes, scores))
11978
@skipIfUnsupportedMinOpsetVersion(11)
11980
def test_symbolic_shape_inference_arange_2(self):
11982
class ArangeModel(torch.nn.Module):
11983
def forward(self, start):
11984
return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.int64)
11986
x = torch.randn(2, 3, 4)
11988
ArangeModel(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
11990
self.run_test(ArangeModel(), (x,), remained_onnx_input_idx=[])
11992
class ArangeModel2(torch.nn.Module):
11993
def forward(self, start):
11994
return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.double)
11996
x = torch.randn(2, 3, 4)
11998
ArangeModel2(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
12000
self.run_test(ArangeModel2(), (x,), remained_onnx_input_idx=[])
12002
@skipIfUnsupportedMinOpsetVersion(9)
12003
def test_symbolic_shape_inference_nonzero(self):
12004
class OneLikeModel(torch.nn.Module):
12005
def forward(self, x):
12006
ones = torch.ones_like(
12009
layout=torch.strided,
12010
device=torch.device("cpu"),
12012
return torch.nonzero(ones)
12015
self.run_test(OneLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0]})
12016
self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[])
12017
x = torch.randn(2, 3, 4)
12019
OneLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
12021
self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[])
12023
class ZeroLikeModel(torch.nn.Module):
12024
def forward(self, x):
12025
zeros = torch.zeros_like(
12028
layout=torch.strided,
12029
device=torch.device("cpu"),
12031
return torch.nonzero(zeros)
12034
self.run_test(ZeroLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0]})
12035
self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[])
12036
x = torch.randn(2, 3, 4)
12038
ZeroLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
12040
self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[])
12042
@skipIfUnsupportedMinOpsetVersion(9)
12043
def test_symbolic_shape_inference_expand_1(self):
12044
class ExpandModel(torch.nn.Module):
12045
def forward(self, x):
12046
return x.expand(4, 6, 2)
12048
x = torch.randn(6, 1, requires_grad=True)
12049
self.run_test(ExpandModel(), (x,))
12051
@skipIfUnsupportedMinOpsetVersion(9)
12052
def test_symbolic_shape_inference_expand_2(self):
12053
class M(torch.nn.Module):
12054
def forward(self, x):
12055
input_shape = x.size()
12056
batch_size, seq_length = input_shape
12057
seq_ids = torch.arange(seq_length)
12059
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
12060
<= seq_ids[None, :, None]
12062
return causal_mask.transpose(0, 1)
12064
x = torch.randn(3, 16)
12065
self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]})
12066
self.run_test(M(), (x,), remained_onnx_input_idx=[])
12068
@skipIfUnsupportedMinOpsetVersion(10)
12069
def test_symbolic_shape_inference_slice(self):
12070
class M(torch.nn.Module):
12071
def forward(self, x, position_bias):
12072
input_shape = x.size()
12073
batch_size, seq_length = input_shape
12074
position_bias = position_bias[:, :, -seq_length:, :]
12075
return position_bias.transpose(0, 1)
12077
x = torch.randn(3, 16)
12078
position_bias = torch.randn(1, 3, 20, 8)
12081
(x, position_bias),
12082
input_names=["x", "position_bias"],
12083
dynamic_axes={"x": [0, 1], "position_bias": [0, 1, 2, 3]},
12085
self.run_test(M(), (x, position_bias), remained_onnx_input_idx=[1])
12087
def test_symbolic_shape_inference_slice_2(self):
12088
class M(torch.nn.Module):
12089
def forward(self, position_bias):
12090
position_bias = position_bias[:, :, -2:, :]
12091
return position_bias.transpose(0, 1)
12093
position_bias = torch.randn(1, 3, 20, 8)
12094
self.run_test(M(), (position_bias,))
12096
@skipIfUnsupportedMinOpsetVersion(9)
12098
def test_symbolic_shape_inference_time(self):
12099
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
12100
h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
12101
c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
12102
model_lstm = torch.nn.LSTM(
12103
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
12108
input_names=["x", "y"],
12109
dynamic_axes={"x": [0, 1]},
12111
model_gru = torch.nn.GRU(
12112
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False, bias=False
12115
model_gru, (input, h0), input_names=["x", "y"], dynamic_axes={"x": [0, 1]}
12117
model_rnn = torch.nn.RNN(
12118
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False, bias=False
12121
model_rnn, (input, h0), input_names=["x", "y"], dynamic_axes={"x": [0, 1]}
12124
def test_symbolic_shape_inference_dynamic_axes(self):
12125
class M(torch.nn.Module):
12126
def forward(self, input_ids):
12127
input_shape = input_ids.size()
12128
input_ids = input_ids.view(-1, input_shape[-1])
12129
return input_ids.transpose(0, 1)
12131
x = torch.randn(3, 16)
12135
input_names=["input_ids"],
12136
dynamic_axes={"input_ids": {0: "batch", 1: "sequence"}},
12139
@skipIfUnsupportedMinOpsetVersion(9)
12140
def test_hann_window_periodic(self):
12141
class HannWindowModule_Periodic(torch.nn.Module):
12142
def __init__(self):
12144
self.window_length = 0
12146
def forward(self, x, window_length: int):
12147
self.window_length = window_length
12151
self.window_length, periodic=True, dtype=torch.float
12156
x = torch.randn(win_length)
12158
module = HannWindowModule_Periodic()
12159
self.run_test(module, (x, win_length))
12161
@skipIfUnsupportedMinOpsetVersion(9)
12162
def test_hann_window_not_periodic(self):
12163
class HannWindowModule_NotPeriodic(torch.nn.Module):
12164
def __init__(self):
12166
self.window_length = 0
12168
def forward(self, x, window_length: int):
12169
self.window_length = window_length
12173
self.window_length, periodic=False, dtype=torch.float
12178
x = torch.randn(win_length)
12180
module = HannWindowModule_NotPeriodic()
12181
self.run_test(module, (x, win_length))
12183
@skipIfUnsupportedMinOpsetVersion(9)
12185
def test_hann_window_default_values(self):
12186
class HannWindowModule(torch.nn.Module):
12187
def __init__(self):
12189
self.window_length = 0
12191
def forward(self, x, window_length: int):
12192
import torch.nn.functional as F
12194
self.window_length = window_length
12195
return torch.add(x, F.relu(torch.hann_window(self.window_length)))
12198
x = torch.randn(win_length, dtype=torch.float)
12199
module = HannWindowModule()
12201
output = module(x, win_length)
12202
self.run_test(module, (x, win_length))
12204
@skipIfUnsupportedMinOpsetVersion(12)
12205
def test_tensordot_dim_count(self):
12206
class M(torch.nn.Module):
12207
def forward(self, x, y):
12208
output = torch.tensordot(x, y, 2)
12211
x = torch.randint(6, (7, 5, 3, 4))
12212
y = torch.randint(6, (3, 4, 9, 2))
12214
self.run_test(M(), (x, y))
12216
@skipIfUnsupportedMinOpsetVersion(12)
12217
def test_tensordot_dim_list(self):
12218
class M(torch.nn.Module):
12219
def forward(self, x, y):
12220
output = torch.tensordot(x, y, ([1, -2, -1], [1, 0, 3]))
12223
x = torch.randint(6, (7, 4, 3, 5, 2))
12224
y = torch.randint(6, (5, 4, 4, 2, 6))
12226
self.run_test(M(), (x, y))
12228
@skipIfUnsupportedMinOpsetVersion(12)
12229
def test_tensordot_dynamic_dim(self):
12230
class M(torch.nn.Module):
12231
def forward(self, x, y):
12232
output = torch.tensordot(x, y, 2)
12235
x = torch.randint(6, (7, 5, 3, 4))
12236
y = torch.randint(6, (3, 4, 9, 2))
12238
new_x = torch.randint(6, (8, 6, 2, 5))
12239
new_y = torch.randint(6, (2, 5, 3, 4))
12244
additional_test_inputs=[(new_x, new_y)],
12245
input_names=["input_x", "input_y"],
12246
dynamic_axes={"input_x": [0, 1, 2, 3], "input_y": [0, 1, 2, 3]},
12249
@skipIfUnsupportedMinOpsetVersion(9)
12250
def test_to_device(self):
12251
class M_ToDevice(torch.nn.Module):
12252
def forward(self, x, y):
12253
return x.to(y.device), y
12255
class M_ToDeviceDtype(torch.nn.Module):
12256
def forward(self, x, y):
12257
return x.to(y.device, dtype=torch.long), y
12262
self.run_test(M_ToDevice(), (x, y))
12263
self.run_test(M_ToDeviceDtype(), (x, y))
12265
@skipIfUnsupportedMinOpsetVersion(9)
12266
def test_fill(self):
12267
class FillModule(torch.nn.Module):
12268
def forward(self, x, filled_value: int):
12269
return x.fill_(filled_value)
12271
x = torch.randn((4, 5, 6))
12273
self.run_test(FillModule(), (x, filled_value))
12275
class FillFloatModule(torch.nn.Module):
12276
def forward(self, x, filled_value: float):
12277
return x.fill_(filled_value)
12279
x = torch.randn((4, 5, 6))
12281
self.run_test(FillFloatModule(), (x, filled_value))
12283
class FillScalarModule(torch.nn.Module):
12284
def forward(self, x):
12289
x = torch.ones(2, 3, 4, dtype=torch.long)
12290
self.run_test(FillScalarModule(), x)
12292
@skipIfUnsupportedMinOpsetVersion(9)
12293
def test_index_add_normal(self):
12294
class M(torch.nn.Module):
12295
def __init__(self, dim, index, updates):
12299
self.updates = updates
12301
def forward(self, x):
12302
x.index_add_(self.dim, self.index, self.updates)
12305
x = torch.ones(5, 1)
12306
updates = torch.tensor([[1], [4], [7], [3], [2]], dtype=torch.float)
12307
index = torch.tensor([0, 2, 3, 1, 4])
12308
self.run_test(M(0, index, updates), (x,))
12310
x = torch.ones(1, 4, 3)
12311
updates = torch.tensor(
12312
[[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float
12314
index = torch.tensor([0, 2, 3, 1])
12315
self.run_test(M(1, index, updates), (x,))
12317
updates = torch.tensor(
12318
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [2, 3, 4]]], dtype=torch.float
12320
index = torch.tensor([0, 2, 1])
12321
self.run_test(M(2, index, updates), (x,))
12323
@skipIfUnsupportedMinOpsetVersion(9)
12324
def test_index_add_dim_size_differ(self):
12325
class M(torch.nn.Module):
12326
def __init__(self, dim, index, updates):
12330
self.updates = updates
12332
def forward(self, x):
12333
x.index_add_(self.dim, self.index, self.updates)
12336
x = torch.ones(1, 4, 3)
12337
updates = torch.tensor([[[1, 5, 7], [2, 4, 5], [5, 5, 6]]], dtype=torch.float)
12338
index = torch.tensor([0, 2, 1])
12339
self.run_test(M(1, index, updates), (x,))
12341
@skipIfUnsupportedMinOpsetVersion(9)
12342
def test_index_add_in_loop(self):
12343
class M(torch.nn.Module):
12344
def __init__(self, dim, index, updates, loop_count):
12348
self.updates = updates
12349
self.loop_count = loop_count
12351
def forward(self, x):
12352
for i in range(self.loop_count):
12353
x.index_add_(self.dim, self.index, self.updates)
12356
x = torch.ones(1, 4, 3)
12357
updates = torch.tensor(
12358
[[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float
12360
index = torch.tensor([0, 2, 3, 1])
12361
loop_count = torch.randint(20, (1,))[0].item()
12362
self.run_test(M(1, index, updates, loop_count), (x,))
12364
@skipIfUnsupportedMinOpsetVersion(9)
12365
def test_index_add_if(self):
12366
class M(torch.nn.Module):
12367
def __init__(self, dim, updates, index_true, index_false):
12370
self.updates = updates
12371
self.index_true = index_true
12372
self.index_false = index_false
12374
def forward(self, x, cond):
12376
x.index_add_(self.dim, self.index_true, self.updates)
12378
x.index_add_(self.dim, self.index_false, self.updates)
12381
x = torch.ones(1, 4, 3)
12382
updates = torch.tensor(
12383
[[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float
12385
index_true = torch.tensor([0, 2, 3, 1])
12386
index_false = torch.tensor([1, 0, 2, 3])
12387
cond = torch.tensor(1, dtype=torch.bool)
12389
torch.jit.script(M(1, updates, index_true, index_false)), (x, cond)
12392
@skipIfUnsupportedMinOpsetVersion(9)
12393
def test_index_add_dynamic_axes(self):
12394
class M(torch.nn.Module):
12395
def __init__(self, dim, index, updates):
12399
self.updates = updates
12401
def forward(self, x):
12402
x.index_add_(self.dim, self.index, self.updates)
12405
x = torch.ones(1, 4, 3)
12406
updates = torch.tensor(
12407
[[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float
12409
index = torch.tensor([0, 2, 3, 1])
12412
M(1, index, updates),
12414
input_names=["input_1"],
12415
dynamic_axes={"input_1": [0, 1]},
12418
def test_roll(self):
12419
class M(torch.nn.Module):
12420
def __init__(self, shifts, dims):
12422
self.shifts = shifts
12425
def forward(self, x):
12426
return torch.roll(x, self.shifts, self.dims)
12428
x = torch.randn(2, 3, 4)
12429
self.run_test(M([1, 1], [1, 0]), (x,))
12430
self.run_test(M([0, 1, 2], [1, 0, 2]), (x,))
12431
self.run_test(M(2, 1), (x,))
12432
self.run_test(M([-1, 3], [-2, -1]), (x,))
12434
def test_sum(self):
12435
class M(torch.nn.Module):
12436
def forward(self, x):
12437
return torch.sum(x)
12439
x = torch.ones(12, 3)
12440
self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0]})
12443
def test_sum_empty_tensor(self):
12444
class M(torch.nn.Module):
12445
def forward(self, x):
12446
return x[0:0].sum(), x.sum()
12449
self.run_test(M(), (x,))
12451
x = torch.ones(2, 0, 3)
12452
self.run_test(M(), (x,))
12455
self.run_test(M(), (x,))
12457
@skipIfUnsupportedMinOpsetVersion(11)
12458
def test_broad_cast_tensors(self):
12459
class M(torch.nn.Module):
12460
def forward(self, x, y):
12461
m = torch.broadcast_tensors(x, y)
12464
x = torch.randint(5, (1,))
12465
y = torch.randint(5, (5,))
12467
self.run_test(M(), (x, y))
12469
x = torch.randint(5, (4, 2, 1, 4))
12470
y = torch.randint(5, (2, 3, 1))
12472
self.run_test(M(), (x, y))
12474
x = torch.randn(2, 1, 4)
12475
y = torch.randn(5, 2, 3, 1)
12477
self.run_test(M(), (x, y))
12480
@skipIfUnsupportedMinOpsetVersion(11)
12481
def test_dist_normal(self):
12482
class M(torch.nn.Module):
12483
def forward(self, x, y):
12484
return torch.distributions.Normal(x, y).sample().size(0), x, y
12486
self.run_test(M(), (torch.tensor([0.0]), torch.tensor([[1.0], [2.0]])))
12487
self.run_test(M(), (torch.tensor([0.0]), torch.tensor([1.0])))
12492
torch.tensor([[[0.0], [10.0]], [[2.0], [8.0]], [[2.0], [8.0]]]),
12493
torch.tensor([[1.0], [3.0]]),
12498
@skipIfUnsupportedMinOpsetVersion(11)
12499
def test_dist_normal_correctness(self):
12500
class M(torch.nn.Module):
12501
def forward(self, x, y):
12502
return torch.distributions.Normal(x, y).sample([20000])
12504
expected_mean = 5.0
12505
expected_std = 10.0
12508
dummy_input = (torch.tensor([expected_mean]), torch.tensor([expected_std]))
12509
model_onnx = io.BytesIO()
12511
model_export, dummy_input, model_onnx, opset_version=self.opset_version
12513
ort_sess = verification._ort_session(model_onnx)
12514
ort_out = verification._run_onnx(ort_sess, inputs=dummy_input)
12516
actual_std = np.std(ort_out)
12517
actual_mean = np.mean(ort_out)
12520
abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1
12521
), "the gap of mean between ort outputs and expected one is unacceptable."
12523
abs(abs(actual_std) - expected_std) <= expected_std * 0.1
12524
), "the gap of variance between ort outputs and expected one is unacceptable."
12527
@skipIfUnsupportedMinOpsetVersion(11)
12528
def test_nn_init_normal_correctness(self):
12529
expected_mean = 5.0
12530
expected_std = 10.0
12532
class M(torch.nn.Module):
12534
x = torch.ones([]).new_empty(1, 400, 50)
12535
torch.nn.init.normal_(x, expected_mean, expected_std)
12539
model_onnx = io.BytesIO()
12540
test_inputs = tuple()
12542
model_export, test_inputs, model_onnx, opset_version=self.opset_version
12544
ort_sess = verification._ort_session(model_onnx)
12545
ort_out = verification._run_onnx(ort_sess, inputs=test_inputs)
12547
actual_std = np.std(ort_out)
12548
actual_mean = np.mean(ort_out)
12551
abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1
12552
), "the gap of mean between ort outputs and expected one is unacceptable."
12554
abs(abs(actual_std) - expected_std) <= expected_std * 0.1
12555
), "the gap of variance between ort outputs and expected one is unacceptable."
12558
@skipIfUnsupportedMinOpsetVersion(11)
12559
def test_dist_uniform(self):
12560
class M(torch.nn.Module):
12561
def forward(self, x, y):
12562
return torch.distributions.Uniform(x, y).sample().size(0), x, y
12564
self.run_test(M(), (torch.tensor([0.0]), torch.tensor([10.0])))
12565
self.run_test(M(), (torch.tensor([[0.0], [6.0]]), torch.tensor([[1.0], [7.0]])))
12567
M(), (torch.tensor([1.0]), torch.tensor([[10.0], [7.0], [9.0], [20.0]]))
12571
@skipIfUnsupportedMinOpsetVersion(11)
12572
def test_dist_uniform_correctness(self):
12573
class M(torch.nn.Module):
12574
def forward(self, x, y):
12575
return torch.distributions.Uniform(x, y).sample([10000])
12578
expected_max = 10.0
12579
expected_mean = (expected_min + expected_max) / 2
12582
dummy_input = (torch.tensor([expected_min]), torch.tensor([expected_max]))
12583
model_onnx = io.BytesIO()
12585
model_export, dummy_input, model_onnx, opset_version=self.opset_version
12587
ort_sess = verification._ort_session(model_onnx)
12589
ort_out = verification._run_onnx(ort_sess, inputs=dummy_input)
12590
actual_min = np.min(ort_out)
12591
actual_max = np.max(ort_out)
12592
actual_mean = np.mean(ort_out)
12595
actual_min >= expected_min
12596
), "the minimum value of ort outputs is out of scope."
12598
actual_max <= expected_max
12599
), "the maximum value of ort outputs is out of scope."
12601
abs(actual_mean - expected_mean) <= expected_mean * 0.05
12602
), "the mean value of ort outputs is out of scope."
12604
@skipIfUnsupportedMinOpsetVersion(13)
12605
def test_sequence_to_int(self):
12606
class M(torch.nn.Module):
12607
def forward(self, x):
12608
result = torch.tensor([2 for i in range(x.size()[0])], dtype=torch.int)
12611
x = torch.randn(10, 5)
12612
self.run_test(M(), (x,))
12614
@skipIfUnsupportedMinOpsetVersion(13)
12615
def test_sequence_to_float(self):
12616
class M(torch.nn.Module):
12617
def forward(self, x):
12618
result = torch.tensor(
12619
[1.1 for i in range(x.size()[0])], dtype=torch.float
12623
x = torch.randn(10, 5)
12624
self.run_test(M(), (x,))
12626
@skipIfUnsupportedMinOpsetVersion(13)
12627
def test_sequence_to_bool(self):
12628
class M(torch.nn.Module):
12629
def forward(self, x):
12630
result = torch.tensor(
12631
[False for i in range(x.size()[0])], dtype=torch.bool
12635
x = torch.randn(10, 5)
12636
self.run_test(M(), (x,))
12638
def test_tuple_output_from_if_with_raised_exception(self):
12639
class M(torch.nn.Module):
12640
def forward(self, t: Tensor) -> Tuple[Tensor, Tensor]:
12642
raise Exception("Negative input")
12644
return torch.zeros(5), torch.zeros(5)
12647
self.run_test(torch.jit.script(M()), (x,))
12649
# NOTE: For quantization tests, choose scale and zero point carefully
12650
# such that inputs and outputs do not always overflow/underflow.
12651
# Otherwise test results could be inaccurate.
12652
@skipIfUnsupportedMinOpsetVersion(10)
12653
def test_quantized_linear(self):
12654
model = torch.ao.nn.quantized.Linear(4, 8)
12655
# Set fixed weight to avoid flaky test.
12656
weight = torch.quantize_per_tensor(
12657
torch.arange(32, dtype=torch.float).view(8, 4), 0.5, 0, torch.qint8
12659
# Set non-zero bias.
12660
bias = torch.arange(8, dtype=torch.float)
12661
model.set_weight_bias(weight, bias)
12662
# Set fixed input to avoid flaky test.
12663
input = torch.randn(4, 4)
12664
input = torch.arange(16, dtype=torch.float).view(4, 4) - 8
12665
input_tensor = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12666
self.run_test(model, input_tensor)
12668
@skipIfUnsupportedMinOpsetVersion(10)
12669
def test_quantized_conv1d(self):
12670
model = torch.ao.nn.quantized.Conv1d(16, 33, 3, stride=2)
12671
# Manually initialize model weight and bias to random numbers.
12672
# By default all zeros.
12673
q_weight = torch.quantize_per_tensor(
12674
torch.randn(33, 16, 3), 0.5, 0, torch.qint8
12676
bias = torch.arange(33).to(torch.float) - 16
12677
model.set_weight_bias(q_weight, bias)
12678
input = torch.randn(3, 16, 32)
12679
q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12680
self.run_test(model, q_input)
12682
@skipIfUnsupportedMinOpsetVersion(10)
12683
def test_quantized_conv2d(self):
12684
model = torch.ao.nn.quantized.Conv2d(16, 33, 3, stride=2)
12685
# Manually initialize model weight and bias to random numbers.
12686
# By default all zeros.
12687
q_weight = torch.quantize_per_tensor(
12688
torch.randn(33, 16, 3, 3), 0.5, 0, torch.qint8
12690
bias = torch.arange(33).to(torch.float) - 16
12691
model.set_weight_bias(q_weight, bias)
12692
input = torch.randn(3, 16, 32, 32)
12693
q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12694
self.run_test(model, q_input)
12696
@skipIfUnsupportedMinOpsetVersion(10)
12697
@skipIfQuantizationBackendQNNPack
12698
def test_quantized_conv3d(self):
12699
model = torch.ao.nn.quantized.Conv3d(16, 33, [2, 3, 4], stride=[3, 1, 2])
12700
# Manually initialize model weight and bias to random numbers.
12701
# By default all zeros.
12702
q_weight = torch.quantize_per_tensor(
12703
torch.randn(33, 16, 2, 3, 4), 0.5, 0, torch.qint8
12705
bias = torch.arange(33).to(torch.float) - 16
12706
model.set_weight_bias(q_weight, bias)
12707
input = torch.randn(3, 16, 8, 8, 8)
12708
q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12709
self.run_test(model, q_input)
12711
@skipIfUnsupportedMinOpsetVersion(10)
12712
def test_quantized_adaptive_avg_pool2d(self):
12713
model = torch.nn.AdaptiveAvgPool2d((5, 7))
12714
input = torch.randn(4, 3, 10, 14)
12715
q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8)
12716
self.run_test(model, q_input)
12718
@skipIfUnsupportedMinOpsetVersion(10)
12719
def test_quantized_conv1d_relu(self):
12720
model = torch.ao.nn.intrinsic.quantized.ConvReLU1d(16, 33, 3, stride=2)
12721
# Manually initialize model weight and bias to random numbers.
12722
# By default all zeros.
12723
q_weight = torch.quantize_per_tensor(
12724
torch.randn(33, 16, 3), 0.5, 0, torch.qint8
12726
bias = torch.arange(33).to(torch.float) - 16
12727
model.set_weight_bias(q_weight, bias)
12728
input = torch.randn(3, 16, 32)
12729
q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12730
self.run_test(model, q_input)
12732
@skipIfUnsupportedMinOpsetVersion(10)
12733
def test_quantized_conv2d_relu(self):
12734
model = torch.ao.nn.intrinsic.quantized.ConvReLU2d(16, 33, 3, stride=2)
12735
# Manually initialize model weight and bias to random numbers.
12736
# By default all zeros.
12737
q_weight = torch.quantize_per_tensor(
12738
torch.randn(33, 16, 3, 3), 0.5, 0, torch.qint8
12740
bias = torch.arange(33).to(torch.float) - 16
12741
model.set_weight_bias(q_weight, bias)
12742
input = torch.randn(3, 16, 32, 32)
12743
q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12744
self.run_test(model, q_input)
12746
@skipIfUnsupportedMinOpsetVersion(10)
12747
@skipIfQuantizationBackendQNNPack
12748
def test_quantized_conv3d_relu(self):
12749
model = torch.ao.nn.intrinsic.quantized.ConvReLU3d(
12750
16, 33, [2, 3, 4], stride=[3, 1, 2]
12752
# Manually initialize model weight and bias to random numbers.
12753
# By default all zeros.
12754
q_weight = torch.quantize_per_tensor(
12755
torch.randn(33, 16, 2, 3, 4), 0.5, 0, torch.qint8
12757
bias = torch.arange(33).to(torch.float) - 16
12758
model.set_weight_bias(q_weight, bias)
12759
input = torch.randn(3, 16, 8, 8, 8)
12760
q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12761
self.run_test(model, q_input)
12763
@skipIfUnsupportedMinOpsetVersion(10)
12764
def test_quantized_conv_transpose1d(self):
12765
model = torch.ao.nn.quantized.ConvTranspose1d(
12766
16, 33, 3, output_padding=1, stride=2
12768
# Manually initialize model weight and bias to random numbers.
12769
# By default all zeros.
12770
q_weight = torch.quantize_per_tensor(
12771
torch.randn(16, 33, 3), 0.5, 0, torch.qint8
12773
bias = torch.arange(33).to(torch.float) - 16
12774
model.set_weight_bias(q_weight, bias)
12775
input = torch.randn(3, 16, 32)
12776
q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12777
self.run_test(model, q_input)
12779
@skipIfUnsupportedMinOpsetVersion(10)
12780
def test_quantized_conv_transpose2d(self):
12781
model = torch.ao.nn.quantized.ConvTranspose2d(
12782
16, 33, 3, output_padding=(0, 1), stride=2
12784
# Manually initialize model weight and bias to random numbers.
12785
# By default all zeros.
12786
q_weight = torch.quantize_per_tensor(
12787
torch.randn(16, 33, 3, 3), 0.5, 0, torch.qint8
12789
bias = torch.arange(33).to(torch.float) - 16
12790
model.set_weight_bias(q_weight, bias)
12791
input = torch.randn(3, 16, 32, 32)
12792
q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12793
self.run_test(model, q_input)
12795
@skipIfUnsupportedMinOpsetVersion(10)
12796
@skipIfQuantizationBackendQNNPack
12797
def test_quantized_conv_transpose3d(self):
12798
model = torch.ao.nn.quantized.ConvTranspose3d(
12799
16, 33, [2, 3, 4], output_padding=(0, 1, 2), stride=[3, 1, 2]
12801
# Manually initialize model weight and bias to random numbers.
12802
# By default all zeros.
12803
q_weight = torch.quantize_per_tensor(
12804
torch.randn(16, 33, 2, 3, 4), 0.5, 0, torch.qint8
12806
bias = torch.arange(33).to(torch.float) - 16
12807
model.set_weight_bias(q_weight, bias)
12808
input = torch.randn(3, 16, 8, 8, 8)
12809
q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12810
self.run_test(model, q_input)
12812
@common_utils.parametrize(
12813
"function_or_module",
12815
common_utils.subtest(
12819
common_utils.subtest(
12820
torch.nn.LeakyReLU(),
12823
common_utils.subtest(
12824
torch.ao.nn.quantized.LeakyReLU(2.0, 1),
12825
name="quantized_leaky_relu",
12827
common_utils.subtest(
12828
torch.ao.nn.quantized.Hardswish(2.0, 1),
12829
name="quantized_hardswish",
12831
common_utils.subtest(
12832
torch.nn.Sigmoid(),
12835
common_utils.subtest(
12836
torch.ao.nn.quantized.Sigmoid(2.0, 1),
12837
name="quantized_sigmoid",
12839
common_utils.subtest(
12840
torch.nn.Hardsigmoid(),
12841
name="hardsigmoid",
12843
common_utils.subtest(
12847
common_utils.subtest(
12848
torch.nn.Hardtanh(),
12851
common_utils.subtest(
12852
lambda x: torch.transpose(x, 0, 1),
12855
common_utils.subtest(
12856
lambda x: x.expand(2, 4, 2, 3),
12859
common_utils.subtest(
12860
lambda x: x.view(1, 4, 6),
12863
common_utils.subtest(
12864
lambda x: x.select(1, 1),
12867
common_utils.subtest(
12868
torch.ao.nn.quantized.LayerNorm(
12870
torch.nn.Parameter(torch.ones([4, 2, 3])),
12871
torch.nn.Parameter(torch.zeros([4, 2, 3])),
12877
common_utils.subtest(
12878
torch.ao.nn.quantized.InstanceNorm1d(
12880
torch.nn.Parameter(torch.ones(4)),
12881
torch.nn.Parameter(torch.zeros(4)),
12885
name="instance_norm",
12887
common_utils.subtest(
12888
torch.ao.nn.quantized.GroupNorm(
12891
torch.nn.Parameter(torch.zeros(4)),
12892
torch.nn.Parameter(torch.zeros(4)),
12898
common_utils.subtest(
12899
lambda x: torch.as_strided(x, (2, 2), (1, 2)),
12905
@skipIfUnsupportedMinOpsetVersion(10)
12906
def test_quantized_unary_ops(self, function_or_module):
12907
input = torch.randn(1, 4, 2, 3)
12908
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
12910
class Model(torch.nn.Module):
12911
def __init__(self, function_or_module):
12913
self.function_or_module = function_or_module
12915
def forward(self, x):
12916
return self.function_or_module(x)
12918
self.run_test(Model(function_or_module), q_input)
12920
@skipIfUnsupportedMinOpsetVersion(10)
12921
def test_quantized_flatten(self):
12922
class FlattenModel(torch.nn.Module):
12923
def forward(self, input):
12924
return torch.flatten(input)
12926
x = torch.quantize_per_tensor(torch.randn(1, 2, 3, 4), 1, 0, torch.quint8)
12927
self.run_test(FlattenModel(), x)
12929
@skipIfUnsupportedMinOpsetVersion(10)
12930
@skipScriptTest() # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function:
12931
def test_quantized_cat_when_concatinating_the_same_tensor(self):
12932
class QuantizedSelfConcatenationModel(torch.nn.Module):
12933
def forward(self, x):
12934
return torch.ao.nn.quantized.QFunctional().cat((x, x), dim=1)
12936
q_input = torch.quantize_per_tensor(torch.ones(2, 3), 0.26, 128, torch.quint8)
12937
self.run_test(QuantizedSelfConcatenationModel(), q_input)
12939
@common_utils.parametrize(
12942
common_utils.subtest(
12944
torch.quantize_per_tensor(
12945
torch.ones(2, 3), 0.26, 128, torch.quint8
12947
torch.quantize_per_tensor(
12948
torch.zeros(1, 3), 0.26, 128, torch.quint8
12951
name="different_shape",
12953
common_utils.subtest(
12955
torch.quantize_per_tensor(
12956
torch.ones(2, 3), 0.26, 128, torch.quint8
12958
torch.quantize_per_tensor(torch.ones(2, 3), 42, 1, torch.quint8),
12960
name="different_scale",
12962
common_utils.subtest(
12964
torch.quantize_per_tensor(
12965
torch.ones(2, 3), 0.26, 128, torch.quint8
12967
torch.quantize_per_tensor(torch.ones(2, 3), 0.26, 63, torch.quint8),
12969
name="different_zero_point",
12971
common_utils.subtest(
12973
torch.quantize_per_tensor(
12974
torch.ones(2, 3), 0.26, 128, torch.quint8
12976
torch.quantize_per_tensor(torch.ones(2, 3), 0.1, 63, torch.quint8),
12978
name="different_zero_point_and_scale",
12982
@skipIfUnsupportedMinOpsetVersion(10)
12983
@skipScriptTest() # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function:
12984
def test_quantized_cat(self, x: torch.Tensor, y: torch.Tensor):
12985
class QuantizedConcatenationModel(torch.nn.Module):
12986
def forward(self, x, y):
12987
return torch.ao.nn.quantized.QFunctional().cat((x, y), dim=0)
12989
self.run_test(QuantizedConcatenationModel(), (x, y))
12991
@skipIfUnsupportedMinOpsetVersion(10)
12992
# torch.jit.frontend.FrontendError:
12993
# Cannot instantiate class 'QFunctional' in a script function
12995
def test_quantized_arithmetic_qfunctional(self):
12996
x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
12997
y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
12999
class ArithmeticModel(torch.nn.Module):
13000
def forward(self, x, y):
13001
o = torch.ao.nn.quantized.QFunctional().add(x, y)
13002
o = torch.ao.nn.quantized.QFunctional().mul(o, x)
13005
self.run_test(ArithmeticModel(), (x, y))
13007
@skipIfUnsupportedMinOpsetVersion(10)
13008
def test_quantized_arithmetic(self):
13009
x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
13010
y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
13012
class ArithmeticModel2(torch.nn.Module):
13013
def forward(self, x, y):
13014
o = torch.ops.quantized.add(x, y, 0.4, 100)
13015
o = torch.ops.quantized.mul(o, x, 0.4, 100)
13018
self.run_test(ArithmeticModel2(), (x, y))
13020
@skipIfUnsupportedMinOpsetVersion(10)
13021
def test_quantize_per_tensor(self):
13022
class Module(torch.nn.Module):
13023
def forward(self, x):
13025
torch.quantize_per_tensor(x, 0.2, 0, torch.qint8),
13026
torch.quantize_per_tensor(x, 0.2, 128, torch.quint8),
13029
x = torch.randn(4, 6)
13030
self.run_test(Module(), x)
13032
@skipIfUnsupportedMinOpsetVersion(10)
13033
def test_dequantize(self):
13034
class Module(torch.nn.Module):
13035
def forward(self, x):
13036
return torch.dequantize(x)
13038
x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 0, torch.qint8)
13039
self.run_test(Module(), x)
13041
@skipIfUnsupportedMinOpsetVersion(13)
13042
def test_qat_linear_per_channel(self):
13043
class M(torch.nn.Module):
13044
def __init__(self):
13046
self.quant = torch.ao.quantization.QuantStub()
13047
self.linear = torch.nn.Linear(4, 3)
13048
self.dequant = torch.ao.quantization.DeQuantStub()
13050
def forward(self, x):
13053
x = self.dequant(x)
13057
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13058
model = torch.ao.quantization.prepare_qat(model)
13059
# Set fixed weight and bias to avoid flaky test.
13060
model.linear.weight = torch.nn.Parameter(
13061
_construct_tensor_for_quantization_test((3, 4))
13063
model.linear.bias = torch.nn.Parameter(torch.arange(3, dtype=torch.float))
13064
model = torch.ao.quantization.convert(model)
13066
# Set fixed input to avoid flaky test.
13067
input = _construct_tensor_for_quantization_test((4, 4), offset=-8)
13068
self.run_test(model, input)
13071
"ORT fails with Validating no unexpected access using an invalid node_index on torch converted model"
13073
@skipIfUnsupportedMinOpsetVersion(13)
13074
def test_quantized_list_of_inputs_with_cat(self):
13075
class TestModel(torch.nn.Module):
13076
def __init__(self):
13078
self.quant = torch.ao.quantization.QuantStub()
13079
self.dequant = torch.ao.quantization.DeQuantStub()
13081
def forward(self, x):
13083
x = torch.cat([x, x], 1)
13084
x = self.dequant(x)
13087
model = TestModel()
13088
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13089
model = torch.ao.quantization.prepare_qat(model)
13090
model = torch.ao.quantization.convert(model)
13091
x = torch.randn(2, 4, 6)
13092
self.run_test(model, x)
13094
@skipIfUnsupportedMinOpsetVersion(13)
13095
def test_qat_relu(self):
13096
class M(torch.nn.Module):
13097
def __init__(self):
13099
self.quant = torch.ao.quantization.QuantStub()
13100
self.relu = torch.nn.ReLU()
13101
self.dequant = torch.ao.quantization.DeQuantStub()
13103
def forward(self, x):
13106
x = self.dequant(x)
13110
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13111
model = torch.ao.quantization.prepare_qat(model)
13112
model = torch.ao.quantization.convert(model)
13113
input = torch.randn(8, 4)
13114
self.run_test(model, input)
13116
@skipIfUnsupportedMinOpsetVersion(13)
13117
def test_qat_conv2d(self):
13118
class M(torch.nn.Module):
13119
def __init__(self):
13121
self.quant = torch.ao.quantization.QuantStub()
13122
self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
13123
self.dequant = torch.ao.quantization.DeQuantStub()
13125
def forward(self, x):
13128
x = self.dequant(x)
13132
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13133
model = torch.ao.quantization.prepare_qat(model)
13134
# Set fixed weight and bias to avoid flaky test.
13135
model.conv.weight = torch.nn.Parameter(
13136
_construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2)
13138
model.conv.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0]))
13139
model = torch.ao.quantization.convert(model)
13141
# Set fixed input to avoid flaky test.
13142
input = _construct_tensor_for_quantization_test(
13143
(3, 4, 8, 8), offset=-384, max_val=12
13145
self.run_test(model, input)
13147
@skipIfUnsupportedMinOpsetVersion(13)
13148
def test_qat_conv2d_relu(self):
13149
class M(torch.nn.Module):
13150
def __init__(self):
13152
self.quant = torch.ao.quantization.QuantStub()
13153
self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
13154
self.relu = torch.nn.ReLU()
13155
self.dequant = torch.ao.quantization.DeQuantStub()
13157
def forward(self, x):
13161
x = self.dequant(x)
13165
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13166
model = torch.ao.quantization.prepare_qat(model)
13167
# Set fixed weight and bias to avoid flaky test.
13168
model.conv.weight = torch.nn.Parameter(
13169
_construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2)
13171
model.conv.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0]))
13172
model = torch.ao.quantization.convert(model)
13174
# Set fixed input to avoid flaky test.
13175
input = _construct_tensor_for_quantization_test(
13176
(3, 4, 8, 8), offset=-384, max_val=12
13178
self.run_test(model, input)
13180
@skipIfUnsupportedMinOpsetVersion(13)
13181
def test_qat_conv2d_relu_fused(self):
13182
class M(torch.nn.Module):
13183
def __init__(self):
13185
self.quant = torch.ao.quantization.QuantStub()
13186
self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
13187
self.relu = torch.nn.ReLU()
13188
self.dequant = torch.ao.quantization.DeQuantStub()
13190
def forward(self, x):
13194
x = self.dequant(x)
13198
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13199
model = torch.ao.quantization.fuse_modules(model.eval(), [["conv", "relu"]])
13200
model = torch.ao.quantization.prepare_qat(model.train())
13201
# Set fixed weight and bias to avoid flaky test.
13202
model.conv.weight = torch.nn.Parameter(
13203
_construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2)
13205
model.conv.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0]))
13206
model = torch.ao.quantization.convert(model)
13208
# Set fixed input to avoid flaky test.
13209
input = _construct_tensor_for_quantization_test(
13210
(3, 4, 8, 8), offset=-384, max_val=12
13212
self.run_test(model, input)
13214
@skipIfUnsupportedMinOpsetVersion(13)
13215
def test_qat_linear_relu_fused(self):
13216
class M(torch.nn.Module):
13217
def __init__(self):
13219
self.quant = torch.ao.quantization.QuantStub()
13220
self.linear = torch.nn.Linear(4, 2)
13221
self.relu = torch.nn.ReLU()
13222
self.dequant = torch.ao.quantization.DeQuantStub()
13224
def forward(self, x):
13228
x = self.dequant(x)
13232
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13233
model = torch.ao.quantization.fuse_modules(model.eval(), [["linear", "relu"]])
13234
model = torch.ao.quantization.prepare_qat(model.train())
13235
# Set fixed weight and bias to avoid flaky test.
13236
model.linear.weight = torch.nn.Parameter(
13237
_construct_tensor_for_quantization_test((2, 4), max_val=2)
13239
model.linear.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0]))
13240
model = torch.ao.quantization.convert(model)
13242
# Set fixed input to avoid flaky test.
13243
input = _construct_tensor_for_quantization_test((3, 4), offset=-384, max_val=12)
13244
self.run_test(model, input)
13246
@skipIfUnsupportedMinOpsetVersion(10)
13247
def test_qat_maxpool2d(self):
13248
class M(torch.nn.Module):
13249
def __init__(self):
13251
self.quant = torch.ao.quantization.QuantStub()
13252
self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
13253
self.dequant = torch.ao.quantization.DeQuantStub()
13255
def forward(self, x):
13258
x = self.dequant(x)
13262
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13263
model = torch.ao.quantization.prepare_qat(model.train())
13264
model = torch.ao.quantization.convert(model)
13266
# Set fixed input to avoid flaky test.
13267
input = _construct_tensor_for_quantization_test((4, 4, 3, 2))
13268
self.run_test(model, input)
13270
@skipIfUnsupportedMinOpsetVersion(10)
13271
@skipScriptTest() # Scale and Zero-point must be a scalar in ORT:optimization
13272
def test_qat_avg_pool2d(self):
13273
model = torch.nn.Sequential(
13274
torch.ao.quantization.QuantStub(),
13275
torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
13276
torch.ao.quantization.DeQuantStub(),
13278
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13279
model = torch.ao.quantization.prepare_qat(model.train())
13280
model = torch.ao.quantization.convert(model)
13281
input = _construct_tensor_for_quantization_test((4, 4, 3, 2))
13282
self.run_test(model, input)
13284
@skipIfUnsupportedMinOpsetVersion(11)
13285
def test_qat_upsample_nearest2d(self):
13286
model = torch.nn.Sequential(
13287
torch.ao.quantization.QuantStub(),
13288
torch.nn.UpsamplingNearest2d(scale_factor=1.5),
13289
torch.ao.quantization.DeQuantStub(),
13291
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13292
model = torch.ao.quantization.prepare_qat(model.train())
13293
model = torch.ao.quantization.convert(model)
13294
input = _construct_tensor_for_quantization_test((4, 3, 2, 2))
13295
self.run_test(model, input)
13297
def test_0d_tensor_broadcast(self):
13298
class fn(torch.nn.Module):
13299
def forward(self, x, y):
13300
a = torch.add(x, y)
13301
b = torch.mul(y, y)
13306
self.run_test(fn(), (x, y), input_names=["x", "y"], output_names=["output"])
13308
@skipIfUnsupportedMinOpsetVersion(9)
13309
def test_convolution_allow_tf32(self):
13310
class Module(torch.nn.Module):
13311
def __init__(self, allow_tf32):
13314
self.allow_tf32 = allow_tf32
13315
weight = torch.rand(32, 3, 3, 3)
13316
self.weight = torch.nn.Parameter(weight)
13318
def forward(self, x):
13319
if self.allow_tf32:
13320
return torch._convolution(
13336
return torch._convolution(
13351
x = torch.randn(1, 3, 224, 224)
13352
self.run_test(Module(False), x, rtol=1e-3, atol=1e-6)
13353
self.run_test(Module(True), x, rtol=1e-3, atol=1e-6)
13355
@skipIfUnsupportedMinOpsetVersion(16)
13356
@common_utils.parametrize(
13358
("bilinear", "nearest", "bicubic"),
13360
@common_utils.parametrize(
13362
("zeros", "border", "reflection"),
13364
@common_utils.parametrize(
13367
name_fn=lambda align_corners: str(align_corners),
13369
def test_grid_sample(self, mode, padding_mode, align_corners):
13370
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
13372
class GridSampleModule(torch.nn.Module):
13373
def __init__(self, mode, padding_mode, align_corners) -> None:
13375
self.mode, self.padding_mode, self.align_corners = (
13381
def forward(self, input, grid):
13382
return torch.nn.functional.grid_sample(
13383
input, grid, self.mode, self.padding_mode, self.align_corners
13387
if (mode, padding_mode) == ("bicubic", "border"):
13389
atol_rtol.update({"atol": 0.3, "rtol": 0.4})
13391
atol_rtol.update({"atol": 0.02, "rtol": 0.02})
13392
input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2)
13394
GridSampleModule(mode, padding_mode, align_corners),
13399
# ONNX Opset 16 GridSample with 5D volumetric input is not supported.
13402
volumetric_input_tensor = torch.randn(n, c, d_in, h_in, w_in)
13403
volumetric_grid_tensor = torch.randn(n, d_out, h_out, w_out, 3)
13404
for mode, padding_mode, align_corners in itertools.product(
13408
), # PyTorch grid_sample "bicubic" mode does not support 5D volumetric input.
13419
with self.assertRaises(
13420
torch.onnx.errors.OnnxExporterError,
13423
GridSampleModule(mode, padding_mode, align_corners),
13424
(volumetric_input_tensor, volumetric_grid_tensor),
13428
class IfNoneInput(torch.nn.Module):
13429
def forward(self, x) -> Optional[Tensor]:
13430
y: Optional[Tensor] = None
13435
class IfNoneOutput(torch.nn.Module):
13436
def forward(self, x) -> Optional[Tensor]:
13437
y: Optional[Tensor] = x
13442
class LoopNoneInput(torch.nn.Module):
13443
def forward(self, x) -> Optional[Tensor]:
13444
y: Optional[Tensor] = None
13445
for _ in range(x.size(0)):
13449
class LoopNoneOutput(torch.nn.Module):
13450
def forward(self, x) -> Optional[Tensor]:
13451
y: Optional[Tensor] = x
13452
for _ in range(x.size(0)):
13456
@common_utils.parametrize(
13458
(IfNoneOutput, IfNoneInput, LoopNoneOutput, LoopNoneInput),
13459
name_fn=lambda module_class: module_class.__name__,
13461
@common_utils.parametrize("x_size", (0, 1), name_fn=lambda x_size: str(x_size))
13463
@skipIfUnsupportedMinOpsetVersion(16)
13464
def test_optional_output(self, module_class: Type[torch.nn.Module], x_size: int):
13465
# Need scripting to preserve control flow for this test to be
13467
model = torch.jit.script(module_class())
13469
x = torch.ones(x_size)
13470
dynamic_axis_name = "condition"
13475
opset_version=self.opset_version,
13476
# Ensure condition is not constant
13477
dynamic_axes={"x": {0: dynamic_axis_name}},
13480
exported = onnx.load_from_string(f.getvalue())
13481
expected_elem_type = torch.onnx.JitScalarType.from_value(x).onnx_type()
13482
expected_output_type = onnx.helper.make_optional_type_proto(
13483
onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,))
13485
self.assertEqual(expected_output_type, exported.graph.output[0].type)
13486
for node in exported.graph.node:
13487
# Both branches output types should match.
13488
if node.op_type == "If":
13489
for attr in node.attribute:
13490
if attr.name in ("then_branch", "else_branch"):
13491
self.assertEqual(expected_output_type, attr.g.output[0].type)
13496
# Ensure condition is not constant
13497
dynamic_axes={"x": {0: dynamic_axis_name}},
13502
@skipIfUnsupportedMinOpsetVersion(16)
13503
def test_uninitialized_optional(self):
13504
class Module(torch.nn.Module):
13505
def forward(self, y: Optional[Tensor]) -> Optional[Tensor]:
13516
torch.ones((3, 4), dtype=torch.int),
13517
dynamic_axes={"y": {0: "y0", 1: "y1"}},
13521
@skipIfUnsupportedMinOpsetVersion(9)
13522
def test_device_eq(self):
13523
class M(torch.nn.Module):
13524
def forward(self, a):
13525
# exercise both Tensor.device (prim::device)
13526
# and torch.device (prim::Constant).
13527
if a.device != torch.device("cpu"):
13529
return torch.zeros_like(a)
13531
mod = torch.jit.script(M()) # preserve control flow
13535
# In order for the ONNX model behavior to match the torch model, we
13536
# need to construct input that has the same device that is checked for
13537
# in forward(). In ONNX there is no such thing as a device, so the if
13538
# condition is always false.
13539
torch.randn(3, 3, device="cpu"),
13540
# Force dynamic axes so that the output shape depends on the input.
13541
# Otherwise the entire model will just return a constant and not have
13544
dynamic_axes={"a": {0: "a0"}},
13547
@skipIfUnsupportedMinOpsetVersion(9)
13548
def test_lerp(self):
13549
class LerpModel(torch.nn.Module):
13550
def forward(self, x):
13552
x.lerp(torch.full_like(x, 10), 0.4),
13553
x.lerp(torch.full_like(x, 20), 0.7),
13554
x.lerp(torch.full_like(x, 30), torch.tensor(0.4)),
13555
x.lerp(torch.full_like(x, 40), x / 10.0),
13556
x.lerp(torch.tensor(10.0), x / 10.0),
13557
x.lerp(torch.tensor(10.0), 0.4),
13558
x.lerp(torch.tensor(10.0), torch.tensor(0.4)),
13561
self.run_test(LerpModel(), torch.rand(5, 4, 3))
13563
@common_utils.parametrize("input_dtype", [torch.cfloat, torch.float])
13564
@skipIfUnsupportedMinOpsetVersion(9)
13565
def test_print_tensor_within_torch_nn_module(self, input_dtype: torch.dtype):
13566
class PrintTensorOnMyModel(torch.nn.Module):
13567
def forward(self, x):
13568
# 'print' has side effect calling 'resolve_conj' and 'resolve_neg'.
13570
print(f"x_firsts: {x_firsts}")
13571
# 'tolist' has side effect calling 'resolve_conj' and 'resolve_neg'.
13572
# Annotation added to pass torch script.
13573
_: List[float] = x.tolist()
13576
m = PrintTensorOnMyModel()
13577
x = torch.randn(10, 5, dtype=input_dtype)
13578
if input_dtype == torch.cfloat:
13579
with self.assertRaises(RuntimeError):
13591
@skipIfUnsupportedMinOpsetVersion(16)
13593
not torch.hub._check_module_exists("torch_geometric"),
13594
"torch_geometric not installed.",
13596
def test_sage_conv(self):
13597
from torch_geometric import nn as torch_geometric_nn
13600
coords0 = torch.randn(1, 6)
13601
coords1 = torch.randn(1, 6)
13602
coords = torch.transpose(torch.cat((coords0, coords1), dim=0), 0, 1)
13603
adj = torch_geometric_nn.knn_graph(coords, k=2, batch=None, loop=True)
13604
edge_from = adj[0:1, :]
13605
edge_to = adj[1:, :]
13606
inputs = (coords0, coords1, edge_from, edge_to)
13608
class MySAGEConv(torch.nn.Module):
13609
def __init__(self):
13611
self.SAGEConvBlock1 = torch_geometric_nn.SAGEConv(
13612
2, 512, normalize=True
13614
self.bano1 = torch_geometric_nn.BatchNorm(512)
13615
self.relu = torch.nn.ReLU()
13616
self.dense1 = torch.nn.Seq(Lin(512, 1)) # noqa: F821
13617
self.sigmoid = torch.nn.Sigmoid()
13619
def forward(self, coords0, coords1, edge_from, edge_to):
13620
adj = torch.cat((edge_from, edge_to), dim=0)
13621
gra = torch.transpose(torch.cat((coords0, coords1), dim=0), 0, 1)
13622
x1 = self.SAGEConvBlock1(gra, edge_index=adj)
13623
x = torch.unsqueeze(torch.sum(x1), dim=0)
13626
input_names = ["coords0", "coords1", "edge_from", "edge_to"]
13627
output_names = ["outputs"]
13629
"coords0": {0: "batch_size", 1: "features"},
13630
"coords1": {0: "batch_size", 1: "features"},
13631
"edge_from": {0: "batch_size", 1: "features"},
13632
"edge_to": {0: "batch_size", 1: "features"},
13633
"outputs": {0: "batch_size"},
13638
input_names=input_names,
13639
output_names=output_names,
13640
dynamic_axes=dynamic_axes,
13643
# Cannot export with older opsets because of "ConstantFill" op
13644
# ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime
13645
# There are still some issues prevent us from enabling script test for these scenarios:
13647
# Operator aten::as_tensor is not supported by exporter yet.
13648
# - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055382
13649
# Operator aten::_pack_padded_sequence is not supported by exporter yet.
13650
# - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055384
13652
# Compiling in script mode fails with errors like:
13653
# torch.jit.frontend.UnsupportedNodeError: annotated assignments
13654
# without assigned value aren't supported
13655
# - https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723
13657
# Compiling in script mode fails with errors like:
13658
# RuntimeError: Arguments for call are not valid.
13659
# - https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723
13661
@skipIfUnsupportedMinOpsetVersion(9)
13662
@common_utils.parametrize(
13663
"name, nonlinearity",
13671
@common_utils.parametrize(**_parametrize_rnn_args("layers"))
13672
@common_utils.parametrize(**_parametrize_rnn_args("bidirectional"))
13673
@common_utils.parametrize(**_parametrize_rnn_args("initial_state"))
13674
@common_utils.parametrize(**_parametrize_rnn_args("packed_sequence"))
13675
@common_utils.parametrize(**_parametrize_rnn_args("dropout"))
13676
def test_rnn(self, *args, **kwargs):
13677
self._dispatch_rnn_test(*args, **kwargs)
13680
if __name__ == "__main__":
13681
common_utils.TestCase._default_dtype_check_enabled = True
13682
common_utils.run_tests()