pytorch

Форк
0
/
test_operators.py 
1312 строк · 45.8 Кб
1
# Owner(s): ["module: onnx"]
2

3
"""
4
Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
5
          --no-onnx: no onnx python dependency
6
          --produce-onnx-test-data: generate onnx test data
7
          --accept: accept onnx updates and overwrite models
8
"""
9
import glob
10
import inspect
11
import io
12
import itertools
13
import operator
14
import os
15
import shutil
16
import tempfile
17

18
# Full diff for expect files
19
import unittest
20

21
import torch
22
import torch.nn as nn
23
import torch.nn.functional as F
24
import torch.onnx
25

26
from pytorch_test_common import (
27
    BATCH_SIZE,
28
    flatten,
29
    RNN_HIDDEN_SIZE,
30
    RNN_INPUT_SIZE,
31
    RNN_SEQUENCE_LENGTH,
32
)
33
from torch.autograd import Function, Variable
34
from torch.nn import functional, Module
35
from torch.onnx._internal import diagnostics
36
from torch.onnx.symbolic_helper import (
37
    _get_tensor_dim_size,
38
    _get_tensor_sizes,
39
    parse_args,
40
)
41
from torch.testing._internal import common_utils
42
from torch.testing._internal.common_utils import skipIfCaffe2, skipIfNoLapack
43

44
unittest.TestCase.maxDiff = None
45

46
_onnx_test = False  # flag to produce onnx test cases.
47
_onnx_dep = True  # flag to import onnx package.
48

49

50
def export_to_pbtxt(model, inputs, *args, **kwargs):
51
    return torch.onnx.export_to_pretty_string(
52
        model, inputs, *args, google_printer=True, **kwargs
53
    )
54

55

56
def export_to_pb(model, inputs, *args, **kwargs):
57
    f = io.BytesIO()
58
    with torch.no_grad():
59
        torch.onnx.export(model, inputs, f, *args, **kwargs)
60
    return f.getvalue()
61

62

63
class FuncModule(Module):
64
    def __init__(self, f, params=None):
65
        if params is None:
66
            params = ()
67
        super().__init__()
68
        self.f = f
69
        self.params = nn.ParameterList(list(params))
70

71
    def forward(self, *args):
72
        return self.f(*itertools.chain(args, self.params))
73

74

75
class TestOperators(common_utils.TestCase):
76
    def setUp(self):
77
        super().setUp()
78
        diagnostics.engine.clear()
79

80
    def assertONNX(self, f, args, params=None, **kwargs):
81
        if params is None:
82
            params = ()
83
        if isinstance(f, nn.Module):
84
            m = f
85
        else:
86
            m = FuncModule(f, params)
87
        m.eval()
88
        onnx_model_pbtxt = export_to_pbtxt(m, args, **kwargs)
89
        subname = kwargs.pop("subname", None)
90
        self.assertExpected(onnx_model_pbtxt, subname)
91
        if _onnx_dep:
92
            onnx_model_pb = export_to_pb(m, args, **kwargs)
93
            import onnx
94
            import onnx.checker
95
            import onnx.numpy_helper
96
            import onnx_test_common
97

98
            model_def = onnx.ModelProto.FromString(onnx_model_pb)
99
            onnx.checker.check_model(model_def)
100
            if _onnx_test:
101
                test_function = inspect.stack()[1][0].f_code.co_name
102
                test_name = test_function[0:4] + "_operator" + test_function[4:]
103
                output_dir = os.path.join(
104
                    onnx_test_common.pytorch_operator_dir, test_name
105
                )
106
                # Assume:
107
                #     1) the old test should be delete before the test.
108
                #     2) only one assertONNX in each test, otherwise will override the data.
109
                assert not os.path.exists(output_dir), f"{output_dir} should not exist!"
110
                os.makedirs(output_dir)
111
                with open(os.path.join(output_dir, "model.onnx"), "wb") as file:
112
                    file.write(model_def.SerializeToString())
113
                data_dir = os.path.join(output_dir, "test_data_set_0")
114
                os.makedirs(data_dir)
115
                if isinstance(args, Variable):
116
                    args = (args,)
117
                for index, var in enumerate(flatten(args)):
118
                    tensor = onnx.numpy_helper.from_array(var.data.numpy())
119
                    with open(
120
                        os.path.join(data_dir, f"input_{index}.pb"), "wb"
121
                    ) as file:
122
                        file.write(tensor.SerializeToString())
123
                outputs = m(*args)
124
                if isinstance(outputs, Variable):
125
                    outputs = (outputs,)
126
                for index, var in enumerate(flatten(outputs)):
127
                    tensor = onnx.numpy_helper.from_array(var.data.numpy())
128
                    with open(
129
                        os.path.join(data_dir, f"output_{index}.pb"), "wb"
130
                    ) as file:
131
                        file.write(tensor.SerializeToString())
132

133
    def assertONNXRaises(self, err, f, args, params=None, **kwargs):
134
        if params is None:
135
            params = ()
136
        if isinstance(f, nn.Module):
137
            m = f
138
        else:
139
            m = FuncModule(f, params)
140
        self.assertExpectedRaises(err, lambda: export_to_pbtxt(m, args, **kwargs))
141

142
    def assertONNXRaisesRegex(self, err, reg, f, args, params=None, **kwargs):
143
        if params is None:
144
            params = ()
145
        if isinstance(f, nn.Module):
146
            m = f
147
        else:
148
            m = FuncModule(f, params)
149
        with self.assertRaisesRegex(err, reg):
150
            export_to_pbtxt(m, args, **kwargs)
151

152
    def test_basic(self):
153
        x = torch.tensor([0.4], requires_grad=True)
154
        y = torch.tensor([0.7], requires_grad=True)
155
        self.assertONNX(lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), (x, y))
156

157
    def test_view(self):
158
        x = torch.tensor([0.0], requires_grad=True)
159
        self.assertONNX(lambda x: x.view(1, 1), x)
160

161
    def test_index(self):
162
        x = torch.tensor([[0.0]], requires_grad=True)
163
        self.assertONNX(lambda x: x[0], x)
164

165
    def test_type_as(self):
166
        x = torch.tensor([0.0], requires_grad=True)
167
        self.assertONNX(lambda x: x.type_as(x), x)
168

169
    def test_addconstant(self):
170
        x = torch.randn(2, 3, requires_grad=True).double()
171
        self.assertONNX(lambda x: x + 1, x)
172

173
    def test_add_broadcast(self):
174
        x = torch.randn(2, 3, requires_grad=True).double()
175
        y = torch.randn(3, requires_grad=True).double()
176
        self.assertONNX(operator.add, (x, y))
177

178
    def test_add_left_broadcast(self):
179
        x = torch.randn(3, requires_grad=True).double()
180
        y = torch.randn(2, 3, requires_grad=True).double()
181
        self.assertONNX(operator.add, (x, y))
182

183
    def test_add_size1_broadcast(self):
184
        x = torch.randn(2, 3, requires_grad=True).double()
185
        y = torch.randn(2, 1, requires_grad=True).double()
186
        self.assertONNX(operator.add, (x, y))
187

188
    def test_add_size1_right_broadcast(self):
189
        x = torch.randn(2, 3, requires_grad=True).double()
190
        y = torch.randn(3, requires_grad=True).double()
191
        self.assertONNX(operator.add, (x, y))
192

193
    def test_add_size1_singleton_broadcast(self):
194
        x = torch.randn(2, 3, requires_grad=True).double()
195
        y = torch.randn(1, 3, requires_grad=True).double()
196
        self.assertONNX(operator.add, (x, y))
197

198
    def test_rsub(self):
199
        x = torch.randn(2, 3, requires_grad=True).double()
200
        self.assertONNX(lambda x: 1 - x, (x,))
201

202
    def test_mul_bool(self):
203
        x = torch.tensor([True, False, True, False])
204
        y = torch.tensor([True, True, False, False])
205
        self.assertONNX(lambda x, y: torch.mul(x, y), (x, y))
206

207
    def test_mul_fp_bool(self):
208
        x = torch.tensor([9.4, 1.7, 3.6])
209
        y = torch.tensor([True, True, False])
210
        self.assertONNX(lambda x, y: torch.mul(x, y), (x, y))
211

212
    def test_transpose(self):
213
        x = torch.tensor([[0.0, 1.0], [2.0, 3.0]], requires_grad=True)
214
        self.assertONNX(lambda x: x.transpose(0, 1).transpose(1, 0), x)
215

216
    def test_chunk(self):
217
        x = torch.tensor([0.0, 1.0, 2.0], requires_grad=True)
218
        self.assertONNX(lambda x: x.chunk(2), x)
219

220
    def test_split(self):
221
        x = torch.tensor(
222
            [[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]]
223
        )
224
        self.assertONNX(lambda x: torch.split(x, 2, 1), x)
225

226
    def test_split_with_sizes(self):
227
        x = torch.tensor(
228
            [[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]]
229
        )
230
        self.assertONNX(lambda x: torch.split(x, [2, 1, 3], 1), x)
231

232
    def test_concat2(self):
233
        x = torch.randn(2, 3)
234
        y = torch.randn(2, 3)
235
        self.assertONNX(lambda inputs: torch.cat(inputs, 1), ((x, y),))
236

237
    def test_mm(self):
238
        m1 = torch.randn(2, 3, requires_grad=True)
239
        m2 = torch.randn(3, 4, requires_grad=True)
240
        self.assertONNX(torch.mm, (m1, m2))
241

242
    def test_addmm(self):
243
        m1 = torch.randn(2, 3, requires_grad=True)
244
        m2 = torch.randn(3, 4, requires_grad=True)
245
        m3 = torch.randn(4, requires_grad=True)
246
        self.assertONNX(
247
            lambda x, y, z: torch.addmm(torch.addmm(z, x, y), x, y), (m1, m2, m3)
248
        )
249

250
    def test_permute2(self):
251
        x = torch.tensor([[[[[[0.0]]]]]], requires_grad=True)
252
        self.assertONNX(lambda x: x.permute(0, 1, 4, 2, 5, 3), x)
253

254
    def test_pad(self):
255
        x = torch.tensor(
256
            [[[[0.0, 1.0, 1.0, 1.0], [2.0, 3.0, 7.0, 7.0]]]], requires_grad=True
257
        )
258
        self.assertONNX(nn.ReflectionPad2d((2, 3, 0, 1)), x)
259

260
    def test_params(self):
261
        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
262
        y = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True))
263
        self.assertONNX(
264
            lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))),
265
            x,
266
            params=(y,),
267
            keep_initializers_as_inputs=True,
268
        )
269

270
    def test_params_onnx_irv4(self):
271
        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
272
        y = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True))
273
        self.assertONNX(
274
            lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))),
275
            x,
276
            params=(y,),
277
            keep_initializers_as_inputs=False,
278
        )
279

280
    def test_symbolic_mismatch(self):
281
        class MyFun(Function):
282
            @staticmethod
283
            def symbolic(g, x):
284
                # The inside of this function should never be invoked, because
285
                # we will fail due to an argument mismatch first.
286
                raise AssertionError()
287

288
            @staticmethod
289
            def forward(ctx, x, y):
290
                return x + y
291

292
        x = torch.ones(2, 2)
293
        y = torch.ones(2, 2)
294
        # NB: Don't use expect test here, the type error wobbles depending
295
        # on Python version
296
        with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"):
297
            export_to_pbtxt(FuncModule(MyFun().apply), (x, y))
298

299
    # TODO: Do an nn style test for these
300
    def test_batchnorm(self):
301
        x = torch.ones(2, 2, 2, 2, requires_grad=True)
302
        self.assertONNX(nn.BatchNorm2d(2), x, keep_initializers_as_inputs=True)
303

304
    def test_batchnorm_onnx_irv4(self):
305
        x = torch.ones(2, 2, 2, 2, requires_grad=True)
306
        self.assertONNX(nn.BatchNorm2d(2), x)
307

308
    def test_batchnorm_1d(self):
309
        x = torch.ones(2, 2, requires_grad=True)
310
        self.assertONNX(nn.BatchNorm1d(2), x, keep_initializers_as_inputs=True)
311

312
    def test_batchnorm_training(self):
313
        x = torch.ones(2, 2, 2, 2, requires_grad=True)
314
        self.assertONNX(
315
            nn.BatchNorm2d(2),
316
            x,
317
            training=torch.onnx.TrainingMode.TRAINING,
318
            keep_initializers_as_inputs=True,
319
        )
320

321
    def test_conv(self):
322
        x = torch.ones(20, 16, 50, 40, requires_grad=True)
323
        self.assertONNX(
324
            nn.Conv2d(16, 13, 3, bias=False), x, keep_initializers_as_inputs=True
325
        )
326

327
    def test_conv_onnx_irv4(self):
328
        x = torch.ones(20, 16, 50, 40, requires_grad=True)
329
        self.assertONNX(nn.Conv2d(16, 13, 3, bias=False), x)
330

331
    def test_conv_onnx_irv4_opset8(self):
332
        # This test point checks that for opset 8 (or lower), even if
333
        # keep_initializers_as_inputs is set to False, it is ignored,
334
        # and initializers are listed as ONNX graph input, in accordance
335
        # with ONNX IR v3 semantics (which apply to opset version <= 8).
336
        x = torch.ones(1, 2, 5, 7, requires_grad=True)
337
        conv_node = nn.Conv2d(2, 4, 3, bias=False)
338
        conv_node.weight.data.fill_(1.0)
339
        self.assertONNX(
340
            conv_node, x, opset_version=8, keep_initializers_as_inputs=False
341
        )
342

343
    def test_conv_variable_length(self):
344
        x = torch.ones(5, 3, 6, 6, requires_grad=True)
345
        model = torch.nn.Conv2d(3, 2, 3)
346

347
        dynamic_axes = {
348
            "input_1": [0, 2, 3],
349
            "output_1": {0: "output_1_variable_dim_0", 1: "output_1_variable_dim_1"},
350
        }
351
        model_proto_file = tempfile.NamedTemporaryFile()
352
        torch.onnx.export(
353
            model,
354
            x,
355
            model_proto_file.name,
356
            verbose=True,
357
            input_names=["input_1"],
358
            output_names=["output_1"],
359
            dynamic_axes=dynamic_axes,
360
        )
361

362
        import onnx
363

364
        onnx_model = onnx.load(model_proto_file.name)
365
        onnx.checker.check_model(onnx_model)
366

367
        # Asserting the default dynamic axes names are generated when custom names are not provided
368
        assert (
369
            onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
370
            == "input_1_dynamic_axes_1"
371
        )
372
        assert (
373
            onnx_model.graph.input[0].type.tensor_type.shape.dim[2].dim_param
374
            == "input_1_dynamic_axes_2"
375
        )
376
        assert (
377
            onnx_model.graph.input[0].type.tensor_type.shape.dim[3].dim_param
378
            == "input_1_dynamic_axes_3"
379
        )
380

381
        # Asserting the custom names are applied when provided
382
        assert (
383
            onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_param
384
            == "output_1_variable_dim_0"
385
        )
386
        assert (
387
            onnx_model.graph.output[0].type.tensor_type.shape.dim[1].dim_param
388
            == "output_1_variable_dim_1"
389
        )
390

391
    def test_convtranspose(self):
392
        x = torch.ones(2, 3, 4, 5, requires_grad=True)
393
        self.assertONNX(
394
            nn.ConvTranspose2d(
395
                3, 3, 3, stride=3, bias=False, padding=1, output_padding=2
396
            ),
397
            x,
398
            keep_initializers_as_inputs=True,
399
        )
400

401
    def test_maxpool(self):
402
        x = torch.randn(20, 16, 50)
403
        self.assertONNX(nn.MaxPool1d(3, stride=2), x)
404

405
    def test_maxpool_dilations(self):
406
        x = torch.randn(20, 16, 50)
407
        self.assertONNX(nn.MaxPool1d(2, stride=1, dilation=2), x, opset_version=10)
408

409
    def test_avg_pool2d(self):
410
        x = torch.randn(20, 16, 50, 32)
411
        self.assertONNX(nn.AvgPool2d(3, stride=2), x)
412

413
    def test_maxpool_indices(self):
414
        x = torch.randn(20, 16, 50)
415
        self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x)
416

417
    @skipIfCaffe2
418
    def test_at_op(self):
419
        x = torch.randn(3, 4)
420

421
        class MyFun(Function):
422
            @staticmethod
423
            def symbolic(g, x):
424
                return g.at("add", x, x)
425

426
            @staticmethod
427
            def forward(ctx, x):
428
                return x + x
429

430
        class MyModule(Module):
431
            def forward(self, x):
432
                return MyFun.apply(x)
433

434
        self.assertONNX(
435
            MyModule(),
436
            x,
437
            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
438
        )
439

440
    def test_clip(self):
441
        x = torch.randn(3, 4, requires_grad=True)
442
        self.assertONNX(lambda x: torch.clamp(x, min=-0.5, max=0.5), x)
443

444
    def test_clip_min(self):
445
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
446
        self.assertONNX(lambda x: x.clamp(min=-0.1), x)
447

448
    def test_clip_max(self):
449
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
450
        self.assertONNX(lambda x: x.clamp(max=0.1), x)
451

452
    def test_hardtanh(self):
453
        x = torch.randn(3, 4, requires_grad=True)
454
        self.assertONNX(lambda x: torch.nn.Hardtanh(-0.5, 0.5)(x), x)
455

456
    def test_full(self):
457
        x = torch.randn(3, 4, requires_grad=True)
458
        self.assertONNX(lambda x: torch.full(x.shape, 2.0), x)
459

460
    def test_full_like(self):
461
        x = torch.randn(3, 4, requires_grad=True)
462
        self.assertONNX(lambda x: torch.full_like(x, 2), x)
463

464
    def test_max(self):
465
        x = torch.randn(3, 4, requires_grad=True)
466
        y = torch.randn(3, 4, requires_grad=True)
467
        self.assertONNX(lambda x, y: torch.max(x, y), (x, y))
468

469
    def test_min(self):
470
        x = torch.randn(3, 4, requires_grad=True)
471
        y = torch.randn(3, 4, requires_grad=True)
472
        self.assertONNX(lambda x, y: torch.min(x, y), (x, y))
473

474
    def test_mean(self):
475
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
476
        self.assertONNX(lambda x: torch.mean(x), x)
477

478
    def test_reduced_mean(self):
479
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
480
        self.assertONNX(lambda x: torch.mean(x, dim=2), x)
481

482
    def test_reduced_mean_keepdim(self):
483
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
484
        self.assertONNX(lambda x: torch.mean(x, dim=(2, 3), keepdim=True), x)
485

486
    def test_mean_dtype(self):
487
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
488
        self.assertONNX(lambda x: torch.mean(x, dtype=torch.double), x)
489

490
    def test_reduced_mean_dtype(self):
491
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
492
        self.assertONNX(lambda x: torch.mean(x, dim=0, dtype=torch.double), x)
493

494
    def test_sum(self):
495
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
496
        self.assertONNX(lambda x: torch.sum(x), x)
497

498
    def test_sum_dtype(self):
499
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
500
        self.assertONNX(lambda x: torch.sum(x, dtype=torch.double), x)
501

502
    def test_reduced_sum_dtype(self):
503
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
504
        self.assertONNX(lambda x: torch.sum(x, dim=0, dtype=torch.double), x)
505

506
    def test_reduced_sum(self):
507
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
508
        self.assertONNX(lambda x: torch.sum(x, dim=(1, 2)), x)
509

510
    def test_reduced_sum_keepdim(self):
511
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
512
        self.assertONNX(lambda x: torch.sum(x, dim=2, keepdim=True), x)
513

514
    def test_prod(self):
515
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
516
        self.assertONNX(lambda x: torch.prod(x), x)
517

518
    def test_reduced_prod(self):
519
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
520
        self.assertONNX(lambda x: torch.prod(x, dim=2), x)
521

522
    def test_reduced_prod_keepdim(self):
523
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
524
        self.assertONNX(lambda x: torch.prod(x, dim=2, keepdim=True), x)
525

526
    def test_prod_dtype(self):
527
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
528
        self.assertONNX(lambda x: torch.prod(x, dtype=torch.double), x)
529

530
    def test_reduced_prod_dtype(self):
531
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
532
        self.assertONNX(lambda x: torch.prod(x, dim=0, dtype=torch.double), x)
533

534
    def test_sqrt(self):
535
        x = torch.randn(3, 4, requires_grad=True)
536
        self.assertONNX(lambda x: torch.sqrt(x), x)
537

538
    def test_rsqrt(self):
539
        x = torch.randn(3, 4, requires_grad=True)
540
        self.assertONNX(lambda x: torch.rsqrt(x), x)
541

542
    def test_equal(self):
543
        x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
544
        y = torch.randn(1, 4, requires_grad=False).int()
545
        self.assertONNX(operator.eq, (x, y))
546

547
    def test_lt(self):
548
        x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
549
        y = torch.randn(1, 4, requires_grad=False).int()
550
        self.assertONNX(operator.lt, (x, y))
551

552
    def test_gt(self):
553
        x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
554
        y = torch.randn(1, 4, requires_grad=False).int()
555
        self.assertONNX(operator.gt, (x, y))
556

557
    def test_le(self):
558
        x = torch.randn(3, 4, requires_grad=False).int()
559
        y = torch.randn(3, 4, requires_grad=False).int()
560
        self.assertONNX(operator.le, (x, y))
561

562
    def test_ge(self):
563
        x = torch.randn(3, 4, requires_grad=False).int()
564
        y = torch.randn(3, 4, requires_grad=False).int()
565
        self.assertONNX(operator.ge, (x, y))
566

567
    def test_exp(self):
568
        x = torch.randn(3, 4, requires_grad=True)
569
        self.assertONNX(lambda x: x.exp(), x)
570

571
    def test_sin(self):
572
        x = torch.randn(3, 4, requires_grad=True)
573
        self.assertONNX(lambda x: x.sin(), x)
574

575
    def test_cos(self):
576
        x = torch.randn(3, 4, requires_grad=True)
577
        self.assertONNX(lambda x: x.cos(), x)
578

579
    def test_tan(self):
580
        x = torch.randn(3, 4, requires_grad=True)
581
        self.assertONNX(lambda x: x.tan(), x)
582

583
    def test_asin(self):
584
        x = torch.rand(3, 4, requires_grad=True)
585
        self.assertONNX(lambda x: x.asin(), x)
586

587
    def test_acos(self):
588
        x = torch.rand(3, 4, requires_grad=True)
589
        self.assertONNX(lambda x: x.acos(), x)
590

591
    def test_slice(self):
592
        x = torch.rand(3, 4, requires_grad=True)
593
        self.assertONNX(lambda x: x[:, 1:2], x)
594

595
    def test_slice_dynamic(self):
596
        x = torch.rand(3, 4, requires_grad=True)
597
        self.assertONNX(lambda x: x[x.size(0) :, x.size(1) - 3], x, opset_version=10)
598

599
    def test_sign(self):
600
        x = torch.rand(3, 4, requires_grad=True)
601
        self.assertONNX(lambda x: x.sign(), x)
602

603
    def test_narrow(self):
604
        x = torch.randn(3, 3, requires_grad=True)
605
        self.assertONNX(lambda x: torch.narrow(x, 0, 0, 2), x)
606

607
    def test_atan(self):
608
        x = torch.randn(3, 4, requires_grad=True)
609
        self.assertONNX(lambda x: x.atan(), x)
610

611
    def test_view_flatten(self):
612
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
613
        self.assertONNX(lambda x: x.view(x.size()[0], x.numel() // x.size()[0]), x)
614

615
    def test_flatten(self):
616
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
617
        self.assertONNX(lambda x: torch.flatten(x), x)
618

619
    def test_flatten2D(self):
620
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
621
        self.assertONNX(lambda x: torch.flatten(x, 1), x)
622

623
    def test_isnan(self):
624
        x = torch.tensor([1, float("nan"), 2])
625
        self.assertONNX(lambda x: torch.isnan(x), x)
626

627
    def test_argmax(self):
628
        x = torch.randn(4, 4, requires_grad=True)
629
        self.assertONNX(lambda x: torch.argmax(x, dim=1), x)
630

631
    def test_logsoftmax(self):
632
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
633
        self.assertONNX(nn.LogSoftmax(dim=3), x)
634

635
    def test_pow(self):
636
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
637
        y = torch.randn(1, 2, 3, 4, requires_grad=True)
638
        self.assertONNX(lambda x, y: x.pow(y), (x, y))
639

640
    def test_elu(self):
641
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
642
        self.assertONNX(nn.ELU(), x)
643

644
    def test_selu(self):
645
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
646
        self.assertONNX(nn.SELU(), x)
647

648
    def test_repeat(self):
649
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
650
        self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x)
651

652
    def test_repeat_dim_overflow(self):
653
        x = torch.randn(1, 2, requires_grad=True)
654
        self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x)
655

656
    def test_norm_p1(self):
657
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
658
        self.assertONNX(lambda x: x.norm(p=1, dim=2), (x))
659

660
    def test_norm_p2(self):
661
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
662
        self.assertONNX(lambda x: x.norm(p=2, dim=2), (x))
663

664
    def test_upsample_nearest_scale(self):
665
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
666
        self.assertONNX(
667
            lambda x: nn.functional.interpolate(
668
                x, scale_factor=2.0, mode="nearest", recompute_scale_factor=False
669
            ),
670
            x,
671
        )
672

673
    def test_upsample_nearest_scale_default_scale_factor(self):
674
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
675
        self.assertONNX(
676
            lambda x: nn.functional.interpolate(x, scale_factor=2.0, mode="nearest"), x
677
        )
678

679
    def test_upsample_nearest_size(self):
680
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
681
        self.assertONNX(
682
            lambda x: nn.functional.interpolate(x, size=16, mode="nearest"), x
683
        )
684

685
    def test_unsqueeze(self):
686
        x = torch.randn(3, 4, requires_grad=True)
687
        self.assertONNX(lambda x: x.unsqueeze(len(x.shape)), x)
688

689
    def test_batchnorm_noaffine(self):
690
        x = torch.randn(128, 128, 1, 1, requires_grad=True)
691
        self.assertONNX(
692
            nn.BatchNorm2d(128, affine=False, momentum=0.3),
693
            x,
694
            keep_initializers_as_inputs=True,
695
        )
696

697
    @skipIfCaffe2
698
    def test_embedding_bags(self):
699
        emb_bag = nn.EmbeddingBag(10, 8)
700
        input = torch.tensor([1, 2, 3, 4]).long()
701
        offset = torch.tensor([0]).long()
702
        self.assertONNX(
703
            emb_bag,
704
            (input, offset),
705
            keep_initializers_as_inputs=True,
706
            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
707
        )
708

709
    def test_implicit_expand(self):
710
        x = torch.randn(3, 4, requires_grad=True)
711
        self.assertONNX(lambda x: x + 1, x)
712

713
    def test_reduce_sum_negative_indices(self):
714
        x = torch.randn(3, 4, requires_grad=True)
715
        self.assertONNX(lambda x: x.sum(-1), x)
716

717
    def test_randn(self):
718
        x = torch.randn(1, 2, 3, 4)
719
        self.assertONNX(lambda x: torch.randn(1, 2, 3, 4) + x, x)
720

721
    def test_rand(self):
722
        x = torch.rand(1, 2, 3, 4)
723
        self.assertONNX(lambda x: torch.rand(1, 2, 3, 4) + x, x)
724

725
    def test_rrelu(self):
726
        x = torch.randn(1, 2, 3, 4)
727
        self.assertONNX(torch.nn.RReLU(), x)
728

729
    def test_prelu(self):
730
        x = torch.randn(1, 2, 3, 4)
731
        self.assertONNX(torch.nn.PReLU(2), x, keep_initializers_as_inputs=True)
732

733
    def test_log_sigmoid(self):
734
        x = torch.randn(1, 2, 3, 4)
735
        self.assertONNX(torch.nn.LogSigmoid(), x)
736

737
    def test_linear(self):
738
        x = torch.randn(3, 4)
739
        self.assertONNX(
740
            torch.nn.Linear(4, 5, bias=True), x, keep_initializers_as_inputs=True
741
        )
742

743
    def test_empty_like(self):
744
        x = torch.randn(5, 8, requires_grad=True)
745
        self.assertONNX(lambda x: torch.empty_like(x), x)
746

747
    def test_zeros_like(self):
748
        x = torch.randn(5, 8, requires_grad=True)
749
        self.assertONNX(lambda x: torch.zeros_like(x), x)
750

751
    def test_ones_like(self):
752
        x = torch.randn(6, 10, requires_grad=True)
753
        self.assertONNX(lambda x: torch.ones_like(x), x)
754

755
    def test_expand(self):
756
        x = torch.randn(6, 1, requires_grad=True)
757
        self.assertONNX(lambda x: x.expand(4, 6, 2), x)
758

759
    def test_ne(self):
760
        x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
761
        y = torch.randn(1, 4, requires_grad=False).int()
762
        self.assertONNX(lambda x, y: torch.ne(x, y), (x, y))
763

764
    def test_reducemax(self):
765
        x = torch.randn(1, 2, 3, 4)
766
        self.assertONNX(lambda x: torch.max(x), x)
767

768
    def test_reducemin(self):
769
        x = torch.randn(1, 2, 3, 4)
770
        self.assertONNX(lambda x: torch.min(x), x)
771

772
    def test_erf(self):
773
        x = torch.randn(1, 2, 3, 4)
774
        self.assertONNX(lambda x: x.erf(), x)
775

776
    def test_dropout(self):
777
        x = torch.randn(3, 4, requires_grad=True)
778
        self.assertONNX(lambda x: torch.max(functional.dropout(x, training=False)), x)
779

780
    def test_dropout_default(self):
781
        x = torch.randn(3, 4, requires_grad=True)
782
        self.assertONNX(
783
            lambda x: torch.max(
784
                functional.dropout(
785
                    x,
786
                )
787
            ),
788
            x,
789
        )
790

791
    def test_dropout_training(self):
792
        x = torch.randn(3, 4, requires_grad=True)
793
        self.assertONNX(
794
            lambda x: torch.max(functional.dropout(x)),
795
            x,
796
            training=torch.onnx.TrainingMode.TRAINING,
797
        )
798

799
    def test_dropout_opset12(self):
800
        x = torch.randn(3, 4, requires_grad=True)
801
        self.assertONNX(
802
            lambda x: torch.max(functional.dropout(x, training=False)),
803
            x,
804
            opset_version=12,
805
        )
806

807
    def test_dropout_training_opset12(self):
808
        x = torch.randn(3, 4, requires_grad=True)
809
        self.assertONNX(
810
            lambda x: torch.max(functional.dropout(x)),
811
            x,
812
            opset_version=12,
813
            training=torch.onnx.TrainingMode.TRAINING,
814
        )
815

816
    def test_nonzero(self):
817
        x = torch.tensor(
818
            [[[2.0, 2.0], [1.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]]], requires_grad=True
819
        )
820
        self.assertONNX(lambda x: torch.nonzero(x), x)
821

822
    def test_gather(self):
823
        data = torch.randn(3, 4, 3, requires_grad=True)
824
        index = torch.tensor([2, 0]).view(1, 2, 1).expand(3, 2, 3)
825
        self.assertONNX(lambda data, index: data.gather(1, index), (data, index))
826

827
    def test_gather_opset11(self):
828
        data = torch.randn(3, 4, 3, requires_grad=True)
829
        index = torch.tensor([2, 0]).view(1, 2, 1).expand(3, 2, 3)
830
        self.assertONNX(
831
            lambda data, index: data.gather(1, index), (data, index), opset_version=11
832
        )
833

834
    def test_scatter_add(self):
835
        data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
836
        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
837
        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
838
        self.assertONNX(
839
            lambda data, index: data.scatter_add(1, indices, values),
840
            (data, (indices, values)),
841
        )
842

843
    def test_scatter_add_opset11(self):
844
        data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
845
        indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
846
        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
847
        self.assertONNX(
848
            lambda data, index: data.scatter_add(1, indices, values),
849
            (data, (indices, values)),
850
            opset_version=11,
851
        )
852

853
    def test_scatter_add_opset16(self):
854
        data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
855
        indices = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64)
856
        values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
857
        self.assertONNX(
858
            lambda data, index: data.scatter_add(1, indices, values),
859
            (data, (indices, values)),
860
            opset_version=16,
861
        )
862

863
    def test_master_opset(self):
864
        x = torch.randn(2, 3).float()
865
        y = torch.randn(2, 3).float()
866
        self.assertONNX(operator.add, (x, y), opset_version=10)
867

868
    def test_std(self):
869
        x = torch.randn(2, 3, 4).float()
870
        self.assertONNX(
871
            lambda x: torch.std(x, dim=(0, 1), unbiased=True, keepdim=True), x
872
        )
873

874
    def test_cumsum(self):
875
        x = torch.randn(2, 3, 4, requires_grad=True)
876
        self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11)
877

878
    # Github Issue: https://github.com/pytorch/pytorch/issues/71095
879
    #    def test_c2_op(self):
880
    #        class MyModel(torch.nn.Module):
881
    #            def __init__(self):
882
    #                super().__init__()
883
    #
884
    #            def forward(self, scores, bbox_deltas, im_info, anchors):
885
    #                a, b = torch.ops._caffe2.GenerateProposals(
886
    #                    (scores), (bbox_deltas), (im_info), (anchors),
887
    #                    2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True,
888
    #                )
889
    #                return a, b
890
    #
891
    #        model = MyModel()
892
    #        A = 4
893
    #        H = 10
894
    #        W = 8
895
    #        img_count = 3
896
    #        scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
897
    #        bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
898
    #                                     dtype=torch.float32)
899
    #        bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
900
    #        im_info = torch.ones(img_count, 3, dtype=torch.float32)
901
    #        anchors = torch.ones(A, 4, dtype=torch.float32)
902
    #        inputs = (scores, bbox_deltas, im_info, anchors)
903
    #        self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0})
904

905
    def test_dict(self):
906
        class MyModel(torch.nn.Module):
907
            def forward(self, x_in):
908
                x_out = {}
909
                x_out["test_key_out"] = torch.add(
910
                    x_in[list(x_in.keys())[0]], list(x_in.keys())[0]  # noqa: RUF015
911
                )
912
                return x_out
913

914
        x = {torch.tensor(1.0): torch.randn(1, 2, 3)}
915
        self.assertONNX(MyModel(), (x, {}))
916

917
    def test_dict_str(self):
918
        class MyModel(torch.nn.Module):
919
            def forward(self, x_in):
920
                x_out = {}
921
                x_out["test_key_out"] = torch.add(x_in["test_key_in"], 2.0)
922
                return x_out
923

924
        x = {"test_key_in": torch.randn(1, 2, 3)}
925
        self.assertONNX(MyModel(), (x, {}))
926

927
    def test_arange_dynamic(self):
928
        class TestModel(torch.nn.Module):
929
            def forward(self, input):
930
                return torch.arange(input.shape[0], input.shape[0] + 5, 0.5)
931

932
        input = torch.randn(5, 3, 2)
933
        self.assertONNX(TestModel(), input, opset_version=11)
934

935
    def test_bitshift(self):
936
        class BitshiftModel(torch.nn.Module):
937
            def forward(self, input):
938
                return input >> 1, input >> 2
939

940
        input = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
941
        self.assertONNX(BitshiftModel(), input, opset_version=11)
942

943
    @skipIfCaffe2
944
    def test_layer_norm_aten(self):
945
        model = torch.nn.LayerNorm([10, 10])
946
        x = torch.randn(20, 5, 10, 10)
947
        self.assertONNX(
948
            model,
949
            x,
950
            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
951
        )
952

953
    def test_pixel_shuffle(self):
954
        x = torch.randn(2, 8, 3, 4).float()
955
        self.assertONNX(
956
            lambda x: torch.pixel_shuffle(x, upscale_factor=2), x, opset_version=11
957
        )
958

959
    def test_frobenius_norm(self):
960
        x = torch.randn(2, 3, 4).float()
961
        self.assertONNX(lambda x: torch.norm(x, p="fro", dim=(0, 1), keepdim=True), x)
962

963
    def test_unfold(self):
964
        x = torch.randn(2, 3, 4, requires_grad=True)
965
        self.assertONNX(lambda x: x.unfold(dimension=2, size=2, step=2), x)
966

967
    def test_remainder(self):
968
        x = torch.randn(2, 3, 4)
969
        y = torch.randn(2, 1, 4)
970
        self.assertONNX(lambda x, y: torch.remainder(x, y), (x, y))
971

972
    def test_fmod(self):
973
        x = torch.randn(2, 3, 4)
974
        y = torch.randn(2, 1, 4)
975
        self.assertONNX(lambda x, y: torch.fmod(x, y), (x, y), opset_version=10)
976

977
    def test_gelu(self):
978
        x = torch.randn(2, 3, 4, 5, requires_grad=True)
979
        self.assertONNX(lambda x: torch.nn.functional.gelu(x), x)
980

981
    def test_unique(self):
982
        x = torch.randint(3, (2, 3, 4, 5)).float()
983
        self.assertONNX(
984
            lambda x: torch.unique(
985
                x, dim=0, sorted=True, return_inverse=False, return_counts=True
986
            ),
987
            x,
988
            opset_version=11,
989
        )
990

991
    def test_meshgrid(self):
992
        x = torch.ones(3, requires_grad=True)
993
        y = torch.zeros(4, requires_grad=True)
994
        z = torch.ones(5, requires_grad=True)
995
        self.assertONNX(lambda x, y, z: torch.meshgrid(x, y, z), (x, y, z))
996

997
    def test_meshgrid_indexing(self):
998
        x = torch.ones(3, requires_grad=True)
999
        y = torch.zeros(4, requires_grad=True)
1000
        z = torch.ones(5, requires_grad=True)
1001
        self.assertONNX(
1002
            lambda x, y, z: torch.meshgrid(x, y, z, indexing="xy"),
1003
            (x, y, z),
1004
            opset_version=9,
1005
        )
1006

1007
    def test_topk(self):
1008
        x = torch.arange(1.0, 6.0, requires_grad=True)
1009
        k = torch.tensor(3)
1010
        self.assertONNX(lambda x, k: torch.topk(x, k), (x, k), opset_version=10)
1011

1012
    def test_topk_smallest_unsorted(self):
1013
        x = torch.arange(1.0, 6.0, requires_grad=True)
1014
        k = torch.tensor(3)
1015
        self.assertONNX(
1016
            lambda x, k: torch.topk(x, k, largest=False, sorted=False),
1017
            (x, k),
1018
            opset_version=11,
1019
        )
1020

1021
    def test_baddbmm(self):
1022
        x = torch.randn(10, 3, 5)
1023
        b1 = torch.randn(10, 3, 4)
1024
        b2 = torch.randn(10, 4, 5)
1025
        self.assertONNX(lambda x, b1, b2: torch.baddbmm(x, b1, b2), (x, b1, b2))
1026

1027
    def test_round(self):
1028
        x = torch.tensor([0.9920, -1.0362, -1.5000, 2.5000], requires_grad=True)
1029
        self.assertONNX(lambda x: torch.round(x), x, opset_version=11)
1030

1031
    def test_dim(self):
1032
        x = torch.ones((2, 2), requires_grad=True)
1033
        self.assertONNX(lambda x: torch.scalar_tensor(x.dim()), x)
1034

1035
    @skipIfNoLapack
1036
    def test_det(self):
1037
        x = torch.randn(2, 3, 5, 5, device=torch.device("cpu"))
1038
        self.assertONNX(lambda x: torch.det(x), x, opset_version=11)
1039
        self.assertONNX(lambda x: torch.linalg.det(x), x, opset_version=11)
1040

1041
    def test_softmaxcrossentropy(self):
1042
        x = torch.randn(3, 5)
1043
        y = torch.empty(3, dtype=torch.long).random_(5)
1044
        self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12)
1045

1046
    def test_softmaxcrossentropy_ignore_index(self):
1047
        x = torch.randn(3, 5)
1048
        y = torch.empty(3, dtype=torch.long).random_(5)
1049
        self.assertONNX(
1050
            torch.nn.CrossEntropyLoss(ignore_index=1), (x, y), opset_version=12
1051
        )
1052

1053
    def test_softmaxcrossentropy_weights(self):
1054
        x = torch.randn(3, 5)
1055
        y = torch.empty(3, dtype=torch.long).random_(5)
1056
        self.assertONNX(
1057
            torch.nn.CrossEntropyLoss(weight=torch.randn(5)), (x, y), opset_version=12
1058
        )
1059

1060
    def test_softmaxcrossentropy_3d(self):
1061
        x = torch.randn(3, 5, 2)
1062
        y = torch.empty(3, 2, dtype=torch.long).random_(5)
1063
        self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12)
1064

1065
    def test_softmaxcrossentropy_3d_none(self):
1066
        x = torch.randn(3, 5, 2)
1067
        y = torch.empty(3, 2, dtype=torch.long).random_(5)
1068
        self.assertONNX(
1069
            torch.nn.CrossEntropyLoss(reduction="none"), (x, y), opset_version=12
1070
        )
1071

1072
    def test_softmaxcrossentropy_4d(self):
1073
        x = torch.randn(3, 5, 2, 1)
1074
        y = torch.empty(3, 2, 1, dtype=torch.long).random_(5)
1075
        self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12)
1076

1077
    def test_lstm_none_sequence_lens(self):
1078
        """Test symbolic shape inference for LSTM when the input sequence_lens = None."""
1079
        input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
1080
        h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
1081
        c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
1082

1083
        class LSTMModel(torch.nn.Module):
1084
            def __init__(self):
1085
                super().__init__()
1086
                self.rnn = torch.nn.LSTM(
1087
                    RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
1088
                )
1089

1090
            def forward(self, x, h0, c0):
1091
                a, b = self.rnn(x, (h0, c0))
1092
                return torch.ones(b[0].shape)
1093

1094
        self.assertONNX(
1095
            LSTMModel(),
1096
            (input, h0, c0),
1097
            input_names=["x", "y"],
1098
            dynamic_axes={"x": {0: "batch"}},
1099
            opset_version=12,
1100
        )
1101

1102
    def test_dynamic_axes_add(self):
1103
        m1 = torch.randn(2, 3, requires_grad=True)
1104
        m2 = torch.randn(2, 1, requires_grad=True)
1105
        self.assertONNX(
1106
            lambda x, y: torch.add(x, y),
1107
            (m1, m2),
1108
            input_names=["input_1", "input_2"],
1109
            dynamic_axes={"input_1": {1: "dim_1"}, "input_2": {1: "dim_2"}},
1110
            opset_version=12,
1111
        )
1112

1113
    def test_dynamic_axes_add_inputs_same_symbolic_shape(self):
1114
        m1 = torch.randn(2, 3, requires_grad=True)
1115
        self.assertONNX(
1116
            lambda x: torch.add(x, x),
1117
            (m1,),
1118
            input_names=["input_1"],
1119
            dynamic_axes={"input_1": {1: "dim_1"}},
1120
            opset_version=12,
1121
        )
1122

1123
    def test_dynamic_axes_matmul(self):
1124
        m1 = torch.randn(2, 2, 4, requires_grad=True)
1125
        m2 = torch.randn(2, 4, 3, requires_grad=True)
1126
        self.assertONNX(
1127
            lambda x, y: torch.matmul(x, y),
1128
            (m1, m2),
1129
            input_names=["input_1", "input_2"],
1130
            dynamic_axes={"input_1": {1: "dim_0"}, "input_2": {2: "dim_1"}},
1131
            opset_version=12,
1132
        )
1133

1134
    def test_dynamic_axes_reduce_mean(self):
1135
        m1 = torch.randn(2, 3, 4, requires_grad=True)
1136
        self.assertONNX(
1137
            lambda x: torch.mean(x, dim=1),
1138
            (m1),
1139
            input_names=["input"],
1140
            dynamic_axes={"input": {1: "dim_1", 2: "dim_2"}},
1141
            opset_version=12,
1142
        )
1143

1144
    def test_dynamic_axes_unchange(self):
1145
        """Test ProcessUnchangeNode in symbolic shape inference."""
1146
        m1 = torch.randn(2, 3, requires_grad=True)
1147
        self.assertONNX(
1148
            lambda x: torch.softmax(x, dim=0),
1149
            (m1,),
1150
            input_names=["input"],
1151
            dynamic_axes={"input": {1: "dim_1"}},
1152
            opset_version=12,
1153
        )
1154

1155
    def test_aten_embedding_1(self):
1156
        _onnx_opset_version = 12
1157

1158
        @parse_args("v", "v", "i", "b", "b")
1159
        def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
1160
            custom_attributes_json = (
1161
                "{"
1162
                f'"padding_idx":{str(padding_idx)},'
1163
                f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
1164
                f'"sparse":{str(sparse).lower()}'
1165
                "}"
1166
            )
1167
            output = g.at(
1168
                "embedding",
1169
                weight,
1170
                indices,
1171
                custom_attributes_json_s=custom_attributes_json,
1172
            )
1173
            return output
1174

1175
        torch.onnx.register_custom_op_symbolic(
1176
            "::embedding", embedding, _onnx_opset_version
1177
        )
1178

1179
        class Model(torch.nn.Module):
1180
            def __init__(self):
1181
                super().__init__()
1182
                self.emb = torch.nn.Embedding(4, 8)
1183

1184
            def forward(self, x, y):
1185
                res = self.emb(x)
1186
                res = res + y
1187
                return torch.ones(res.shape[0])
1188

1189
        model = Model()
1190
        x = torch.ones(32, dtype=torch.long)
1191
        y = torch.randn(1, 8)
1192
        self.assertONNX(model, (x, y), opset_version=_onnx_opset_version)
1193

1194
        torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version)
1195

1196
    # This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding.
1197
    @skipIfCaffe2
1198
    def test_aten_embedding_2(self):
1199
        _onnx_opset_version = 12
1200

1201
        @parse_args("v", "v", "i", "b", "b")
1202
        def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
1203
            custom_attributes_json = (
1204
                "{"
1205
                f'"padding_idx":{str(padding_idx)},'
1206
                f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
1207
                f'"sparse":{str(sparse).lower()}'
1208
                "}"
1209
            )
1210
            output = g.at(
1211
                "embedding",
1212
                weight,
1213
                indices,
1214
                custom_attributes_json_s=custom_attributes_json,
1215
            )
1216

1217
            # do shape inference and set it via setType
1218
            indices_shape = _get_tensor_sizes(indices)
1219
            if indices_shape is not None and hasattr(weight.type(), "with_sizes"):
1220
                output_type = weight.type().with_sizes(
1221
                    indices_shape + [_get_tensor_dim_size(weight, 1)]
1222
                )
1223
                output.setType(output_type)
1224
            return output
1225

1226
        torch.onnx.register_custom_op_symbolic(
1227
            "::embedding", embedding, _onnx_opset_version
1228
        )
1229

1230
        class Model(torch.nn.Module):
1231
            def __init__(self):
1232
                super().__init__()
1233
                self.emb = torch.nn.Embedding(4, 8)
1234

1235
            def forward(self, x, y):
1236
                res = self.emb(x)
1237
                res = res + y
1238
                return torch.ones(res.shape[0])
1239

1240
        model = Model()
1241
        x = torch.ones(32, dtype=torch.long)
1242
        y = torch.randn(1, 8)
1243
        self.assertONNX(
1244
            model,
1245
            (x, y),
1246
            opset_version=_onnx_opset_version,
1247
            input_names=["input_1", "input_2"],
1248
            dynamic_axes={"input_1": {0: "dim_0"}, "input_2": {0: "dim_1", 1: "dim_2"}},
1249
            keep_initializers_as_inputs=False,
1250
            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
1251
        )
1252

1253
        torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version)
1254

1255
    # Without shapeValueMap, the onnx graph looks like:
1256
    # graph(%0 : Float(*, 1, 128, 1, strides=[128, 128, 1, 1], requires_grad=0, device=cpu)):
1257
    #   %2 : Long(4, strides=[1], device=cpu) = onnx::Shape(%0)
1258
    #   %4 : Long(device=cpu) = onnx::Constant[value={0}]()
1259
    #   %5 : Long(device=cpu) = onnx::Gather[axis=0](%2, %4)
1260
    #   %6 : Long(device=cpu) = onnx::Constant[value={1}]()
1261
    #   %7 : Long(device=cpu) = onnx::Constant[value={2}]()
1262
    #   %8 : Long(device=cpu) = onnx::Constant[value={-1}]()
1263
    #   %9 : int[] = prim::ListConstruct(%5, %6, %7, %8)
1264
    #   %10 : Float(*, *, *, *, strides=[128, 128, 64, 1], requires_grad=0, device=cpu) = onnx::Reshape(%0, %9)
1265
    #   ...
1266
    # With shapeValueMap, it becomes:
1267
    #   ...
1268
    #   %10 : Float(*, 1, 2, 64, strides=[128, 128, 64, 1], requires_grad=0, device=cpu) = onnx::Reshape(%0, %9)
1269
    #   ...
1270
    def test_shape_value_map(self):
1271
        class RSoftMax(torch.nn.Module):
1272
            def __init__(self, radix, cardinality):
1273
                super().__init__()
1274
                self.radix = radix
1275
                self.cardinality = cardinality
1276

1277
            def forward(self, x):
1278
                batch = x.size(0)
1279
                x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
1280
                x = F.softmax(x, dim=1)
1281
                x = x.reshape(batch, -1)
1282
                return x
1283

1284
        radix = 2
1285
        cardinality = 1
1286
        x = torch.randn(10, 1, 128, 1)
1287
        self.assertONNX(
1288
            RSoftMax(radix, cardinality),
1289
            (x,),
1290
            input_names=["x"],
1291
            dynamic_axes={"x": {0: "dim_0"}},
1292
        )
1293

1294

1295
if __name__ == "__main__":
1296
    no_onnx_dep_flag = "--no-onnx"
1297
    _onnx_dep = no_onnx_dep_flag not in common_utils.UNITTEST_ARGS
1298
    if no_onnx_dep_flag in common_utils.UNITTEST_ARGS:
1299
        common_utils.UNITTEST_ARGS.remove(no_onnx_dep_flag)
1300
    onnx_test_flag = "--produce-onnx-test-data"
1301
    _onnx_test = onnx_test_flag in common_utils.UNITTEST_ARGS
1302
    if onnx_test_flag in common_utils.UNITTEST_ARGS:
1303
        common_utils.UNITTEST_ARGS.remove(onnx_test_flag)
1304
    if _onnx_test:
1305
        _onnx_dep = True
1306
        import onnx_test_common
1307

1308
        for d in glob.glob(
1309
            os.path.join(onnx_test_common.pytorch_operator_dir, "test_operator_*")
1310
        ):
1311
            shutil.rmtree(d)
1312
    common_utils.run_tests()
1313

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.