pytorch

Форк
0
/
test_pytorch_onnx_onnxruntime.py 
13682 строки · 472.2 Кб
1
# Owner(s): ["module: onnx"]
2

3
from __future__ import annotations
4

5
import functools
6

7
import io
8
import itertools
9
import os
10
import unittest
11
from collections import OrderedDict
12
from typing import Dict, List, Optional, Tuple, Type, Union
13

14
import numpy as np
15
import onnx
16
import onnx_test_common
17
import parameterized
18
import torch
19
import torchvision
20
from model_defs import (
21
    lstm_flattening_result,
22
    rnn_model_with_packed_sequence,
23
    word_language_model,
24
)
25
from pytorch_test_common import (
26
    BATCH_SIZE,
27
    RNN_BATCH_SIZE,
28
    RNN_HIDDEN_SIZE,
29
    RNN_INPUT_SIZE,
30
    RNN_SEQUENCE_LENGTH,
31
    skipDtypeChecking,
32
    skipIfQuantizationBackendQNNPack,
33
    skipIfUnsupportedMaxOpsetVersion,
34
    skipIfUnsupportedMinOpsetVersion,
35
    skipIfUnsupportedOpsetVersion,
36
    skipScriptTest,
37
    skipShapeChecking,
38
    skipTraceTest,
39
)
40

41
from torch import Tensor
42
from torch.nn.utils import rnn as rnn_utils
43
from torch.onnx import errors, verification
44
from torch.testing._internal import common_utils
45
from torch.testing._internal.common_utils import skipIfNoLapack
46

47

48
def _init_test_generalized_rcnn_transform():
49
    min_size = 100
50
    max_size = 200
51
    image_mean = [0.485, 0.456, 0.406]
52
    image_std = [0.229, 0.224, 0.225]
53
    transform = torchvision.models.detection.transform.GeneralizedRCNNTransform(
54
        min_size, max_size, image_mean, image_std
55
    )
56
    return transform
57

58

59
def _init_test_rpn():
60
    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
61
    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
62
    rpn_anchor_generator = torchvision.models.detection.rpn.AnchorGenerator(
63
        anchor_sizes, aspect_ratios
64
    )
65
    out_channels = 256
66
    rpn_head = torchvision.models.detection.rpn.RPNHead(
67
        out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
68
    )
69
    rpn_fg_iou_thresh = 0.7
70
    rpn_bg_iou_thresh = 0.3
71
    rpn_batch_size_per_image = 256
72
    rpn_positive_fraction = 0.5
73
    rpn_pre_nms_top_n = dict(training=2000, testing=1000)
74
    rpn_post_nms_top_n = dict(training=2000, testing=1000)
75
    rpn_nms_thresh = 0.7
76
    rpn_score_thresh = 0.0
77

78
    rpn = torchvision.models.detection.rpn.RegionProposalNetwork(
79
        rpn_anchor_generator,
80
        rpn_head,
81
        rpn_fg_iou_thresh,
82
        rpn_bg_iou_thresh,
83
        rpn_batch_size_per_image,
84
        rpn_positive_fraction,
85
        rpn_pre_nms_top_n,
86
        rpn_post_nms_top_n,
87
        rpn_nms_thresh,
88
        score_thresh=rpn_score_thresh,
89
    )
90
    return rpn
91

92

93
def _construct_tensor_for_quantization_test(
94
    shape: Tuple[int, ...],
95
    offset: Optional[Union[int, float]] = None,
96
    max_val: Optional[Union[int, float]] = None,
97
) -> Tensor:
98
    """Helper function to generate weights and test inputs in a deterministic way.
99

100
    Due to difference in implementation details between PyTorch and ONNXRuntime, randomly generated
101
    test data for quantization tests can be flaky. To help stablize the test, this helper function is
102
    used to generate weights and test inputs in a deterministic way.
103

104
    Args:
105
        shape (Tuple[int]): Shape for tensor to construct.
106
        offset (Optional[Union[int, float]]): Offset to be added to the generated tensor.
107
        max_val (Optional[Union[int, float]]): If any element within tensor has a larger absolute value than
108
            max_val, the tensor will be scaled by max_val / tensor.abs().max(). This step is done after
109
            applying offset.
110
    """
111
    tensor = torch.arange(np.prod(shape), dtype=torch.float).view(shape)
112
    if offset is not None:
113
        tensor = tensor + offset
114
    if max_val is not None and tensor.abs().max() > max_val:
115
        tensor = tensor * max_val / tensor.abs().max()
116
    return tensor
117

118

119
def _parameterized_class_attrs_and_values(
120
    min_opset_version: int, max_opset_version: int
121
):
122
    attrs = ("opset_version", "is_script", "keep_initializers_as_inputs")
123
    input_values = []
124
    input_values.extend(itertools.product((7, 8), (True, False), (True,)))
125
    # Valid opset versions are defined in torch/onnx/_constants.py.
126
    # Versions are intentionally set statically, to not be affected by changes elsewhere.
127
    if min_opset_version < 9:
128
        raise ValueError("min_opset_version must be >= 9")
129
    input_values.extend(
130
        itertools.product(
131
            range(min_opset_version, max_opset_version + 1),
132
            (True, False),
133
            (True, False),
134
        )
135
    )
136
    return {"attrs": attrs, "input_values": input_values}
137

138

139
def _parametrize_rnn_args(arg_name):
140
    options = {
141
        "layers": {1: "unilayer", 3: "trilayer"},
142
        "bidirectional": {True: "bidirectional", False: "forward"},
143
        "initial_state": {True: "with_initial_state", False: "no_initial_state"},
144
        "packed_sequence": {
145
            0: "without_sequence_lengths",
146
            1: "with_variable_length_sequences",
147
            2: "with_batch_first_sequence_lengths",
148
        },
149
        "dropout": {0.2: "with_dropout", 0.0: "without_dropout"},
150
    }
151

152
    return {
153
        "arg_str": arg_name,
154
        "arg_values": options[arg_name].keys(),
155
        "name_fn": lambda val: options[arg_name][val],
156
    }
157

158

159
@parameterized.parameterized_class(
160
    **_parameterized_class_attrs_and_values(
161
        onnx_test_common.MIN_ONNX_OPSET_VERSION, onnx_test_common.MAX_ONNX_OPSET_VERSION
162
    ),
163
    class_name_func=onnx_test_common.parameterize_class_name,
164
)
165
@common_utils.instantiate_parametrized_tests
166
class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
167
    def test_fuse_conv_bn1d(self):
168
        class Fuse(torch.nn.Module):
169
            def __init__(self):
170
                super().__init__()
171
                self.conv = torch.nn.Conv1d(16, 33, 3, stride=2)
172
                self.bn = torch.nn.BatchNorm1d(33)
173

174
            def forward(self, x):
175
                out = self.conv(x)
176
                return self.bn(out)
177

178
        model = Fuse()
179
        x = torch.randn(20, 16, 50, requires_grad=True)
180
        self.run_test(model, (x,))
181

182
    def test_fuse_conv_bn2d(self):
183
        class Fuse(torch.nn.Module):
184
            def __init__(self):
185
                super().__init__()
186
                self.conv = torch.nn.Conv2d(
187
                    3, 2, kernel_size=1, stride=2, padding=3, bias=False
188
                )
189
                self.bn = torch.nn.BatchNorm2d(2)
190

191
            def forward(self, x):
192
                out = self.conv(x)
193
                return self.bn(out)
194

195
        model = Fuse()
196
        x = torch.randn(2, 3, 2, 2, requires_grad=True)
197
        self.run_test(model, (x,))
198

199
    def test_fuse_conv_bn3d(self):
200
        class Fuse(torch.nn.Module):
201
            def __init__(self):
202
                super().__init__()
203
                self.conv = torch.nn.Conv3d(
204
                    3, 2, (3, 5, 2), stride=(2, 1, 1), padding=(3, 2, 0), bias=False
205
                )
206
                self.bn = torch.nn.BatchNorm3d(2)
207

208
            def forward(self, x):
209
                out = self.conv(x)
210
                return self.bn(out)
211

212
        model = Fuse()
213
        x = torch.randn(2, 3, 10, 50, 100, requires_grad=True)
214
        self.run_test(model, (x,), rtol=1e-3, atol=1e-6)
215

216
    def test_fuse_conv_in_block(self):
217
        class Fuse(torch.nn.Module):
218
            def __init__(self):
219
                super().__init__()
220
                self.conv = torch.nn.Conv1d(
221
                    in_channels=5,
222
                    out_channels=5,
223
                    kernel_size=3,
224
                    stride=1,
225
                    padding=2,
226
                    dilation=1,
227
                )
228
                self.bn = torch.nn.BatchNorm1d(5)
229

230
            def forward(self, x):
231
                results_available = True
232

233
                if x.sum() > -1:
234
                    results_available = False
235

236
                if results_available:
237
                    x = self.conv(x)
238
                    x = self.bn(x)
239

240
                return x
241

242
        model = Fuse()
243
        x = torch.randn(2, 5, 9, requires_grad=True)
244
        self.run_test(
245
            torch.jit.script(model),
246
            (x,),
247
            input_names=["x"],
248
            dynamic_axes={"x": [0, 2]},
249
            rtol=1e-3,
250
            atol=1e-6,
251
        )
252

253
    def test_conv_tbc(self):
254
        from torch.nn.modules.utils import _single
255

256
        class ConvTBC(torch.nn.Module):
257
            def __init__(self, in_channels, out_channels, kernel_size, padding=0):
258
                super().__init__()
259
                self.in_channels = in_channels
260
                self.out_channels = out_channels
261
                self.kernel_size = _single(kernel_size)
262
                self.padding = _single(padding)
263

264
                self.weight = torch.nn.Parameter(
265
                    Tensor(self.kernel_size[0], in_channels, out_channels)
266
                )
267
                self.bias = torch.nn.Parameter(Tensor(out_channels))
268
                self.reset_parameters()
269

270
            def reset_parameters(self):
271
                torch.nn.init.xavier_normal_(self.weight)
272
                torch.nn.init.zeros_(self.bias)
273

274
            def conv_tbc(self, input):
275
                return torch.conv_tbc(
276
                    input.contiguous(), self.weight, self.bias, self.padding[0]
277
                )
278

279
            def forward(self, input):
280
                return self.conv_tbc(input)
281

282
        in_channels = 3
283
        out_channels = 5
284
        kernel_size = 5
285
        model = ConvTBC(in_channels, out_channels, kernel_size, padding=0)
286
        x = torch.randn(10, 7, in_channels, requires_grad=True)
287
        self.run_test(model, (x,), atol=1e-5)
288

289
    def test_reshape_constant_fold(self):
290
        class Reshape(torch.nn.Module):
291
            def __init__(
292
                self,
293
            ):
294
                super().__init__()
295
                self.register_buffer("weight", torch.ones(5))
296

297
            def forward(self, x):
298
                scale_1 = self.weight.reshape(1, -1, 1, 1)
299
                return x * scale_1
300

301
        x = torch.randn(4, 5)
302
        self.run_test(Reshape(), (x,), rtol=1e-3, atol=1e-5)
303

304
    def run_word_language_model(self, model_name):
305
        ntokens = 50
306
        emsize = 5
307
        nhid = 5
308
        nlayers = 5
309
        dropout = 0.2
310
        tied = False
311
        batchsize = 5
312
        if model_name == "GRU":
313
            model = word_language_model.RNNModelWithTensorHidden(
314
                model_name, ntokens, emsize, nhid, nlayers, dropout, tied, batchsize
315
            )
316
        elif model_name == "LSTM":
317
            model = word_language_model.RNNModelWithTupleHidden(
318
                model_name, ntokens, emsize, nhid, nlayers, dropout, tied, batchsize
319
            )
320
        else:
321
            model = word_language_model.RNNModel(
322
                model_name, ntokens, emsize, nhid, nlayers, dropout, tied, batchsize
323
            )
324
        x = torch.arange(0, ntokens).long().view(-1, batchsize)
325
        # Only support CPU version, since tracer is not working in GPU RNN.
326
        self.run_test(model, (x, model.hidden))
327

328
    def get_image(self, rel_path: str, size: Tuple[int, int]) -> Tensor:
329
        from PIL import Image
330
        from torchvision import transforms
331

332
        data_dir = os.path.join(os.path.dirname(__file__), "assets")
333
        path = os.path.join(data_dir, *rel_path.split("/"))
334
        image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
335

336
        return transforms.ToTensor()(image)
337

338
    def get_test_images(self) -> Tuple[List[Tensor], List[Tensor]]:
339
        return (
340
            [self.get_image("grace_hopper_517x606.jpg", (100, 320))],
341
            [self.get_image("rgb_pytorch.png", (250, 380))],
342
        )
343

344
    def test_paste_mask_in_image(self):
345
        masks = torch.rand(10, 1, 26, 26)
346
        boxes = torch.rand(10, 4)
347
        boxes[:, 2:] += torch.rand(10, 2)
348
        boxes *= 50
349
        o_im_s = (100, 100)
350
        from torchvision.models.detection.roi_heads import paste_masks_in_image
351

352
        out = paste_masks_in_image(masks, boxes, o_im_s)
353
        jit_trace = torch.jit.trace(
354
            paste_masks_in_image,
355
            (masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])]),
356
        )
357
        out_trace = jit_trace(
358
            masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])]
359
        )
360

361
        assert torch.all(out.eq(out_trace))
362

363
        masks2 = torch.rand(20, 1, 26, 26)
364
        boxes2 = torch.rand(20, 4)
365
        boxes2[:, 2:] += torch.rand(20, 2)
366
        boxes2 *= 100
367
        o_im_s2 = (200, 200)
368
        from torchvision.models.detection.roi_heads import paste_masks_in_image
369

370
        out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
371
        out_trace2 = jit_trace(
372
            masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])]
373
        )
374

375
        assert torch.all(out2.eq(out_trace2))
376

377
    def test_heatmaps_to_keypoints(self):
378
        maps = torch.rand(10, 1, 26, 26)
379
        rois = torch.rand(10, 4)
380
        from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
381

382
        out = heatmaps_to_keypoints(maps, rois)
383
        jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois))
384
        out_trace = jit_trace(maps, rois)
385

386
        assert torch.all(out[0].eq(out_trace[0]))
387
        assert torch.all(out[1].eq(out_trace[1]))
388

389
        maps2 = torch.rand(20, 2, 21, 21)
390
        rois2 = torch.rand(20, 4)
391
        from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
392

393
        out2 = heatmaps_to_keypoints(maps2, rois2)
394
        out_trace2 = jit_trace(maps2, rois2)
395

396
        assert torch.all(out2[0].eq(out_trace2[0]))
397
        assert torch.all(out2[1].eq(out_trace2[1]))
398

399
    def test_word_language_model_RNN_TANH(self):
400
        self.run_word_language_model("RNN_TANH")
401

402
    def test_word_language_model_RNN_RELU(self):
403
        self.run_word_language_model("RNN_RELU")
404

405
    @skipScriptTest()  # scripting prim::unchecked_cast prim::setattr
406
    def test_word_language_model_LSTM(self):
407
        self.run_word_language_model("LSTM")
408

409
    def test_word_language_model_GRU(self):
410
        self.run_word_language_model("GRU")
411

412
    def test_index_1d(self):
413
        class MyModel(torch.nn.Module):
414
            def forward(self, input):
415
                return input[0]
416

417
        m1 = torch.randn(3, 4, 5, 6, 7)
418
        self.run_test(MyModel(), m1)
419

420
    def test_index_2d_1dimslice(self):
421
        class MyModel(torch.nn.Module):
422
            def forward(self, input):
423
                return input[0:1, :]
424

425
        m1 = torch.randn(3, 4, 5, 6, 7)
426
        self.run_test(MyModel(), m1)
427

428
    def test_index_2d_sliceint(self):
429
        class MyModel(torch.nn.Module):
430
            def forward(self, input):
431
                return input[1, :]
432

433
        m1 = torch.randn(3, 4, 5, 6, 7)
434
        self.run_test(MyModel(), m1)
435

436
    def test_index_2d_neg_slice(self):
437
        class MyModel(torch.nn.Module):
438
            def forward(self, input):
439
                return input[0:-1, :]
440

441
        m1 = torch.randn(3, 4, 5, 6, 7)
442
        self.run_test(MyModel(), m1)
443

444
    @skipIfUnsupportedMinOpsetVersion(9)
445
    def test_index_mask(self):
446
        class MyModel(torch.nn.Module):
447
            def forward(self, input):
448
                return input[torch.tensor([0, 1, 0], dtype=torch.uint8)]
449

450
        m1 = torch.randn(3, 4, 5, 6, 7)
451
        self.run_test(MyModel(), m1)
452

453
        class MyModel(torch.nn.Module):
454
            def forward(self, input):
455
                return input[torch.tensor([0, 1, 0], dtype=torch.bool)]
456

457
        m1 = torch.randn(3, 4, 5, 6, 7)
458
        self.run_test(MyModel(), m1)
459

460
    @skipIfUnsupportedMinOpsetVersion(9)
461
    def test_data(self):
462
        class Data(torch.jit.ScriptModule):
463
            @torch.jit.script_method
464
            def forward(self, x):
465
                return x.new_zeros(x.data.size())
466

467
        x = torch.randn(3, 4)
468
        self.run_test(Data(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
469
        self.run_test(Data(), x, remained_onnx_input_idx=[])
470

471
    @skipIfUnsupportedMinOpsetVersion(11)
472
    def test_index_mask_nd(self):
473
        class MyModel(torch.nn.Module):
474
            def forward(self, input):
475
                return input[input > 0]
476

477
        m1 = torch.randn(3, 4, 5, 6, 7)
478
        self.run_test(MyModel(), m1)
479

480
    @skipScriptTest()
481
    def test_dict(self):
482
        class MyModel(torch.nn.Module):
483
            def forward(self, x_in):
484
                x_out = {}
485
                x_out["test_key_out"] = torch.add(
486
                    x_in[list(x_in.keys())[0]], list(x_in.keys())[0]  # noqa: RUF015
487
                )
488
                return x_out
489

490
        x = {torch.tensor(1.0): torch.randn(1, 2, 3)}
491
        self.run_test(MyModel(), (x,))
492

493
    @skipScriptTest()
494
    def test_dict_str(self):
495
        class MyModel(torch.nn.Module):
496
            def forward(self, x_in):
497
                x_out = {}
498
                x_out["test_key_out"] = torch.add(x_in["test_key_in"], 2.0)
499
                return x_out
500

501
        x = {"test_key_in": torch.randn(1, 2, 3)}
502
        self.run_test(MyModel(), (x,))
503

504
    @skipScriptTest()  # User-defined class not supported
505
    def test_dict_output(self):
506
        class DictModelOutput(OrderedDict):
507
            tensor_out: Tensor
508
            tuple_out: Optional[Tuple[Tensor]] = None
509
            list_out: Optional[List[Tensor]] = None
510

511
        class MyModel(torch.nn.Module):
512
            def forward(self, a, b, c, d):
513
                return DictModelOutput(
514
                    tensor_out=a,
515
                    tuple_out=(b, c),
516
                    list_out=[d],
517
                )
518

519
        a = torch.randn(2, 3)
520
        b = torch.randn(2, 3)
521
        c = torch.randn(2, 3)
522
        d = torch.randn(2, 3)
523
        self.run_test(MyModel(), (a, b, c, d))
524

525
    def test_tuple_output(self):
526
        class MyModel(torch.nn.Module):
527
            def forward(self, a, b, c, d):
528
                return a, (b, c), d
529

530
        a = torch.randn(2, 3)
531
        b = torch.randn(2, 3)
532
        c = torch.randn(2, 3)
533
        d = torch.randn(2, 3)
534
        self.run_test(MyModel(), (a, b, c, d))
535

536
    def test_nested_tuple_output(self):
537
        class MyModel(torch.nn.Module):
538
            def forward(self, a, b, c, d):
539
                return a, ((b,), (c, d))
540

541
        a = torch.randn(2, 3)
542
        b = torch.randn(2, 3)
543
        c = torch.randn(2, 3)
544
        d = torch.randn(2, 3)
545
        self.run_test(MyModel(), (a, b, c, d))
546

547
    def test_tuple_input(self):
548
        class TupleModel(torch.nn.Module):
549
            def forward(self, a: Tuple[Tensor, Tensor]):
550
                return a
551

552
        x = (torch.randn(3, 4), torch.randn(4, 3))
553
        self.run_test(TupleModel(), input_args=(x,))
554

555
    def test_tuple_primitive_input(self):
556
        class TupleModel(torch.nn.Module):
557
            def forward(self, a: Tuple[int, Tensor], b):
558
                return a[0], a[1] + b
559

560
        x = (3, torch.randn(4, 3))
561
        y = torch.randn(4, 3)
562
        self.run_test(TupleModel(), input_args=(x, y))
563

564
    def test_nested_tuple_input(self):
565
        class NestedTupleModel(torch.nn.Module):
566
            def forward(self, a, b: Tuple[Tensor, Tuple[Tensor, Tensor]]):
567
                return a + b[0] + b[1][0] + b[1][1]
568

569
        x = torch.randn(4, 5)
570
        y = (torch.randn(4, 5), (torch.randn(1, 5), torch.randn(4, 1)))
571
        self.run_test(NestedTupleModel(), input_args=(x, y))
572

573
    @skipScriptTest()  # Needs https://github.com/pytorch/rfcs/pull/21
574
    @skipIfUnsupportedMinOpsetVersion(15)
575
    def test_mixed_optional_default_none(self):
576
        class Model(torch.nn.Module):
577
            def forward(
578
                self,
579
                x,
580
                y: Optional[Tensor] = None,
581
                z: Optional[Tensor] = None,
582
            ):
583
                if y is not None:
584
                    return x + y
585
                if z is not None:
586
                    return x + z
587
                return x
588

589
        x = torch.randn(2, 3)
590
        y = torch.randn(2, 3)
591
        z = torch.randn(2, 3)
592
        model = Model()
593
        # Without kwargs dict.
594
        self.run_test(model, (x, y, None))
595
        self.run_test(model, (x, None, z))
596
        # With kwargs dict.
597
        self.run_test(model, (x,), {"y": y, "z": None})
598
        self.run_test(model, (x,), {"y": None, "z": z})
599
        self.run_test(model, (x,), {"z": z})
600
        self.run_test(model, (x,), {"y": y})
601

602
    @skipScriptTest()  # tracing eliminates None inputs so it works differently. See _script version below.
603
    @skipIfUnsupportedMinOpsetVersion(15)
604
    def test_mixed_optional_default_tensor(self):
605
        class Model(torch.nn.Module):
606
            def forward(
607
                self,
608
                x,
609
                y: Optional[Tensor] = torch.ones(2, 3),
610
                z: Optional[Tensor] = torch.zeros(2, 3),
611
            ):
612
                if y is not None:
613
                    return x + y
614
                if z is not None:
615
                    return x + z
616
                return x
617

618
        x = torch.randn(2, 3)
619
        y = torch.randn(2, 3)
620
        z = torch.randn(2, 3)
621
        model = Model()
622

623
        self.run_test(model, (x, y, None))
624
        self.run_test(model, (x, None, z))
625

626
    @skipTraceTest()  # tracing is verified with different set of inputs. See above.
627
    @skipIfUnsupportedMinOpsetVersion(15)
628
    def test_mixed_optional_default_tensor_script(self):
629
        class Model(torch.nn.Module):
630
            def forward(
631
                self,
632
                x,
633
                y: Optional[Tensor] = torch.ones(2, 3),
634
                z: Optional[Tensor] = torch.zeros(2, 3),
635
            ):
636
                if y is not None:
637
                    return x + y
638
                if z is not None:
639
                    return x + z
640
                return x
641

642
        x = torch.randn(2, 3)
643
        y = torch.randn(2, 3)
644
        z = torch.randn(2, 3)
645
        model = torch.jit.script(Model())
646

647
        self.run_test(model, (x, y, z), input_names=("x", "y", "z"))
648
        self.run_test(model, (x,), {"y": y, "z": z}, input_names=("x", "y", "z"))
649
        self.run_test(model, (x,), {"y": y}, input_names=("x", "y"))
650

651
        for example_inputs, example_kwargs in (
652
            ((x, y, None), {}),
653
            ((x, None, z), {}),
654
            ((x,), {"y": y, "z": None}),
655
            ((x,), {"y": None, "z": z}),
656
        ):
657
            with self.assertRaisesRegex(
658
                ValueError, "args contained 1 None's after flattening."
659
            ):
660
                self.run_test(
661
                    model, example_inputs, example_kwargs, input_names=("x", "y", "z")
662
                )
663

664
    @skipScriptTest()  # Needs https://github.com/pytorch/rfcs/pull/21
665
    @skipIfUnsupportedMinOpsetVersion(15)
666
    def test_all_optional_default_none(self):
667
        class Model(torch.nn.Module):
668
            def forward(self, x: Optional[Tensor] = None, y: Optional[Tensor] = None):
669
                if x is not None:
670
                    return x
671
                if y is not None:
672
                    return y
673
                else:
674
                    return torch.tensor(-1.0)
675

676
        x = torch.randn(2, 3)
677
        model = Model()
678
        self.run_test(model, (x, None))
679
        self.run_test(
680
            model,
681
            (),
682
            {"x": x, "y": None},
683
            # y disappears in tracing.
684
            input_names=("x",),
685
        )
686

687
    @skipScriptTest()  # tracing eliminates None inputs so it works differently. See _script version below.
688
    @skipIfUnsupportedMinOpsetVersion(15)
689
    def test_all_optional_default_tensor(self):
690
        class Model(torch.nn.Module):
691
            def forward(
692
                self,
693
                x: Optional[Tensor] = torch.ones(2, 3),
694
                y: Optional[Tensor] = torch.zeros(2, 3),
695
            ):
696
                if x is not None:
697
                    return x
698
                elif y is not None:
699
                    return y
700
                else:
701
                    return torch.tensor(-1.0)
702

703
        x = torch.randn(2, 3)
704
        y = torch.randn(2, 3)
705
        model = Model()
706
        self.run_test(model, (x, None))
707
        self.run_test(model, (None, y))
708
        # tracing means y is never used so it's removed from the exported model inputs,
709
        # and we fail when trying to run ORT.
710
        with self.assertRaisesRegex(ValueError, "got too many positional inputs"):
711
            self.run_test(model, (x, y))
712

713
    @skipTraceTest()  # tracing is verified with different set of inputs. See above.
714
    @skipIfUnsupportedMinOpsetVersion(15)
715
    def test_all_optional_default_tensor_script(self):
716
        class Model(torch.nn.Module):
717
            def forward(
718
                self,
719
                x: Optional[Tensor] = torch.ones(2, 3),
720
                y: Optional[Tensor] = torch.zeros(2, 3),
721
            ):
722
                if x is not None:
723
                    return x
724
                elif y is not None:
725
                    return y
726
                else:
727
                    return torch.tensor(-1.0)
728

729
        x = torch.randn(2, 3)
730
        y = torch.randn(2, 3)
731
        model = torch.jit.script(Model())
732

733
        # Optional supports None inputs
734
        self.run_test(model, (x,))
735
        # NOTE: default value is not supported on ONNX, so torch and ONNX has
736
        # different behavior
737
        with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
738
            self.run_test(model, (), {"y": y}, input_names=["y"])
739

740
        self.run_test(model, (x, y))
741
        self.run_test(model, (), {"x": x, "y": y}, input_names=("x", "y"))
742

743
    @skipIfUnsupportedMinOpsetVersion(9)
744
    def test_logit(self):
745
        class Logit(torch.nn.Module):
746
            def __init__(self, eps):
747
                super().__init__()
748
                self.eps = eps
749

750
            def forward(self, x):
751
                return x.logit(self.eps)
752

753
        model = Logit(eps=1e-6)
754
        self.run_test(model, torch.randn(1, 3, 640, 640))
755

756
    class Atleast1d(torch.nn.Module):
757
        def forward(self, t, w, x, y, z):
758
            return torch.atleast_1d((t, w, x, y, z))
759

760
    class Atleast2d(torch.nn.Module):
761
        def forward(self, t, w, x, y, z):
762
            return torch.atleast_2d((t, w, x, y, z))
763

764
    class Atleast3d(torch.nn.Module):
765
        def forward(self, t, w, x, y, z):
766
            return torch.atleast_3d((t, w, x, y, z))
767

768
    class Atleast1dTensor(torch.nn.Module):
769
        def forward(self, x):
770
            return torch.atleast_1d(x)
771

772
    class Atleast2dTensor(torch.nn.Module):
773
        def forward(self, x):
774
            return torch.atleast_2d(x)
775

776
    class Atleast3dTensor(torch.nn.Module):
777
        def forward(self, x):
778
            return torch.atleast_3d(x)
779

780
    @skipScriptTest()  # tracing uses prim::ListUnpack to avoid onnx::SequenceConstruct
781
    @skipIfUnsupportedMinOpsetVersion(11)
782
    @common_utils.parametrize("module_class", (Atleast1d, Atleast2d, Atleast3d))
783
    def test_atleast_nd_list_input(self, module_class: torch.nn.Module):
784
        inputs = (
785
            torch.tensor(1.0),
786
            torch.randn(2),
787
            torch.randn(2, 3),
788
            torch.randn(2, 3, 4),
789
            torch.randn(2, 3, 4, 5),
790
        )
791
        self.run_test(module_class(), inputs)
792

793
    @skipScriptTest()  # tracing uses prim::ListUnpack to avoid onnx::SequenceConstruct
794
    @skipIfUnsupportedMinOpsetVersion(11)
795
    @common_utils.parametrize(
796
        "module_class", (Atleast1dTensor, Atleast2dTensor, Atleast3dTensor)
797
    )
798
    @common_utils.parametrize(
799
        "inputs",
800
        [
801
            torch.tensor(1.0),
802
            torch.randn(2),
803
            torch.randn(2, 3),
804
            torch.randn(2, 3, 4),
805
            torch.randn(2, 3, 4, 5),
806
        ],
807
    )
808
    def test_atleast_nd_single_tensor_input(
809
        self, module_class: torch.nn.Module, inputs: torch.Tensor
810
    ):
811
        self.run_test(module_class(), inputs)
812

813
    @skipScriptTest()  # Needs https://github.com/pytorch/rfcs/pull/21
814
    @skipIfUnsupportedMinOpsetVersion(15)
815
    def test_mixed_optional(self):
816
        class Model(torch.nn.Module):
817
            def forward(self, x, y: Optional[Tensor]):
818
                if y is not None:
819
                    return x + y
820
                return x
821

822
        x = torch.randn(2, 3)
823
        model = Model()
824
        self.run_test(model, (x, None))
825
        self.run_test(model, (x, x))
826

827
    @skipScriptTest()  # Needs https://github.com/pytorch/rfcs/pull/21
828
    @skipIfUnsupportedMinOpsetVersion(15)
829
    def test_tuple_of_optional(self):
830
        class Model(torch.nn.Module):
831
            def forward(self, x, y: Tuple[Optional[Tensor], Optional[Tensor]]):
832
                if y[0] is not None:
833
                    return x + y[0]
834
                if y[1] is not None:
835
                    return x + y[1]
836
                return x
837

838
        x = torch.randn(2, 3)
839
        y1 = torch.randn(2, 3)
840
        self.run_test(Model(), (x, (None, y1)))
841

842
    @skipScriptTest()  # tracing eliminates None inputs so it works differently. See _script version below.
843
    @skipIfUnsupportedMinOpsetVersion(15)
844
    def test_tuple_of_optional_default_tensor(self):
845
        class Model(torch.nn.Module):
846
            def forward(
847
                self,
848
                x,
849
                y: Tuple[Optional[Tensor], Optional[Tensor]] = (
850
                    torch.zeros(2, 3),
851
                    torch.zeros(2, 3),
852
                ),
853
            ):
854
                y0, y1 = y
855
                if y0 is not None:
856
                    return x + y0
857
                if y1 is not None:
858
                    return x + y1
859
                return x
860

861
        x = torch.randn(2, 3)
862
        y1 = torch.randn(2, 3)
863
        self.run_test(Model(), (x, (None, y1)))
864

865
    @skipTraceTest()  # tracing is verified with different set of inputs. See above.
866
    @skipIfUnsupportedMinOpsetVersion(15)
867
    def test_tuple_of_optional_default_tensor_script(self):
868
        class Model(torch.nn.Module):
869
            def forward(
870
                self,
871
                x,
872
                y: Tuple[Optional[Tensor], Optional[Tensor]] = (
873
                    torch.zeros(2, 3),
874
                    torch.zeros(2, 3),
875
                ),
876
            ):
877
                y0, y1 = y
878
                if y0 is not None:
879
                    return x + y0
880
                if y1 is not None:
881
                    return x + y1
882
                return x
883

884
        x = torch.randn(2, 3)
885
        y0 = torch.randn(2, 3)
886
        y1 = torch.randn(2, 3)
887
        model = torch.jit.script(Model())
888
        with self.assertRaisesRegex(
889
            ValueError, "args contained 1 None's after flattening."
890
        ):
891
            self.run_test(model, (x, (None, y1)))
892
        self.run_test(model, (x, (y0, y1)))
893
        # export succeeds, but running ORT through run_test would fail because the exported model
894
        # has the inputs flattened into 3 inputs.
895
        torch.onnx.export(
896
            model, (x, {"y": (y0, y1)}), io.BytesIO(), opset_version=self.opset_version
897
        )
898

899
    def test_primitive_input_integer(self):
900
        class Model(torch.nn.Module):
901
            def forward(self, x: int, y):
902
                return x + y
903

904
        x = 3
905
        y = torch.randint(10, (2, 3, 4))
906
        self.run_test(Model(), (x, y))
907

908
    @skipDtypeChecking
909
    def test_primitive_input_floating(self):
910
        class Model(torch.nn.Module):
911
            def forward(self, x: float, y):
912
                return x + y
913

914
        x = 3.0
915
        y = torch.randn(2, 3, 4)
916
        self.run_test(Model(), (x, y))
917

918
    def test_primitive_input_bool(self):
919
        class Model(torch.nn.Module):
920
            def forward(self, flag: bool, x, y):
921
                if flag:
922
                    return x
923
                else:
924
                    return y
925

926
        flag = True
927
        x = torch.randn(2, 3, 4)
928
        y = torch.randn(2, 3, 4)
929
        self.run_test(torch.jit.script(Model()), (flag, x, y))
930

931
    @skipIfUnsupportedMinOpsetVersion(9)
932
    def test_cste_script(self):
933
        class MyModel(torch.jit.ScriptModule):
934
            @torch.jit.script_method
935
            def forward(self, x):
936
                return torch.zeros(x.size(0)), torch.ones(
937
                    (x.size(1), x.size(0)), dtype=torch.int64
938
                )
939

940
        x = torch.randn(3, 4)
941
        self.run_test(MyModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
942
        self.run_test(MyModel(), x, remained_onnx_input_idx=[])
943

944
    def test_scalar_tensor(self):
945
        class test(torch.nn.Module):
946
            def forward(self, input):
947
                return torch.scalar_tensor(input.size(0)), torch.scalar_tensor(
948
                    input.size(1), dtype=torch.int64
949
                )
950

951
        x = torch.randn(2, 3, 4)
952
        y = torch.randn(7, 8, 9)
953
        model = test()
954
        self.run_test(
955
            model,
956
            x,
957
            additional_test_inputs=[y],
958
            input_names=["input_1"],
959
            dynamic_axes={"input_1": [0, 1, 2]},
960
        )
961

962
    def test_tensor(self):
963
        class ScalarInputModel(torch.jit.ScriptModule):
964
            @torch.jit.script_method
965
            def forward(self, input):
966
                return torch.tensor(input.shape[1])
967

968
        x = torch.randn(3, 4)
969
        self.run_test(
970
            ScalarInputModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}
971
        )
972
        self.run_test(ScalarInputModel(), x, remained_onnx_input_idx=[])
973

974
        class TensorInputModel(torch.jit.ScriptModule):
975
            @torch.jit.script_method
976
            def forward(self, input):
977
                return torch.tensor([input.shape[0], input.shape[1]])
978

979
        x = torch.randn(3, 4)
980
        self.run_test(
981
            TensorInputModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}
982
        )
983
        self.run_test(TensorInputModel(), x, remained_onnx_input_idx=[])
984

985
        class FloatInputModel(torch.jit.ScriptModule):
986
            @torch.jit.script_method
987
            def forward(self, input):
988
                return torch.tensor([float(input)])
989

990
        x = torch.randn(1)
991
        self.run_test(FloatInputModel(), x)
992

993
        class InputWithDtypeModel(torch.jit.ScriptModule):
994
            @torch.jit.script_method
995
            def forward(self, input):
996
                return torch.tensor(input.shape[1], dtype=torch.long)
997

998
        x = torch.randn(3, 4)
999
        self.run_test(
1000
            InputWithDtypeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]}
1001
        )
1002
        self.run_test(InputWithDtypeModel(), x, remained_onnx_input_idx=[])
1003

1004
        class MixedInputModel(torch.jit.ScriptModule):
1005
            @torch.jit.script_method
1006
            def forward(self, input):
1007
                return torch.tensor([input.shape[0], int(input)])
1008

1009
        x = torch.randn(1)
1010
        self.run_test(MixedInputModel(), x)
1011

1012
    def test_hardtanh(self):
1013
        model = torch.nn.Hardtanh(-1.5, 2.5)
1014
        x = torch.arange(-5, 5).to(dtype=torch.float32)
1015
        self.run_test(model, x)
1016

1017
    def test_hardtanh_script_with_default_values(self):
1018
        class MyModel(torch.jit.ScriptModule):
1019
            @torch.jit.script_method
1020
            def forward(self, x):
1021
                return torch.nn.functional.hardtanh(x)
1022

1023
        x = torch.arange(-5, 5).to(dtype=torch.float32)
1024
        self.run_test(MyModel(), x)
1025

1026
    def test_hardswish(self):
1027
        model = torch.nn.Hardswish()
1028

1029
        x = torch.rand(3, 3).to(dtype=torch.float32)
1030
        self.run_test(model, x)
1031

1032
        # Testing edge cases
1033
        x = torch.tensor(3).to(dtype=torch.float32)
1034
        self.run_test(model, x)
1035
        x = torch.tensor(-3).to(dtype=torch.float32)
1036
        self.run_test(model, x)
1037

1038
    def test_hardswish_script(self):
1039
        class MyModel(torch.jit.ScriptModule):
1040
            @torch.jit.script_method
1041
            def forward(self, x):
1042
                return torch.nn.functional.hardswish(x)
1043

1044
        x = torch.rand(3, 3).to(dtype=torch.float32)
1045
        self.run_test(MyModel(), x)
1046

1047
    def test_hardsigmoid(self):
1048
        model = torch.nn.Hardsigmoid()
1049

1050
        x = torch.rand(3, 3).to(dtype=torch.float32)
1051
        self.run_test(model, x)
1052

1053
        # corner cases
1054
        x = torch.tensor(3).to(dtype=torch.float32)
1055
        self.run_test(model, x)
1056
        x = torch.tensor(-3).to(dtype=torch.float32)
1057
        self.run_test(model, x)
1058

1059
    def test_tanhshrink(self):
1060
        model = torch.nn.Tanhshrink()
1061

1062
        x = torch.rand(3, 3).to(dtype=torch.float32)
1063
        self.run_test(model, x)
1064

1065
    @skipIfUnsupportedMinOpsetVersion(9)
1066
    def test_hardshrink(self):
1067
        model = torch.nn.Hardshrink()
1068

1069
        x = torch.rand(3, 3).to(dtype=torch.float32)
1070
        self.run_test(model, x)
1071

1072
        # Testing edge cases
1073
        x = torch.tensor(0.5).to(dtype=torch.float32)
1074
        self.run_test(model, x)
1075
        x = torch.tensor(-0.5).to(dtype=torch.float32)
1076
        self.run_test(model, x)
1077

1078
    @skipIfUnsupportedMinOpsetVersion(9)
1079
    def test_hardshrink_dtype(self):
1080
        x = torch.rand(3, 3).to(dtype=torch.float64)
1081
        self.run_test(torch.nn.Hardshrink(), x)
1082

1083
    @skipIfUnsupportedMinOpsetVersion(9)
1084
    def test_softshrink(self):
1085
        model = torch.nn.Softshrink()
1086

1087
        x = torch.rand(3, 3).to(dtype=torch.float32)
1088
        self.run_test(model, x)
1089

1090
        # Testing edge cases
1091
        x = torch.tensor(0.5).to(dtype=torch.float32)
1092
        self.run_test(model, x)
1093
        x = torch.tensor(-0.5).to(dtype=torch.float32)
1094
        self.run_test(model, x)
1095

1096
    @skipIfUnsupportedMinOpsetVersion(9)
1097
    def test_softshrink_dtype(self):
1098
        x = torch.rand(3, 3).to(dtype=torch.float64)
1099
        self.run_test(torch.nn.Softshrink(), x)
1100

1101
    def test_clamp(self):
1102
        class ClampModel(torch.nn.Module):
1103
            def forward(self, x):
1104
                return x.clamp(-0.5, 0.5)
1105

1106
        x = torch.randn(3, 4)
1107
        self.run_test(ClampModel(), x)
1108

1109
        class ClampMinModel(torch.nn.Module):
1110
            def forward(self, x):
1111
                return x.clamp(min=-0.5)
1112

1113
        x = torch.randn(3, 4)
1114
        self.run_test(ClampMinModel(), x)
1115

1116
        class ClampMaxModel(torch.nn.Module):
1117
            def forward(self, x):
1118
                return x.clamp(max=0.5)
1119

1120
        x = torch.randn(3, 4)
1121
        self.run_test(ClampMaxModel(), x)
1122

1123
    @skipIfUnsupportedMinOpsetVersion(8)
1124
    def test_clamp_dyn(self):
1125
        class ClampMaxModel(torch.jit.ScriptModule):
1126
            @torch.jit.script_method
1127
            def forward(self, x):
1128
                return x.clamp(None, x.size(0))
1129

1130
        x = torch.arange(16).view(4, 4).float()
1131
        self.run_test(ClampMaxModel(), x)
1132

1133
        class ClampMinModel(torch.jit.ScriptModule):
1134
            @torch.jit.script_method
1135
            def forward(self, x):
1136
                return x.clamp(x.size(0), None)
1137

1138
        x = torch.arange(16).view(4, 4).float()
1139
        self.run_test(ClampMinModel(), x)
1140

1141
        class ClampMinMaxModel(torch.jit.ScriptModule):
1142
            @torch.jit.script_method
1143
            def forward(self, x):
1144
                return x.clamp(x.size(0), x.size(1))
1145

1146
        x = torch.arange(16).view(2, 8).float()
1147
        self.run_test(ClampMinMaxModel(), x)
1148

1149
        class ClampTensorModel(torch.nn.Module):
1150
            def forward(self, x, min, max):
1151
                return x.clamp(min, max)
1152

1153
        x = torch.randn(3, 4)
1154
        y = torch.randn(3, 4)
1155
        z = torch.randn(3, 4)
1156
        self.run_test(ClampTensorModel(), (x, y, z))
1157

1158
        class ClampTensorMinModel(torch.nn.Module):
1159
            def forward(self, x, min):
1160
                return x.clamp(min=min)
1161

1162
        self.run_test(ClampTensorMinModel(), (x, y))
1163

1164
        class ClampTensorMaxModel(torch.nn.Module):
1165
            def forward(self, x, max):
1166
                return x.clamp(max=max)
1167

1168
        self.run_test(ClampTensorMaxModel(), (x, z))
1169

1170
    @skipIfUnsupportedMinOpsetVersion(9)
1171
    def test_full_trace(self):
1172
        class FullModel(torch.nn.Module):
1173
            def forward(self, x):
1174
                return torch.full((3, 4), x, dtype=torch.long)
1175

1176
        x = torch.tensor(12)
1177
        self.run_test(FullModel(), x)
1178

1179
    @skipIfUnsupportedMinOpsetVersion(9)
1180
    def test_full_script(self):
1181
        class FullModelScripting(torch.jit.ScriptModule):
1182
            @torch.jit.script_method
1183
            def forward(self, x):
1184
                return torch.full((3, 4), x, dtype=torch.long)
1185

1186
        x = torch.tensor(12)
1187
        self.run_test(FullModelScripting(), x)
1188

1189
    def test_fuse_addmm(self):
1190
        class AddmmModel(torch.nn.Module):
1191
            def forward(self, x):
1192
                return torch.mm(x, x) + x
1193

1194
        x = torch.ones(3, 3)
1195
        self.run_test(AddmmModel(), x)
1196

1197
    def test_maxpool(self):
1198
        model = torch.nn.MaxPool1d(2, stride=1)
1199
        x = torch.randn(20, 16, 50)
1200
        self.run_test(model, x)
1201

1202
    def test_conv(self):
1203
        class TraceModel(torch.nn.Module):
1204
            def __init__(self):
1205
                super().__init__()
1206
                self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2)
1207
                self.conv2 = torch.nn.Conv2d(
1208
                    16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)
1209
                )
1210
                self.conv3 = torch.nn.Conv3d(
1211
                    16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)
1212
                )
1213

1214
            def forward(self, input1, input2, input3):
1215
                return self.conv1(input1), self.conv2(input2), self.conv3(input3)
1216

1217
        x1 = torch.randn(20, 16, 50)
1218
        x2 = torch.randn(20, 16, 50, 50)
1219
        x3 = torch.randn(20, 16, 10, 50, 50)
1220

1221
        self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
1222

1223
    def test_conv_str_padding(self):
1224
        class TraceModel(torch.nn.Module):
1225
            def __init__(self):
1226
                super().__init__()
1227
                self.conv1 = torch.nn.Conv1d(16, 33, 3, padding="valid")
1228
                self.conv2 = torch.nn.Conv2d(
1229
                    16, 33, (3, 5), stride=1, padding="valid", dilation=(3, 1)
1230
                )
1231
                self.conv3 = torch.nn.Conv3d(
1232
                    16, 33, (3, 5, 2), stride=1, padding="same"
1233
                )
1234

1235
            def forward(self, input1, input2, input3):
1236
                return self.conv1(input1), self.conv2(input2), self.conv3(input3)
1237

1238
        x1 = torch.randn(20, 16, 50)
1239
        x2 = torch.randn(20, 16, 50, 50)
1240
        x3 = torch.randn(20, 16, 10, 50, 50)
1241

1242
        self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
1243

1244
    def test_conv_shape_inference(self):
1245
        class Model(torch.nn.Module):
1246
            def __init__(self):
1247
                super().__init__()
1248
                self.conv2 = torch.nn.Conv2d(
1249
                    16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)
1250
                )
1251

1252
            def forward(self, input):
1253
                return self.conv2(input) + 2
1254

1255
        x = torch.randn(20, 16, 50, 100)
1256
        self.run_test(
1257
            Model(), x, atol=10e-5, input_names=["x"], dynamic_axes={"x": [0]}
1258
        )
1259

1260
    def test_conv_transpose(self):
1261
        class TraceModel(torch.nn.Module):
1262
            def __init__(self):
1263
                super().__init__()
1264
                self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2)
1265
                self.conv2 = torch.nn.ConvTranspose2d(
1266
                    16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)
1267
                )
1268
                self.conv3 = torch.nn.ConvTranspose3d(
1269
                    16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)
1270
                )
1271

1272
            def forward(self, input1, input2, input3):
1273
                return self.conv1(input1), self.conv2(input2), self.conv3(input3)
1274

1275
        x1 = torch.randn(20, 16, 10)
1276
        x2 = torch.randn(20, 16, 10, 10)
1277
        x3 = torch.randn(20, 16, 10, 10, 10)
1278

1279
        self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
1280

1281
    def test_numpy_T(self):
1282
        class NumpyTranspose(torch.nn.Module):
1283
            def forward(self, x):
1284
                return x.T
1285

1286
        self.run_test(NumpyTranspose(), torch.randn(4, 7))
1287

1288
    # Conversion of Transpose depends on input shape to be known.
1289
    # The following test only works when onnx shape inference is enabled.
1290
    def test_transpose_infer_shape(self):
1291
        class TransposeModule(torch.jit.ScriptModule):
1292
            def __init__(self):
1293
                super().__init__()
1294
                self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
1295

1296
            @torch.jit.script_method
1297
            def forward(self, x):
1298
                x = self.conv(x)
1299
                return x.transpose(0, 1)
1300

1301
        x = torch.randn(32, 3, 64, 64)
1302
        y = torch.randn(16, 3, 8, 64)
1303
        self.run_test(
1304
            TransposeModule(),
1305
            x,
1306
            input_names=["x"],
1307
            dynamic_axes={"x": [0, 2]},
1308
            additional_test_inputs=[y],
1309
        )
1310

1311
    def squeeze_model_tests(self, d, x1, x2):
1312
        class Squeeze(torch.nn.Module):
1313
            def __init__(self, d):
1314
                super().__init__()
1315
                self.d = d
1316

1317
            def forward(self, x):
1318
                if self.d is not None:
1319
                    return torch.squeeze(x, dim=self.d)
1320
                else:
1321
                    return torch.squeeze(x)
1322

1323
        x2 = [] if x2 is None else [x2]
1324
        if len(x2) > 0:
1325
            self.run_test(
1326
                Squeeze(d),
1327
                x1,
1328
                input_names=["input"],
1329
                dynamic_axes={"input": {0: "0", 1: "1", 2: "2"}},
1330
                additional_test_inputs=x2,
1331
            )
1332
        else:
1333
            self.run_test(Squeeze(d), x1)
1334

1335
    def test_squeeze_without_no_op(self):
1336
        x = torch.randn(2, 1, 4)
1337
        self.squeeze_model_tests(1, x, None)
1338

1339
    @skipIfUnsupportedMinOpsetVersion(11)
1340
    def test_squeeze_dynamic(self):
1341
        x_squeeze = torch.randn(2, 1, 4)
1342
        x_noop = torch.randn(2, 2, 3)
1343
        self.squeeze_model_tests(1, x_squeeze, x_noop)
1344

1345
    def test_squeeze_neg_without_no_op(self):
1346
        x = torch.randn(2, 1, 4)
1347
        self.squeeze_model_tests(-2, x, None)
1348

1349
    @skipIfUnsupportedMinOpsetVersion(11)
1350
    def test_squeeze_neg(self):
1351
        x_squeeze = torch.randn(2, 1, 4)
1352
        x_noop = torch.randn(2, 2, 3)
1353
        self.squeeze_model_tests(-2, x_squeeze, x_noop)
1354

1355
    def test_squeeze_all_dims(self):
1356
        x_squeeze = torch.randn(2, 1, 4)
1357
        x_noop = torch.randn(2, 2, 3)
1358
        self.squeeze_model_tests(None, x_squeeze, x_noop)
1359

1360
    @skipIfUnsupportedMinOpsetVersion(11)
1361
    def test_squeeze_no_op(self):
1362
        x_noop = torch.randn(2, 1, 4)
1363
        x_squeeze = torch.randn(2, 2, 1)
1364
        self.squeeze_model_tests(2, x_noop, x_squeeze)
1365

1366
    @skipIfUnsupportedMinOpsetVersion(11)
1367
    def test_squeeze_runtime_dim(self):
1368
        class Squeeze(torch.nn.Module):
1369
            def forward(self, d1, d2):
1370
                t = torch.zeros(d1[0], d2[0])
1371
                return t.squeeze(0)
1372

1373
        d1 = torch.tensor([1])
1374
        d3 = torch.tensor([3])
1375
        d4 = torch.tensor([4])
1376
        self.run_test(Squeeze(), (d1, d4), additional_test_inputs=[(d3, d4)])
1377
        self.run_test(Squeeze(), (d3, d4), additional_test_inputs=[(d1, d3)])
1378

1379
    def test_squeeze(self):
1380
        class Squeeze(torch.nn.Module):
1381
            def forward(self, x):
1382
                return torch.squeeze(x, dim=-2)
1383

1384
        x = torch.randn(2, 1, 4)
1385
        self.run_test(Squeeze(), x)
1386

1387
    @skipIfUnsupportedMinOpsetVersion(13)
1388
    def test_squeeze_dynamic_dim(self):
1389
        class Squeeze(torch.nn.Module):
1390
            def forward(self, x, dim: int):
1391
                return torch.squeeze(x, dim)
1392

1393
        x = torch.randn(2, 1, 4)
1394
        dim = 1
1395
        self.run_test(Squeeze(), (x, dim))
1396

1397
    def test_unsqueeze(self):
1398
        class Unsqueeze(torch.nn.Module):
1399
            def forward(self, x):
1400
                return torch.unsqueeze(x, dim=-2)
1401

1402
        x = torch.randn(2, 3, 4)
1403
        self.run_test(Unsqueeze(), x)
1404

1405
    @skipIfUnsupportedMinOpsetVersion(13)
1406
    def test_unsqueeze_dynamic_dim(self):
1407
        class Unsqueeze(torch.nn.Module):
1408
            def forward(self, x, dim: int):
1409
                return torch.unsqueeze(x, dim)
1410

1411
        x = torch.randn(2, 1, 4)
1412
        dim = -1
1413
        self.run_test(Unsqueeze(), (x, dim))
1414

1415
    def test_maxpool_default_stride(self):
1416
        class MaxPoolModel(torch.nn.Module):
1417
            def forward(self, x):
1418
                return torch.nn.functional.max_pool2d(x, 2)
1419

1420
        model = MaxPoolModel()
1421
        x = torch.randn(10, 20, 16, 50)
1422
        self.run_test(model, x)
1423

1424
    @skipIfUnsupportedMinOpsetVersion(8)
1425
    def test_maxpool_adaptive(self):
1426
        model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False)
1427
        x = torch.randn(20, 16, 50, requires_grad=True)
1428
        y = torch.randn(32, 16, 50, requires_grad=True)
1429
        self.run_test(
1430
            model,
1431
            x,
1432
            input_names=["x"],
1433
            dynamic_axes={"x": [0]},
1434
            additional_test_inputs=[y],
1435
        )
1436

1437
    def test_maxpool_2d(self):
1438
        model = torch.nn.MaxPool2d(5, padding=(1, 2))
1439
        x = torch.randn(1, 20, 16, 50, requires_grad=True)
1440
        self.run_test(model, x)
1441

1442
    def test_maxpool_1d_ceil(self):
1443
        model = torch.nn.MaxPool1d(3, 2, ceil_mode=True)
1444
        x = torch.randn(20, 16, 50)
1445
        self.run_test(model, x)
1446

1447
    def test_maxpool_2d_ceil(self):
1448
        model = torch.nn.MaxPool2d(3, 2, ceil_mode=True)
1449
        x = torch.randn(20, 16, 50, 32)
1450
        self.run_test(model, x)
1451

1452
    def test_maxpool_3d_ceil(self):
1453
        model = torch.nn.MaxPool3d(3, 2, ceil_mode=True)
1454
        x = torch.randn(20, 16, 50, 44, 31)
1455
        self.run_test(model, x)
1456

1457
    @skipIfUnsupportedMinOpsetVersion(10)
1458
    def test_maxpool_dynamic(self):
1459
        class test(torch.nn.Module):
1460
            def __init__(self, in_channels, out_channels):
1461
                super().__init__()
1462
                norm_layer = functools.partial(torch.nn.BatchNorm2d, eps=0.0009)
1463
                self.avgpool = torch.nn.MaxPool2d((2, 2), stride=2, ceil_mode=True)
1464
                self.conv = torch.nn.Conv2d(
1465
                    in_channels, out_channels, kernel_size=1, stride=1, bias=False
1466
                )
1467
                self.norm = norm_layer(out_channels)
1468

1469
            def forward(self, x):
1470
                return self.norm(self.conv(self.avgpool(x)))
1471

1472
        model = test(8, 16)
1473
        inputs = torch.randn(2, 8, 64, 64)
1474
        self.run_test(
1475
            model,
1476
            inputs,
1477
            input_names=["input_0"],
1478
            dynamic_axes={"input_0": {3: "x", 2: "y"}, "output_0": {3: "x", 2: "y"}},
1479
            output_names=["output_0"],
1480
        )
1481

1482
    # TODO: Enable maxpool-ceil family after ONNX 1.15.1+ is bumped
1483
    @skipIfUnsupportedMaxOpsetVersion(9)
1484
    def test_maxpool_1d_ceil_corner(self):
1485
        model = torch.nn.MaxPool1d(
1486
            kernel_size=1, dilation=1, stride=2, ceil_mode=True, return_indices=False
1487
        )
1488
        x = torch.randn(1, 3, 32)
1489
        self.run_test(model, x)
1490

1491
    @skipIfUnsupportedMaxOpsetVersion(9)
1492
    def test_maxpool_2d_ceil_corner(self):
1493
        model = torch.nn.MaxPool2d(
1494
            kernel_size=[1, 1],
1495
            dilation=[1, 1],
1496
            stride=[2, 2],
1497
            ceil_mode=True,
1498
            return_indices=False,
1499
        )
1500
        x = torch.randn(1, 3, 32, 32)
1501
        self.run_test(model, x)
1502

1503
    @skipIfUnsupportedMaxOpsetVersion(9)
1504
    def test_maxpool_3d_ceil_corner(self):
1505
        model = torch.nn.MaxPool3d(
1506
            kernel_size=[7, 8, 4],
1507
            dilation=[1, 1, 1],
1508
            stride=[10, 11, 3],
1509
            padding=[2, 2, 2],
1510
            ceil_mode=True,
1511
            return_indices=False,
1512
        )
1513
        x = torch.randn(1, 3, 51, 52, 45)
1514
        self.run_test(model, x)
1515

1516
    @skipIfUnsupportedMaxOpsetVersion(9)
1517
    @skipIfUnsupportedMinOpsetVersion(8)
1518
    def test_maxpool_1d_ceil_corner_with_indices(self):
1519
        model = torch.nn.MaxPool1d(
1520
            kernel_size=1, dilation=1, stride=2, ceil_mode=True, return_indices=True
1521
        )
1522
        x = torch.randn(1, 3, 32)
1523
        self.run_test(model, x)
1524

1525
    @skipIfUnsupportedMaxOpsetVersion(9)
1526
    @skipIfUnsupportedMinOpsetVersion(8)
1527
    def test_maxpool_2d_ceil_corner_with_indices(self):
1528
        model = torch.nn.MaxPool2d(
1529
            kernel_size=[1, 1],
1530
            dilation=[1, 1],
1531
            stride=[2, 2],
1532
            ceil_mode=True,
1533
            return_indices=True,
1534
        )
1535
        x = torch.randn(1, 3, 32, 32)
1536
        self.run_test(model, x)
1537

1538
    @skipIfUnsupportedMaxOpsetVersion(9)
1539
    @skipIfUnsupportedMinOpsetVersion(8)
1540
    def test_maxpool_3d_ceil_corner_with_indices(self):
1541
        model = torch.nn.MaxPool3d(
1542
            kernel_size=[7, 8, 4],
1543
            dilation=[1, 1, 1],
1544
            stride=[10, 11, 3],
1545
            padding=[2, 2, 2],
1546
            ceil_mode=True,
1547
            return_indices=True,
1548
        )
1549
        x = torch.randn(1, 3, 51, 52, 45)
1550
        self.run_test(model, x)
1551

1552
    @skipIfUnsupportedMinOpsetVersion(8)
1553
    def test_maxpool_with_indices(self):
1554
        model = torch.nn.MaxPool1d(2, stride=1, return_indices=True)
1555
        x = torch.randn(20, 16, 50)
1556
        self.run_test(model, x)
1557

1558
    @skipIfUnsupportedMinOpsetVersion(10)
1559
    def test_maxpool_dilation(self):
1560
        model = torch.nn.MaxPool1d(2, stride=1, dilation=2)
1561
        x = torch.randn(20, 16, 50)
1562
        self.run_test(model, x)
1563

1564
    def test_avgpool_default_stride(self):
1565
        class AvgPoolModel(torch.nn.Module):
1566
            def forward(self, x):
1567
                return torch.nn.functional.avg_pool2d(x, 2)
1568

1569
        model = AvgPoolModel()
1570
        x = torch.randn(10, 20, 16, 50)
1571
        self.run_test(model, x)
1572

1573
    def test_avgpool(self):
1574
        model = torch.nn.AvgPool1d(2, stride=1)
1575
        x = torch.randn(20, 16, 50)
1576
        self.run_test(model, x)
1577

1578
    def test_avgpool_1d_ceil(self):
1579
        model = torch.nn.AvgPool1d(3, 2, ceil_mode=True)
1580
        x = torch.randn(1, 1, 7)
1581
        self.run_test(model, x)
1582

1583
    # TODO: ceil_mode is not included in the test, because of
1584
    # https://github.com/microsoft/onnxruntime/issues/16203
1585
    # The ORT and PyTorch has different calculation for ceil_mode (the last value).
1586
    @common_utils.parametrize(
1587
        "padding",
1588
        (0, 1),
1589
    )
1590
    @common_utils.parametrize(
1591
        "count_include_pad",
1592
        (True, False),
1593
    )
1594
    def test_avgpool_2d(self, padding, count_include_pad):
1595
        model = torch.nn.AvgPool2d(
1596
            3,
1597
            3,
1598
            padding=padding,
1599
            count_include_pad=count_include_pad,
1600
        )
1601
        x = torch.randn(20, 16, 50, 32)
1602
        self.run_test(model, x)
1603

1604
    # TODO: ceil_mode is not included in the test, because of
1605
    # https://github.com/microsoft/onnxruntime/issues/16203
1606
    # The ORT and PyTorch has different calculation for ceil_mode (the last value).
1607
    @skipIfUnsupportedMinOpsetVersion(19)
1608
    def test_avgpool_3d_ceil(self):
1609
        model = torch.nn.AvgPool3d(3, 2, ceil_mode=True)
1610
        x = torch.randn(20, 16, 50, 44, 31)
1611
        y = torch.randn(32, 8, 50, 44, 31)
1612
        self.run_test(
1613
            model,
1614
            x,
1615
            input_names=["x"],
1616
            dynamic_axes={"x": [0, 1]},
1617
            additional_test_inputs=[y],
1618
        )
1619

1620
    @skipIfUnsupportedMinOpsetVersion(10)
1621
    def test_avgpool_dynamic(self):
1622
        class test(torch.nn.Module):
1623
            def __init__(self, in_channels, out_channels):
1624
                super().__init__()
1625
                norm_layer = functools.partial(torch.nn.BatchNorm2d, eps=0.0009)
1626
                self.avgpool = torch.nn.AvgPool2d(
1627
                    (2, 2), stride=2, ceil_mode=True, count_include_pad=False
1628
                )
1629
                self.conv = torch.nn.Conv2d(
1630
                    in_channels, out_channels, kernel_size=1, stride=1, bias=False
1631
                )
1632
                self.norm = norm_layer(out_channels)
1633

1634
            def forward(self, x):
1635
                return self.norm(self.conv(self.avgpool(x)))
1636

1637
        model = test(8, 16)
1638
        inputs = torch.randn(2, 8, 64, 64)
1639
        self.run_test(
1640
            model,
1641
            inputs,
1642
            input_names=["input_0"],
1643
            dynamic_axes={"input_0": {3: "x", 2: "y"}, "output_0": {3: "x", 2: "y"}},
1644
            output_names=["output_0"],
1645
        )
1646

1647
    @skipIfUnsupportedMinOpsetVersion(9)
1648
    def test_floating_point(self):
1649
        class FloatingPoint(torch.jit.ScriptModule):
1650
            @torch.jit.script_method
1651
            def forward(self, x):
1652
                if x.is_floating_point():
1653
                    return x.new_zeros(x.shape)
1654
                return x.new_zeros(x.shape)
1655

1656
        x = torch.randn(2, 3, 4)
1657
        self.run_test(
1658
            FloatingPoint(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
1659
        )
1660
        self.run_test(FloatingPoint(), x, remained_onnx_input_idx=[])
1661

1662
        class FloatingPoint(torch.jit.ScriptModule):
1663
            @torch.jit.script_method
1664
            def forward(self, x):
1665
                if x.size(0) > 1:
1666
                    a = x + 2
1667
                    if a.is_floating_point():
1668
                        return x + 1
1669
                    return x + 1
1670
                return x
1671

1672
        x = torch.randn(2, 3, 4)
1673
        self.run_test(FloatingPoint(), x)
1674

1675
    # Operator rank mismatch between outputs of two branches for opsets below 11.
1676
    @skipIfUnsupportedMinOpsetVersion(11)
1677
    def test_floating_point_infer_dtype(self):
1678
        class FloatingPoint(torch.jit.ScriptModule):
1679
            @torch.jit.script_method
1680
            def forward(self, x):
1681
                if x.size(0) > 1:
1682
                    a = x + 2
1683
                    if a.is_floating_point():
1684
                        return x.new_zeros(x.shape[1:])
1685
                    return x.new_zeros(x.shape)
1686
                return x
1687

1688
        x = torch.randn(2, 3, 4)
1689
        self.run_test(
1690
            FloatingPoint(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
1691
        )
1692
        self.run_test(FloatingPoint(), x, remained_onnx_input_idx=[])
1693

1694
        class FloatingPoint(torch.jit.ScriptModule):
1695
            @torch.jit.script_method
1696
            def forward(self, x):
1697
                if x.size(0) > 1:
1698
                    a = x + 2
1699
                    if a.is_floating_point():
1700
                        return x + 1
1701
                    return x
1702
                return x
1703

1704
        x = torch.randn(2, 3, 4).to(torch.int32)
1705
        self.run_test(FloatingPoint(), x)
1706

1707
    @skipIfUnsupportedMinOpsetVersion(12)
1708
    def test_prim_min(self):
1709
        @torch.jit.script
1710
        def list_append(boxes: List[Tensor]):
1711
            temp = []
1712
            for i, b in enumerate(
1713
                boxes
1714
            ):  # enumerate is creating a prim::min op in torch graph
1715
                temp.append(torch.full_like(b[:, 1], i))
1716
            return temp[0]
1717

1718
        class Min(torch.nn.Module):
1719
            def forward(self, x):
1720
                boxes = [x for _ in range(3)]
1721
                return list_append(boxes)
1722

1723
        x = torch.rand(5, 5)
1724
        self.run_test(Min(), (x,))
1725

1726
        class M(torch.jit.ScriptModule):
1727
            @torch.jit.script_method
1728
            def forward(self, x):
1729
                i = 3
1730
                return min(x[i], i)
1731

1732
        x = torch.arange(6, dtype=torch.int64)
1733
        self.run_test(M(), (x,))
1734

1735
    def test_arithmetic(self):
1736
        class ArithmeticModule(torch.nn.Module):
1737
            def forward(self, x):
1738
                x = x + 2
1739
                x = x - 4
1740
                x = x * 6
1741
                x = x / 8
1742
                return x
1743

1744
        x = torch.randn(2, 3, 4)
1745
        self.run_test(ArithmeticModule(), x)
1746

1747
    def test_arithmetic_prim_long(self):
1748
        class ArithmeticModule(torch.nn.Module):
1749
            def forward(self, x, y: int):
1750
                x = x + y
1751
                x = x - y
1752
                x = x * (y * 3)
1753
                x = x / (y * 4)
1754
                return x
1755

1756
        x = torch.randn(2, 3, 4)
1757
        y = 2
1758
        self.run_test(ArithmeticModule(), (x, y))
1759

1760
        class ArithmeticModule(torch.nn.Module):
1761
            def forward(self, x):
1762
                x = x + 2
1763
                x = x - 3
1764
                return x.shape[0]
1765

1766
        x = torch.randn(2, 3, 4)
1767
        self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])
1768

1769
    @skipDtypeChecking
1770
    def test_arithmetic_prim_float(self):
1771
        class ArithmeticModule(torch.nn.Module):
1772
            def forward(self, x, y: float):
1773
                x = x + y
1774
                x = x - y
1775
                x = x * (y * 3)
1776
                x = x / (y * 4)
1777
                return x
1778

1779
        x = torch.randn(2, 3, 4)
1780
        y = 2.5
1781
        self.run_test(ArithmeticModule(), (x, y))
1782

1783
        class ArithmeticModule(torch.nn.Module):
1784
            def forward(self, x):
1785
                x = x + 2
1786
                x = x - 3
1787
                return x.shape[1] / 2
1788

1789
        x = torch.randn(2, 3, 4)
1790
        self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])
1791

1792
    @skipDtypeChecking
1793
    def test_arithmetic_prim_bool(self):
1794
        class ArithmeticModule(torch.nn.Module):
1795
            def forward(self, x, y: int, z: bool, t: float):
1796
                x = x + y
1797
                x = x - y
1798
                if z:
1799
                    x = x * (y * 3)
1800
                    x = x / (y * 4)
1801
                return x / t, z
1802

1803
        x = torch.randn(2, 3, 4)
1804
        y = 2
1805
        z = False
1806
        t = 2.5
1807
        self.run_test(ArithmeticModule(), (x, y, z, t))
1808

1809
        class ArithmeticModule(torch.nn.Module):
1810
            def forward(self, x: int, y: int):
1811
                return x == y
1812

1813
        x = 3
1814
        y = 2
1815
        self.run_test(ArithmeticModule(), (x, y))
1816

1817
    @skipScriptTest(
1818
        15,
1819
        reason="In trace: Outputs that are always None are removed. \
1820
                In script: Outputs that are always None are removed before opset 15. \
1821
                After opset 15, we replace the None in output with Optional node.",
1822
    )
1823
    def test_tuple_with_none_outputs(self):
1824
        class TupleModel(torch.nn.Module):
1825
            def forward(self, x):
1826
                return (x, (x, None, (x, None)))
1827

1828
        x = torch.randn(3, 4)
1829
        self.run_test(TupleModel(), (x,))
1830

1831
    # In scripting the first transpose node do not carry shape and dtype info.
1832
    # The following test only works when onnx shape inference is enabled.
1833
    def test_arithmetic_infer_dtype(self):
1834
        class ArithmeticModule(torch.jit.ScriptModule):
1835
            @torch.jit.script_method
1836
            def forward(self, x):
1837
                x = x.t()
1838
                x = x + 2
1839
                x = x - 4
1840
                x = x * 6
1841
                x = x / 8
1842
                return x
1843

1844
        x = torch.randn(2, 3)
1845
        self.run_test(ArithmeticModule(), x)
1846

1847
    @unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)")
1848
    def test_floor_div(self):
1849
        class FloorDivModule(torch.nn.Module):
1850
            def forward(self, x, y):
1851
                return (
1852
                    x // 3,
1853
                    x // 2.0,
1854
                    x.to(dtype=torch.float64) // 3,
1855
                    x.to(dtype=torch.float64) // 2.0,
1856
                    x.to(dtype=torch.int64) // 3,
1857
                    x.to(dtype=torch.int64) // 2.0,
1858
                    x // (y + 1.0).to(dtype=torch.int64),
1859
                    x // y,
1860
                    x.to(dtype=torch.float64) // y.to(dtype=torch.int64),
1861
                    x.to(dtype=torch.float64) // y.to(dtype=torch.float64),
1862
                    x.to(dtype=torch.int64) // y.to(dtype=torch.int64),
1863
                    x.to(dtype=torch.int64) // y,
1864
                )
1865

1866
        x = torch.arange(-2, 4).reshape(2, 3, 1)
1867
        y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4)
1868
        self.run_test(FloorDivModule(), (x, y))
1869

1870
    @unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)")
1871
    def test_floor_div_script(self):
1872
        class FloorDivModule(torch.jit.ScriptModule):
1873
            @torch.jit.script_method
1874
            def forward(self, x, y):
1875
                return x // 3, x // 2.0, x // y
1876

1877
        x = torch.arange(-2, 4).reshape(2, 3, 1)
1878
        y = torch.randn(2, 3, 4)
1879
        self.run_test(FloorDivModule(), (x, y))
1880

1881
    @unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)")
1882
    @skipIfUnsupportedMinOpsetVersion(9)
1883
    def test_floordiv(self):
1884
        class FloordivModule(torch.nn.Module):
1885
            def forward(self, x):
1886
                return x.new_zeros(x.size(2) // x.size(1))
1887

1888
        x = torch.randn(2, 3, 4)
1889
        self.run_test(
1890
            FloordivModule(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
1891
        )
1892
        self.run_test(FloordivModule(), (x,), remained_onnx_input_idx=[])
1893

1894
    def test_div(self):
1895
        class DivModule(torch.nn.Module):
1896
            def forward(self, x, y):
1897
                return x / y, torch.true_divide(x, y)
1898

1899
        x = torch.randn(2, 3, 4).to(torch.int)
1900
        y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
1901
        self.run_test(DivModule(), (x, y))
1902
        self.run_test(DivModule(), (x.float(), y.float()))
1903

1904
    # Note: div cannot (generally) be exported via scripting
1905
    # since its type promotion logic is dependent on knowing the scalar types
1906
    # of the input tensors. That is, the ONNX graph is dependent on the
1907
    # data type of the inputs. This makes it appropriate for tracing only.
1908
    def test_div_promotion_trace(self):
1909
        class DivModule(torch.nn.Module):
1910
            def forward(self, x, y):
1911
                return x / y, torch.true_divide(x, y)
1912

1913
        x = torch.randn(2, 3, 4).to(torch.int)
1914
        y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
1915

1916
        with common_utils.set_default_dtype(torch.float):
1917
            self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
1918

1919
        with common_utils.set_default_dtype(torch.double):
1920
            self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
1921

1922
    # In scripting x, y do not carry shape and dtype info.
1923
    # The following test only works when onnx shape inference is enabled.
1924
    def test_div_promotion_script(self):
1925
        class DivModule(torch.nn.Module):
1926
            def forward(self, x, y):
1927
                # Add transpose to hide shape/type information
1928
                # Otherwise shape and type are still avaiable from input.
1929
                x = x.transpose(1, 2)
1930
                y = y.transpose(1, 2)
1931
                return x / y, torch.true_divide(x, y)
1932

1933
        x = torch.randn(2, 3, 4).to(torch.int)
1934
        y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
1935

1936
        # 1. x,y are int, and output is float.
1937
        #    This can be handled by the default case, where both are cast to float.
1938
        #    It works even if type of x, y are unknown.
1939
        with common_utils.set_default_dtype(torch.float):
1940
            self.run_test(torch.jit.script(DivModule()), (x, y))
1941

1942
        # 2. x,y are int, and output is double.
1943
        #    This can be handled by the default case, where both are cast to double.
1944
        #    It works even if type of x, y are unknown.
1945
        with common_utils.set_default_dtype(torch.double):
1946
            self.run_test(torch.jit.script(DivModule()), (x, y))
1947

1948
        # 3. x is int, y is double, and output is double.
1949
        #    This can only be handled when both type of x and y are known.
1950
        x = torch.randn(2, 3, 4).to(torch.int)
1951
        y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double)
1952
        self.run_test(torch.jit.script(DivModule()), (x, y))
1953

1954
    @skipDtypeChecking
1955
    def test_div_rounding_mode(self):
1956
        class TrueDivModule(torch.nn.Module):
1957
            def forward(self, x, y):
1958
                return (
1959
                    x.div(y, rounding_mode=None),
1960
                    torch.div(x, y, rounding_mode=None),
1961
                )
1962

1963
        class TruncDivModule(torch.nn.Module):
1964
            def forward(self, x, y):
1965
                return (
1966
                    x.div(y, rounding_mode="trunc"),
1967
                    torch.div(x, y, rounding_mode="trunc"),
1968
                )
1969

1970
        class FloorDivModule(torch.nn.Module):
1971
            def forward(self, x, y):
1972
                return (
1973
                    x.div(y, rounding_mode="floor"),
1974
                    torch.div(x, y, rounding_mode="floor"),
1975
                )
1976

1977
        modules = [TrueDivModule(), TruncDivModule(), FloorDivModule()]
1978

1979
        x = (torch.randn(2, 3, 4) * 100).to(torch.int)
1980
        y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
1981

1982
        for module in modules:
1983
            self.run_test(module, (x, y))
1984
            self.run_test(torch.jit.trace(module, (x, y)), (x, y))
1985
            self.run_test(torch.jit.script(module), (x, y))
1986

1987
        x = torch.randn(2, 3, 4)
1988
        y = torch.rand(2, 3, 4) * 10.0 + 0.1
1989

1990
        for module in modules:
1991
            self.run_test(module, (x, y))
1992
            self.run_test(torch.jit.trace(module, (x, y)), (x, y))
1993
            self.run_test(torch.jit.script(module), (x, y))
1994

1995
    def test_slice_trace(self):
1996
        class MyModule(torch.nn.Module):
1997
            def forward(self, x):
1998
                return x[0:1]
1999

2000
        x = torch.randn(3)
2001
        self.run_test(MyModule(), x)
2002

2003
    def test_slice_neg(self):
2004
        class NegSlice(torch.nn.Module):
2005
            def forward(self, x):
2006
                return x[-1:]
2007

2008
        x = torch.randn(3, 4, 5)
2009
        self.run_test(NegSlice(), x)
2010

2011
    def test_slice_neg_large(self):
2012
        class NegSlice(torch.nn.Module):
2013
            def forward(self, x):
2014
                return x[:, :, -3:-1, :, -1]
2015

2016
        x = torch.randn(3, 4, 5, 6, 7)
2017
        self.run_test(NegSlice(), x)
2018

2019
    def test_slice_neg_large_negone(self):
2020
        class NegSlice(torch.nn.Module):
2021
            def forward(self, x):
2022
                return x[:, :, :, :, -1]
2023

2024
        x = torch.randn(3, 4, 5, 6, 7)
2025
        self.run_test(NegSlice(), x)
2026

2027
    @skipIfUnsupportedMinOpsetVersion(11)
2028
    def test_slice_with_input_index(self):
2029
        class InputIndexSlice(torch.nn.Module):
2030
            def forward(self, x, y):
2031
                x[: y.size(0), 0, :] = y
2032
                return x
2033

2034
        x = torch.zeros((56, 6, 256))
2035
        y = torch.rand((22, 256))
2036
        self.run_test(InputIndexSlice(), (x, y))
2037

2038
    @skipIfUnsupportedMinOpsetVersion(11)
2039
    @skipScriptTest()  # Torchscript doesn't support 1d index.
2040
    def test_slice_with_1d_input_index(self):
2041
        class InputIndexSlice(torch.nn.Module):
2042
            def forward(self, x, y):
2043
                x[:y, 0, :] = y
2044
                return x
2045

2046
        x = torch.zeros((56, 6, 256))
2047
        y = torch.tensor([5], dtype=torch.int64)
2048
        self.run_test(InputIndexSlice(), (x, y))
2049

2050
    @skipIfUnsupportedMinOpsetVersion(11)
2051
    def test_slice_with_input_step_size(self):
2052
        class InputIndexSlice(torch.nn.Module):
2053
            def forward(self, x, y, z):
2054
                x[:y:z, 0::z, :] = 1
2055
                return x
2056

2057
        x = torch.zeros((56, 6, 256))
2058
        y = torch.tensor(5, dtype=torch.int64)
2059
        z = torch.tensor(2, dtype=torch.int64)
2060
        self.run_test(InputIndexSlice(), (x, y, z))
2061

2062
    @skipIfUnsupportedMinOpsetVersion(10)
2063
    @skipScriptTest()  # scripting tuple/list append
2064
    def test_slice_dynamic(self):
2065
        class DynamicSliceExportMod(torch.nn.Module):
2066
            def forward(self, x):
2067
                results = []
2068
                for i in range(4):
2069
                    results.append(x[: x.size(0) - i, i : x.size(2), i:3])
2070
                return tuple(results)
2071

2072
        x = torch.rand(5, 5, 5)
2073
        y = torch.randn(6, 7, 8)
2074
        self.run_test(
2075
            DynamicSliceExportMod(),
2076
            x,
2077
            additional_test_inputs=[y],
2078
            input_names=["input_1"],
2079
            output_names=["output_1"],
2080
            dynamic_axes={"input_1": [0, 1, 2], "output_1": [0, 1, 2]},
2081
        )
2082

2083
    @skipIfUnsupportedMinOpsetVersion(10)
2084
    def test_slice_dynamic_script(self):
2085
        class DynamicSliceModel(torch.jit.ScriptModule):
2086
            @torch.jit.script_method
2087
            def forward(self, x):
2088
                return x[1 : x.size(1)]
2089

2090
        x = torch.rand(1, 2)
2091
        self.run_test(DynamicSliceModel(), x)
2092

2093
    @skipIfUnsupportedMinOpsetVersion(10)
2094
    def test_slice_dynamic_shape_script(self):
2095
        class DynamicSliceModel(torch.nn.Module):
2096
            def forward(self, x):
2097
                return x.new_zeros(x.shape[1 : x.size(2)])
2098

2099
        x = torch.rand(1, 2, 3, 4)
2100
        self.run_test(
2101
            DynamicSliceModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]}
2102
        )
2103
        self.run_test(DynamicSliceModel(), x, remained_onnx_input_idx=[])
2104

2105
    @skipIfUnsupportedMinOpsetVersion(10)
2106
    @skipScriptTest()  # scripting tuple/list append
2107
    def test_slice_dynamic_to_end(self):
2108
        class DynamicSliceExportMod(torch.nn.Module):
2109
            def forward(self, x):
2110
                results = []
2111
                for i in range(4):
2112
                    results.append(x[:, i:, x.size(2) - 5])
2113
                return tuple(results)
2114

2115
        x = torch.rand(5, 5, 5)
2116
        self.run_test(
2117
            DynamicSliceExportMod(),
2118
            x,
2119
            dynamic_axes={"input_1": [0, 1, 2], "output_1": [0, 1, 2]},
2120
        )
2121

2122
    def test_square(self):
2123
        class Square(torch.nn.Module):
2124
            def forward(self, x):
2125
                return torch.square(x)
2126

2127
        x = torch.randn(2, 3, 4)
2128
        self.run_test(Square(), x)
2129

2130
    @skipIfUnsupportedMinOpsetVersion(9)
2131
    def test_arange_dynamic(self):
2132
        class ArangeModel(torch.nn.Module):
2133
            def forward(self, input):
2134
                return (
2135
                    torch.arange(input.shape[0]),
2136
                    torch.arange(12),
2137
                    torch.arange(start=input.shape[0], end=input.shape[0] + 5),
2138
                )
2139

2140
        x = torch.randn(5, 3, 2)
2141
        y = torch.randn(8, 3, 2)
2142
        self.run_test(
2143
            ArangeModel(),
2144
            x,
2145
            additional_test_inputs=[y],
2146
            input_names=["input_1"],
2147
            output_names=["output_1", "output_2", "output_3"],
2148
            dynamic_axes={"input_1": [0], "output_1": [0]},
2149
        )
2150
        self.run_test(
2151
            torch.jit.script(ArangeModel()),
2152
            x,
2153
            additional_test_inputs=[y],
2154
            input_names=["input_1"],
2155
            output_names=["output_1", "output_2", "output_3"],
2156
            dynamic_axes={"input_1": [0], "output_1": [0]},
2157
        )
2158

2159
    @skipIfUnsupportedMinOpsetVersion(9)
2160
    def test_dynamic_arange_out(self):
2161
        class ArangeOutModel(torch.nn.Module):
2162
            def forward(self, end):
2163
                out_t = torch.tensor([1], dtype=torch.int64)
2164
                return torch.arange(end, out=out_t)
2165

2166
        x = torch.tensor(8)
2167
        self.run_test(ArangeOutModel(), (x))
2168

2169
    @skipIfUnsupportedMinOpsetVersion(9)
2170
    def test_dynamic_arange_start_out(self):
2171
        class ArangeStartOutModel(torch.nn.Module):
2172
            def forward(self, start, end):
2173
                out_t = torch.tensor([1], dtype=torch.int64)
2174
                return torch.arange(start.size(0), end, out=out_t)
2175

2176
        x = torch.randn(2, 3, 4)
2177
        y = torch.tensor(8)
2178
        self.run_test(
2179
            ArangeStartOutModel(),
2180
            (x, y),
2181
            input_names=["x", "y"],
2182
            dynamic_axes={"x": [0, 1, 2]},
2183
        )
2184
        self.run_test(ArangeStartOutModel(), (x, y), remained_onnx_input_idx=[1])
2185

2186
    @skipIfUnsupportedMinOpsetVersion(9)
2187
    def test_linspace(self):
2188
        class LinspaceModel(torch.nn.Module):
2189
            def forward(self, start, end, steps):
2190
                return torch.linspace(start, end, steps)
2191

2192
        x = torch.tensor(3, dtype=torch.float)
2193
        y = torch.tensor(10, dtype=torch.float)
2194
        z = torch.tensor(5, dtype=torch.int)
2195
        self.run_test(LinspaceModel(), (x, y, z))
2196

2197
    @skipIfUnsupportedMinOpsetVersion(9)
2198
    def test_linspace_negative_start(self):
2199
        class LinspaceModel(torch.nn.Module):
2200
            def forward(self, start, end, steps):
2201
                return torch.linspace(start, end, steps)
2202

2203
        x = torch.tensor(-1, dtype=torch.float)
2204
        y = torch.tensor(1, dtype=torch.float)
2205
        z = torch.tensor(6, dtype=torch.int)
2206
        self.run_test(LinspaceModel(), (x, y, z))
2207

2208
    @skipIfUnsupportedMinOpsetVersion(9)
2209
    def test_arange_with_floats_out(self):
2210
        class ArangeModelEnd(torch.nn.Module):
2211
            def forward(self, end):
2212
                out_t = torch.tensor([1], dtype=torch.float)
2213
                return torch.arange(end, out=out_t)
2214

2215
        y = torch.tensor(8.5, dtype=torch.float)
2216
        self.run_test(ArangeModelEnd(), (y))
2217

2218
        class ArangeModelStep(torch.nn.Module):
2219
            def forward(self, start, end):
2220
                out_t = torch.tensor([1], dtype=torch.float)
2221
                return torch.arange(start.size(0), end, 1.5, out=out_t)
2222

2223
        x = torch.randn(2, 3, 4)
2224
        y = torch.tensor(8.5, dtype=torch.float)
2225
        self.run_test(
2226
            ArangeModelStep(),
2227
            (x, y),
2228
            input_names=["x", "y"],
2229
            dynamic_axes={"x": [0, 1, 2]},
2230
        )
2231
        self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1])
2232

2233
    @skipIfUnsupportedMinOpsetVersion(9)
2234
    def test_arange_with_floats(self):
2235
        class ArangeModelEnd(torch.nn.Module):
2236
            def forward(self, end):
2237
                return torch.arange(end)
2238

2239
        y = torch.tensor(8.5, dtype=torch.float)
2240
        self.run_test(ArangeModelEnd(), (y))
2241

2242
        class ArangeModelStep(torch.nn.Module):
2243
            def forward(self, start, end):
2244
                return torch.arange(start.size(0), end, 1.5)
2245

2246
        x = torch.randn(2, 3, 4)
2247
        y = torch.tensor(8.5, dtype=torch.float)
2248
        self.run_test(
2249
            ArangeModelStep(),
2250
            (x, y),
2251
            input_names=["x", "y"],
2252
            dynamic_axes={"x": [0, 1, 2]},
2253
        )
2254
        self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1])
2255

2256
        class ArangeModelStepNeg(torch.nn.Module):
2257
            def forward(self, start, end):
2258
                return torch.arange(end, start.size(0), -1.5)
2259

2260
        x = torch.randn(2, 3, 4)
2261
        y = torch.tensor(8.5, dtype=torch.float)
2262
        self.run_test(
2263
            ArangeModelStepNeg(),
2264
            (x, y),
2265
            input_names=["x", "y"],
2266
            dynamic_axes={"x": [0, 1, 2]},
2267
        )
2268
        self.run_test(ArangeModelStepNeg(), (x, y), remained_onnx_input_idx=[1])
2269

2270
        class ArangeModelStart(torch.nn.Module):
2271
            def forward(self, start, end):
2272
                return torch.arange(start.size(0), end)
2273

2274
        x = torch.randn(2, 3, 4)
2275
        y = torch.tensor(8.5, dtype=torch.float)
2276
        self.run_test(
2277
            ArangeModelStart(),
2278
            (x, y),
2279
            input_names=["x", "y"],
2280
            dynamic_axes={"x": [0, 1, 2]},
2281
        )
2282
        self.run_test(ArangeModelStart(), (x, y), remained_onnx_input_idx=[1])
2283

2284
    @skipIfUnsupportedMinOpsetVersion(9)
2285
    def test_arange_with_floats_override(self):
2286
        class ArangeModelEnd(torch.nn.Module):
2287
            def forward(self, end):
2288
                return torch.arange(end, dtype=torch.int64)
2289

2290
        y = torch.tensor(8.5, dtype=torch.float)
2291
        self.run_test(ArangeModelEnd(), (y))
2292

2293
        class ArangeModelStep(torch.nn.Module):
2294
            def forward(self, start, end):
2295
                return torch.arange(start.size(0), end, 1.5, dtype=torch.int64)
2296

2297
        x = torch.randn(2, 3, 4)
2298
        y = torch.tensor(8.5, dtype=torch.float)
2299
        self.run_test(
2300
            ArangeModelStep(),
2301
            (x, y),
2302
            input_names=["x", "y"],
2303
            dynamic_axes={"x": [0, 1, 2]},
2304
        )
2305
        self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1])
2306

2307
    @skipIfUnsupportedMinOpsetVersion(11)
2308
    def test_arange_out(self):
2309
        class ArangeOutModel(torch.nn.Module):
2310
            def forward(self, end):
2311
                out_t = torch.tensor([1], dtype=torch.float)
2312
                return torch.arange(end, out=out_t)
2313

2314
        x = torch.tensor(8.5, dtype=torch.float)
2315
        self.run_test(ArangeOutModel(), (x))
2316

2317
    @skipIfUnsupportedMinOpsetVersion(11)
2318
    def test_arange_start_out(self):
2319
        class ArangeStartOutModel(torch.nn.Module):
2320
            def forward(self, start, end):
2321
                out_t = torch.tensor([1], dtype=torch.float)
2322
                return torch.arange(start.size(0), end, out=out_t)
2323

2324
        x = torch.randn(2, 3, 4)
2325
        y = torch.tensor(8.5, dtype=torch.float)
2326
        self.run_test(
2327
            ArangeStartOutModel(),
2328
            (x, y),
2329
            input_names=["x", "y"],
2330
            dynamic_axes={"x": [0, 1, 2]},
2331
        )
2332
        self.run_test(ArangeStartOutModel(), (x, y), remained_onnx_input_idx=[1])
2333

2334
    @skipIfUnsupportedMinOpsetVersion(11)
2335
    def test_arange_no_type(self):
2336
        class ArangeModel(torch.nn.Module):
2337
            def forward(self, end):
2338
                return torch.arange(end), torch.arange(0, end)
2339

2340
        x = torch.tensor(6.2, dtype=torch.float)
2341
        self.run_test(ArangeModel(), x)
2342

2343
    @skipIfUnsupportedMinOpsetVersion(9)
2344
    def test_size(self):
2345
        class SizeModel(torch.nn.Module):
2346
            def forward(self, input):
2347
                return (
2348
                    torch.arange(input.size(0)),
2349
                    torch.arange(input.size(-1)),
2350
                    torch.ones(input.shape),
2351
                )
2352

2353
        x = torch.randn(5, 3, 2)
2354
        self.run_test(SizeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
2355
        self.run_test(SizeModel(), x, remained_onnx_input_idx=[])
2356

2357
    @skipIfUnsupportedMinOpsetVersion(9)
2358
    @skipScriptTest()  # x.stride() not scriptable
2359
    def test_as_strided(self):
2360
        class Model(torch.nn.Module):
2361
            def forward(self, x):
2362
                chunk_size = list(x.size())
2363
                chunk_size[1] = chunk_size[1] * 2 - 1
2364
                chunk_stride = list(x.stride())
2365
                chunk_stride[1] = chunk_stride[1] // 2
2366
                return x.as_strided(
2367
                    (3, 3, 3), (1, 4, 2), storage_offset=2
2368
                ), x.as_strided(chunk_size, chunk_stride)
2369

2370
        x = torch.randn(5, 8, 7)
2371
        self.run_test(Model(), x)
2372

2373
    @skipScriptTest()  # Ellipses followed by tensor indexing not scriptable
2374
    def test_tensor_index_advanced_indexing_ellipsis(self):
2375
        class MyModel(torch.nn.Module):
2376
            def forward(self, input):
2377
                return input[..., torch.tensor([2, 1]), torch.tensor([0, 3])]
2378

2379
        m1 = torch.randn(3, 4, 5, 6, 7)
2380
        self.run_test(MyModel(), (m1,))
2381

2382
    def test_tensor_index_advanced_indexing(self):
2383
        class MyModel(torch.nn.Module):
2384
            def forward(self, input):
2385
                return input[
2386
                    :,
2387
                    torch.tensor([[0, 2], [1, 1]]),
2388
                    :,
2389
                    torch.tensor([2, 1]),
2390
                    torch.tensor([0, 3]),
2391
                ]
2392

2393
        m1 = torch.randn(3, 4, 5, 6, 7)
2394
        self.run_test(MyModel(), (m1,))
2395

2396
        class MyModel(torch.nn.Module):
2397
            def forward(self, input):
2398
                return input[
2399
                    :, torch.tensor([0, 2]), None, 2:4, torch.tensor([[1, 3], [4, 0]])
2400
                ]
2401

2402
        self.run_test(MyModel(), (m1,))
2403

2404
        class MyModel(torch.nn.Module):
2405
            def forward(self, input):
2406
                return input[
2407
                    :,
2408
                    torch.tensor([0, 2]),
2409
                    torch.tensor([1]),
2410
                    2:4,
2411
                    torch.tensor([[1], [4]]),
2412
                ]
2413

2414
        self.run_test(MyModel(), (m1,))
2415

2416
    def test_tensor_index_advanced_indexing_consecutive(self):
2417
        class MyModel(torch.nn.Module):
2418
            def forward(self, input):
2419
                return input[
2420
                    :, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None
2421
                ]
2422

2423
        m1 = torch.randn(3, 4, 5, 6, 7)
2424
        self.run_test(MyModel(), (m1,))
2425

2426
    @skipIfUnsupportedMinOpsetVersion(11)
2427
    def test_index_put(self):
2428
        class IndexPutModel(torch.nn.Module):
2429
            def forward(self, x, ind, update):
2430
                x[ind] = update
2431
                return x
2432

2433
        x = torch.randn(3, 4)
2434
        ind = torch.tensor([1], dtype=torch.long)
2435
        update = torch.ones(4)
2436
        self.run_test(IndexPutModel(), (x, ind, update))
2437

2438
    @skipIfUnsupportedMinOpsetVersion(11)
2439
    def test_index_put_singular(self):
2440
        class IndexPutBoolModel(torch.nn.Module):
2441
            def forward(self, mask, indices):
2442
                mask[indices] = True
2443
                return mask
2444

2445
        mask = torch.zeros(100, dtype=torch.bool)
2446
        indices = (torch.rand(25) * mask.shape[0]).to(torch.int64)
2447
        self.run_test(IndexPutBoolModel(), (mask, indices))
2448

2449
        class IndexPutFloatModel(torch.nn.Module):
2450
            def forward(self, mask, indices):
2451
                mask[indices] = torch.tensor(5.5)
2452
                return mask
2453

2454
        mask = torch.rand(100, dtype=torch.float)
2455
        indices = (torch.rand(50) * mask.shape[0]).to(torch.int64)
2456
        self.run_test(IndexPutFloatModel(), (mask, indices))
2457

2458
    @skipIfUnsupportedMinOpsetVersion(11)
2459
    def test_index_put_accumulate(self):
2460
        class IndexPutModel(torch.nn.Module):
2461
            def forward(self, x, ind, update):
2462
                return x.index_put((ind,), update, accumulate=True)
2463

2464
        x = torch.randn(3, 4)
2465
        ind = torch.tensor([2], dtype=torch.long)
2466
        update = torch.ones(4)
2467
        self.run_test(IndexPutModel(), (x, ind, update))
2468

2469
    @skipIfUnsupportedMinOpsetVersion(11)
2470
    def test_index_put_slice_index(self):
2471
        class IndexPutModel(torch.nn.Module):
2472
            def forward(self, x, update):
2473
                x[1:2, 1:3, torch.tensor([1])] += update
2474
                return x
2475

2476
        x = torch.randn(3, 4, 5)
2477
        update = torch.tensor([10, 15]).view(1, 2, 1)
2478
        self.run_test(IndexPutModel(), (x, update))
2479

2480
        class IndexPutModel2(torch.nn.Module):
2481
            def forward(self, x, update):
2482
                x[torch.tensor([0, 2]), torch.tensor([1, 2])] += update
2483
                return x
2484

2485
        x = torch.randn(3, 4, 5)
2486
        update = torch.randn(2, 5)
2487
        self.run_test(IndexPutModel2(), (x, update))
2488

2489
        class IndexPutModel3(torch.nn.Module):
2490
            def forward(self, x, update):
2491
                x[torch.tensor([0, 2]), 1:2] += update
2492
                return x
2493

2494
        x = torch.randn(3, 4, 5)
2495
        update = torch.tensor([10, 15]).view(2, 1, 1)
2496
        self.run_test(IndexPutModel3(), (x, update))
2497

2498
        class IndexPutModel4(torch.nn.Module):
2499
            def forward(self, x, update):
2500
                x[torch.tensor([0, 2]), 2] += update
2501
                return x
2502

2503
        x = torch.randn(3, 4, 5)
2504
        update = torch.tensor([10, 15]).view(2, 1)
2505
        self.run_test(IndexPutModel4(), (x, update))
2506

2507
        class IndexPutModel5(torch.nn.Module):
2508
            def forward(self, x, update):
2509
                x[1:3, torch.tensor([0, 2]), 2] += update
2510
                return x
2511

2512
        x = torch.randn(3, 4, 5)
2513
        update = torch.tensor([10, 15]).view(2, 1)
2514
        self.run_test(IndexPutModel5(), (x, update))
2515

2516
        class IndexPutModel6(torch.nn.Module):
2517
            def forward(self, x, update):
2518
                x[1:3, 0] = update
2519
                return x
2520

2521
        x = torch.randn(3, 4, 5)
2522
        update = torch.arange(2 * 5).to(torch.float).view(2, 5)
2523
        self.run_test(IndexPutModel6(), (x, update))
2524

2525
        class IndexPutModel7(torch.nn.Module):
2526
            def forward(self, x, update):
2527
                x[1:, 0] = update
2528
                return x
2529

2530
        x = torch.randn(3, 4, 5)
2531
        update = torch.arange(2 * 5).to(torch.float).view(2, 5)
2532
        self.run_test(IndexPutModel7(), (x, update))
2533

2534
        class IndexPutModel8(torch.nn.Module):
2535
            def forward(self, x, update):
2536
                x[:3, 0] = update
2537
                return x
2538

2539
        x = torch.randn(3, 4, 5)
2540
        update = torch.arange(3 * 5).to(torch.float).view(3, 5)
2541
        self.run_test(IndexPutModel8(), (x, update))
2542

2543
        class IndexPutModel9(torch.nn.Module):
2544
            def forward(self, poses):
2545
                w = 32
2546
                x = poses[:, :, 0] - (w - 1) // 2
2547
                boxes = torch.zeros([poses.shape[0], 17, 4])
2548
                boxes[:, :, 0] = x
2549
                return boxes
2550

2551
        x = torch.zeros([2, 17, 3], dtype=torch.int64)
2552
        self.run_test(IndexPutModel9(), (x,))
2553

2554
        class IndexPutModel10(torch.nn.Module):
2555
            def forward(self, x, ind, update):
2556
                x[ind, 1:3] = update.view(1, 1, 1, 5).expand(2, 2, 2, 5)
2557
                return x
2558

2559
        x = torch.randn(3, 4, 5)
2560
        ind = torch.tensor([[0, 2], [1, 1]])
2561
        update = torch.randn(5)
2562
        self.run_test(IndexPutModel10(), (x, ind, update))
2563

2564
    @skipIfUnsupportedMinOpsetVersion(11)
2565
    @skipScriptTest()  # Ellipses followed by tensor indexing not scriptable
2566
    def test_index_put_ellipsis(self):
2567
        class IndexPutModel(torch.nn.Module):
2568
            def forward(self, x, update):
2569
                x[..., torch.tensor([2, 1, 3]), 2:4] += update
2570
                return x
2571

2572
        x = torch.randn(3, 4, 5, 6, 7)
2573
        update = torch.randn(3, 1, 1, 3, 2)
2574
        self.run_test(IndexPutModel(), (x, update))
2575

2576
        class IndexPutModel2(torch.nn.Module):
2577
            def forward(self, x, update):
2578
                x[2, ..., torch.tensor([2, 1, 3]), 2:4] += update
2579
                return x
2580

2581
        x = torch.randn(3, 4, 5, 6, 7)
2582
        update = torch.randn(4, 1, 3, 2)
2583
        self.run_test(IndexPutModel2(), (x, update))
2584

2585
    @skipIfUnsupportedMinOpsetVersion(11)
2586
    def test_index_put_loop(self):
2587
        @torch.jit.script
2588
        def ngram_attention_bias(
2589
            sequence_length: int, ngram: int, device: torch.device, dtype: torch.dtype
2590
        ):
2591
            bias = torch.ones(
2592
                (ngram, sequence_length), device=device, dtype=dtype
2593
            ) * float("-inf")
2594
            for stream_idx in range(ngram):
2595
                for i in range(sequence_length):
2596
                    bias = bias * 2
2597
                    bias[stream_idx, i] = 5
2598
                    bias = bias * 5
2599
                    bias[0, 0] = 5
2600

2601
            for stream_idx in range(ngram):
2602
                for i in range(sequence_length):
2603
                    bias[stream_idx, i] = 5
2604
                    bias[0, i] = 5
2605
            return bias
2606

2607
        class ScriptModel(torch.nn.Module):
2608
            def __init__(self):
2609
                super().__init__()
2610
                self.ngram = 2
2611
                self.max_target_positions = 512
2612

2613
            def forward(self, hidden_states):
2614
                seq_length, batch_size = hidden_states.shape[:2]
2615
                predict_causal_mask = ngram_attention_bias(
2616
                    self.max_target_positions,
2617
                    self.ngram,
2618
                    hidden_states.device,
2619
                    hidden_states.dtype,
2620
                )
2621
                predict_causal_mask = predict_causal_mask[:, :seq_length]
2622
                return predict_causal_mask
2623

2624
        x = torch.randn(6, 2)
2625
        y = torch.randn(4, 1)
2626
        self.run_test(
2627
            ScriptModel(),
2628
            x,
2629
            input_names=["x"],
2630
            dynamic_axes={"x": {0: "seq_length", 1: "batch_size"}},
2631
            additional_test_inputs=[y],
2632
        )
2633

2634
    @skipIfUnsupportedMinOpsetVersion(11)
2635
    def test_copy_(self):
2636
        class CopyModel(torch.nn.Module):
2637
            def forward(self, x, data):
2638
                x[1:3] = data
2639
                return x
2640

2641
        x = torch.randn(3, 4)
2642
        update = torch.randn(2, 4)
2643
        self.run_test(CopyModel(), (x, update))
2644

2645
        # mixed slice and select
2646
        class CopyModel2(torch.nn.Module):
2647
            def forward(self, x, data):
2648
                x[1:3, 0] = data
2649
                return x
2650

2651
        x = torch.randn(3, 4)
2652
        update = torch.tensor([0], dtype=torch.float32)
2653
        self.run_test(CopyModel2(), (x, update))
2654

2655
        update = torch.tensor([2, 3], dtype=torch.float32)
2656
        self.run_test(CopyModel2(), (x, update))
2657

2658
        update = torch.randn(2)
2659
        self.run_test(CopyModel2(), (x, update))
2660

2661
        class CopyModel3(torch.nn.Module):
2662
            def forward(self, x, data):
2663
                x[1, 1:3] = data
2664
                return x
2665

2666
        x = torch.randn(3, 4)
2667
        update = torch.tensor([0], dtype=torch.float32)
2668
        self.run_test(CopyModel3(), (x, update))
2669

2670
        update = torch.tensor([2, 3], dtype=torch.float32)
2671
        self.run_test(CopyModel3(), (x, update))
2672

2673
        update = torch.randn(2)
2674
        self.run_test(CopyModel3(), (x, update))
2675

2676
        class CopyModel4(torch.nn.Module):
2677
            def forward(self, x, ind, data):
2678
                x[ind] = data
2679
                return x
2680

2681
        x = torch.randn(3, 4)
2682
        ind = torch.tensor(2)
2683
        data = torch.randn(4)
2684
        self.run_test(CopyModel4(), (x, ind, data))
2685

2686
        class CopyModel5(torch.nn.Module):
2687
            def forward(self, x, mask):
2688
                if mask is not None:
2689
                    x.copy_(mask)
2690
                    return x
2691

2692
        x = torch.randn(3, 4)
2693
        mask = torch.randn(3, 1)
2694
        self.run_test(CopyModel5(), (x, mask))
2695

2696
    @skipIfUnsupportedMinOpsetVersion(11)
2697
    @skipScriptTest()  # Model not scriptable (output with shape doesn't match the broadcast shape)
2698
    def test_copy_tracing(self):
2699
        class CopyModel(torch.nn.Module):
2700
            def forward(self, x, data):
2701
                x[1, 1:3] = data
2702
                return x
2703

2704
        x = torch.randn(3, 4)
2705
        update = torch.randn(1, 2)
2706
        self.run_test(CopyModel(), (x, update))
2707

2708
    @skipIfUnsupportedMinOpsetVersion(11)
2709
    def test_copy_ellipsis(self):
2710
        class CopyModel(torch.nn.Module):
2711
            def forward(self, x, update):
2712
                x[..., 1] = update
2713
                return x
2714

2715
        x = torch.randn(2, 3, 4)
2716
        update = torch.ones(1)
2717
        self.run_test(CopyModel(), (x, update))
2718

2719
        x = torch.randn(2, 3, 4, 5, 6)
2720
        update = torch.ones(1)
2721
        self.run_test(CopyModel(), (x, update))
2722

2723
    @skipIfUnsupportedMinOpsetVersion(11)
2724
    def test_copy_ellipsis_script(self):
2725
        class CopyModel(torch.nn.Module):
2726
            def forward(self, x, update):
2727
                # Insert reshape node to ensure no shape/type info for
2728
                # x in scripting, without onnx shape inference.
2729
                x = x.reshape(4, 3, 5, 6)
2730
                x[2, ..., 1:3] = update
2731
                return x
2732

2733
        x = torch.randn(3, 4, 5, 6)
2734

2735
        update = torch.ones(1)
2736
        self.run_test(CopyModel(), (x, update))
2737

2738
    @skipIfUnsupportedMinOpsetVersion(10)
2739
    def test_flip(self):
2740
        class MyModule(torch.nn.Module):
2741
            def forward(self, x):
2742
                return torch.flip(x, dims=[0])
2743

2744
        x = torch.tensor(np.arange(6.0).reshape(2, 3))
2745
        self.run_test(MyModule(), x)
2746

2747
    @skipIfUnsupportedMinOpsetVersion(9)
2748
    def test_randint(self):
2749
        class RandInt(torch.nn.Module):
2750
            def forward(self, x):
2751
                randint = torch.randint(1, 10, x.shape)
2752
                x = 0 * randint + x
2753
                return x
2754

2755
        x = torch.randn(2, 3, 4)
2756
        self.run_test(RandInt(), x)
2757

2758
    @skipIfUnsupportedMinOpsetVersion(9)
2759
    def test_randint_value(self):
2760
        class RandInt(torch.nn.Module):
2761
            def forward(self, x):
2762
                # This randint call always returns 3
2763
                return torch.randint(3, 4, x.shape) + x
2764

2765
        x = torch.randn(2, 3, 4)
2766
        self.run_test(RandInt(), x)
2767

2768
    @skipIfUnsupportedMinOpsetVersion(9)
2769
    def test_randint_like(self):
2770
        class RandInt(torch.nn.Module):
2771
            def forward(self, x):
2772
                # This randint call always returns 3
2773
                return torch.randint_like(x, 3, 4) + x
2774

2775
        x = torch.randn(2, 3, 4)
2776
        self.run_test(RandInt(), x)
2777

2778
    def test_randn(self):
2779
        class RandN(torch.nn.Module):
2780
            def forward(self, x):
2781
                return torch.mul(x, (torch.randn(2, 3, 4) + x).size(0))
2782

2783
        x = torch.randn(2, 3, 4)
2784
        self.run_test(RandN(), x)
2785

2786
    def test_rand(self):
2787
        class Rand(torch.nn.Module):
2788
            def forward(self, x):
2789
                return torch.mul(x, (torch.rand(2, 3, 4) + x).size(0))
2790

2791
        x = torch.randn(2, 3, 4)
2792
        self.run_test(Rand(), x)
2793

2794
    def test_randn_dtype(self):
2795
        class RandN(torch.nn.Module):
2796
            def forward(self, x):
2797
                # The resulting node's dtype should be double.
2798
                return (
2799
                    x.to(torch.float32)
2800
                    * torch.randn(2, 3, 4, dtype=torch.double)
2801
                    * torch.tensor(0, dtype=torch.float32)
2802
                )
2803

2804
        x = torch.randn(2, 3, 4)
2805
        self.run_test(RandN(), x)
2806

2807
    def test_rand_dtype(self):
2808
        class Rand(torch.nn.Module):
2809
            def forward(self, x):
2810
                # The resulting node's dtype should be double.
2811
                return (
2812
                    x.to(torch.float32)
2813
                    * torch.rand(2, 3, 4, dtype=torch.double)
2814
                    * torch.tensor(0, dtype=torch.float32)
2815
                )
2816

2817
        x = torch.randn(2, 3, 4)
2818
        self.run_test(Rand(), x)
2819

2820
    @skipIfUnsupportedMinOpsetVersion(9)
2821
    def test_randn_dynamic_size(self):
2822
        class RandN(torch.nn.Module):
2823
            def forward(self, x):
2824
                return torch.mul(x, torch.randn(x.size()).size(1))
2825

2826
        x = torch.randn(2, 3, 4)
2827
        self.run_test(RandN(), x)
2828

2829
    @skipIfUnsupportedMinOpsetVersion(9)
2830
    def test_rand_dynamic_size(self):
2831
        class Rand(torch.nn.Module):
2832
            def forward(self, x):
2833
                return torch.mul(x, torch.rand(x.size()).size(1))
2834

2835
        x = torch.randn(2, 3, 4)
2836
        self.run_test(Rand(), x)
2837

2838
    def test_randn_like(self):
2839
        class RandNLike(torch.nn.Module):
2840
            def forward(self, x):
2841
                return torch.mul(x, torch.randn_like(x).size(0))
2842

2843
        x = torch.randn(2, 3, 4)
2844
        self.run_test(RandNLike(), x)
2845
        self.run_test(torch.jit.script(RandNLike()), x)
2846

2847
    def test_rand_like(self):
2848
        class RandLike(torch.nn.Module):
2849
            def forward(self, x):
2850
                return torch.mul(x, torch.rand_like(x).size(0))
2851

2852
        x = torch.randn(2, 3, 4)
2853
        self.run_test(RandLike(), x)
2854
        self.run_test(torch.jit.script(RandLike()), x)
2855

2856
    def test_randn_like_dtype(self):
2857
        class RandNLike(torch.nn.Module):
2858
            def forward(self, x):
2859
                # The resulting node's dtype should be double.
2860
                return (
2861
                    x.to(torch.float32)
2862
                    * torch.randn_like(x, dtype=torch.double)
2863
                    * torch.tensor(0, dtype=torch.float32)
2864
                )
2865

2866
        x = torch.randn(2, 3, 4)
2867
        self.run_test(RandNLike(), x)
2868

2869
    def test_rand_like_dtype(self):
2870
        class RandLike(torch.nn.Module):
2871
            def forward(self, x):
2872
                # The resulting node's dtype should be double.
2873
                return (
2874
                    x.to(torch.float32)
2875
                    * torch.rand_like(x, dtype=torch.double)
2876
                    * torch.tensor(0, dtype=torch.float32)
2877
                )
2878

2879
        x = torch.randn(2, 3, 4)
2880
        self.run_test(RandLike(), x)
2881

2882
    def test_bernoulli(self):
2883
        class Bernoulli(torch.nn.Module):
2884
            def forward(self, x):
2885
                return torch.mul(x, torch.bernoulli(x).size(0))
2886

2887
        x = torch.empty(3, 3).uniform_(0, 1)
2888
        self.run_test(Bernoulli(), x)
2889

2890
        x = torch.empty(2, 3, 3, dtype=torch.double).uniform_(0, 1)
2891
        self.run_test(Bernoulli(), x)
2892

2893
    def test_bernoulli_p(self):
2894
        class Bernoulli_float(torch.nn.Module):
2895
            def forward(self, x):
2896
                return torch.mul(x, torch.bernoulli(x, 0.2).size(0))
2897

2898
        class Bernoulli_tensor(torch.nn.Module):
2899
            def forward(self, x):
2900
                return torch.mul(x, torch.rand_like(x).bernoulli_(x).size(0))
2901

2902
        x = torch.rand(3, 3)
2903
        self.run_test(Bernoulli_float(), x)
2904
        self.run_test(Bernoulli_tensor(), x)
2905

2906
        x = torch.rand(2, 3, 3, dtype=torch.double)
2907
        self.run_test(Bernoulli_float(), x)
2908
        self.run_test(Bernoulli_tensor(), x)
2909

2910
    @unittest.skip("Bug in ORT, skip test until rel-1.11.")
2911
    @skipIfUnsupportedMinOpsetVersion(14)
2912
    def test_reshape_allowzero(self):
2913
        class ReshapeModel(torch.nn.Module):
2914
            def forward(self, x):
2915
                x = x.reshape(3, 4, 0)
2916
                return x
2917

2918
        x = torch.randn(0, 3, 4)
2919
        self.run_test(ReshapeModel(), x)
2920

2921
    def test_reshape_different_rank(self):
2922
        class ReshapeModel(torch.nn.Module):
2923
            def forward(self, x):
2924
                x = x.reshape(-1, 2, 4, 4, 5, 5)
2925
                return x
2926

2927
        x = torch.randn(1, 32, 5, 5)
2928
        self.run_test(ReshapeModel(), x)
2929

2930
    def _interpolate(self, x, mode, use_size, is_upsample, align_corners=False):
2931
        class MyModel(torch.nn.Module):
2932
            __constants__ = [
2933
                "mode",
2934
                "use_size",
2935
                "is_upsample",
2936
                "size",
2937
                "scale",
2938
                "size_array",
2939
                "scale_array",
2940
                "align_corners",
2941
            ]
2942

2943
            def __init__(self, mode, use_size, is_upsample, align_corners):
2944
                super().__init__()
2945
                self.mode = mode
2946
                self.use_size = use_size
2947
                self.is_upsample = is_upsample
2948
                self.align_corners = align_corners
2949
                self.scale = 2.0 if self.is_upsample else 0.5
2950
                self.size = 24 if self.is_upsample else 2
2951
                if x.dim() == 3:
2952
                    self.scale_array = [2.3]
2953
                    self.size_array = [16]
2954
                elif x.dim() == 4:
2955
                    self.scale_array = [2.3, 3.1]
2956
                    self.size_array = [16, 32]
2957
                else:
2958
                    self.scale_array = [2.3, 3.1, 4.6]
2959
                    self.size_array = [16, 32, 64]
2960

2961
            def forward(self, x):
2962
                if self.use_size:
2963
                    if self.align_corners:
2964
                        return torch.nn.functional.interpolate(
2965
                            x, mode=self.mode, size=self.size, align_corners=True
2966
                        ), torch.nn.functional.interpolate(
2967
                            x, mode=self.mode, size=self.size_array, align_corners=True
2968
                        )
2969
                    return torch.nn.functional.interpolate(
2970
                        x, mode=self.mode, size=self.size
2971
                    ), torch.nn.functional.interpolate(
2972
                        x, mode=self.mode, size=self.size_array
2973
                    )
2974
                if self.align_corners:
2975
                    return torch.nn.functional.interpolate(
2976
                        x,
2977
                        mode=self.mode,
2978
                        scale_factor=self.scale,
2979
                        recompute_scale_factor=False,
2980
                    ), torch.nn.functional.interpolate(
2981
                        x,
2982
                        mode=self.mode,
2983
                        scale_factor=self.scale_array,
2984
                        recompute_scale_factor=False,
2985
                    )
2986
                return torch.nn.functional.interpolate(
2987
                    x,
2988
                    mode=self.mode,
2989
                    scale_factor=self.scale,
2990
                    recompute_scale_factor=False,
2991
                ), torch.nn.functional.interpolate(
2992
                    x,
2993
                    mode=self.mode,
2994
                    scale_factor=self.scale_array,
2995
                    recompute_scale_factor=False,
2996
                )
2997

2998
        model = MyModel(mode, use_size, is_upsample, align_corners)
2999
        self.run_test(model, x, atol=1e-6)
3000

3001
    def _interpolate_tests(self, is_upsample):
3002
        # - cubic mode is not supported for opsets below 11;
3003
        # - linear mode does not match for opsets below 11;
3004
        modes = ["nearest", "linear", "bicubic"]
3005
        if self.opset_version < 11:
3006
            modes = ["nearest"]
3007
        x = [
3008
            torch.randn(1, 2, 6, requires_grad=True),
3009
            torch.randn(1, 2, 4, 6, requires_grad=True),
3010
            torch.randn(1, 2, 4, 4, 6, requires_grad=True),
3011
        ]
3012

3013
        for mode in modes:
3014
            for xi in x:
3015
                mode_i = mode
3016
                # TODO: enable bicubic downsample when ORT precision loss fixed
3017
                if mode == "bicubic" and xi.dim() != 4:
3018
                    continue
3019
                elif mode == "linear":
3020
                    if xi.dim() == 3:
3021
                        # TODO : enable when linear mode is implemented for 1d inputs in ORT
3022
                        continue
3023
                    elif xi.dim() == 4:
3024
                        mode_i = "bilinear"
3025
                    elif xi.dim() == 5:
3026
                        # TODO : enable when linear mode is implemented for 3d inputs in ORT
3027
                        mode_i = "trilinear"
3028
                        continue
3029
                self._interpolate(xi, mode_i, True, is_upsample)
3030
                # test with align_corners if supported
3031
                if mode != "nearest":
3032
                    self._interpolate(xi, mode_i, True, is_upsample, True)
3033
                # the following cases, require dynamic sizes/scales,
3034
                # which which is not supported for opset_version < 9
3035
                if self.opset_version >= 9:
3036
                    self._interpolate(xi, mode_i, True, is_upsample)
3037
                    # test with align_corners if supported
3038
                    if mode != "nearest":
3039
                        self._interpolate(xi, mode_i, False, is_upsample, True)
3040
                    self._interpolate(xi, mode_i, False, is_upsample)
3041

3042
    # ONNX export failed on interpolate scripting because dynamic size not supported for opsets below 9.
3043
    @skipIfUnsupportedMinOpsetVersion(9)
3044
    def test_interpolate_upsample(self):
3045
        self._interpolate_tests(True)
3046

3047
    @skipIfUnsupportedMaxOpsetVersion(8)
3048
    @skipScriptTest()  # Scripting supported for opsets > 8. See test_interpolate_upsample
3049
    def test_interpolate_upsample_trace(self):
3050
        self._interpolate_tests(True)
3051

3052
    @skipIfUnsupportedMinOpsetVersion(9)
3053
    def test_interpolate_function_substitution(self):
3054
        class ScriptModel(torch.jit.ScriptModule):
3055
            @torch.jit.script_method
3056
            def forward(self, x):
3057
                return torch.nn.functional.interpolate(
3058
                    x, mode="nearest", scale_factor=2.0
3059
                )
3060

3061
        class ScriptModule(torch.jit.ScriptModule):
3062
            def __init__(self):
3063
                super().__init__()
3064
                self.submodule = ScriptModel()
3065

3066
            @torch.jit.script_method
3067
            def forward(self, input):
3068
                return self.submodule(input)
3069

3070
        x = torch.randn(1, 2, 4, 4, 6)
3071
        self.run_test(ScriptModule(), (x,))
3072

3073
        @torch.jit.script
3074
        def script_method(x):
3075
            return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2.0)
3076

3077
        class TracingModule(torch.nn.Module):
3078
            def forward(self, x):
3079
                return script_method(x)
3080

3081
        self.run_test(TracingModule(), (x,))
3082

3083
    @skipIfUnsupportedMinOpsetVersion(10)
3084
    def test_interpolate_downsample(self):
3085
        self._interpolate_tests(False)
3086

3087
    @skipIfUnsupportedMinOpsetVersion(11)
3088
    def test_interpolate_half_pixel(self):
3089
        # testing whether it uses "half_pixel" or "pytorch_half_pixel"
3090
        # see https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize
3091

3092
        class MyModel(torch.nn.Module):
3093
            def __init__(self, mode, size):
3094
                super().__init__()
3095
                self.mode = mode
3096
                self.size = size
3097

3098
            def forward(self, x):
3099
                return torch.nn.functional.interpolate(
3100
                    x, mode=self.mode, size=self.size
3101
                )
3102

3103
        modes = ["linear", "bicubic"]
3104
        x = [
3105
            torch.randn(1, 2, 6, requires_grad=True),
3106
            torch.randn(1, 2, 4, 6, requires_grad=True),
3107
            torch.randn(1, 2, 4, 4, 6, requires_grad=True),
3108
        ]
3109
        for mode in modes:
3110
            for xi in x:
3111
                mode_i = mode
3112
                if mode == "bicubic" and xi.dim() != 4:
3113
                    continue
3114
                elif mode == "linear":
3115
                    if xi.dim() == 4:
3116
                        mode_i = "bilinear"
3117
                    elif xi.dim() == 5:
3118
                        mode_i = "trilinear"
3119
                for i in range(xi.dim() - 2):
3120
                    size = list(xi.shape[2:])
3121
                    size[i] = 1
3122
                    self.run_test(MyModel(mode_i, size), xi)
3123

3124
    @skipIfUnsupportedMinOpsetVersion(11)
3125
    def test_interpolate_no_shape(self):
3126
        class MyModel(torch.jit.ScriptModule):
3127
            @torch.jit.script_method
3128
            def forward(self, x, y):
3129
                x = torch.add(x, x)
3130
                out1 = torch.nn.functional.interpolate(
3131
                    x, mode="bilinear", size=(16, 16), align_corners=False
3132
                )
3133
                out2 = torch.nn.functional.interpolate(
3134
                    x, mode="nearest", size=(int(y.size(0)), int(y.size(1)))
3135
                )
3136
                return out1, out2
3137

3138
        x = torch.randn(1, 2, 4, 4, requires_grad=True)
3139
        y = torch.randn(16, 16, requires_grad=True)
3140
        self.run_test(
3141
            MyModel(),
3142
            (x, y),
3143
            input_names=["x", "y"],
3144
            dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1]},
3145
        )
3146
        self.run_test(MyModel(), (x, y), remained_onnx_input_idx=[0])
3147

3148
    @skipScriptTest()  # scripting raises OnnxRuntimeError
3149
    def test_interpolate_adaptive_pooling_error(self):
3150
        x = torch.randn(1, 2, 6, requires_grad=True)
3151
        with self.assertRaises(RuntimeError) as cm:
3152
            self._interpolate(x, "area", True, True)
3153

3154
        with self.assertRaises(RuntimeError) as cm:
3155
            self._interpolate(x, "area", False, True)
3156

3157
    def test_groupnorm(self):
3158
        model = torch.nn.GroupNorm(3, 6, 0.002)
3159
        x = torch.randn(4, 6, 36, 36, 18)
3160
        self.run_test(model, x)
3161

3162
        model = torch.nn.GroupNorm(1, 6, 0.002)
3163
        x = torch.randn(4, 6, 180, 180)
3164
        self.run_test(model, x)
3165

3166
        model = torch.nn.GroupNorm(6, 6, 0.002)
3167
        x = torch.randn(4, 6, 180, 180)
3168
        self.run_test(model, x)
3169

3170
    def test_groupnorm_noaffine(self):
3171
        model = torch.nn.GroupNorm(4, 8, 0.002, affine=False)
3172
        x = torch.randn(3, 8, 224, 224)
3173
        self.run_test(model, x)
3174

3175
        model = torch.nn.GroupNorm(1, 6, 0.002, affine=False)
3176
        x = torch.randn(4, 6, 180, 180)
3177
        self.run_test(model, x)
3178

3179
        model = torch.nn.GroupNorm(6, 6, 0.002, affine=False)
3180
        x = torch.randn(4, 6, 180, 180)
3181
        self.run_test(model, x)
3182

3183
    @skipIfUnsupportedMinOpsetVersion(9)
3184
    def test_list_unpack_scripted(self):
3185
        class ListUnpack(torch.nn.Module):
3186
            def forward(self, x):
3187
                a, b = x.shape
3188
                return x.new_zeros((a, b))
3189

3190
        x = torch.randn(2, 3)
3191
        self.run_test(
3192
            torch.jit.script(ListUnpack()),
3193
            x,
3194
            input_names=["x"],
3195
            dynamic_axes={"x": [0, 1]},
3196
        )
3197
        self.run_test(torch.jit.script(ListUnpack()), x, remained_onnx_input_idx=[])
3198

3199
    @skipIfUnsupportedMinOpsetVersion(9)
3200
    def test_list_unpack_scripted_runs_without_error_with_constructed_list_as_input(
3201
        self,
3202
    ):
3203
        class PackUnpack(torch.nn.Module):
3204
            """Create and unpack a list of tensors.
3205

3206
            When scripted, it should produce a graph similar to
3207

3208
            ```
3209
            graph(%self : __torch__.PackUnpack,
3210
                %a.1 : Tensor,
3211
                %b.1 : Tensor):
3212
            %packed.1 : Tensor[] = prim::ListConstruct(%a.1, %b.1)
3213
            %c.1 : Tensor, %8 : Tensor = prim::ListUnpack(%packed.1)
3214
            return (%c.1)
3215
            ```
3216
            """
3217

3218
            def forward(self, a, b):
3219
                packed = [a, b]
3220
                c, _ = packed
3221
                return c
3222

3223
        self.run_test(
3224
            torch.jit.script(PackUnpack()),
3225
            (torch.tensor(0), torch.tensor([42])),
3226
            remained_onnx_input_idx=[0],
3227
        )
3228

3229
    @skipIfUnsupportedMinOpsetVersion(9)
3230
    def test_list_unpack_slice_scripted(self):
3231
        class ListUnpackSlice(torch.nn.Module):
3232
            def forward(self, x):
3233
                a, b = x.shape[2:]
3234
                return x.new_zeros((a, b))
3235

3236
        x = torch.randn(2, 3, 4, 5)
3237
        self.run_test(
3238
            torch.jit.script(ListUnpackSlice()),
3239
            x,
3240
            input_names=["x"],
3241
            dynamic_axes={"x": [0, 1, 2, 3]},
3242
        )
3243
        self.run_test(
3244
            torch.jit.script(ListUnpackSlice()), x, remained_onnx_input_idx=[]
3245
        )
3246

3247
    @skipDtypeChecking
3248
    def test_pow(self):
3249
        class PowModule(torch.nn.Module):
3250
            def forward(self, x, y):
3251
                return x.pow(y)
3252

3253
        x = torch.randn(2, 3, 4)
3254
        y = torch.randn(2, 3, 4)
3255
        self.run_test(PowModule(), (x, y))
3256

3257
        x = torch.randint(10, (2, 3, 4))
3258
        y = torch.randint(10, (2, 3, 4)).to(dtype=torch.int32)
3259
        self.run_test(PowModule(), (x, y))
3260

3261
        x = torch.randint(10, (2, 3, 4))
3262
        y = torch.randint(10, (2, 3, 4))
3263
        self.run_test(PowModule(), (x, y))
3264

3265
        x = torch.randn(2, 3, 4).to(dtype=torch.float64)
3266
        y = torch.randint(10, (2, 3, 4))
3267
        self.run_test(PowModule(), (x, y))
3268

3269
        class PowModule2(torch.nn.Module):
3270
            def forward(self, x):
3271
                return torch.pow(2, x)
3272

3273
        x = torch.randn(1, 10)
3274
        self.run_test(PowModule2(), (x,))
3275

3276
        x = torch.randint(10, (2, 3, 4))
3277
        self.run_test(PowModule2(), (x,))
3278

3279
        x = torch.randn(1, 10).to(dtype=torch.float64)
3280
        self.run_test(PowModule2(), (x,))
3281

3282
        class PowModule3(torch.nn.Module):
3283
            def forward(self, x, y):
3284
                return y[torch.pow(2, x)]
3285

3286
        x = torch.randint(5, (2, 3, 4))
3287
        y = torch.rand(100)
3288
        self.run_test(PowModule3(), (x, y))
3289

3290
    # the arithmeticOps(Add\Sub\Mul\Div\Gemm\Pow\Mod) with low precision include unit8 will be failed in ORT
3291
    # add to(dtype=torch.long) to avoid ORT output type does not match expected type.
3292
    # will be fixed in ONNX version 14.
3293
    @skipIfUnsupportedMaxOpsetVersion(13)
3294
    @skipDtypeChecking
3295
    def test_arithmeticOps_with_low_precision(self):
3296
        class AddModule(torch.nn.Module):
3297
            def forward(self, x, y):
3298
                return x + y
3299

3300
        class SubModule(torch.nn.Module):
3301
            def forward(self, x, y):
3302
                return x - y
3303

3304
        class MulModule(torch.nn.Module):
3305
            def forward(self, x, y):
3306
                return x * y
3307

3308
        class DivModule(torch.nn.Module):
3309
            def forward(self, x, y):
3310
                return x / y
3311

3312
        class PowModule(torch.nn.Module):
3313
            def forward(self, x, y):
3314
                return x.pow(y)
3315

3316
        x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3317
        y = torch.tensor([2, 3, 5], dtype=torch.uint8)
3318
        z = torch.tensor([1], dtype=torch.uint8)
3319
        self.run_test(AddModule(), (x, y))
3320
        self.run_test(SubModule(), (x, y))
3321
        self.run_test(MulModule(), (x, y))
3322
        self.run_test(DivModule(), (x, y))
3323
        self.run_test(PowModule(), (x, z))
3324

3325
        x = torch.tensor([2, 3, 5], dtype=torch.int8)
3326
        y = torch.tensor([2, 3, 5], dtype=torch.int8)
3327
        z = torch.tensor([1], dtype=torch.int8)
3328
        self.run_test(AddModule(), (x, y))
3329
        self.run_test(SubModule(), (x, y))
3330
        self.run_test(MulModule(), (x, y))
3331
        self.run_test(DivModule(), (x, y))
3332
        self.run_test(PowModule(), (x, z))
3333

3334
        x = torch.tensor([2, 3, 5], dtype=torch.int16)
3335
        y = torch.tensor([2, 3, 5], dtype=torch.int16)
3336
        z = torch.tensor([1], dtype=torch.int16)
3337
        self.run_test(AddModule(), (x, y))
3338
        self.run_test(SubModule(), (x, y))
3339
        self.run_test(MulModule(), (x, y))
3340
        self.run_test(DivModule(), (x, y))
3341
        self.run_test(PowModule(), (x, z))
3342

3343
        x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3344
        y = torch.tensor([2, 3, 5], dtype=torch.float32)
3345
        z = torch.tensor([1], dtype=torch.float64)
3346
        self.run_test(AddModule(), (x, y))
3347
        self.run_test(SubModule(), (x, y))
3348
        self.run_test(MulModule(), (x, y))
3349
        self.run_test(DivModule(), (x, y))
3350
        self.run_test(PowModule(), (x, z))
3351

3352
        x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3353
        y = torch.tensor([2, 3, 5], dtype=torch.int64)
3354
        z = torch.tensor([1], dtype=torch.int32)
3355
        self.run_test(AddModule(), (x, y))
3356
        self.run_test(SubModule(), (x, y))
3357
        self.run_test(MulModule(), (x, y))
3358
        self.run_test(DivModule(), (x, y))
3359
        self.run_test(PowModule(), (x, z))
3360

3361
    def test_mul_bool(self):
3362
        class MyModel(torch.nn.Module):
3363
            def forward(self, x, y):
3364
                return torch.mul(x, y)
3365

3366
        x_t = torch.tensor([True, False, True, False])
3367
        y_t = torch.tensor([True, True, False, False])
3368
        z_t = torch.tensor([1.0, 2.0, 3.0, 0.0])
3369
        self.run_test(MyModel(), (x_t, y_t))
3370
        self.run_test(MyModel(), (x_t, z_t))
3371
        self.run_test(MyModel(), (z_t, y_t))
3372

3373
    # fmod was added in version 10
3374
    @skipIfUnsupportedMinOpsetVersion(10)
3375
    @skipIfUnsupportedMaxOpsetVersion(13)
3376
    def test_mod_with_low_precision(self):
3377
        class ModModule(torch.nn.Module):
3378
            def forward(self, x, y):
3379
                return torch.fmod(x, y).to(dtype=torch.long)
3380

3381
        x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3382
        y = torch.tensor([2, 3, 5], dtype=torch.uint8)
3383
        self.run_test(ModModule(), (x, y))
3384

3385
        x = torch.tensor([2, 3, 5], dtype=torch.int8)
3386
        y = torch.tensor([2, 3, 5], dtype=torch.int8)
3387
        self.run_test(ModModule(), (x, y))
3388

3389
        x = torch.tensor([2, 3, 5], dtype=torch.int16)
3390
        y = torch.tensor([2, 3, 5], dtype=torch.int16)
3391
        self.run_test(ModModule(), (x, y))
3392

3393
        x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3394
        y = torch.tensor([2, 3, 5], dtype=torch.int32)
3395
        self.run_test(ModModule(), (x, y))
3396

3397
        x = torch.tensor([2, 3, 5], dtype=torch.uint8)
3398
        y = torch.tensor([2, 3, 5], dtype=torch.float64)
3399
        self.run_test(ModModule(), (x, y))
3400

3401
    @skipIfUnsupportedMinOpsetVersion(9)
3402
    def test_empty_constant_shape(self):
3403
        class Zeros(torch.nn.Module):
3404
            def forward(self, x):
3405
                y = torch.zeros(())
3406
                y += x
3407
                return y
3408

3409
        x = torch.tensor(42.0)
3410
        self.run_test(Zeros(), x)
3411

3412
        class Ones(torch.nn.Module):
3413
            def forward(self, x):
3414
                y = torch.ones(())
3415
                y += x
3416
                return y
3417

3418
        x = torch.tensor(42.0)
3419
        self.run_test(Ones(), x)
3420

3421
        class Full(torch.nn.Module):
3422
            def forward(self, x):
3423
                y = torch.full((), 1.0)
3424
                y += x
3425
                return y
3426

3427
        x = torch.tensor(42.0)
3428
        self.run_test(Full(), x)
3429

3430
        class Empty(torch.nn.Module):
3431
            def forward(self, x):
3432
                y = torch.empty(()).fill_(0)
3433
                y += x
3434
                return y
3435

3436
        x = torch.tensor(42.0)
3437
        self.run_test(Empty(), x)
3438

3439
    def test_std(self):
3440
        class StandardDeviation(torch.nn.Module):
3441
            def forward(self, input):
3442
                return torch.std(input, unbiased=False)
3443

3444
        x = torch.randn(2, 3, 4)
3445
        model = StandardDeviation()
3446
        self.run_test(model, x)
3447

3448
        class StandardDeviationUnbiased(torch.nn.Module):
3449
            def forward(self, input):
3450
                return torch.std(input, unbiased=True)
3451

3452
        model = StandardDeviationUnbiased()
3453
        self.run_test(model, x)
3454

3455
    def test_std_along_dims(self):
3456
        class StandardDeviation(torch.nn.Module):
3457
            def forward(self, input):
3458
                return torch.std(input, dim=(0, 1), unbiased=False)
3459

3460
        x = torch.randn(2, 3, 4)
3461
        model = StandardDeviation()
3462
        self.run_test(model, x)
3463

3464
        class StandardDeviationUnbiased(torch.nn.Module):
3465
            def forward(self, input):
3466
                return torch.std(input, dim=(0, 1), unbiased=True)
3467

3468
        x = torch.randn(2, 3, 4)
3469
        model = StandardDeviationUnbiased()
3470
        self.run_test(model, x)
3471

3472
    def test_std_keepdim(self):
3473
        class StandardDeviation(torch.nn.Module):
3474
            def forward(self, input):
3475
                return torch.std(input, dim=(0, 1), unbiased=False, keepdim=True)
3476

3477
        x = torch.randn(2, 3, 4)
3478
        model = StandardDeviation()
3479
        self.run_test(model, x)
3480

3481
        class StandardDeviationUnbiased(torch.nn.Module):
3482
            def forward(self, input):
3483
                return torch.std(input, dim=(0, 1), unbiased=True, keepdim=True)
3484

3485
        x = torch.randn(2, 3, 4)
3486
        model = StandardDeviationUnbiased()
3487
        self.run_test(model, x)
3488

3489
    def test_std_correction(self):
3490
        class StandardDeviation(torch.nn.Module):
3491
            def forward(self, input):
3492
                return torch.std(input, dim=(0, 1), correction=3, keepdim=True)
3493

3494
        x = torch.randn(2, 3, 4)
3495
        model = StandardDeviation()
3496
        self.run_test(model, x)
3497

3498
    def test_var(self):
3499
        class Variance(torch.nn.Module):
3500
            def forward(self, input):
3501
                return torch.var(input, unbiased=False)
3502

3503
        x = torch.randn(2, 3, 4)
3504
        model = Variance()
3505
        self.run_test(model, x)
3506

3507
        class VarianceUnbiased(torch.nn.Module):
3508
            def forward(self, input):
3509
                return torch.var(input, unbiased=True)
3510

3511
        model = VarianceUnbiased()
3512
        self.run_test(model, x)
3513

3514
        class VarianceSqrt(torch.nn.Module):
3515
            def forward(self, input):
3516
                y = torch.var(input, 1)
3517
                return torch.sqrt(y + 1e-8)
3518

3519
        x = torch.randn(1, 2, 3, 300, 300)
3520
        model = VarianceSqrt()
3521
        self.run_test(model, x)
3522

3523
    def test_var_along_dims(self):
3524
        class Variance(torch.nn.Module):
3525
            def forward(self, input):
3526
                return torch.var(input, dim=(0, 1), unbiased=False)
3527

3528
        x = torch.randn(2, 3, 4)
3529
        model = Variance()
3530
        self.run_test(model, x)
3531

3532
        class VarianceUnbiased(torch.nn.Module):
3533
            def forward(self, input):
3534
                return torch.var(input, dim=(0, 1), unbiased=True)
3535

3536
        x = torch.randn(2, 3, 4)
3537
        model = VarianceUnbiased()
3538
        self.run_test(model, x)
3539

3540
    def test_var_keepdim(self):
3541
        class Variance(torch.nn.Module):
3542
            def forward(self, input):
3543
                return torch.var(input, dim=(0, 1), unbiased=False, keepdim=True)
3544

3545
        x = torch.randn(2, 3, 4)
3546
        model = Variance()
3547
        self.run_test(model, x)
3548

3549
        class VarianceUnbiased(torch.nn.Module):
3550
            def forward(self, input):
3551
                return torch.var(input, dim=(0, 1), unbiased=True, keepdim=True)
3552

3553
        x = torch.randn(2, 3, 4)
3554
        model = VarianceUnbiased()
3555
        self.run_test(model, x)
3556

3557
    def test_var_correction(self):
3558
        class Variance(torch.nn.Module):
3559
            def forward(self, input):
3560
                return torch.var(input, dim=(0, 1), correction=3, keepdim=True)
3561

3562
        x = torch.randn(2, 3, 4)
3563
        model = Variance()
3564
        self.run_test(model, x)
3565

3566
    def test_var_mean(self):
3567
        class Variance(torch.nn.Module):
3568
            def forward(self, input):
3569
                return torch.var_mean(input, unbiased=False)
3570

3571
        x = torch.randn(2, 3, 4)
3572
        model = Variance()
3573
        self.run_test(model, x)
3574

3575
        class VarianceUnbiased(torch.nn.Module):
3576
            def forward(self, input):
3577
                return torch.var_mean(input, unbiased=True)
3578

3579
        model = VarianceUnbiased()
3580
        self.run_test(model, x)
3581

3582
    def test_var_mean_along_dims(self):
3583
        class Variance(torch.nn.Module):
3584
            def forward(self, input):
3585
                return torch.var_mean(input, dim=(0, 1), unbiased=False)
3586

3587
        x = torch.randn(2, 3, 4)
3588
        model = Variance()
3589
        self.run_test(model, x)
3590

3591
        class VarianceUnbiased(torch.nn.Module):
3592
            def forward(self, input):
3593
                return torch.var_mean(input, dim=(0, 1), unbiased=True)
3594

3595
        x = torch.randn(2, 3, 4)
3596
        model = VarianceUnbiased()
3597
        self.run_test(model, x)
3598

3599
    def test_var_mean_mixed_dims(self):
3600
        class ReverseDims(torch.nn.Module):
3601
            def forward(self, input):
3602
                return torch.var_mean(input, dim=(2, 1), unbiased=False)
3603

3604
        x = torch.randn(2, 3, 4)
3605
        model = ReverseDims()
3606
        self.run_test(model, x)
3607

3608
        class SkipDims(torch.nn.Module):
3609
            def forward(self, input):
3610
                return torch.var_mean(input, dim=(0, 2), unbiased=False)
3611

3612
        x = torch.randn(2, 3, 4)
3613
        model = SkipDims()
3614
        self.run_test(model, x)
3615

3616
        class NonZeroDims(torch.nn.Module):
3617
            def forward(self, input):
3618
                return torch.var_mean(input, dim=(1, 2), unbiased=False)
3619

3620
        x = torch.randn(2, 3, 4)
3621
        model = NonZeroDims()
3622
        self.run_test(model, x)
3623

3624
    def test_var_mean_keepdim(self):
3625
        class Variance(torch.nn.Module):
3626
            def forward(self, input):
3627
                return torch.var_mean(input, dim=(0, 1), unbiased=False, keepdim=True)
3628

3629
        x = torch.randn(2, 3, 4)
3630
        model = Variance()
3631
        self.run_test(model, x)
3632

3633
        class VarianceUnbiased(torch.nn.Module):
3634
            def forward(self, input):
3635
                return torch.var_mean(input, dim=(0, 1), unbiased=True, keepdim=True)
3636

3637
        x = torch.randn(2, 3, 4)
3638
        model = VarianceUnbiased()
3639
        self.run_test(model, x)
3640

3641
    def test_var_mean_correction(self):
3642
        class Variance(torch.nn.Module):
3643
            def forward(self, input):
3644
                return torch.var_mean(input, dim=(0, 1), correction=3, keepdim=True)
3645

3646
        x = torch.randn(2, 3, 4)
3647
        model = Variance()
3648
        self.run_test(model, x)
3649

3650
    def test_std_mean(self):
3651
        class StandardDeviation(torch.nn.Module):
3652
            def forward(self, input):
3653
                return torch.std_mean(input, unbiased=False)
3654

3655
        x = torch.randn(2, 3, 4)
3656
        model = StandardDeviation()
3657
        self.run_test(model, x)
3658

3659
        class StandardDeviationUnbiased(torch.nn.Module):
3660
            def forward(self, input):
3661
                return torch.std_mean(input, unbiased=True)
3662

3663
        model = StandardDeviationUnbiased()
3664
        self.run_test(model, x)
3665

3666
    def test_std_mean_along_dims(self):
3667
        class StandardDeviation(torch.nn.Module):
3668
            def forward(self, input):
3669
                return torch.std_mean(input, dim=(0, 1), unbiased=False)
3670

3671
        x = torch.randn(2, 3, 4)
3672
        model = StandardDeviation()
3673
        self.run_test(model, x)
3674

3675
        class VarianceUnbiased(torch.nn.Module):
3676
            def forward(self, input):
3677
                return torch.std_mean(input, dim=(0, 1), unbiased=True)
3678

3679
        x = torch.randn(2, 3, 4)
3680
        model = VarianceUnbiased()
3681
        self.run_test(model, x)
3682

3683
    def test_std_mean_keepdim(self):
3684
        class StandardDeviation(torch.nn.Module):
3685
            def forward(self, input):
3686
                return torch.std_mean(input, dim=(0, 1), unbiased=False, keepdim=True)
3687

3688
        x = torch.randn(2, 3, 4)
3689
        model = StandardDeviation()
3690
        self.run_test(model, x)
3691

3692
        class StandardDeviationUnbiased(torch.nn.Module):
3693
            def forward(self, input):
3694
                return torch.std_mean(input, dim=(0, 1), unbiased=True, keepdim=True)
3695

3696
        x = torch.randn(2, 3, 4)
3697
        model = StandardDeviationUnbiased()
3698
        self.run_test(model, x)
3699

3700
    def test_std_mean_correction(self):
3701
        class StandardDeviation(torch.nn.Module):
3702
            def forward(self, input):
3703
                return torch.var_mean(input, dim=(0, 1), correction=3, keepdim=True)
3704

3705
        x = torch.randn(2, 3, 4)
3706
        model = StandardDeviation()
3707
        self.run_test(model, x)
3708

3709
    def test_bitshift(self):
3710
        class BitshiftModel(torch.nn.Module):
3711
            def forward(self, input):
3712
                return (
3713
                    input >> 1,
3714
                    input << 3,
3715
                    input >> torch.tensor([1, 2]),
3716
                    input << 4,
3717
                )
3718

3719
        input = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2)
3720
        self.run_test(BitshiftModel(), input)
3721

3722
    # uint8 not implemented in ORT for Mul used in
3723
    # exporting bitshift for opset_version < 10
3724
    @skipIfUnsupportedMinOpsetVersion(11)
3725
    def test_bitshift_uint8(self):
3726
        class BitshiftModel(torch.nn.Module):
3727
            def forward(self, input, input2):
3728
                return (
3729
                    input >> 1,
3730
                    input << 3,
3731
                    input2 >> torch.tensor([1, 2], dtype=torch.uint8),
3732
                    input2 << 4,
3733
                )
3734

3735
        input = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
3736
        input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
3737
        self.run_test(BitshiftModel(), (input, input2))
3738

3739
    def test_narrow(self):
3740
        class NarrowModel(torch.nn.Module):
3741
            def forward(self, input):
3742
                return torch.narrow(input, 0, 0, 2)
3743

3744
        x = torch.randn(3, 3, requires_grad=True)
3745
        self.run_test(NarrowModel(), x)
3746

3747
    @skipIfUnsupportedMinOpsetVersion(11)
3748
    def test_narrow_dynamic(self):
3749
        class NarrowModel(torch.nn.Module):
3750
            def forward(self, input):
3751
                return torch.narrow(input, 0, 0, input.shape[0] - 1)
3752

3753
        x = torch.randn(3, 3, requires_grad=True)
3754
        self.run_test(NarrowModel(), x)
3755

3756
    @skipIfUnsupportedMinOpsetVersion(9)
3757
    def test_index_fill(self):
3758
        class IndexFillModel(torch.nn.Module):
3759
            def forward(self, input):
3760
                index = torch.tensor([2, 0])
3761
                return input.index_fill(2, index, -1)
3762

3763
        x = torch.randn(3, 4, 5, requires_grad=True)
3764
        self.run_test(IndexFillModel(), x)
3765

3766
    @skipIfUnsupportedMinOpsetVersion(9)
3767
    def test_index_copy(self):
3768
        class IndexCopyModel(torch.nn.Module):
3769
            def __init__(self, dim):
3770
                super().__init__()
3771
                self.dim = dim
3772

3773
            def forward(self, input):
3774
                index = torch.tensor([2, 0])
3775
                source = torch.ones(3, 2, 5)
3776
                return input.index_copy(self.dim, index, source)
3777

3778
        x = torch.randn(3, 4, 5, requires_grad=True)
3779
        for dim in (1, -2):
3780
            self.run_test(IndexCopyModel(dim), x)
3781

3782
    def test_select(self):
3783
        class Select(torch.nn.Module):
3784
            def forward(self, x):
3785
                return x[:, 1]
3786

3787
        x = torch.randn(3, 4)
3788
        self.run_test(Select(), x)
3789

3790
    def test_select_negative_index(self):
3791
        class Select(torch.nn.Module):
3792
            def forward(self, x):
3793
                return x[:, -1]
3794

3795
        x = torch.randn(3, 4)
3796
        self.run_test(Select(), x)
3797

3798
    def test_index_select_constant_scaler_index(self):
3799
        class IndexSelectScalerIndexModel(torch.nn.Module):
3800
            def forward(self, x):
3801
                index = 2
3802
                return torch.index_select(x, 1, torch.tensor(index))
3803

3804
        x = torch.randn(3, 4)
3805
        self.run_test(IndexSelectScalerIndexModel(), x)
3806

3807
    def test_index_select_scaler_index(self):
3808
        class IndexSelectScalerIndexModel(torch.nn.Module):
3809
            def __init__(self, index_base):
3810
                super().__init__()
3811
                self.index_base = torch.tensor(index_base)
3812

3813
            def forward(self, x, index_offset):
3814
                index = self.index_base + index_offset
3815
                return torch.index_select(x, 1, index)
3816

3817
        x = torch.randn(3, 4)
3818
        offset = 2
3819
        index_offset = torch.tensor(offset)
3820
        base = 1
3821
        self.run_test(IndexSelectScalerIndexModel(base), (x, index_offset))
3822

3823
    def test_take(self):
3824
        class TakeModel(torch.nn.Module):
3825
            def forward(self, x, y):
3826
                return torch.take(x, y)
3827

3828
        x = torch.randn(6, 4, 3, 3)
3829
        y = torch.tensor([4, 1, 7, 15, 63])
3830
        self.run_test(TakeModel(), (x, y))
3831

3832
    def test_topk(self):
3833
        class MyModule(torch.nn.Module):
3834
            def forward(self, x):
3835
                return torch.topk(x, 3)
3836

3837
        x = torch.arange(1.0, 6.0, requires_grad=True)
3838
        self.run_test(MyModule(), x)
3839

3840
    @skipIfUnsupportedMinOpsetVersion(10)
3841
    def test_topk_int32_k(self):
3842
        class Model(torch.nn.Module):
3843
            def forward(self, x, k):
3844
                return torch.topk(x, k)
3845

3846
        x = torch.arange(1.0, 6.0)
3847
        k = torch.tensor(3, dtype=torch.int32)
3848
        self.run_test(Model(), (x, k))
3849

3850
    @skipIfUnsupportedMinOpsetVersion(11)
3851
    def test_topk_smallest_unsorted(self):
3852
        class MyModule(torch.nn.Module):
3853
            def forward(self, x, k):
3854
                # When sorted=False, order of elements in the outout tensors
3855
                # are not expected to match between PyTorch and ORT
3856
                topk_unsorted = torch.topk(x, k, largest=False, sorted=False)
3857
                topk_sorted = torch.topk(x, k, largest=False, sorted=True)
3858
                return topk_sorted, torch.sort(topk_unsorted.values).values
3859

3860
        x = torch.arange(1.0, 6.0, requires_grad=True)
3861
        k = torch.tensor(3)
3862
        self.run_test(MyModule(), (x, k))
3863

3864
    @skipIfUnsupportedMinOpsetVersion(10)
3865
    def test_topk_script(self):
3866
        class MyModuleDynamic(torch.jit.ScriptModule):
3867
            @torch.jit.script_method
3868
            def forward(self, x, k):
3869
                return torch.topk(x, k)
3870

3871
        x = torch.arange(1.0, 6.0, requires_grad=True)
3872
        k = torch.tensor(3)
3873
        self.run_test(MyModuleDynamic(), (x, k))
3874

3875
    @skipScriptTest()  # Python builtin apply of FunctionMeta object is currently not supported in Torchscript.
3876
    @skipIfUnsupportedMinOpsetVersion(11)  # Clip op min is an input since opset 11.
3877
    def test_auto_grad(self):
3878
        class MyClip(torch.autograd.Function):
3879
            @staticmethod
3880
            def forward(ctx, input, scalar):
3881
                ctx.save_for_backward(input)
3882
                return input.clamp(min=scalar)
3883

3884
        class MyRelu(torch.autograd.Function):
3885
            @staticmethod
3886
            def forward(ctx, input):
3887
                ctx.save_for_backward(input)
3888
                return input.clamp(min=0)
3889

3890
        def symbolic_python_op(
3891
            ctx: torch.onnx.SymbolicContext, g: torch._C.Graph, *args, **kwargs
3892
        ):
3893
            n = ctx.cur_node
3894
            name = kwargs["name"]
3895
            if name == "MyClip":
3896
                return g.op("Clip", args[0], args[1], outputs=n.outputsSize())
3897
            elif name == "MyRelu":
3898
                return g.op("Relu", args[0], outputs=n.outputsSize())
3899
            else:
3900
                # TODO(justinchuby): Remove reference to internal names in symbolic_helper
3901
                return torch.onnx.symbolic_helper._unimplemented(
3902
                    "prim::PythonOp", "unknown node kind: " + name
3903
                )
3904

3905
        torch.onnx.register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)
3906
        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "prim::PythonOp", 1)
3907

3908
        class MyClipModule(torch.nn.Module):
3909
            def forward(self, x, min):
3910
                return MyClip.apply(x, min)
3911

3912
        x = torch.randn(3, 3)
3913
        min = torch.tensor([0.0])
3914
        self.run_test(MyClipModule(), (x, min))
3915

3916
        class MyReluModule(torch.nn.Module):
3917
            def forward(self, x):
3918
                return MyRelu.apply(x)
3919

3920
        x = torch.randn(3, 3)
3921
        self.run_test(MyReluModule(), x)
3922

3923
    def test_clip_int(self):
3924
        class MyClipInt(torch.nn.Module):
3925
            def forward(self, x):
3926
                return torch.clamp(x, 0, 1)
3927

3928
        self.run_test(MyClipInt(), torch.randn(3, 3).to(torch.int64))
3929

3930
    def test_relu_int(self):
3931
        self.run_test(torch.nn.ReLU(), torch.randn(3, 3).to(torch.int32))
3932

3933
    def test_pad_int(self):
3934
        class MyPadInt(torch.nn.Module):
3935
            def forward(self, x):
3936
                return torch.nn.functional.pad(x, (1, 1))
3937

3938
        self.run_test(MyPadInt(), torch.randn(3, 3).to(torch.int32))
3939

3940
    def test_min_int(self):
3941
        class MyMinInt(torch.nn.Module):
3942
            def forward(self, x):
3943
                return torch.min(x, x + 1)
3944

3945
        self.run_test(MyMinInt(), torch.randn(3, 3).to(torch.int32))
3946

3947
    def test_max_int(self):
3948
        class MyMaxnInt(torch.nn.Module):
3949
            def forward(self, x):
3950
                return torch.max(x, x + 1)
3951

3952
        self.run_test(MyMaxnInt(), torch.randn(3, 3).to(torch.int32))
3953

3954
    @skipIfUnsupportedOpsetVersion([7])
3955
    def test_normalize(self):
3956
        class Model(torch.nn.Module):
3957
            def forward(self, x):
3958
                return torch.nn.functional.normalize(x)
3959

3960
        x = torch.randn(3, 3)
3961
        self.run_test(Model(), x)
3962

3963
    def test_norm_with_dtype(self):
3964
        class Model(torch.nn.Module):
3965
            def forward(self, x):
3966
                # TODO(bowbao): There is a slight gap in today's test infrastructure
3967
                # to directly test aten ops. OpInfo `torch.norm`` in `common_methods_invocations.py`
3968
                # will not decompose to below aten op.
3969
                return torch.ops.aten.norm(
3970
                    x, p=2, dim=[1], keepdim=True, dtype=torch.float64
3971
                )
3972

3973
        x = torch.randn(3, 3)
3974
        self.run_test(Model(), x)
3975

3976
    def test_layer_norm(self):
3977
        # As layer_norm works on the last D dimension, please keep
3978
        # this test case at least three dimension to prevent the
3979
        # situation of axis=2 mapping to the same axis as axis=-2
3980
        for elementwise_affine in (True, False):
3981
            for bias in (True, False):
3982
                model = torch.nn.LayerNorm(
3983
                    [10, 10, 10], elementwise_affine=elementwise_affine, bias=bias
3984
                )
3985
                x = torch.randn(20, 5, 10, 10, 10)
3986
                self.run_test(model, x)
3987

3988
    def test_batchnorm1d(self):
3989
        x = torch.randn(10, 10)
3990
        model = torch.nn.BatchNorm1d(10, affine=True)
3991
        self.run_test(model, x)
3992

3993
        x = torch.randn(10, 10, 128)
3994
        self.run_test(model, x)
3995

3996
    def test_batchnorm1d_noaffine(self):
3997
        x = torch.randn(10, 10)
3998
        model = torch.nn.BatchNorm1d(10, affine=False)
3999
        self.run_test(model, x)
4000

4001
        x = torch.randn(10, 10, 128)
4002
        self.run_test(model, x)
4003

4004
    def test_batchnorm1d_norunningstats(self):
4005
        x = torch.randn(10, 10)
4006
        model = torch.nn.BatchNorm1d(10, track_running_stats=False)
4007
        self.run_test(model, x)
4008

4009
        x = torch.randn(10, 10, 128)
4010
        self.run_test(model, x)
4011

4012
    def test_batchnorm2d(self):
4013
        x = torch.randn(10, 3, 128, 128)
4014
        model = torch.nn.BatchNorm2d(3, affine=True)
4015
        self.run_test(model, x)
4016

4017
    def test_batchnorm2d_noaffine(self):
4018
        x = torch.randn(10, 3, 128, 128)
4019
        model = torch.nn.BatchNorm2d(3, affine=False)
4020
        self.run_test(model, x)
4021

4022
    def test_batchnorm2d_norunningstats(self):
4023
        x = torch.randn(10, 3, 128, 128)
4024
        model = torch.nn.BatchNorm2d(3, track_running_stats=False)
4025
        self.run_test(model, x)
4026

4027
    def test_batchnorm3d(self):
4028
        x = torch.randn(10, 3, 64, 64, 64)
4029
        model = torch.nn.BatchNorm3d(3, affine=True)
4030
        self.run_test(model, x)
4031

4032
    def test_batchnorm3d_noaffine(self):
4033
        x = torch.randn(10, 3, 64, 64, 64)
4034
        model = torch.nn.BatchNorm3d(3, affine=False)
4035
        self.run_test(model, x)
4036

4037
    @skipIfUnsupportedMinOpsetVersion(
4038
        9
4039
    )  # Because ConstantOfShape op is not supported for opset < 9
4040
    def test_instancenorm1d_runningstats(self):
4041
        x = torch.randn(10, 5, 128)
4042
        model = torch.nn.InstanceNorm1d(5, affine=True, track_running_stats=True)
4043
        self.run_test(model, x)
4044

4045
        model = torch.nn.InstanceNorm1d(5, affine=False, track_running_stats=True)
4046
        self.run_test(model, x)
4047

4048
    def test_instancenorm1d_norunningstats(self):
4049
        x = torch.randn(10, 5, 128)
4050
        model = torch.nn.InstanceNorm1d(5, affine=True, track_running_stats=False)
4051
        self.run_test(model, x)
4052

4053
        model = torch.nn.InstanceNorm1d(5, affine=False, track_running_stats=False)
4054
        self.run_test(model, x)
4055

4056
    @skipIfUnsupportedMinOpsetVersion(
4057
        9
4058
    )  # Because ConstantOfShape op is not supported for opset < 9
4059
    def test_instancenorm2d_runningstats(self):
4060
        x = torch.randn(10, 3, 128, 128)
4061
        model = torch.nn.InstanceNorm2d(3, affine=True, track_running_stats=True)
4062
        self.run_test(model, x)
4063

4064
        model = torch.nn.InstanceNorm2d(3, affine=False, track_running_stats=True)
4065
        self.run_test(model, x)
4066

4067
    def test_instancenorm2d_norunningstats(self):
4068
        x = torch.randn(10, 3, 128, 128)
4069
        model = torch.nn.InstanceNorm2d(3, affine=True, track_running_stats=False)
4070
        self.run_test(model, x)
4071

4072
        model = torch.nn.InstanceNorm2d(3, affine=False, track_running_stats=False)
4073
        self.run_test(model, x)
4074

4075
    @skipIfUnsupportedMinOpsetVersion(
4076
        9
4077
    )  # Because ConstantOfShape op is not supported for opset < 9
4078
    def test_instancenorm3d_runningstats(self):
4079
        x = torch.randn(10, 3, 64, 64, 64)
4080
        model = torch.nn.InstanceNorm3d(3, affine=True, track_running_stats=True)
4081
        self.run_test(model, x)
4082

4083
        model = torch.nn.InstanceNorm3d(3, affine=False, track_running_stats=True)
4084
        self.run_test(model, x)
4085

4086
    def test_instancenorm3d_norunningstats(self):
4087
        x = torch.randn(10, 3, 64, 64, 64)
4088
        model = torch.nn.InstanceNorm3d(3, affine=True, track_running_stats=False)
4089
        self.run_test(model, x)
4090

4091
        model = torch.nn.InstanceNorm3d(3, affine=False, track_running_stats=False)
4092
        self.run_test(model, x)
4093

4094
    @skipIfUnsupportedMinOpsetVersion(9)
4095
    def test_scatter_with_scalar(self):
4096
        class ScatterModel(torch.nn.Module):
4097
            def forward(self, input, indices):
4098
                values = 1.0
4099
                return input.scatter(1, indices, values)
4100

4101
        input = torch.tensor(
4102
            [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float64
4103
        )
4104
        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4105
        self.run_test(ScatterModel(), input_args=(input, indices))
4106

4107
    @skipIfUnsupportedMinOpsetVersion(9)
4108
    def test_scatter_with_scalar_different_types(self):
4109
        # Tests the case when scalar src (updates values) type is different
4110
        # from self type. Happens only with scalar src - PyTorch does not
4111
        # allow this when src is a tensor.
4112
        class ScatterModel(torch.nn.Module):
4113
            def forward(self, input, indices):
4114
                values = 1.0
4115
                return input.scatter(1, indices, values)
4116

4117
        input = torch.tensor(
4118
            [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float32
4119
        )
4120
        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4121
        self.run_test(ScatterModel(), input_args=(input, indices))
4122

4123
    @skipIfUnsupportedMinOpsetVersion(9)
4124
    def test_scatter(self):
4125
        class ScatterModel(torch.nn.Module):
4126
            def forward(self, input, indices, values):
4127
                return input.scatter(1, indices, values)
4128

4129
        input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
4130
        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4131
        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
4132
        self.run_test(ScatterModel(), input_args=(input, indices, values))
4133

4134
        input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
4135
        indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64)
4136
        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
4137
        self.run_test(ScatterModel(), (input, indices, values))
4138

4139
        input = torch.zeros(3, 4, 5, 6)
4140
        indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64)
4141
        indices = indices.view(3, 2, 1, 1).expand(3, 2, 5, 6)
4142
        values = torch.arange(3 * 2 * 5 * 6, dtype=torch.float32).view(3, 2, 5, 6)
4143
        self.run_test(ScatterModel(), (input, indices, values))
4144

4145
        input = torch.zeros(3, 4, 2)
4146
        indices = torch.tensor([[[1, 0], [0, 2]], [[1, 1], [0, 1]], [[2, 1], [2, 2]]])
4147
        values = torch.arange(3 * 2 * 2, dtype=torch.float32).view(3, 2, 2)
4148
        self.run_test(ScatterModel(), (input, indices, values))
4149

4150
    @skipIfUnsupportedMinOpsetVersion(9)
4151
    def test_scatter_add(self):
4152
        class ScatterModel(torch.nn.Module):
4153
            def forward(self, input, indices, values):
4154
                return input.scatter_add(1, indices, values)
4155

4156
        input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
4157
        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4158
        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
4159
        self.run_test(ScatterModel(), input_args=(input, indices, values))
4160

4161
        @torch.jit.script
4162
        def scatter_sum(src: Tensor, index: Tensor):
4163
            size = src.size()
4164
            out = torch.zeros(size, dtype=src.dtype)
4165
            return out.scatter_add_(1, index, src)
4166

4167
        class ScatterModel(torch.nn.Module):
4168
            def forward(self, src, index):
4169
                return scatter_sum(src, index)
4170

4171
        src = torch.rand(3, 2)
4172
        index = torch.tensor([[0, 1], [0, 1], [0, 1]], dtype=torch.int64)
4173
        self.run_test(ScatterModel(), (src, index))
4174

4175
    @skipIfUnsupportedMinOpsetVersion(16)
4176
    def test_scatter_add_index_not_unique(self):
4177
        class ScatterModel(torch.nn.Module):
4178
            def forward(self, input, indices, values):
4179
                return input.scatter_add(1, indices, values)
4180

4181
        input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
4182
        indices = torch.tensor([[0, 0], [1, 1], [2, 2]], dtype=torch.int64)
4183
        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
4184
        self.run_test(ScatterModel(), input_args=(input, indices, values))
4185

4186
        @torch.jit.script
4187
        def scatter_sum(src: Tensor, index: Tensor):
4188
            size = src.size()
4189
            out = torch.zeros(size, dtype=src.dtype)
4190
            return out.scatter_add_(1, index, src)
4191

4192
        class ScatterModel(torch.nn.Module):
4193
            def forward(self, src, index):
4194
                return scatter_sum(src, index)
4195

4196
        src = torch.rand(3, 2)
4197
        index = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64)
4198
        self.run_test(ScatterModel(), (src, index))
4199

4200
    @skipIfUnsupportedMinOpsetVersion(16)
4201
    def test_scatter_add_different_size_index_src(self):
4202
        class ScatterModel(torch.nn.Module):
4203
            def forward(self, input, indices, src):
4204
                return input.scatter_add(0, indices, src)
4205

4206
        src = torch.ones((2, 5))
4207
        input = torch.zeros(3, 5, dtype=src.dtype)
4208
        indices = torch.tensor([[0, 1, 2, 0, 0]])
4209
        self.run_test(ScatterModel(), input_args=(input, indices, src))
4210

4211
    @common_utils.parametrize(
4212
        "src, indices",
4213
        [
4214
            common_utils.subtest(
4215
                [torch.ones((1, 5)), torch.tensor([[0, 1, 2, 0, 0]])],
4216
                name="src_indices_dynamic_combination1",
4217
            ),
4218
            common_utils.subtest(
4219
                [torch.ones((2, 5)), torch.tensor([[0, 1, 2, 0, 0], [1, 0, 2, 1, 2]])],
4220
                name="src_indices_dynamic_combination2",
4221
            ),
4222
            common_utils.subtest(
4223
                [torch.ones((3, 5)), torch.tensor([[0, 1, 2, 0, 0], [1, 0, 2, 1, 2]])],
4224
                name="src_indices_dynamic_combination3",
4225
            ),
4226
            common_utils.subtest(
4227
                [torch.ones((3, 5)), torch.tensor([[0, 1, 2, 0], [1, 0, 2, 1]])],
4228
                name="src_indices_dynamic_combination4",
4229
            ),
4230
        ],
4231
    )
4232
    @skipIfUnsupportedMinOpsetVersion(16)
4233
    def test_scatter_add_dynamic_index(self, src, indices):
4234
        class ScatterModel(torch.nn.Module):
4235
            def forward(self, input, indices, src):
4236
                return input.scatter_add(0, indices, src)
4237

4238
        input = torch.zeros(3, 5, dtype=src.dtype)
4239
        self.run_test(
4240
            ScatterModel(),
4241
            input_args=(input, indices, src),
4242
            input_names=["input", "indices", "src"],
4243
            dynamic_axes={"indices": {0: "a", 1: "b"}, "src": {0: "c", 1: "d"}},
4244
        )
4245

4246
    @skipIfUnsupportedMinOpsetVersion(16)
4247
    def test_scatter_reduce(self):
4248
        class Model(torch.nn.Module):
4249
            def __init__(self):
4250
                super().__init__()
4251

4252
            def forward(self, x, index, input):
4253
                y_max = input.scatter_reduce(0, index, x, reduce="amax")
4254
                y_sum = input.scatter_reduce(0, index, x, reduce="sum")
4255
                y_min = input.scatter_reduce(0, index, x, reduce="amin")
4256
                y_mul = input.scatter_reduce(0, index, x, reduce="prod")
4257
                return y_max, y_sum, y_min, y_mul
4258

4259
        model = Model()
4260
        model.eval()
4261

4262
        src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
4263
        index = torch.tensor([0, 1, 0, 1, 2, 1])
4264
        input = torch.tensor([1.0, 2.0, 3.0, 8.0])
4265

4266
        self.run_test(model, (src, index, input))
4267

4268
    @skipIfUnsupportedMinOpsetVersion(16)
4269
    def test_scatter_reduce_self_rank_zero(self):
4270
        class Model(torch.nn.Module):
4271
            def __init__(self):
4272
                super().__init__()
4273

4274
            def forward(self, x, index, input):
4275
                y_max = input.scatter_reduce(0, index, x, reduce="amax")
4276
                y_sum = input.scatter_reduce(0, index, x, reduce="sum")
4277
                y_min = input.scatter_reduce(0, index, x, reduce="amin")
4278
                y_mul = input.scatter_reduce(0, index, x, reduce="prod")
4279
                return y_max, y_sum, y_min, y_mul
4280

4281
        model = Model()
4282
        model.eval()
4283

4284
        empty_tensor = torch.tensor([])
4285
        empty_idx = torch.tensor([], dtype=torch.int64)
4286

4287
        self.run_test(model, (empty_tensor, empty_idx, empty_tensor))
4288

4289
    @skipIfUnsupportedMinOpsetVersion(9)
4290
    def test_bucketize(self):
4291
        class BucketModel(torch.nn.Module):
4292
            def forward(self, input, boundaries):
4293
                return torch.bucketize(input, boundaries), torch.bucketize(
4294
                    input, boundaries, right=True
4295
                )
4296

4297
        input = torch.tensor([[2, 5, 10], [6, 8, 3]])
4298
        boundaries = torch.tensor([1, 5, 7, 8, 10])
4299
        self.run_test(BucketModel(), (input, boundaries))
4300

4301
    @skipIfUnsupportedMinOpsetVersion(9)
4302
    def test_one_hot(self):
4303
        class OneHot(torch.nn.Module):
4304
            def __init__(self, num_classes):
4305
                super().__init__()
4306
                self.num_classes = num_classes
4307

4308
            def forward(self, x):
4309
                return torch.nn.functional.one_hot(x, self.num_classes)
4310

4311
        x = torch.arange(10)
4312
        self.run_test(OneHot(15), (x))
4313

4314
        class OneHot(torch.nn.Module):
4315
            def forward(self, x, num_classes):
4316
                num_classes = num_classes.to(torch.int32)
4317
                return torch.nn.functional.one_hot(x, num_classes[0])
4318

4319
        x = torch.arange(10)
4320
        num_classes = 15 * torch.ones(1)
4321
        self.run_test(OneHot(), (x, num_classes))
4322

4323
    @skipIfUnsupportedMinOpsetVersion(9)
4324
    def test_gather(self):
4325
        class GatherModel(torch.nn.Module):
4326
            def forward(self, input, indices):
4327
                return input.gather(1, indices)
4328

4329
        input = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
4330
        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
4331
        self.run_test(GatherModel(), input_args=(input, indices))
4332

4333
    @skipScriptTest()  # Scripting error: Cannot instantiate nn module
4334
    def test_gather_constant_fold(self):
4335
        class GatherModule(torch.nn.Module):
4336
            def __init__(self):
4337
                super().__init__()
4338
                self.register_buffer("weight", torch.ones(5))
4339
                # torch.nn.Embedding is converted to ONNX::Gather.
4340
                # Constant folding will be triggerred for constant inputs.
4341
                # This pattern is common for constant mask inputs in transformer models.
4342
                self.embed = torch.nn.Embedding(8, 3)
4343

4344
            def forward(self, x):
4345
                # shape is of rank 0
4346
                shape = self.weight.shape[0]
4347
                m = 5 - shape
4348
                y = torch.ones(1, 4, dtype=torch.long)
4349
                return x.clamp(min=m), self.embed(y)
4350

4351
        x = torch.randn(1)
4352
        self.run_test(GatherModule(), (x,))
4353

4354
        class GatherModule(torch.nn.Module):
4355
            def __init__(self):
4356
                super().__init__()
4357
                self.register_buffer("weight", torch.ones(2))
4358

4359
            def forward(self, x):
4360
                # shape is of rank 0
4361
                shape = self.weight.shape[0]
4362
                pad = [1, shape, shape, shape]
4363
                zero_pad = torch.nn.ZeroPad2d(pad)
4364
                return zero_pad(x)
4365

4366
        x = torch.randn(1, 3, 2)
4367
        self.run_test(GatherModule(), (x,))
4368

4369
        class GatherModule(torch.nn.Module):
4370
            def __init__(self):
4371
                super().__init__()
4372
                self.register_buffer("rb", torch.randn(1, 1, 3, 1, 1))
4373

4374
            def forward(self, x):
4375
                x += self.rb[0]
4376
                return x
4377

4378
        x = torch.randn(1, 3, 224, 224)
4379
        self.run_test(
4380
            GatherModule(),
4381
            (x,),
4382
            dynamic_axes={
4383
                "input": {0: "batch", 2: "height", 3: "width"},
4384
                "output": {0: "batch", 1: "class", 2: "height", 3: "width"},
4385
            },
4386
            input_names=["input"],
4387
            output_names=["output"],
4388
        )
4389

4390
    @skipIfUnsupportedOpsetVersion([13])
4391
    @skipIfUnsupportedMinOpsetVersion(9)
4392
    def test_expand(self):
4393
        class ExpandModel(torch.nn.Module):
4394
            def forward(self, input):
4395
                return input.expand(2, 3, -1)
4396

4397
        input = torch.randn(2, 1, 4)
4398
        self.run_test(ExpandModel(), input_args=(input))
4399

4400
        class ExpandInferDimModel(torch.nn.Module):
4401
            def forward(self, input):
4402
                return input.expand(-1, input.size(0))
4403

4404
        input = torch.randn(3, 1)
4405
        self.run_test(ExpandInferDimModel(), input_args=(input))
4406

4407
        class ExpandTensorSizeModel(torch.nn.Module):
4408
            def forward(self, input, size):
4409
                return input.expand(size)
4410

4411
        input = torch.randn(
4412
            3,
4413
        )
4414
        size = torch.tensor(-1)
4415
        self.run_test(ExpandTensorSizeModel(), input_args=(input, size))
4416

4417
    @skipIfUnsupportedMinOpsetVersion(11)  # index_put is supported in opsets >= 11
4418
    def test_dynamic_expand_as(self):
4419
        class Model(torch.nn.Module):
4420
            def forward(self, x):
4421
                x[:, x.size(0) :] = 0
4422
                return x
4423

4424
        x = torch.ones(2, 5)
4425
        x2 = torch.randn(3, 4)
4426
        self.run_test(
4427
            Model(),
4428
            (x,),
4429
            input_names=["x"],
4430
            dynamic_axes={"x": [0, 1]},
4431
            additional_test_inputs=[x2],
4432
        )
4433

4434
        class Model(torch.nn.Module):
4435
            def forward(self, x):
4436
                x[:, x.size(0) :] = torch.tensor([1, 2, 3])
4437
                return x
4438

4439
        x = torch.ones(2, 5, 3)
4440
        x2 = torch.randn(3, 4, 3)
4441
        self.run_test(
4442
            Model(),
4443
            (x,),
4444
            input_names=["x"],
4445
            dynamic_axes={"x": [0, 1, 2]},
4446
            additional_test_inputs=[x2],
4447
        )
4448

4449
        class Model(torch.nn.Module):
4450
            def forward(self, x):
4451
                aa = torch.tensor([[0], [1], [2]])
4452
                return aa.expand_as(x)
4453

4454
        x = torch.ones(3, 2)
4455
        x2 = torch.randn(3, 5)
4456
        self.run_test(
4457
            Model(),
4458
            (x,),
4459
            input_names=["x"],
4460
            dynamic_axes={"x": [0, 1]},
4461
            additional_test_inputs=[x2],
4462
        )
4463

4464
    def test_multinomial(self):
4465
        class Multinomial(torch.nn.Module):
4466
            def forward(self, weight):
4467
                return torch.multinomial(weight, 3, replacement=True)
4468

4469
        class MultinomialNoReplacement(torch.nn.Module):
4470
            def forward(self, weight):
4471
                return torch.multinomial(weight, 1)
4472

4473
        weight = torch.tensor([[0, 10, 0, 0], [0, 0, 100, 0]], dtype=torch.float)
4474
        self.run_test(Multinomial(), (weight,))
4475
        self.run_test(MultinomialNoReplacement(), (weight,))
4476

4477
    def _test_reduced_ops(self, op):
4478
        class ReducedOpModule(torch.nn.Module):
4479
            def forward(self, input):
4480
                return op(input, dim=-1)
4481

4482
        if op != torch.mean:  # torch.mean only supports float types
4483
            x = torch.randint(10, (4, 4), dtype=torch.uint8)
4484
            self.run_test(ReducedOpModule(), x)
4485

4486
            x = torch.randint(10, (4, 4), dtype=torch.int8)
4487
            self.run_test(ReducedOpModule(), x)
4488

4489
            x = torch.randint(10, (4, 4), dtype=torch.int16)
4490
            self.run_test(ReducedOpModule(), x)
4491

4492
            x = torch.randint(10, (4, 4), dtype=torch.int32)
4493
            self.run_test(ReducedOpModule(), x)
4494

4495
            x = torch.randint(10, (4, 4), dtype=torch.int64)
4496
            self.run_test(ReducedOpModule(), x)
4497

4498
        # torch.mean only supports float types
4499
        # ORT does not support double ReduceProd for double
4500
        if op != torch.prod and op != torch.mean:
4501
            x = torch.randn(4, 5, dtype=torch.double)
4502
            self.run_test(ReducedOpModule(), x)
4503

4504
        if op != torch.prod:  # torch.prod not implemented for Half
4505
            x = torch.randn(4, 4, dtype=torch.half)
4506
            self.run_test(ReducedOpModule(), x)
4507

4508
        x = torch.randn(4, 5, dtype=torch.float)
4509
        self.run_test(ReducedOpModule(), x)
4510

4511
    def test_reduced_sum(self):
4512
        return self._test_reduced_ops(op=torch.sum)
4513

4514
    def test_reduced_mean(self):
4515
        return self._test_reduced_ops(op=torch.mean)
4516

4517
    def test_reduced_prod(self):
4518
        return self._test_reduced_ops(op=torch.prod)
4519

4520
    def test_reduced_sum_dtypes(self):
4521
        class NoDimModel(torch.nn.Module):
4522
            def forward(self, input):
4523
                return input.sum(dtype=torch.float)
4524

4525
        class DimModel(torch.nn.Module):
4526
            def forward(self, input):
4527
                return input.sum(dim=-1, dtype=torch.float)
4528

4529
        input = torch.randn((4, 4), dtype=torch.half)
4530
        self.run_test(NoDimModel(), input)
4531
        self.run_test(DimModel(), input)
4532

4533
    def test_reduced_min_max(self):
4534
        class ReducedMinMaxModule(torch.nn.Module):
4535
            def forward(self, input):
4536
                return torch.min(input, dim=-1)[0], torch.max(input, dim=0)[0]
4537

4538
        x = torch.randint(10, (4, 4), dtype=torch.int32)
4539
        self.run_test(ReducedMinMaxModule(), x)
4540

4541
        x = torch.randint(10, (4, 4), dtype=torch.int64)
4542
        self.run_test(ReducedMinMaxModule(), x)
4543

4544
        x = torch.randn(4, 5, dtype=torch.float)
4545
        self.run_test(ReducedMinMaxModule(), x)
4546

4547
    def test_reduce_log_sum_exp(self):
4548
        class ReduceLogSumExpModel(torch.nn.Module):
4549
            def forward(self, input):
4550
                a = torch.logsumexp(input, dim=0)
4551
                b = torch.logsumexp(input, dim=(0, 1))
4552
                return a + b
4553

4554
        x = torch.randn(4, 4, requires_grad=True)
4555
        self.run_test(ReduceLogSumExpModel(), x)
4556

4557
    def test_softmax(self):
4558
        for i in range(-4, 3):
4559
            model = torch.nn.Softmax(dim=i)
4560
            input = torch.randn(3, 4, 5, 6)
4561
            self.run_test(model, input)
4562

4563
            class SoftmaxUnknownRank(torch.nn.Module):
4564
                def __init__(self, i):
4565
                    super().__init__()
4566
                    self.softmax = torch.nn.Softmax(dim=i)
4567

4568
                def forward(self, x):
4569
                    return self.softmax(x.reshape(3, 4, 5, 6))
4570

4571
            model = torch.jit.script(SoftmaxUnknownRank(i))
4572
            self.run_test(model, input)
4573

4574
    def test_softmax_large_values(self):
4575
        input = torch.tensor(
4576
            [[-1e12, -1e12, -1e12], [1e12, 0.0, -5.0], [3.0, 4.0, 5.0]]
4577
        )
4578
        for i in range(-2, 1):
4579
            model = torch.nn.Softmax(dim=i)
4580
            self.run_test(model, input)
4581

4582
            class SoftmaxUnknownRank(torch.nn.Module):
4583
                def __init__(self, i):
4584
                    super().__init__()
4585
                    self.softmax = torch.nn.Softmax(dim=i)
4586

4587
                def forward(self, x):
4588
                    return self.softmax(x.reshape(3, 3))
4589

4590
            model = torch.jit.script(SoftmaxUnknownRank(i))
4591
            self.run_test(model, input)
4592

4593
    def test_logsoftmax(self):
4594
        for i in range(7)[2:]:
4595
            model = torch.nn.LogSoftmax(dim=i - 1)
4596
            dims = [2] * (i - 2) + [3, 4]
4597
            input = torch.ones(*dims, requires_grad=True)
4598
            self.run_test(model, input)
4599

4600
    def test_logsoftmax_dim(self):
4601
        for i in range(-4, 3):
4602
            model = torch.nn.LogSoftmax(dim=i)
4603
            input = torch.randn(3, 4, 5, 6)
4604
            self.run_test(model, input)
4605

4606
    def test_logsoftmax_dtype(self):
4607
        class Model(torch.nn.Module):
4608
            def forward(self, x):
4609
                return torch.nn.functional.log_softmax(x, dim=1, dtype=torch.float64)
4610

4611
        x = torch.randn(3, 4, 5, requires_grad=True)
4612
        self.run_test(Model(), x)
4613

4614
    def test_softplus(self):
4615
        class BetaOneModel(torch.nn.Module):
4616
            def forward(self, x):
4617
                return torch.nn.functional.softplus(x)
4618

4619
        x = torch.randn(3, 4, 5, requires_grad=True)
4620
        self.run_test(BetaOneModel(), x)
4621

4622
        class BetaModel(torch.nn.Module):
4623
            def forward(self, x):
4624
                return torch.nn.functional.softplus(x, beta=2)
4625

4626
        x = torch.randn(3, 4, 5, requires_grad=True)
4627
        self.run_test(BetaModel(), x)
4628

4629
        class BetaFloatModel(torch.nn.Module):
4630
            def forward(self, x):
4631
                return torch.nn.functional.softplus(x, beta=1.7)
4632

4633
        x = torch.randn(3, 4, 5, requires_grad=True)
4634
        self.run_test(BetaFloatModel(), x)
4635

4636
    @skipIfUnsupportedMinOpsetVersion(9)
4637
    def test_lstm_no_hidden(self):
4638
        class LSTMModel(torch.nn.Module):
4639
            def __init__(self):
4640
                super().__init__()
4641
                self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16)
4642

4643
            def forward(self, x):
4644
                return self.rnn(x)
4645

4646
        input = torch.randn((10, 16, 16))
4647
        self.run_test(LSTMModel(), (input,))
4648

4649
    @skipIfUnsupportedMinOpsetVersion(9)
4650
    def test_lstm_proj_no_hidden(self):
4651
        class LSTMModel(torch.nn.Module):
4652
            def __init__(self):
4653
                super().__init__()
4654
                self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16, proj_size=8)
4655

4656
            def forward(self, x):
4657
                return self.rnn(x)
4658

4659
        input = torch.randn((10, 16, 16))
4660
        with self.assertRaises(RuntimeError):
4661
            self.run_test(LSTMModel(), (input,))
4662

4663
    @skipIfUnsupportedMinOpsetVersion(9)
4664
    def test_lstm(self):
4665
        class LSTMModel(torch.nn.Module):
4666
            def __init__(self):
4667
                super().__init__()
4668
                self.rnn = torch.nn.LSTM(
4669
                    RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
4670
                )
4671

4672
            def forward(self, x, h0, c0):
4673
                return self.rnn(x, (h0, c0))
4674

4675
        input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4676
        h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
4677
        c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
4678
        self.run_test(LSTMModel(), (input, h0, c0))
4679

4680
    @skipIfUnsupportedMinOpsetVersion(9)
4681
    def test_lstm_cell(self):
4682
        class LSTMCellModel(torch.nn.Module):
4683
            def __init__(self, bias):
4684
                super().__init__()
4685
                self.lstm_cell = torch.nn.LSTMCell(
4686
                    RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, bias=bias
4687
                )
4688

4689
            def forward(self, x, h0, c0):
4690
                return self.lstm_cell(x, (h0, c0))
4691

4692
        input = torch.randn(BATCH_SIZE, RNN_INPUT_SIZE)
4693
        h0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE)
4694
        c0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE)
4695
        for bias in [True, False]:
4696
            self.run_test(LSTMCellModel(bias), (input, h0, c0))
4697

4698
    @skipIfUnsupportedMinOpsetVersion(9)
4699
    def test_lstm_default_init_state(self):
4700
        class LSTMModel(torch.nn.Module):
4701
            def __init__(self):
4702
                super().__init__()
4703
                self.rnn = torch.nn.LSTM(
4704
                    RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
4705
                )
4706

4707
            def forward(self, x):
4708
                return self.rnn(x)
4709

4710
        input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4711
        self.run_test(LSTMModel(), input)
4712

4713
    @skipIfUnsupportedMinOpsetVersion(9)
4714
    def test_lstm_fixed_batch_size(self):
4715
        class LSTMModel(torch.nn.Module):
4716
            def __init__(self):
4717
                super().__init__()
4718
                self.lstm = torch.nn.LSTM(
4719
                    RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
4720
                )
4721
                self.RNN_HIDDEN_SIZE = RNN_HIDDEN_SIZE
4722

4723
            def forward(self, input):
4724
                batch_size = input.size()[1]
4725
                h0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
4726
                c0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
4727
                return self.lstm(input, (h0, c0))
4728

4729
        input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4730
        # verify with different input of same batch size
4731
        input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4732
        self.run_test(
4733
            LSTMModel(), input, fixed_batch_size=True, additional_test_inputs=[input2]
4734
        )
4735

4736
    @skipIfUnsupportedMinOpsetVersion(9)
4737
    def test_lstm_post_fix_init_state(self):
4738
        class LSTMModel(torch.nn.Module):
4739
            def __init__(self):
4740
                super().__init__()
4741
                self.lstm = torch.nn.LSTM(
4742
                    RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
4743
                )
4744
                self.RNN_HIDDEN_SIZE = RNN_HIDDEN_SIZE
4745

4746
            def forward(self, input):
4747
                batch_size = input.size()[1]
4748
                h0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
4749
                c0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
4750
                return self.lstm(input, (h0, c0))
4751

4752
        model = LSTMModel()
4753
        input = torch.randn(RNN_SEQUENCE_LENGTH, 1, RNN_INPUT_SIZE)
4754
        # verify with different input of different batch size
4755
        input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4756
        self.run_test(
4757
            model,
4758
            input,
4759
            input_names=["input.1"],
4760
            dynamic_axes={"input.1": {0: "seq", 1: "batch"}},
4761
            additional_test_inputs=[input2],
4762
        )
4763

4764
    def test_lstm_constant_folding(self):
4765
        class LstmNet(torch.nn.Module):
4766
            def __init__(self, input_size, hidden_size, num_layers, bidirectional):
4767
                super().__init__()
4768
                self.lstm = torch.nn.LSTM(
4769
                    input_size, hidden_size, num_layers, bidirectional=bidirectional
4770
                )
4771

4772
            def forward(self, input, initial_state: Tuple[Tensor, Tensor]):
4773
                return self.lstm(input, initial_state)
4774

4775
        def get_LstmNet_model_and_inputs(
4776
            input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional
4777
        ):
4778
            num_directions = 2 if bidirectional else 1
4779
            model = LstmNet(input_size, hidden_size, num_layers, bidirectional)
4780
            input = torch.randn(seq_len, batch_size, input_size)
4781
            h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
4782
            c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
4783
            return model, (input, (h0, c0))
4784

4785
        batch_size1 = 3
4786
        model1, input1 = get_LstmNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
4787
        self.run_test(model1, input1, do_constant_folding=True)
4788

4789
        batch_size2 = 4
4790
        model2, input2 = get_LstmNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
4791
        self.run_test(model2, input2, do_constant_folding=True)
4792

4793
    @skipIfUnsupportedMinOpsetVersion(9)
4794
    def test_lstm_no_bias(self):
4795
        class LstmNet(torch.nn.Module):
4796
            def __init__(self, num_layers, bidirectional):
4797
                super().__init__()
4798
                self.lstm = torch.nn.LSTM(
4799
                    RNN_INPUT_SIZE,
4800
                    RNN_HIDDEN_SIZE,
4801
                    num_layers,
4802
                    bias=False,
4803
                    bidirectional=bidirectional,
4804
                )
4805

4806
            def forward(self, input, initial_state: Tuple[Tensor, Tensor]):
4807
                return self.lstm(input, initial_state)
4808

4809
        def get_LstmNet_model_and_inputs(num_layers, bidirectional):
4810
            input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
4811
            num_directions = 2 if bidirectional else 1
4812
            model = LstmNet(num_layers, bidirectional)
4813
            h0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE)
4814
            c0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE)
4815
            return model, (input, (h0, c0))
4816

4817
        num_layers = [1, 1, 2, 3]
4818
        bidirectional = [True, False, True, False]
4819
        models_and_inputs = [
4820
            get_LstmNet_model_and_inputs(n, b)
4821
            for n, b in zip(num_layers, bidirectional)
4822
        ]
4823
        for model, input in models_and_inputs:
4824
            self.run_test(model, input)
4825

4826
    @skipIfUnsupportedMinOpsetVersion(9)
4827
    def test_lstm_sequence(self):
4828
        class LstmNet(torch.nn.Module):
4829
            def __init__(self):
4830
                super().__init__()
4831
                self.rnn1 = torch.nn.LSTM(8, 8, bidirectional=True, batch_first=True)
4832
                self.linear1 = torch.nn.Linear(8 * 2, 8)
4833
                self.rnn2 = torch.nn.LSTM(8, 8, bidirectional=True, batch_first=True)
4834
                self.linear2 = torch.nn.Linear(8 * 2, 8)
4835

4836
            def forward(self, input):
4837
                rnn_output1, _ = self.rnn1(input)
4838
                linear_output1 = self.linear1(rnn_output1)
4839
                rnn_output2, _ = self.rnn2(linear_output1)
4840
                linear_output2 = self.linear2(rnn_output2)
4841
                return linear_output2
4842

4843
        input = torch.zeros((1, 100, 8), dtype=torch.float32)
4844
        self.run_test(
4845
            LstmNet(),
4846
            input,
4847
            input_names=["input"],
4848
            output_names=["output"],
4849
            dynamic_axes={
4850
                "input": {0: "batch_size", 1: "w", 2: "h"},
4851
                "output": {0: "batch_size", 1: "w", 2: "h"},
4852
            },
4853
        )
4854

4855
    @skipScriptTest()
4856
    def test_rnn_no_bias(self):
4857
        def make_model(layers, packed_sequence):
4858
            batch_first = True if packed_sequence == 2 else False
4859
            model = torch.nn.RNN(
4860
                RNN_INPUT_SIZE,
4861
                RNN_HIDDEN_SIZE,
4862
                layers,
4863
                bidirectional=False,
4864
                batch_first=batch_first,
4865
                bias=False,
4866
            )
4867

4868
            if packed_sequence == 1:
4869
                model = rnn_model_with_packed_sequence.RnnModelWithPackedSequence(
4870
                    model, False
4871
                )
4872
            if packed_sequence == 2:
4873
                model = rnn_model_with_packed_sequence.RnnModelWithPackedSequence(
4874
                    model, True
4875
                )
4876
            return model
4877

4878
        def make_input(batch_size, layers, packed_sequence):
4879
            batch_first = True if packed_sequence == 2 else False
4880
            seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
4881
            seq_lengths = sorted(map(int, seq_lengths), reverse=True)
4882
            inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
4883
            inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
4884
            inputs = [inputs]
4885

4886
            h0 = torch.randn(layers, batch_size, RNN_HIDDEN_SIZE)
4887
            inputs.append(h0)
4888
            if packed_sequence != 0:
4889
                inputs.append(torch.IntTensor(seq_lengths))
4890
            if len(inputs) == 1:
4891
                input = inputs[0]
4892
            else:
4893
                input = tuple(inputs)
4894
            return input
4895

4896
        layers = [1, 3, 1, 3, 1, 3]
4897
        packed_sequence = [0, 0, 1, 1, 2, 2]
4898
        models = [make_model(l, p) for l, p in zip(layers, packed_sequence)]
4899
        inputs = [
4900
            make_input(RNN_BATCH_SIZE, l, p) for l, p in zip(layers, packed_sequence)
4901
        ]
4902

4903
        for model, input in zip(models, inputs):
4904
            self.run_test(model, input)
4905

4906
    def test_gru_no_bias(self):
4907
        class GruNet(torch.nn.Module):
4908
            def __init__(self, input_size, hidden_size, num_layers, bidirectional):
4909
                super().__init__()
4910
                self.mygru = torch.nn.GRU(
4911
                    input_size,
4912
                    hidden_size,
4913
                    num_layers,
4914
                    bidirectional=bidirectional,
4915
                    bias=False,
4916
                )
4917

4918
            def forward(self, input, initial_state):
4919
                out = self.mygru(input, initial_state)
4920
                return out
4921

4922
        def get_GruNet_model_and_inputs(
4923
            input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional
4924
        ):
4925
            num_directions = 2 if bidirectional else 1
4926
            model = GruNet(input_size, hidden_size, num_layers, bidirectional)
4927
            input = torch.randn(seq_len, batch_size, input_size)
4928
            h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
4929
            return model, (input, h0)
4930

4931
        input_size = [7, 5]
4932
        hidden_size = [3, 4]
4933
        num_layers = [2, 3]
4934
        batch_size = [3, 4]
4935
        seq_len = [5, 7]
4936
        bidirectional = [True, False]
4937
        models_and_inputs = [
4938
            get_GruNet_model_and_inputs(i, h, n, b, s, bi)
4939
            for i, h, n, b, s, bi in zip(
4940
                input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional
4941
            )
4942
        ]
4943
        for model, input in models_and_inputs:
4944
            self.run_test(model, input, do_constant_folding=True)
4945

4946
    def test_gru_constant_folding(self):
4947
        class GruNet(torch.nn.Module):
4948
            def __init__(self, input_size, hidden_size, num_layers, bidirectional):
4949
                super().__init__()
4950
                self.mygru = torch.nn.GRU(
4951
                    input_size, hidden_size, num_layers, bidirectional=bidirectional
4952
                )
4953

4954
            def forward(self, input, initial_state):
4955
                out = self.mygru(input, initial_state)
4956
                return out
4957

4958
        def get_GruNet_model_and_inputs(
4959
            input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional
4960
        ):
4961
            num_directions = 2 if bidirectional else 1
4962
            model = GruNet(input_size, hidden_size, num_layers, bidirectional)
4963
            input = torch.randn(seq_len, batch_size, input_size)
4964
            h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
4965
            return model, (input, h0)
4966

4967
        batch_size1 = 3
4968
        model1, input1 = get_GruNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
4969
        self.run_test(model1, input1, do_constant_folding=True)
4970

4971
        batch_size2 = 4
4972
        model2, input2 = get_GruNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
4973
        self.run_test(model2, input2, do_constant_folding=True)
4974

4975
    @skipIfUnsupportedMinOpsetVersion(8)
4976
    def test_max_tensors(self):
4977
        class MaxModel(torch.nn.Module):
4978
            def forward(self, input, other):
4979
                return torch.max(input, other)
4980

4981
        model = MaxModel()
4982
        x = torch.randn(4, 4, requires_grad=True)
4983
        y = torch.randn(4, 1, requires_grad=True)
4984
        self.run_test(model, (x, y))
4985

4986
    def test_amax_amin(self):
4987
        class Model(torch.nn.Module):
4988
            def forward(self, x):
4989
                return torch.amax(x, dim=0, keepdim=True), torch.amin(
4990
                    x, dim=[0, 1], keepdim=False
4991
                )
4992

4993
        model = Model()
4994
        x = torch.randn(4, 4)
4995
        self.run_test(model, x)
4996

4997
    def test_aminmax(self):
4998
        class Model(torch.nn.Module):
4999
            def forward(self, x):
5000
                return torch.aminmax(x, dim=1, keepdim=True), torch.aminmax(
5001
                    x, keepdim=False
5002
                )
5003

5004
        model = Model()
5005
        x = torch.randn(3, 4)
5006
        self.run_test(model, x)
5007

5008
    @skipIfUnsupportedMinOpsetVersion(9)
5009
    def test_arange_end(self):
5010
        class ArangeScript(torch.jit.ScriptModule):
5011
            @torch.jit.script_method
5012
            def forward(self, a):
5013
                return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
5014

5015
        x = torch.randn(3, 4, requires_grad=True)
5016
        outputs = ArangeScript()(x)
5017
        self.run_test(ArangeScript(), x)
5018

5019
        class ArangeModel(torch.nn.Module):
5020
            def forward(self, a):
5021
                return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
5022

5023
        self.run_test(ArangeModel(), x)
5024

5025
    @skipIfUnsupportedMinOpsetVersion(11)
5026
    def test_arange_end_notype(self):
5027
        class ArangeScript(torch.jit.ScriptModule):
5028
            @torch.jit.script_method
5029
            def forward(self, a):
5030
                return torch.arange(a.size(0))
5031

5032
        x = torch.randn(3, 4, requires_grad=True)
5033
        outputs = ArangeScript()(x)
5034
        self.run_test(ArangeScript(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
5035
        self.run_test(ArangeScript(), x, remained_onnx_input_idx=[])
5036

5037
        class ArangeModel(torch.nn.Module):
5038
            def forward(self, a):
5039
                return torch.arange(a.size(0))
5040

5041
        self.run_test(ArangeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
5042
        self.run_test(ArangeModel(), x, remained_onnx_input_idx=[])
5043

5044
    @skipIfUnsupportedMinOpsetVersion(9)
5045
    def test_arange_start_end(self):
5046
        class ArangeScript(torch.jit.ScriptModule):
5047
            @torch.jit.script_method
5048
            def forward(self, a):
5049
                return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
5050

5051
        x = torch.randn(3, 4, requires_grad=True)
5052
        self.run_test(ArangeScript(), x)
5053

5054
        class ArangeModel(torch.nn.Module):
5055
            def forward(self, a):
5056
                return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
5057

5058
        self.run_test(ArangeModel(), x)
5059

5060
    @skipIfUnsupportedMinOpsetVersion(11)
5061
    def test_arange_start_end_notype(self):
5062
        class ArangeScript(torch.jit.ScriptModule):
5063
            @torch.jit.script_method
5064
            def forward(self, a):
5065
                return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a
5066

5067
        x = torch.randn(3, 4, requires_grad=True)
5068
        self.run_test(ArangeScript(), x)
5069

5070
        class ArangeModel(torch.nn.Module):
5071
            def forward(self, a):
5072
                return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a
5073

5074
        self.run_test(ArangeModel(), x)
5075

5076
    @skipIfUnsupportedMinOpsetVersion(9)
5077
    def test_arange_start_end_step(self):
5078
        class ArangeScript(torch.jit.ScriptModule):
5079
            @torch.jit.script_method
5080
            def forward(self, a):
5081
                return (
5082
                    torch.arange(
5083
                        2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float
5084
                    ).view(-1, 1)
5085
                    + a
5086
                )
5087

5088
        x = torch.randn(3, 4, requires_grad=True)
5089
        self.run_test(ArangeScript(), x)
5090

5091
        class ArangeModel(torch.nn.Module):
5092
            def forward(self, a):
5093
                return (
5094
                    torch.arange(
5095
                        2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float
5096
                    ).view(-1, 1)
5097
                    + a
5098
                )
5099

5100
        self.run_test(ArangeModel(), x)
5101

5102
    @skipIfUnsupportedMinOpsetVersion(11)
5103
    def test_arange_start_end_step_notype(self):
5104
        class ArangeScript(torch.jit.ScriptModule):
5105
            @torch.jit.script_method
5106
            def forward(self, a):
5107
                return (
5108
                    torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1)
5109
                    + a
5110
                )
5111

5112
        x = torch.randn(3, 4, requires_grad=True)
5113
        self.run_test(ArangeScript(), x)
5114

5115
        class ArangeModel(torch.nn.Module):
5116
            def forward(self, a):
5117
                return (
5118
                    torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1)
5119
                    + a
5120
                )
5121

5122
        self.run_test(ArangeModel(), x)
5123

5124
    @skipIfUnsupportedMinOpsetVersion(9)
5125
    def test__dim_arange(self):
5126
        class DimArange(torch.nn.Module):
5127
            def forward(self, input):
5128
                return torch._dim_arange(input, 1)
5129

5130
        x = torch.ones(5, 6)
5131
        self.run_test(DimArange(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
5132
        remained_onnx_input_idx = None if self.opset_version < 11 else []
5133
        self.run_test(DimArange(), x, remained_onnx_input_idx=remained_onnx_input_idx)
5134

5135
    def _test_compare_ops(self, model, num_inputs):
5136
        x_float = torch.randn(1, 2, 3, 4, requires_grad=True)
5137
        x_int = torch.randint(10, (3, 4), dtype=torch.int32)
5138
        if num_inputs > 1:
5139
            y_float = torch.randn(1, 2, 3, 4, requires_grad=True)
5140
            y_int = torch.randint(10, (3, 4), dtype=torch.int32)
5141
            self.run_test(model, (x_float, y_float))
5142
            self.run_test(model, (x_float, y_int))
5143
            self.run_test(model, (x_int, y_float))
5144
            self.run_test(model, (x_int, y_int))
5145
        else:
5146
            self.run_test(model, x_float)
5147
            self.run_test(model, x_int)
5148

5149
    @skipIfUnsupportedMinOpsetVersion(9)
5150
    def test_and_or_xor(self):
5151
        class MyModel(torch.nn.Module):
5152
            def forward(self, x, y):
5153
                return x ^ y, x | y, x & y, ~x
5154

5155
        x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5156
        y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5157
        self.run_test(MyModel(), input_args=(x, y))
5158

5159
    @skipIfUnsupportedMinOpsetVersion(9)
5160
    def test_logical_and(self):
5161
        class AndModel(torch.nn.Module):
5162
            def forward(self, x, y):
5163
                return torch.logical_and(x, y)
5164

5165
        x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5166
        y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5167
        self.run_test(AndModel(), input_args=(x, y))
5168

5169
        x = torch.randint(10, (5, 5), dtype=torch.int32)
5170
        y = torch.randint(10, (5, 5), dtype=torch.int32)
5171
        self.run_test(AndModel(), input_args=(x, y))
5172

5173
        x = torch.randint(10, (5, 5), dtype=torch.double)
5174
        y = torch.randint(10, (5, 5), dtype=torch.double)
5175
        self.run_test(AndModel(), input_args=(x, y))
5176

5177
        x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
5178
        y = torch.randint(10, (2, 3, 5), dtype=torch.long)
5179
        self.run_test(AndModel(), input_args=(x, y))
5180

5181
    @skipIfUnsupportedMinOpsetVersion(9)
5182
    def test_logical_or(self):
5183
        class OrModel(torch.nn.Module):
5184
            def forward(self, x, y):
5185
                return torch.logical_or(x, y)
5186

5187
        x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5188
        y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5189
        self.run_test(OrModel(), input_args=(x, y))
5190

5191
        x = torch.randint(10, (5, 5), dtype=torch.int32)
5192
        y = torch.randint(10, (5, 5), dtype=torch.int32)
5193
        self.run_test(OrModel(), input_args=(x, y))
5194

5195
        x = torch.randint(10, (5, 5), dtype=torch.double)
5196
        y = torch.randint(10, (5, 5), dtype=torch.double)
5197
        self.run_test(OrModel(), input_args=(x, y))
5198

5199
        x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
5200
        y = torch.randint(10, (2, 3, 5), dtype=torch.long)
5201
        self.run_test(OrModel(), input_args=(x, y))
5202

5203
    @skipIfUnsupportedMinOpsetVersion(9)
5204
    def test_logical_xor(self):
5205
        class XorModel(torch.nn.Module):
5206
            def forward(self, x, y):
5207
                return torch.logical_xor(x, y)
5208

5209
        x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5210
        y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5211
        self.run_test(XorModel(), input_args=(x, y))
5212

5213
        x = torch.randint(10, (5, 5), dtype=torch.int32)
5214
        y = torch.randint(10, (5, 5), dtype=torch.int32)
5215
        self.run_test(XorModel(), input_args=(x, y))
5216

5217
        x = torch.randint(10, (5, 5), dtype=torch.double)
5218
        y = torch.randint(10, (5, 5), dtype=torch.double)
5219
        self.run_test(XorModel(), input_args=(x, y))
5220

5221
        x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
5222
        y = torch.randint(10, (2, 3, 5), dtype=torch.long)
5223
        self.run_test(XorModel(), input_args=(x, y))
5224

5225
    @skipIfUnsupportedMinOpsetVersion(9)
5226
    def test_logical_not(self):
5227
        class NotModel(torch.nn.Module):
5228
            def forward(self, x):
5229
                return torch.logical_not(x)
5230

5231
        x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
5232
        self.run_test(NotModel(), input_args=(x,))
5233

5234
        x = torch.randint(10, (5, 5), dtype=torch.int32)
5235
        self.run_test(NotModel(), input_args=(x,))
5236

5237
        x = torch.randint(10, (5, 5), dtype=torch.double)
5238
        self.run_test(NotModel(), input_args=(x,))
5239

5240
        x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
5241
        self.run_test(NotModel(), input_args=(x,))
5242

5243
    @skipIfUnsupportedMinOpsetVersion(11)  # float equal added after opset 11
5244
    def test_eq(self):
5245
        class EqualModel(torch.nn.Module):
5246
            def forward(self, input, other):
5247
                return input == other
5248

5249
        self._test_compare_ops(EqualModel(), 2)
5250

5251
    def test_gt(self):
5252
        class GreaterModel(torch.nn.Module):
5253
            def forward(self, input, other):
5254
                return input > other
5255

5256
        self._test_compare_ops(GreaterModel(), 2)
5257

5258
    @skipIfUnsupportedMinOpsetVersion(9)
5259
    def test_ge(self):
5260
        class GreaterOrEqualModel(torch.nn.Module):
5261
            def forward(self, input, other):
5262
                return input >= other
5263

5264
        self._test_compare_ops(GreaterOrEqualModel(), 2)
5265

5266
    def test_gt_scalar(self):
5267
        class GreaterModel(torch.nn.Module):
5268
            def forward(self, input):
5269
                return input > 1
5270

5271
        self._test_compare_ops(GreaterModel(), 1)
5272

5273
    def test_gt_primitive(self):
5274
        class GreaterModel(torch.nn.Module):
5275
            def __init__(self):
5276
                super().__init__()
5277
                self.y: int = 2
5278

5279
            def forward(self, x: int):
5280
                return self.y > x
5281

5282
        x = 3
5283
        self.run_test(GreaterModel(), (x,))
5284

5285
    @skipIfUnsupportedMinOpsetVersion(9)
5286
    def test_ge_scalar(self):
5287
        class GreaterOrEqualModel(torch.nn.Module):
5288
            def forward(self, input):
5289
                return input >= 1
5290

5291
        self._test_compare_ops(GreaterOrEqualModel(), 1)
5292

5293
    def test_lt(self):
5294
        class LessModel(torch.nn.Module):
5295
            def forward(self, input, other):
5296
                return input > other
5297

5298
        self._test_compare_ops(LessModel(), 2)
5299

5300
    @skipIfUnsupportedMinOpsetVersion(9)
5301
    def test_le(self):
5302
        class LessOrEqualModel(torch.nn.Module):
5303
            def forward(self, input, other):
5304
                return input <= other
5305

5306
        self._test_compare_ops(LessOrEqualModel(), 2)
5307

5308
    def test_lt_scalar(self):
5309
        class LessModel(torch.nn.Module):
5310
            def forward(self, input):
5311
                return input < 1
5312

5313
        self._test_compare_ops(LessModel(), 1)
5314

5315
    @skipIfUnsupportedMinOpsetVersion(9)
5316
    def test_le_scalar(self):
5317
        class LessOrEqualModel(torch.nn.Module):
5318
            def forward(self, input):
5319
                return input <= 1
5320

5321
        self._test_compare_ops(LessOrEqualModel(), 1)
5322

5323
    def test_matmul(self):
5324
        class MatmulModel(torch.nn.Module):
5325
            def forward(self, input, other):
5326
                return torch.matmul(input, other)
5327

5328
        x = torch.randn(3, 4, requires_grad=True)
5329
        y = torch.randn(4, 5, requires_grad=True)
5330
        self.run_test(MatmulModel(), (x, y))
5331

5332
        x = torch.randint(10, (3, 4))
5333
        y = torch.randint(10, (4, 5))
5334
        self.run_test(MatmulModel(), (x, y))
5335

5336
    def test_matmul_batch(self):
5337
        class MatmulModel(torch.nn.Module):
5338
            def forward(self, input, other):
5339
                return torch.matmul(input, other)
5340

5341
        x = torch.randn(2, 3, 4, requires_grad=True)
5342
        y = torch.randn(2, 4, 5, requires_grad=True)
5343
        self.run_test(MatmulModel(), (x, y))
5344

5345
        x = torch.randint(10, (2, 3, 4))
5346
        y = torch.randint(10, (2, 4, 5))
5347
        self.run_test(MatmulModel(), (x, y))
5348

5349
    def _argmin_argmax_model(self, input):
5350
        class ArgminArgmaxModel(torch.nn.Module):
5351
            def forward(self, input):
5352
                return (
5353
                    torch.argmin(input),
5354
                    torch.argmax(input),
5355
                    torch.argmin(input, keepdim=True),
5356
                    torch.argmax(input, keepdim=True),
5357
                    torch.argmin(input, dim=0, keepdim=True),
5358
                    torch.argmax(input, dim=1, keepdim=True),
5359
                )
5360

5361
        self.run_test(ArgminArgmaxModel(), input)
5362

5363
    @skipIfUnsupportedMinOpsetVersion(9)
5364
    def test_argmin_argmax(self):
5365
        input = torch.randn(7, 3, 5)
5366
        self._argmin_argmax_model(input)
5367

5368
    # Argmin and Argmax with "select_last_index" is not supprted before opset 12
5369
    # "select_last_index" was added in opset 12 to deal with corner case where the
5370
    # same value appears multiple times in the tensor
5371
    @skipIfUnsupportedMinOpsetVersion(12)
5372
    def test_argmin_argmax_select_last_index(self):
5373
        input = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 2.0]])
5374
        self._argmin_argmax_model(input)
5375

5376
        input = torch.ones(7, 3, 5)
5377
        self._argmin_argmax_model(input)
5378

5379
    def test_repeat(self):
5380
        class RepeatModel(torch.nn.Module):
5381
            def forward(self, x, y):
5382
                x2 = x.repeat(y.shape[0], 1)
5383
                y1 = y.view(-1, 1)
5384
                return x2 + y1
5385

5386
        x = torch.tensor([1, 2, 3])
5387
        y = torch.tensor([4, 5, 8, 9])
5388
        self.run_test(RepeatModel(), (x, y))
5389

5390
    @skipIfUnsupportedMinOpsetVersion(9)
5391
    def test_repeat_interleave(self):
5392
        class FlattenModel(torch.nn.Module):
5393
            def forward(self, x):
5394
                return x.repeat_interleave(2)
5395

5396
        for shape in ([3], [3, 4], [2, 3, 4]):
5397
            x = torch.randn(shape)
5398
            self.run_test(FlattenModel(), (x,))
5399

5400
        class DimsModel(torch.nn.Module):
5401
            def forward(self, x):
5402
                return x.repeat_interleave(4, dim=1)
5403

5404
        x = torch.tensor([[1, 2], [3, 4]])
5405
        self.run_test(DimsModel(), (x,))
5406

5407
        class DimsModel2(torch.nn.Module):
5408
            def forward(self, x):
5409
                repeats = torch.tensor([4])
5410
                return torch.repeat_interleave(x, repeats, dim=1)
5411

5412
        x = torch.tensor([[1, 2], [3, 4]])
5413
        self.run_test(DimsModel2(), (x,))
5414

5415
        class RepeatsDimsModel(torch.nn.Module):
5416
            def forward(self, x):
5417
                repeats = torch.tensor([1, 2])
5418
                return torch.repeat_interleave(x, repeats, dim=0)
5419

5420
        x = torch.tensor([[1, 2], [3, 4]])
5421
        self.run_test(RepeatsDimsModel(), (x,))
5422

5423
        class RepeatsDimsModel2(torch.nn.Module):
5424
            def forward(self, x):
5425
                repeats = torch.tensor([1, 2])
5426
                return torch.repeat_interleave(x, repeats, dim=1)
5427

5428
        x = torch.tensor([[1, 2], [3, 4]])
5429
        self.run_test(RepeatsDimsModel2(), (x,))
5430

5431
    @skipIfUnsupportedMinOpsetVersion(9)
5432
    def test_repeat_interleave_noop(self):
5433
        class Model(torch.nn.Module):
5434
            def forward(self, x):
5435
                return x.repeat_interleave(1, dim=1)
5436

5437
        x = torch.randn(4, 1, 8)
5438
        self.run_test(Model(), (x,))
5439

5440
    @skipIfUnsupportedMinOpsetVersion(13)
5441
    def test_dynamic_repeat_interleave(self):
5442
        class SingleDynamicModel(torch.nn.Module):
5443
            def forward(self, x):
5444
                repeats = torch.tensor(4)
5445
                return torch.repeat_interleave(x, repeats, dim=1)
5446

5447
        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5448
        another_x = torch.tensor([[7, 8], [5, 6]])
5449
        self.run_test(
5450
            SingleDynamicModel(),
5451
            x,
5452
            additional_test_inputs=[another_x],
5453
            input_names=["input_1"],
5454
            dynamic_axes={"input_1": {1: "w"}},
5455
        )
5456

5457
        class NegDynamicModel(torch.nn.Module):
5458
            def forward(self, x):
5459
                repeats = torch.tensor(4)
5460
                return torch.repeat_interleave(x, repeats, dim=-1)
5461

5462
        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5463
        another_x = torch.tensor([[7, 8], [5, 6]])
5464
        self.run_test(
5465
            NegDynamicModel(),
5466
            x,
5467
            additional_test_inputs=[another_x],
5468
            input_names=["input_1"],
5469
            dynamic_axes={"input_1": {1: "w"}},
5470
        )
5471

5472
        class SingleDynamicModelFloat(torch.nn.Module):
5473
            def forward(self, x):
5474
                repeats = torch.tensor([4])
5475
                return torch.repeat_interleave(x, repeats, dim=0)
5476

5477
        x = torch.tensor([[1.1, 2.1], [3.1, 4.1]])
5478
        another_x = torch.tensor([[7.1, 8.1], [5.1, 6.1]])
5479
        self.run_test(
5480
            SingleDynamicModelFloat(),
5481
            x,
5482
            additional_test_inputs=[another_x],
5483
            input_names=["input_1"],
5484
            dynamic_axes={"input_1": {0: "h"}},
5485
        )
5486

5487
        class DynamicRepeatsModel(torch.nn.Module):
5488
            def forward(self, x, repeats):
5489
                return torch.repeat_interleave(x, repeats, dim=1)
5490

5491
        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5492
        another_x = torch.tensor([[7, 8], [5, 6]])
5493
        repeats = torch.tensor([2])
5494
        another_repeats = torch.tensor([4])
5495
        self.run_test(
5496
            DynamicRepeatsModel(),
5497
            (x, repeats),
5498
            additional_test_inputs=[(another_x, another_repeats)],
5499
            input_names=["input_1", "repeats_1"],
5500
            dynamic_axes={"input_1": {1: "w"}, "repeats_1": {0: "r"}},
5501
        )
5502

5503
        class DynamicRepeatsModel2(torch.nn.Module):
5504
            def forward(self, x, repeats):
5505
                return torch.repeat_interleave(x, repeats, dim=1)
5506

5507
        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5508
        repeats = torch.tensor([2])
5509
        another_repeats = torch.tensor([4])
5510
        self.run_test(
5511
            DynamicRepeatsModel2(),
5512
            (x, repeats),
5513
            additional_test_inputs=[(x, another_repeats)],
5514
            input_names=["input_1", "repeats_1"],
5515
            dynamic_axes={"repeats_1": {0: "r"}},
5516
        )
5517

5518
        class DynamicFlattenModel(torch.nn.Module):
5519
            def forward(self, x):
5520
                return x.repeat_interleave(2)
5521

5522
        x = torch.tensor([1, 2, 3])
5523
        self.run_test(
5524
            DynamicFlattenModel(),
5525
            x,
5526
            input_names=["input_1"],
5527
            dynamic_axes={"input_1": {0: "w"}},
5528
        )
5529

5530
    @skipIfUnsupportedMinOpsetVersion(13)
5531
    def test_multiple_dynamic_repeat_interleave(self):
5532
        class DynamicRepeatsModel(torch.nn.Module):
5533
            def forward(self, x, repeats):
5534
                return torch.repeat_interleave(x, repeats, dim=1)
5535

5536
        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5537
        repeats = torch.tensor([2, 3, 4])
5538
        another_repeats = torch.tensor([4, 3, 2])
5539
        self.run_test(
5540
            DynamicRepeatsModel(),
5541
            (x, repeats),
5542
            additional_test_inputs=[(x, another_repeats)],
5543
            input_names=["input_1", "repeats_1"],
5544
            dynamic_axes={"repeats_1": {0: "r"}},
5545
        )
5546

5547
        class DynamicRepeatsModel2(torch.nn.Module):
5548
            def forward(self, x, repeats):
5549
                return torch.repeat_interleave(x, repeats, dim=0)
5550

5551
        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
5552
        repeats = torch.tensor([2, 3])
5553
        another_repeats = torch.tensor([4, 3])
5554
        self.run_test(
5555
            DynamicRepeatsModel2(),
5556
            (x, repeats),
5557
            additional_test_inputs=[(x, another_repeats)],
5558
            input_names=["input_1", "repeats_1"],
5559
            dynamic_axes={"repeats_1": {0: "r"}},
5560
        )
5561

5562
    def test_view(self):
5563
        class ViewModel(torch.nn.Module):
5564
            def forward(self, input):
5565
                return input.view(4, 24)
5566

5567
        x = torch.randint(10, (4, 2, 3, 4), dtype=torch.int32)
5568
        self.run_test(ViewModel(), x)
5569

5570
    def test_view_dynamic(self):
5571
        class ViewModel(torch.nn.Module):
5572
            def forward(self, input, other):
5573
                return input.view(other.shape)
5574

5575
        x = torch.randn(2, 3, 4)
5576
        shape = torch.randn(6, 4)
5577
        self.run_test(
5578
            ViewModel(),
5579
            (x, shape),
5580
            input_names=["x", "shape"],
5581
            dynamic_axes={"x": [0, 1, 2], "shape": [0, 1]},
5582
        )
5583
        self.run_test(ViewModel(), (x, shape), remained_onnx_input_idx=[0])
5584

5585
    def test_view_dynamic_zero_dim(self):
5586
        class ViewModel(torch.nn.Module):
5587
            def forward(self, input):
5588
                input = input.view(-1, 2)
5589
                return input.view(1, -1)
5590

5591
        x = torch.ones(2)
5592
        another_x = torch.empty((0,))
5593
        self.run_test(
5594
            ViewModel(),
5595
            x,
5596
            additional_test_inputs=[another_x],
5597
            input_names=["input_1"],
5598
            dynamic_axes={
5599
                "input_1": [
5600
                    0,
5601
                ]
5602
            },
5603
        )
5604

5605
    def test_view_as(self):
5606
        class ViewModel(torch.nn.Module):
5607
            def forward(self, input, other):
5608
                return input.view_as(other)
5609

5610
        x = torch.randn(2, 3, 4)
5611
        y = torch.randn(6, 4)
5612
        self.run_test(ViewModel(), (x, y))
5613

5614
    def test_linear(self):
5615
        class LinearModel(torch.nn.Module):
5616
            def __init__(self):
5617
                super().__init__()
5618
                self.fc = torch.nn.Linear(16, 16)
5619

5620
            def forward(self, x):
5621
                out = self.fc(x)
5622
                out = self.fc(out)
5623
                return out
5624

5625
        x = torch.randn(3, 16)
5626
        self.run_test(LinearModel(), (x,))
5627

5628
        class LinearModel(torch.nn.Module):
5629
            def forward(self, input, weight, bias):
5630
                return torch.nn.functional.linear(input, weight, bias)
5631

5632
        # input of rank 2
5633
        x = torch.randn(2, 2)
5634
        y = torch.randn(2, 2)
5635
        z = torch.randn(1)
5636
        self.run_test(LinearModel(), (x, y, z))
5637

5638
        # input of rank 3
5639
        x = torch.randn(3, 3, 3)
5640
        y = torch.randn(3, 3)
5641
        z = torch.randn(1)
5642
        self.run_test(LinearModel(), (x, y, z))
5643

5644
    @skipScriptTest()
5645
    def test_weight_norm(self):
5646
        # addmm for 3-d inputs converts to onnx::MatMul
5647
        model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1)
5648
        x = torch.randn(3, 4, 5, requires_grad=True)
5649
        self.run_test(model, x)
5650

5651
        # addmm for 2-d inputs converts to onnx::Gemm
5652
        model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1)
5653
        x = torch.randn(4, 5, requires_grad=True)
5654
        self.run_test(model, x)
5655

5656
        model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3))
5657
        x = torch.randn(1, 1, 5, requires_grad=True)
5658
        self.run_test(model, x)
5659

5660
        model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3), dim=-2)
5661
        x = torch.randn(1, 1, 5, requires_grad=True)
5662
        self.run_test(model, x)
5663

5664
        model = torch.nn.utils.weight_norm(torch.nn.Conv1d(3, 6, 3), name="weight")
5665
        x = torch.randn(3, 3, 5, requires_grad=True)
5666
        self.run_test(model, x)
5667

5668
    @skipScriptTest()
5669
    def test_weight_norm_nodim(self):
5670
        # addmm for 3-d inputs converts to onnx::MatMul
5671
        model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None)
5672
        x = torch.randn(3, 4, 5, requires_grad=True)
5673
        self.run_test(model, x)
5674

5675
        # addmm for 2-d inputs converts to onnx::Gemm
5676
        model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None)
5677
        x = torch.randn(4, 5, requires_grad=True)
5678
        self.run_test(model, x)
5679

5680
    def test_flatten(self):
5681
        class FlattenModel(torch.nn.Module):
5682
            def forward(self, input):
5683
                return torch.flatten(input)
5684

5685
        model = FlattenModel()
5686

5687
        # flatten with 4d input
5688
        x = torch.randint(10, (1, 2, 3, 4))
5689
        self.run_test(model, x)
5690

5691
        # flatten with 0d input
5692
        x = torch.randn([])
5693
        self.run_test(model, x)
5694

5695
        # flatten with 1d input
5696
        x = torch.randn(4)
5697
        self.run_test(model, x)
5698

5699
    def test_flatten2d(self):
5700
        class FlattenModel(torch.nn.Module):
5701
            def forward(self, input):
5702
                return torch.flatten(input, 1)
5703

5704
        x = torch.randint(10, (1, 2, 3, 4))
5705
        self.run_test(FlattenModel(), x)
5706

5707
    def test_flatten2d_neg(self):
5708
        class FlattenModel(torch.nn.Module):
5709
            def forward(self, x):
5710
                return (
5711
                    torch.flatten(x, 1, -1),
5712
                    torch.flatten(x, 0, -2),
5713
                    torch.flatten(x, 1, -2),
5714
                )
5715

5716
        x = torch.randint(10, (1, 2, 3, 4))
5717
        self.run_test(FlattenModel(), x)
5718

5719
    @skipIfUnsupportedMinOpsetVersion(9)
5720
    def test_flatten_dynamic_axes(self):
5721
        class MyModule(torch.nn.Module):
5722
            def forward(self, x):
5723
                return torch.flatten(x, start_dim=2, end_dim=3)
5724

5725
        batch_size = 3
5726
        x = torch.randn(batch_size, 5, 4, 5)
5727
        y = torch.randn(5, 5, 4, 5)
5728
        model = MyModule()
5729
        self.run_test(
5730
            model,
5731
            x,
5732
            additional_test_inputs=[y],
5733
            input_names=["input"],
5734
            output_names=["output"],
5735
            dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
5736
        )
5737

5738
    @skipIfUnsupportedMinOpsetVersion(11)
5739
    def test_getitem(self):
5740
        class GetItemModel(torch.jit.ScriptModule):
5741
            @torch.jit.script_method
5742
            def forward(self, x, y, z, ind):
5743
                # this will create prim::ListConstruct(x, y, z) + aten::__getitem__
5744
                arr = [x, y, z]
5745
                return arr[ind]
5746

5747
        x = torch.randn(3, 4, 5)
5748
        y = torch.randn(1, 4, 5)
5749
        z = torch.randn(2, 4, 5)
5750
        ind = torch.tensor(1, dtype=torch.long)
5751
        self.run_test(GetItemModel(), (x, y, z, ind))
5752

5753
        ind = torch.tensor(-2, dtype=torch.long)
5754
        self.run_test(GetItemModel(), (x, y, z, ind))
5755

5756
    @skipDtypeChecking
5757
    def test_item(self):
5758
        class M(torch.nn.Module):
5759
            def forward(self, x, y, i: int):
5760
                return int(x[y[i]].item())
5761

5762
        x = torch.arange(6, dtype=torch.float)
5763
        y = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)
5764
        i = 3
5765
        self.run_test(torch.jit.script(M()), (x, y, i))
5766

5767
    @skipScriptTest()  # torch.nonzero(x, as_tuple=True) is not scriptable.
5768
    @skipIfUnsupportedMinOpsetVersion(9)
5769
    def test_nonzero(self):
5770
        class NonzeroModel(torch.nn.Module):
5771
            def forward(self, x):
5772
                return x.nonzero(), x.nonzero(as_tuple=True)
5773

5774
        x = torch.randn(60).index_fill_(0, torch.randint(0, 60, (20,)), 0).view(3, 4, 5)
5775
        self.run_test(NonzeroModel(), (x,))
5776

5777
    def test_unbind(self):
5778
        class UnbindModel(torch.nn.Module):
5779
            def forward(self, input):
5780
                _, out, _ = input.unbind()
5781
                return out
5782

5783
        x = torch.randn(3, 4, 5)
5784
        self.run_test(UnbindModel(), x)
5785

5786
        class UnbindModel2(torch.nn.Module):
5787
            def forward(self, input):
5788
                _, out, _, _ = input.unbind(1)
5789
                return out
5790

5791
        x = torch.randn(3, 4, 5)
5792
        self.run_test(UnbindModel2(), x)
5793

5794
        class UnbindModel3(torch.nn.Module):
5795
            def forward(self, input):
5796
                _, out, _, _ = input.unbind(-2)
5797
                return out
5798

5799
        x = torch.randn(3, 4, 5)
5800
        self.run_test(UnbindModel3(), x)
5801

5802
    @skipIfUnsupportedMinOpsetVersion(11)
5803
    def test_len(self):
5804
        class LenModel(torch.jit.ScriptModule):
5805
            @torch.jit.script_method
5806
            def forward(self, input):
5807
                return len(input.unbind()) + input
5808

5809
        x = torch.randn(4, 5)
5810
        self.run_test(
5811
            LenModel(),
5812
            x,
5813
            input_names=["input"],
5814
            dynamic_axes={"input": {0: "seq"}},
5815
            additional_test_inputs=(torch.randn(5, 5),),
5816
        )
5817

5818
    @skipIfUnsupportedMinOpsetVersion(9)
5819
    def test_len_list(self):
5820
        class LenListModel(torch.jit.ScriptModule):
5821
            @torch.jit.script_method
5822
            def forward(self, input):
5823
                return torch.ones(len(input.shape))
5824

5825
        x = torch.randn(4, 5)
5826
        self.run_test(LenListModel(), x, remained_onnx_input_idx=[])
5827

5828
    @skipIfUnsupportedMinOpsetVersion(11)
5829
    def test_unbind_dynamic(self):
5830
        class UnbindModel(torch.jit.ScriptModule):
5831
            @torch.jit.script_method
5832
            def forward(self, input):
5833
                return input.unbind()[1]
5834

5835
        x = torch.randn(3, 4, 5)
5836
        self.run_test(UnbindModel(), x)
5837

5838
        class UnbindModel2(torch.jit.ScriptModule):
5839
            @torch.jit.script_method
5840
            def forward(self, input):
5841
                return input.unbind(-1)[1]
5842

5843
        x = torch.randn(3, 4, 5)
5844
        self.run_test(UnbindModel2(), x)
5845

5846
    @skipScriptTest()  # scripting tests run for opsets > 11. See: test_split_script
5847
    def test_split(self):
5848
        class SplitModel(torch.nn.Module):
5849
            def forward(self, input):
5850
                return input.split([2, 1, 2]), input.split([3, 2])[0]
5851

5852
        x = torch.randn(5, 4, 3)
5853
        self.run_test(SplitModel(), x)
5854

5855
        class SplitModel2(torch.nn.Module):
5856
            def forward(self, input):
5857
                return input.split([2, 1, 1], -2), input.split([2, 2], -2)[-1]
5858

5859
        x = torch.randn(5, 4, 3)
5860
        self.run_test(SplitModel2(), x)
5861

5862
        class SplitModel3(torch.nn.Module):
5863
            def forward(self, input):
5864
                return input.split([2, 1, 2])
5865

5866
        x = torch.randn(5, 4, 3)
5867
        self.run_test(SplitModel3(), x)
5868

5869
    @skipIfUnsupportedMinOpsetVersion(11)
5870
    def test_split_script(self):
5871
        class SplitModel(torch.nn.Module):
5872
            def forward(self, input):
5873
                return input.split([2, 1, 2]), input.split([3, 2])[0]
5874

5875
        x = torch.randn(5, 4, 3)
5876
        self.run_test(SplitModel(), x)
5877

5878
        class SplitModel2(torch.nn.Module):
5879
            def forward(self, input):
5880
                return input.split([2, 1, 1], -2), input.split([2, 2], -2)[-1]
5881

5882
        x = torch.randn(5, 4, 3)
5883
        self.run_test(SplitModel2(), x)
5884

5885
        class SplitModel3(torch.nn.Module):
5886
            def forward(self, input):
5887
                return input.split([2, 1, 2])
5888

5889
        x = torch.randn(5, 4, 3)
5890
        self.run_test(SplitModel3(), x)
5891

5892
    @skipIfUnsupportedMinOpsetVersion(11)
5893
    @skipScriptTest()
5894
    def test_split_size_as_list(self):
5895
        class SplitModel(torch.nn.Module):
5896
            def forward(self, input, split_sizes: List[int]):
5897
                out = []
5898
                split_list: List[Tensor] = input.split(split_sizes)
5899

5900
                for ob in split_list:
5901
                    out.append(ob)  # noqa: PERF402
5902
                return torch.cat(out, dim=0)
5903

5904
        x = torch.randn(6, 4, 3)
5905
        split_sizes = [torch.tensor(2), torch.tensor(4)]
5906
        self.run_test(SplitModel(), (x, split_sizes))
5907

5908
    @skipIfUnsupportedMinOpsetVersion(11)
5909
    def test_split_size_with_slice(self):
5910
        class SplitModule(torch.nn.Module):
5911
            def forward(self, x, y, t):
5912
                splits = (x.size(1), y.size(1))
5913
                out, out2 = torch.split(t, splits, dim=1)
5914
                return out, out2
5915

5916
        x = torch.randn(2, 3)
5917
        y = torch.randn(2, 4)
5918
        t = torch.randn(2, 7)
5919
        self.run_test(
5920
            SplitModule(),
5921
            (x, y, t),
5922
            input_names=["x", "y", "t"],
5923
            dynamic_axes={"x": [0, 1], "y": [0, 1], "t": [0, 1]},
5924
        )
5925
        self.run_test(SplitModule(), (x, y, t), remained_onnx_input_idx=[2])
5926

5927
    @skipIfUnsupportedMinOpsetVersion(11)
5928
    def test_split_dynamic(self):
5929
        class SplitModel(torch.jit.ScriptModule):
5930
            @torch.jit.script_method
5931
            def forward(self, input):
5932
                return input.split(2)[1]
5933

5934
        x = torch.randn(5, 4, 3)
5935
        self.run_test(SplitModel(), x)
5936

5937
        class SplitModel2(torch.jit.ScriptModule):
5938
            @torch.jit.script_method
5939
            def forward(self, input):
5940
                return input.split(2, -3)[1]
5941

5942
        x = torch.randn(5, 4, 3)
5943
        self.run_test(SplitModel2(), x)
5944

5945
    @skipIfUnsupportedMinOpsetVersion(11)
5946
    def test_split_dynamic_axes(self):
5947
        class Split(torch.nn.Module):
5948
            def forward(self, x):
5949
                return x.split(1, dim=-1)
5950

5951
        x = torch.randn(4, 384, 2)
5952
        input_names = ["logits"]
5953
        self.run_test(
5954
            Split(),
5955
            x,
5956
            input_names=input_names,
5957
            dynamic_axes={input_names[0]: {0: "batch"}},
5958
        )
5959

5960
    @skipIfUnsupportedMinOpsetVersion(11)
5961
    def test_chunk(self):
5962
        class ChunkModel(torch.nn.Module):
5963
            def __init__(self, dim=1):
5964
                super().__init__()
5965
                self.dim = dim
5966

5967
            def forward(self, x):
5968
                return torch.chunk(x, 3, dim=self.dim)
5969

5970
        model = ChunkModel()
5971
        model.eval()
5972
        model_neg_dim = ChunkModel(-1)
5973
        model_neg_dim.eval()
5974
        x = torch.randn(1, 18)
5975

5976
        for dim_size_ in range(13, 16):
5977
            y = torch.randn(1, dim_size_)
5978
            self.run_test(
5979
                model,
5980
                x,
5981
                additional_test_inputs=[y],
5982
                input_names=["x"],
5983
                dynamic_axes={"x": {0: "batch_size", 1: "dims"}},
5984
            )
5985

5986
            self.run_test(
5987
                model_neg_dim,
5988
                x,
5989
                additional_test_inputs=[y],
5990
                input_names=["x"],
5991
                dynamic_axes={"x": {0: "batch_size", 1: "dims"}},
5992
            )
5993

5994
    @skipIfUnsupportedMinOpsetVersion(11)
5995
    def test_dynamic_chunk(self):
5996
        class ChunkModel(torch.nn.Module):
5997
            def __init__(self, dim=1):
5998
                super().__init__()
5999
                self.dim = dim
6000

6001
            def forward(self, x):
6002
                return torch.chunk(x, x.size(0), dim=self.dim)
6003

6004
        model = ChunkModel()
6005
        model.eval()
6006
        model_neg_dim = ChunkModel(-1)
6007
        model_neg_dim.eval()
6008
        x = torch.randn(3, 18)
6009

6010
        for dim_size_ in range(13, 16):
6011
            y = torch.randn(3, dim_size_)
6012
            self.run_test(
6013
                model,
6014
                x,
6015
                additional_test_inputs=[y],
6016
                input_names=["x"],
6017
                dynamic_axes={"x": {0: "batch_size", 1: "dims"}},
6018
            )
6019

6020
            self.run_test(
6021
                model_neg_dim,
6022
                x,
6023
                additional_test_inputs=[y],
6024
                input_names=["x"],
6025
                dynamic_axes={"x": {0: "batch_size", 1: "dims"}},
6026
            )
6027

6028
    def test_concat(self):
6029
        class ConcatModel(torch.nn.Module):
6030
            def forward(self, x, y, z):
6031
                return torch.cat((x, y, z))
6032

6033
        x = torch.randn(3, 4, 5)
6034
        y = torch.randn(1, 4, 5)
6035
        z = torch.randn(2, 4, 5)
6036
        self.run_test(ConcatModel(), (x, y, z))
6037

6038
    @skipIfUnsupportedMinOpsetVersion(11)
6039
    def test_concat_dynamic(self):
6040
        class ConcatDynamicModel(torch.jit.ScriptModule):
6041
            @torch.jit.script_method
6042
            def forward(self, x):
6043
                return torch.cat(x.unbind())
6044

6045
        x = torch.randn(4, 5, 6)
6046
        self.run_test(ConcatDynamicModel(), x)
6047

6048
    def test_stack(self):
6049
        class StackModel(torch.nn.Module):
6050
            def forward(self, x, y, z):
6051
                return torch.stack((x, y, z), 1)
6052

6053
        x = torch.randn(3, 4, 5)
6054
        y = torch.randn(3, 4, 5)
6055
        z = torch.randn(3, 4, 5)
6056
        self.run_test(StackModel(), (x, y, z))
6057

6058
    @skipIfUnsupportedMinOpsetVersion(11)
6059
    def test_stack_dynamic(self):
6060
        class StackDynamicModel(torch.jit.ScriptModule):
6061
            @torch.jit.script_method
6062
            def forward(self, x):
6063
                return torch.stack(x.unbind(), 1)
6064

6065
        x = torch.randn(4, 5, 6)
6066
        self.run_test(StackDynamicModel(), x)
6067

6068
    def test_loop_dynamic(self):
6069
        class LoopModel(torch.jit.ScriptModule):
6070
            @torch.jit.script_method
6071
            def forward(self, x):
6072
                for i in range(x.size(2)):
6073
                    x = x + i
6074
                return x
6075

6076
        model = LoopModel()
6077
        inputs = torch.zeros(1, 2, 3, dtype=torch.long)
6078
        self.run_test(model, inputs)
6079

6080
    @skipIfUnsupportedMinOpsetVersion(9)
6081
    def test_loop_nested(self):
6082
        class NestedLoopsModel(torch.jit.ScriptModule):
6083
            @torch.jit.script_method
6084
            def forward(self, x):
6085
                for i in range(5):
6086
                    a = 0
6087
                    while a < 4:
6088
                        a += 1
6089
                    x = x + a
6090
                return x
6091

6092
        model = NestedLoopsModel()
6093
        inputs = torch.zeros(1, 2, 3, dtype=torch.long)
6094
        self.run_test(model, inputs)
6095

6096
    @skipIfUnsupportedMinOpsetVersion(11)
6097
    def test_loop_with_list(self):
6098
        class ListLoopModel(torch.jit.ScriptModule):
6099
            @torch.jit.script_method
6100
            def forward(self, x):
6101
                res = []
6102
                res1 = []
6103
                arr = x.split([3, 4, 1, 1, 2, 3, 2], 0)
6104
                res2 = torch.zeros(3, 4, dtype=torch.long)
6105
                res3 = []
6106
                res4 = []
6107
                for i in range(len(arr)):
6108
                    res.append(arr[i].sum(0, False))
6109
                    res1.append(arr[-1 - i].sum(0, False))
6110
                    res2 += 1
6111
                    res3 = res3 + [arr[i].sum(0, False)]
6112
                    res4 += [arr[-1 - i].sum(0, False)]
6113
                return res, res1, res2, torch.stack(res3), torch.stack(res4)
6114

6115
        model = ListLoopModel()
6116
        inputs = torch.randn(16)
6117
        self.run_test(model, inputs)
6118

6119
    @skipIfUnsupportedMinOpsetVersion(11)
6120
    def test_loop_transpose(self):
6121
        class LoopModel(torch.nn.Module):
6122
            def forward(self, x):
6123
                res = torch.zeros_like(x[0])
6124
                for i in range(x.size(0)):
6125
                    res += x[0].transpose(0, 1)
6126
                return res
6127

6128
        model = torch.jit.script(LoopModel())
6129
        x = torch.randn(5, 3, 3)
6130
        self.run_test(model, x)
6131

6132
    @skipIfUnsupportedMinOpsetVersion(11)
6133
    def test_loop_multi_dim(self):
6134
        class LoopMultiDimModel(torch.jit.ScriptModule):
6135
            @torch.jit.script_method
6136
            def forward(self, x, y):
6137
                for x_ in torch.flip(x.narrow(0, 0, 7), [0]):
6138
                    y = x_[0][y]
6139
                return y
6140

6141
        model = LoopMultiDimModel()
6142
        x = torch.randint(0, 5, (8, 1, 17), dtype=torch.long)
6143
        y = torch.ones(1, dtype=torch.long)
6144
        self.run_test(model, (x, y))
6145

6146
    @skipIfUnsupportedMinOpsetVersion(11)
6147
    def test_list(self):
6148
        class ListModel(torch.jit.ScriptModule):
6149
            @torch.jit.script_method
6150
            def forward(self, x):
6151
                tensors = x.unbind()
6152
                res = []
6153
                res.append(tensors[0])
6154
                res.append(tensors[1])
6155
                res.pop(1)
6156

6157
                res.insert(0, tensors[1])
6158
                res.append(tensors[2])
6159
                res += [tensors[3], tensors[4]]
6160
                res = res + [tensors[5]]
6161
                return torch.ones(len(res))
6162

6163
        model = ListModel()
6164
        inputs = torch.randn(16, 1)
6165
        self.run_test(model, inputs)
6166

6167
    @skipIfUnsupportedMinOpsetVersion(11)
6168
    def test_list_append(self):
6169
        class ListModel(torch.nn.Module):
6170
            def forward(self, x, y):
6171
                res = []
6172
                for i in range(x.size(0)):
6173
                    res += [torch.matmul(x[i], y)]
6174
                return res
6175

6176
        model = torch.jit.script(ListModel())
6177
        x = torch.randn(16, 3, 4)
6178
        y = torch.randn(4, 5)
6179
        self.run_test(model, (x, y))
6180

6181
    @skipIfUnsupportedMinOpsetVersion(13)
6182
    def test_list_append_nested(self):
6183
        class ListModel(torch.nn.Module):
6184
            def forward(self, x, y):
6185
                res = []
6186
                for i in range(x.size(0)):
6187
                    for j in range(x.size(1)):
6188
                        res += [torch.matmul(x[i][j], y)]
6189
                return res
6190

6191
        model = torch.jit.script(ListModel())
6192
        x = torch.randn(4, 4, 3, 4)
6193
        y = torch.randn(4, 5)
6194
        self.run_test(model, (x, y))
6195

6196
    @skipIfUnsupportedMinOpsetVersion(14)  # Need onnx::Identity of sequence in opset 14
6197
    def test_list_append_nested_2(self):
6198
        class ListModel(torch.nn.Module):
6199
            def forward(self, x):
6200
                res = []
6201
                res_replicate = []
6202
                for i in range(x.size(0)):
6203
                    if len(res) > 2:
6204
                        for j in range(x.size(1)):
6205
                            res.append(x[i][j])
6206
                        res_replicate.append(res[-1])
6207
                        res.append(res_replicate[-1])
6208
                return res, res_replicate
6209

6210
        model = torch.jit.script(ListModel())
6211
        x = torch.randn(4, 4, 3, 4)
6212
        self.run_test(model, (x,))
6213

6214
    @skipIfUnsupportedMinOpsetVersion(13)
6215
    def test_list_append_nested_mixed_dtype(self):
6216
        class ListModel(torch.nn.Module):
6217
            def forward(self, x, y):
6218
                res = []
6219
                for i in range(x.size(0)):
6220
                    for j in range(x.size(1)):
6221
                        if i == j:
6222
                            res.append(x == y)
6223
                        else:
6224
                            res.append(x != y)
6225
                return res
6226

6227
        model = torch.jit.script(ListModel())
6228
        x = torch.randn(4, 4, 3, 4)
6229
        y = torch.randn(3, 4)
6230
        self.run_test(model, (x, y))
6231

6232
    @skipIfUnsupportedMinOpsetVersion(11)
6233
    def test_list_pop(self):
6234
        class ListModel(torch.nn.Module):
6235
            def forward(self, x, y):
6236
                res = []
6237
                for i in range(x.size(0)):
6238
                    res += [torch.matmul(x[i], y)]
6239
                res.pop()
6240
                return res
6241

6242
        model = torch.jit.script(ListModel())
6243
        x = torch.randn(16, 3, 4)
6244
        y = torch.randn(4, 5)
6245
        self.run_test(model, (x, y))
6246

6247
    @skipIfUnsupportedMinOpsetVersion(13)
6248
    def test_list_pop_nested(self):
6249
        class ListModel(torch.nn.Module):
6250
            def forward(self, x, y):
6251
                res = []
6252
                for i in range(x.size(0)):
6253
                    for j in range(x.size(1)):
6254
                        res += [torch.matmul(x[i][j], y)]
6255
                        res.pop()
6256
                    res += [torch.matmul(x[i][0], y)]
6257
                return res
6258

6259
        model = torch.jit.script(ListModel())
6260
        x = torch.randn(4, 4, 3, 4)
6261
        y = torch.randn(4, 5)
6262
        self.run_test(model, (x, y))
6263

6264
    @skipIfUnsupportedMinOpsetVersion(11)
6265
    def test_list_del(self):
6266
        class ListModel(torch.nn.Module):
6267
            def forward(self, x, y):
6268
                res = []
6269
                for i in range(x.size(0)):
6270
                    res += [torch.matmul(x[i], y)]
6271
                del res[2]
6272
                return res
6273

6274
        model = torch.jit.script(ListModel())
6275
        x = torch.randn(16, 3, 4)
6276
        y = torch.randn(4, 5)
6277
        self.run_test(model, (x, y))
6278

6279
    @skipIfUnsupportedMinOpsetVersion(13)
6280
    def test_list_del_nested(self):
6281
        class ListModel(torch.nn.Module):
6282
            def forward(self, x, y):
6283
                res = []
6284
                for i in range(x.size(0)):
6285
                    for j in range(x.size(1)):
6286
                        res += [torch.matmul(x[i][j], y)]
6287
                        del res[i]
6288
                    res += [torch.matmul(x[i][0], y)]
6289
                return res
6290

6291
        model = torch.jit.script(ListModel())
6292
        x = torch.randn(4, 4, 3, 4)
6293
        y = torch.randn(4, 5)
6294
        self.run_test(model, (x, y))
6295

6296
    @skipIfUnsupportedMinOpsetVersion(11)
6297
    def test_list_set(self):
6298
        class ListModel(torch.nn.Module):
6299
            def forward(self, x, y):
6300
                res = []
6301
                for i in range(x.size(0)):
6302
                    res.append(x[i])
6303
                res[y] = x[y]
6304
                return res
6305

6306
        model = torch.jit.script(ListModel())
6307
        x = torch.randn(12, 4)
6308
        y = torch.tensor(2, dtype=torch.long)
6309
        self.run_test(model, (x, y))
6310

6311
    @skipIfUnsupportedMinOpsetVersion(13)
6312
    def test_list_idx_sum(self):
6313
        class ListModel(torch.nn.Module):
6314
            def forward(self, x, y):
6315
                indices = torch.arange(x.size(0))
6316
                res = []
6317
                for i in range(x.size(0)):
6318
                    res.append(x[i])
6319
                return res[torch.sum(indices[:y])]
6320

6321
        model = torch.jit.script(ListModel())
6322
        x = torch.randn(12, 4)
6323
        y = torch.tensor(2, dtype=torch.long)
6324
        self.run_test(model, (x, y))
6325

6326
    @skipIfUnsupportedMinOpsetVersion(9)
6327
    def test_tensor_factories(self):
6328
        class TensorFactory(torch.nn.Module):
6329
            def forward(self, x):
6330
                return torch.zeros(x.size()) + torch.ones(x.size())
6331

6332
        x = torch.randn(2, 3, 4)
6333
        self.run_test(
6334
            TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
6335
        )
6336
        self.run_test(TensorFactory(), x, remained_onnx_input_idx=[])
6337

6338
    @skipIfUnsupportedMinOpsetVersion(9)
6339
    def test_tensor_factories_script(self):
6340
        class TensorFactory(torch.jit.ScriptModule):
6341
            @torch.jit.script_method
6342
            def forward(self, x):
6343
                return torch.zeros(x.shape, dtype=torch.float) + torch.ones(
6344
                    x.shape, dtype=torch.float
6345
                )
6346

6347
        x = torch.randn(2, 3, 4)
6348
        self.run_test(
6349
            TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
6350
        )
6351
        self.run_test(TensorFactory(), x, remained_onnx_input_idx=[])
6352

6353
    @skipIfUnsupportedMinOpsetVersion(9)
6354
    def test_tensor_like_factories_script(self):
6355
        class TensorFactory(torch.jit.ScriptModule):
6356
            @torch.jit.script_method
6357
            def forward(self, x):
6358
                zeros = torch.zeros_like(
6359
                    x,
6360
                    dtype=torch.float,
6361
                    layout=torch.strided,
6362
                    device=torch.device("cpu"),
6363
                )
6364
                ones = torch.ones_like(
6365
                    x,
6366
                    dtype=torch.float,
6367
                    layout=torch.strided,
6368
                    device=torch.device("cpu"),
6369
                )
6370
                return zeros + ones
6371

6372
        x = torch.randn(2, 3, 4)
6373
        self.run_test(
6374
            TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
6375
        )
6376
        self.run_test(TensorFactory(), x, remained_onnx_input_idx=[])
6377

6378
    @skipIfUnsupportedMinOpsetVersion(13)
6379
    def test_tensor_split(self):
6380
        class TensorSplitModel(torch.nn.Module):
6381
            def forward(self, input):
6382
                return (
6383
                    input.tensor_split([1, 3]),
6384
                    # test with output indexing.
6385
                    input.tensor_split([2, 4])[0],
6386
                    # test split on specific dim.
6387
                    input.tensor_split([1, 3, 4], dim=-2),
6388
                    # test split on specific dim and output indexing.
6389
                    input.tensor_split([0, 2], dim=-2)[-1],
6390
                    # test with out of bound end index (5).
6391
                    input.tensor_split([2, 3, 5]),
6392
                )
6393

6394
        self.run_test(TensorSplitModel(), torch.randn(5, 4, 3))
6395

6396
    @skipIfUnsupportedMinOpsetVersion(13)
6397
    def test_tensor_split_scalar(self):
6398
        class TensorSplitModel(torch.nn.Module):
6399
            def forward(self, x):
6400
                return torch.tensor_split(x, x.size(1))
6401

6402
        self.run_test(TensorSplitModel(), torch.randn(1, 2, 3))
6403

6404
    @skipIfUnsupportedMinOpsetVersion(13)
6405
    def test_tensor_split_dynamic_axes(self):
6406
        class TensorSplitModel(torch.nn.Module):
6407
            def forward(self, x):
6408
                return x.tensor_split(1, dim=-1)
6409

6410
        x = torch.randn(4, 384, 2)
6411
        input_names = ["logits"]
6412
        self.run_test(
6413
            TensorSplitModel(),
6414
            x,
6415
            input_names=input_names,
6416
            dynamic_axes={input_names[0]: {0: "batch"}},
6417
        )
6418

6419
    @skipIfUnsupportedMinOpsetVersion(9)
6420
    def test_eye(self):
6421
        class TensorFactory(torch.nn.Module):
6422
            def forward(self, x):
6423
                return (
6424
                    torch.eye(x.size()[1], 3),
6425
                    torch.eye(4, 4, dtype=torch.long),
6426
                    torch.eye(x.size()[1], 2, dtype=torch.long),
6427
                    torch.eye(x.shape[0]),
6428
                    torch.eye(x.shape[0], dtype=torch.float64),
6429
                )
6430

6431
        x = torch.randn(2, 3, 4)
6432
        another_x = torch.randn(5, 6, 7)
6433
        self.run_test(
6434
            TensorFactory(),
6435
            x,
6436
            additional_test_inputs=[another_x],
6437
            input_names=["input_1"],
6438
            dynamic_axes={"input_1": [0, 1, 2]},
6439
        )
6440

6441
    @skipIfUnsupportedMinOpsetVersion(13)
6442
    def test_diagonal(self):
6443
        class DiagonalModel(torch.nn.Module):
6444
            def forward(self, x):
6445
                return torch.diagonal(x)
6446

6447
        x = torch.randn(2, 4, 5, 2)
6448
        # Other test inputs to test dynamic behavior
6449
        another_x = torch.randn(5, 6, 7, 8)
6450
        self.run_test(
6451
            DiagonalModel(),
6452
            x,
6453
            additional_test_inputs=[another_x],
6454
            input_names=["input_1"],
6455
            dynamic_axes={"input_1": [0, 1, 2, 3]},
6456
        )
6457

6458
        class DiagonalModelNegOffset(torch.nn.Module):
6459
            def forward(self, x):
6460
                return torch.diagonal(x, offset=-1)
6461

6462
        x = torch.randn(2, 4, 5, 2)
6463
        # Other test inputs to test dynamic behavior
6464
        another_x = torch.randn(5, 6, 7, 8)
6465
        self.run_test(
6466
            DiagonalModelNegOffset(),
6467
            x,
6468
            additional_test_inputs=[another_x],
6469
            input_names=["input_1"],
6470
            dynamic_axes={"input_1": [0, 1, 2, 3]},
6471
        )
6472

6473
        class DiagonalModelPosOffset(torch.nn.Module):
6474
            def forward(self, x):
6475
                return torch.diagonal(x, offset=1)
6476

6477
        x = torch.randn(2, 4, 5, 2)
6478
        # Other test inputs to test dynamic behavior
6479
        another_x = torch.randn(5, 6, 7, 8)
6480
        self.run_test(
6481
            DiagonalModelPosOffset(),
6482
            x,
6483
            additional_test_inputs=[another_x],
6484
            input_names=["input_1"],
6485
            dynamic_axes={"input_1": [0, 1, 2, 3]},
6486
        )
6487

6488
        class DiagonalModelWithDims(torch.nn.Module):
6489
            def forward(self, x):
6490
                return torch.diagonal(x, offset=-1, dim1=1, dim2=2)
6491

6492
        x = torch.randn(2, 4, 5, 2)
6493
        # Other test inputs to test dynamic behavior
6494
        another_x = torch.randn(5, 6, 7, 8)
6495
        self.run_test(
6496
            DiagonalModelWithDims(),
6497
            x,
6498
            additional_test_inputs=[another_x],
6499
            input_names=["input_1"],
6500
            dynamic_axes={"input_1": [0, 1, 2, 3]},
6501
        )
6502

6503
        class DiagonalModelWithNegativeDims(torch.nn.Module):
6504
            def forward(self, x):
6505
                return torch.diagonal(x, offset=0, dim1=-2, dim2=-1)
6506

6507
        x = torch.randn(2, 4, 5, 2)
6508
        # Other test inputs to test dynamic behavior
6509
        another_x = torch.randn(5, 6, 7, 8)
6510
        self.run_test(
6511
            DiagonalModelWithNegativeDims(),
6512
            x,
6513
            additional_test_inputs=[another_x],
6514
            input_names=["input_1"],
6515
            dynamic_axes={"input_1": [0, 1, 2, 3]},
6516
        )
6517

6518
        class DiagonalModelOffsetOverrun(torch.nn.Module):
6519
            def forward(self, x):
6520
                return torch.diagonal(x, offset=-2), torch.diagonal(x, offset=5)
6521

6522
        x = torch.randn(2, 4, 5, 2)
6523
        # Other test inputs to test dynamic behavior
6524
        another_x = torch.randn(5, 6, 7, 8)
6525
        self.run_test(
6526
            DiagonalModelOffsetOverrun(),
6527
            x,
6528
            additional_test_inputs=[another_x],
6529
            input_names=["input_1"],
6530
            dynamic_axes={"input_1": [0, 1, 2, 3]},
6531
        )
6532

6533
    @skipIfUnsupportedMinOpsetVersion(9)
6534
    def test_inplace_zero(self):
6535
        class Zero_(torch.nn.Module):
6536
            def forward(self, x):
6537
                return x.zero_(), x
6538

6539
        x = torch.randn(2, 3, 4)
6540
        self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6541
        self.run_test(Zero_(), x, remained_onnx_input_idx=[])
6542

6543
    @skipIfUnsupportedMinOpsetVersion(11)
6544
    def test_inplace_zero_qkv(self):
6545
        class Zero_(torch.nn.Module):
6546
            def forward(self, x):
6547
                return x[2:4].zero_()
6548

6549
        x = torch.randn(24, 3, 4)
6550
        self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6551

6552
    @skipIfUnsupportedMinOpsetVersion(9)
6553
    def test_new_zeros(self):
6554
        class Zero_(torch.nn.Module):
6555
            def forward(self, x):
6556
                return x.new_zeros(x.shape[1:2]), x.new_zeros(
6557
                    x.shape[2:], dtype=torch.long
6558
                )
6559

6560
        x = torch.randn(2, 3, 4)
6561
        self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6562
        self.run_test(Zero_(), x, remained_onnx_input_idx=[])
6563

6564
    @skipIfUnsupportedMinOpsetVersion(9)
6565
    def test_new_zeros_with_dtype(self):
6566
        class MyModel(torch.nn.Module):
6567
            def __init__(self):
6568
                super().__init__()
6569
                self.emb = torch.nn.Embedding(50, 64)
6570

6571
            def forward(self, x):
6572
                inp = x.new_zeros(x.shape)
6573
                return self.emb(inp)
6574

6575
        model = MyModel()
6576
        x = torch.Tensor([[2, 5, 6], [3, 2, 5]]).to(torch.int64)
6577
        self.run_test(model, x, input_names=["x"], dynamic_axes={"x": [0, 1]})
6578

6579
    @skipIfUnsupportedMinOpsetVersion(9)
6580
    def test_new_ones(self):
6581
        class OnesModel(torch.nn.Module):
6582
            def forward(self, x):
6583
                return x.new_ones(x.shape[1:2]), x.new_ones(
6584
                    x.shape[2:], dtype=torch.long
6585
                )
6586

6587
        x = torch.randn(2, 3, 4)
6588
        self.run_test(OnesModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6589
        self.run_test(OnesModel(), x, remained_onnx_input_idx=[])
6590

6591
    @skipIfUnsupportedMinOpsetVersion(9)
6592
    @skipScriptTest()  # torch.zeros/torch.ones with size tensor of dim != 0 not scriptable.
6593
    def test_zeros_ones_with_tensor_input(self):
6594
        class ZeroAndOnes(torch.nn.Module):
6595
            def forward(self, x):
6596
                return torch.zeros(x, 1), torch.ones(x, 1)
6597

6598
        x = torch.tensor([2])
6599
        self.run_test(ZeroAndOnes(), (x,))
6600

6601
    @skipIfUnsupportedMinOpsetVersion(9)
6602
    @skipShapeChecking
6603
    def test_tolist(self):
6604
        class List(torch.jit.ScriptModule):
6605
            @torch.jit.script_method
6606
            def forward(self, input):
6607
                res: List[int] = input.tolist()
6608
                return res
6609

6610
        self.run_test(List(), (torch.randint(100, (1,)),))
6611

6612
    @skipIfUnsupportedMinOpsetVersion(9)
6613
    def test_list_pass(self):
6614
        class Slice(torch.nn.Module):
6615
            def forward(self, x, y):
6616
                return x.new_zeros(x.shape[2:] + y.shape[1:])
6617

6618
        x = torch.randn(2, 3, 4, 5)
6619
        y = torch.randn(1, 2, 3, 4)
6620
        self.run_test(
6621
            Slice(),
6622
            (x, y),
6623
            input_names=["x", "y"],
6624
            dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]},
6625
        )
6626
        self.run_test(Slice(), (x, y), remained_onnx_input_idx=[])
6627

6628
        class Size(torch.nn.Module):
6629
            def forward(self, x, y):
6630
                return x.new_zeros(x.shape + y.shape)
6631

6632
        x = torch.randn(2, 3, 4)
6633
        y = torch.randn(1, 2, 3)
6634
        self.run_test(
6635
            Size(),
6636
            (x, y),
6637
            input_names=["x", "y"],
6638
            dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
6639
        )
6640
        self.run_test(Size(), (x, y), remained_onnx_input_idx=[])
6641

6642
        class Array(torch.nn.Module):
6643
            def forward(self, x, y):
6644
                arr1 = [x.shape[0], x.shape[1], 2]
6645
                arr2 = [y.shape[0], y.shape[1]]
6646
                return x.new_zeros(arr1 + arr2)
6647

6648
        x = torch.randn(2, 3, 4)
6649
        y = torch.randn(1, 2, 3)
6650
        self.run_test(
6651
            Array(),
6652
            (x, y),
6653
            input_names=["x", "y"],
6654
            dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
6655
        )
6656
        self.run_test(Array(), (x, y), remained_onnx_input_idx=[])
6657

6658
        class List(torch.nn.Module):
6659
            def forward(self, x, y):
6660
                l1 = list(x.shape)
6661
                l2 = list(y.shape)
6662
                return x.new_zeros(l1 + l2)
6663

6664
        x = torch.randn(2, 3, 4)
6665
        y = torch.randn(1, 2, 3)
6666
        self.run_test(
6667
            List(),
6668
            (x, y),
6669
            input_names=["x", "y"],
6670
            dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
6671
        )
6672
        self.run_test(List(), (x, y), remained_onnx_input_idx=[])
6673

6674
    @skipIfUnsupportedMinOpsetVersion(9)
6675
    def test_new_empty(self):
6676
        class Emtpy(torch.nn.Module):
6677
            def forward(self, x):
6678
                return (
6679
                    x.new_empty(x.shape[0]).fill_(0),
6680
                    x.new_empty(x.shape[0], dtype=torch.long) * 0,
6681
                )
6682

6683
        x = torch.randn(2, 3, 4)
6684
        self.run_test(Emtpy(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6685
        self.run_test(Emtpy(), x, remained_onnx_input_idx=[])
6686

6687
    @skipIfUnsupportedMinOpsetVersion(9)
6688
    def test_new_full(self):
6689
        class Full(torch.nn.Module):
6690
            def forward(self, x):
6691
                return x.new_full(x.shape[1:2], 5), x.new_full(
6692
                    x.shape[0:1], 1.3, dtype=torch.long
6693
                )
6694

6695
        x = torch.randn(2, 3, 4)
6696
        self.run_test(Full(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6697
        self.run_test(Full(), x, remained_onnx_input_idx=[])
6698

6699
    @skipIfUnsupportedMinOpsetVersion(9)
6700
    def test_inplace_list(self):
6701
        class Arithmetic(torch.jit.ScriptModule):
6702
            @torch.jit.script_method
6703
            def forward(self, x, y):
6704
                return torch.cat([x.add_(3), y.fill_(0)])
6705

6706
        x = torch.randn(2, 3)
6707
        y = torch.randn(2, 3)
6708
        self.run_test(
6709
            Arithmetic(),
6710
            (x, y),
6711
            input_names=["x", "y"],
6712
            dynamic_axes={"x": [0, 1], "y": [0, 1]},
6713
        )
6714
        self.run_test(Arithmetic(), (x, y), remained_onnx_input_idx=[0])
6715

6716
    @skipIfUnsupportedMinOpsetVersion(9)
6717
    def test_inplace_fill(self):
6718
        class Fill_(torch.nn.Module):
6719
            def forward(self, x):
6720
                return x.fill_(3), x
6721

6722
        x = torch.randn(2, 3, 4)
6723
        self.run_test(Fill_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
6724
        self.run_test(Fill_(), x, remained_onnx_input_idx=[])
6725

6726
    def test_inplace_arithmetic(self):
6727
        class Arithmetic(torch.jit.ScriptModule):
6728
            @torch.jit.script_method
6729
            def forward(self, x, y):
6730
                x.add_(3)
6731
                y.mul_(x)
6732
                return x, y
6733

6734
        x = torch.randn(2, 3, 4)
6735
        y = torch.randn(2, 3, 4)
6736
        self.run_test(Arithmetic(), (x, y))
6737

6738
    def test_inplace_arithmetic_half(self):
6739
        class InplaceAddModel(torch.nn.Module):
6740
            def forward(self, x, y):
6741
                return x.add_(y)
6742

6743
        class InplaceMulModel(torch.nn.Module):
6744
            def forward(self, x, y):
6745
                return x.mul_(y)
6746

6747
        x = torch.randn(2, 2, dtype=torch.half)
6748
        y = torch.randn(2, 2, dtype=torch.float)
6749
        self.run_test(InplaceAddModel(), (x, y), rtol=1e-2, atol=1e-2)
6750
        self.run_test(InplaceMulModel(), (x, y), rtol=1e-2, atol=1e-2)
6751

6752
    @skipIfUnsupportedMinOpsetVersion(9)
6753
    def test_inplace_with_loop(self):
6754
        class M(torch.nn.Module):
6755
            def forward(self, x):
6756
                a = torch.ones(
6757
                    12,
6758
                )
6759
                for i in range(10):
6760
                    a.add_(
6761
                        torch.ones(
6762
                            12,
6763
                        )
6764
                    )
6765
                return a + x
6766

6767
        m = M()
6768
        x = torch.randn(
6769
            12,
6770
        )
6771
        self.run_test(torch.jit.script(M()), (x))
6772

6773
    @skipIfUnsupportedMinOpsetVersion(9)
6774
    def test_inplace_with_loop_2(self):
6775
        class M(torch.nn.Module):
6776
            def forward(self, x):
6777
                _bias = torch.ones(
6778
                    12,
6779
                )
6780
                a = torch.ones(
6781
                    12,
6782
                )  # used in loop, altered.
6783
                a_ref = a  # not used in loop, should be altered.
6784
                b = x.clone()  # used in loop, not be altered.
6785
                b_ref = b  # not used in loop, should not be altered.
6786
                for i in range(10):
6787
                    if i == 3:
6788
                        for j in range(5):
6789
                            a += _bias
6790
                            _bias.add_(
6791
                                torch.ones(
6792
                                    12,
6793
                                )
6794
                            )
6795
                            b = b + torch.ones(
6796
                                12,
6797
                            )
6798

6799
                    _bias.add_(
6800
                        torch.ones(
6801
                            12,
6802
                        )
6803
                    )
6804
                    a += _bias
6805
                # TODO: value for a_ref is incorrect.
6806
                # a_ref += torch.ones(12,)
6807
                b_ref += torch.ones(
6808
                    12,
6809
                )
6810
                return _bias + x, a, b, b_ref
6811

6812
        m = M()
6813
        x = torch.zeros(
6814
            12,
6815
        )
6816
        self.run_test(torch.jit.script(M()), (x))
6817

6818
    @skipIfUnsupportedMinOpsetVersion(11)
6819
    def test_inplace_attr_with_loop(self):
6820
        class M(torch.nn.Module):
6821
            def __init__(self):
6822
                super().__init__()
6823
                self._bias = torch.arange(
6824
                    12,
6825
                )
6826

6827
            def forward(self, x):
6828
                self._bias = torch.arange(
6829
                    12,
6830
                )
6831
                for i in range(10):
6832
                    if i == 3:
6833
                        for j in range(5):
6834
                            self._bias += torch.arange(
6835
                                12,
6836
                            )
6837
                return self._bias + x
6838

6839
        m = M()
6840
        x = torch.zeros(
6841
            12,
6842
        )
6843
        self.run_test(torch.jit.script(M()), (x))
6844

6845
    @skipIfUnsupportedMinOpsetVersion(11)
6846
    def test_inplace_attr_copy_with_loop(self):
6847
        class M(torch.nn.Module):
6848
            def __init__(self):
6849
                super().__init__()
6850
                self._bias = torch.arange(
6851
                    12,
6852
                )
6853

6854
            def forward(self, x):
6855
                self._bias = torch.arange(
6856
                    12,
6857
                )
6858
                for i in range(10):
6859
                    if i == 3:
6860
                        for j in range(5):
6861
                            self._bias.copy_(
6862
                                torch.arange(
6863
                                    12,
6864
                                )
6865
                            )
6866
                        self._bias.copy_(
6867
                            self._bias
6868
                            + torch.arange(
6869
                                12,
6870
                            )
6871
                        )
6872

6873
                    self._bias.copy_(
6874
                        self._bias
6875
                        + torch.arange(
6876
                            12,
6877
                        )
6878
                    )
6879
                return self._bias + x
6880

6881
        m = M()
6882
        x = torch.zeros(
6883
            12,
6884
        )
6885
        self.run_test(torch.jit.script(M()), (x))
6886

6887
    @skipIfUnsupportedMinOpsetVersion(14)  # Need onnx::Identity of sequence in opset 14
6888
    def test_inplace_sequence_with_loop(self):
6889
        class M(torch.nn.Module):
6890
            def process(self, beam_hyps: List[Tensor], done: Tensor, x):
6891
                batch_size = x.shape[0]
6892
                for i in range(batch_size):
6893
                    if done[i]:
6894
                        continue
6895

6896
                    beam_idx = 0
6897
                    for _, token in enumerate(x[i]):
6898
                        beam_hyps.append(token)
6899
                        beam_idx += 1
6900

6901
                        if beam_idx == 6:
6902
                            break
6903

6904
                    done[i] = len(beam_hyps) > 4
6905

6906
                return beam_hyps, done
6907

6908
            def forward(self, x):
6909
                beam_hyps: List[Tensor] = []
6910
                batch_size = x.shape[0]
6911
                cur_len = 0
6912
                max_len = x.shape[1]
6913
                done = torch.zeros(batch_size, dtype=torch.bool)
6914
                while cur_len < max_len:
6915
                    beam_hyps, done = self.process(beam_hyps, done, x[:, 0, :])
6916
                    cur_len = cur_len + 1
6917

6918
                return beam_hyps
6919

6920
        m = torch.jit.script(M())
6921
        x = torch.randn(8, 4, 3)
6922
        self.run_test(torch.jit.script(M()), (x))
6923

6924
    @skipScriptTest()  # Sort with dynamic dim not supported in ONNX
6925
    def test_sort(self):
6926
        class SortModel(torch.nn.Module):
6927
            def forward(self, x):
6928
                out = []
6929
                for i in range(-2, 2):
6930
                    out.append(torch.sort(x, dim=i, descending=True))
6931
                return out
6932

6933
        x = torch.randn(3, 4)
6934
        self.run_test(SortModel(), x)
6935

6936
    @skipIfUnsupportedMinOpsetVersion(11)
6937
    @skipScriptTest()  # Sort with dynamic dim not supported in ONNX
6938
    def test_sort_ascending(self):
6939
        class SortModel(torch.nn.Module):
6940
            def forward(self, x):
6941
                out = []
6942
                for i in range(-2, 2):
6943
                    out.append(torch.sort(x, dim=i, descending=False))
6944
                return out
6945

6946
        x = torch.randn(3, 4)
6947
        self.run_test(SortModel(), x)
6948

6949
    @skipIfUnsupportedMinOpsetVersion(11)
6950
    def test_argsort(self):
6951
        class ArgSortModel(torch.nn.Module):
6952
            def forward(self, x):
6953
                return torch.argsort(x, dim=1, descending=False)
6954

6955
        x = torch.randn(3, 4)
6956
        self.run_test(ArgSortModel(), x)
6957

6958
    @skipIfUnsupportedMinOpsetVersion(9)
6959
    def test_masked_fill(self):
6960
        class MaskedFillModel(torch.nn.Module):
6961
            def forward(self, x):
6962
                mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.bool)
6963
                return x.masked_fill(mask, 2)
6964

6965
        x = torch.zeros(4, 2, 3, requires_grad=True)
6966
        self.run_test(MaskedFillModel(), x)
6967

6968
        class MaskedFillModel2(torch.nn.Module):
6969
            def forward(self, x):
6970
                return x.masked_fill(x > 3, -1)
6971

6972
        x = torch.arange(16).view(2, 2, 4).to(torch.float32)
6973
        self.run_test(MaskedFillModel2(), x)
6974

6975
    @skipIfUnsupportedMinOpsetVersion(9)
6976
    def test_masked_fill_inplace(self):
6977
        class MaskedFillModel(torch.jit.ScriptModule):
6978
            @torch.jit.script_method
6979
            def forward(self, x):
6980
                mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.bool)
6981
                x.masked_fill_(mask, 2)
6982
                return x
6983

6984
        x = torch.zeros(4, 2, 3, requires_grad=True)
6985
        self.run_test(MaskedFillModel(), x)
6986

6987
        class MaskedFillModel2(torch.jit.ScriptModule):
6988
            @torch.jit.script_method
6989
            def forward(self, x):
6990
                x.masked_fill_(x > 3, -1)
6991
                return x
6992

6993
        x = torch.arange(16).view(2, 2, 4).to(torch.float32)
6994
        self.run_test(MaskedFillModel2(), x)
6995

6996
    @skipIfUnsupportedMinOpsetVersion(11)
6997
    def test_masked_scatter(self):
6998
        class MaskedScatterModel(torch.nn.Module):
6999
            def forward(self, x):
7000
                return torch.masked_scatter(x, x.ge(0.5), torch.ones(100, 100) * 5)
7001

7002
        x = torch.randn(3, 4, 5, requires_grad=True)
7003
        self.run_test(MaskedScatterModel(), x)
7004

7005
    @skipIfUnsupportedMinOpsetVersion(11)
7006
    def test_masked_select(self):
7007
        class MaskedSelectModel(torch.nn.Module):
7008
            def forward(self, x):
7009
                return torch.masked_select(x, x.ge(0.5))
7010

7011
        x = torch.randn(3, 4, 5, requires_grad=True)
7012
        self.run_test(MaskedSelectModel(), x)
7013

7014
    @skipIfUnsupportedMinOpsetVersion(11)
7015
    def test_index_put_to_masked_fill(self):
7016
        class MaskedFillModel(torch.nn.Module):
7017
            def forward(self, input_mask, some_const):
7018
                mask = input_mask.clone()
7019
                mask[mask != some_const] = 1
7020
                mask[mask == some_const] = 0
7021
                return mask
7022

7023
        mask = torch.randn(2, 2, 2, requires_grad=True)
7024
        constant = torch.tensor(5, dtype=torch.float)
7025
        self.run_test(MaskedFillModel(), (mask, constant))
7026

7027
    @skipIfUnsupportedMinOpsetVersion(11)
7028
    def test_index_put_to_masked_scatter(self):
7029
        class MaskedScatterModel(torch.nn.Module):
7030
            def forward(self, input_mask, some_const):
7031
                mask = input_mask.clone()
7032
                mask[mask != some_const] = torch.ones(8)
7033
                return mask
7034

7035
        mask = torch.randn(2, 2, 2, requires_grad=True)
7036
        constant = torch.tensor(5, dtype=torch.float)
7037
        self.run_test(MaskedScatterModel(), (mask, constant))
7038

7039
    @skipIfUnsupportedMinOpsetVersion(11)
7040
    def test_index_put_with_1d_mask_to_masked_scatter(self):
7041
        class MaskedScatterModel(torch.nn.Module):
7042
            def forward(self, tensor, mask, some_const):
7043
                tensor[mask] = some_const
7044
                return tensor
7045

7046
        mask = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.bool)
7047
        tensor = torch.randn(8, 4, 5, requires_grad=True)
7048
        some_const = torch.randn(4, 4, 5, dtype=torch.float)
7049
        self.run_test(MaskedScatterModel(), (tensor, mask, some_const))
7050

7051
    @skipIfUnsupportedMinOpsetVersion(9)
7052
    def test_pixel_shuffle(self):
7053
        class PixelShuffle(torch.nn.Module):
7054
            def forward(self, x):
7055
                return torch.pixel_shuffle(x, upscale_factor=2)
7056

7057
        x = torch.randn(2, 16, 4, 3, requires_grad=True)
7058
        y = torch.randn(4, 32, 8, 4, requires_grad=True)
7059
        self.run_test(PixelShuffle(), x)
7060
        self.run_test(
7061
            PixelShuffle(),
7062
            x,
7063
            input_names=["x"],
7064
            dynamic_axes={"x": [0, 1, 2, 3]},
7065
            additional_test_inputs=[y],
7066
        )
7067

7068
    @skipIfUnsupportedMinOpsetVersion(9)
7069
    def test_pixel_unshuffle(self):
7070
        class PixelUnshuffle(torch.nn.Module):
7071
            def forward(self, x):
7072
                return torch.pixel_unshuffle(x, downscale_factor=2)
7073

7074
        x = torch.randn(2, 16, 4, 6, requires_grad=True)
7075
        y = torch.randn(4, 32, 8, 4, requires_grad=True)
7076
        self.run_test(PixelUnshuffle(), x)
7077
        self.run_test(
7078
            PixelUnshuffle(),
7079
            x,
7080
            input_names=["x"],
7081
            dynamic_axes={"x": [0, 1, 2, 3]},
7082
            additional_test_inputs=[y],
7083
        )
7084

7085
    @skipIfUnsupportedMinOpsetVersion(9)
7086
    def test_reciprocal(self):
7087
        class ReciprocalModel(torch.nn.Module):
7088
            def forward(self, x):
7089
                return torch.reciprocal(x)
7090

7091
        model = ReciprocalModel()
7092
        x = torch.tensor([2, 4])
7093
        self.run_test(model, x.to(torch.long))
7094
        self.run_test(model, x.to(torch.float))
7095
        self.run_test(model, x.to(torch.double))
7096

7097
    @skipIfUnsupportedMinOpsetVersion(9)
7098
    def test_scalar_type(self):
7099
        class ArithmeticModel(torch.nn.Module):
7100
            def forward(self, x):
7101
                return x.size(0) * 2 * x, 2 - x
7102

7103
        x = torch.ones(2, 3, dtype=torch.float32)
7104
        self.run_test(ArithmeticModel(), x)
7105

7106
        class ComparisonModel(torch.nn.Module):
7107
            def forward(self, x, y):
7108
                a = torch.tensor([12.0])
7109
                return x.lt(1.5) & y.le(2) & x.le(1), x.gt(y), x.lt(y), a.ge(x.size(0))
7110

7111
        x = torch.ones(2, 3, dtype=torch.int32)
7112
        y = torch.ones(2, 3, dtype=torch.float32)
7113
        self.run_test(ComparisonModel(), (x, y))
7114

7115
        class MatMulModel(torch.nn.Module):
7116
            def forward(self, x):
7117
                return torch.mm(x, x) + x + torch.mm(x, x) + x
7118

7119
        x = torch.ones(3, 3)
7120
        self.run_test(MatMulModel(), x)
7121

7122
        class AddMMModel(torch.nn.Module):
7123
            def forward(self, x):
7124
                return torch.mm(x, x) + x
7125

7126
        x = torch.ones(3, 3)
7127
        self.run_test(AddMMModel(), x)
7128

7129
        class FullModel(torch.nn.Module):
7130
            # add is used for exporting full
7131
            def forward(self, x):
7132
                return torch.full((3, 4), x)
7133

7134
        x = torch.tensor(12.0)
7135
        self.run_test(FullModel(), x)
7136

7137
        class CatModel(torch.nn.Module):
7138
            def forward(self, fp16, fp32):
7139
                return torch.cat([fp16, fp32])
7140

7141
        fp16 = Tensor([0.5])
7142
        fp16 = fp16.half()
7143
        fp32 = Tensor([1.5])
7144
        self.run_test(CatModel(), (fp16, fp32))
7145

7146
    @skipIfUnsupportedMinOpsetVersion(9)
7147
    def test_scalar_type_does_not_trigger_upcast_type_promotion(self):
7148
        class DoNotUpcastModel(torch.nn.Module):
7149
            def forward(self, x):
7150
                scale = x.size()[-1] ** -0.5
7151
                # 'scale' is exported as onnx float32 rank 0 tensor.
7152
                # The following 'Mul' should NOT be promoted to float32.
7153
                return x * scale
7154

7155
        x = torch.ones(2, 3, dtype=torch.float16)
7156
        self.run_test(DoNotUpcastModel(), x)
7157

7158
    @skipIfUnsupportedMinOpsetVersion(9)
7159
    def test_full_like(self):
7160
        class FullLikeModel(torch.nn.Module):
7161
            def forward(self, x):
7162
                return torch.full_like(x, 1.3, dtype=torch.int)
7163

7164
        x = torch.tensor(12)
7165
        self.run_test(FullLikeModel(), x)
7166

7167
    @skipIfUnsupportedMinOpsetVersion(9)
7168
    @skipDtypeChecking
7169
    def test_full_like_value(self):
7170
        class FullLikeModel(torch.nn.Module):
7171
            def forward(self, x, y):
7172
                out = y + 2
7173
                return torch.full_like(x, out)
7174

7175
        x = torch.tensor(12)
7176
        y = torch.tensor(2)
7177
        self.run_test(FullLikeModel(), (x, y))
7178

7179
    def test_l1_norm(self):
7180
        class NormModel(torch.nn.Module):
7181
            def forward(self, x):
7182
                return torch.norm(x, p=1, dim=-1, keepdim=False)
7183

7184
        x = torch.randn(4, 2, 3, requires_grad=True)
7185
        self.run_test(NormModel(), x)
7186

7187
    def test_l2_norm(self):
7188
        class NormModel(torch.nn.Module):
7189
            def forward(self, x):
7190
                return torch.norm(x, p=2, dim=-2, keepdim=False)
7191

7192
        x = torch.randn(4, 2, 3, requires_grad=True)
7193
        self.run_test(NormModel(), x)
7194

7195
    def test_frobenius_norm(self):
7196
        class NormModel(torch.nn.Module):
7197
            def forward(self, x):
7198
                return torch.norm(x, p="fro", dim=0, keepdim=False)
7199

7200
        x = torch.randn(4, 2, 3, requires_grad=True)
7201
        self.run_test(NormModel(), x)
7202

7203
    def test_frobenius_norm_keepdim(self):
7204
        class NormModel(torch.nn.Module):
7205
            def forward(self, x):
7206
                return torch.norm(x, p="fro", dim=(0, 1), keepdim=True)
7207

7208
        x = torch.randn(4, 2, 3, requires_grad=True)
7209
        self.run_test(NormModel(), x)
7210

7211
    def test_unfold(self):
7212
        class UnfoldModel(torch.nn.Module):
7213
            def forward(self, x):
7214
                return x.unfold(dimension=2, size=2, step=2)
7215

7216
        x = torch.randn(4, 2, 3, requires_grad=True)
7217
        y = torch.randn(2, 1, 3, requires_grad=True)
7218
        self.run_test(
7219
            UnfoldModel(),
7220
            x,
7221
            dynamic_axes={"x": [0, 1]},
7222
            input_names=["x"],
7223
            additional_test_inputs=[y],
7224
        )
7225

7226
    def test_unfold_infer_shape(self):
7227
        class UnfoldModule(torch.jit.ScriptModule):
7228
            def __init__(self):
7229
                super().__init__()
7230
                self.conv = torch.nn.Conv1d(3, 1, 3, stride=2)
7231

7232
            @torch.jit.script_method
7233
            def forward(self, x):
7234
                x = self.conv(x)
7235
                return x.unfold(dimension=2, size=2, step=2)
7236

7237
        x = torch.randn(32, 3, 64)
7238
        self.run_test(UnfoldModule(), x)
7239

7240
    @skipIfUnsupportedMinOpsetVersion(12)
7241
    def test_unfold_dynamic_inputs(self):
7242
        class UnfoldModel(torch.nn.Module):
7243
            def forward(self, x):
7244
                return x.unfold(dimension=2, size=x.shape[1], step=x.shape[1] - 1)
7245

7246
        x = torch.randn(4, 2, 4, requires_grad=True)
7247
        self.run_test(UnfoldModel(), x)
7248

7249
        class UnfoldModel(torch.nn.Module):
7250
            def forward(self, x):
7251
                return x.unfold(dimension=2, size=x.shape[1], step=1)
7252

7253
        x = torch.randn(4, 2, 4, requires_grad=True)
7254
        self.run_test(UnfoldModel(), x)
7255

7256
    @skipIfUnsupportedMinOpsetVersion(9)  # MatMul long inputs is added in ONNX opset 9.
7257
    def test_mv(self):
7258
        class MatmulModel(torch.nn.Module):
7259
            def forward(self, input, other):
7260
                return torch.mv(input, other)
7261

7262
        x = torch.randn(4, 5, requires_grad=True)
7263
        y = torch.randn(5, requires_grad=True)
7264
        self.run_test(MatmulModel(), (x, y))
7265

7266
        x = torch.randint(10, (4, 5))
7267
        y = torch.randint(10, (5,))
7268
        self.run_test(MatmulModel(), (x, y))
7269

7270
    @skipIfUnsupportedMinOpsetVersion(9)  # MatMul long inputs is added in ONNX opset 9.
7271
    def test_dot(self):
7272
        class MatmulModel(torch.nn.Module):
7273
            def forward(self, input, other):
7274
                return torch.dot(input, other)
7275

7276
        x = torch.randn(5, requires_grad=True)
7277
        y = torch.randn(5, requires_grad=True)
7278
        self.run_test(MatmulModel(), (x, y))
7279

7280
        x = torch.randint(10, (5,))
7281
        y = torch.randint(10, (5,))
7282
        self.run_test(MatmulModel(), (x, y))
7283

7284
    @skipScriptTest()  # SpectralNorm not TorchScript compatible.
7285
    def test_spectral_norm(self):
7286
        m = torch.nn.utils.spectral_norm(torch.nn.Linear(2, 4))
7287

7288
        x = torch.randn(6, 2)
7289
        self.run_test(m, (x,))
7290

7291
    def test_prelu(self):
7292
        class PReluModel(torch.nn.Module):
7293
            def __init__(self):
7294
                super().__init__()
7295
                self.prelu = torch.nn.PReLU()
7296

7297
            def forward(self, x):
7298
                return self.prelu(x)
7299

7300
        x = torch.randn(2, 3, 4)
7301
        y = torch.randn(2, 4, 5)
7302
        self.run_test(
7303
            PReluModel(),
7304
            x,
7305
            input_names=["x"],
7306
            dynamic_axes={"x": [1, 2]},
7307
            additional_test_inputs=[y],
7308
        )
7309

7310
    def test_prelu_scalar(self):
7311
        x = torch.scalar_tensor(1.0)
7312
        self.run_test(torch.nn.PReLU(), x, input_names=["x"])
7313

7314
    def test_relu6(self):
7315
        class Relu6Model(torch.nn.Module):
7316
            def __init__(self):
7317
                super().__init__()
7318
                self.relu6 = torch.nn.ReLU6()
7319

7320
            def forward(self, x):
7321
                return self.relu6(x)
7322

7323
        x = torch.randn(2, 3, 4) * 100.0
7324
        y = torch.randn(2, 4, 5) * 100.0
7325
        self.run_test(
7326
            Relu6Model(),
7327
            x,
7328
            input_names=["x"],
7329
            dynamic_axes={"x": [1, 2]},
7330
            additional_test_inputs=[y],
7331
        )
7332

7333
    def test_silu(self):
7334
        class SiLUModel(torch.nn.Module):
7335
            def __init__(self):
7336
                super().__init__()
7337
                self.silu = torch.nn.SiLU()
7338

7339
            def forward(self, x):
7340
                return self.silu(x)
7341

7342
        x = torch.randn(2, 3, 4)
7343
        self.run_test(SiLUModel(), (x))
7344

7345
    @skipIfUnsupportedMinOpsetVersion(14)
7346
    def test_tril(self):
7347
        class trilModel(torch.nn.Module):
7348
            def forward(self, x):
7349
                return torch.tril(x)
7350

7351
        x = torch.randn(2, 3, 4)
7352
        self.run_test(trilModel(), (x))
7353

7354
        class trilModelwithDiagonal(torch.nn.Module):
7355
            def forward(self, x):
7356
                return torch.tril(x, diagonal=1)
7357

7358
        x = torch.randn(2, 3, 4)
7359
        self.run_test(trilModelwithDiagonal(), (x))
7360

7361
        class trilModelwithNegDiagonal(torch.nn.Module):
7362
            def forward(self, x):
7363
                return torch.tril(x, diagonal=-1)
7364

7365
        x = torch.randn(2, 3, 4)
7366
        self.run_test(trilModelwithNegDiagonal(), (x))
7367

7368
        class trilModelWithDiagonalInput(torch.nn.Module):
7369
            def forward(self, x, diagnonal: int):
7370
                return torch.tril(x, diagonal=diagnonal)
7371

7372
        x = torch.randn(2, 3, 4)
7373
        self.run_test(trilModelWithDiagonalInput(), (x, 5))
7374

7375
    @skipIfUnsupportedMinOpsetVersion(14)
7376
    def test_triu(self):
7377
        class triuModel(torch.nn.Module):
7378
            def forward(self, x):
7379
                return torch.triu(x)
7380

7381
        x = torch.randn(2, 3, 4)
7382
        self.run_test(triuModel(), (x))
7383

7384
        class triuModelwithDiagonal(torch.nn.Module):
7385
            def forward(self, x):
7386
                return torch.triu(x, diagonal=1)
7387

7388
        x = torch.randn(2, 3, 4)
7389
        self.run_test(triuModelwithDiagonal(), (x))
7390

7391
        class triuModelwithNegDiagonal(torch.nn.Module):
7392
            def forward(self, x):
7393
                return torch.triu(x, diagonal=-1)
7394

7395
        x = torch.randn(2, 3, 4)
7396
        self.run_test(triuModelwithNegDiagonal(), (x))
7397

7398
        class triuModelWithDiagonalInput(torch.nn.Module):
7399
            def forward(self, x, diagnonal: int):
7400
                return torch.triu(x, diagonal=diagnonal)
7401

7402
        x = torch.randn(2, 3, 4)
7403
        self.run_test(triuModelWithDiagonalInput(), (x, 5))
7404

7405
    def test_mish(self):
7406
        class MishModel(torch.nn.Module):
7407
            def __init__(self):
7408
                super().__init__()
7409
                self.mish = torch.nn.Mish()
7410

7411
            def forward(self, x):
7412
                return self.mish(x)
7413

7414
        x = torch.randn(2, 3, 4)
7415
        self.run_test(MishModel(), (x))
7416

7417
    def test_remainder(self):
7418
        class RemainderModel(torch.nn.Module):
7419
            def forward(self, input, other):
7420
                return torch.remainder(input, other)
7421

7422
        x = torch.randn(4, 2, 3)
7423
        y = torch.randn(1, 2, 1)
7424
        self.run_test(RemainderModel(), (x, y))
7425

7426
        x = torch.tensor([7, 6, -7, -6], dtype=torch.long)
7427
        y = torch.tensor([2], dtype=torch.long)
7428
        self.run_test(RemainderModel(), (x, y))
7429

7430
        x = x.to(torch.float)
7431
        self.run_test(RemainderModel(), (x, y))
7432

7433
        y = y.to(torch.float)
7434
        self.run_test(RemainderModel(), (x, y))
7435

7436
        x = x.to(torch.int32)
7437
        self.run_test(RemainderModel(), (x, y))
7438

7439
    def test_remainder_scalar(self):
7440
        class RemainderModel(torch.nn.Module):
7441
            def __init__(self, scalar=2.55):
7442
                super().__init__()
7443
                self.scalar = scalar
7444

7445
            def forward(self, input):
7446
                return torch.remainder(input, self.scalar)
7447

7448
        x = torch.randint(10, (2, 3))
7449
        self.run_test(RemainderModel(), x)
7450

7451
        x = torch.tensor([7, 6, -7, -6], dtype=torch.long)
7452
        self.run_test(RemainderModel(2), x)
7453

7454
    @skipIfUnsupportedMinOpsetVersion(10)
7455
    def test_fmod(self):
7456
        class FModModel(torch.nn.Module):
7457
            def forward(self, input, other):
7458
                return torch.fmod(input, other)
7459

7460
        x = torch.randn(4, 2, 3)
7461
        y = torch.randn(1, 2, 1)
7462
        self.run_test(FModModel(), (x, y))
7463

7464
    @skipIfUnsupportedMinOpsetVersion(10)
7465
    def test_fmod_scalar(self):
7466
        class FModModel(torch.nn.Module):
7467
            def forward(self, input):
7468
                return torch.fmod(input, 2.55)
7469

7470
        x = torch.randint(10, (2, 3))
7471
        self.run_test(FModModel(), x)
7472

7473
    @skipIfUnsupportedMinOpsetVersion(9)
7474
    def test_glu(self):
7475
        class GluModel(torch.nn.Module):
7476
            def forward(self, x):
7477
                return torch.nn.functional.glu(x)
7478

7479
        x = torch.randn(2, 4, 5, 6, requires_grad=True)
7480
        self.run_test(GluModel(), x)
7481

7482
    @skipIfUnsupportedMinOpsetVersion(9)
7483
    def test_gelu(self):
7484
        class GeluModel(torch.nn.Module):
7485
            def forward(self, x):
7486
                return torch.nn.functional.gelu(x, approximate="none")
7487

7488
        x = torch.randn(2, 4, 5, 6, requires_grad=True)
7489
        self.run_test(GeluModel(), x)
7490

7491
    @skipIfUnsupportedMinOpsetVersion(9)
7492
    def test_tanh_gelu(self):
7493
        class GeluModel(torch.nn.Module):
7494
            def forward(self, x):
7495
                return torch.nn.functional.gelu(x, approximate="tanh")
7496

7497
        x = torch.randn(2, 4, 5, 6, requires_grad=True)
7498
        self.run_test(GeluModel(), x)
7499

7500
    def test_add_inplace(self):
7501
        class InplaceAddModel(torch.nn.Module):
7502
            def forward(self, x):
7503
                x += 12
7504
                return x
7505

7506
        x = torch.randn(4, 2, 3, requires_grad=True)
7507
        self.run_test(InplaceAddModel(), x)
7508

7509
    def test_addcmul(self):
7510
        class AddcmulModel(torch.nn.Module):
7511
            def forward(self, x, t1, t2):
7512
                return torch.addcmul(x, t1, t2), torch.addcmul(x, t1, t2, value=2.2)
7513

7514
        x = torch.randn(1, 3)
7515
        t1 = torch.randn(3, 1)
7516
        t2 = torch.randn(1, 3)
7517
        self.run_test(AddcmulModel(), (x, t1, t2))
7518

7519
    def test_rsqrt(self):
7520
        class RsqrtModel(torch.nn.Module):
7521
            def forward(self, x):
7522
                return x.rsqrt()
7523

7524
        x = torch.randn(4, 2, 3, requires_grad=True, dtype=torch.float64)
7525
        self.run_test(RsqrtModel(), x)
7526

7527
    def test_rsqrt_zeros(self):
7528
        class RsqrtModel(torch.nn.Module):
7529
            def forward(self, x):
7530
                return x.rsqrt()
7531

7532
        x = torch.zeros(4, 2, 3, requires_grad=True, dtype=torch.float64)
7533
        self.run_test(RsqrtModel(), x)
7534

7535
    @skipIfUnsupportedMinOpsetVersion(11)
7536
    def test_unique(self):
7537
        class UniqueModel(torch.nn.Module):
7538
            def forward(self, x):
7539
                return torch.unique(
7540
                    x, sorted=True, return_inverse=False, return_counts=True
7541
                )
7542

7543
        x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
7544
        self.run_test(UniqueModel(), x)
7545

7546
    @skipIfUnsupportedMinOpsetVersion(11)
7547
    def test_unique_along_dim(self):
7548
        class UniqueModel(torch.nn.Module):
7549
            def forward(self, x):
7550
                return torch.unique(
7551
                    x, dim=0, sorted=True, return_inverse=True, return_counts=False
7552
                )
7553

7554
        x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
7555
        self.run_test(UniqueModel(), x)
7556

7557
    @skipIfUnsupportedMinOpsetVersion(11)
7558
    def test_cumsum(self):
7559
        class CumSum(torch.nn.Module):
7560
            def forward(self, input):
7561
                return torch.cumsum(input, dim=0)
7562

7563
        x = torch.randn(2, 3, 4)
7564
        model = CumSum()
7565
        self.run_test(model, x)
7566

7567
    @skipIfUnsupportedMinOpsetVersion(11)
7568
    def test_cumsum_with_cast(self):
7569
        class CumSum(torch.nn.Module):
7570
            def forward(self, input):
7571
                return torch.cumsum(input, dim=0, dtype=torch.float32)
7572

7573
        model = CumSum()
7574
        x = torch.tensor([2, 3, 4], dtype=torch.int32)
7575
        self.run_test(model, x)
7576
        x = torch.tensor([False, True, True])
7577
        self.run_test(model, x)
7578

7579
    @skipScriptTest()  # error in propagate as assign input shape
7580
    @skipIfUnsupportedMinOpsetVersion(10)
7581
    def test_embedding_bag(self):
7582
        model = torch.nn.EmbeddingBag(10, 5, mode="sum", scale_grad_by_freq=True)
7583
        input = torch.randint(10, (7,))
7584
        offset = torch.tensor([0, 2, 5, 6])
7585
        self.run_test(model, (input, offset))
7586

7587
        model = torch.nn.EmbeddingBag(10, 5, mode="sum", include_last_offset=True)
7588
        input = torch.randint(10, (7,))
7589
        offset = torch.tensor([0, 2, 5, 6])
7590
        self.run_test(model, (input, offset))
7591

7592
        model = torch.nn.EmbeddingBag(10, 5, mode="max")
7593
        input = torch.randint(10, (7, 5))
7594
        self.run_test(model, (input))
7595

7596
    @skipIfUnsupportedMinOpsetVersion(11)
7597
    def test_embedding_bag_1d_per_sample_weights(self):
7598
        class EmbeddingModel(torch.nn.Module):
7599
            def forward(self, embedding_matrix, input, offset, weights):
7600
                return torch.nn.functional.embedding_bag(
7601
                    input,
7602
                    embedding_matrix,
7603
                    offsets=offset,
7604
                    mode="sum",
7605
                    per_sample_weights=weights,
7606
                )
7607

7608
        model = EmbeddingModel()
7609
        x = torch.randint(7, (6,))
7610
        w = torch.randn(
7611
            6,
7612
        )
7613
        offset = torch.tensor([0, 2, 5])
7614
        embedding_matrix = torch.rand(10, 15)
7615
        self.run_test(model, (embedding_matrix, x, offset, w))
7616

7617
    @skipIfUnsupportedMinOpsetVersion(11)
7618
    @unittest.skip(
7619
        "This test is broken with ONNXRuntime(17): "
7620
        "when running with onnxruntime 1.17.0 this test fails with the following error:"
7621
        "FAIL : Non-zero status code returned while running If node. "
7622
        "Name:'/If' Status Message: if.cc:253 Compute "
7623
        "If nodes condition input must have exactly one element"
7624
        "https://github.com/pytorch/pytorch/issues/119442"
7625
    )
7626
    def test_embedding_bag_2d_per_sample_weights(self):
7627
        class EmbeddingModel(torch.nn.Module):
7628
            def forward(self, embedding_matrix, input, weights):
7629
                return torch.nn.functional.embedding_bag(
7630
                    input, embedding_matrix, mode="sum", per_sample_weights=weights
7631
                )
7632

7633
        embedding_matrix = torch.rand(10, 15)
7634
        model = EmbeddingModel()
7635
        x = torch.randint(7, (2, 3))
7636
        w = torch.randn(2, 3)
7637

7638
        x2 = torch.randint(7, (4, 3))
7639
        w2 = torch.randn(4, 3)
7640
        self.run_test(
7641
            model,
7642
            (embedding_matrix, x, w),
7643
            input_names=["embed", "x", "w"],
7644
            dynamic_axes={"x": [0], "w": [0]},
7645
            additional_test_inputs=[(embedding_matrix, x2, w2)],
7646
        )
7647

7648
    @skipScriptTest()  # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
7649
    @skipIfUnsupportedMinOpsetVersion(11)
7650
    @unittest.skip(
7651
        "Due to ONNX Loop shape inference issue. "
7652
        "https://msdata.visualstudio.com/Vienna/_workitems/edit/1352001"
7653
    )
7654
    def test_embedding_bag_dynamic_input(self):
7655
        class EmbeddingModel1D(torch.nn.Module):
7656
            def forward(self, embedding_matrix, input, weights, offsets):
7657
                return torch.nn.functional.embedding_bag(
7658
                    input,
7659
                    embedding_matrix,
7660
                    offsets=offsets,
7661
                    mode="sum",
7662
                    per_sample_weights=weights,
7663
                )
7664

7665
        model = EmbeddingModel1D()
7666
        x = torch.randint(7, (6,))
7667
        w = torch.randn(
7668
            6,
7669
        )
7670
        offsets = torch.tensor([0, 2, 5], dtype=torch.long)
7671
        embedding_matrix = torch.rand(10, 15)
7672
        x2 = torch.randint(7, (2,))
7673
        w2 = torch.randn(
7674
            2,
7675
        )
7676
        embedding_matrix2 = torch.rand(12, 25)
7677
        offsets2 = torch.tensor(
7678
            [
7679
                0,
7680
            ],
7681
            dtype=torch.long,
7682
        )
7683
        self.run_test(
7684
            model,
7685
            (embedding_matrix, x, w, offsets),
7686
            additional_test_inputs=[(embedding_matrix2, x2, w2, offsets2)],
7687
            input_names=["embedding_matrix", "x", "offsets", "w"],
7688
            dynamic_axes={
7689
                "embedding_matrix": [0, 1],
7690
                "x": [0],
7691
                "offsets": [0],
7692
                "w": [0],
7693
            },
7694
        )
7695

7696
        class EmbeddingModel2D(torch.nn.Module):
7697
            def forward(self, embedding_matrix, input, weights):
7698
                return torch.nn.functional.embedding_bag(
7699
                    input, embedding_matrix, mode="sum", per_sample_weights=weights
7700
                )
7701

7702
        model = EmbeddingModel2D()
7703
        x = torch.randint(7, (2, 3))
7704
        w = torch.randn(2, 3)
7705
        embedding_matrix = torch.rand(10, 15)
7706
        x2 = torch.randint(7, (3, 5))
7707
        w2 = torch.randn(3, 5)
7708
        embedding_matrix2 = torch.rand(12, 25)
7709
        self.run_test(
7710
            model,
7711
            (embedding_matrix, x, w),
7712
            additional_test_inputs=[(embedding_matrix2, x2, w2)],
7713
            input_names=["embedding_matrix", "x", "w"],
7714
            dynamic_axes={"embedding_matrix": [0, 1], "x": [0, 1], "w": [0, 1]},
7715
        )
7716

7717
    @skipIfUnsupportedMinOpsetVersion(8)
7718
    def test_meshgrid(self):
7719
        class Meshgrid(torch.nn.Module):
7720
            def forward(self, x, y, z):
7721
                output1, output2, output3 = torch.meshgrid(x, y, z)
7722
                return output1, output2, output3
7723

7724
        x = torch.randn(3, requires_grad=True)
7725
        y = torch.zeros(4, requires_grad=True)
7726
        z = torch.randn(5, requires_grad=True)
7727
        self.run_test(Meshgrid(), (x, y, z))
7728

7729
    @skipIfUnsupportedMinOpsetVersion(8)
7730
    def test_meshgrid_indexing(self):
7731
        class Meshgrid(torch.nn.Module):
7732
            def __init__(self, indexing):
7733
                super().__init__()
7734
                self.indexing = indexing
7735

7736
            def forward(self, x, y, z):
7737
                output1, output2, output3 = torch.meshgrid(
7738
                    x, y, z, indexing=self.indexing
7739
                )
7740
                return output1, output2, output3
7741

7742
        x = torch.randn(5, requires_grad=True)
7743
        y = torch.zeros(6, requires_grad=True)
7744
        z = torch.randn(7, requires_grad=True)
7745
        for indexing in ("xy", "ij"):
7746
            self.run_test(Meshgrid(indexing), (x, y, z))
7747

7748
    @skipIfUnsupportedMinOpsetVersion(8)
7749
    def test_meshgrid_scalar(self):
7750
        class Meshgrid(torch.nn.Module):
7751
            def forward(self, x, y, z):
7752
                output1, output2, output3 = torch.meshgrid(x, y, z)
7753
                return output1, output2, output3
7754

7755
        x = torch.ones(3, requires_grad=True)
7756
        y = torch.zeros(4, requires_grad=True)
7757
        z = torch.tensor(2.0)
7758
        self.run_test(Meshgrid(), (x, y, z))
7759

7760
    def test_baddbmm(self):
7761
        class MyModule(torch.nn.Module):
7762
            def forward(self, input, batch1, batch2):
7763
                return torch.baddbmm(
7764
                    input, batch1, batch2, alpha=torch.tensor(5), beta=3.5
7765
                )
7766

7767
        x = torch.randn(10, 3, 5)
7768
        batch1 = torch.randn(10, 3, 4)
7769
        batch2 = torch.randn(10, 4, 5)
7770
        model = MyModule()
7771
        self.run_test(model, (x, batch1, batch2))
7772

7773
    def test_baddbmm_dynamic(self):
7774
        class MyModule(torch.nn.Module):
7775
            def forward(self, input, batch1, batch2, alpha, beta):
7776
                return torch.baddbmm(input, batch1, batch2, alpha=alpha, beta=beta)
7777

7778
        x = torch.randn(10, 3, 5)
7779
        batch1 = torch.randn(10, 3, 4)
7780
        batch2 = torch.randn(10, 4, 5)
7781
        alpha = torch.tensor(5)
7782
        beta = torch.tensor(3.5)
7783
        model = MyModule()
7784
        self.run_test(model, (x, batch1, batch2, alpha, beta))
7785

7786
    def test_numel(self):
7787
        class MyModule(torch.nn.Module):
7788
            def forward(self, input):
7789
                return input.numel() * input
7790

7791
        x = torch.randn(2, 3, 5)
7792
        x2 = torch.randn(4, 5, 6)
7793
        model = MyModule()
7794
        self.run_test(
7795
            model,
7796
            (x,),
7797
            input_names=["x"],
7798
            dynamic_axes={"x": [0, 1, 2]},
7799
            additional_test_inputs=[(x2,)],
7800
        )
7801

7802
    def test_numel_empty(self):
7803
        class MyModule(torch.nn.Module):
7804
            def forward(self, input):
7805
                return input.numel() * input
7806

7807
        x = torch.randn(0)
7808
        x2 = torch.randn(4)
7809
        model = MyModule()
7810
        self.run_test(
7811
            model,
7812
            (x,),
7813
            input_names=["x"],
7814
            dynamic_axes={"x": [0]},
7815
            additional_test_inputs=[(x2,)],
7816
        )
7817

7818
    def test_dtype(self):
7819
        class MyModel(torch.jit.ScriptModule):
7820
            @torch.jit.script_method
7821
            def forward(self, input, other):
7822
                return input.to(dtype=other.dtype) + other
7823

7824
        x = torch.randn(2, 3)
7825
        y = torch.randn(2, 3)
7826
        self.run_test(MyModel(), (x, y))
7827

7828
    def test_dtype_eq(self):
7829
        class MyModel(torch.jit.ScriptModule):
7830
            @torch.jit.script_method
7831
            def forward(self, input, other):
7832
                if input.dtype == other.dtype:
7833
                    return input + other
7834
                return input
7835

7836
        x = torch.randn(2, 3)
7837
        y = torch.randn(2, 3)
7838
        self.run_test(MyModel(), (x, y))
7839

7840
    def test_cast_to(self):
7841
        class MyModule(torch.jit.ScriptModule):
7842
            @torch.jit.script_method
7843
            def forward(self, input, other):
7844
                return input.to(other) + other
7845

7846
        x = torch.randn(2, 3, 4)
7847
        y = torch.tensor([1], dtype=torch.int64)
7848
        model = MyModule()
7849
        self.run_test(model, (x, y))
7850

7851
    def test_cast_to_bool(self):
7852
        class MyModule(torch.nn.Module):
7853
            def forward(self, input, other):
7854
                return torch.cat((input.to(other), other), 0)
7855

7856
        x = torch.randn(2, 3, 4)
7857
        y = torch.zeros([2, 3, 4], dtype=torch.bool)
7858
        model = MyModule()
7859
        self.run_test(model, (x, y))
7860

7861
    # ONNX supports bfloat16 for opsets >= 13
7862
    @skipIfUnsupportedMinOpsetVersion(13)
7863
    def test_cast_type_as_with_bfloat16(self):
7864
        class MyModule(torch.nn.Module):
7865
            def forward(self, x):
7866
                y = torch.ones((3, 4), dtype=torch.bfloat16)
7867
                x = x.type_as(y)
7868
                return x.to(dtype=torch.float16)
7869

7870
        x = torch.ones(3, 4, dtype=torch.float16)
7871
        model = MyModule()
7872
        self.run_test(model, x)
7873

7874
    @skipIfUnsupportedMinOpsetVersion(9)
7875
    def test_type_as(self):
7876
        class MyModule(torch.nn.Module):
7877
            def forward(self, x):
7878
                y = torch.tensor([1.0])
7879
                return x.type_as(y)
7880

7881
        a = torch.tensor([True, False], dtype=torch.bool)
7882
        b = torch.randn(3, 4, dtype=torch.double)
7883
        c = torch.ones((2, 2), dtype=torch.int64)
7884
        model = MyModule()
7885
        self.run_test(model, a)
7886
        self.run_test(model, b)
7887
        self.run_test(model, c)
7888

7889
    @skipIfUnsupportedMinOpsetVersion(9)
7890
    def test_ones_bool(self):
7891
        class MyModule(torch.nn.Module):
7892
            def forward(self, input):
7893
                true = torch.ones(input.shape, dtype=torch.bool)
7894
                return input.to(true) & true
7895

7896
        x = torch.randn(2, 3, 4)
7897
        model = MyModule()
7898
        self.run_test(model, x)
7899

7900
    def test_log(self):
7901
        class Log(torch.nn.Module):
7902
            def forward(self, input):
7903
                return torch.log(input)
7904

7905
        x = torch.rand(2, 3, 4)
7906
        model = Log()
7907
        self.run_test(model, x)
7908

7909
    def test_log1p(self):
7910
        class Log1p(torch.nn.Module):
7911
            def forward(self, input):
7912
                return torch.log1p(input)
7913

7914
        x = torch.rand(2, 3, 4)
7915
        model = Log1p()
7916
        self.run_test(model, x)
7917

7918
    def test_log10(self):
7919
        class Log10(torch.nn.Module):
7920
            def forward(self, input):
7921
                return torch.log10(input)
7922

7923
        x = torch.rand(2, 3, 4)
7924
        model = Log10()
7925
        self.run_test(model, x)
7926

7927
    def test_log2(self):
7928
        class Log2(torch.nn.Module):
7929
            def forward(self, input):
7930
                return torch.log2(input)
7931

7932
        x = torch.tensor(1.0)
7933
        model = Log2()
7934
        self.run_test(model, x)
7935

7936
    @skipIfUnsupportedMinOpsetVersion(11)
7937
    def test_round(self):
7938
        class Round(torch.nn.Module):
7939
            def forward(self, x):
7940
                return torch.round(x)
7941

7942
        x = torch.tensor([0.9920, -1.0362, -1.5000, 3.5000], requires_grad=True)
7943
        self.run_test(Round(), x)
7944

7945
        int_x = torch.tensor([9920, 1036, -1500, 35], dtype=torch.int32)
7946
        self.run_test(Round(), int_x)
7947

7948
    @skipIfUnsupportedMinOpsetVersion(11)
7949
    def test_round_with_decimals(self):
7950
        class Round(torch.nn.Module):
7951
            def __init__(self, decimals):
7952
                super().__init__()
7953
                self.decimals = decimals
7954

7955
            def forward(self, x):
7956
                return torch.round(x, decimals=self.decimals)
7957

7958
        x = torch.tensor([0.9920, -1234.0362, -1.58960, 3.5000])
7959
        for decimals in (0, -2, 3):
7960
            self.run_test(Round(decimals), x)
7961

7962
    @skipIfUnsupportedMinOpsetVersion(17)
7963
    def test_stft_default(self):
7964
        class STFT(torch.nn.Module):
7965
            def forward(self, x):
7966
                n_fft = 16
7967
                return torch.stft(x, n_fft=n_fft, center=False, return_complex=False)
7968

7969
        x = torch.randn((1, 32), requires_grad=True)
7970
        self.run_test(STFT(), x, atol=1e-6)
7971

7972
    @skipIfUnsupportedMinOpsetVersion(17)
7973
    def test_stft_hop_length(self):
7974
        class STFT(torch.nn.Module):
7975
            def forward(self, x):
7976
                n_fft = 16
7977
                hop_length = 4
7978
                return torch.stft(
7979
                    x,
7980
                    n_fft=n_fft,
7981
                    center=False,
7982
                    hop_length=hop_length,
7983
                    return_complex=False,
7984
                )
7985

7986
        x = torch.randn((1, 32), requires_grad=True)
7987
        self.run_test(STFT(), x, atol=1e-6)
7988

7989
    @skipIfUnsupportedMinOpsetVersion(17)
7990
    def test_stft_non_divisible_hop_length(self):
7991
        class STFT(torch.nn.Module):
7992
            def forward(self, x):
7993
                n_fft = 16
7994
                hop_length = 5
7995
                return torch.stft(
7996
                    x,
7997
                    n_fft=n_fft,
7998
                    center=False,
7999
                    hop_length=hop_length,
8000
                    return_complex=False,
8001
                )
8002

8003
        x = torch.randn((1, 32), requires_grad=True)
8004
        self.run_test(STFT(), x, atol=1e-6)
8005

8006
    @skipIfUnsupportedMinOpsetVersion(17)
8007
    def test_stft_window_int_same_size(self):
8008
        class STFT(torch.nn.Module):
8009
            def forward(self, x):
8010
                n_fft = 16
8011
                win_length = 16
8012
                return torch.stft(
8013
                    x,
8014
                    n_fft=n_fft,
8015
                    center=False,
8016
                    win_length=win_length,
8017
                    return_complex=False,
8018
                )
8019

8020
        x = torch.randn((1, 32), requires_grad=True)
8021
        self.run_test(STFT(), x, atol=1e-6)
8022

8023
    @skipIfUnsupportedMinOpsetVersion(17)
8024
    def test_stft_window_int_different_size(self):
8025
        class STFT(torch.nn.Module):
8026
            def forward(self, x):
8027
                n_fft = 16
8028
                win_length = 9
8029
                return torch.stft(
8030
                    x,
8031
                    n_fft=n_fft,
8032
                    center=False,
8033
                    win_length=win_length,
8034
                    return_complex=False,
8035
                )
8036

8037
        x = torch.randn((1, 32), requires_grad=True)
8038
        self.run_test(STFT(), x, atol=1e-6)
8039

8040
    @skipIfUnsupportedMinOpsetVersion(17)
8041
    def test_stft_window_custom(self):
8042
        class STFT(torch.nn.Module):
8043
            def forward(self, x):
8044
                n_fft = 16
8045
                window = torch.hann_window(16)
8046
                return torch.stft(
8047
                    x,
8048
                    n_fft=n_fft,
8049
                    center=False,
8050
                    window=window,
8051
                    return_complex=False,
8052
                )
8053

8054
        x = torch.randn((1, 32), requires_grad=True)
8055
        self.run_test(STFT(), x, atol=1e-6)
8056

8057
    @skipIfUnsupportedMinOpsetVersion(17)
8058
    def test_stft_wrong_custom_window_size(self):
8059
        class STFT(torch.nn.Module):
8060
            def forward(self, x):
8061
                n_fft = 16
8062
                window = torch.hann_window(10)
8063
                return torch.stft(
8064
                    x, n_fft=n_fft, window=window, center=False, return_complex=False
8065
                )
8066

8067
        x = torch.randn((1, 32), requires_grad=True)
8068
        with self.assertRaises((AssertionError, RuntimeError)):
8069
            self.run_test(STFT(), x)
8070

8071
    @skipIfUnsupportedMinOpsetVersion(17)
8072
    def test_stft_wrong_window_length(self):
8073
        class STFT(torch.nn.Module):
8074
            def forward(self, x):
8075
                n_fft = 16
8076
                win_len = 17
8077
                return torch.stft(
8078
                    x,
8079
                    n_fft=n_fft,
8080
                    win_length=win_len,
8081
                    center=False,
8082
                    return_complex=False,
8083
                )
8084

8085
        x = torch.randn((1, 32), requires_grad=True)
8086
        with self.assertRaises(RuntimeError):
8087
            self.run_test(STFT(), x)
8088

8089
    @skipIfUnsupportedMinOpsetVersion(17)
8090
    def test_stft_window_size_with_win_len(self):
8091
        class STFT(torch.nn.Module):
8092
            def forward(self, x):
8093
                n_fft = 16
8094
                window = torch.hann_window(10)
8095
                win_len = 10
8096
                return torch.stft(
8097
                    x,
8098
                    n_fft=n_fft,
8099
                    window=window,
8100
                    win_length=win_len,
8101
                    center=False,
8102
                    return_complex=False,
8103
                )
8104

8105
        x = torch.randn((1, 32), requires_grad=True)
8106
        self.run_test(STFT(), x, atol=1e-6)
8107

8108
    @skipIfUnsupportedMinOpsetVersion(17)
8109
    def test_stft_one_dimension(self):
8110
        class STFT(torch.nn.Module):
8111
            def forward(self, x):
8112
                n_fft = 16
8113
                return torch.stft(
8114
                    x,
8115
                    n_fft=n_fft,
8116
                    center=False,
8117
                    return_complex=False,
8118
                )
8119

8120
        x = torch.randn((32), requires_grad=True)
8121
        self.run_test(STFT(), x, atol=1e-6)
8122

8123
    @skipIfUnsupportedMinOpsetVersion(17)
8124
    def test_stft_wrong_input_size(self):
8125
        class STFT(torch.nn.Module):
8126
            def forward(self, x):
8127
                n_fft = 16
8128
                return torch.stft(x, n_fft=n_fft, center=False, return_complex=False)
8129

8130
        x = torch.randn((1, 1, 32), requires_grad=True)
8131
        with self.assertRaises(RuntimeError):
8132
            self.run_test(STFT(), x)
8133

8134
    @skipIfUnsupportedMinOpsetVersion(17)
8135
    def test_stft_wrong_return_complex(self):
8136
        class STFT(torch.nn.Module):
8137
            def forward(self, x):
8138
                n_fft = 16
8139
                return torch.stft(x, n_fft=n_fft, center=False, return_complex=True)
8140

8141
        x = torch.randn((1, 32), requires_grad=True)
8142
        with self.assertRaises(errors.SymbolicValueError):
8143
            self.run_test(STFT(), x)
8144

8145
    @skipIfUnsupportedMinOpsetVersion(17)
8146
    def test_stft_normalize(self):
8147
        class STFT(torch.nn.Module):
8148
            def forward(self, x):
8149
                n_fft = 16
8150
                return torch.stft(
8151
                    x,
8152
                    n_fft=n_fft,
8153
                    center=False,
8154
                    normalized=True,
8155
                    return_complex=False,
8156
                )
8157

8158
        x = torch.randn((32), requires_grad=True)
8159
        self.run_test(STFT(), x, atol=1e-6)
8160

8161
    @skipIfUnsupportedMinOpsetVersion(17)
8162
    def test_stft_not_onesided(self):
8163
        class STFT(torch.nn.Module):
8164
            def forward(self, x):
8165
                n_fft = 16
8166
                return torch.stft(
8167
                    x,
8168
                    n_fft=n_fft,
8169
                    center=False,
8170
                    onesided=False,
8171
                    return_complex=False,
8172
                )
8173

8174
        x = torch.randn((32), requires_grad=True)
8175
        self.run_test(STFT(), x, atol=1e-6)
8176

8177
    def test_constant_pad(self):
8178
        model = torch.nn.ConstantPad1d(2, 3.5)
8179
        x = torch.randn(2, 4, 4)
8180
        self.run_test(model, x)
8181

8182
        model = torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5)
8183
        x = torch.randn(2, 2, 4, 4)
8184
        self.run_test(model, x)
8185

8186
    @common_utils.parametrize(
8187
        "pad",
8188
        [
8189
            common_utils.subtest([2, 4], name="scalar_list"),
8190
            common_utils.subtest(
8191
                [
8192
                    torch.tensor(2, dtype=torch.int64),
8193
                    torch.tensor(4, dtype=torch.int64),
8194
                ],
8195
                name="scalar_tensor_list",
8196
            ),
8197
        ],
8198
    )
8199
    @skipIfUnsupportedMinOpsetVersion(11)  # Dynamic padding is added in opset 11
8200
    def test_pad_types(self, pad):
8201
        # Test for different pad integer types
8202
        class Pad(torch.nn.Module):
8203
            def forward(self, x, pad: List[int]):
8204
                return torch.nn.functional.pad(x, pad)
8205

8206
        x = torch.randn(2, 2, 4, 4)
8207
        self.run_test(Pad(), (x, pad))
8208

8209
    @skipIfUnsupportedMinOpsetVersion(11)
8210
    def test_pad_circular(self):
8211
        class PadModel(torch.nn.Module):
8212
            def forward(self, x):
8213
                out = torch.nn.functional.pad(x, (1, 2, 1, 2), mode="circular")
8214
                return out
8215

8216
        x = torch.randn(2, 3, 3, 4)
8217
        self.run_test(PadModel(), (x))
8218

8219
    @skipIfUnsupportedMinOpsetVersion(11)
8220
    def test_pad_circular_negative(self):
8221
        # Test for different pad integer types
8222
        class PadModel(torch.nn.Module):
8223
            def forward(self, x):
8224
                out = torch.nn.functional.pad(x, (-1, -2), mode="circular")
8225
                return out
8226

8227
        x = torch.randn(2, 3, 6)
8228
        self.run_test(PadModel(), (x))
8229

8230
    @skipIfUnsupportedMinOpsetVersion(11)
8231
    def test_pad_circular_dynamic_axes(self):
8232
        class PadModel(torch.nn.Module):
8233
            def forward(self, x):
8234
                out = torch.nn.functional.pad(x, (2, 1, 2, 1), mode="circular")
8235
                return out
8236

8237
        x = torch.randn(4, 3, 5, 6)
8238
        self.run_test(
8239
            PadModel(),
8240
            x,
8241
            input_names=["input_1"],
8242
            dynamic_axes={"input_1": [0, 1, 2, 3]},
8243
        )
8244

8245
    @skipIfUnsupportedMaxOpsetVersion(10)
8246
    @skipScriptTest()  # TODO: the logic in symbolic_opset9 doesn't handle script
8247
    def test_unsupported_pad(self):
8248
        class Pad(torch.nn.Module):
8249
            def forward(self, x, pad: List[int]):
8250
                return torch.nn.functional.pad(x, pad)
8251

8252
        x = torch.randn(2, 2, 4, 4)
8253
        y = [2, 4]
8254

8255
        with self.assertRaisesRegex(
8256
            RuntimeError,
8257
            (
8258
                "Unsupported: ONNX export of Pad.*"
8259
                + "The sizes of the padding must be constant"
8260
            ),
8261
        ):
8262
            self.run_test(Pad(), (x, y))
8263

8264
    @skipIfUnsupportedMinOpsetVersion(9)
8265
    def test_if_fold(self):
8266
        class IfFoldModel(torch.nn.Module):
8267
            def forward(self, y):
8268
                if y.dim() == 2:
8269
                    y = y + 4
8270
                    y = y + 2
8271
                else:
8272
                    y = y - 1
8273
                return y
8274

8275
        x = torch.ones((3, 4), dtype=torch.int)
8276
        self.run_test(IfFoldModel(), x)
8277

8278
        class IfFoldModel(torch.nn.Module):
8279
            def forward(self, y):
8280
                if y.numel() > 1:
8281
                    y = y + 4
8282
                else:
8283
                    y = y + 2
8284
                return y
8285

8286
        x = torch.ones((3, 4), dtype=torch.int)
8287
        self.run_test(IfFoldModel(), x)
8288

8289
        class IfFoldModel(torch.nn.Module):
8290
            def forward(self, y):
8291
                if y.dim() != 3:
8292
                    y = y + 4
8293
                    y = y + 2
8294
                else:
8295
                    return y
8296
                return y
8297

8298
        x = torch.ones((3, 4), dtype=torch.int)
8299
        self.run_test(IfFoldModel(), x)
8300

8301
        class IfFoldModel(torch.nn.Module):
8302
            def forward(self, y):
8303
                if y.dim() >= 1:
8304
                    y = y + 4
8305
                else:
8306
                    y = y - 1
8307
                return y
8308

8309
        x = torch.ones((3, 4), dtype=torch.int)
8310
        self.run_test(IfFoldModel(), x)
8311

8312
        class IfFoldModel(torch.nn.Module):
8313
            def forward(self, y):
8314
                if y.dim() <= 1:
8315
                    y = y + 4
8316
                else:
8317
                    y = y + 2
8318
                return y
8319

8320
        x = torch.ones((3, 4), dtype=torch.int)
8321
        self.run_test(IfFoldModel(), x)
8322

8323
        class IfFoldModel(torch.nn.Module):
8324
            def forward(self, y):
8325
                if y.dim() < 3 and y.dtype == torch.int:
8326
                    y = y + 4
8327
                    y = y + 2
8328
                else:
8329
                    return y
8330
                return y
8331

8332
        x = torch.ones((3, 4), dtype=torch.int)
8333
        self.run_test(IfFoldModel(), x)
8334

8335
        class IfFoldModel(torch.nn.Module):
8336
            def forward(self, y):
8337
                if y.dim() == 3 and y.dtype == torch.int:
8338
                    y = y + 4
8339
                    y = y + 2
8340
                else:
8341
                    y = y + 1
8342
                return y
8343

8344
        x = torch.ones((3, 4), dtype=torch.int)
8345
        self.run_test(IfFoldModel(), x)
8346

8347
        class IfFoldModel(torch.nn.Module):
8348
            def forward(self, y):
8349
                if y.numel() != 0 and y.dim() == 2:
8350
                    y = y + 4
8351
                    y = y + 2
8352
                else:
8353
                    return y
8354
                return y
8355

8356
        x = torch.ones((3, 4), dtype=torch.int)
8357
        self.run_test(IfFoldModel(), x)
8358

8359
        class IfFoldModel(torch.nn.Module):
8360
            def forward(self, x, y):
8361
                if x.numel() == y.numel():
8362
                    y = x + y
8363
                else:
8364
                    y = y - x
8365
                return y
8366

8367
        x = torch.ones((3, 4), dtype=torch.int)
8368
        y = torch.ones((3, 4), dtype=torch.int)
8369
        self.run_test(IfFoldModel(), (x, y))
8370

8371
        class IfFoldModel(torch.nn.Module):
8372
            def forward(self, x, y):
8373
                if x.numel() != y.numel():
8374
                    y = x + y
8375
                else:
8376
                    y = y - x
8377
                return y
8378

8379
        x = torch.ones((3, 4), dtype=torch.int)
8380
        y = torch.ones((3, 4), dtype=torch.int)
8381
        self.run_test(IfFoldModel(), (x, y))
8382

8383
    @skipIfUnsupportedMinOpsetVersion(11)
8384
    def test_uninitialized(self):
8385
        class UninitializedModel(torch.nn.Module):
8386
            def forward(self, y):
8387
                if y.shape[1] < 5:
8388
                    if y.size(0) == 1:
8389
                        y = y + 4
8390
                    else:
8391
                        return y
8392
                return y
8393

8394
        x = torch.ones((3, 4), dtype=torch.int)
8395
        self.run_test(UninitializedModel(), x)
8396

8397
    @skipIfUnsupportedMinOpsetVersion(11)
8398
    def test_uninitialized_dynamic(self):
8399
        class UninitializedModel(torch.nn.Module):
8400
            def forward(self, y):
8401
                if y.shape[1] < 5:
8402
                    if y.size(0) == 1:
8403
                        y = y + 4
8404
                    else:
8405
                        return y
8406
                return y
8407

8408
        x = torch.ones((3, 4), dtype=torch.int)
8409
        y = torch.ones((6, 7), dtype=torch.int)
8410
        self.run_test(
8411
            UninitializedModel(),
8412
            x,
8413
            additional_test_inputs=[y],
8414
            input_names=["input_1"],
8415
            dynamic_axes={"input_1": [0, 1]},
8416
        )
8417

8418
    # onnx::Identity of sequence supported for ONNX opset >= 14
8419
    @skipIfUnsupportedMinOpsetVersion(14)
8420
    def test_uninitialized_tensorList(self):
8421
        class UninitializedTensorListModel(torch.nn.Module):
8422
            def forward(self, x):
8423
                if x[0].shape[0] < 5:
8424
                    if x.size(0) == 1:
8425
                        x = x + 4
8426
                    else:
8427
                        return [x]
8428
                return [x]
8429

8430
        x = torch.ones((3, 4), dtype=torch.int)
8431
        self.run_test(torch.jit.script(UninitializedTensorListModel()), x)
8432

8433
    # onnx::Identity of sequence supported for ONNX opset >= 14
8434
    @skipIfUnsupportedMinOpsetVersion(14)
8435
    def test_uninitialized_tensorList_dynamic(self):
8436
        class UninitializedTensorListModel(torch.nn.Module):
8437
            def forward(self, x):
8438
                if x[0].shape[0] < 5:
8439
                    if x.size(0) == 1:
8440
                        x += x
8441
                    else:
8442
                        return list(x)
8443
                return list(x)
8444

8445
        x = torch.ones((3, 4), dtype=torch.double)
8446
        self.run_test(
8447
            torch.jit.script(UninitializedTensorListModel()),
8448
            x,
8449
            input_names=["input_1"],
8450
            dynamic_axes={"input_1": [0, 1]},
8451
        )
8452

8453
    # onnx::Identity of sequence supported for ONNX opset >= 14
8454
    @skipIfUnsupportedMinOpsetVersion(14)
8455
    def test_uninitialized_intList(self):
8456
        class UninitializedListModel(torch.nn.Module):
8457
            def forward(self, x):
8458
                y = list(range(x.size(0)))
8459
                if y[0] < 5:
8460
                    # if x.size(0) != 3, ORT will throw type error.
8461
                    if x.size(0) == 3:
8462
                        y.append(10)
8463
                    else:
8464
                        return y
8465
                return y
8466

8467
        x = torch.ones((3, 4), dtype=torch.int)
8468
        self.run_test(
8469
            torch.jit.script(UninitializedListModel()),
8470
            x,
8471
            input_names=["input_1"],
8472
            dynamic_axes={"input_1": [0, 1]},
8473
        )
8474

8475
    # onnx::Identity of sequence supported for ONNX opset >= 14
8476
    @skipIfUnsupportedMinOpsetVersion(14)
8477
    def test_uninitialized_tensorList_shape(self):
8478
        class UninitializedModel(torch.nn.Module):
8479
            def forward(self, x):
8480
                if x.shape[1] < 5:
8481
                    if x.size(0) == 1:
8482
                        x = x + 4
8483
                    else:
8484
                        x_list = list(x)
8485
                        x_list.append(x)
8486
                        return x_list
8487
                return [x, x]
8488

8489
        x = torch.ones((3, 4), dtype=torch.int)
8490
        y = torch.ones((4, 6), dtype=torch.int)
8491
        self.run_test(
8492
            torch.jit.script(UninitializedModel()),
8493
            x,
8494
            additional_test_inputs=[y],
8495
            input_names=["input_1"],
8496
            dynamic_axes={"input_1": [0, 1]},
8497
        )
8498

8499
    # Sequence type as loop-carried dependencies only supported for ONNX opset >= 13
8500
    @skipIfUnsupportedMinOpsetVersion(13)
8501
    def test_sequance_loopcarried(self):
8502
        class SequanceLoopModel(torch.nn.Module):
8503
            def forward(self, x):
8504
                outputs = []
8505
                for i in range(3):
8506
                    outputs += [x]
8507
                return torch.stack(outputs).transpose(0, 1)
8508

8509
        x = torch.ones((3, 4), dtype=torch.int)
8510
        self.run_test(torch.jit.script(SequanceLoopModel()), x)
8511

8512
    def test_reflection_pad(self):
8513
        model = torch.nn.ReflectionPad1d(2)
8514
        x = torch.randn(2, 4, 4)
8515
        self.run_test(model, x)
8516

8517
        model = torch.nn.ReflectionPad2d((3, 0, 2, 1))
8518
        x = torch.randn(2, 2, 4, 4)
8519
        self.run_test(model, x)
8520

8521
    def test_replication_pad(self):
8522
        model = torch.nn.ReplicationPad1d(2)
8523
        x = torch.randn(2, 4, 4)
8524
        self.run_test(model, x)
8525

8526
        model = torch.nn.ReplicationPad2d((3, 0, 2, 1))
8527
        x = torch.randn(2, 2, 4, 4)
8528
        self.run_test(model, x)
8529

8530
    @skipIfUnsupportedMinOpsetVersion(11)
8531
    def test_im2col(self):
8532
        class Unfold(torch.nn.Module):
8533
            def forward(self, input):
8534
                return (
8535
                    torch.nn.functional.unfold(
8536
                        input, kernel_size=(10, 15), dilation=2, padding=5, stride=3
8537
                    ),
8538
                    torch.nn.functional.unfold(
8539
                        input, kernel_size=(2, 2), dilation=1, padding=0, stride=3
8540
                    ),
8541
                    torch.nn.functional.unfold(
8542
                        input, kernel_size=(1, 1), dilation=5, padding=2, stride=3
8543
                    ),
8544
                )
8545

8546
        x = torch.rand(1, 1, 200, 100)
8547
        self.run_test(Unfold(), x)
8548

8549
    @skipIfNoLapack
8550
    @skipIfUnsupportedMinOpsetVersion(11)
8551
    def test_det(self):
8552
        class Det(torch.nn.Module):
8553
            def forward(self, x):
8554
                return torch.linalg.det(x)
8555

8556
        x = torch.randn(2, 3, 5, 5)
8557
        self.run_test(Det(), x)
8558

8559
    def test_linalg_norm(self):
8560
        class LinalgSingleDimModel(torch.nn.Module):
8561
            def __init__(self, ord_val):
8562
                super().__init__()
8563
                self.ord = ord_val
8564

8565
            def forward(self, x):
8566
                return torch.linalg.norm(x, ord=self.ord, dim=1)
8567

8568
        x = torch.randn(2, 3, 5, 5)
8569
        self.run_test(LinalgSingleDimModel(None), x)
8570
        self.run_test(LinalgSingleDimModel(2), x)
8571
        self.run_test(LinalgSingleDimModel(float("inf")), x)
8572
        self.run_test(LinalgSingleDimModel(-float("inf")), x)
8573
        self.run_test(LinalgSingleDimModel(-4), x)
8574
        self.run_test(LinalgSingleDimModel(1.5), x)
8575

8576
        class LinalgMultiDimModel(torch.nn.Module):
8577
            def __init__(self, ord_val):
8578
                super().__init__()
8579
                self.ord = ord_val
8580

8581
            def forward(self, x):
8582
                return torch.linalg.norm(x, ord=self.ord, dim=(0, 2))
8583

8584
        x = torch.randn(2, 3, 5, 5)
8585
        self.run_test(LinalgMultiDimModel("fro"), x)
8586
        self.run_test(LinalgMultiDimModel(float("inf")), x)
8587
        self.run_test(LinalgMultiDimModel(-float("inf")), x)
8588
        self.run_test(LinalgMultiDimModel(1), x)
8589
        self.run_test(LinalgMultiDimModel(-1), x)
8590

8591
        class LinalgNoDimNoOrdModel(torch.nn.Module):
8592
            def forward(self, x):
8593
                return torch.linalg.norm(x)
8594

8595
        x = torch.randn(2, 3, 5, 5)
8596
        self.run_test(LinalgNoDimNoOrdModel(), x)
8597
        y = torch.randn(2, 3)
8598
        self.run_test(LinalgNoDimNoOrdModel(), y)
8599
        z = torch.randn(2)
8600
        self.run_test(LinalgNoDimNoOrdModel(), z)
8601

8602
        class LinalgNoDim1DModel(torch.nn.Module):
8603
            def __init__(self, ord_val):
8604
                super().__init__()
8605
                self.ord = ord_val
8606

8607
            def forward(self, x):
8608
                return torch.linalg.norm(x, ord=self.ord)
8609

8610
        x = torch.randn(2)
8611
        self.run_test(LinalgNoDim1DModel(None), x)
8612
        self.run_test(LinalgNoDim1DModel(2), x)
8613
        self.run_test(LinalgNoDim1DModel(float("inf")), x)
8614
        self.run_test(LinalgNoDim1DModel(-float("inf")), x)
8615
        self.run_test(LinalgNoDim1DModel(-4), x)
8616
        self.run_test(LinalgNoDim1DModel(1.5), x)
8617

8618
        class LinalgNoDim2DModel(torch.nn.Module):
8619
            def __init__(self, ord_val):
8620
                super().__init__()
8621
                self.ord = ord_val
8622

8623
            def forward(self, x):
8624
                return torch.linalg.norm(x, ord=self.ord)
8625

8626
        x = torch.randn(2, 3)
8627
        self.run_test(LinalgNoDim2DModel("fro"), x)
8628
        self.run_test(LinalgNoDim2DModel(float("inf")), x)
8629
        self.run_test(LinalgNoDim2DModel(-float("inf")), x)
8630
        self.run_test(LinalgNoDim2DModel(1), x)
8631
        self.run_test(LinalgNoDim2DModel(-1), x)
8632

8633
    @skipIfUnsupportedMinOpsetVersion(11)
8634
    def test_linalg_vector_norm_zero(self):
8635
        class LinalgVectorNormModel(torch.nn.Module):
8636
            def __init__(self, ord_val):
8637
                super().__init__()
8638
                self.ord = ord_val
8639

8640
            def forward(self, x):
8641
                return torch.linalg.vector_norm(x, ord=self.ord)
8642

8643
        x = torch.randn(2, 3, 5, 5)
8644
        self.run_test(LinalgVectorNormModel(0), x)
8645

8646
    def test_linalg_vector_norm(self):
8647
        class LinalgVectorNormModel(torch.nn.Module):
8648
            def __init__(self, ord_val, dim_info):
8649
                super().__init__()
8650
                self.ord = ord_val
8651
                self.dim, self.keepdim = dim_info
8652

8653
            def forward(self, x):
8654
                return torch.linalg.vector_norm(
8655
                    x, ord=self.ord, dim=self.dim, keepdim=self.keepdim
8656
                )
8657

8658
        x = torch.randn(2, 3, 5, 5)
8659
        ord_options = [2, float("inf"), -float("inf"), -4, 1.5]
8660
        dim_options = [(None, False), (1, False), ((1, 2), False), ((1, 2), True)]
8661
        for ord_val in ord_options:
8662
            for dim_info in dim_options:
8663
                self.run_test(LinalgVectorNormModel(ord_val, dim_info), x)
8664

8665
    def test_linalg_matrix_norm(self):
8666
        class LinalgMatrixNormModel(torch.nn.Module):
8667
            def __init__(self, ord_val, dim_val=(-2, -1), keepdim_val=False):
8668
                super().__init__()
8669
                self.ord = ord_val
8670
                self.dim = dim_val
8671
                self.keepdim = keepdim_val
8672

8673
            def forward(self, x):
8674
                return torch.linalg.matrix_norm(
8675
                    x, ord=self.ord, dim=self.dim, keepdim=self.keepdim
8676
                )
8677

8678
        x = torch.randn(2, 3, 5, 5)
8679
        ord_options = ["fro", float("inf"), -float("inf"), 1, -1]
8680
        for ord_val in ord_options:
8681
            self.run_test(LinalgMatrixNormModel(ord_val), x)
8682
            self.run_test(LinalgMatrixNormModel(ord_val, (0, 2)), x)
8683
            self.run_test(LinalgMatrixNormModel(ord_val, (0, 2), True), x)
8684

8685
    @skipIfUnsupportedMinOpsetVersion(9)
8686
    def test_linalg_cross(self):
8687
        class Cross(torch.nn.Module):
8688
            def forward(self, x, y):
8689
                return torch.linalg.cross(x, y, dim=1), torch.linalg.cross(x, y)
8690

8691
        x = torch.randn(5, 3, 2, 3)
8692
        y = torch.randn(1, 3, 1, 3)
8693
        self.run_test(Cross(), input_args=(x, y))
8694

8695
    # This test checks output scalar type in the ONNX graph should not be null
8696
    # https://github.com/pytorch/pytorch/issues/28607
8697
    @skipIfUnsupportedMinOpsetVersion(10)
8698
    def test_trace_script(self):
8699
        @torch.jit.script
8700
        def center_slice_helper(input, h_offset):
8701
            return input[:, h_offset:]
8702

8703
        class CenterCrop(torch.nn.Module):
8704
            def forward(self, input):
8705
                return center_slice_helper(input, torch.tensor(input.shape[1] - 1))
8706

8707
        x = torch.randn(3, 4)
8708
        self.run_test(CenterCrop(), x)
8709

8710
    @skipIfNoLapack
8711
    @skipIfUnsupportedMinOpsetVersion(11)
8712
    def test_logdet(self):
8713
        class LogDet(torch.nn.Module):
8714
            def forward(self, x):
8715
                return torch.logdet(x)
8716

8717
        x = torch.randn(2, 3, 5, 5)
8718
        self.run_test(LogDet(), x)
8719

8720
    def test_dim(self):
8721
        class DimModel(torch.jit.ScriptModule):
8722
            @torch.jit.script_method
8723
            def forward(self, input):
8724
                out = input * 2
8725
                out *= out.dim()
8726
                return out
8727

8728
        empty_input = torch.randn(0, requires_grad=True)
8729
        multi_dim_input = torch.randn(1, 2, 3, requires_grad=True)
8730
        self.run_test(DimModel(), empty_input)
8731
        self.run_test(DimModel(), multi_dim_input)
8732

8733
    @skipIfUnsupportedMinOpsetVersion(11)
8734
    def test_dim_1(self):
8735
        class M(torch.jit.ScriptModule):
8736
            @torch.jit.script_method
8737
            def forward(self, poses):
8738
                boxes = torch.zeros([poses.shape[0], 2, 4])
8739
                batch_boxes = []
8740
                for kp_boxes in boxes:
8741
                    kp_boxes = torchvision.ops.clip_boxes_to_image(kp_boxes, (2, 3))
8742
                    batch_boxes.append(kp_boxes)
8743
                return batch_boxes
8744

8745
        dummy_inputs = torch.rand(2, 2, 3)
8746
        self.run_test(M(), (dummy_inputs,), input_names=["x"], dynamic_axes={"x": [0]})
8747

8748
    @skipIfUnsupportedMinOpsetVersion(12)
8749
    @skipDtypeChecking
8750
    def test_outer(self):
8751
        class Outer(torch.nn.Module):
8752
            def forward(self, x, y):
8753
                return torch.outer(x, y)
8754

8755
        x = torch.arange(1, 5)
8756
        y = torch.arange(1, 4)
8757
        self.run_test(Outer(), input_args=(x, y))
8758

8759
        x = torch.arange(1, 6).to(dtype=torch.float32)
8760
        y = torch.arange(1, 4).to(dtype=torch.long)
8761
        self.run_test(Outer(), input_args=(x, y))
8762

8763
        x = torch.arange(2, 5).to(dtype=torch.float32)
8764
        y = torch.arange(2, 4).to(dtype=torch.float64)
8765
        self.run_test(Outer(), input_args=(x, y))
8766

8767
        x = torch.arange(3, 6).to(dtype=torch.int32)
8768
        y = torch.arange(4, 7).to(dtype=torch.long)
8769
        self.run_test(Outer(), input_args=(x, y))
8770

8771
    @skipIfUnsupportedMinOpsetVersion(9)
8772
    def test_movedim(self):
8773
        class MovedimModel(torch.nn.Module):
8774
            def forward(self, x):
8775
                return (
8776
                    x.movedim(1, 3),
8777
                    x.movedim(2, 0),
8778
                    x.movedim(1, 1),
8779
                    x.movedim((1, 2, 3), (3, 0, 1)),
8780
                    x.movedim((0, 1, 2), (1, 2, 3)),
8781
                    x.movedim((1, 3, 2), (1, 3, 2)),
8782
                )
8783

8784
        x = torch.randn(5, 3, 4, 2)
8785

8786
        self.run_test(MovedimModel(), x)
8787

8788
    @skipIfUnsupportedMinOpsetVersion(9)
8789
    def test_moveaxis(self):
8790
        # moveaxis is an alias of movedim; thus, mostly copied from `test_movedim`.
8791
        class MoveaxisModel(torch.nn.Module):
8792
            def forward(self, x):
8793
                return (
8794
                    x.moveaxis(1, 3),
8795
                    x.moveaxis(2, 0),
8796
                    x.moveaxis(1, 1),
8797
                    x.moveaxis((1, 2, 3), (3, 0, 1)),
8798
                    x.moveaxis((0, 1, 2), (1, 2, 3)),
8799
                    x.moveaxis((1, 3, 2), (1, 3, 2)),
8800
                )
8801

8802
        x = torch.randn(5, 3, 4, 2)
8803

8804
        self.run_test(MoveaxisModel(), x)
8805

8806
    @skipIfUnsupportedMinOpsetVersion(12)
8807
    def test_einsum(self):
8808
        class EinsumModelBatchDiagonal(torch.nn.Module):
8809
            def forward(self, x):
8810
                eqn = "...ii ->...i"
8811
                return torch.einsum(eqn, x)
8812

8813
        for x in [torch.randn(3, 5, 5), torch.randn(3, 5, 5).to(dtype=torch.bool)]:
8814
            self.run_test(EinsumModelBatchDiagonal(), input_args=(x,))
8815

8816
        class EinsumModelBatchMatmul(torch.nn.Module):
8817
            def forward(self, x, y):
8818
                eqn = "bij, bjk -> bik"
8819
                return torch.einsum(eqn, x, y)
8820

8821
        x = torch.randn(5, 2, 3)
8822
        y = torch.randn(5, 3, 4)
8823
        self.run_test(EinsumModelBatchMatmul(), input_args=(x, y))
8824

8825
        class EinsumModelInnerProd(torch.nn.Module):
8826
            def forward(self, x, y):
8827
                eqn = "i,i"
8828
                return torch.einsum(eqn, x, y)
8829

8830
        x = torch.randn(5)
8831
        y = torch.randn(5)
8832
        self.run_test(EinsumModelInnerProd(), input_args=(x, y))
8833

8834
        class EinsumModelTranspose(torch.nn.Module):
8835
            def forward(self, x):
8836
                eqn = "ij->ji"
8837
                return torch.einsum(eqn, x)
8838

8839
        for x in [torch.randn(3, 4), torch.randn(3, 4).to(dtype=torch.bool)]:
8840
            self.run_test(EinsumModelTranspose(), input_args=(x,))
8841

8842
    @skipIfUnsupportedMinOpsetVersion(9)
8843
    def test_cosine_similarity(self):
8844
        x = torch.randn(5, 3, 2)
8845
        y = torch.randn(5, 3, 2)
8846
        self.run_test(torch.nn.CosineSimilarity(dim=2), input_args=(x, y))
8847

8848
    @skipIfUnsupportedMinOpsetVersion(9)
8849
    def test_pairwise_distance(self):
8850
        x = torch.randn(5, 3, 2)
8851
        y = torch.randn(5, 3, 2)
8852
        self.run_test(torch.nn.PairwiseDistance(p=2.0), input_args=(x, y))
8853

8854
    @skipIfUnsupportedMinOpsetVersion(9)
8855
    def test_cross(self):
8856
        class Cross(torch.nn.Module):
8857
            def forward(self, x, y):
8858
                return torch.cross(x, y, dim=3), torch.cross(x, y)
8859

8860
        x = torch.randn(5, 3, 2, 3)
8861
        y = torch.randn(5, 3, 2, 3)
8862
        self.run_test(Cross(), input_args=(x, y))
8863

8864
    @skipIfUnsupportedMinOpsetVersion(9)
8865
    def test_cdist(self):
8866
        class Cdist(torch.nn.Module):
8867
            def forward(self, x, y):
8868
                return torch.cdist(x, y)
8869

8870
        x = torch.randn(5, 3, 3)
8871
        y = torch.randn(5, 2, 3)
8872
        self.run_test(Cdist(), input_args=(x, y))
8873

8874
    @skipIfUnsupportedMinOpsetVersion(12)
8875
    def test_crossentropyloss(self):
8876
        for ignore_index in [-100, 1]:
8877
            x = torch.randn(3, 5)
8878
            y = torch.empty(3, dtype=torch.long).random_(5)
8879
            y[y == 1] = ignore_index
8880

8881
            self._crossentropyloss(x, y, ignore_index)
8882

8883
            x = torch.randn(3, 5, 2)
8884
            y = torch.empty(3, 2, dtype=torch.long).random_(5)
8885
            y[y == 1] = ignore_index
8886
            self._crossentropyloss(x, y, ignore_index)
8887

8888
            x = torch.randn(3, 5, 2, 7)
8889
            y = torch.empty(3, 2, 7, dtype=torch.long).random_(5)
8890
            y[y == 1] = ignore_index
8891
            self._crossentropyloss(x, y, ignore_index)
8892

8893
    def _crossentropyloss(self, x, y, ignore_index):
8894
        class CrossEntropyLossNone(torch.nn.Module):
8895
            def __init__(self, ignore_index):
8896
                super().__init__()
8897
                if ignore_index == -100:
8898
                    self.loss = torch.nn.CrossEntropyLoss(reduction="none")
8899
                else:
8900
                    self.loss = torch.nn.CrossEntropyLoss(
8901
                        reduction="none", ignore_index=ignore_index
8902
                    )
8903

8904
            def forward(self, input, target):
8905
                return self.loss(input, target)
8906

8907
        self.run_test(CrossEntropyLossNone(ignore_index), input_args=(x, y))
8908

8909
        class CrossEntropyLossNoneWeight(torch.nn.Module):
8910
            def __init__(self, ignore_index):
8911
                super().__init__()
8912
                if ignore_index == -100:
8913
                    self.loss = torch.nn.CrossEntropyLoss(
8914
                        reduction="none", weight=torch.randn(5)
8915
                    )
8916
                else:
8917
                    self.loss = torch.nn.CrossEntropyLoss(
8918
                        reduction="none",
8919
                        weight=torch.randn(5),
8920
                        ignore_index=ignore_index,
8921
                    )
8922

8923
            def forward(self, input, target):
8924
                return self.loss(input, target)
8925

8926
        self.run_test(CrossEntropyLossNoneWeight(ignore_index), input_args=(x, y))
8927

8928
        class CrossEntropyLossSum(torch.nn.Module):
8929
            def __init__(self, ignore_index):
8930
                super().__init__()
8931
                if ignore_index == -100:
8932
                    self.loss = torch.nn.CrossEntropyLoss(reduction="sum")
8933
                else:
8934
                    self.loss = torch.nn.CrossEntropyLoss(
8935
                        reduction="sum", ignore_index=ignore_index
8936
                    )
8937

8938
            def forward(self, input, target):
8939
                return self.loss(input, target)
8940

8941
        self.run_test(CrossEntropyLossSum(ignore_index), input_args=(x, y))
8942

8943
        class CrossEntropyLossSumWeight(torch.nn.Module):
8944
            def __init__(self, ignore_index):
8945
                super().__init__()
8946
                if ignore_index == -100:
8947
                    self.loss = torch.nn.CrossEntropyLoss(
8948
                        reduction="sum", weight=torch.randn(5)
8949
                    )
8950
                else:
8951
                    self.loss = torch.nn.CrossEntropyLoss(
8952
                        reduction="sum",
8953
                        weight=torch.randn(5),
8954
                        ignore_index=ignore_index,
8955
                    )
8956

8957
            def forward(self, input, target):
8958
                return self.loss(input, target)
8959

8960
        self.run_test(CrossEntropyLossSumWeight(ignore_index), input_args=(x, y))
8961

8962
        class CrossEntropyLossMean(torch.nn.Module):
8963
            def __init__(self, ignore_index):
8964
                super().__init__()
8965
                if ignore_index == -100:
8966
                    self.loss = torch.nn.CrossEntropyLoss()
8967
                else:
8968
                    self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
8969

8970
            def forward(self, input, target):
8971
                return self.loss(input, target)
8972

8973
        self.run_test(CrossEntropyLossMean(ignore_index), input_args=(x, y))
8974

8975
        class CrossEntropyLossMeanWeight(torch.nn.Module):
8976
            def __init__(self, ignore_index):
8977
                super().__init__()
8978
                if ignore_index == -100:
8979
                    self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5))
8980
                else:
8981
                    self.loss = torch.nn.CrossEntropyLoss(
8982
                        weight=torch.randn(5), ignore_index=ignore_index
8983
                    )
8984

8985
            def forward(self, input, target):
8986
                return self.loss(input, target)
8987

8988
        self.run_test(CrossEntropyLossMeanWeight(ignore_index), input_args=(x, y))
8989

8990
    @skipIfUnsupportedMinOpsetVersion(9)
8991
    def test_MSELoss(self):
8992
        class MSELoss(torch.nn.Module):
8993
            def __init__(self):
8994
                super().__init__()
8995
                self.loss1 = torch.nn.MSELoss(reduction="none")
8996
                self.loss2 = torch.nn.MSELoss(reduction="sum")
8997
                self.loss3 = torch.nn.MSELoss(reduction="mean")
8998

8999
            def forward(self, input, target):
9000
                return (
9001
                    self.loss1(input, target),
9002
                    self.loss2(input, target),
9003
                    self.loss3(input, target),
9004
                )
9005

9006
        x = torch.randn(2, 3, 5)
9007
        y = torch.randn(2, 3, 5)
9008
        self.run_test(MSELoss(), input_args=(x, y))
9009

9010
    @skipIfUnsupportedMinOpsetVersion(9)
9011
    def test_kldiv_loss(self):
9012
        x = torch.rand(5).log()
9013
        y = torch.rand(5)
9014
        self._kldiv_loss(x, y)
9015

9016
        x = torch.rand(2, 3, 5).log()
9017
        y = torch.rand(2, 3, 5)
9018
        self._kldiv_loss(x, y)
9019

9020
        x = torch.rand(2, 3, 5, 7).log()
9021
        y = torch.rand(2, 3, 5, 7)
9022
        self._kldiv_loss(x, y)
9023

9024
    def _kldiv_loss(self, x, y):
9025
        class KLDivLossNone(torch.nn.Module):
9026
            def __init__(self):
9027
                super().__init__()
9028
                self.loss = torch.nn.KLDivLoss(reduction="none", log_target=True)
9029

9030
            def forward(self, input, target):
9031
                return self.loss(input, target.log())
9032

9033
        self.run_test(KLDivLossNone(), input_args=(x, y))
9034

9035
        class KLDivLossMean(torch.nn.Module):
9036
            def __init__(self):
9037
                super().__init__()
9038
                self.loss = torch.nn.KLDivLoss(reduction="mean", log_target=False)
9039

9040
            def forward(self, input, target):
9041
                return self.loss(input, target)
9042

9043
        self.run_test(KLDivLossMean(), input_args=(x, y))
9044

9045
        class KLDivLossSum(torch.nn.Module):
9046
            def __init__(self):
9047
                super().__init__()
9048
                self.loss = torch.nn.KLDivLoss(reduction="sum", log_target=True)
9049

9050
            def forward(self, input, target):
9051
                return self.loss(input, target.log())
9052

9053
        self.run_test(KLDivLossSum(), input_args=(x, y))
9054

9055
        class KLDivLossBatchMean(torch.nn.Module):
9056
            def __init__(self):
9057
                super().__init__()
9058
                self.loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=False)
9059

9060
            def forward(self, input, target):
9061
                return self.loss(input, target)
9062

9063
        self.run_test(KLDivLossBatchMean(), input_args=(x, y))
9064

9065
        class KLDivLossMiniBatchMean(torch.nn.Module):
9066
            def __init__(self):
9067
                super().__init__()
9068
                self.loss = torch.nn.KLDivLoss(
9069
                    reduction="batchmean", size_average=False, log_target=True
9070
                )
9071

9072
            def forward(self, input, target):
9073
                return self.loss(input, target.log())
9074

9075
        self.run_test(KLDivLossMiniBatchMean(), input_args=(x, y))
9076

9077
    @skipIfUnsupportedMinOpsetVersion(12)
9078
    def test_nllloss(self):
9079
        class NLLModel(torch.nn.Module):
9080
            def __init__(self):
9081
                super().__init__()
9082
                self.loss = torch.nn.NLLLoss(reduction="none")
9083
                self.m = torch.nn.LogSoftmax(dim=1)
9084

9085
            def forward(self, input, target):
9086
                output = self.loss(self.m(2 * input), target)
9087
                return output
9088

9089
        N, C = 5, 4
9090
        input = torch.randn(N, 16)
9091
        target = torch.empty(N, dtype=torch.long).random_(0, C)
9092

9093
        # using test data containing default ignore_index=-100
9094
        target[target == 1] = -100
9095
        self.run_test(NLLModel(), (input, target))
9096

9097
    @skipIfUnsupportedMinOpsetVersion(12)
9098
    def test_nllloss_2d_none(self):
9099
        class NLLModel(torch.nn.Module):
9100
            def __init__(self):
9101
                super().__init__()
9102
                self.loss = torch.nn.NLLLoss(reduction="none")
9103
                self.conv = torch.nn.Conv2d(16, C, (3, 3))
9104
                self.m = torch.nn.LogSoftmax(dim=1)
9105

9106
            def forward(self, input, target):
9107
                output = self.loss(self.m(self.conv(input)), target)
9108
                return output
9109

9110
        N, C = 5, 4
9111
        input = torch.randn(N, 16, 10, 10)
9112
        target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9113

9114
        # using test data containing default ignore_index=-100
9115
        target[target == 1] = -100
9116
        self.run_test(NLLModel(), (input, target))
9117

9118
    @skipIfUnsupportedMinOpsetVersion(12)
9119
    def test_nllloss_2d_mean(self):
9120
        class NLLModel(torch.nn.Module):
9121
            def __init__(self):
9122
                super().__init__()
9123
                self.loss = torch.nn.NLLLoss(reduction="mean")
9124
                self.conv = torch.nn.Conv2d(16, C, (3, 3))
9125
                self.m = torch.nn.LogSoftmax(dim=1)
9126

9127
            def forward(self, input, target):
9128
                output = self.loss(self.m(self.conv(input)), target)
9129
                return output
9130

9131
        N, C = 5, 4
9132
        input = torch.randn(N, 16, 10, 10)
9133
        target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9134

9135
        # using test data containing default ignore_index=-100
9136
        target[target == 1] = -100
9137
        self.run_test(NLLModel(), (input, target))
9138

9139
    @skipIfUnsupportedMinOpsetVersion(12)
9140
    def test_nllloss_2d_sum(self):
9141
        class NLLModel(torch.nn.Module):
9142
            def __init__(self):
9143
                super().__init__()
9144
                self.loss = torch.nn.NLLLoss(reduction="sum")
9145
                self.conv = torch.nn.Conv2d(16, C, (3, 3))
9146
                self.m = torch.nn.LogSoftmax(dim=1)
9147

9148
            def forward(self, input, target):
9149
                output = self.loss(self.m(self.conv(input)), target)
9150
                return output
9151

9152
        N, C = 5, 4
9153
        input = torch.randn(N, 16, 10, 10)
9154
        target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9155

9156
        # using test data containing default ignore_index=-100
9157
        target[target == 1] = -100
9158
        self.run_test(NLLModel(), (input, target))
9159

9160
    @skipIfUnsupportedMinOpsetVersion(12)
9161
    def test_nllloss_2d_mean_weights(self):
9162
        class NLLModel(torch.nn.Module):
9163
            def __init__(self):
9164
                super().__init__()
9165
                self.loss = torch.nn.NLLLoss(reduction="mean", weight=torch.randn(C))
9166
                self.conv = torch.nn.Conv2d(16, C, (3, 3))
9167
                self.m = torch.nn.LogSoftmax(dim=1)
9168

9169
            def forward(self, input, target):
9170
                output = self.loss(self.m(self.conv(input)), target)
9171
                return output
9172

9173
        N, C = 5, 4
9174
        input = torch.randn(N, 16, 10, 10)
9175
        target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9176

9177
        # using test data containing default ignore_index=-100
9178
        target[target == 1] = -100
9179
        self.run_test(NLLModel(), (input, target))
9180

9181
    @skipIfUnsupportedMinOpsetVersion(12)
9182
    def test_nllloss_2d_mean_ignore_index(self):
9183
        class NLLModel(torch.nn.Module):
9184
            def __init__(self):
9185
                super().__init__()
9186
                self.loss = torch.nn.NLLLoss(reduction="mean", ignore_index=1)
9187
                self.conv = torch.nn.Conv2d(16, C, (3, 3))
9188
                self.m = torch.nn.LogSoftmax(dim=1)
9189

9190
            def forward(self, input, target):
9191
                output = self.loss(self.m(self.conv(input)), target)
9192
                return output
9193

9194
        N, C = 5, 4
9195
        input = torch.randn(N, 16, 10, 10)
9196
        target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9197
        self.run_test(NLLModel(), (input, target))
9198

9199
    @skipIfUnsupportedMinOpsetVersion(12)
9200
    def test_nllloss_dynamic_ignore_index(self):
9201
        import torch.nn.functional as F
9202

9203
        def linear_combination(x, y, epsilon):
9204
            return epsilon * x + (1 - epsilon) * y
9205

9206
        def reduce_loss(loss, reduction="mean"):
9207
            return (
9208
                loss.mean()
9209
                if reduction == "mean"
9210
                else loss.sum()
9211
                if reduction == "sum"
9212
                else loss
9213
            )
9214

9215
        class LabelSmoothingCrossEntropy(torch.nn.Module):
9216
            def __init__(self, epsilon: float = 0.1, reduction="mean"):
9217
                super().__init__()
9218
                self.epsilon = epsilon
9219
                self.reduction = reduction
9220

9221
            def forward(self, preds, target, start_position):
9222
                n = preds.size()[-1]
9223
                log_preds = F.log_softmax(preds, dim=-1)
9224
                ignore_index = start_position.size(1)
9225
                nll = F.nll_loss(
9226
                    log_preds,
9227
                    target,
9228
                    reduction=self.reduction,
9229
                    ignore_index=ignore_index,
9230
                )
9231
                return nll + start_position.float()
9232

9233
        N = 5
9234
        preds = torch.randn(N, 16)
9235
        target = torch.randint(5, (N,))
9236
        start_position = torch.randint(10, (N, N))
9237
        self.run_test(LabelSmoothingCrossEntropy(), (preds, target, start_position))
9238

9239
    @skipIfUnsupportedMinOpsetVersion(12)
9240
    def test_nllloss_2d_mean_ignore_index_weights(self):
9241
        class NLLModel(torch.nn.Module):
9242
            def __init__(self):
9243
                super().__init__()
9244
                self.loss = torch.nn.NLLLoss(
9245
                    reduction="mean", weight=torch.randn(C), ignore_index=1
9246
                )
9247
                self.conv = torch.nn.Conv2d(16, C, (3, 3))
9248
                self.m = torch.nn.LogSoftmax(dim=1)
9249

9250
            def forward(self, input, target):
9251
                output = self.loss(self.m(self.conv(input)), target)
9252
                return output
9253

9254
        N, C = 5, 4
9255
        input = torch.randn(N, 16, 10, 10)
9256
        target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
9257
        self.run_test(NLLModel(), (input, target))
9258

9259
    @skipIfUnsupportedMinOpsetVersion(12)
9260
    def test_binary_cross_entropy_with_logits(self):
9261
        x = torch.randn(5)
9262
        y = torch.empty(5).random_(2)
9263
        self._bce_logits(x, y)
9264

9265
        x = torch.randn(3, 4)
9266
        y = torch.empty(3, 4).random_(2)
9267
        weight = torch.tensor([3])
9268
        self._bce_logits_wegiht(x, y, weight)
9269

9270
        x = torch.randn(3, 2, 4)
9271
        y = torch.empty(3, 2, 4).random_(2)
9272
        pos_weight = torch.empty([2, 4]).random_(2)
9273
        self._bce_logits_posweight(x, y, pos_weight)
9274

9275
        x = torch.randn(3, 3, 4)
9276
        y = torch.empty(3, 3, 4).random_(2)
9277
        weight = torch.tensor([3])
9278
        pos_weight = torch.empty([3, 4]).random_(2)
9279
        self._bce_logits_loss_weight_posweight(x, y, weight, pos_weight)
9280

9281
    def _bce_logits(self, x, y):
9282
        class BCEWithLogitsLossNone(torch.nn.Module):
9283
            def forward(self, input, target):
9284
                return torch.nn.functional.binary_cross_entropy_with_logits(
9285
                    input, target, reduction="none"
9286
                )
9287

9288
        self.run_test(BCEWithLogitsLossNone(), input_args=(x, y))
9289

9290
        class BCEWithLogitsLossMean(torch.nn.Module):
9291
            def forward(self, input, target):
9292
                return torch.nn.functional.binary_cross_entropy_with_logits(
9293
                    input, target, reduction="mean"
9294
                )
9295

9296
        self.run_test(BCEWithLogitsLossMean(), input_args=(x, y))
9297

9298
        class BCEWithLogitsLossSum(torch.nn.Module):
9299
            def forward(self, input, target):
9300
                return torch.nn.functional.binary_cross_entropy_with_logits(
9301
                    input, target, reduction="sum"
9302
                )
9303

9304
        self.run_test(BCEWithLogitsLossSum(), input_args=(x, y))
9305

9306
    def _bce_logits_wegiht(self, x, y, weight):
9307
        class BCEWithLogitsLossWegihtNone(torch.nn.Module):
9308
            def forward(self, input, target, weight):
9309
                return torch.nn.functional.binary_cross_entropy_with_logits(
9310
                    input, target, weight=weight, reduction="none"
9311
                )
9312

9313
        self.run_test(BCEWithLogitsLossWegihtNone(), input_args=(x, y, weight))
9314

9315
        class BCEWithLogitsLossWegihtMean(torch.nn.Module):
9316
            def forward(self, input, target, weight):
9317
                return torch.nn.functional.binary_cross_entropy_with_logits(
9318
                    input, target, weight=weight, reduction="mean"
9319
                )
9320

9321
        self.run_test(BCEWithLogitsLossWegihtMean(), input_args=(x, y, weight))
9322

9323
        class BCEWithLogitsLossWegihtSum(torch.nn.Module):
9324
            def forward(self, input, target, weight):
9325
                return torch.nn.functional.binary_cross_entropy_with_logits(
9326
                    input, target, weight=weight, reduction="sum"
9327
                )
9328

9329
        self.run_test(BCEWithLogitsLossWegihtSum(), input_args=(x, y, weight))
9330

9331
    def _bce_logits_posweight(self, x, y, pos_weight):
9332
        class BCEWithLogitsLossPosWegihtNone(torch.nn.Module):
9333
            def forward(self, input, target, pos_weight):
9334
                return torch.nn.functional.binary_cross_entropy_with_logits(
9335
                    input, target, pos_weight=pos_weight, reduction="none"
9336
                )
9337

9338
        self.run_test(BCEWithLogitsLossPosWegihtNone(), input_args=(x, y, pos_weight))
9339

9340
        class BCEWithLogitsLossPosWegihtMean(torch.nn.Module):
9341
            def forward(self, input, target, pos_weight):
9342
                return torch.nn.functional.binary_cross_entropy_with_logits(
9343
                    input, target, pos_weight=pos_weight, reduction="mean"
9344
                )
9345

9346
        self.run_test(BCEWithLogitsLossPosWegihtMean(), input_args=(x, y, pos_weight))
9347

9348
        class BCEWithLogitsLossPosWegihtSum(torch.nn.Module):
9349
            def forward(self, input, target, pos_weight):
9350
                return torch.nn.functional.binary_cross_entropy_with_logits(
9351
                    input, target, pos_weight=pos_weight, reduction="sum"
9352
                )
9353

9354
        self.run_test(BCEWithLogitsLossPosWegihtSum(), input_args=(x, y, pos_weight))
9355

9356
    def _bce_logits_loss_weight_posweight(self, x, y, weight, pos_weight):
9357
        class BCEWithLogitsLossWeightPosweightNone(torch.nn.Module):
9358
            def forward(self, input, target, weight, pos_weight):
9359
                return torch.nn.functional.binary_cross_entropy_with_logits(
9360
                    input,
9361
                    target,
9362
                    weight=weight,
9363
                    pos_weight=pos_weight,
9364
                    reduction="none",
9365
                )
9366

9367
        self.run_test(
9368
            BCEWithLogitsLossWeightPosweightNone(),
9369
            input_args=(x, y, weight, pos_weight),
9370
        )
9371

9372
        class BCEWithLogitsLossWeightPosweightMean(torch.nn.Module):
9373
            def forward(self, input, target, weight, pos_weight):
9374
                return torch.nn.functional.binary_cross_entropy_with_logits(
9375
                    input,
9376
                    target,
9377
                    weight=weight,
9378
                    pos_weight=pos_weight,
9379
                    reduction="mean",
9380
                )
9381

9382
        self.run_test(
9383
            BCEWithLogitsLossWeightPosweightMean(),
9384
            input_args=(x, y, weight, pos_weight),
9385
        )
9386

9387
        class BCEWithLogitsLossWeightPosweightSum(torch.nn.Module):
9388
            def forward(self, input, target, weight, pos_weight):
9389
                return torch.nn.functional.binary_cross_entropy_with_logits(
9390
                    input, target, weight=weight, pos_weight=pos_weight, reduction="sum"
9391
                )
9392

9393
        self.run_test(
9394
            BCEWithLogitsLossWeightPosweightSum(), input_args=(x, y, weight, pos_weight)
9395
        )
9396

9397
    def test_torch_mm(self):
9398
        class M(torch.nn.Module):
9399
            def forward(self, mat1, mat2):
9400
                mm = torch.mm(mat1, mat2)
9401
                return mm
9402

9403
        mat1 = torch.randn(2, 3)
9404
        mat2 = torch.randn(3, 3)
9405
        self.run_test(M(), input_args=(mat1, mat2))
9406

9407
    @skipIfUnsupportedMinOpsetVersion(
9408
        9
9409
    )  # Because where op is not supported for opset < 9.
9410
    def test_where_with_bool_tensor(self):
9411
        class M(torch.nn.Module):
9412
            def forward(self, mat1, mat2):
9413
                out = torch.where(mat1 > 0, mat1, mat2)
9414
                return out
9415

9416
        mat1 = torch.randn(2, 3)
9417
        mat2 = torch.ones(2, 3)
9418
        self.run_test(M(), input_args=(mat1, mat2))
9419

9420
    @skipIfUnsupportedMinOpsetVersion(
9421
        9
9422
    )  # Because where op is not supported for opset < 9.
9423
    def test_where_with_byte_tensor(self):
9424
        class M(torch.nn.Module):
9425
            def forward(self, cond, mat1, mat2):
9426
                out = torch.where(cond, mat1, mat2)
9427
                return out
9428

9429
        cond = torch.ones(2, 3, dtype=torch.uint8)
9430
        cond[1, 2] = 0
9431
        mat1 = torch.randn(2, 3)
9432
        mat2 = torch.ones(2, 3)
9433
        self.run_test(M(), input_args=(cond, mat1, mat2))
9434

9435
    @skipIfUnsupportedMinOpsetVersion(10)  # ONNX IsInf op is added in opset 10.
9436
    def test_isinf(self):
9437
        class M(torch.nn.Module):
9438
            def forward(self, x):
9439
                return x.isinf()
9440

9441
        x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), float("inf")]])
9442
        self.run_test(M(), (x,))
9443

9444
    @skipIfUnsupportedMinOpsetVersion(10)
9445
    def test_isfinite(self):
9446
        class M(torch.nn.Module):
9447
            def forward(self, x):
9448
                return x.isfinite()
9449

9450
        x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]])
9451
        self.run_test(M(), (x,))
9452

9453
    @skipIfUnsupportedMinOpsetVersion(9)  # ONNX IsNaN op is added in opset 9.
9454
    def test_isnan(self):
9455
        class M(torch.nn.Module):
9456
            def forward(self, x):
9457
                return x.isnan()
9458

9459
        x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), float("inf")]])
9460
        self.run_test(M(), (x,))
9461

9462
    @skipIfUnsupportedMinOpsetVersion(
9463
        10
9464
    )  # ONNX IsNaN, IsInf op is added in opset 9, 10 respectively.
9465
    def test_nan_to_num(self):
9466
        class NoParams(torch.nn.Module):
9467
            def forward(self, x):
9468
                return x.nan_to_num()
9469

9470
        x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]])
9471
        xint = torch.ones((2, 4), dtype=torch.int)
9472
        xhalf = torch.ones((2, 4), dtype=torch.half)
9473
        self.run_test(NoParams(), (x,))
9474
        self.run_test(NoParams(), (xint,))
9475
        self.run_test(NoParams(), (xhalf,))
9476

9477
        class WithParams(torch.nn.Module):
9478
            def forward(self, x):
9479
                return x.nan_to_num(nan=2.3, posinf=4.5, neginf=6.7)
9480

9481
        x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]])
9482
        self.run_test(WithParams(), (x,))
9483

9484
    @skipIfUnsupportedMinOpsetVersion(9)
9485
    def test_maximum_minimum(self):
9486
        class ModelWithNan(torch.nn.Module):
9487
            def forward(self, x, y):
9488
                return torch.maximum(x, y), torch.minimum(x, y)
9489

9490
        x = torch.tensor([-2, -2, float("nan")])
9491
        y = torch.rand(1, 3)
9492
        self.run_test(ModelWithNan(), (x, y))
9493

9494
    @skipIfUnsupportedMinOpsetVersion(12)
9495
    def test_minimum_dtypes(self):
9496
        class MinimumModel(torch.nn.Module):
9497
            def forward(self, x, y):
9498
                return torch.minimum(x, y)
9499

9500
        x = torch.randn((5, 5), dtype=torch.float16)
9501
        y = torch.randn((5, 5), dtype=torch.float)
9502
        self.run_test(MinimumModel(), (x, y))
9503

9504
        x = torch.randn((5, 5), dtype=torch.float16)
9505
        y = torch.randint(10, (5, 5), dtype=torch.int16)
9506
        self.run_test(MinimumModel(), (x, y))
9507

9508
        x = torch.randint(10, (5, 5), dtype=torch.int16)
9509
        y = torch.randint(10, (5, 5), dtype=torch.int32)
9510
        self.run_test(MinimumModel(), (x, y))
9511

9512
        x = torch.randint(10, (5, 5), dtype=torch.int)
9513
        y = torch.full_like(x, True)
9514
        self.run_test(MinimumModel(), (x, y))
9515

9516
    @skipIfUnsupportedMinOpsetVersion(12)
9517
    def test_maximum_dtypes(self):
9518
        class MaximumModel(torch.nn.Module):
9519
            def forward(self, x, y):
9520
                return torch.maximum(x, y)
9521

9522
        x = torch.randn((5, 5), dtype=torch.float16)
9523
        y = torch.randn((5, 5), dtype=torch.float)
9524
        self.run_test(MaximumModel(), (x, y))
9525

9526
        x = torch.randn((5, 5), dtype=torch.float16)
9527
        y = torch.randint(10, (5, 5), dtype=torch.int16)
9528
        self.run_test(MaximumModel(), (x, y))
9529

9530
        x = torch.randint(10, (5, 5), dtype=torch.int16)
9531
        y = torch.randint(10, (5, 5), dtype=torch.int32)
9532
        self.run_test(MaximumModel(), (x, y))
9533

9534
        x = torch.randint(10, (5, 5), dtype=torch.int)
9535
        y = torch.full_like(x, True)
9536
        self.run_test(MaximumModel(), (x, y))
9537

9538
    @skipIfUnsupportedMinOpsetVersion(9)
9539
    def test_any(self):
9540
        class M(torch.nn.Module):
9541
            def forward(self, x):
9542
                return x.any()
9543

9544
        x = torch.tensor([[True, False], [False, False]])
9545
        self.run_test(M(), (x,))
9546

9547
        class MDim(torch.nn.Module):
9548
            def forward(self, x):
9549
                return x.any(dim=1)
9550

9551
        x = torch.rand(3, 4).bool()
9552
        self.run_test(MDim(), (x,))
9553

9554
        class MKeepdim(torch.nn.Module):
9555
            def forward(self, x):
9556
                return x.any(dim=1, keepdim=True)
9557

9558
        x = torch.rand(3, 4).bool()
9559
        self.run_test(MKeepdim(), (x,))
9560

9561
    @skipIfUnsupportedMinOpsetVersion(9)
9562
    def test_all(self):
9563
        class M(torch.nn.Module):
9564
            def forward(self, x):
9565
                return x.all()
9566

9567
        x = torch.tensor([[True, False], [False, False]])
9568
        self.run_test(M(), (x,))
9569

9570
        class MDim(torch.nn.Module):
9571
            def forward(self, x):
9572
                return x.all(dim=1)
9573

9574
        x = torch.rand(3, 4).bool()
9575
        self.run_test(MDim(), (x,))
9576

9577
        class MKeepdim(torch.nn.Module):
9578
            def forward(self, x):
9579
                return x.all(dim=1, keepdim=True)
9580

9581
        x = torch.rand(3, 4).bool()
9582
        self.run_test(MKeepdim(), (x,))
9583

9584
    def test_dropout(self):
9585
        class M(torch.nn.Module):
9586
            def __init__(self):
9587
                super().__init__()
9588
                self.dropout = torch.nn.Dropout(0.3)
9589

9590
            def forward(self, x):
9591
                dropout = self.dropout(x)
9592
                return dropout
9593

9594
        x = torch.randn(10, 3, 53)
9595
        self.run_test(M(), (x))
9596

9597
    def test_rrelu_eval(self):
9598
        x = torch.tensor([0.5, -0.5])
9599
        self.run_test(torch.nn.RReLU(0.1, 0.3).eval(), x)
9600

9601
    def test_shape_constant_fold(self):
9602
        class ShapeModule(torch.nn.Module):
9603
            def __init__(self):
9604
                super().__init__()
9605
                self.register_buffer("weight", torch.ones(5))
9606

9607
            def forward(self, x):
9608
                shape = self.weight.shape[0]
9609
                return x + shape
9610

9611
        x = torch.randn(2, 5)
9612
        self.run_test(ShapeModule(), (x,), rtol=1e-3, atol=1e-5)
9613

9614
    @skipIfUnsupportedMinOpsetVersion(12)
9615
    def test_celu(self):
9616
        class Celu(torch.nn.Module):
9617
            def __init__(self):
9618
                super().__init__()
9619
                self.celu = torch.nn.CELU(alpha=1.0)
9620

9621
            def forward(self, input):
9622
                return self.celu(input)
9623

9624
        input = torch.randn(2)
9625
        self.run_test(Celu(), (input,))
9626

9627
    @skipIfUnsupportedMinOpsetVersion(12)
9628
    def test_celu_default(self):
9629
        class Celu(torch.nn.Module):
9630
            def __init__(self):
9631
                super().__init__()
9632
                self.celu = torch.nn.CELU()
9633

9634
            def forward(self, input):
9635
                return self.celu(input)
9636

9637
        input = torch.randn(2)
9638
        self.run_test(Celu(), (input,))
9639

9640
    @skipIfUnsupportedMinOpsetVersion(12)
9641
    def test_celu_alpha(self):
9642
        class Celu(torch.nn.Module):
9643
            def __init__(self):
9644
                super().__init__()
9645
                self.celu = torch.nn.CELU(alpha=2.0)
9646

9647
            def forward(self, input):
9648
                return self.celu(input)
9649

9650
        input = torch.randn(2)
9651
        self.run_test(Celu(), (input,))
9652

9653
    @skipIfUnsupportedMinOpsetVersion(12)
9654
    def test_celu_cast(self):
9655
        class Celu(torch.nn.Module):
9656
            def __init__(self):
9657
                super().__init__()
9658
                self.celu = torch.nn.CELU()
9659

9660
            def forward(self, input):
9661
                return self.celu(input)
9662

9663
        input = torch.randn(2, 5, 7, dtype=torch.float64)
9664
        self.run_test(Celu(), (input,))
9665

9666
    def test_lower_tuple(self):
9667
        class TupleModule(torch.nn.Module):
9668
            def forward(self, input1: Tensor, input2: Tensor, input3: Tensor) -> Tensor:
9669
                a = (input1, input2)
9670
                b = a
9671
                c = (input1, input2, input3)
9672
                for i in range(5):
9673
                    d = a[0]
9674
                    for j in range(2):
9675
                        e, f = a
9676
                        a = (d, f)
9677
                        f = c[2]
9678
                        if f.size(0) != input1.size(-1):
9679
                            g = b[1]
9680
                            b = (g, f)
9681
                        else:
9682
                            k = c[1:]
9683
                            b = (f, k[0])
9684
                    m, n = b
9685
                    c = (input1, n, m)
9686
                p, q, r = c
9687
                return p + q + r
9688

9689
        input1 = torch.randn(2)
9690
        input2 = torch.randn(2)
9691
        input3 = torch.randn(2)
9692
        self.run_test(TupleModule(), (input1, input2, input3))
9693

9694
    def test_lower_tuple_2(self):
9695
        class TupleModule(torch.nn.Module):
9696
            def forward(self, input1: Tensor, input2: Tensor) -> Tuple[Tensor, Tensor]:
9697
                a = (input1, input2)
9698
                for x in range(5):
9699
                    c, d = a
9700
                    a = (c, d)
9701
                return a
9702

9703
        input1 = torch.randn(2)
9704
        input2 = torch.randn(2)
9705
        self.run_test(TupleModule(), (input1, input2))
9706

9707
    def test_lower_tuple_3(self):
9708
        class TupleModule(torch.nn.Module):
9709
            def forward(
9710
                self,
9711
                input1: Tuple[Tensor, Tensor],
9712
                input2: Tuple[Tensor, Tensor],
9713
            ) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
9714
                a = input1
9715
                b = input2
9716
                for x in range(5):
9717
                    c, d = a
9718
                    e, f = b
9719
                    if c.shape[0] == e.shape[0]:
9720
                        e = e + c
9721
                    else:
9722
                        f = f + d
9723
                    a = (e, f)
9724
                    b = (c, d)
9725
                return a, b
9726

9727
        input1 = (torch.randn(2), torch.randn(2))
9728
        input2 = (torch.randn(2), torch.randn(2))
9729
        self.run_test(TupleModule(), (input1, input2))
9730

9731
    @skipIfUnsupportedMinOpsetVersion(9)
9732
    def test_where(self):
9733
        class Model(torch.nn.Module):
9734
            def forward(self, cond, input, other):
9735
                return torch.where(cond, input, other)
9736

9737
        x = torch.randint(0, 1, (2, 3, 4), dtype=torch.bool)
9738
        y = torch.randn(2, 1, 4)
9739
        z = torch.ones(2, 3, 1)
9740
        self.run_test(Model(), (x, y, z))
9741

9742
    @skipIfUnsupportedMinOpsetVersion(9)
9743
    @skipScriptTest()  # scripting tests run for opsets > 11. See: test_where_condition_script
9744
    def test_where_condition(self):
9745
        class Model1(torch.nn.Module):
9746
            def forward(self, input):
9747
                return torch.stack(torch.where(input > 0.5), dim=1)
9748

9749
        x = torch.randint(0, 2, (2, 3, 4), dtype=bool)
9750
        self.run_test(Model1(), (x))
9751

9752
        class Model2(torch.nn.Module):
9753
            def forward(self, input, other):
9754
                return torch.stack(torch.where(input > other), dim=1)
9755

9756
        x = torch.randint(0, 1, (2, 3, 4), dtype=bool)
9757
        y = torch.randint(1, 2, (2, 3, 4), dtype=bool)
9758
        self.run_test(Model2(), (x, y))
9759

9760
    @skipIfUnsupportedOpsetVersion([13])
9761
    @skipIfUnsupportedMinOpsetVersion(11)
9762
    def test_where_condition_script(self):
9763
        class Model1(torch.nn.Module):
9764
            def forward(self, input):
9765
                return torch.stack(torch.where(input > 0.5), dim=1)
9766

9767
        x = torch.randint(0, 2, (2, 3, 4), dtype=bool)
9768
        self.run_test(Model1(), (x))
9769

9770
        class Model2(torch.nn.Module):
9771
            def forward(self, input, other):
9772
                return torch.stack(torch.where(input > other), dim=1)
9773

9774
        x = torch.randint(0, 1, (2, 3, 4), dtype=bool)
9775
        y = torch.randint(1, 2, (2, 3, 4), dtype=bool)
9776
        self.run_test(Model2(), (x, y))
9777

9778
    def test_empty_branch(self):
9779
        class EmptyBranchModel(torch.jit.ScriptModule):
9780
            @torch.jit.script_method
9781
            def forward(self, input):
9782
                out = input + 1
9783
                if out.dim() > 2:
9784
                    if out.dim() > 3:
9785
                        out += 3
9786
                    else:
9787
                        pass
9788
                else:
9789
                    pass
9790
                return out
9791

9792
        x = torch.randn(1, 2, 3, requires_grad=True)
9793
        self.run_test(EmptyBranchModel(), x)
9794

9795
    @skipIfUnsupportedMinOpsetVersion(11)
9796
    def test_derive_index_scripting(self):
9797
        class MyModule(torch.nn.Module):
9798
            def forward(self, x: Tensor):
9799
                j = []
9800
                for idx in range(len(x) - 1, -len(x), -2):
9801
                    y = x[idx]
9802
                    j += [x * y]
9803
                return j
9804

9805
        x = torch.randn(5, 13)
9806
        self.run_test(MyModule(), x)
9807

9808
        class MyModule(torch.nn.Module):
9809
            def forward(self, x: Tensor):
9810
                j = []
9811
                for idx in range(-len(x), len(x) - 1, 2):
9812
                    y = x[idx]
9813
                    j += [x * y]
9814
                return j
9815

9816
        x = torch.randn(5, 13)
9817
        self.run_test(MyModule(), x)
9818

9819
        class MyModule(torch.nn.Module):
9820
            def forward(self, x: Tensor):
9821
                j = []
9822
                for idx in range(len(x) - 1, -len(x), -3):
9823
                    y = x[idx]
9824
                    j += [x * y]
9825
                return j
9826

9827
        self.run_test(MyModule(), x)
9828

9829
        class MyModule(torch.nn.Module):
9830
            def forward(self, x: Tensor):
9831
                j = []
9832
                for idx in range(-len(x), len(x) - 1, 3):
9833
                    y = x[idx]
9834
                    j += [x * y]
9835
                return j
9836

9837
        self.run_test(MyModule(), x)
9838

9839
    @skipScriptTest()  # Scripting fails for add lists for opsets < 11. Chek test_derive_index_scripting
9840
    def test_derive_index(self):
9841
        class MyModule(torch.nn.Module):
9842
            def forward(self, x: Tensor):
9843
                j = []
9844
                for idx in range(len(x) - 1, -len(x), -2):
9845
                    y = x[idx]
9846
                    j += [x * y]
9847
                return j
9848

9849
        x = torch.randn(5, 13)
9850
        self.run_test(MyModule(), x)
9851

9852
        class MyModule(torch.nn.Module):
9853
            def forward(self, x: Tensor):
9854
                j = []
9855
                for idx in range(-len(x), len(x) - 1, 2):
9856
                    y = x[idx]
9857
                    j += [x * y]
9858
                return j
9859

9860
        x = torch.randn(5, 13)
9861
        self.run_test(MyModule(), x)
9862

9863
        class MyModule(torch.nn.Module):
9864
            def forward(self, x: Tensor):
9865
                j = []
9866
                for idx in range(len(x) - 1, -len(x), -3):
9867
                    y = x[idx]
9868
                    j += [x * y]
9869
                return j
9870

9871
        self.run_test(MyModule(), x)
9872

9873
        class MyModule(torch.nn.Module):
9874
            def forward(self, x: Tensor):
9875
                j = []
9876
                for idx in range(-len(x), len(x) - 1, 3):
9877
                    y = x[idx]
9878
                    j += [x * y]
9879
                return j
9880

9881
        self.run_test(MyModule(), x)
9882

9883
    @skipIfUnsupportedMinOpsetVersion(11)
9884
    def test_if_transpose(self):
9885
        class IfModel(torch.nn.Module):
9886
            def forward(self, x):
9887
                x = x.transpose(0, 1)
9888
                if x.size(0) == 2:
9889
                    return x.transpose(0, 1)
9890
                else:
9891
                    return x
9892

9893
        x = torch.randn(2, 3)
9894
        self.run_test(
9895
            torch.jit.script(IfModel()),
9896
            x,
9897
            output_names=["output_1"],
9898
            dynamic_axes={"output_1": [0, 1]},
9899
        )
9900

9901
    @skipIfUnsupportedMinOpsetVersion(13)
9902
    def test_if_list(self):
9903
        class IfModel(torch.nn.Module):
9904
            def forward(self, x, y, cond):
9905
                res = []
9906
                if cond:
9907
                    res = res + [x]
9908
                else:
9909
                    res = res + [y]
9910
                return res
9911

9912
        x = torch.randn(2, 3)
9913
        y = torch.randn(3, 3)
9914
        cond = torch.tensor(1, dtype=torch.bool)
9915
        self.run_test(torch.jit.script(IfModel()), (x, y, cond))
9916

9917
    @skipIfUnsupportedMinOpsetVersion(13)
9918
    def test_if_view(self):
9919
        class IfModel(torch.nn.Module):
9920
            def forward(self, x, y, cond):
9921
                bs, seq = y.shape[:2]
9922
                if cond:
9923
                    res = x.view(bs, seq, -1)
9924
                else:
9925
                    res = y
9926
                return res.transpose(1, 2)
9927

9928
        x = torch.randn(2, 16, 2, 2)
9929
        y = torch.randn(2, 16, 8)
9930
        cond = torch.tensor(1, dtype=torch.bool)
9931
        self.run_test(
9932
            torch.jit.script(IfModel()),
9933
            (x, y, cond),
9934
            output_names=["output_1"],
9935
            dynamic_axes={"output_1": [1]},
9936
        )
9937

9938
    @skipScriptTest(
9939
        skip_before_opset_version=11, reason="dynamic split support added in 11"
9940
    )
9941
    def test_split_tensor_scalar(self):
9942
        class SplitModel(torch.nn.Module):
9943
            def forward(self, x):
9944
                return torch.split(x, x.size(1))
9945

9946
        x = torch.randn(1, 2, 3, requires_grad=True)
9947
        self.run_test(SplitModel(), x)
9948

9949
    def test_split_tensor_multi(self):
9950
        class SplitModel(torch.nn.Module):
9951
            def forward(self, x):
9952
                return torch.split(x, torch.ones(3))
9953

9954
        x = torch.randn(1, 2, 3, requires_grad=True)
9955

9956
        def run_model():
9957
            SplitModel(x)
9958

9959
        self.assertRaises(TypeError, run_model)
9960

9961
    @skipIfUnsupportedMinOpsetVersion(9)
9962
    def test_embedding(self):
9963
        class EmbedModel(torch.nn.Module):
9964
            def forward(self, input, emb):
9965
                return torch.nn.functional.embedding(input, emb, padding_idx=1)
9966

9967
        model = EmbedModel()
9968
        x = torch.randint(4, (4,))
9969
        x[2] = x[0] = 1
9970
        embedding_matrix = torch.rand(10, 3)
9971
        self.run_test(model, (x, embedding_matrix))
9972

9973
        x = torch.randint(4, (4, 3, 2))
9974
        x[2] = 1
9975
        x[0][1] = 1
9976
        self.run_test(model, (x, embedding_matrix))
9977
        self.run_test(
9978
            model, (x, embedding_matrix), training=torch.onnx.TrainingMode.TRAINING
9979
        )
9980

9981
        class EmbedModelWithoutPaddingIdx(torch.nn.Module):
9982
            def forward(self, input, emb):
9983
                return torch.nn.functional.embedding(input, emb)
9984

9985
        model = EmbedModelWithoutPaddingIdx()
9986
        x = torch.randint(4, (4, 3, 2))
9987
        self.run_test(model, (x, embedding_matrix))
9988

9989
    @skipIfUnsupportedMinOpsetVersion(9)
9990
    def test_embedding_module(self):
9991
        class EmbedModel(torch.nn.Module):
9992
            def __init__(self):
9993
                super().__init__()
9994
                self.emb = torch.nn.Embedding(4, 3, padding_idx=1)
9995
                self.emb2 = torch.nn.Embedding(4, 3, padding_idx=1)
9996
                with torch.no_grad():
9997
                    self.emb2.weight[1] = torch.ones(3)
9998

9999
            def forward(self, input):
10000
                return self.emb(input), self.emb2(input)
10001

10002
        model = EmbedModel()
10003
        x = torch.randint(4, (4,))
10004
        x[2] = x[0] = 1
10005
        self.run_test(model, (x,))
10006

10007
        x = torch.randint(4, (4, 3, 2))
10008
        x[2] = 1
10009
        x[0][1] = 1
10010
        self.run_test(model, (x,))
10011

10012
        class EmbedModelWithoutPaddingIdx(torch.nn.Module):
10013
            def __init__(self):
10014
                super().__init__()
10015
                self.emb = torch.nn.Embedding(4, 3)
10016

10017
            def forward(self, input):
10018
                return self.emb(input)
10019

10020
        model = EmbedModelWithoutPaddingIdx()
10021
        x = torch.randint(4, (4, 3, 2))
10022
        self.run_test(model, (x,))
10023

10024
    @skipIfUnsupportedMinOpsetVersion(11)
10025
    def test_embedding_renorm(self):
10026
        n, d = 7, 5
10027
        embedding = torch.nn.Embedding(n, d, max_norm=0.2)
10028
        idx = torch.tensor([2, 1])
10029
        self.run_test(embedding, idx)
10030

10031
        embedding = torch.nn.Embedding(n, d, max_norm=0.5, norm_type=1.0)
10032
        idx = torch.tensor([4, 3, 4, 2])
10033
        self.run_test(embedding, idx)
10034

10035
    def _dispatch_rnn_test(self, name, *args, **kwargs):
10036
        if name == "elman":
10037
            self._elman_rnn_test(*args, **kwargs)
10038
        if name == "lstm":
10039
            self._lstm_test(*args, **kwargs)
10040
        if name == "gru":
10041
            self._gru_test(*args, **kwargs)
10042

10043
    def _elman_rnn_test(
10044
        self,
10045
        layers,
10046
        nonlinearity,
10047
        bidirectional,
10048
        initial_state,
10049
        packed_sequence,
10050
        dropout,
10051
        **extra_kwargs,
10052
    ):
10053
        class ElmanWithStateModel(torch.nn.Module):
10054
            def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first):
10055
                super().__init__()
10056

10057
                self.batch_first = batch_first
10058
                self.inner_model = torch.nn.RNN(
10059
                    RNN_INPUT_SIZE,
10060
                    RNN_HIDDEN_SIZE,
10061
                    layers,
10062
                    nonlinearity=nonlinearity,
10063
                    bidirectional=bidirectional,
10064
                    dropout=dropout,
10065
                    batch_first=batch_first,
10066
                )
10067

10068
            def forward(self, input: rnn_utils.PackedSequence, hx=None):
10069
                return self.inner_model(input, hx)
10070

10071
        class ElmanWithoutStateModel(torch.nn.Module):
10072
            def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first):
10073
                super().__init__()
10074
                self.batch_first = batch_first
10075
                self.inner_model = torch.nn.RNN(
10076
                    RNN_INPUT_SIZE,
10077
                    RNN_HIDDEN_SIZE,
10078
                    layers,
10079
                    nonlinearity=nonlinearity,
10080
                    bidirectional=bidirectional,
10081
                    dropout=dropout,
10082
                    batch_first=batch_first,
10083
                )
10084

10085
            def forward(self, input: rnn_utils.PackedSequence):
10086
                return self.inner_model(input)
10087

10088
        batch_first = packed_sequence == 2
10089

10090
        if initial_state:
10091
            model = ElmanWithStateModel(
10092
                layers=layers,
10093
                bidirect=bidirectional,
10094
                nonlinearity=nonlinearity,
10095
                dropout=dropout,
10096
                batch_first=batch_first,
10097
            )
10098
            if packed_sequence:
10099
                model = (
10100
                    rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState(
10101
                        model, batch_first
10102
                    )
10103
                )
10104
        else:
10105
            model = ElmanWithoutStateModel(
10106
                layers=layers,
10107
                bidirect=bidirectional,
10108
                nonlinearity=nonlinearity,
10109
                dropout=dropout,
10110
                batch_first=batch_first,
10111
            )
10112
            if packed_sequence:
10113
                model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState(
10114
                    model, batch_first
10115
                )
10116

10117
        def make_input(batch_size):
10118
            seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
10119
            seq_lengths = sorted(map(int, seq_lengths), reverse=True)
10120
            inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
10121
            inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
10122
            inputs = [inputs]
10123
            input_names = ["input"]
10124

10125
            directions = 2 if bidirectional else 1
10126

10127
            if initial_state:
10128
                h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
10129
                inputs.append(h0)
10130
                input_names.append("h0")
10131
            if packed_sequence != 0:
10132
                inputs.append(torch.IntTensor(seq_lengths))
10133
                input_names.append("seq_lengths")
10134
            if len(inputs) == 1:
10135
                input = inputs[0]
10136
            else:
10137
                input = tuple(inputs)
10138
            return input, input_names
10139

10140
        input, input_names = make_input(RNN_BATCH_SIZE)
10141
        dynamic_axes = {"input": [0, 1], "seq_lengths": [0]}
10142
        if initial_state:
10143
            dynamic_axes.update({"h0": [1]})
10144
        export_options = {"input_names": input_names, "dynamic_axes": dynamic_axes}
10145

10146
        # test that the model still runs with a different batch size
10147
        other_input, _ = make_input(RNN_BATCH_SIZE + 1)
10148
        self.run_test(
10149
            model, input, additional_test_inputs=[other_input], **export_options
10150
        )
10151

10152
    def _lstm_test(
10153
        self,
10154
        layers,
10155
        bidirectional,
10156
        initial_state,
10157
        packed_sequence,
10158
        dropout,
10159
        **extra_kwargs,
10160
    ):
10161
        batch_first = packed_sequence == 2
10162

10163
        if packed_sequence:
10164
            model = lstm_flattening_result.LstmFlatteningResultWithSeqLength(
10165
                RNN_INPUT_SIZE,
10166
                RNN_HIDDEN_SIZE,
10167
                layers,
10168
                bidirectional,
10169
                dropout,
10170
                batch_first,
10171
            )
10172
            if initial_state:
10173
                model = (
10174
                    rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState(
10175
                        model, batch_first
10176
                    )
10177
                )
10178
            else:
10179
                model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState(
10180
                    model, batch_first
10181
                )
10182
        else:
10183
            model = lstm_flattening_result.LstmFlatteningResultWithoutSeqLength(
10184
                RNN_INPUT_SIZE,
10185
                RNN_HIDDEN_SIZE,
10186
                layers,
10187
                bidirectional,
10188
                dropout,
10189
                batch_first,
10190
            )
10191

10192
        def make_input(batch_size):
10193
            seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
10194
            seq_lengths = sorted(map(int, seq_lengths), reverse=True)
10195
            inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
10196
            inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
10197
            inputs = [inputs]
10198
            input_names = ["input"]
10199
            directions = 2 if bidirectional else 1
10200

10201
            if initial_state:
10202
                h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
10203
                c0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
10204
                inputs.append((h0, c0))
10205
                input_names.append("h0")
10206
                input_names.append("c0")
10207
            if packed_sequence != 0:
10208
                inputs.append(torch.IntTensor(seq_lengths))
10209
                input_names.append("seq_lengths")
10210
            if len(inputs) == 1:
10211
                input = inputs[0]
10212
            else:
10213
                input = tuple(inputs)
10214
            return input, input_names
10215

10216
        input, input_names = make_input(RNN_BATCH_SIZE)
10217
        dynamic_axes = {"input": [0, 1], "seq_lengths": [0]}
10218
        if initial_state:
10219
            dynamic_axes.update({"h0": [1], "c0": [1]})
10220
        export_options = {"input_names": input_names, "dynamic_axes": dynamic_axes}
10221

10222
        # test that the model still runs with a different batch size
10223
        other_input, _ = make_input(RNN_BATCH_SIZE + 1)
10224
        self.run_test(
10225
            model, input, additional_test_inputs=[other_input], **export_options
10226
        )
10227

10228
    def _gru_test(
10229
        self,
10230
        layers,
10231
        bidirectional,
10232
        initial_state,
10233
        packed_sequence,
10234
        dropout,
10235
        **extra_kwargs,
10236
    ):
10237
        class GRUWithStateModel(torch.nn.Module):
10238
            def __init__(self, layers, bidirect, dropout, batch_first):
10239
                super().__init__()
10240

10241
                self.batch_first = batch_first
10242
                self.inner_model = torch.nn.GRU(
10243
                    RNN_INPUT_SIZE,
10244
                    RNN_HIDDEN_SIZE,
10245
                    num_layers=layers,
10246
                    bidirectional=bidirectional,
10247
                    dropout=dropout,
10248
                    batch_first=batch_first,
10249
                )
10250

10251
            def forward(self, input: rnn_utils.PackedSequence, hx):
10252
                return self.inner_model(input, hx)
10253

10254
        class GRUWithoutStateModel(torch.nn.Module):
10255
            def __init__(self, layers, bidirect, dropout, batch_first):
10256
                super().__init__()
10257
                self.batch_first = batch_first
10258
                self.inner_model = torch.nn.GRU(
10259
                    RNN_INPUT_SIZE,
10260
                    RNN_HIDDEN_SIZE,
10261
                    num_layers=layers,
10262
                    bidirectional=bidirectional,
10263
                    dropout=dropout,
10264
                    batch_first=batch_first,
10265
                )
10266

10267
            def forward(self, input: rnn_utils.PackedSequence):
10268
                return self.inner_model(input)
10269

10270
        class GRUNoSeqLengthWithoutStateModel(torch.nn.Module):
10271
            def __init__(self, layers, bidirect, dropout, batch_first):
10272
                super().__init__()
10273
                self.batch_first = batch_first
10274
                self.inner_model = torch.nn.GRU(
10275
                    RNN_INPUT_SIZE,
10276
                    RNN_HIDDEN_SIZE,
10277
                    num_layers=layers,
10278
                    bidirectional=bidirectional,
10279
                    dropout=dropout,
10280
                    batch_first=batch_first,
10281
                )
10282

10283
            def forward(self, input):
10284
                return self.inner_model(input)
10285

10286
        class GRUNoSeqLengthWithStateModel(torch.nn.Module):
10287
            def __init__(self, layers, bidirect, dropout, batch_first):
10288
                super().__init__()
10289
                self.batch_first = batch_first
10290
                self.inner_model = torch.nn.GRU(
10291
                    RNN_INPUT_SIZE,
10292
                    RNN_HIDDEN_SIZE,
10293
                    num_layers=layers,
10294
                    bidirectional=bidirectional,
10295
                    dropout=dropout,
10296
                    batch_first=batch_first,
10297
                )
10298

10299
            def forward(self, input, hx):
10300
                return self.inner_model(input, hx)
10301

10302
        batch_first = packed_sequence == 2
10303

10304
        if packed_sequence:
10305
            if initial_state:
10306
                model = GRUWithStateModel(
10307
                    layers=layers,
10308
                    bidirect=bidirectional,
10309
                    dropout=dropout,
10310
                    batch_first=batch_first,
10311
                )
10312
                model = (
10313
                    rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState(
10314
                        model, batch_first
10315
                    )
10316
                )
10317
            else:
10318
                model = GRUWithoutStateModel(
10319
                    layers=layers,
10320
                    bidirect=bidirectional,
10321
                    dropout=dropout,
10322
                    batch_first=batch_first,
10323
                )
10324
                model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState(
10325
                    model, batch_first
10326
                )
10327
        else:
10328
            if initial_state:
10329
                model = GRUNoSeqLengthWithStateModel(
10330
                    layers=layers,
10331
                    bidirect=bidirectional,
10332
                    dropout=dropout,
10333
                    batch_first=batch_first,
10334
                )
10335
            else:
10336
                model = GRUNoSeqLengthWithoutStateModel(
10337
                    layers=layers,
10338
                    bidirect=bidirectional,
10339
                    dropout=dropout,
10340
                    batch_first=batch_first,
10341
                )
10342

10343
        def make_input(batch_size):
10344
            seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
10345
            seq_lengths = sorted(map(int, seq_lengths), reverse=True)
10346
            inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
10347
            inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
10348
            inputs = [inputs]
10349
            input_names = ["input"]
10350

10351
            directions = 2 if bidirectional else 1
10352

10353
            if initial_state:
10354
                h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
10355
                inputs.append(h0)
10356
                input_names.append("h0")
10357
            if packed_sequence != 0:
10358
                inputs.append(torch.IntTensor(seq_lengths))
10359
                input_names.append("seq_lengths")
10360
            if len(inputs) == 1:
10361
                input = inputs[0]
10362
            else:
10363
                input = tuple(inputs)
10364
            return input, input_names
10365

10366
        input, input_names = make_input(RNN_BATCH_SIZE)
10367
        dynamic_axes = {"input": [0, 1], "seq_lengths": [0]}
10368
        if initial_state:
10369
            dynamic_axes.update({"h0": [1]})
10370
        export_options = {"input_names": input_names, "dynamic_axes": dynamic_axes}
10371

10372
        # test that the model still runs with a different batch size
10373
        other_input, _ = make_input(RNN_BATCH_SIZE + 1)
10374
        self.run_test(
10375
            model, input, additional_test_inputs=[other_input], **export_options
10376
        )
10377

10378
    @skipIfUnsupportedMinOpsetVersion(10)
10379
    def test_fake_quantize_per_tensor(self):
10380
        class FakeQuantizePerTensorModel(torch.nn.Module):
10381
            def forward(self, input):
10382
                scale = 1.0 / 127
10383
                zero_point = 0
10384
                quant_min = -128
10385
                quant_max = 127
10386
                return torch.fake_quantize_per_tensor_affine(
10387
                    input, scale, zero_point, quant_min, quant_max
10388
                )
10389

10390
        x = torch.randn(6, 4, 3, 3)
10391
        self.run_test(FakeQuantizePerTensorModel(), (x))
10392

10393
    @skipIfUnsupportedMinOpsetVersion(13)
10394
    def test_fake_quantize_per_tensor_dynamic_scale_zeropoint(self):
10395
        class FakeQuantizePerTensorModel(torch.nn.Module):
10396
            def forward(self, input, scale, zero_point):
10397
                quant_min = -128
10398
                quant_max = 127
10399
                return torch.fake_quantize_per_tensor_affine(
10400
                    input, scale, zero_point, quant_min, quant_max
10401
                )
10402

10403
        x = torch.randn(6, 4, 3, 3)
10404
        scale = torch.tensor(1.0 / 127)
10405
        zero_point = torch.tensor(0)
10406
        self.run_test(FakeQuantizePerTensorModel(), (x, scale, zero_point))
10407

10408
    @skipIfUnsupportedMinOpsetVersion(13)
10409
    def test_fake_quantize_per_channel(self):
10410
        class FakeQuantizePerChannelModel(torch.nn.Module):
10411
            def forward(self, input):
10412
                amax = torch.ones(4)
10413
                scale = amax / 127.0
10414
                zero_point = torch.zeros_like(amax, dtype=torch.int)
10415
                # Quantize twice to test differnet branches
10416
                y = torch.fake_quantize_per_channel_affine(
10417
                    input, scale, zero_point, 1, 0, 255
10418
                )
10419
                return torch.fake_quantize_per_channel_affine(
10420
                    y, scale, zero_point, 1, -128, 127
10421
                )
10422

10423
        x = torch.randn(6, 4, 3, 3)
10424
        self.run_test(FakeQuantizePerChannelModel(), (x))
10425

10426
    @skipIfUnsupportedMinOpsetVersion(13)
10427
    # RuntimeError: Can't redefine method:
10428
    # forward on class: __torch__.torch.nn.modules.linear.Linear
10429
    @skipScriptTest()
10430
    def test_fake_quantize_activation(self):
10431
        from torch.ao import quantization
10432

10433
        m = torch.nn.Linear(1, 1)
10434
        m.qconfig = quantization.QConfig(
10435
            activation=quantization.default_fake_quant,
10436
            weight=quantization.default_per_channel_weight_fake_quant,
10437
        )
10438
        quantization.prepare_qat(m.train(), inplace=True)
10439
        m.apply(quantization.enable_observer)
10440
        m.apply(quantization.enable_fake_quant)
10441
        for module in m.modules():
10442
            if isinstance(module, quantization.FakeQuantize):
10443
                module.calculate_qparams()
10444

10445
        m.apply(quantization.disable_observer)
10446
        m.eval()
10447

10448
        # Fake quantize activation is a special case, as it restricts quantized range to be (0, 127),
10449
        # while standard 8bit quantization range is (-128, 127) or (0, 255).
10450
        # Set fixed weight, bias and inputs to test if ONNX handles the overflow correctly.
10451
        m.weight = torch.nn.Parameter(torch.tensor([[1.0], [1.0], [1.0]]))
10452
        m.bias = torch.nn.Parameter(torch.tensor([0.0]))
10453
        x = torch.tensor([[150.0], [127.0], [-5.0]])
10454
        self.run_test(m, x)
10455

10456
    def test_batchnorm_training(self):
10457
        class MyModule(torch.nn.Module):
10458
            def __init__(self):
10459
                super().__init__()
10460
                self.bn1 = torch.nn.BatchNorm2d(3, affine=False)
10461
                self.cv1 = torch.nn.Conv2d(3, 3, 10)
10462
                self.bn2 = torch.nn.BatchNorm2d(3, affine=True)
10463
                self.cv2 = torch.nn.Conv2d(3, 3, 10)
10464
                self.bn3 = torch.nn.BatchNorm2d(3, affine=False)
10465

10466
            def forward(self, x):
10467
                x = self.bn1(x)
10468
                x = self.cv1(x)
10469
                x = self.bn2(x)
10470
                x = self.cv2(x)
10471
                x = self.bn3(x)
10472
                return x
10473

10474
        x = torch.randn(10, 3, 20, 20) * 2
10475
        model_export = MyModule()
10476
        self.run_test(
10477
            model_export,
10478
            (x,),
10479
            training=torch.onnx.TrainingMode.TRAINING,
10480
            rtol=1e-3,
10481
            atol=1e-5,
10482
        )
10483
        model_export.train()
10484
        self.run_test(
10485
            model_export,
10486
            (x,),
10487
            training=torch.onnx.TrainingMode.PRESERVE,
10488
            rtol=1e-3,
10489
            atol=1e-5,
10490
        )
10491

10492
    def test_batchnorm_training_mode_fix_layer(self):
10493
        class MyModule(torch.nn.Module):
10494
            def __init__(self):
10495
                super().__init__()
10496
                self.bn1 = torch.nn.BatchNorm2d(3, affine=True)
10497
                self.cv1 = torch.nn.Conv2d(3, 3, 10)
10498
                self.bn2 = torch.nn.BatchNorm2d(3, affine=False)
10499
                self.cv2 = torch.nn.Conv2d(3, 3, 10)
10500
                self.bn3 = torch.nn.BatchNorm2d(3, affine=True)
10501
                self.bn3.eval()
10502

10503
            def forward(self, x):
10504
                x = self.bn1(x)
10505
                x = self.cv1(x)
10506
                x = self.bn2(x)
10507
                x = self.cv2(x)
10508
                x = self.bn3(x)
10509
                return x
10510

10511
        x = torch.randn(10, 3, 128, 128)
10512
        model_export = MyModule()
10513
        self.run_test(
10514
            model_export,
10515
            (x,),
10516
            training=torch.onnx.TrainingMode.TRAINING,
10517
            rtol=1e-3,
10518
            atol=1e-5,
10519
        )
10520
        model_export.train()
10521
        self.run_test(
10522
            model_export,
10523
            (x,),
10524
            training=torch.onnx.TrainingMode.PRESERVE,
10525
            rtol=1e-3,
10526
            atol=1e-5,
10527
        )
10528

10529
    def test_batchnorm_eval_mode_train_layer(self):
10530
        class MyModule(torch.nn.Module):
10531
            def __init__(self):
10532
                super().__init__()
10533
                self.bn1 = torch.nn.BatchNorm2d(3, affine=True)
10534
                self.cv1 = torch.nn.Conv2d(3, 3, 10)
10535
                self.bn2 = torch.nn.BatchNorm2d(3, affine=False)
10536
                self.cv2 = torch.nn.Conv2d(3, 3, 10)
10537
                self.bn3 = torch.nn.BatchNorm2d(3, affine=True)
10538
                self.bn3.train()
10539

10540
            def forward(self, x):
10541
                x = self.bn1(x)
10542
                x = self.cv1(x)
10543
                x = self.bn2(x)
10544
                x = self.cv2(x)
10545
                x = self.bn3(x)
10546
                return x
10547

10548
        x = torch.randn(10, 3, 128, 128)
10549
        model_export = MyModule()
10550
        self.run_test(
10551
            model_export,
10552
            (x,),
10553
            training=torch.onnx.TrainingMode.EVAL,
10554
            rtol=1e-3,
10555
            atol=1e-5,
10556
        )
10557
        model_export.eval()
10558
        self.run_test(
10559
            model_export,
10560
            (x,),
10561
            training=torch.onnx.TrainingMode.PRESERVE,
10562
            rtol=1e-3,
10563
            atol=1e-5,
10564
        )
10565

10566
    def test_instancenorm_training(self):
10567
        class MyModule(torch.nn.Module):
10568
            def __init__(self):
10569
                super().__init__()
10570
                self.in1 = torch.nn.InstanceNorm2d(3, affine=True)
10571
                self.cv1 = torch.nn.Conv2d(3, 3, 10)
10572
                self.in2 = torch.nn.InstanceNorm2d(3, affine=False)
10573
                self.cv2 = torch.nn.Conv2d(3, 3, 10)
10574
                self.in3 = torch.nn.InstanceNorm2d(3, affine=True)
10575

10576
            def forward(self, x):
10577
                x = self.in1(x)
10578
                x = self.cv1(x)
10579
                x = self.in2(x)
10580
                x = self.cv2(x)
10581
                x = self.in3(x)
10582
                return x
10583

10584
        x = torch.randn(10, 3, 128, 128)
10585
        model_export = MyModule()
10586
        self.run_test(
10587
            model_export,
10588
            (x,),
10589
            training=torch.onnx.TrainingMode.TRAINING,
10590
            rtol=1e-3,
10591
            atol=1e-5,
10592
        )
10593
        model_export.train()
10594
        self.run_test(
10595
            model_export,
10596
            (x,),
10597
            training=torch.onnx.TrainingMode.PRESERVE,
10598
            rtol=1e-3,
10599
            atol=1e-5,
10600
        )
10601

10602
    def test_instancenorm_training_mode_fix_layer(self):
10603
        class MyModule(torch.nn.Module):
10604
            def __init__(self):
10605
                super().__init__()
10606
                self.in1 = torch.nn.InstanceNorm2d(3, affine=True)
10607
                self.cv1 = torch.nn.Conv2d(3, 3, 10)
10608
                self.in2 = torch.nn.InstanceNorm2d(3, affine=False)
10609
                self.cv2 = torch.nn.Conv2d(3, 3, 10)
10610
                self.in3 = torch.nn.InstanceNorm2d(3, affine=True)
10611
                self.in3.eval()
10612

10613
            def forward(self, x):
10614
                x = self.in1(x)
10615
                x = self.cv1(x)
10616
                x = self.in2(x)
10617
                x = self.cv2(x)
10618
                x = self.in3(x)
10619
                return x
10620

10621
        x = torch.randn(10, 3, 128, 128)
10622
        model_export = MyModule()
10623
        self.run_test(
10624
            model_export,
10625
            (x,),
10626
            training=torch.onnx.TrainingMode.TRAINING,
10627
            rtol=1e-3,
10628
            atol=1e-5,
10629
        )
10630
        model_export.train()
10631
        self.run_test(
10632
            model_export,
10633
            (x,),
10634
            training=torch.onnx.TrainingMode.PRESERVE,
10635
            rtol=1e-3,
10636
            atol=1e-5,
10637
        )
10638

10639
    def test_instancenorm_eval_mode_train_layer(self):
10640
        class MyModule(torch.nn.Module):
10641
            def __init__(self):
10642
                super().__init__()
10643
                self.in1 = torch.nn.InstanceNorm2d(8, affine=True)
10644
                self.cv1 = torch.nn.Conv2d(8, 8, 10)
10645
                self.in2 = torch.nn.InstanceNorm2d(8, affine=False)
10646
                self.cv2 = torch.nn.Conv2d(8, 8, 10)
10647
                self.in3 = torch.nn.InstanceNorm2d(8, affine=True)
10648
                self.in3.train()
10649

10650
            def forward(self, x):
10651
                x = self.in1(x)
10652
                x = self.cv1(x)
10653
                x = self.in2(x)
10654
                x = self.cv2(x)
10655
                x = self.in3(x)
10656
                return x
10657

10658
        x = torch.randn(10, 8, 128, 128)
10659
        model_export = MyModule()
10660
        self.run_test(
10661
            model_export,
10662
            (x,),
10663
            training=torch.onnx.TrainingMode.EVAL,
10664
            rtol=1e-3,
10665
            atol=1e-5,
10666
        )
10667
        model_export.eval()
10668
        self.run_test(
10669
            model_export,
10670
            (x,),
10671
            training=torch.onnx.TrainingMode.PRESERVE,
10672
            rtol=1e-3,
10673
            atol=1e-5,
10674
        )
10675

10676
    @skipIfUnsupportedMinOpsetVersion(12)
10677
    def test_dropout_training(self):
10678
        class MyModule(torch.nn.Module):
10679
            def __init__(self):
10680
                super().__init__()
10681
                self.dropout = torch.nn.Dropout(0.4)
10682

10683
            def forward(self, x):
10684
                dropout = self.dropout(x)
10685
                return dropout
10686

10687
        model = MyModule()
10688
        x = torch.randn(10)
10689
        model.train()
10690

10691
        model_onnx = io.BytesIO()
10692
        torch.onnx.export(
10693
            model,
10694
            x,
10695
            model_onnx,
10696
            opset_version=self.opset_version,
10697
            do_constant_folding=False,
10698
            training=torch.onnx.TrainingMode.TRAINING,
10699
        )
10700
        ort_sess = verification._ort_session(model_onnx)
10701
        ort_outs = verification._run_onnx(ort_sess, (x,))
10702
        assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
10703

10704
        script_model = torch.jit.script(model)
10705
        output = model(x)
10706
        model_onnx = io.BytesIO()
10707
        torch.onnx.export(
10708
            model,
10709
            x,
10710
            model_onnx,
10711
            opset_version=self.opset_version,
10712
            do_constant_folding=False,
10713
            training=torch.onnx.TrainingMode.TRAINING,
10714
        )
10715
        ort_outs = verification._run_onnx(ort_sess, (x,))
10716
        assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
10717

10718
    @skipIfUnsupportedMinOpsetVersion(12)
10719
    def test_dropout_training_zero(self):
10720
        class MyModule(torch.nn.Module):
10721
            def __init__(self):
10722
                super().__init__()
10723
                self.dropout = torch.nn.Dropout(0.5)
10724

10725
            def forward(self, x):
10726
                dropout = self.dropout(x)
10727
                return dropout
10728

10729
        model = MyModule()
10730

10731
        # ensure there are no zeros in the input
10732
        x = torch.randn(10, 3, 128, 128)
10733
        y = x.numpy()
10734
        y_mask = np.where(y == 0, 1, y)
10735
        input = torch.from_numpy(y_mask)
10736
        nb_elements = torch.numel(input)
10737

10738
        model.train()
10739
        model_onnx = io.BytesIO()
10740
        torch.onnx.export(
10741
            model,
10742
            x,
10743
            model_onnx,
10744
            opset_version=self.opset_version,
10745
            do_constant_folding=False,
10746
            training=torch.onnx.TrainingMode.TRAINING,
10747
        )
10748
        ort_sess = verification._ort_session(model_onnx)
10749
        ort_outs = verification._run_onnx(ort_sess, (x,))
10750

10751
        y = model(input)
10752
        output = y.cpu().numpy()
10753
        ort_mask = np.where(ort_outs[0] != 0, 1, 0)
10754
        pyt_mask = np.where(output != 0, 1, 0)
10755

10756
        ratio_pytorch = np.sum(pyt_mask) / nb_elements
10757
        ratio_ort = np.sum(ort_mask) / nb_elements
10758

10759
        np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
10760

10761
        script_model = torch.jit.script(model)
10762
        y = model(input)
10763
        output = y.cpu().numpy()
10764
        model_onnx = io.BytesIO()
10765
        torch.onnx.export(
10766
            model,
10767
            x,
10768
            model_onnx,
10769
            opset_version=self.opset_version,
10770
            do_constant_folding=False,
10771
            training=torch.onnx.TrainingMode.TRAINING,
10772
        )
10773
        ort_sess = verification._ort_session(model_onnx)
10774
        ort_outs = verification._run_onnx(ort_sess, (x,))
10775
        ort_mask = np.where(ort_outs[0] != 0, 1, 0)
10776
        pyt_mask = np.where(output != 0, 1, 0)
10777

10778
        ratio_pytorch = np.sum(pyt_mask) / nb_elements
10779
        ratio_ort = np.sum(ort_mask) / nb_elements
10780

10781
        np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
10782

10783
    def test_conv_bn(self):
10784
        class MyModule(torch.nn.Module):
10785
            def __init__(self):
10786
                super().__init__()
10787
                self.conv = torch.nn.Conv2d(
10788
                    3, 16, kernel_size=1, stride=2, padding=3, bias=True
10789
                )
10790
                self.bn = torch.nn.BatchNorm2d(16, affine=True)
10791

10792
            def forward(self, x):
10793
                x = self.conv(x)
10794
                bn = self.bn(x)
10795
                return bn
10796

10797
        model_export = MyModule()
10798
        x = torch.randn(10, 3, 128, 128)
10799
        self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL)
10800
        self.run_test(
10801
            model_export,
10802
            (x,),
10803
            training=torch.onnx.TrainingMode.TRAINING,
10804
            rtol=1e-3,
10805
            atol=1e-5,
10806
        )
10807

10808
    def test_multiple_conv_bn(self):
10809
        class MyModule(torch.nn.Module):
10810
            def __init__(self):
10811
                super().__init__()
10812
                self.conv1 = torch.nn.Conv2d(
10813
                    3, 64, kernel_size=7, stride=2, padding=3, bias=False
10814
                )
10815
                self.conv2 = torch.nn.Conv2d(
10816
                    64, 2, kernel_size=1, stride=1, padding=0, bias=False
10817
                )
10818
                self.conv3 = torch.nn.Conv2d(
10819
                    2, 2, kernel_size=3, stride=1, padding=1, bias=False
10820
                )
10821
                self.bn = torch.nn.BatchNorm2d(64)
10822
                self.bn2 = torch.nn.BatchNorm2d(2)
10823
                self.relu = torch.nn.ReLU(inplace=True)
10824
                self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
10825

10826
            def forward(self, x):
10827
                x = self.conv1(x)
10828
                x = self.bn(x)
10829
                x = self.relu(x)
10830
                x = self.maxpool(x)
10831
                x = self.conv2(x)
10832
                x = self.bn2(x)
10833
                x = self.relu(x)
10834
                x = self.conv3(x)
10835
                x = self.bn2(x)
10836
                x = self.relu(x)
10837
                return x
10838

10839
        model_export = MyModule()
10840
        x = torch.randn(2, 3, 224, 224)
10841
        self.run_test(
10842
            model_export,
10843
            (x,),
10844
            training=torch.onnx.TrainingMode.TRAINING,
10845
            rtol=1e-3,
10846
            atol=1e-5,
10847
        )
10848
        self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL)
10849

10850
    @skipIfUnsupportedMinOpsetVersion(11)
10851
    def test_nms(self):
10852
        num_boxes = 100
10853
        boxes = torch.rand(num_boxes, 4)
10854
        boxes[:, 2:] += boxes[:, :2]
10855
        scores = torch.randn(num_boxes)
10856

10857
        class Module(torch.nn.Module):
10858
            def forward(self, boxes, scores):
10859
                return torchvision.ops.nms(boxes, scores, 0.5)
10860

10861
        self.run_test(Module(), (boxes, scores))
10862

10863
    @skipIfUnsupportedMinOpsetVersion(11)
10864
    def test_batched_nms(self):
10865
        num_boxes = 100
10866
        boxes = torch.rand(num_boxes, 4)
10867
        boxes[:, 2:] += boxes[:, :2]
10868
        scores = torch.randn(num_boxes)
10869
        idxs = torch.randint(0, 5, size=(num_boxes,))
10870

10871
        class Module(torch.nn.Module):
10872
            def forward(self, boxes, scores, idxs):
10873
                return torchvision.ops.batched_nms(boxes, scores, idxs, 0.5)
10874

10875
        self.run_test(Module(), (boxes, scores, idxs))
10876

10877
    @skipIfUnsupportedMinOpsetVersion(11)
10878
    @skipScriptTest()
10879
    def test_clip_boxes_to_image(self):
10880
        boxes = torch.randn(5, 4) * 500
10881
        boxes[:, 2:] += boxes[:, :2]
10882
        size = torch.randn(200, 300)
10883

10884
        size_2 = torch.randn(300, 400)
10885

10886
        class Module(torch.nn.Module):
10887
            def forward(self, boxes, size):
10888
                shape = (size.shape[0], size.shape[1])
10889
                return torchvision.ops.boxes.clip_boxes_to_image(boxes, shape)
10890

10891
        self.run_test(
10892
            Module(),
10893
            (boxes, size),
10894
            input_names=["boxes", "size"],
10895
            dynamic_axes={"size": [0, 1]},
10896
            additional_test_inputs=[(boxes, size), (boxes, size_2)],
10897
        )
10898

10899
    @skipScriptTest(
10900
        reason="Conditioning on input type via prim::isinstance unsupported in ONNX"
10901
    )
10902
    @skipIfUnsupportedMinOpsetVersion(11)
10903
    def test_roi_align(self):
10904
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10905
        single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
10906
        model = torchvision.ops.RoIAlign((5, 5), 1.0, 2)
10907
        self.run_test(model, (x, single_roi))
10908

10909
    @skipScriptTest(
10910
        reason="Conditioning on input type via prim::isinstance unsupported in ONNX"
10911
    )
10912
    @skipIfUnsupportedMinOpsetVersion(16)
10913
    def test_roi_align_aligned(self):
10914
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10915
        single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
10916
        model1 = torchvision.ops.RoIAlign((5, 5), 1.0, 2, aligned=True)
10917
        self.run_test(model1, (x, single_roi))
10918

10919
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10920
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
10921
        model2 = torchvision.ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
10922
        self.run_test(model2, (x, single_roi))
10923

10924
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10925
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
10926
        model3 = torchvision.ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
10927
        self.run_test(model3, (x, single_roi))
10928

10929
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10930
        single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
10931
        model4 = torchvision.ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
10932
        self.run_test(model4, (x, single_roi))
10933

10934
    @skipScriptTest(
10935
        reason="Conditioning on input type via prim::isinstance unsupported in ONNX"
10936
    )
10937
    @skipIfUnsupportedMinOpsetVersion(11)
10938
    def test_roi_pool(self):
10939
        x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
10940
        rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
10941
        pool_h = 5
10942
        pool_w = 5
10943
        model = torchvision.ops.RoIPool((pool_h, pool_w), 2.0)
10944
        self.run_test(model, (x, rois))
10945

10946
    @skipIfUnsupportedMinOpsetVersion(11)
10947
    def test_resize_images(self):
10948
        class TransformModule(torch.nn.Module):
10949
            def __init__(self):
10950
                super().__init__()
10951
                self.transform = _init_test_generalized_rcnn_transform()
10952

10953
            def forward(self, images):
10954
                return self.transform.resize(images, None)[0]
10955

10956
        input = torch.rand(3, 10, 20)
10957
        input_test = torch.rand(3, 100, 150)
10958
        self.run_test(
10959
            TransformModule(),
10960
            (input,),
10961
            input_names=["input1"],
10962
            dynamic_axes={"input1": [0, 1, 2]},
10963
            additional_test_inputs=[(input,), (input_test,)],
10964
        )
10965

10966
    @skipIfUnsupportedMinOpsetVersion(11)
10967
    @skipScriptTest()
10968
    def test_transform_images(self):
10969
        class TransformModule(torch.nn.Module):
10970
            def __init__(self):
10971
                super().__init__()
10972
                self.transform = _init_test_generalized_rcnn_transform()
10973

10974
            def forward(self, images: List[Tensor]):
10975
                return self.transform(images)[0].tensors
10976

10977
        input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
10978
        input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
10979
        self.run_test(
10980
            TransformModule(),
10981
            (input,),
10982
            additional_test_inputs=[(input,), (input_test,)],
10983
        )
10984

10985
    def get_features(self, images):
10986
        s0, s1 = images.shape[-2:]
10987
        features = [
10988
            ("0", torch.rand(2, 256, s0 // 4, s1 // 4)),
10989
            ("1", torch.rand(2, 256, s0 // 8, s1 // 8)),
10990
            ("2", torch.rand(2, 256, s0 // 16, s1 // 16)),
10991
            ("3", torch.rand(2, 256, s0 // 32, s1 // 32)),
10992
            ("4", torch.rand(2, 256, s0 // 64, s1 // 64)),
10993
        ]
10994
        features = OrderedDict(features)
10995
        return features
10996

10997
    @skipIfUnsupportedMinOpsetVersion(11)
10998
    @skipScriptTest()
10999
    def test_rpn(self):
11000
        class RPNModule(torch.nn.Module):
11001
            def __init__(self):
11002
                super().__init__()
11003
                self.rpn = _init_test_rpn()
11004

11005
            def forward(self, images, features: Dict[str, Tensor]):
11006
                images_m = torchvision.models.detection.image_list.ImageList(
11007
                    images, [(i.shape[-1], i.shape[-2]) for i in images]
11008
                )
11009
                return self.rpn(images_m, features)
11010

11011
        images = torch.rand(2, 3, 150, 150)
11012
        features = self.get_features(images)
11013
        images2 = torch.rand(2, 3, 80, 80)
11014
        test_features = self.get_features(images2)
11015

11016
        model = RPNModule()
11017
        model.eval()
11018
        model(images, features)
11019
        self.run_test(
11020
            model,
11021
            (images, features),
11022
            input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
11023
            dynamic_axes={
11024
                "input1": [0, 1, 2, 3],
11025
                "input2": [0, 1, 2, 3],
11026
                "input3": [0, 1, 2, 3],
11027
                "input4": [0, 1, 2, 3],
11028
                "input5": [0, 1, 2, 3],
11029
                "input6": [0, 1, 2, 3],
11030
            },
11031
            additional_test_inputs=[(images, features), (images2, test_features)],
11032
            # dict_check=False,
11033
        )
11034

11035
    @skipIfUnsupportedMaxOpsetVersion(15)  # TODO: Opset 16 RoiAlign result mismatch
11036
    @skipIfUnsupportedMinOpsetVersion(11)
11037
    @skipScriptTest()
11038
    def test_multi_scale_roi_align(self):
11039
        class TransformModule(torch.nn.Module):
11040
            def __init__(self):
11041
                super().__init__()
11042
                self.model = torchvision.ops.MultiScaleRoIAlign(
11043
                    ["feat1", "feat2"], 3, 2
11044
                )
11045
                self.image_sizes = [(512, 512)]
11046

11047
            def forward(self, input: Dict[str, Tensor], boxes: List[Tensor]) -> Tensor:
11048
                return self.model(input, boxes, self.image_sizes)
11049

11050
        i = OrderedDict()
11051
        i["feat1"] = torch.rand(1, 5, 64, 64)
11052
        i["feat2"] = torch.rand(1, 5, 16, 16)
11053
        boxes = torch.rand(6, 4) * 256
11054
        boxes[:, 2:] += boxes[:, :2]
11055

11056
        i1 = OrderedDict()
11057
        i1["feat1"] = torch.rand(1, 5, 64, 64)
11058
        i1["feat2"] = torch.rand(1, 5, 16, 16)
11059
        boxes1 = torch.rand(6, 4) * 256
11060
        boxes1[:, 2:] += boxes1[:, :2]
11061

11062
        self.run_test(
11063
            TransformModule(),
11064
            (
11065
                i,
11066
                [boxes],
11067
            ),
11068
            additional_test_inputs=[
11069
                (
11070
                    i,
11071
                    [boxes],
11072
                ),
11073
                (
11074
                    i1,
11075
                    [boxes1],
11076
                ),
11077
            ],
11078
        )
11079

11080
    def test_set_(self):
11081
        class M(torch.nn.Module):
11082
            def forward(self, x, y):
11083
                x.set_(y)
11084
                return x
11085

11086
        x = torch.ones(2, 3)
11087
        y = torch.randn(4, 6)
11088
        self.run_test(M(), (x, y), remained_onnx_input_idx=[1])
11089

11090
        y2 = torch.randn(5, 2)
11091
        self.run_test(
11092
            M(),
11093
            (x, y),
11094
            remained_onnx_input_idx=[1],
11095
            input_names=["x", "y"],
11096
            dynamic_axes={"x": [0, 1], "y": [0, 1]},
11097
            additional_test_inputs=[(y, y2)],
11098
        )
11099

11100
    @skipIfUnsupportedMinOpsetVersion(9)
11101
    def test_set_attr_modules(self):
11102
        class InnerModule2(torch.nn.Module):
11103
            def __init__(self, embedding_dim):
11104
                super().__init__()
11105
                self.weights = InnerModule2.get_embedding(embedding_dim)
11106
                self.register_buffer("_float_tensor", torch.FloatTensor(1))
11107
                self.const = 2
11108

11109
            @staticmethod
11110
            def get_embedding(embedding_dim: int):
11111
                emb = 4 / ((embedding_dim // 2) - 1)
11112
                emb = torch.exp(
11113
                    torch.arange((embedding_dim // 2), dtype=torch.float) * -emb
11114
                )
11115
                return emb
11116

11117
            def forward(self, input, incremental_state: Optional[Tensor] = None):
11118
                bsz, seq_len = input.shape[0], input.shape[1]
11119
                self.const = 3
11120
                if self.weights is None:
11121
                    self.weights = InnerModule.get_embedding(self.embedding_dim)
11122
                self.weights = self.weights.to(self._float_tensor)
11123
                self.weights = self.weights * self.const
11124
                if incremental_state is not None:
11125
                    pos = seq_len
11126
                    return self.weights[1 + pos, :].expand(bsz, 1, -1)
11127
                return self.weights.index_select(
11128
                    0, torch.ones((bsz * seq_len), dtype=torch.int64)
11129
                ).view(bsz, seq_len, -1)
11130

11131
        class InnerModule(torch.nn.Module):
11132
            def __init__(self, embedding_dim):
11133
                super().__init__()
11134
                self.weights = InnerModule.get_embedding(embedding_dim)
11135
                self.module = InnerModule2(embedding_dim=8)
11136

11137
            @staticmethod
11138
            def get_embedding(embedding_dim: int):
11139
                emb = 4 / ((embedding_dim // 2) - 1)
11140
                emb = torch.exp(
11141
                    torch.arange((embedding_dim // 2), dtype=torch.float) * -emb
11142
                )
11143
                return emb
11144

11145
            def forward(self, x):
11146
                return self.module(x) + self.weights
11147

11148
        class Module(torch.nn.Module):
11149
            def __init__(self):
11150
                super().__init__()
11151
                self.module = InnerModule(embedding_dim=8)
11152

11153
            def forward(self, x):
11154
                return self.module(x)
11155

11156
        x = torch.randn(3, 256)
11157
        self.run_test(Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]})
11158
        self.run_test(Module(), (x,), remained_onnx_input_idx=[])
11159

11160
    @skipIfUnsupportedMinOpsetVersion(9)
11161
    def test_set_attr_modules_2(self):
11162
        class InnerModule(torch.nn.Module):
11163
            def __init__(self, embedding_dim):
11164
                super().__init__()
11165
                self.embedding_dim = embedding_dim
11166
                self.const = 2.5
11167
                self.weights = InnerModule.get_embedding(self.embedding_dim)
11168
                self.register_buffer("_float_tensor", torch.FloatTensor(1))
11169

11170
            @staticmethod
11171
            def get_embedding(embedding_dim: int):
11172
                emb = 4 / ((embedding_dim // 2) - 1)
11173
                emb = torch.exp(
11174
                    torch.arange((embedding_dim // 2), dtype=torch.float) * -emb
11175
                )
11176
                return emb
11177

11178
            def forward(self, input, incremental_state: Optional[Tensor] = None):
11179
                bsz, seq_len = input.shape[0], input.shape[1]
11180
                self.const = 1.5
11181
                self.weights = InnerModule.get_embedding(self.embedding_dim)
11182
                return (
11183
                    self.weights.index_select(
11184
                        0, torch.ones((bsz * seq_len), dtype=torch.int64)
11185
                    ).view(bsz, seq_len, -1)
11186
                ) * self.const
11187

11188
        class Module(torch.nn.Module):
11189
            def __init__(self):
11190
                super().__init__()
11191
                self.module = InnerModule(embedding_dim=8)
11192

11193
            def forward(self, x):
11194
                return self.module(x)
11195

11196
        x = torch.randn(3, 256)
11197
        self.run_test(Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]})
11198
        self.run_test(Module(), (x,), remained_onnx_input_idx=[])
11199

11200
    def test_set_attr(self):
11201
        class MyModule(torch.nn.Module):
11202
            def __init__(self):
11203
                super().__init__()
11204
                self.conv = torch.nn.Conv1d(3, 10, 2)
11205
                self.b = False
11206

11207
            def forward(self, box_regression, weight):
11208
                self.b = True
11209
                self.conv.weight = weight
11210
                w = torch.softmax(self.conv.weight, dim=0)
11211
                self.conv.weight = w + w
11212
                if self.b:
11213
                    return box_regression + self.conv.weight
11214
                else:
11215
                    return box_regression - self.conv.weight
11216

11217
        model = torch.jit.script(MyModule())
11218
        weight = torch.ones(3, 2)
11219
        box_regression = torch.randn(3, 2)
11220
        self.run_test(model, (box_regression, weight))
11221

11222
    @skipIfUnsupportedMinOpsetVersion(11)
11223
    def test_set_attr_2(self):
11224
        class MyModule(torch.nn.Module):
11225
            def __init__(self):
11226
                super().__init__()
11227
                self.conv = torch.nn.Conv1d(10, 3, 3)
11228
                self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11229

11230
            def set_cell_anchors(self, anchors):
11231
                if self.conv.bias is not None:
11232
                    b = self.conv.bias
11233
                    assert b is not None
11234
                    self.conv.bias = anchors + b
11235
                elif self.conv.weight is not None:
11236
                    self.conv.weight = torch.randn(3, 10)
11237
                    self.conv.bias = self.conv.weight[:]
11238

11239
            def forward(self, anchors) -> Optional[Tensor]:
11240
                self.set_cell_anchors(anchors)
11241
                return self.conv.bias
11242

11243
        model = torch.jit.script(MyModule())
11244
        anchors = torch.ones(3, 10, 3)
11245
        self.run_test(model, (anchors))
11246

11247
    @skipIfUnsupportedMinOpsetVersion(11)
11248
    def test_set_attr_3(self):
11249
        class MyModule(torch.nn.Module):
11250
            def __init__(self):
11251
                super().__init__()
11252
                self.conv = torch.nn.Conv1d(10, 3, 3)
11253
                self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
11254
                self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11255

11256
            def set_cell_anchors(self, anchors, boxes):
11257
                self.conv.weight = torch.ones(3, 10)
11258
                if self.conv.bias is not None:
11259
                    self.conv.bias = torch.randn(3, 10, 3)
11260
                    self.conv.weight = anchors + self.conv.weight
11261
                    boxes[:] = torch.zeros(2, 3)
11262

11263
            def forward(self, anchors) -> Tuple[Tensor, Tensor]:
11264
                boxes = torch.ones(2, 2, 3)
11265
                self.set_cell_anchors(anchors, boxes)
11266
                if self.conv.bias is not None:
11267
                    return self.conv.weight, boxes
11268
                return anchors, boxes
11269

11270
        model = torch.jit.script(MyModule())
11271
        anchors = torch.rand(3, 10)
11272
        self.run_test(model, (anchors))
11273

11274
    @skipIfUnsupportedMinOpsetVersion(11)
11275
    def test_set_attr_4(self):
11276
        class MyModule(torch.nn.Module):
11277
            def __init__(self):
11278
                super().__init__()
11279
                self.conv = torch.nn.Conv1d(10, 3, 3)
11280
                self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11281

11282
            def set_cell_anchors(self, anchors):
11283
                self.conv.weight = torch.zeros(10, 3)
11284
                if self.conv.bias is not None:
11285
                    w = self.conv.bias
11286
                    assert w is not None
11287
                    self.conv.bias = anchors + w
11288
                else:
11289
                    self.conv.bias = torch.ones(3, 10, 3)
11290

11291
            def forward(self, feature_maps, anchors) -> Tuple[Tensor, Tensor]:
11292
                self.set_cell_anchors(anchors)
11293
                result = []
11294
                if self.conv.bias is not None:
11295
                    a = self.conv.bias
11296
                    assert a is not None
11297
                    result += [a]
11298
                result += [feature_maps]
11299
                return result[0], result[1]
11300

11301
        model = torch.jit.script(MyModule())
11302
        x = torch.rand(5, 11, 30)
11303
        anchors = torch.ones(3, 10, 3)
11304
        self.run_test(model, (x, anchors))
11305

11306
    @skipIfUnsupportedMinOpsetVersion(11)
11307
    def test_set_attr_5(self):
11308
        class MyModule(torch.nn.Module):
11309
            def __init__(self):
11310
                super().__init__()
11311
                self.conv = torch.nn.Conv1d(10, 3, 3)
11312
                self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11313

11314
            def set_cell_anchors(self, anchors):
11315
                self.conv.weight = torch.arange(10)
11316
                for i in range(10):
11317
                    if i == 3:
11318
                        for j in range(10):
11319
                            w = self.conv.weight
11320
                            self.conv.weight = torch.arange(10) + w
11321

11322
                    self.conv.weight = self.conv.weight + torch.arange(10)
11323
                    # NOTE: `is not None` and `assert` is for passing torchscript.
11324
                    if self.conv.bias is not None:
11325
                        a = self.conv.bias
11326
                        assert a is not None
11327
                        self.conv.bias = anchors + a
11328

11329
            def forward(self, anchors):
11330
                self.set_cell_anchors(anchors)
11331
                return self.conv.weight, self.conv.bias
11332

11333
        model = torch.jit.script(MyModule())
11334
        anchors = torch.ones(3, 10, 3)
11335
        self.run_test(model, (anchors))
11336

11337
    @skipIfUnsupportedMinOpsetVersion(11)
11338
    def test_set_attr_in_loop(self):
11339
        class MyModule(torch.nn.Module):
11340
            def __init__(self):
11341
                super().__init__()
11342
                self.conv = torch.nn.Conv1d(10, 3, 3)
11343
                self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
11344
                self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11345

11346
            def set_cell_anchors(self, anchors, boxes):
11347
                self.conv.weight = torch.randn(3, 10)
11348
                for i in range(self.conv.weight.size(0)):
11349
                    for j in range(10):
11350
                        self.conv.bias = torch.randn(3, 10, 3)
11351
                        self.conv.weight = anchors * i
11352
                        boxes[j] += torch.ones(3, 3)
11353

11354
            def forward(self, anchors) -> Tuple[Tensor, Tensor]:
11355
                boxes = torch.ones(10, 3, 3)
11356
                self.set_cell_anchors(anchors, boxes)
11357
                if self.conv.bias is not None:
11358
                    return self.conv.weight, boxes
11359
                return anchors, boxes
11360

11361
        model = torch.jit.script(MyModule())
11362
        anchors = torch.rand(10)
11363
        self.run_test(model, anchors)
11364

11365
    @skipIfUnsupportedMinOpsetVersion(13)
11366
    def test_set_attr_in_loop_with_list(self):
11367
        class MyModule(torch.nn.Module):
11368
            def __init__(self):
11369
                super().__init__()
11370
                self.conv = torch.nn.Conv1d(10, 3, 3)
11371
                self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
11372
                self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
11373
                self.boxes: List[Tensor] = [
11374
                    torch.ones(1)
11375
                ]  # Workaround placeholder for TorchScript
11376

11377
            def set_cell_anchors(self, anchors):
11378
                self.conv.weight = torch.randn(3, 10)
11379
                for i in range(self.conv.weight.size(0)):
11380
                    for j in range(10):
11381
                        self.conv.bias = torch.randn(3, 10, 3)
11382
                        self.conv.weight = anchors * i
11383
                        self.boxes.append(torch.ones(3, 3))
11384

11385
            def forward(self, anchors) -> Tuple[Tensor, List[Tensor]]:
11386
                self.boxes = []
11387
                self.set_cell_anchors(anchors)
11388
                if self.conv.bias is not None:
11389
                    return self.conv.weight, self.boxes
11390
                return anchors, self.boxes
11391

11392
        model = torch.jit.script(MyModule())
11393
        anchors = torch.rand(10)
11394
        self.run_test(model, anchors)
11395

11396
    @skipIfUnsupportedMinOpsetVersion(11)
11397
    def test_index_put_if(self):
11398
        @torch.jit.script
11399
        def check_init(
11400
            input_data: Tensor, hidden_size: int, prev_state: Tensor
11401
        ) -> Tuple[Tensor, Tensor]:
11402
            batch_size = input_data.size(0)
11403
            spatial_size_0 = input_data.size(2)
11404
            spatial_size_1 = input_data.size(3)
11405
            # generate empty prev_state, if None is provided
11406
            state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11407
            state = torch.zeros(state_size, device=input_data.device)
11408
            state_copy = torch.zeros(state_size, device=input_data.device)
11409
            if prev_state.size(0) == 0:
11410
                state[:] = (
11411
                    torch.zeros(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11412
                    + state[:]
11413
                )
11414
                state_copy[:] = (
11415
                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11416
                    * 2
11417
                )
11418
                state_copy[:] = (
11419
                    torch.zeros(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11420
                    * 2
11421
                )
11422
            else:
11423
                state[:] = (
11424
                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11425
                    * 4
11426
                )
11427
            return state, state_copy
11428

11429
        class Example(torch.nn.Module):
11430
            def __init__(self, hidden_size):
11431
                super().__init__()
11432
                self.hidden_size = hidden_size
11433

11434
            def forward(self, input_data, prev_state):
11435
                prev_state = check_init(input_data, self.hidden_size, prev_state)
11436
                return prev_state[0], prev_state[1]
11437

11438
        model = Example(10)
11439
        random_data = torch.rand((1, 5, 30, 30))
11440
        empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11441
        self.run_test(
11442
            model,
11443
            (random_data, empty_tensor),
11444
            input_names=["random_data", "empty_tensor"],
11445
            dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]},
11446
        )
11447
        self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11448

11449
    @skipIfUnsupportedMinOpsetVersion(11)
11450
    def test_index_put_if_2(self):
11451
        @torch.jit.script
11452
        def check_init(
11453
            input_data: Tensor, hidden_size: int, prev_state: Tensor
11454
        ) -> Tuple[Tensor, Tensor]:
11455
            batch_size = input_data.size(0)
11456
            spatial_size_0 = input_data.size(2)
11457
            spatial_size_1 = input_data.size(3)
11458
            # generate empty prev_state, if None is provided
11459
            state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11460
            state = torch.zeros(state_size, device=input_data.device)
11461
            state_copy = torch.zeros(state_size, device=input_data.device)
11462
            if prev_state.size(0) == 0:
11463
                for i in range(2):
11464
                    state[:] = (
11465
                        torch.ones(
11466
                            batch_size, hidden_size, spatial_size_0, spatial_size_1
11467
                        )
11468
                        * i
11469
                    )
11470
                    state_copy[:] = (
11471
                        torch.ones(
11472
                            batch_size, hidden_size, spatial_size_0, spatial_size_1
11473
                        )
11474
                        * i
11475
                    )
11476
            elif prev_state.size(0) == 1:
11477
                s = state[:]
11478
                state[:] = prev_state + s
11479
            elif prev_state.size(0) == 2:
11480
                state[:] = (
11481
                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11482
                    * 4
11483
                )
11484
            return state, state_copy
11485

11486
        class Example(torch.nn.Module):
11487
            def __init__(self, hidden_size):
11488
                super().__init__()
11489
                self.hidden_size = hidden_size
11490

11491
            def forward(self, input_data, prev_state):
11492
                prev_state = check_init(input_data, self.hidden_size, prev_state)
11493
                return prev_state[0], prev_state[1]
11494

11495
        model = Example(10)
11496
        random_data = torch.rand((1, 5, 30, 30))
11497
        empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11498
        random_state = torch.rand((1, 1, 10, 30, 30))
11499
        self.run_test(
11500
            model,
11501
            (random_data, empty_tensor),
11502
            input_names=["data", "state"],
11503
            dynamic_axes={"data": [0, 1, 2], "state": [0, 1, 2, 3, 4]},
11504
            additional_test_inputs=[(random_data, random_state)],
11505
        )
11506
        self.run_test(
11507
            model,
11508
            (random_data, empty_tensor),
11509
            input_names=["data", "state"],
11510
            dynamic_axes={"state": [0, 1, 2, 3, 4]},
11511
            additional_test_inputs=[(random_data, random_state)],
11512
            remained_onnx_input_idx=[1],
11513
        )
11514
        self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11515

11516
    @skipIfUnsupportedMinOpsetVersion(11)
11517
    def test_index_put_if_3(self):
11518
        @torch.jit.script
11519
        def check_init(
11520
            input_data: Tensor, hidden_size: int, prev_state: Tensor
11521
        ) -> Tensor:
11522
            batch_size = input_data.size(0)
11523
            spatial_size_0 = input_data.size(2)
11524
            spatial_size_1 = input_data.size(3)
11525
            # generate empty prev_state, if None is provided
11526
            state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11527
            state = torch.zeros(state_size, device=input_data.device)
11528
            if prev_state.size(0) < 2:
11529
                state = state * 3
11530
                if prev_state.size(0) == 0:
11531
                    state[:] = (
11532
                        torch.ones(
11533
                            batch_size, hidden_size, spatial_size_0, spatial_size_1
11534
                        )
11535
                        * 3
11536
                    )
11537
                else:
11538
                    state = state + 2
11539

11540
            return state
11541

11542
        class Example(torch.nn.Module):
11543
            def __init__(self, hidden_size):
11544
                super().__init__()
11545
                self.hidden_size = hidden_size
11546

11547
            def forward(self, input_data, prev_state):
11548
                prev_state = check_init(input_data, self.hidden_size, prev_state)
11549
                return prev_state
11550

11551
        model = Example(4)
11552
        random_data = torch.rand((1, 5, 4, 4))
11553
        empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11554
        self.run_test(
11555
            model,
11556
            (random_data, empty_tensor),
11557
            input_names=["random_data", "empty_tensor"],
11558
            dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]},
11559
        )
11560
        self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11561

11562
    @skipIfUnsupportedMinOpsetVersion(11)
11563
    def test_index_put_if_4(self):
11564
        @torch.jit.script
11565
        def check_init(
11566
            input_data: Tensor, hidden_size: int, prev_state: Tensor
11567
        ) -> Tensor:
11568
            batch_size = input_data.size(0)
11569
            spatial_size_0 = input_data.size(2)
11570
            spatial_size_1 = input_data.size(3)
11571
            # generate empty prev_state, if None is provided
11572
            state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11573
            state = torch.zeros(state_size, device=input_data.device)
11574
            if prev_state.size(0) == 0:
11575
                state = state + 3
11576
                state[:] = (
11577
                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11578
                    * 3
11579
                )
11580
                state = state + 3
11581
                state[:] = (
11582
                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11583
                    * 4
11584
                )
11585
            else:
11586
                state = state + 2
11587
            return state
11588

11589
        class Example(torch.nn.Module):
11590
            def __init__(self, hidden_size):
11591
                super().__init__()
11592
                self.hidden_size = hidden_size
11593

11594
            def forward(self, input_data, prev_state):
11595
                prev_state = check_init(input_data, self.hidden_size, prev_state)
11596
                return prev_state
11597

11598
        model = Example(4)
11599
        random_data = torch.rand((1, 5, 4, 4))
11600
        empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11601
        self.run_test(
11602
            model,
11603
            (random_data, empty_tensor),
11604
            input_names=["random_data", "empty_tensor"],
11605
            dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]},
11606
        )
11607
        self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11608

11609
    @skipIfUnsupportedMinOpsetVersion(11)
11610
    def test_index_put_if_5(self):
11611
        @torch.jit.script
11612
        def check_init(
11613
            input_data: Tensor, hidden_size: int, prev_state: Tensor
11614
        ) -> Tuple[Tensor, Tensor]:
11615
            batch_size = input_data.size(0)
11616
            spatial_size_0 = input_data.size(2)
11617
            spatial_size_1 = input_data.size(3)
11618
            # generate empty prev_state, if None is provided
11619
            state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11620
            state = torch.zeros(state_size, device=input_data.device)
11621
            state_ref = state
11622
            if prev_state.size(0) == 0:
11623
                state[:] = (
11624
                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11625
                    * 3
11626
                )
11627
                state = state + 3
11628
                state[:] = (
11629
                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11630
                    * 4
11631
                )
11632
            else:
11633
                state = state + 2
11634
            return state, state_ref
11635

11636
        class Example(torch.nn.Module):
11637
            def __init__(self, hidden_size):
11638
                super().__init__()
11639
                self.hidden_size = hidden_size
11640

11641
            def forward(self, input_data, prev_state):
11642
                prev_state, state_ref = check_init(
11643
                    input_data, self.hidden_size, prev_state
11644
                )
11645
                return prev_state, state_ref
11646

11647
        model = Example(4)
11648
        random_data = torch.rand((1, 5, 4, 4))
11649
        empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
11650
        self.run_test(
11651
            model,
11652
            (random_data, empty_tensor),
11653
            input_names=["random_data", "empty_tensor"],
11654
            dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]},
11655
        )
11656
        self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
11657

11658
    @skipIfUnsupportedMinOpsetVersion(11)
11659
    def test_list_append_in_block(self):
11660
        class ListModel(torch.nn.Module):
11661
            def forward(self, x, y):
11662
                res = []
11663
                for i in range(x.size(0)):
11664
                    res.append(torch.matmul(x[i], y))
11665
                return res
11666

11667
        model = torch.jit.script(ListModel())
11668
        x = torch.randn(16, 3, 4)
11669
        y = torch.randn(4, 5)
11670
        self.run_test(model, (x, y))
11671

11672
    @skipIfUnsupportedMinOpsetVersion(13)
11673
    def test_list_append_in_nested_block(self):
11674
        class ListModel(torch.nn.Module):
11675
            def forward(self, x, y):
11676
                res = []
11677
                for i in range(x.size(0)):
11678
                    for j in range(x.size(1)):
11679
                        res.append(torch.matmul(x[i][j], y))
11680
                return res
11681

11682
        model = torch.jit.script(ListModel())
11683
        x = torch.randn(4, 4, 3, 4)
11684
        y = torch.randn(4, 5)
11685
        self.run_test(model, (x, y))
11686

11687
    @skipIfUnsupportedMinOpsetVersion(13)
11688
    def test_list_pop_in_block(self):
11689
        class ListModel(torch.nn.Module):
11690
            def forward(self, x, y):
11691
                res = []
11692
                elem = torch.matmul(x[0], y)
11693
                for i in range(x.size(0)):
11694
                    res.append(torch.matmul(x[i], y))
11695
                for i in range(x.size(0)):
11696
                    elem = res.pop()
11697
                for i in range(x.size(0)):
11698
                    res.append(torch.matmul(x[i], y))
11699
                    elem = res.pop()
11700
                return res.append(elem)
11701

11702
        model = torch.jit.script(ListModel())
11703
        x = torch.randn(16, 3, 4)
11704
        y = torch.randn(4, 5)
11705
        self.run_test(model, (x, y))
11706

11707
    @skipIfUnsupportedMinOpsetVersion(13)
11708
    def test_list_del_in_block(self):
11709
        class ListModel(torch.nn.Module):
11710
            def forward(self, x, y):
11711
                res = []
11712
                elem = torch.matmul(x[0], y)
11713
                for i in range(x.size(0)):
11714
                    res.append(torch.matmul(x[i], y))
11715
                for i in range(x.size(0)):
11716
                    del res[0]
11717
                for i in range(x.size(0)):
11718
                    res.append(torch.matmul(x[i], y))
11719
                    del res[0]
11720
                return res.append(elem)
11721

11722
        model = torch.jit.script(ListModel())
11723
        x = torch.randn(16, 3, 4)
11724
        y = torch.randn(4, 5)
11725
        self.run_test(model, (x, y))
11726

11727
    @skipIfUnsupportedMinOpsetVersion(11)
11728
    def test_list_unpack(self):
11729
        class ListModel(torch.nn.Module):
11730
            def forward(self, x, y):
11731
                res = []
11732
                elem = torch.matmul(x[0], y)
11733
                for i in range(x.size(0)):
11734
                    res.append(torch.matmul(x[i], y))
11735
                a, b, c = res
11736
                return a, b
11737

11738
        model = torch.jit.script(ListModel())
11739
        x = torch.randn(3, 3, 4)
11740
        y = torch.randn(4, 5)
11741
        self.run_test(model, (x, y))
11742

11743
    @skipIfUnsupportedMinOpsetVersion(11)
11744
    def test_index_put_inplace_ops(self):
11745
        @torch.jit.script
11746
        def check_init(input_data: Tensor, hidden_size: int) -> Tensor:
11747
            batch_size = input_data.size(0)
11748
            spatial_size_0 = input_data.size(2)
11749
            spatial_size_1 = input_data.size(3)
11750
            # generate empty prev_state, if None is provided
11751
            state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
11752
            state = torch.zeros(state_size, device=input_data.device)
11753
            if input_data.size(0) == 1:
11754
                state[1] += (
11755
                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11756
                    * 2
11757
                )
11758
                state[1] /= (
11759
                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11760
                    * 3
11761
                )
11762
            for i in range(input_data.size(0)):
11763
                state[1] += torch.ones(
11764
                    batch_size, hidden_size, spatial_size_0, spatial_size_1
11765
                )
11766
                state[1] /= (
11767
                    torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
11768
                    * i
11769
                )
11770
            return state
11771

11772
        class Example(torch.nn.Module):
11773
            def __init__(self, hidden_size):
11774
                super().__init__()
11775
                self.hidden_size = hidden_size
11776

11777
            def forward(self, input_data):
11778
                state = check_init(input_data, self.hidden_size)
11779
                return state
11780

11781
        model = Example(10)
11782
        random_data = torch.rand((1, 5, 30, 30))
11783
        self.run_test(
11784
            model,
11785
            (random_data),
11786
            input_names=["random_data"],
11787
            dynamic_axes={"random_data": [0, 1, 2, 3]},
11788
        )
11789
        self.run_test(model, (random_data), remained_onnx_input_idx=[])
11790

11791
    @skipIfUnsupportedMinOpsetVersion(11)
11792
    def test_input_mask_model(self):
11793
        class InputMaskModel(torch.nn.Module):
11794
            def __init__(self, output_size):
11795
                super().__init__()
11796
                self.bias = torch.nn.Parameter(
11797
                    torch.empty(output_size, dtype=torch.float)
11798
                )
11799
                with torch.no_grad():
11800
                    self.bias.zero_()
11801

11802
            def forward(self, model_input, y):
11803
                input_mask = (model_input <= 0) | (model_input > 25)
11804
                y[input_mask, :] = 0.0
11805
                output = y + self.bias
11806
                return output
11807

11808
        output_size = 4
11809
        m = InputMaskModel(output_size)
11810
        x = torch.tensor([0, 4, 24, 25], dtype=torch.int64)
11811
        y = torch.tensor(
11812
            [
11813
                [0.1, 0.2, 0.3, 0.4],
11814
                [0.1, 0.2, 0.3, 0.4],
11815
                [0.1, 0.2, 0.3, 0.4],
11816
                [0.1, 0.2, 0.3, 0.4],
11817
            ],
11818
            dtype=torch.float,
11819
        )
11820
        self.run_test(m, (x, y))
11821

11822
        class InputMaskModel(torch.nn.Module):
11823
            def __init__(self, output_size):
11824
                super().__init__()
11825

11826
            def forward(self, model_input_1, model_input_2, y):
11827
                input_mask_1 = (model_input_1 <= 0) | (model_input_1 > 25)
11828
                input_mask_2 = (model_input_2 < 1) | (model_input_2 >= 12)
11829
                y[input_mask_1, input_mask_2] = 0.0
11830
                return y
11831

11832
        output_size = 4
11833
        m = InputMaskModel(output_size)
11834
        x1 = torch.tensor([0, 4, 24, 25], dtype=torch.int64)
11835
        x2 = torch.tensor([0, 3, 12, 15], dtype=torch.int64)
11836
        y = torch.tensor(
11837
            [
11838
                [0.1, 0.2, 0.3, 0.4],
11839
                [0.1, 0.2, 0.3, 0.4],
11840
                [0.1, 0.2, 0.3, 0.4],
11841
                [0.1, 0.2, 0.3, 0.4],
11842
            ],
11843
            dtype=torch.float,
11844
        )
11845
        self.run_test(m, (x1, x2, y))
11846

11847
    @skipScriptTest()
11848
    def test_unsafe_chunk(self):
11849
        class ChunkModel(torch.nn.Module):
11850
            def forward(self, x):
11851
                return torch.unsafe_chunk(x, 3, dim=1)
11852

11853
        model = ChunkModel()
11854
        model.eval()
11855
        x = torch.randn(1, 18)
11856
        self.run_test(model, x, input_names=["x"])
11857

11858
    def test_symbolic_shape_inference(self):
11859
        # ConstantOfShape is tested in test_embedding_bag
11860
        # Tile is tested in test_repeat
11861
        # test Shape, Reshape, Transpose, Gather
11862
        class ShapeModel(torch.nn.Module):
11863
            def forward(self, x, y):
11864
                shape = x.size()[:3] + (-1,)  # shape [4], ("batch", 3, 4, -1)
11865
                y = y.reshape(shape)  # batch, 3, 4, 10/batch
11866
                return y.transpose(1, 2)
11867

11868
        model = ShapeModel()
11869
        model.eval()
11870
        x = torch.ones(2, 3, 4, 5)
11871
        y = torch.ones(3, 4, 5, 2)
11872
        self.run_test(
11873
            model,
11874
            (x, y),
11875
            input_names=["x", "y"],
11876
            dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]},
11877
        )
11878
        self.run_test(model, (x, y), remained_onnx_input_idx=[1])
11879

11880
        class ViewModel(torch.nn.Module):
11881
            def forward(self, x):
11882
                return x.view(-1)
11883

11884
        model = ViewModel()
11885
        model.eval()
11886
        x = torch.tensor(2.0)
11887
        self.run_test(model, (x,))
11888

11889
        # test prim::ListConstruct for Reshape input 1
11890
        class ViewModel_2(torch.nn.Module):
11891
            def forward(self, x):
11892
                N, C, H, W = x.shape[0], x.shape[2], x.shape[3], x.shape[4]
11893
                x1 = x.view(N, -1, C, H, W)
11894
                x2 = x1.permute(0, 3, 4, 1, 2)
11895
                return x2.reshape(N, -1, C)
11896

11897
        model = ViewModel_2()
11898
        model.eval()
11899
        x = torch.ones(2, 3, 4, 5, 6)
11900
        self.run_test(model, x)
11901

11902
    @skipIfUnsupportedMinOpsetVersion(9)
11903
    def test_symbolic_shape_inference_arange(self):
11904
        # test Range
11905
        class ArangeModel(torch.nn.Module):
11906
            def forward(self, signal):
11907
                frame_step = 2
11908
                outer_dimensions = signal.size()[:-2]
11909
                frames, frame_length = signal.size()[-2:]
11910

11911
                subframe_length = signal.size()[0]
11912
                subframe_step = frame_step // subframe_length
11913
                subframes_per_frame = frame_length // subframe_length
11914
                output_size = frame_step * (frames - 1) + frame_length
11915
                output_subframes = output_size // subframe_length
11916

11917
                frame = torch.arange(0, output_subframes)
11918
                return frame
11919

11920
        model = ArangeModel()
11921
        model.eval()
11922
        M, C, K, N = 1, 2, 3, 4
11923
        x = torch.randint(5, (M, C, K, N))
11924
        y = torch.randint(5, (M, C + 1, K + 1, N + 1))
11925
        self.run_test(model, x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]})
11926
        self.run_test(model, x, remained_onnx_input_idx=[])
11927
        self.run_test(
11928
            model,
11929
            x,
11930
            input_names=["x"],
11931
            dynamic_axes={"x": [0, 1, 2, 3]},
11932
            additional_test_inputs=[(x,), (y,)],
11933
        )
11934

11935
    @skipIfUnsupportedMinOpsetVersion(11)
11936
    def test_symbolic_shape_inference_box(self):
11937
        # test NonZero
11938
        class BoxModel(torch.nn.Module):
11939
            def forward(self, boxes):
11940
                min_size = 1e-2
11941
                ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
11942
                keep = (ws >= min_size) & (hs >= min_size)
11943
                keep = torch.where(keep)[0]
11944
                return keep
11945

11946
        model = BoxModel()
11947
        model.eval()
11948
        x = torch.ones(2, 4)
11949
        y = torch.ones(3, 5)
11950
        self.run_test(model, x)
11951
        self.run_test(
11952
            model,
11953
            x,
11954
            input_names=["x"],
11955
            dynamic_axes={"x": [0, 1]},
11956
            additional_test_inputs=[(x,), (y,)],
11957
        )
11958

11959
    @skipIfUnsupportedMinOpsetVersion(11)
11960
    def test_symbolic_shape_inference_box_if(self):
11961
        # test If
11962
        class BoxIfModel(torch.nn.Module):
11963
            def forward(self, boxes, scores):
11964
                score_thresh = 0.0
11965
                inds = torch.where(scores > score_thresh)[0]
11966
                boxes_1 = boxes[inds]
11967
                if boxes_1.numel() > 3:
11968
                    return boxes_1
11969
                else:
11970
                    return boxes_1 * 2
11971

11972
        model = BoxIfModel()
11973
        model.eval()
11974
        boxes = torch.ones(2, 4)
11975
        scores = torch.ones(1, 4)
11976
        self.run_test(model, (boxes, scores))
11977

11978
    @skipIfUnsupportedMinOpsetVersion(11)
11979
    @skipDtypeChecking
11980
    def test_symbolic_shape_inference_arange_2(self):
11981
        # test Range
11982
        class ArangeModel(torch.nn.Module):
11983
            def forward(self, start):
11984
                return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.int64)
11985

11986
        x = torch.randn(2, 3, 4)
11987
        self.run_test(
11988
            ArangeModel(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
11989
        )
11990
        self.run_test(ArangeModel(), (x,), remained_onnx_input_idx=[])
11991

11992
        class ArangeModel2(torch.nn.Module):
11993
            def forward(self, start):
11994
                return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.double)
11995

11996
        x = torch.randn(2, 3, 4)
11997
        self.run_test(
11998
            ArangeModel2(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
11999
        )
12000
        self.run_test(ArangeModel2(), (x,), remained_onnx_input_idx=[])
12001

12002
    @skipIfUnsupportedMinOpsetVersion(9)
12003
    def test_symbolic_shape_inference_nonzero(self):
12004
        class OneLikeModel(torch.nn.Module):
12005
            def forward(self, x):
12006
                ones = torch.ones_like(
12007
                    x,
12008
                    dtype=torch.float,
12009
                    layout=torch.strided,
12010
                    device=torch.device("cpu"),
12011
                )
12012
                return torch.nonzero(ones)
12013

12014
        x = torch.randn(2)
12015
        self.run_test(OneLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0]})
12016
        self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[])
12017
        x = torch.randn(2, 3, 4)
12018
        self.run_test(
12019
            OneLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
12020
        )
12021
        self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[])
12022

12023
        class ZeroLikeModel(torch.nn.Module):
12024
            def forward(self, x):
12025
                zeros = torch.zeros_like(
12026
                    x,
12027
                    dtype=torch.float,
12028
                    layout=torch.strided,
12029
                    device=torch.device("cpu"),
12030
                )
12031
                return torch.nonzero(zeros)
12032

12033
        x = torch.randn(2)
12034
        self.run_test(ZeroLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0]})
12035
        self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[])
12036
        x = torch.randn(2, 3, 4)
12037
        self.run_test(
12038
            ZeroLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
12039
        )
12040
        self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[])
12041

12042
    @skipIfUnsupportedMinOpsetVersion(9)
12043
    def test_symbolic_shape_inference_expand_1(self):
12044
        class ExpandModel(torch.nn.Module):
12045
            def forward(self, x):
12046
                return x.expand(4, 6, 2)
12047

12048
        x = torch.randn(6, 1, requires_grad=True)
12049
        self.run_test(ExpandModel(), (x,))
12050

12051
    @skipIfUnsupportedMinOpsetVersion(9)
12052
    def test_symbolic_shape_inference_expand_2(self):
12053
        class M(torch.nn.Module):
12054
            def forward(self, x):
12055
                input_shape = x.size()
12056
                batch_size, seq_length = input_shape
12057
                seq_ids = torch.arange(seq_length)
12058
                causal_mask = (
12059
                    seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
12060
                    <= seq_ids[None, :, None]
12061
                )
12062
                return causal_mask.transpose(0, 1)
12063

12064
        x = torch.randn(3, 16)
12065
        self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]})
12066
        self.run_test(M(), (x,), remained_onnx_input_idx=[])
12067

12068
    @skipIfUnsupportedMinOpsetVersion(10)
12069
    def test_symbolic_shape_inference_slice(self):
12070
        class M(torch.nn.Module):
12071
            def forward(self, x, position_bias):
12072
                input_shape = x.size()
12073
                batch_size, seq_length = input_shape
12074
                position_bias = position_bias[:, :, -seq_length:, :]
12075
                return position_bias.transpose(0, 1)
12076

12077
        x = torch.randn(3, 16)
12078
        position_bias = torch.randn(1, 3, 20, 8)
12079
        self.run_test(
12080
            M(),
12081
            (x, position_bias),
12082
            input_names=["x", "position_bias"],
12083
            dynamic_axes={"x": [0, 1], "position_bias": [0, 1, 2, 3]},
12084
        )
12085
        self.run_test(M(), (x, position_bias), remained_onnx_input_idx=[1])
12086

12087
    def test_symbolic_shape_inference_slice_2(self):
12088
        class M(torch.nn.Module):
12089
            def forward(self, position_bias):
12090
                position_bias = position_bias[:, :, -2:, :]
12091
                return position_bias.transpose(0, 1)
12092

12093
        position_bias = torch.randn(1, 3, 20, 8)
12094
        self.run_test(M(), (position_bias,))
12095

12096
    @skipIfUnsupportedMinOpsetVersion(9)
12097
    @skipScriptTest()
12098
    def test_symbolic_shape_inference_time(self):
12099
        input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
12100
        h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
12101
        c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
12102
        model_lstm = torch.nn.LSTM(
12103
            RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
12104
        )
12105
        self.run_test(
12106
            model_lstm,
12107
            (input, (h0, c0)),
12108
            input_names=["x", "y"],
12109
            dynamic_axes={"x": [0, 1]},
12110
        )
12111
        model_gru = torch.nn.GRU(
12112
            RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False, bias=False
12113
        )
12114
        self.run_test(
12115
            model_gru, (input, h0), input_names=["x", "y"], dynamic_axes={"x": [0, 1]}
12116
        )
12117
        model_rnn = torch.nn.RNN(
12118
            RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False, bias=False
12119
        )
12120
        self.run_test(
12121
            model_rnn, (input, h0), input_names=["x", "y"], dynamic_axes={"x": [0, 1]}
12122
        )
12123

12124
    def test_symbolic_shape_inference_dynamic_axes(self):
12125
        class M(torch.nn.Module):
12126
            def forward(self, input_ids):
12127
                input_shape = input_ids.size()
12128
                input_ids = input_ids.view(-1, input_shape[-1])
12129
                return input_ids.transpose(0, 1)
12130

12131
        x = torch.randn(3, 16)
12132
        self.run_test(
12133
            M(),
12134
            (x,),
12135
            input_names=["input_ids"],
12136
            dynamic_axes={"input_ids": {0: "batch", 1: "sequence"}},
12137
        )
12138

12139
    @skipIfUnsupportedMinOpsetVersion(9)
12140
    def test_hann_window_periodic(self):
12141
        class HannWindowModule_Periodic(torch.nn.Module):
12142
            def __init__(self):
12143
                super().__init__()
12144
                self.window_length = 0
12145

12146
            def forward(self, x, window_length: int):
12147
                self.window_length = window_length
12148
                return torch.add(
12149
                    x,
12150
                    torch.hann_window(
12151
                        self.window_length, periodic=True, dtype=torch.float
12152
                    ),
12153
                )
12154

12155
        win_length = 100
12156
        x = torch.randn(win_length)
12157

12158
        module = HannWindowModule_Periodic()
12159
        self.run_test(module, (x, win_length))
12160

12161
    @skipIfUnsupportedMinOpsetVersion(9)
12162
    def test_hann_window_not_periodic(self):
12163
        class HannWindowModule_NotPeriodic(torch.nn.Module):
12164
            def __init__(self):
12165
                super().__init__()
12166
                self.window_length = 0
12167

12168
            def forward(self, x, window_length: int):
12169
                self.window_length = window_length
12170
                return torch.add(
12171
                    x,
12172
                    torch.hann_window(
12173
                        self.window_length, periodic=False, dtype=torch.float
12174
                    ),
12175
                )
12176

12177
        win_length = 100
12178
        x = torch.randn(win_length)
12179

12180
        module = HannWindowModule_NotPeriodic()
12181
        self.run_test(module, (x, win_length))
12182

12183
    @skipIfUnsupportedMinOpsetVersion(9)
12184
    @skipScriptTest()
12185
    def test_hann_window_default_values(self):
12186
        class HannWindowModule(torch.nn.Module):
12187
            def __init__(self):
12188
                super().__init__()
12189
                self.window_length = 0
12190

12191
            def forward(self, x, window_length: int):
12192
                import torch.nn.functional as F
12193

12194
                self.window_length = window_length
12195
                return torch.add(x, F.relu(torch.hann_window(self.window_length)))
12196

12197
        win_length = 100
12198
        x = torch.randn(win_length, dtype=torch.float)
12199
        module = HannWindowModule()
12200

12201
        output = module(x, win_length)
12202
        self.run_test(module, (x, win_length))
12203

12204
    @skipIfUnsupportedMinOpsetVersion(12)
12205
    def test_tensordot_dim_count(self):
12206
        class M(torch.nn.Module):
12207
            def forward(self, x, y):
12208
                output = torch.tensordot(x, y, 2)
12209
                return output
12210

12211
        x = torch.randint(6, (7, 5, 3, 4))
12212
        y = torch.randint(6, (3, 4, 9, 2))
12213

12214
        self.run_test(M(), (x, y))
12215

12216
    @skipIfUnsupportedMinOpsetVersion(12)
12217
    def test_tensordot_dim_list(self):
12218
        class M(torch.nn.Module):
12219
            def forward(self, x, y):
12220
                output = torch.tensordot(x, y, ([1, -2, -1], [1, 0, 3]))
12221
                return output
12222

12223
        x = torch.randint(6, (7, 4, 3, 5, 2))
12224
        y = torch.randint(6, (5, 4, 4, 2, 6))
12225

12226
        self.run_test(M(), (x, y))
12227

12228
    @skipIfUnsupportedMinOpsetVersion(12)
12229
    def test_tensordot_dynamic_dim(self):
12230
        class M(torch.nn.Module):
12231
            def forward(self, x, y):
12232
                output = torch.tensordot(x, y, 2)
12233
                return output
12234

12235
        x = torch.randint(6, (7, 5, 3, 4))
12236
        y = torch.randint(6, (3, 4, 9, 2))
12237

12238
        new_x = torch.randint(6, (8, 6, 2, 5))
12239
        new_y = torch.randint(6, (2, 5, 3, 4))
12240

12241
        self.run_test(
12242
            M(),
12243
            (x, y),
12244
            additional_test_inputs=[(new_x, new_y)],
12245
            input_names=["input_x", "input_y"],
12246
            dynamic_axes={"input_x": [0, 1, 2, 3], "input_y": [0, 1, 2, 3]},
12247
        )
12248

12249
    @skipIfUnsupportedMinOpsetVersion(9)
12250
    def test_to_device(self):
12251
        class M_ToDevice(torch.nn.Module):
12252
            def forward(self, x, y):
12253
                return x.to(y.device), y
12254

12255
        class M_ToDeviceDtype(torch.nn.Module):
12256
            def forward(self, x, y):
12257
                return x.to(y.device, dtype=torch.long), y
12258

12259
        x = torch.randn(6)
12260
        y = torch.randn(6)
12261

12262
        self.run_test(M_ToDevice(), (x, y))
12263
        self.run_test(M_ToDeviceDtype(), (x, y))
12264

12265
    @skipIfUnsupportedMinOpsetVersion(9)
12266
    def test_fill(self):
12267
        class FillModule(torch.nn.Module):
12268
            def forward(self, x, filled_value: int):
12269
                return x.fill_(filled_value)
12270

12271
        x = torch.randn((4, 5, 6))
12272
        filled_value = 7
12273
        self.run_test(FillModule(), (x, filled_value))
12274

12275
        class FillFloatModule(torch.nn.Module):
12276
            def forward(self, x, filled_value: float):
12277
                return x.fill_(filled_value)
12278

12279
        x = torch.randn((4, 5, 6))
12280
        filled_value = 7.5
12281
        self.run_test(FillFloatModule(), (x, filled_value))
12282

12283
        class FillScalarModule(torch.nn.Module):
12284
            def forward(self, x):
12285
                res = x + 2
12286
                res.fill_(2.5)
12287
                return res, x
12288

12289
        x = torch.ones(2, 3, 4, dtype=torch.long)
12290
        self.run_test(FillScalarModule(), x)
12291

12292
    @skipIfUnsupportedMinOpsetVersion(9)
12293
    def test_index_add_normal(self):
12294
        class M(torch.nn.Module):
12295
            def __init__(self, dim, index, updates):
12296
                super().__init__()
12297
                self.dim = dim
12298
                self.index = index
12299
                self.updates = updates
12300

12301
            def forward(self, x):
12302
                x.index_add_(self.dim, self.index, self.updates)
12303
                return x
12304

12305
        x = torch.ones(5, 1)
12306
        updates = torch.tensor([[1], [4], [7], [3], [2]], dtype=torch.float)
12307
        index = torch.tensor([0, 2, 3, 1, 4])
12308
        self.run_test(M(0, index, updates), (x,))
12309

12310
        x = torch.ones(1, 4, 3)
12311
        updates = torch.tensor(
12312
            [[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float
12313
        )
12314
        index = torch.tensor([0, 2, 3, 1])
12315
        self.run_test(M(1, index, updates), (x,))
12316

12317
        updates = torch.tensor(
12318
            [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [2, 3, 4]]], dtype=torch.float
12319
        )
12320
        index = torch.tensor([0, 2, 1])
12321
        self.run_test(M(2, index, updates), (x,))
12322

12323
    @skipIfUnsupportedMinOpsetVersion(9)
12324
    def test_index_add_dim_size_differ(self):
12325
        class M(torch.nn.Module):
12326
            def __init__(self, dim, index, updates):
12327
                super().__init__()
12328
                self.dim = dim
12329
                self.index = index
12330
                self.updates = updates
12331

12332
            def forward(self, x):
12333
                x.index_add_(self.dim, self.index, self.updates)
12334
                return x
12335

12336
        x = torch.ones(1, 4, 3)
12337
        updates = torch.tensor([[[1, 5, 7], [2, 4, 5], [5, 5, 6]]], dtype=torch.float)
12338
        index = torch.tensor([0, 2, 1])
12339
        self.run_test(M(1, index, updates), (x,))
12340

12341
    @skipIfUnsupportedMinOpsetVersion(9)
12342
    def test_index_add_in_loop(self):
12343
        class M(torch.nn.Module):
12344
            def __init__(self, dim, index, updates, loop_count):
12345
                super().__init__()
12346
                self.dim = dim
12347
                self.index = index
12348
                self.updates = updates
12349
                self.loop_count = loop_count
12350

12351
            def forward(self, x):
12352
                for i in range(self.loop_count):
12353
                    x.index_add_(self.dim, self.index, self.updates)
12354
                return x
12355

12356
        x = torch.ones(1, 4, 3)
12357
        updates = torch.tensor(
12358
            [[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float
12359
        )
12360
        index = torch.tensor([0, 2, 3, 1])
12361
        loop_count = torch.randint(20, (1,))[0].item()
12362
        self.run_test(M(1, index, updates, loop_count), (x,))
12363

12364
    @skipIfUnsupportedMinOpsetVersion(9)
12365
    def test_index_add_if(self):
12366
        class M(torch.nn.Module):
12367
            def __init__(self, dim, updates, index_true, index_false):
12368
                super().__init__()
12369
                self.dim = dim
12370
                self.updates = updates
12371
                self.index_true = index_true
12372
                self.index_false = index_false
12373

12374
            def forward(self, x, cond):
12375
                if cond:
12376
                    x.index_add_(self.dim, self.index_true, self.updates)
12377
                else:
12378
                    x.index_add_(self.dim, self.index_false, self.updates)
12379
                return x
12380

12381
        x = torch.ones(1, 4, 3)
12382
        updates = torch.tensor(
12383
            [[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float
12384
        )
12385
        index_true = torch.tensor([0, 2, 3, 1])
12386
        index_false = torch.tensor([1, 0, 2, 3])
12387
        cond = torch.tensor(1, dtype=torch.bool)
12388
        self.run_test(
12389
            torch.jit.script(M(1, updates, index_true, index_false)), (x, cond)
12390
        )
12391

12392
    @skipIfUnsupportedMinOpsetVersion(9)
12393
    def test_index_add_dynamic_axes(self):
12394
        class M(torch.nn.Module):
12395
            def __init__(self, dim, index, updates):
12396
                super().__init__()
12397
                self.dim = dim
12398
                self.index = index
12399
                self.updates = updates
12400

12401
            def forward(self, x):
12402
                x.index_add_(self.dim, self.index, self.updates)
12403
                return x
12404

12405
        x = torch.ones(1, 4, 3)
12406
        updates = torch.tensor(
12407
            [[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float
12408
        )
12409
        index = torch.tensor([0, 2, 3, 1])
12410

12411
        self.run_test(
12412
            M(1, index, updates),
12413
            (x,),
12414
            input_names=["input_1"],
12415
            dynamic_axes={"input_1": [0, 1]},
12416
        )
12417

12418
    def test_roll(self):
12419
        class M(torch.nn.Module):
12420
            def __init__(self, shifts, dims):
12421
                super().__init__()
12422
                self.shifts = shifts
12423
                self.dims = dims
12424

12425
            def forward(self, x):
12426
                return torch.roll(x, self.shifts, self.dims)
12427

12428
        x = torch.randn(2, 3, 4)
12429
        self.run_test(M([1, 1], [1, 0]), (x,))
12430
        self.run_test(M([0, 1, 2], [1, 0, 2]), (x,))
12431
        self.run_test(M(2, 1), (x,))
12432
        self.run_test(M([-1, 3], [-2, -1]), (x,))
12433

12434
    def test_sum(self):
12435
        class M(torch.nn.Module):
12436
            def forward(self, x):
12437
                return torch.sum(x)
12438

12439
        x = torch.ones(12, 3)
12440
        self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0]})
12441

12442
    @skipShapeChecking
12443
    def test_sum_empty_tensor(self):
12444
        class M(torch.nn.Module):
12445
            def forward(self, x):
12446
                return x[0:0].sum(), x.sum()
12447

12448
        x = torch.ones(12)
12449
        self.run_test(M(), (x,))
12450

12451
        x = torch.ones(2, 0, 3)
12452
        self.run_test(M(), (x,))
12453

12454
        x = torch.ones(0)
12455
        self.run_test(M(), (x,))
12456

12457
    @skipIfUnsupportedMinOpsetVersion(11)
12458
    def test_broad_cast_tensors(self):
12459
        class M(torch.nn.Module):
12460
            def forward(self, x, y):
12461
                m = torch.broadcast_tensors(x, y)
12462
                return m
12463

12464
        x = torch.randint(5, (1,))
12465
        y = torch.randint(5, (5,))
12466

12467
        self.run_test(M(), (x, y))
12468

12469
        x = torch.randint(5, (4, 2, 1, 4))
12470
        y = torch.randint(5, (2, 3, 1))
12471

12472
        self.run_test(M(), (x, y))
12473

12474
        x = torch.randn(2, 1, 4)
12475
        y = torch.randn(5, 2, 3, 1)
12476

12477
        self.run_test(M(), (x, y))
12478

12479
    @skipScriptTest()
12480
    @skipIfUnsupportedMinOpsetVersion(11)
12481
    def test_dist_normal(self):
12482
        class M(torch.nn.Module):
12483
            def forward(self, x, y):
12484
                return torch.distributions.Normal(x, y).sample().size(0), x, y
12485

12486
        self.run_test(M(), (torch.tensor([0.0]), torch.tensor([[1.0], [2.0]])))
12487
        self.run_test(M(), (torch.tensor([0.0]), torch.tensor([1.0])))
12488

12489
        self.run_test(
12490
            M(),
12491
            (
12492
                torch.tensor([[[0.0], [10.0]], [[2.0], [8.0]], [[2.0], [8.0]]]),
12493
                torch.tensor([[1.0], [3.0]]),
12494
            ),
12495
        )
12496

12497
    @skipScriptTest()
12498
    @skipIfUnsupportedMinOpsetVersion(11)
12499
    def test_dist_normal_correctness(self):
12500
        class M(torch.nn.Module):
12501
            def forward(self, x, y):
12502
                return torch.distributions.Normal(x, y).sample([20000])
12503

12504
        expected_mean = 5.0
12505
        expected_std = 10.0
12506

12507
        model_export = M()
12508
        dummy_input = (torch.tensor([expected_mean]), torch.tensor([expected_std]))
12509
        model_onnx = io.BytesIO()
12510
        torch.onnx.export(
12511
            model_export, dummy_input, model_onnx, opset_version=self.opset_version
12512
        )
12513
        ort_sess = verification._ort_session(model_onnx)
12514
        ort_out = verification._run_onnx(ort_sess, inputs=dummy_input)
12515

12516
        actual_std = np.std(ort_out)
12517
        actual_mean = np.mean(ort_out)
12518

12519
        assert (
12520
            abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1
12521
        ), "the gap of mean between ort outputs and expected one is unacceptable."
12522
        assert (
12523
            abs(abs(actual_std) - expected_std) <= expected_std * 0.1
12524
        ), "the gap of variance between ort outputs and expected one is unacceptable."
12525

12526
    @skipScriptTest()
12527
    @skipIfUnsupportedMinOpsetVersion(11)
12528
    def test_nn_init_normal_correctness(self):
12529
        expected_mean = 5.0
12530
        expected_std = 10.0
12531

12532
        class M(torch.nn.Module):
12533
            def forward(self):
12534
                x = torch.ones([]).new_empty(1, 400, 50)
12535
                torch.nn.init.normal_(x, expected_mean, expected_std)
12536
                return x
12537

12538
        model_export = M()
12539
        model_onnx = io.BytesIO()
12540
        test_inputs = tuple()
12541
        torch.onnx.export(
12542
            model_export, test_inputs, model_onnx, opset_version=self.opset_version
12543
        )
12544
        ort_sess = verification._ort_session(model_onnx)
12545
        ort_out = verification._run_onnx(ort_sess, inputs=test_inputs)
12546

12547
        actual_std = np.std(ort_out)
12548
        actual_mean = np.mean(ort_out)
12549

12550
        assert (
12551
            abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1
12552
        ), "the gap of mean between ort outputs and expected one is unacceptable."
12553
        assert (
12554
            abs(abs(actual_std) - expected_std) <= expected_std * 0.1
12555
        ), "the gap of variance between ort outputs and expected one is unacceptable."
12556

12557
    @skipScriptTest()
12558
    @skipIfUnsupportedMinOpsetVersion(11)
12559
    def test_dist_uniform(self):
12560
        class M(torch.nn.Module):
12561
            def forward(self, x, y):
12562
                return torch.distributions.Uniform(x, y).sample().size(0), x, y
12563

12564
        self.run_test(M(), (torch.tensor([0.0]), torch.tensor([10.0])))
12565
        self.run_test(M(), (torch.tensor([[0.0], [6.0]]), torch.tensor([[1.0], [7.0]])))
12566
        self.run_test(
12567
            M(), (torch.tensor([1.0]), torch.tensor([[10.0], [7.0], [9.0], [20.0]]))
12568
        )
12569

12570
    @skipScriptTest()
12571
    @skipIfUnsupportedMinOpsetVersion(11)
12572
    def test_dist_uniform_correctness(self):
12573
        class M(torch.nn.Module):
12574
            def forward(self, x, y):
12575
                return torch.distributions.Uniform(x, y).sample([10000])
12576

12577
        expected_min = 5.0
12578
        expected_max = 10.0
12579
        expected_mean = (expected_min + expected_max) / 2
12580

12581
        model_export = M()
12582
        dummy_input = (torch.tensor([expected_min]), torch.tensor([expected_max]))
12583
        model_onnx = io.BytesIO()
12584
        torch.onnx.export(
12585
            model_export, dummy_input, model_onnx, opset_version=self.opset_version
12586
        )
12587
        ort_sess = verification._ort_session(model_onnx)
12588

12589
        ort_out = verification._run_onnx(ort_sess, inputs=dummy_input)
12590
        actual_min = np.min(ort_out)
12591
        actual_max = np.max(ort_out)
12592
        actual_mean = np.mean(ort_out)
12593

12594
        assert (
12595
            actual_min >= expected_min
12596
        ), "the minimum value of ort outputs is out of scope."
12597
        assert (
12598
            actual_max <= expected_max
12599
        ), "the maximum value of ort outputs is out of scope."
12600
        assert (
12601
            abs(actual_mean - expected_mean) <= expected_mean * 0.05
12602
        ), "the mean value of ort outputs is out of scope."
12603

12604
    @skipIfUnsupportedMinOpsetVersion(13)
12605
    def test_sequence_to_int(self):
12606
        class M(torch.nn.Module):
12607
            def forward(self, x):
12608
                result = torch.tensor([2 for i in range(x.size()[0])], dtype=torch.int)
12609
                return x, result
12610

12611
        x = torch.randn(10, 5)
12612
        self.run_test(M(), (x,))
12613

12614
    @skipIfUnsupportedMinOpsetVersion(13)
12615
    def test_sequence_to_float(self):
12616
        class M(torch.nn.Module):
12617
            def forward(self, x):
12618
                result = torch.tensor(
12619
                    [1.1 for i in range(x.size()[0])], dtype=torch.float
12620
                )
12621
                return x, result
12622

12623
        x = torch.randn(10, 5)
12624
        self.run_test(M(), (x,))
12625

12626
    @skipIfUnsupportedMinOpsetVersion(13)
12627
    def test_sequence_to_bool(self):
12628
        class M(torch.nn.Module):
12629
            def forward(self, x):
12630
                result = torch.tensor(
12631
                    [False for i in range(x.size()[0])], dtype=torch.bool
12632
                )
12633
                return x, result
12634

12635
        x = torch.randn(10, 5)
12636
        self.run_test(M(), (x,))
12637

12638
    def test_tuple_output_from_if_with_raised_exception(self):
12639
        class M(torch.nn.Module):
12640
            def forward(self, t: Tensor) -> Tuple[Tensor, Tensor]:
12641
                if float(t) < 0:
12642
                    raise Exception("Negative input")
12643
                else:
12644
                    return torch.zeros(5), torch.zeros(5)
12645

12646
        x = torch.zeros(1)
12647
        self.run_test(torch.jit.script(M()), (x,))
12648

12649
    # NOTE: For quantization tests, choose scale and zero point carefully
12650
    #       such that inputs and outputs do not always overflow/underflow.
12651
    #       Otherwise test results could be inaccurate.
12652
    @skipIfUnsupportedMinOpsetVersion(10)
12653
    def test_quantized_linear(self):
12654
        model = torch.ao.nn.quantized.Linear(4, 8)
12655
        # Set fixed weight to avoid flaky test.
12656
        weight = torch.quantize_per_tensor(
12657
            torch.arange(32, dtype=torch.float).view(8, 4), 0.5, 0, torch.qint8
12658
        )
12659
        # Set non-zero bias.
12660
        bias = torch.arange(8, dtype=torch.float)
12661
        model.set_weight_bias(weight, bias)
12662
        # Set fixed input to avoid flaky test.
12663
        input = torch.randn(4, 4)
12664
        input = torch.arange(16, dtype=torch.float).view(4, 4) - 8
12665
        input_tensor = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12666
        self.run_test(model, input_tensor)
12667

12668
    @skipIfUnsupportedMinOpsetVersion(10)
12669
    def test_quantized_conv1d(self):
12670
        model = torch.ao.nn.quantized.Conv1d(16, 33, 3, stride=2)
12671
        # Manually initialize model weight and bias to random numbers.
12672
        # By default all zeros.
12673
        q_weight = torch.quantize_per_tensor(
12674
            torch.randn(33, 16, 3), 0.5, 0, torch.qint8
12675
        )
12676
        bias = torch.arange(33).to(torch.float) - 16
12677
        model.set_weight_bias(q_weight, bias)
12678
        input = torch.randn(3, 16, 32)
12679
        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12680
        self.run_test(model, q_input)
12681

12682
    @skipIfUnsupportedMinOpsetVersion(10)
12683
    def test_quantized_conv2d(self):
12684
        model = torch.ao.nn.quantized.Conv2d(16, 33, 3, stride=2)
12685
        # Manually initialize model weight and bias to random numbers.
12686
        # By default all zeros.
12687
        q_weight = torch.quantize_per_tensor(
12688
            torch.randn(33, 16, 3, 3), 0.5, 0, torch.qint8
12689
        )
12690
        bias = torch.arange(33).to(torch.float) - 16
12691
        model.set_weight_bias(q_weight, bias)
12692
        input = torch.randn(3, 16, 32, 32)
12693
        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12694
        self.run_test(model, q_input)
12695

12696
    @skipIfUnsupportedMinOpsetVersion(10)
12697
    @skipIfQuantizationBackendQNNPack
12698
    def test_quantized_conv3d(self):
12699
        model = torch.ao.nn.quantized.Conv3d(16, 33, [2, 3, 4], stride=[3, 1, 2])
12700
        # Manually initialize model weight and bias to random numbers.
12701
        # By default all zeros.
12702
        q_weight = torch.quantize_per_tensor(
12703
            torch.randn(33, 16, 2, 3, 4), 0.5, 0, torch.qint8
12704
        )
12705
        bias = torch.arange(33).to(torch.float) - 16
12706
        model.set_weight_bias(q_weight, bias)
12707
        input = torch.randn(3, 16, 8, 8, 8)
12708
        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12709
        self.run_test(model, q_input)
12710

12711
    @skipIfUnsupportedMinOpsetVersion(10)
12712
    def test_quantized_adaptive_avg_pool2d(self):
12713
        model = torch.nn.AdaptiveAvgPool2d((5, 7))
12714
        input = torch.randn(4, 3, 10, 14)
12715
        q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8)
12716
        self.run_test(model, q_input)
12717

12718
    @skipIfUnsupportedMinOpsetVersion(10)
12719
    def test_quantized_conv1d_relu(self):
12720
        model = torch.ao.nn.intrinsic.quantized.ConvReLU1d(16, 33, 3, stride=2)
12721
        # Manually initialize model weight and bias to random numbers.
12722
        # By default all zeros.
12723
        q_weight = torch.quantize_per_tensor(
12724
            torch.randn(33, 16, 3), 0.5, 0, torch.qint8
12725
        )
12726
        bias = torch.arange(33).to(torch.float) - 16
12727
        model.set_weight_bias(q_weight, bias)
12728
        input = torch.randn(3, 16, 32)
12729
        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12730
        self.run_test(model, q_input)
12731

12732
    @skipIfUnsupportedMinOpsetVersion(10)
12733
    def test_quantized_conv2d_relu(self):
12734
        model = torch.ao.nn.intrinsic.quantized.ConvReLU2d(16, 33, 3, stride=2)
12735
        # Manually initialize model weight and bias to random numbers.
12736
        # By default all zeros.
12737
        q_weight = torch.quantize_per_tensor(
12738
            torch.randn(33, 16, 3, 3), 0.5, 0, torch.qint8
12739
        )
12740
        bias = torch.arange(33).to(torch.float) - 16
12741
        model.set_weight_bias(q_weight, bias)
12742
        input = torch.randn(3, 16, 32, 32)
12743
        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12744
        self.run_test(model, q_input)
12745

12746
    @skipIfUnsupportedMinOpsetVersion(10)
12747
    @skipIfQuantizationBackendQNNPack
12748
    def test_quantized_conv3d_relu(self):
12749
        model = torch.ao.nn.intrinsic.quantized.ConvReLU3d(
12750
            16, 33, [2, 3, 4], stride=[3, 1, 2]
12751
        )
12752
        # Manually initialize model weight and bias to random numbers.
12753
        # By default all zeros.
12754
        q_weight = torch.quantize_per_tensor(
12755
            torch.randn(33, 16, 2, 3, 4), 0.5, 0, torch.qint8
12756
        )
12757
        bias = torch.arange(33).to(torch.float) - 16
12758
        model.set_weight_bias(q_weight, bias)
12759
        input = torch.randn(3, 16, 8, 8, 8)
12760
        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12761
        self.run_test(model, q_input)
12762

12763
    @skipIfUnsupportedMinOpsetVersion(10)
12764
    def test_quantized_conv_transpose1d(self):
12765
        model = torch.ao.nn.quantized.ConvTranspose1d(
12766
            16, 33, 3, output_padding=1, stride=2
12767
        )
12768
        # Manually initialize model weight and bias to random numbers.
12769
        # By default all zeros.
12770
        q_weight = torch.quantize_per_tensor(
12771
            torch.randn(16, 33, 3), 0.5, 0, torch.qint8
12772
        )
12773
        bias = torch.arange(33).to(torch.float) - 16
12774
        model.set_weight_bias(q_weight, bias)
12775
        input = torch.randn(3, 16, 32)
12776
        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12777
        self.run_test(model, q_input)
12778

12779
    @skipIfUnsupportedMinOpsetVersion(10)
12780
    def test_quantized_conv_transpose2d(self):
12781
        model = torch.ao.nn.quantized.ConvTranspose2d(
12782
            16, 33, 3, output_padding=(0, 1), stride=2
12783
        )
12784
        # Manually initialize model weight and bias to random numbers.
12785
        # By default all zeros.
12786
        q_weight = torch.quantize_per_tensor(
12787
            torch.randn(16, 33, 3, 3), 0.5, 0, torch.qint8
12788
        )
12789
        bias = torch.arange(33).to(torch.float) - 16
12790
        model.set_weight_bias(q_weight, bias)
12791
        input = torch.randn(3, 16, 32, 32)
12792
        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12793
        self.run_test(model, q_input)
12794

12795
    @skipIfUnsupportedMinOpsetVersion(10)
12796
    @skipIfQuantizationBackendQNNPack
12797
    def test_quantized_conv_transpose3d(self):
12798
        model = torch.ao.nn.quantized.ConvTranspose3d(
12799
            16, 33, [2, 3, 4], output_padding=(0, 1, 2), stride=[3, 1, 2]
12800
        )
12801
        # Manually initialize model weight and bias to random numbers.
12802
        # By default all zeros.
12803
        q_weight = torch.quantize_per_tensor(
12804
            torch.randn(16, 33, 2, 3, 4), 0.5, 0, torch.qint8
12805
        )
12806
        bias = torch.arange(33).to(torch.float) - 16
12807
        model.set_weight_bias(q_weight, bias)
12808
        input = torch.randn(3, 16, 8, 8, 8)
12809
        q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
12810
        self.run_test(model, q_input)
12811

12812
    @common_utils.parametrize(
12813
        "function_or_module",
12814
        [
12815
            common_utils.subtest(
12816
                torch.nn.ReLU(),
12817
                name="relu",
12818
            ),
12819
            common_utils.subtest(
12820
                torch.nn.LeakyReLU(),
12821
                name="leaky_relu",
12822
            ),
12823
            common_utils.subtest(
12824
                torch.ao.nn.quantized.LeakyReLU(2.0, 1),
12825
                name="quantized_leaky_relu",
12826
            ),
12827
            common_utils.subtest(
12828
                torch.ao.nn.quantized.Hardswish(2.0, 1),
12829
                name="quantized_hardswish",
12830
            ),
12831
            common_utils.subtest(
12832
                torch.nn.Sigmoid(),
12833
                name="sigmoid",
12834
            ),
12835
            common_utils.subtest(
12836
                torch.ao.nn.quantized.Sigmoid(2.0, 1),
12837
                name="quantized_sigmoid",
12838
            ),
12839
            common_utils.subtest(
12840
                torch.nn.Hardsigmoid(),
12841
                name="hardsigmoid",
12842
            ),
12843
            common_utils.subtest(
12844
                torch.nn.Tanh(),
12845
                name="tanh",
12846
            ),
12847
            common_utils.subtest(
12848
                torch.nn.Hardtanh(),
12849
                name="hardtanh",
12850
            ),
12851
            common_utils.subtest(
12852
                lambda x: torch.transpose(x, 0, 1),
12853
                name="transpose",
12854
            ),
12855
            common_utils.subtest(
12856
                lambda x: x.expand(2, 4, 2, 3),
12857
                name="expand",
12858
            ),
12859
            common_utils.subtest(
12860
                lambda x: x.view(1, 4, 6),
12861
                name="view",
12862
            ),
12863
            common_utils.subtest(
12864
                lambda x: x.select(1, 1),
12865
                name="select",
12866
            ),
12867
            common_utils.subtest(
12868
                torch.ao.nn.quantized.LayerNorm(
12869
                    [4, 2, 3],
12870
                    torch.nn.Parameter(torch.ones([4, 2, 3])),
12871
                    torch.nn.Parameter(torch.zeros([4, 2, 3])),
12872
                    2.0,
12873
                    1,
12874
                ),
12875
                name="layer_norm",
12876
            ),
12877
            common_utils.subtest(
12878
                torch.ao.nn.quantized.InstanceNorm1d(
12879
                    2,
12880
                    torch.nn.Parameter(torch.ones(4)),
12881
                    torch.nn.Parameter(torch.zeros(4)),
12882
                    2.0,
12883
                    1,
12884
                ),
12885
                name="instance_norm",
12886
            ),
12887
            common_utils.subtest(
12888
                torch.ao.nn.quantized.GroupNorm(
12889
                    2,
12890
                    4,
12891
                    torch.nn.Parameter(torch.zeros(4)),
12892
                    torch.nn.Parameter(torch.zeros(4)),
12893
                    2.0,
12894
                    1,
12895
                ),
12896
                name="group_norm",
12897
            ),
12898
            common_utils.subtest(
12899
                lambda x: torch.as_strided(x, (2, 2), (1, 2)),
12900
                name="as_strided",
12901
            ),
12902
        ],
12903
    )
12904
    @skipScriptTest()
12905
    @skipIfUnsupportedMinOpsetVersion(10)
12906
    def test_quantized_unary_ops(self, function_or_module):
12907
        input = torch.randn(1, 4, 2, 3)
12908
        q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
12909

12910
        class Model(torch.nn.Module):
12911
            def __init__(self, function_or_module):
12912
                super().__init__()
12913
                self.function_or_module = function_or_module
12914

12915
            def forward(self, x):
12916
                return self.function_or_module(x)
12917

12918
        self.run_test(Model(function_or_module), q_input)
12919

12920
    @skipIfUnsupportedMinOpsetVersion(10)
12921
    def test_quantized_flatten(self):
12922
        class FlattenModel(torch.nn.Module):
12923
            def forward(self, input):
12924
                return torch.flatten(input)
12925

12926
        x = torch.quantize_per_tensor(torch.randn(1, 2, 3, 4), 1, 0, torch.quint8)
12927
        self.run_test(FlattenModel(), x)
12928

12929
    @skipIfUnsupportedMinOpsetVersion(10)
12930
    @skipScriptTest()  # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function:
12931
    def test_quantized_cat_when_concatinating_the_same_tensor(self):
12932
        class QuantizedSelfConcatenationModel(torch.nn.Module):
12933
            def forward(self, x):
12934
                return torch.ao.nn.quantized.QFunctional().cat((x, x), dim=1)
12935

12936
        q_input = torch.quantize_per_tensor(torch.ones(2, 3), 0.26, 128, torch.quint8)
12937
        self.run_test(QuantizedSelfConcatenationModel(), q_input)
12938

12939
    @common_utils.parametrize(
12940
        "x, y",
12941
        [
12942
            common_utils.subtest(
12943
                [
12944
                    torch.quantize_per_tensor(
12945
                        torch.ones(2, 3), 0.26, 128, torch.quint8
12946
                    ),
12947
                    torch.quantize_per_tensor(
12948
                        torch.zeros(1, 3), 0.26, 128, torch.quint8
12949
                    ),
12950
                ],
12951
                name="different_shape",
12952
            ),
12953
            common_utils.subtest(
12954
                [
12955
                    torch.quantize_per_tensor(
12956
                        torch.ones(2, 3), 0.26, 128, torch.quint8
12957
                    ),
12958
                    torch.quantize_per_tensor(torch.ones(2, 3), 42, 1, torch.quint8),
12959
                ],
12960
                name="different_scale",
12961
            ),
12962
            common_utils.subtest(
12963
                [
12964
                    torch.quantize_per_tensor(
12965
                        torch.ones(2, 3), 0.26, 128, torch.quint8
12966
                    ),
12967
                    torch.quantize_per_tensor(torch.ones(2, 3), 0.26, 63, torch.quint8),
12968
                ],
12969
                name="different_zero_point",
12970
            ),
12971
            common_utils.subtest(
12972
                [
12973
                    torch.quantize_per_tensor(
12974
                        torch.ones(2, 3), 0.26, 128, torch.quint8
12975
                    ),
12976
                    torch.quantize_per_tensor(torch.ones(2, 3), 0.1, 63, torch.quint8),
12977
                ],
12978
                name="different_zero_point_and_scale",
12979
            ),
12980
        ],
12981
    )
12982
    @skipIfUnsupportedMinOpsetVersion(10)
12983
    @skipScriptTest()  # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function:
12984
    def test_quantized_cat(self, x: torch.Tensor, y: torch.Tensor):
12985
        class QuantizedConcatenationModel(torch.nn.Module):
12986
            def forward(self, x, y):
12987
                return torch.ao.nn.quantized.QFunctional().cat((x, y), dim=0)
12988

12989
        self.run_test(QuantizedConcatenationModel(), (x, y))
12990

12991
    @skipIfUnsupportedMinOpsetVersion(10)
12992
    # torch.jit.frontend.FrontendError:
12993
    # Cannot instantiate class 'QFunctional' in a script function
12994
    @skipScriptTest()
12995
    def test_quantized_arithmetic_qfunctional(self):
12996
        x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
12997
        y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
12998

12999
        class ArithmeticModel(torch.nn.Module):
13000
            def forward(self, x, y):
13001
                o = torch.ao.nn.quantized.QFunctional().add(x, y)
13002
                o = torch.ao.nn.quantized.QFunctional().mul(o, x)
13003
                return o
13004

13005
        self.run_test(ArithmeticModel(), (x, y))
13006

13007
    @skipIfUnsupportedMinOpsetVersion(10)
13008
    def test_quantized_arithmetic(self):
13009
        x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
13010
        y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
13011

13012
        class ArithmeticModel2(torch.nn.Module):
13013
            def forward(self, x, y):
13014
                o = torch.ops.quantized.add(x, y, 0.4, 100)
13015
                o = torch.ops.quantized.mul(o, x, 0.4, 100)
13016
                return o
13017

13018
        self.run_test(ArithmeticModel2(), (x, y))
13019

13020
    @skipIfUnsupportedMinOpsetVersion(10)
13021
    def test_quantize_per_tensor(self):
13022
        class Module(torch.nn.Module):
13023
            def forward(self, x):
13024
                return (
13025
                    torch.quantize_per_tensor(x, 0.2, 0, torch.qint8),
13026
                    torch.quantize_per_tensor(x, 0.2, 128, torch.quint8),
13027
                )
13028

13029
        x = torch.randn(4, 6)
13030
        self.run_test(Module(), x)
13031

13032
    @skipIfUnsupportedMinOpsetVersion(10)
13033
    def test_dequantize(self):
13034
        class Module(torch.nn.Module):
13035
            def forward(self, x):
13036
                return torch.dequantize(x)
13037

13038
        x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 0, torch.qint8)
13039
        self.run_test(Module(), x)
13040

13041
    @skipIfUnsupportedMinOpsetVersion(13)
13042
    def test_qat_linear_per_channel(self):
13043
        class M(torch.nn.Module):
13044
            def __init__(self):
13045
                super().__init__()
13046
                self.quant = torch.ao.quantization.QuantStub()
13047
                self.linear = torch.nn.Linear(4, 3)
13048
                self.dequant = torch.ao.quantization.DeQuantStub()
13049

13050
            def forward(self, x):
13051
                x = self.quant(x)
13052
                x = self.linear(x)
13053
                x = self.dequant(x)
13054
                return x
13055

13056
        model = M()
13057
        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13058
        model = torch.ao.quantization.prepare_qat(model)
13059
        # Set fixed weight and bias to avoid flaky test.
13060
        model.linear.weight = torch.nn.Parameter(
13061
            _construct_tensor_for_quantization_test((3, 4))
13062
        )
13063
        model.linear.bias = torch.nn.Parameter(torch.arange(3, dtype=torch.float))
13064
        model = torch.ao.quantization.convert(model)
13065

13066
        # Set fixed input to avoid flaky test.
13067
        input = _construct_tensor_for_quantization_test((4, 4), offset=-8)
13068
        self.run_test(model, input)
13069

13070
    @unittest.skip(
13071
        "ORT fails with Validating no unexpected access using an invalid node_index on torch converted model"
13072
    )
13073
    @skipIfUnsupportedMinOpsetVersion(13)
13074
    def test_quantized_list_of_inputs_with_cat(self):
13075
        class TestModel(torch.nn.Module):
13076
            def __init__(self):
13077
                super().__init__()
13078
                self.quant = torch.ao.quantization.QuantStub()
13079
                self.dequant = torch.ao.quantization.DeQuantStub()
13080

13081
            def forward(self, x):
13082
                x = self.quant(x)
13083
                x = torch.cat([x, x], 1)
13084
                x = self.dequant(x)
13085
                return x
13086

13087
        model = TestModel()
13088
        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13089
        model = torch.ao.quantization.prepare_qat(model)
13090
        model = torch.ao.quantization.convert(model)
13091
        x = torch.randn(2, 4, 6)
13092
        self.run_test(model, x)
13093

13094
    @skipIfUnsupportedMinOpsetVersion(13)
13095
    def test_qat_relu(self):
13096
        class M(torch.nn.Module):
13097
            def __init__(self):
13098
                super().__init__()
13099
                self.quant = torch.ao.quantization.QuantStub()
13100
                self.relu = torch.nn.ReLU()
13101
                self.dequant = torch.ao.quantization.DeQuantStub()
13102

13103
            def forward(self, x):
13104
                x = self.quant(x)
13105
                x = self.relu(x)
13106
                x = self.dequant(x)
13107
                return x
13108

13109
        model = M()
13110
        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13111
        model = torch.ao.quantization.prepare_qat(model)
13112
        model = torch.ao.quantization.convert(model)
13113
        input = torch.randn(8, 4)
13114
        self.run_test(model, input)
13115

13116
    @skipIfUnsupportedMinOpsetVersion(13)
13117
    def test_qat_conv2d(self):
13118
        class M(torch.nn.Module):
13119
            def __init__(self):
13120
                super().__init__()
13121
                self.quant = torch.ao.quantization.QuantStub()
13122
                self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
13123
                self.dequant = torch.ao.quantization.DeQuantStub()
13124

13125
            def forward(self, x):
13126
                x = self.quant(x)
13127
                x = self.conv(x)
13128
                x = self.dequant(x)
13129
                return x
13130

13131
        model = M()
13132
        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13133
        model = torch.ao.quantization.prepare_qat(model)
13134
        # Set fixed weight and bias to avoid flaky test.
13135
        model.conv.weight = torch.nn.Parameter(
13136
            _construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2)
13137
        )
13138
        model.conv.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0]))
13139
        model = torch.ao.quantization.convert(model)
13140

13141
        # Set fixed input to avoid flaky test.
13142
        input = _construct_tensor_for_quantization_test(
13143
            (3, 4, 8, 8), offset=-384, max_val=12
13144
        )
13145
        self.run_test(model, input)
13146

13147
    @skipIfUnsupportedMinOpsetVersion(13)
13148
    def test_qat_conv2d_relu(self):
13149
        class M(torch.nn.Module):
13150
            def __init__(self):
13151
                super().__init__()
13152
                self.quant = torch.ao.quantization.QuantStub()
13153
                self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
13154
                self.relu = torch.nn.ReLU()
13155
                self.dequant = torch.ao.quantization.DeQuantStub()
13156

13157
            def forward(self, x):
13158
                x = self.quant(x)
13159
                x = self.conv(x)
13160
                x = self.relu(x)
13161
                x = self.dequant(x)
13162
                return x
13163

13164
        model = M()
13165
        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13166
        model = torch.ao.quantization.prepare_qat(model)
13167
        # Set fixed weight and bias to avoid flaky test.
13168
        model.conv.weight = torch.nn.Parameter(
13169
            _construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2)
13170
        )
13171
        model.conv.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0]))
13172
        model = torch.ao.quantization.convert(model)
13173

13174
        # Set fixed input to avoid flaky test.
13175
        input = _construct_tensor_for_quantization_test(
13176
            (3, 4, 8, 8), offset=-384, max_val=12
13177
        )
13178
        self.run_test(model, input)
13179

13180
    @skipIfUnsupportedMinOpsetVersion(13)
13181
    def test_qat_conv2d_relu_fused(self):
13182
        class M(torch.nn.Module):
13183
            def __init__(self):
13184
                super().__init__()
13185
                self.quant = torch.ao.quantization.QuantStub()
13186
                self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
13187
                self.relu = torch.nn.ReLU()
13188
                self.dequant = torch.ao.quantization.DeQuantStub()
13189

13190
            def forward(self, x):
13191
                x = self.quant(x)
13192
                x = self.conv(x)
13193
                x = self.relu(x)
13194
                x = self.dequant(x)
13195
                return x
13196

13197
        model = M()
13198
        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13199
        model = torch.ao.quantization.fuse_modules(model.eval(), [["conv", "relu"]])
13200
        model = torch.ao.quantization.prepare_qat(model.train())
13201
        # Set fixed weight and bias to avoid flaky test.
13202
        model.conv.weight = torch.nn.Parameter(
13203
            _construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2)
13204
        )
13205
        model.conv.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0]))
13206
        model = torch.ao.quantization.convert(model)
13207

13208
        # Set fixed input to avoid flaky test.
13209
        input = _construct_tensor_for_quantization_test(
13210
            (3, 4, 8, 8), offset=-384, max_val=12
13211
        )
13212
        self.run_test(model, input)
13213

13214
    @skipIfUnsupportedMinOpsetVersion(13)
13215
    def test_qat_linear_relu_fused(self):
13216
        class M(torch.nn.Module):
13217
            def __init__(self):
13218
                super().__init__()
13219
                self.quant = torch.ao.quantization.QuantStub()
13220
                self.linear = torch.nn.Linear(4, 2)
13221
                self.relu = torch.nn.ReLU()
13222
                self.dequant = torch.ao.quantization.DeQuantStub()
13223

13224
            def forward(self, x):
13225
                x = self.quant(x)
13226
                x = self.linear(x)
13227
                x = self.relu(x)
13228
                x = self.dequant(x)
13229
                return x
13230

13231
        model = M()
13232
        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13233
        model = torch.ao.quantization.fuse_modules(model.eval(), [["linear", "relu"]])
13234
        model = torch.ao.quantization.prepare_qat(model.train())
13235
        # Set fixed weight and bias to avoid flaky test.
13236
        model.linear.weight = torch.nn.Parameter(
13237
            _construct_tensor_for_quantization_test((2, 4), max_val=2)
13238
        )
13239
        model.linear.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0]))
13240
        model = torch.ao.quantization.convert(model)
13241

13242
        # Set fixed input to avoid flaky test.
13243
        input = _construct_tensor_for_quantization_test((3, 4), offset=-384, max_val=12)
13244
        self.run_test(model, input)
13245

13246
    @skipIfUnsupportedMinOpsetVersion(10)
13247
    def test_qat_maxpool2d(self):
13248
        class M(torch.nn.Module):
13249
            def __init__(self):
13250
                super().__init__()
13251
                self.quant = torch.ao.quantization.QuantStub()
13252
                self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
13253
                self.dequant = torch.ao.quantization.DeQuantStub()
13254

13255
            def forward(self, x):
13256
                x = self.quant(x)
13257
                x = self.pool(x)
13258
                x = self.dequant(x)
13259
                return x
13260

13261
        model = M()
13262
        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13263
        model = torch.ao.quantization.prepare_qat(model.train())
13264
        model = torch.ao.quantization.convert(model)
13265

13266
        # Set fixed input to avoid flaky test.
13267
        input = _construct_tensor_for_quantization_test((4, 4, 3, 2))
13268
        self.run_test(model, input)
13269

13270
    @skipIfUnsupportedMinOpsetVersion(10)
13271
    @skipScriptTest()  # Scale and Zero-point must be a scalar in ORT:optimization
13272
    def test_qat_avg_pool2d(self):
13273
        model = torch.nn.Sequential(
13274
            torch.ao.quantization.QuantStub(),
13275
            torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
13276
            torch.ao.quantization.DeQuantStub(),
13277
        )
13278
        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13279
        model = torch.ao.quantization.prepare_qat(model.train())
13280
        model = torch.ao.quantization.convert(model)
13281
        input = _construct_tensor_for_quantization_test((4, 4, 3, 2))
13282
        self.run_test(model, input)
13283

13284
    @skipIfUnsupportedMinOpsetVersion(11)
13285
    def test_qat_upsample_nearest2d(self):
13286
        model = torch.nn.Sequential(
13287
            torch.ao.quantization.QuantStub(),
13288
            torch.nn.UpsamplingNearest2d(scale_factor=1.5),
13289
            torch.ao.quantization.DeQuantStub(),
13290
        )
13291
        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
13292
        model = torch.ao.quantization.prepare_qat(model.train())
13293
        model = torch.ao.quantization.convert(model)
13294
        input = _construct_tensor_for_quantization_test((4, 3, 2, 2))
13295
        self.run_test(model, input)
13296

13297
    def test_0d_tensor_broadcast(self):
13298
        class fn(torch.nn.Module):
13299
            def forward(self, x, y):
13300
                a = torch.add(x, y)
13301
                b = torch.mul(y, y)
13302
                return a + b
13303

13304
        x = torch.ones(0)
13305
        y = torch.ones(1)
13306
        self.run_test(fn(), (x, y), input_names=["x", "y"], output_names=["output"])
13307

13308
    @skipIfUnsupportedMinOpsetVersion(9)
13309
    def test_convolution_allow_tf32(self):
13310
        class Module(torch.nn.Module):
13311
            def __init__(self, allow_tf32):
13312
                super().__init__()
13313

13314
                self.allow_tf32 = allow_tf32
13315
                weight = torch.rand(32, 3, 3, 3)
13316
                self.weight = torch.nn.Parameter(weight)
13317

13318
            def forward(self, x):
13319
                if self.allow_tf32:
13320
                    return torch._convolution(
13321
                        x,
13322
                        self.weight,
13323
                        None,
13324
                        [2, 2],
13325
                        [0, 0],
13326
                        [1, 1],
13327
                        False,
13328
                        [0, 0],
13329
                        1,
13330
                        False,
13331
                        False,
13332
                        True,
13333
                        True,
13334
                    )
13335
                else:
13336
                    return torch._convolution(
13337
                        x,
13338
                        self.weight,
13339
                        None,
13340
                        [2, 2],
13341
                        [0, 0],
13342
                        [1, 1],
13343
                        False,
13344
                        [0, 0],
13345
                        1,
13346
                        False,
13347
                        False,
13348
                        True,
13349
                    )
13350

13351
        x = torch.randn(1, 3, 224, 224)
13352
        self.run_test(Module(False), x, rtol=1e-3, atol=1e-6)
13353
        self.run_test(Module(True), x, rtol=1e-3, atol=1e-6)
13354

13355
    @skipIfUnsupportedMinOpsetVersion(16)
13356
    @common_utils.parametrize(
13357
        "mode",
13358
        ("bilinear", "nearest", "bicubic"),
13359
    )
13360
    @common_utils.parametrize(
13361
        "padding_mode",
13362
        ("zeros", "border", "reflection"),
13363
    )
13364
    @common_utils.parametrize(
13365
        "align_corners",
13366
        (True, False),
13367
        name_fn=lambda align_corners: str(align_corners),
13368
    )
13369
    def test_grid_sample(self, mode, padding_mode, align_corners):
13370
        n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
13371

13372
        class GridSampleModule(torch.nn.Module):
13373
            def __init__(self, mode, padding_mode, align_corners) -> None:
13374
                super().__init__()
13375
                self.mode, self.padding_mode, self.align_corners = (
13376
                    mode,
13377
                    padding_mode,
13378
                    align_corners,
13379
                )
13380

13381
            def forward(self, input, grid):
13382
                return torch.nn.functional.grid_sample(
13383
                    input, grid, self.mode, self.padding_mode, self.align_corners
13384
                )
13385

13386
        atol_rtol = {}
13387
        if (mode, padding_mode) == ("bicubic", "border"):
13388
            if align_corners:
13389
                atol_rtol.update({"atol": 0.3, "rtol": 0.4})
13390
            else:
13391
                atol_rtol.update({"atol": 0.02, "rtol": 0.02})
13392
        input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2)
13393
        self.run_test(
13394
            GridSampleModule(mode, padding_mode, align_corners),
13395
            (input, grid),
13396
            **atol_rtol,
13397
        )
13398

13399
        # ONNX Opset 16 GridSample with 5D volumetric input is not supported.
13400
        d_in = 2
13401
        d_out = 3
13402
        volumetric_input_tensor = torch.randn(n, c, d_in, h_in, w_in)
13403
        volumetric_grid_tensor = torch.randn(n, d_out, h_out, w_out, 3)
13404
        for mode, padding_mode, align_corners in itertools.product(
13405
            (
13406
                "bilinear",
13407
                "nearest",
13408
            ),  # PyTorch grid_sample "bicubic" mode does not support 5D volumetric input.
13409
            (
13410
                "zeros",
13411
                "border",
13412
                "reflection",
13413
            ),
13414
            (
13415
                True,
13416
                False,
13417
            ),
13418
        ):
13419
            with self.assertRaises(
13420
                torch.onnx.errors.OnnxExporterError,
13421
            ):
13422
                self.run_test(
13423
                    GridSampleModule(mode, padding_mode, align_corners),
13424
                    (volumetric_input_tensor, volumetric_grid_tensor),
13425
                    **atol_rtol,
13426
                )
13427

13428
    class IfNoneInput(torch.nn.Module):
13429
        def forward(self, x) -> Optional[Tensor]:
13430
            y: Optional[Tensor] = None
13431
            if x.size(0) > 1:
13432
                y = x
13433
            return y
13434

13435
    class IfNoneOutput(torch.nn.Module):
13436
        def forward(self, x) -> Optional[Tensor]:
13437
            y: Optional[Tensor] = x
13438
            if x.size(0) > 1:
13439
                y = None
13440
            return y
13441

13442
    class LoopNoneInput(torch.nn.Module):
13443
        def forward(self, x) -> Optional[Tensor]:
13444
            y: Optional[Tensor] = None
13445
            for _ in range(x.size(0)):
13446
                y = x
13447
            return y
13448

13449
    class LoopNoneOutput(torch.nn.Module):
13450
        def forward(self, x) -> Optional[Tensor]:
13451
            y: Optional[Tensor] = x
13452
            for _ in range(x.size(0)):
13453
                y = None
13454
            return y
13455

13456
    @common_utils.parametrize(
13457
        "module_class",
13458
        (IfNoneOutput, IfNoneInput, LoopNoneOutput, LoopNoneInput),
13459
        name_fn=lambda module_class: module_class.__name__,
13460
    )
13461
    @common_utils.parametrize("x_size", (0, 1), name_fn=lambda x_size: str(x_size))
13462
    @skipTraceTest()
13463
    @skipIfUnsupportedMinOpsetVersion(16)
13464
    def test_optional_output(self, module_class: Type[torch.nn.Module], x_size: int):
13465
        # Need scripting to preserve control flow for this test to be
13466
        # meaningful.
13467
        model = torch.jit.script(module_class())
13468
        f = io.BytesIO()
13469
        x = torch.ones(x_size)
13470
        dynamic_axis_name = "condition"
13471
        torch.onnx.export(
13472
            model,
13473
            x,
13474
            f,
13475
            opset_version=self.opset_version,
13476
            # Ensure condition is not constant
13477
            dynamic_axes={"x": {0: dynamic_axis_name}},
13478
            input_names=["x"],
13479
        )
13480
        exported = onnx.load_from_string(f.getvalue())
13481
        expected_elem_type = torch.onnx.JitScalarType.from_value(x).onnx_type()
13482
        expected_output_type = onnx.helper.make_optional_type_proto(
13483
            onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,))
13484
        )
13485
        self.assertEqual(expected_output_type, exported.graph.output[0].type)
13486
        for node in exported.graph.node:
13487
            # Both branches output types should match.
13488
            if node.op_type == "If":
13489
                for attr in node.attribute:
13490
                    if attr.name in ("then_branch", "else_branch"):
13491
                        self.assertEqual(expected_output_type, attr.g.output[0].type)
13492

13493
        self.run_test(
13494
            module_class(),
13495
            x,
13496
            # Ensure condition is not constant
13497
            dynamic_axes={"x": {0: dynamic_axis_name}},
13498
            input_names=["x"],
13499
        )
13500

13501
    @skipTraceTest()
13502
    @skipIfUnsupportedMinOpsetVersion(16)
13503
    def test_uninitialized_optional(self):
13504
        class Module(torch.nn.Module):
13505
            def forward(self, y: Optional[Tensor]) -> Optional[Tensor]:
13506
                if y is not None:
13507
                    if y.shape[1] < 5:
13508
                        if y.size(0) == 1:
13509
                            y = y + 4
13510
                        else:
13511
                            return y
13512
                return y
13513

13514
        self.run_test(
13515
            Module(),
13516
            torch.ones((3, 4), dtype=torch.int),
13517
            dynamic_axes={"y": {0: "y0", 1: "y1"}},
13518
            input_names=["y"],
13519
        )
13520

13521
    @skipIfUnsupportedMinOpsetVersion(9)
13522
    def test_device_eq(self):
13523
        class M(torch.nn.Module):
13524
            def forward(self, a):
13525
                # exercise both Tensor.device (prim::device)
13526
                # and torch.device (prim::Constant).
13527
                if a.device != torch.device("cpu"):
13528
                    return a
13529
                return torch.zeros_like(a)
13530

13531
        mod = torch.jit.script(M())  # preserve control flow
13532

13533
        self.run_test(
13534
            mod,
13535
            # In order for the ONNX model behavior to match the torch model, we
13536
            # need to construct input that has the same device that is checked for
13537
            # in forward(). In ONNX there is no such thing as a device, so the if
13538
            # condition is always false.
13539
            torch.randn(3, 3, device="cpu"),
13540
            # Force dynamic axes so that the output shape depends on the input.
13541
            # Otherwise the entire model will just return a constant and not have
13542
            # any inputs.
13543
            input_names=["a"],
13544
            dynamic_axes={"a": {0: "a0"}},
13545
        )
13546

13547
    @skipIfUnsupportedMinOpsetVersion(9)
13548
    def test_lerp(self):
13549
        class LerpModel(torch.nn.Module):
13550
            def forward(self, x):
13551
                return (
13552
                    x.lerp(torch.full_like(x, 10), 0.4),
13553
                    x.lerp(torch.full_like(x, 20), 0.7),
13554
                    x.lerp(torch.full_like(x, 30), torch.tensor(0.4)),
13555
                    x.lerp(torch.full_like(x, 40), x / 10.0),
13556
                    x.lerp(torch.tensor(10.0), x / 10.0),
13557
                    x.lerp(torch.tensor(10.0), 0.4),
13558
                    x.lerp(torch.tensor(10.0), torch.tensor(0.4)),
13559
                )
13560

13561
        self.run_test(LerpModel(), torch.rand(5, 4, 3))
13562

13563
    @common_utils.parametrize("input_dtype", [torch.cfloat, torch.float])
13564
    @skipIfUnsupportedMinOpsetVersion(9)
13565
    def test_print_tensor_within_torch_nn_module(self, input_dtype: torch.dtype):
13566
        class PrintTensorOnMyModel(torch.nn.Module):
13567
            def forward(self, x):
13568
                # 'print' has side effect calling 'resolve_conj' and 'resolve_neg'.
13569
                x_firsts = x[:, 0]
13570
                print(f"x_firsts: {x_firsts}")
13571
                # 'tolist' has side effect calling 'resolve_conj' and 'resolve_neg'.
13572
                # Annotation added to pass torch script.
13573
                _: List[float] = x.tolist()
13574
                return x_firsts
13575

13576
        m = PrintTensorOnMyModel()
13577
        x = torch.randn(10, 5, dtype=input_dtype)
13578
        if input_dtype == torch.cfloat:
13579
            with self.assertRaises(RuntimeError):
13580
                self.run_test(
13581
                    m,
13582
                    x,
13583
                )
13584
        else:
13585
            self.run_test(
13586
                m,
13587
                x,
13588
            )
13589

13590
    @skipScriptTest()
13591
    @skipIfUnsupportedMinOpsetVersion(16)
13592
    @unittest.skipIf(
13593
        not torch.hub._check_module_exists("torch_geometric"),
13594
        "torch_geometric not installed.",
13595
    )
13596
    def test_sage_conv(self):
13597
        from torch_geometric import nn as torch_geometric_nn
13598

13599
        # Input
13600
        coords0 = torch.randn(1, 6)
13601
        coords1 = torch.randn(1, 6)
13602
        coords = torch.transpose(torch.cat((coords0, coords1), dim=0), 0, 1)
13603
        adj = torch_geometric_nn.knn_graph(coords, k=2, batch=None, loop=True)
13604
        edge_from = adj[0:1, :]
13605
        edge_to = adj[1:, :]
13606
        inputs = (coords0, coords1, edge_from, edge_to)
13607

13608
        class MySAGEConv(torch.nn.Module):
13609
            def __init__(self):
13610
                super().__init__()
13611
                self.SAGEConvBlock1 = torch_geometric_nn.SAGEConv(
13612
                    2, 512, normalize=True
13613
                )
13614
                self.bano1 = torch_geometric_nn.BatchNorm(512)
13615
                self.relu = torch.nn.ReLU()
13616
                self.dense1 = torch.nn.Seq(Lin(512, 1))  # noqa: F821
13617
                self.sigmoid = torch.nn.Sigmoid()
13618

13619
            def forward(self, coords0, coords1, edge_from, edge_to):
13620
                adj = torch.cat((edge_from, edge_to), dim=0)
13621
                gra = torch.transpose(torch.cat((coords0, coords1), dim=0), 0, 1)
13622
                x1 = self.SAGEConvBlock1(gra, edge_index=adj)
13623
                x = torch.unsqueeze(torch.sum(x1), dim=0)
13624
                return x
13625

13626
        input_names = ["coords0", "coords1", "edge_from", "edge_to"]
13627
        output_names = ["outputs"]
13628
        dynamic_axes = {
13629
            "coords0": {0: "batch_size", 1: "features"},
13630
            "coords1": {0: "batch_size", 1: "features"},
13631
            "edge_from": {0: "batch_size", 1: "features"},
13632
            "edge_to": {0: "batch_size", 1: "features"},
13633
            "outputs": {0: "batch_size"},
13634
        }
13635
        self.run_test(
13636
            MySAGEConv(),
13637
            inputs,
13638
            input_names=input_names,
13639
            output_names=output_names,
13640
            dynamic_axes=dynamic_axes,
13641
        )
13642

13643
    # Cannot export with older opsets because of "ConstantFill" op
13644
    # ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime
13645
    # There are still some issues prevent us from enabling script test for these scenarios:
13646
    # test_gru_*:
13647
    #   Operator aten::as_tensor is not supported by exporter yet.
13648
    #       - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055382
13649
    #   Operator aten::_pack_padded_sequence is not supported by exporter yet.
13650
    #       - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055384
13651
    # test_elman_*:
13652
    # Compiling in script mode fails with errors like:
13653
    #   torch.jit.frontend.UnsupportedNodeError: annotated assignments
13654
    #   without assigned value aren't supported
13655
    #       - https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723
13656
    # test_lstm_*:
13657
    #   Compiling in script mode fails with errors like:
13658
    #   RuntimeError: Arguments for call are not valid.
13659
    #       - https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723
13660
    @skipScriptTest()
13661
    @skipIfUnsupportedMinOpsetVersion(9)
13662
    @common_utils.parametrize(
13663
        "name, nonlinearity",
13664
        [
13665
            ("elman", "relu"),
13666
            ("elman", "tanh"),
13667
            ("lstm", None),
13668
            ("gru", None),
13669
        ],
13670
    )
13671
    @common_utils.parametrize(**_parametrize_rnn_args("layers"))
13672
    @common_utils.parametrize(**_parametrize_rnn_args("bidirectional"))
13673
    @common_utils.parametrize(**_parametrize_rnn_args("initial_state"))
13674
    @common_utils.parametrize(**_parametrize_rnn_args("packed_sequence"))
13675
    @common_utils.parametrize(**_parametrize_rnn_args("dropout"))
13676
    def test_rnn(self, *args, **kwargs):
13677
        self._dispatch_rnn_test(*args, **kwargs)
13678

13679

13680
if __name__ == "__main__":
13681
    common_utils.TestCase._default_dtype_check_enabled = True
13682
    common_utils.run_tests()
13683

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.