pytorch

Форк
0
/
test_utility_funs.py 
1985 строк · 71.7 Кб
1
# Owner(s): ["module: onnx"]
2

3
import copy
4
import functools
5
import io
6
import re
7
import warnings
8
from typing import Callable
9

10
import onnx
11
import parameterized
12
import pytorch_test_common
13

14
import torch
15
import torch.onnx
16
import torch.utils.cpp_extension
17
import torchvision
18
from autograd_helper import CustomFunction as CustomFunction2
19
from pytorch_test_common import (
20
    skipIfNoCuda,
21
    skipIfUnsupportedMaxOpsetVersion,
22
    skipIfUnsupportedMinOpsetVersion,
23
)
24
from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils
25
from torch.onnx._globals import GLOBALS
26
from torch.onnx.symbolic_helper import _unpack_list, parse_args
27
from torch.testing._internal import common_utils
28
from torch.testing._internal.common_utils import skipIfNoCaffe2, skipIfNoLapack
29
from verify import verify
30

31

32
def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str:
33
    """Remove test environment prefix added to module.
34

35
    Remove prefix to normalize scope names, since different test environments add
36
    prefixes with slight differences.
37

38
    Example:
39

40
        >>> _remove_test_environment_prefix_from_scope_name(
41
        >>>     "test_utility_funs.M"
42
        >>> )
43
        "M"
44
        >>> _remove_test_environment_prefix_from_scope_name(
45
        >>>     "test_utility_funs.test_abc.<locals>.M"
46
        >>> )
47
        "M"
48
        >>> _remove_test_environment_prefix_from_scope_name(
49
        >>>     "__main__.M"
50
        >>> )
51
        "M"
52
    """
53
    prefixes_to_remove = ["test_utility_funs", "__main__"]
54
    for prefix in prefixes_to_remove:
55
        scope_name = re.sub(f"{prefix}\\.(.*?<locals>\\.)?", "", scope_name)
56
    return scope_name
57

58

59
class _BaseTestCase(pytorch_test_common.ExportTestCase):
60
    def _model_to_graph(
61
        self,
62
        model,
63
        input,
64
        do_constant_folding=True,
65
        training=TrainingMode.EVAL,
66
        operator_export_type=OperatorExportTypes.ONNX,
67
        input_names=None,
68
        dynamic_axes=None,
69
    ):
70
        torch.onnx.utils._setup_trace_module_map(model, False)
71
        if training == torch.onnx.TrainingMode.TRAINING:
72
            model.train()
73
        elif training == torch.onnx.TrainingMode.EVAL:
74
            model.eval()
75
        utils._validate_dynamic_axes(dynamic_axes, model, None, None)
76
        graph, params_dict, torch_out = utils._model_to_graph(
77
            model,
78
            input,
79
            do_constant_folding=do_constant_folding,
80
            _disable_torch_constant_prop=True,
81
            operator_export_type=operator_export_type,
82
            training=training,
83
            input_names=input_names,
84
            dynamic_axes=dynamic_axes,
85
        )
86
        return graph, params_dict, torch_out
87

88

89
@common_utils.instantiate_parametrized_tests
90
class TestUnconvertibleOps(pytorch_test_common.ExportTestCase):
91
    """Unit tests for the `unconvertible_ops` function."""
92

93
    def setUp(self):
94
        class EinsumModule(torch.nn.Module):
95
            def forward(self, x):
96
                return torch.einsum("ii", x)
97

98
        self.einsum_module = EinsumModule()
99

100
    def test_it_returns_graph_and_unconvertible_ops_at_lower_opset_version(self):
101
        x = torch.randn(4, 4)
102

103
        # Einsum is supported since opset 12. It should be unconvertible at opset 9.
104
        graph, unconvertible_ops = utils.unconvertible_ops(
105
            self.einsum_module, (x,), opset_version=9
106
        )
107
        nodes = graph.nodes()
108
        self.assertEqual(next(nodes).kind(), "prim::Constant")
109
        self.assertEqual(next(nodes).kind(), "prim::ListConstruct")
110
        self.assertEqual(next(nodes).kind(), "prim::Constant")
111
        self.assertEqual(next(nodes).kind(), "aten::einsum")
112
        self.assertEqual(unconvertible_ops, ["aten::einsum"])
113

114
    @common_utils.parametrize(
115
        "jit_function",
116
        [
117
            common_utils.subtest(
118
                functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
119
                name="traced",
120
            ),
121
            common_utils.subtest(torch.jit.script, name="scripted"),
122
        ],
123
    )
124
    def test_it_returns_unconvertible_ops_at_lower_opset_version_for_jit_module(
125
        self, jit_function: Callable
126
    ):
127
        module = jit_function(self.einsum_module)
128
        x = torch.randn(4, 4)
129

130
        # Einsum is supported since opset 12. It should be unconvertible at opset 9.
131
        _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=9)
132
        self.assertEqual(unconvertible_ops, ["aten::einsum"])
133

134
    @common_utils.parametrize(
135
        "jit_function",
136
        [
137
            common_utils.subtest(lambda x: x, name="nn_module"),
138
            common_utils.subtest(
139
                functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
140
                name="traced",
141
            ),
142
            common_utils.subtest(torch.jit.script, name="scripted"),
143
        ],
144
    )
145
    def test_it_returns_empty_list_when_all_ops_convertible(
146
        self, jit_function: Callable
147
    ):
148
        module = jit_function(self.einsum_module)
149
        x = torch.randn(4, 4)
150

151
        # Einsum is supported since opset 12
152
        _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12)
153
        self.assertEqual(unconvertible_ops, [])
154

155
    def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self):
156
        class SkipConnectionModule(torch.nn.Module):
157
            def forward(self, x):
158
                out = x
159
                out += x
160
                out = torch.nn.functional.relu(out, inplace=True)
161
                return out
162

163
        module = SkipConnectionModule()
164
        x = torch.randn(4, 4)
165
        _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13)
166
        self.assertEqual(unconvertible_ops, [])
167

168

169
@parameterized.parameterized_class(
170
    [
171
        {"opset_version": opset}
172
        for opset in range(
173
            _constants.ONNX_BASE_OPSET,
174
            _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET + 1,
175
        )
176
    ],
177
    class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}",
178
)
179
class TestUtilityFuns(_BaseTestCase):
180
    opset_version = None
181

182
    def test_is_in_onnx_export(self):
183
        test_self = self
184

185
        class MyModule(torch.nn.Module):
186
            def forward(self, x):
187
                test_self.assertTrue(torch.onnx.is_in_onnx_export())
188
                raise ValueError
189
                return x + 1
190

191
        x = torch.randn(3, 4)
192
        f = io.BytesIO()
193
        try:
194
            torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
195
        except ValueError:
196
            self.assertFalse(torch.onnx.is_in_onnx_export())
197

198
    def test_validate_dynamic_axes_invalid_input_output_name(self):
199
        with warnings.catch_warnings(record=True) as w:
200
            warnings.simplefilter("always")
201
            utils._validate_dynamic_axes(
202
                {"input1": {}, "output": {}, "invalid_name1": {}, "invalid_name2": {}},
203
                None,
204
                ["input1", "input2"],
205
                ["output"],
206
            )
207
            messages = [str(warning.message) for warning in w]
208
        self.assertIn(
209
            "Provided key invalid_name1 for dynamic axes is not a valid input/output name",
210
            messages,
211
        )
212
        self.assertIn(
213
            "Provided key invalid_name2 for dynamic axes is not a valid input/output name",
214
            messages,
215
        )
216
        self.assertEqual(len(messages), 2)
217

218
    @skipIfUnsupportedMinOpsetVersion(11)
219
    def test_split_to_slice(self):
220
        class SplitModule(torch.nn.Module):
221
            def forward(self, x, y, t):
222
                splits = (x.size(1), y.size(1))
223
                out, out2 = torch.split(t, splits, dim=1)
224
                return out, out2
225

226
        GLOBALS.export_onnx_opset_version = self.opset_version
227
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
228
        x = torch.randn(2, 3)
229
        y = torch.randn(2, 4)
230
        t = torch.randn(2, 7)
231
        graph, _, _ = self._model_to_graph(
232
            SplitModule(),
233
            (x, y, t),
234
            input_names=["x", "y", "t"],
235
            dynamic_axes={"x": [0, 1], "y": [0, 1], "t": [0, 1]},
236
        )
237
        for node in graph.nodes():
238
            self.assertNotEqual(node.kind(), "onnx::SplitToSequence")
239

240
    def test_constant_fold_transpose(self):
241
        class TransposeModule(torch.nn.Module):
242
            def forward(self, x):
243
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
244
                b = torch.transpose(a, 1, 0)
245
                return b + x
246

247
        GLOBALS.export_onnx_opset_version = self.opset_version
248
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
249
        x = torch.ones(3, 2)
250
        graph, _, __ = self._model_to_graph(
251
            TransposeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
252
        )
253

254
        for node in graph.nodes():
255
            self.assertNotEqual(node.kind(), "onnx::Transpose")
256
            self.assertNotEqual(node.kind(), "onnx::Cast")
257
        self.assertEqual(len(list(graph.nodes())), 2)
258

259
    def test_constant_fold_reduceL2(self):
260
        class ReduceModule(torch.nn.Module):
261
            def forward(self, x):
262
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
263
                b = torch.norm(a, p=2, dim=-2, keepdim=False)
264
                return b + x
265

266
        GLOBALS.export_onnx_opset_version = self.opset_version
267
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
268
        x = torch.ones(2, 3)
269
        graph, _, __ = self._model_to_graph(
270
            ReduceModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
271
        )
272

273
        for node in graph.nodes():
274
            self.assertNotEqual(node.kind(), "onnx::ReduceL2")
275

276
    def test_constant_fold_reduceL1(self):
277
        class NormModule(torch.nn.Module):
278
            def forward(self, x):
279
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
280
                b = torch.norm(a, p=1, dim=-2)
281
                return b + x
282

283
        GLOBALS.export_onnx_opset_version = self.opset_version
284
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
285
        x = torch.ones(2, 3)
286
        graph, _, __ = self._model_to_graph(
287
            NormModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
288
        )
289

290
        for node in graph.nodes():
291
            self.assertNotEqual(node.kind(), "onnx::ReduceL1")
292

293
    def test_constant_fold_slice(self):
294
        class NarrowModule(torch.nn.Module):
295
            def forward(self, x):
296
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
297
                b = torch.narrow(a, 0, 0, 1)
298
                return b + x
299

300
        GLOBALS.export_onnx_opset_version = self.opset_version
301
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
302
        x = torch.ones(1, 3)
303
        graph, _, __ = self._model_to_graph(
304
            NarrowModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
305
        )
306

307
        for node in graph.nodes():
308
            self.assertNotEqual(node.kind(), "onnx::Slice")
309
            self.assertNotEqual(node.kind(), "onnx::Cast")
310
        self.assertEqual(len(list(graph.nodes())), 2)
311

312
    def test_constant_fold_slice_index_exceeds_dim(self):
313
        class SliceIndexExceedsDimModule(torch.nn.Module):
314
            def forward(self, x):
315
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
316
                b = a[1:10]  # index exceeds dimension
317
                return b + x
318

319
        GLOBALS.export_onnx_opset_version = self.opset_version
320
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
321
        x = torch.ones(1, 3)
322
        graph, _, __ = self._model_to_graph(
323
            SliceIndexExceedsDimModule(),
324
            (x,),
325
            input_names=["x"],
326
            dynamic_axes={"x": [0, 1]},
327
        )
328

329
        for node in graph.nodes():
330
            self.assertNotEqual(node.kind(), "onnx::Slice")
331
            self.assertNotEqual(node.kind(), "onnx::Cast")
332
        self.assertEqual(len(list(graph.nodes())), 2)
333

334
    def test_constant_fold_slice_negative_index(self):
335
        class SliceNegativeIndexModule(torch.nn.Module):
336
            def forward(self, x):
337
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
338
                b = a[0:-1]  # index relative to the end
339
                c = torch.select(a, dim=-1, index=-2)
340
                d = torch.select(a, dim=1, index=0)
341
                return b + x, c + d
342

343
        GLOBALS.export_onnx_opset_version = self.opset_version
344
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
345
        x = torch.ones(1, 3)
346
        graph, _, __ = self._model_to_graph(
347
            SliceNegativeIndexModule(),
348
            (x,),
349
            input_names=["x"],
350
            dynamic_axes={"x": [0, 1]},
351
        )
352

353
        for node in graph.nodes():
354
            self.assertNotEqual(node.kind(), "onnx::Slice")
355
            self.assertNotEqual(node.kind(), "onnx::Cast")
356

357
    def test_constant_fold_gather(self):
358
        class GatherModule(torch.nn.Module):
359
            def forward(self, x):
360
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
361
                b = torch.select(a, dim=1, index=-2)
362
                c = torch.index_select(a, dim=-2, index=torch.tensor([0, 1]))
363
                return b + 1, c + x
364

365
        GLOBALS.export_onnx_opset_version = self.opset_version
366
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
367
        x = torch.ones(1, 3)
368
        model = GatherModule()
369
        model(x)
370
        graph, _, __ = self._model_to_graph(
371
            GatherModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
372
        )
373

374
        for node in graph.nodes():
375
            self.assertNotEqual(node.kind(), "onnx::Gather")
376

377
    def test_constant_fold_unsqueeze(self):
378
        class UnsqueezeModule(torch.nn.Module):
379
            def forward(self, x):
380
                a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
381
                b = torch.unsqueeze(a, -2)
382
                return b + x
383

384
        GLOBALS.export_onnx_opset_version = self.opset_version
385
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
386
        x = torch.ones(1, 2, 3)
387
        graph, _, __ = self._model_to_graph(
388
            UnsqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
389
        )
390

391
        for node in graph.nodes():
392
            self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
393
            self.assertNotEqual(node.kind(), "onnx::Cast")
394
        self.assertEqual(len(list(graph.nodes())), 2)
395

396
    def test_constant_fold_unsqueeze_multi_axies(self):
397
        class PReluModel(torch.nn.Module):
398
            def __init__(self):
399
                super().__init__()
400
                self.prelu = torch.nn.PReLU()
401

402
            def forward(self, x):
403
                a = torch.randn(2, 3, 4, 5, 8, 7)
404
                return self.prelu(x) + a
405

406
        GLOBALS.export_onnx_opset_version = self.opset_version
407
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
408
        x = torch.randn(2, 3, 4, 5, 8, 7)
409
        graph, _, __ = self._model_to_graph(
410
            PReluModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3, 4, 5]}
411
        )
412

413
        for node in graph.nodes():
414
            self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
415
            self.assertNotEqual(node.kind(), "onnx::Cast")
416
        self.assertEqual(len(list(graph.nodes())), 5)
417

418
    def test_constant_fold_squeeze_without_axes(self):
419
        class SqueezeModule(torch.nn.Module):
420
            def forward(self, x):
421
                a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
422
                return torch.squeeze(a) + x + torch.squeeze(a)
423

424
        GLOBALS.export_onnx_opset_version = self.opset_version
425
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
426
        x = torch.ones(2, 3)
427
        graph, _, __ = self._model_to_graph(
428
            SqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
429
        )
430
        for node in graph.nodes():
431
            self.assertNotEqual(node.kind(), "onnx::Squeeze")
432
            self.assertNotEqual(node.kind(), "onnx::Cast")
433
        self.assertEqual(len(list(graph.nodes())), 4)
434

435
    def test_constant_fold_squeeze_with_axes(self):
436
        class SqueezeAxesModule(torch.nn.Module):
437
            def forward(self, x):
438
                a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
439
                return torch.squeeze(a, dim=-3) + x
440

441
        GLOBALS.export_onnx_opset_version = self.opset_version
442
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
443
        x = torch.ones(2, 3)
444
        graph, _, __ = self._model_to_graph(
445
            SqueezeAxesModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
446
        )
447

448
        for node in graph.nodes():
449
            self.assertNotEqual(node.kind(), "onnx::Squeeze")
450
            self.assertNotEqual(node.kind(), "onnx::Cast")
451
        self.assertEqual(len(list(graph.nodes())), 2)
452

453
    def test_constant_fold_concat(self):
454
        class ConcatModule(torch.nn.Module):
455
            def forward(self, x):
456
                # Why did I insert a Cast here?  There appears to be intentional
457
                # behavior in ONNX constant folding where constant tensors which
458
                # are not attached to any known to be foldable onnx
459
                # operations don't get extracted into the initializer graph.  So
460
                # without these casts, we will actually fail to pull out one of
461
                # the constants, thus failing constant folding.  I think the
462
                # test is wrong but I don't have time to write a more correct
463
                # test (I think the right way to go about the test is to setup
464
                # a predicate for what invariant graphs should hold after
465
                # constant folding, and then verify this predicate holds.
466
                # I think the asserts below are an attempt at this predicate,
467
                # but it is not right!)
468
                #
469
                # More commentary at
470
                # https://github.com/pytorch/pytorch/pull/18698/files#r340107552
471
                a = torch.tensor([[1.0, 2.0, 3.0]]).to(torch.float)
472
                b = torch.tensor([[4.0, 5.0, 6.0]]).to(torch.float)
473
                c = torch.cat((a, b), 0)
474
                d = b + c
475
                return x + d
476

477
        GLOBALS.export_onnx_opset_version = self.opset_version
478
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
479
        x = torch.ones(2, 3)
480
        graph, _, __ = self._model_to_graph(
481
            ConcatModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
482
        )
483

484
        for node in graph.nodes():
485
            self.assertNotEqual(node.kind(), "onnx::Concat")
486
            self.assertNotEqual(node.kind(), "onnx::Cast")
487
        self.assertEqual(len(list(graph.nodes())), 2)
488

489
    def test_constant_fold_lstm(self):
490
        class GruNet(torch.nn.Module):
491
            def __init__(self):
492
                super().__init__()
493
                self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False)
494

495
            def forward(self, input, initial_state):
496
                return self.mygru(input, initial_state)
497

498
        GLOBALS.export_onnx_opset_version = self.opset_version
499
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
500
        input = torch.randn(5, 3, 7)
501
        h0 = torch.randn(1, 3, 3)
502
        graph, _, __ = self._model_to_graph(
503
            GruNet(),
504
            (input, h0),
505
            input_names=["input", "h0"],
506
            dynamic_axes={"input": [0, 1, 2], "h0": [0, 1, 2]},
507
        )
508

509
        for node in graph.nodes():
510
            self.assertNotEqual(node.kind(), "onnx::Slice")
511
            self.assertNotEqual(node.kind(), "onnx::Concat")
512
            self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
513

514
        if self.opset_version <= 12:
515
            self.assertEqual(len(list(graph.nodes())), 3)
516
        else:
517
            # Unsqueeze op parameter "axes" as an input instead of as an attribute when opset version >= 13
518
            self.assertEqual(len(list(graph.nodes())), 4)
519

520
    def test_constant_fold_transpose_matmul(self):
521
        class MatMulNet(torch.nn.Module):
522
            def __init__(self):
523
                super().__init__()
524
                self.B = torch.nn.Parameter(torch.ones(5, 3))
525

526
            def forward(self, A):
527
                return torch.matmul(A, torch.transpose(self.B, -1, -2))
528

529
        GLOBALS.export_onnx_opset_version = self.opset_version
530
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
531
        A = torch.randn(2, 3)
532
        graph, _, __ = self._model_to_graph(
533
            MatMulNet(), (A,), input_names=["A"], dynamic_axes={"A": [0, 1]}
534
        )
535

536
        for node in graph.nodes():
537
            self.assertNotEqual(node.kind(), "onnx::Transpose")
538
        self.assertEqual(len(list(graph.nodes())), 1)
539

540
    def test_constant_fold_reshape(self):
541
        class ReshapeModule(torch.nn.Module):
542
            def __init__(
543
                self,
544
            ):
545
                super().__init__()
546
                self.register_buffer("weight", torch.ones(5))
547

548
            def forward(self, x):
549
                b = self.weight.reshape(1, -1, 1, 1)
550
                return x * b
551

552
        GLOBALS.export_onnx_opset_version = self.opset_version
553
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
554
        x = torch.randn(4, 5)
555
        graph, _, __ = self._model_to_graph(
556
            ReshapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
557
        )
558

559
        for node in graph.nodes():
560
            self.assertNotEqual(node.kind(), "onnx::Reshape")
561
        self.assertEqual(len(list(graph.nodes())), 1)
562

563
    def test_constant_fold_div(self):
564
        class Module(torch.nn.Module):
565
            def __init__(
566
                self,
567
            ):
568
                super().__init__()
569
                self.register_buffer("weight", torch.ones(5))
570

571
            def forward(self, x):
572
                div = self.weight.div(torch.tensor([1, 2, 3, 4, 5]))
573
                return div * x
574

575
        x = torch.randn(2, 5)
576
        GLOBALS.export_onnx_opset_version = self.opset_version
577
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
578
        graph, _, __ = self._model_to_graph(
579
            Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
580
        )
581

582
        for node in graph.nodes():
583
            self.assertNotEqual(node.kind(), "onnx::Div")
584
        self.assertEqual(len(list(graph.nodes())), 1)
585

586
    def test_constant_fold_mul(self):
587
        class Module(torch.nn.Module):
588
            def __init__(
589
                self,
590
            ):
591
                super().__init__()
592
                self.register_buffer("weight", torch.ones(5))
593

594
            def forward(self, x):
595
                mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5]))
596
                return mul / x
597

598
        x = torch.randn(2, 5)
599
        GLOBALS.export_onnx_opset_version = self.opset_version
600
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
601
        graph, _, __ = self._model_to_graph(
602
            Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
603
        )
604

605
        for node in graph.nodes():
606
            self.assertNotEqual(node.kind(), "onnx::Mul")
607
        self.assertEqual(len(list(graph.nodes())), 1)
608

609
    def test_constant_fold_add(self):
610
        class Module(torch.nn.Module):
611
            def __init__(
612
                self,
613
            ):
614
                super().__init__()
615
                self.register_buffer("weight", torch.ones(5))
616

617
            def forward(self, x):
618
                add = self.weight + torch.tensor([1, 2, 3, 4, 5])
619
                return add - x
620

621
        x = torch.randn(2, 5)
622
        GLOBALS.export_onnx_opset_version = self.opset_version
623
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
624
        graph, params_dict, __ = self._model_to_graph(
625
            Module(),
626
            (x,),
627
            do_constant_folding=True,
628
            operator_export_type=OperatorExportTypes.ONNX,
629
            input_names=["x"],
630
            dynamic_axes={"x": [0, 1]},
631
        )
632
        for node in graph.nodes():
633
            self.assertTrue(node.kind() != "onnx::Add")
634
        self.assertEqual(len(list(graph.nodes())), 1)
635
        params = list(params_dict.values())
636
        self.assertEqual(len(params), 1)
637
        weight = params[0]
638
        self.assertEqual(weight, torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0]))
639

640
    def test_constant_fold_sub(self):
641
        class Module(torch.nn.Module):
642
            def __init__(
643
                self,
644
            ):
645
                super().__init__()
646
                self.register_buffer("weight", torch.ones(5))
647

648
            def forward(self, x):
649
                sub = self.weight - torch.tensor([1, 2, 3, 4, 5])
650
                return sub + x
651

652
        x = torch.randn(2, 5)
653
        GLOBALS.export_onnx_opset_version = self.opset_version
654
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
655
        graph, params_dict, __ = self._model_to_graph(
656
            Module(),
657
            (x,),
658
            do_constant_folding=True,
659
            operator_export_type=OperatorExportTypes.ONNX,
660
            input_names=["x"],
661
            dynamic_axes={"x": [0, 1]},
662
        )
663
        for node in graph.nodes():
664
            self.assertNotEqual(node.kind(), "onnx::Sub")
665
        self.assertEqual(len(list(graph.nodes())), 1)
666
        params = list(params_dict.values())
667
        self.assertEqual(len(params), 1)
668
        weight = params[0]
669
        self.assertEqual(weight, torch.tensor([0.0, -1.0, -2.0, -3.0, -4.0]))
670

671
    def test_constant_fold_sqrt(self):
672
        class Module(torch.nn.Module):
673
            def __init__(
674
                self,
675
            ):
676
                super().__init__()
677
                self.register_buffer("weight", torch.ones(5))
678

679
            def forward(self, x):
680
                sqrt = torch.sqrt(self.weight)
681
                return sqrt / x
682

683
        x = torch.randn(2, 5)
684
        GLOBALS.export_onnx_opset_version = self.opset_version
685
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
686
        graph, _, __ = self._model_to_graph(
687
            Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
688
        )
689
        for node in graph.nodes():
690
            self.assertNotEqual(node.kind(), "onnx::Sqrt")
691
        self.assertEqual(len(list(graph.nodes())), 1)
692

693
    def test_constant_fold_shape(self):
694
        class ShapeModule(torch.nn.Module):
695
            def __init__(self):
696
                super().__init__()
697
                self.register_buffer("weight", torch.ones(5))
698

699
            def forward(self, x):
700
                shape = self.weight.shape[0]
701
                return x + shape
702

703
        x = torch.randn(2, 5)
704
        GLOBALS.export_onnx_opset_version = self.opset_version
705
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
706
        graph, _, __ = self._model_to_graph(
707
            ShapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
708
        )
709
        for node in graph.nodes():
710
            self.assertNotEqual(node.kind(), "onnx::Shape")
711
        self.assertEqual(len(list(graph.nodes())), 2)
712

713
    def test_constant_fold_upsample_scale_fold_as_constant(self):
714
        # upsample scale is a constant, not a model parameter,
715
        # therefore should not be added as initializer after constant folding.
716
        model = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
717
        x = torch.randn(1, 32, 224, 224)
718
        f = io.BytesIO()
719
        torch.onnx.export(model, x, f)
720
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
721
        self.assertEqual(len(onnx_model.graph.initializer), 0)
722

723
    def test_verbose(self):
724
        class MyModule(torch.nn.Module):
725
            def forward(self, input):
726
                return torch.exp(input)
727

728
        x = torch.randn(3, 4)
729

730
        def is_model_stripped(f, verbose=None):
731
            if verbose is None:
732
                torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
733
            else:
734
                torch.onnx.export(
735
                    MyModule(), x, f, verbose=verbose, opset_version=self.opset_version
736
                )
737
            model = onnx.load(io.BytesIO(f.getvalue()))
738
            model_strip = copy.copy(model)
739
            onnx.helper.strip_doc_string(model_strip)
740
            return model == model_strip
741

742
        # test verbose=False (default)
743
        self.assertTrue(is_model_stripped(io.BytesIO()))
744
        # test verbose=True
745
        self.assertFalse(is_model_stripped(io.BytesIO(), True))
746

747
    # NB: remove this test once DataParallel can be correctly handled
748
    def test_error_on_data_parallel(self):
749
        model = torch.nn.DataParallel(torch.nn.ReflectionPad2d((1, 2, 3, 4)))
750
        x = torch.randn(1, 2, 3, 4)
751
        f = io.BytesIO()
752
        with self.assertRaisesRegex(
753
            ValueError,
754
            "torch.nn.DataParallel is not supported by ONNX "
755
            "exporter, please use 'attribute' module to "
756
            "unwrap model from torch.nn.DataParallel. Try ",
757
        ):
758
            torch.onnx.export(model, x, f, opset_version=self.opset_version)
759

760
    @skipIfUnsupportedMinOpsetVersion(11)
761
    def test_sequence_dim(self):
762
        class Module(torch.nn.Module):
763
            def forward(self, x, y):
764
                return [x, y]
765

766
        model = Module()
767
        # Export with scripting to keep output as Sequence type.
768
        # Tracing unpacks the list.
769
        script_model = torch.jit.script(model)
770
        x = torch.randn(2, 3)
771

772
        # Case 1: dynamic axis
773
        f = io.BytesIO()
774
        y = torch.randn(2, 3)
775
        torch.onnx.export(
776
            script_model,
777
            (x, y),
778
            f,
779
            opset_version=self.opset_version,
780
            input_names=["x", "y"],
781
            dynamic_axes={"y": [1]},
782
        )
783
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
784
        loop_output_value_info_proto = onnx_model.graph.output[0]
785
        ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(
786
            loop_output_value_info_proto.name, 1, [2, None]
787
        )
788
        self.assertEqual(loop_output_value_info_proto, ref_value_info_proto)
789

790
        # Case 2: no dynamic axes.
791
        f = io.BytesIO()
792
        y = torch.randn(2, 3)
793
        torch.onnx.export(script_model, (x, y), f, opset_version=self.opset_version)
794
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
795
        loop_output_value_info_proto = onnx_model.graph.output[0]
796
        ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(
797
            loop_output_value_info_proto.name, 1, [2, 3]
798
        )
799
        self.assertEqual(loop_output_value_info_proto, ref_value_info_proto)
800

801
    def test_export_mode(self):
802
        class MyModule(torch.nn.Module):
803
            def forward(self, x):
804
                y = x + 1
805
                return y
806

807
        model = MyModule()
808
        x = torch.randn(10, 3, 128, 128)
809
        f = io.BytesIO()
810

811
        # set mode to in inference mode and export in training mode
812
        model.eval()
813
        old_state = model.training
814
        torch.onnx.export(
815
            model,
816
            (x,),
817
            f,
818
            opset_version=self.opset_version,
819
            training=torch.onnx.TrainingMode.TRAINING,
820
        )
821
        # verify that the model state is preserved
822
        self.assertEqual(model.training, old_state)
823

824
        # set mode to training mode and export in inference mode
825
        model.train()
826
        old_state = model.training
827
        torch.onnx.export(
828
            model,
829
            (x,),
830
            f,
831
            opset_version=self.opset_version,
832
            training=torch.onnx.TrainingMode.EVAL,
833
        )
834
        # verify that the model state is preserved
835
        self.assertEqual(model.training, old_state)
836

837
    def test_export_does_not_fail_on_frozen_scripted_module(self):
838
        class Inner(torch.nn.Module):
839
            def forward(self, x):
840
                if x > 0:
841
                    return x
842
                else:
843
                    return x * x
844

845
        class Outer(torch.nn.Module):
846
            def __init__(self):
847
                super().__init__()
848
                self.inner = torch.jit.script(Inner())
849

850
            def forward(self, x):
851
                return self.inner(x)
852

853
        x = torch.zeros(1)
854
        # Freezing is only implemented in eval mode. So we need to call eval()
855
        outer_module = Outer().eval()
856
        module = torch.jit.trace_module(outer_module, {"forward": (x)})
857
        # jit.freeze removes the training attribute in the module
858
        module = torch.jit.freeze(module)
859

860
        torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version)
861

862
    @skipIfUnsupportedMinOpsetVersion(15)
863
    def test_local_function(self):
864
        class N(torch.nn.Module):
865
            def __init__(self, prob):
866
                super().__init__()
867
                self.dropout = torch.nn.Dropout(prob)
868

869
            def forward(self, x):
870
                return self.dropout(x)
871

872
        class M(torch.nn.Module):
873
            def __init__(self, num_layers):
874
                super().__init__()
875
                self.num_layers = num_layers
876
                self.lns = torch.nn.ModuleList(
877
                    [torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)]
878
                )
879
                self.celu1 = torch.nn.CELU(1.0)
880
                self.celu2 = torch.nn.CELU(2.0)
881
                self.dropout = N(0.5)
882

883
            def forward(self, x, y, z):
884
                res1 = self.celu1(x)
885
                res2 = self.celu2(y)
886
                for ln in self.lns:
887
                    z = ln(z)
888
                return res1 + res2, self.dropout(z)
889

890
        x = torch.randn(2, 3)
891
        y = torch.randn(2, 3)
892
        z = torch.randn(2, 3)
893

894
        # Export specified modules. Test against specifying modules that won't
895
        # exist in the exported model.
896
        # Model export in inference mode will remove dropout node,
897
        # thus the dropout module no longer exist in graph.
898
        f = io.BytesIO()
899
        torch.onnx.export(
900
            M(3),
901
            (x, y, z),
902
            f,
903
            opset_version=self.opset_version,
904
            export_modules_as_functions={
905
                torch.nn.CELU,
906
                torch.nn.Dropout,
907
                torch.nn.LayerNorm,
908
            },
909
        )
910

911
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
912

913
        # Check function definition
914
        funcs = onnx_model.functions
915
        celu_funcs = [f for f in funcs if f.name == "CELU"]
916
        self.assertEqual(len(celu_funcs), 1)
917
        self.assertEqual(celu_funcs[0].domain, "torch.nn.modules.activation")
918
        self.assertEqual(len(celu_funcs[0].attribute), 3)
919
        ln_funcs = [f for f in funcs if f.name == "LayerNorm"]
920
        self.assertEqual(len(ln_funcs), 1)
921
        self.assertEqual(ln_funcs[0].domain, "torch.nn.modules.normalization")
922
        self.assertEqual(len(ln_funcs[0].attribute), 3)
923

924
        # Check local function nodes
925
        nodes = onnx_model.graph.node
926
        celu_ns = [n for n in nodes if n.op_type == "CELU"]
927
        ln_ns = [n for n in nodes if n.op_type == "LayerNorm"]
928
        self.assertEqual(len(celu_ns), 2)
929
        self.assertEqual(celu_ns[0].domain, "torch.nn.modules.activation")
930
        self.assertEqual(len(celu_ns[0].attribute), 3)
931
        self.assertEqual(len(ln_ns), 3)
932
        self.assertEqual(ln_ns[0].domain, "torch.nn.modules.normalization")
933
        self.assertEqual(len(ln_ns[0].attribute), 3)
934

935
        # Export specified modules.
936
        f = io.BytesIO()
937
        torch.onnx.export(
938
            M(3),
939
            (x, y, z),
940
            f,
941
            opset_version=self.opset_version,
942
            export_modules_as_functions={torch.nn.CELU},
943
        )
944

945
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
946
        funcs = onnx_model.functions
947
        self.assertEqual(len(funcs), 1)
948
        self.assertEqual(funcs[0].name, "CELU")
949

950
        # Export with empty specified modules. Normal export.
951
        f = io.BytesIO()
952
        torch.onnx.export(
953
            M(3),
954
            (x, y, z),
955
            f,
956
            opset_version=self.opset_version,
957
            export_modules_as_functions=set(),
958
        )
959

960
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
961
        funcs = onnx_model.functions
962
        self.assertEqual(len(funcs), 0)
963

964
        # Export all modules. Should contain {M, CELU, LayerNorm}.
965
        f = io.BytesIO()
966
        torch.onnx.export(
967
            M(3),
968
            (x, y, z),
969
            f,
970
            opset_version=self.opset_version,
971
            export_modules_as_functions=True,
972
        )
973

974
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
975
        funcs = onnx_model.functions
976
        self.assertEqual(len(funcs), 3)
977

978
    @skipIfUnsupportedMinOpsetVersion(15)
979
    def test_local_function_overloads(self):
980
        class NWithOverloads(torch.nn.Module):
981
            def forward(self, x, y=None, z=None):
982
                if y is None:
983
                    return x + 1
984
                elif z is None:
985
                    return x + y
986
                else:
987
                    return x + y, x + z
988

989
        class M(torch.nn.Module):
990
            def __init__(self, num_layers):
991
                super().__init__()
992
                self.n = NWithOverloads()
993

994
            def forward(self, x, y, z):
995
                return self.n(x), self.n(x, y), self.n(x, y, z)
996

997
        x = torch.randn(2, 3)
998
        y = torch.randn(2, 3)
999
        z = torch.randn(2, 3)
1000

1001
        f = io.BytesIO()
1002
        torch.onnx.export(
1003
            M(3),
1004
            (x, y, z),
1005
            f,
1006
            opset_version=self.opset_version,
1007
            export_modules_as_functions={NWithOverloads},
1008
        )
1009

1010
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1011
        funcs = onnx_model.functions
1012
        self.assertEqual(len(funcs), 3)
1013
        func_names = [f.name for f in funcs]
1014
        self.assertIn("NWithOverloads", func_names)
1015
        self.assertIn("NWithOverloads.1", func_names)
1016
        self.assertIn("NWithOverloads.2", func_names)
1017

1018
    # Failing after ONNX 1.13.0
1019
    @skipIfUnsupportedMaxOpsetVersion(1)
1020
    def test_local_function_infer_scopes(self):
1021
        class M(torch.nn.Module):
1022
            def forward(self, x):
1023
                # Concatenation of scalars inserts unscoped tensors in IR graph.
1024
                new_tensor_shape = x.size()[:-1] + (1, 1, -1)
1025
                tensor = x.view(*new_tensor_shape)
1026
                return tensor
1027

1028
        x = torch.randn(4, 5)
1029
        f = io.BytesIO()
1030
        torch.onnx.export(
1031
            M(),
1032
            (x,),
1033
            f,
1034
            export_modules_as_functions=True,
1035
            opset_version=self.opset_version,
1036
            do_constant_folding=False,
1037
        )
1038

1039
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1040
        funcs = onnx_model.functions
1041
        self.assertIn("M", [f.name for f in funcs])
1042

1043
    @skipIfUnsupportedMinOpsetVersion(15)
1044
    def test_local_function_predefined_attributes(self):
1045
        class M(torch.nn.Module):
1046
            num_layers: int
1047

1048
            def __init__(self, num_layers):
1049
                super().__init__()
1050
                self.num_layers = num_layers
1051
                self.lns = torch.nn.ModuleList(
1052
                    [torch.nn.LayerNorm(3, eps=1e-4) for _ in range(num_layers)]
1053
                )
1054

1055
            def forward(self, x):
1056
                for ln in self.lns:
1057
                    x = ln(x)
1058
                return x
1059

1060
        x = torch.randn(2, 3)
1061
        f = io.BytesIO()
1062
        model = M(3)
1063
        torch.onnx.export(
1064
            model,
1065
            (x,),
1066
            f,
1067
            export_modules_as_functions=True,
1068
            opset_version=self.opset_version,
1069
        )
1070

1071
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1072
        funcs = onnx_model.functions
1073
        m_funcs = [fn for fn in funcs if fn.name == "M"]
1074
        self.assertEqual(m_funcs[0].attribute, ["num_layers"])
1075
        ln_funcs = [fn for fn in funcs if fn.name == "LayerNorm"]
1076
        self.assertEqual(ln_funcs[0].attribute, ["eps", "elementwise_affine"])
1077

1078
        from onnx import helper
1079

1080
        m_node = [n for n in onnx_model.graph.node if n.op_type == "M"]
1081
        self.assertEqual(
1082
            m_node[0].attribute[0],
1083
            helper.make_attribute("num_layers", model.num_layers),
1084
        )
1085

1086
        ln_nodes = [n for n in m_funcs[0].node if n.op_type == "LayerNorm"]
1087
        expected_ln_attrs = [
1088
            helper.make_attribute(
1089
                "elementwise_affine", model.lns[0].elementwise_affine
1090
            ),
1091
            helper.make_attribute("eps", model.lns[0].eps),
1092
        ]
1093
        for ln_node in ln_nodes:
1094
            self.assertIn(ln_node.attribute[0], expected_ln_attrs)
1095
            self.assertIn(ln_node.attribute[1], expected_ln_attrs)
1096

1097
    # This test cases checks the issue where an object does not have an attribute.
1098
    # When enabling `export_modules_as_functions = True`, the exporter could return an
1099
    # AttributeError. With this test case, we check that the export passes successfully
1100
    # without any AttributeError exceptions.
1101
    # See https://github.com/pytorch/pytorch/pull/109759 for an example. The exception that
1102
    # this test tries to avoid is `AttributeError: 'Embedding' object has no attribute 'freeze'`.
1103
    @skipIfUnsupportedMinOpsetVersion(15)
1104
    def test_local_function_subset_of_predefined_attributes(self):
1105
        class M(torch.nn.Module):
1106
            num_layers: int
1107

1108
            def __init__(self, num_layers):
1109
                super().__init__()
1110
                self.embed_layer = torch.nn.Embedding.from_pretrained(
1111
                    torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
1112
                )
1113
                self.num_layers = num_layers
1114
                self.lns = torch.nn.ModuleList(
1115
                    [torch.nn.LayerNorm(3, eps=1e-4) for _ in range(num_layers)]
1116
                )
1117

1118
            def forward(self, x):
1119
                e = self.embed_layer(torch.LongTensor([1]))
1120
                for ln in self.lns:
1121
                    x = ln(x)
1122
                return x, e
1123

1124
        x = torch.randn(2, 3)
1125
        f = io.BytesIO()
1126
        model = M(3)
1127
        torch.onnx.export(
1128
            model,
1129
            (x,),
1130
            f,
1131
            export_modules_as_functions=True,
1132
            opset_version=self.opset_version,
1133
            verbose=True,  # Allows the test case to print `Skipping module attribute 'freeze'`
1134
        )
1135

1136
    def test_node_scope(self):
1137
        class N(torch.nn.Module):
1138
            def __init__(self):
1139
                super().__init__()
1140
                self.relu = torch.nn.ReLU()
1141

1142
            def forward(self, x):
1143
                return self.relu(x)
1144

1145
        class M(torch.nn.Module):
1146
            def __init__(self, num_layers):
1147
                super().__init__()
1148
                self.num_layers = num_layers
1149
                self.lns = torch.nn.ModuleList(
1150
                    [torch.nn.LayerNorm(3, eps=float(i)) for i in range(num_layers)]
1151
                )
1152
                self.gelu1 = torch.nn.GELU()
1153
                self.gelu2 = torch.nn.GELU()
1154
                self.relu = N()
1155

1156
            def forward(self, x, y, z):
1157
                res1 = self.gelu1(x)
1158
                res2 = self.gelu2(y)
1159
                for ln in self.lns:
1160
                    z = ln(z)
1161
                return res1 + res2, self.relu(z)
1162

1163
        x = torch.randn(2, 3)
1164
        y = torch.randn(2, 3)
1165
        z = torch.randn(2, 3)
1166

1167
        model = M(3)
1168
        expected_scope_names = {
1169
            "M::/torch.nn.modules.activation.GELU::gelu1",
1170
            "M::/torch.nn.modules.activation.GELU::gelu2",
1171
            "M::/torch.nn.modules.normalization.LayerNorm::lns.0",
1172
            "M::/torch.nn.modules.normalization.LayerNorm::lns.1",
1173
            "M::/torch.nn.modules.normalization.LayerNorm::lns.2",
1174
            "M::/N::relu/torch.nn.modules.activation.ReLU::relu",
1175
            "M::",
1176
        }
1177

1178
        graph, _, _ = self._model_to_graph(
1179
            model, (x, y, z), input_names=[], dynamic_axes={}
1180
        )
1181
        for node in graph.nodes():
1182
            self.assertIn(
1183
                _remove_test_environment_prefix_from_scope_name(node.scopeName()),
1184
                expected_scope_names,
1185
            )
1186

1187
        graph, _, _ = self._model_to_graph(
1188
            torch.jit.script(model), (x, y, z), input_names=[], dynamic_axes={}
1189
        )
1190
        for node in graph.nodes():
1191
            self.assertIn(
1192
                _remove_test_environment_prefix_from_scope_name(node.scopeName()),
1193
                expected_scope_names,
1194
            )
1195

1196
    def test_scope_of_constants_when_combined_by_cse_pass(self):
1197
        layer_num = 3
1198

1199
        class M(torch.nn.Module):
1200
            def __init__(self, constant):
1201
                super().__init__()
1202
                self.constant = constant
1203

1204
            def forward(self, x):
1205
                # 'self.constant' is designed to be the same for all layers,
1206
                # hence it is common sub expression.
1207
                return x + self.constant
1208

1209
        class N(torch.nn.Module):
1210
            def __init__(self, layers: int = layer_num):
1211
                super().__init__()
1212
                self.layers = torch.nn.ModuleList(
1213
                    [M(constant=torch.tensor(1.0)) for i in range(layers)]
1214
                )
1215

1216
            def forward(self, x):
1217
                for layer in self.layers:
1218
                    x = layer(x)
1219
                return x
1220

1221
        graph, _, _ = self._model_to_graph(
1222
            N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
1223
        )
1224

1225
        # NOTE: Duplicated constants are populated due to implicit casting in scalar_type_analysis,
1226
        #       so we expect 3 constants with different scopes. The 3 constants are for the 3 layers.
1227
        #       If CSE in exporter is improved later, this test needs to be updated.
1228
        #       It should expect 1 constant, with same scope as root.
1229
        expected_root_scope_name = "N::"
1230
        expected_layer_scope_name = "M::layers"
1231
        expected_constant_scope_name = [
1232
            f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
1233
            for i in range(layer_num)
1234
        ]
1235

1236
        constant_scope_names = []
1237
        for node in graph.nodes():
1238
            if node.kind() == "onnx::Constant":
1239
                constant_scope_names.append(
1240
                    _remove_test_environment_prefix_from_scope_name(node.scopeName())
1241
                )
1242
        self.assertEqual(constant_scope_names, expected_constant_scope_name)
1243

1244
    def test_scope_of_nodes_when_combined_by_cse_pass(self):
1245
        layer_num = 3
1246

1247
        class M(torch.nn.Module):
1248
            def __init__(self, constant, bias):
1249
                super().__init__()
1250
                self.constant = constant
1251
                self.bias = bias
1252

1253
            def forward(self, x):
1254
                # 'constant' and 'x' is designed to be the same for all layers,
1255
                # hence `x + self.constant` is common sub expression.
1256
                # 'bias' is designed to be different for all layers,
1257
                # hence `* self.bias` is not common sub expression.
1258
                return (x + self.constant) * self.bias
1259

1260
        class N(torch.nn.Module):
1261
            def __init__(self, layers: int = layer_num):
1262
                super().__init__()
1263

1264
                self.layers = torch.nn.ModuleList(
1265
                    [
1266
                        M(constant=torch.tensor([1.0]), bias=torch.randn(1))
1267
                        for i in range(layers)
1268
                    ]
1269
                )
1270

1271
            def forward(self, x):
1272
                y = []
1273
                for layer in self.layers:
1274
                    y.append(layer(x))
1275
                return y[0], y[1], y[2]
1276

1277
        graph, _, _ = self._model_to_graph(
1278
            N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
1279
        )
1280
        expected_root_scope_name = "N::"
1281
        expected_layer_scope_name = "M::layers"
1282
        expected_add_scope_names = [
1283
            f"{expected_root_scope_name}/{expected_layer_scope_name}.0"
1284
        ]
1285
        expected_mul_scope_names = [
1286
            f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
1287
            for i in range(layer_num)
1288
        ]
1289

1290
        add_scope_names = []
1291
        mul_scope_names = []
1292
        for node in graph.nodes():
1293
            if node.kind() == "onnx::Add":
1294
                add_scope_names.append(
1295
                    _remove_test_environment_prefix_from_scope_name(node.scopeName())
1296
                )
1297
            elif node.kind() == "onnx::Mul":
1298
                mul_scope_names.append(
1299
                    _remove_test_environment_prefix_from_scope_name(node.scopeName())
1300
                )
1301
        self.assertEqual(add_scope_names, expected_add_scope_names)
1302
        self.assertEqual(mul_scope_names, expected_mul_scope_names)
1303

1304
    def test_aten_fallthrough(self):
1305
        # Test aten export of op with no symbolic
1306
        class Module(torch.nn.Module):
1307
            def forward(self, x):
1308
                return torch.erfc(x)
1309

1310
        x = torch.randn(2, 3, 4)
1311
        GLOBALS.export_onnx_opset_version = self.opset_version
1312
        graph, _, __ = self._model_to_graph(
1313
            Module(),
1314
            (x,),
1315
            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1316
            input_names=["x"],
1317
            dynamic_axes={"x": [0, 1, 2]},
1318
        )
1319
        iter = graph.nodes()
1320
        self.assertEqual(next(iter).kind(), "aten::erfc")
1321

1322
    def test_custom_op_fallthrough(self):
1323
        # Test custom op
1324
        op_source = """
1325
        #include <torch/script.h>
1326

1327
        torch::Tensor custom_add(torch::Tensor self, torch::Tensor other) {
1328
          return self + other;
1329
        }
1330

1331
        static auto registry =
1332
          torch::RegisterOperators("custom_namespace::custom_op", &custom_add);
1333
        """
1334

1335
        torch.utils.cpp_extension.load_inline(
1336
            name="custom_add",
1337
            cpp_sources=op_source,
1338
            is_python_module=False,
1339
            verbose=True,
1340
        )
1341

1342
        class FooModel(torch.nn.Module):
1343
            def forward(self, input, other):
1344
                # Calling custom op
1345
                return torch.ops.custom_namespace.custom_op(input, other)
1346

1347
        x = torch.randn(2, 3, 4, requires_grad=False)
1348
        y = torch.randn(2, 3, 4, requires_grad=False)
1349
        model = FooModel()
1350
        graph, _, __ = self._model_to_graph(
1351
            model,
1352
            (x, y),
1353
            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
1354
            input_names=["x", "y"],
1355
            dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
1356
        )
1357
        iter = graph.nodes()
1358
        self.assertEqual(next(iter).kind(), "custom_namespace::custom_op")
1359

1360
    def test_custom_opsets_gelu(self):
1361
        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::gelu", 9)
1362

1363
        def gelu(g, self, approximate):
1364
            return g.op("com.microsoft::Gelu", self).setType(self.type())
1365

1366
        torch.onnx.register_custom_op_symbolic("::gelu", gelu, 9)
1367
        model = torch.nn.GELU(approximate="none")
1368
        x = torch.randn(3, 3)
1369
        f = io.BytesIO()
1370
        torch.onnx.export(
1371
            model,
1372
            (x,),
1373
            f,
1374
            opset_version=self.opset_version,
1375
            custom_opsets={"com.microsoft": 1},
1376
        )
1377

1378
        graph = onnx.load(io.BytesIO(f.getvalue()))
1379
        self.assertEqual(graph.graph.node[0].op_type, "Gelu")
1380
        self.assertEqual(graph.opset_import[0].version, self.opset_version)
1381
        self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
1382
        self.assertEqual(graph.opset_import[1].version, 1)
1383

1384
    def test_register_aten_custom_op_symbolic(self):
1385
        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "aten::gelu", 9)
1386

1387
        def gelu(g, self, approximate):
1388
            return g.op("com.microsoft::Gelu", self).setType(self.type())
1389

1390
        torch.onnx.register_custom_op_symbolic("aten::gelu", gelu, 9)
1391
        model = torch.nn.GELU(approximate="none")
1392
        x = torch.randn(3, 3)
1393
        f = io.BytesIO()
1394
        torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
1395
        graph = onnx.load(io.BytesIO(f.getvalue()))
1396

1397
        self.assertEqual(graph.graph.node[0].op_type, "Gelu")
1398
        self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
1399

1400
    @skipIfNoLapack
1401
    def test_custom_opsets_inverse(self):
1402
        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
1403

1404
        class CustomInverse(torch.nn.Module):
1405
            def forward(self, x):
1406
                return torch.inverse(x) + x
1407

1408
        def linalg_inv(g, self):
1409
            return g.op("com.microsoft::Inverse", self).setType(self.type())
1410

1411
        torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv, 9)
1412
        model = CustomInverse()
1413
        x = torch.randn(2, 3, 3)
1414
        f = io.BytesIO()
1415
        torch.onnx.export(
1416
            model,
1417
            (x,),
1418
            f,
1419
            opset_version=self.opset_version,
1420
            custom_opsets={"com.microsoft": 1},
1421
        )
1422

1423
        graph = onnx.load(io.BytesIO(f.getvalue()))
1424
        self.assertEqual(graph.graph.node[0].op_type, "Inverse")
1425
        self.assertEqual(graph.opset_import[0].version, self.opset_version)
1426
        self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
1427
        self.assertEqual(graph.opset_import[1].version, 1)
1428

1429
    def test_onnx_fallthrough(self):
1430
        # Test aten export of op with symbolic for aten
1431
        class Module(torch.nn.Module):
1432
            def forward(self, x):
1433
                return torch.digamma(x)
1434

1435
        x = torch.randn(100, 128)
1436
        graph, _, __ = self._model_to_graph(
1437
            Module(),
1438
            (x,),
1439
            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1440
            input_names=["x"],
1441
            dynamic_axes={"x": [0, 1]},
1442
        )
1443
        iter = graph.nodes()
1444
        self.assertEqual(next(iter).kind(), "aten::digamma")
1445

1446
    # prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11
1447
    @skipIfUnsupportedMaxOpsetVersion(10)
1448
    def test_prim_fallthrough(self):
1449
        # Test prim op
1450
        class PrimModule(torch.jit.ScriptModule):
1451
            @torch.jit.script_method
1452
            def forward(self, x):
1453
                if isinstance(x, list):
1454
                    y = x
1455
                else:
1456
                    y = [x]
1457
                return y
1458

1459
        x = torch.tensor([2])
1460
        model = PrimModule()
1461
        model.eval()
1462
        graph, _, __ = self._model_to_graph(
1463
            model,
1464
            (x,),
1465
            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1466
            input_names=["x"],
1467
            dynamic_axes={"x": [0]},
1468
        )
1469
        iter = graph.nodes()
1470
        self.assertEqual(next(iter).kind(), "prim::ListConstruct")
1471

1472
    def test_custom_layer_tuple(self):
1473
        class CustomFunction(torch.autograd.Function):
1474
            @staticmethod
1475
            def symbolic(g, input):
1476
                return g.op("CustomNamespace::Custom", input, outputs=2)
1477

1478
            @staticmethod
1479
            def forward(ctx, input):
1480
                return input, input
1481

1482
        class Custom(torch.nn.Module):
1483
            def forward(self, input):
1484
                return CustomFunction.apply(input)
1485

1486
        model = Custom()
1487
        batch = torch.FloatTensor(1, 3)
1488

1489
        graph, _, _ = self._model_to_graph(
1490
            model, batch, input_names=["batch"], dynamic_axes={"batch": [0, 1]}
1491
        )
1492
        iter = graph.nodes()
1493
        self.assertEqual(next(iter).kind(), "CustomNamespace::Custom")
1494

1495
    def test_autograd_onnx_fallthrough(self):
1496
        class CustomFunction(torch.autograd.Function):
1497
            @staticmethod
1498
            def forward(ctx, input):
1499
                ctx.save_for_backward(input)
1500
                return input.clamp(min=0)
1501

1502
            @staticmethod
1503
            def backward(ctx, grad_output):
1504
                (input,) = ctx.saved_tensors
1505
                grad_input = grad_output.clone()
1506
                grad_input[input < 0] = 0
1507
                return grad_input
1508

1509
        class Custom(torch.nn.Module):
1510
            def forward(self, input):
1511
                return CustomFunction.apply(input)
1512

1513
        model = Custom()
1514
        batch = torch.FloatTensor(1, 3)
1515

1516
        graph, _, _ = self._model_to_graph(
1517
            model,
1518
            batch,
1519
            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1520
            input_names=["batch"],
1521
            dynamic_axes={"batch": [0, 1]},
1522
        )
1523
        iter = graph.nodes()
1524
        self.assertEqual(next(iter).kind(), "prim::PythonOp")
1525

1526
    def test_autograd_module_name(self):
1527
        class CustomFunction(torch.autograd.Function):
1528
            @staticmethod
1529
            def forward(ctx, input):
1530
                ctx.save_for_backward(input)
1531
                return input.clamp(min=0)
1532

1533
            @staticmethod
1534
            def backward(ctx, grad_output):
1535
                (input,) = ctx.saved_tensors
1536
                grad_input = grad_output.clone()
1537
                grad_input[input < 0] = 0
1538
                return grad_input
1539

1540
        class Custom(torch.nn.Module):
1541
            def forward(self, input):
1542
                return CustomFunction.apply(input) + CustomFunction2.apply(input)
1543

1544
        model = Custom()
1545
        batch = torch.FloatTensor(1, 3)
1546

1547
        graph, _, _ = self._model_to_graph(
1548
            model,
1549
            batch,
1550
            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1551
            input_names=["batch"],
1552
            dynamic_axes={"batch": [0, 1]},
1553
        )
1554
        iter = graph.nodes()
1555
        autograd1 = next(iter)
1556
        autograd2 = next(iter)
1557
        self.assertEqual(autograd1.kind(), "prim::PythonOp")
1558
        self.assertEqual(autograd2.kind(), "prim::PythonOp")
1559
        self.assertNotEqual(autograd1.s("module"), autograd2.s("module"))
1560

1561
    def test_unused_initializers(self):
1562
        class Model(torch.nn.Module):
1563
            def __init__(self):
1564
                super().__init__()
1565
                self.conv2 = torch.nn.ConvTranspose2d(
1566
                    16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1)
1567
                )
1568
                self.k_proj = torch.nn.Linear(5, 5, bias=True)
1569

1570
            def forward(self, x):
1571
                x = self.conv2(x)
1572
                return x
1573

1574
        x = torch.randn(20, 16, 50, 100)
1575
        GLOBALS.export_onnx_opset_version = self.opset_version
1576
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
1577
        _, params_dict, __ = self._model_to_graph(
1578
            Model(),
1579
            (x,),
1580
            do_constant_folding=False,
1581
            operator_export_type=OperatorExportTypes.ONNX,
1582
            input_names=["x"],
1583
            dynamic_axes={"x": [0, 1, 2, 3]},
1584
        )
1585

1586
        self.assertEqual(len(params_dict), 2)
1587

1588
    def test_scripting_param(self):
1589
        class MyModule(torch.nn.Module):
1590
            def __init__(self):
1591
                super().__init__()
1592
                self.conv = torch.nn.Conv2d(
1593
                    3, 16, kernel_size=1, stride=2, padding=3, bias=True
1594
                )
1595
                self.bn = torch.nn.BatchNorm2d(16, affine=True)
1596

1597
            def forward(self, x):
1598
                x = self.conv(x)
1599
                bn = self.bn(x)
1600
                return bn
1601

1602
        model = torch.jit.script(MyModule())
1603
        x = torch.randn(10, 3, 128, 128)
1604
        GLOBALS.export_onnx_opset_version = self.opset_version
1605
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
1606
        graph, _, __ = self._model_to_graph(
1607
            model,
1608
            (x,),
1609
            do_constant_folding=True,
1610
            operator_export_type=OperatorExportTypes.ONNX,
1611
            training=torch.onnx.TrainingMode.TRAINING,
1612
            input_names=["x"],
1613
            dynamic_axes={"x": [0, 1, 2, 3]},
1614
        )
1615

1616
        graph_input_params = [param.debugName() for param in graph.inputs()]
1617
        for item in dict(model.named_parameters()):
1618
            self.assertIn(
1619
                item,
1620
                graph_input_params,
1621
                "Graph parameter names does not match model parameters.",
1622
            )
1623

1624
    @skipIfNoCaffe2
1625
    def test_modifying_params(self):
1626
        class MyModel(torch.nn.Module):
1627
            def __init__(self):
1628
                super().__init__()
1629
                self.param = torch.nn.Parameter(torch.tensor([2.0]))
1630

1631
            def forward(self, x):
1632
                y = x * x
1633
                self.param.data.add_(1.0)
1634
                return y
1635

1636
        x = torch.tensor([1, 2])
1637
        # Move import to local as caffe2 backend requires additional build flag,
1638
        # and is only used in this test case.
1639
        import caffe2.python.onnx.backend as backend
1640

1641
        verify(MyModel(), x, backend, do_constant_folding=False)
1642

1643
    def test_fuse_conv_bn(self):
1644
        class Fuse(torch.nn.Module):
1645
            def __init__(self):
1646
                super().__init__()
1647
                self.conv = torch.nn.Conv2d(
1648
                    3, 2, kernel_size=1, stride=2, padding=3, bias=True
1649
                )
1650
                self.bn = torch.nn.BatchNorm2d(2)
1651

1652
            def forward(self, x):
1653
                out = self.conv(x)
1654
                return self.bn(out)
1655

1656
        x = torch.randn(2, 3, 2, 2, requires_grad=True)
1657
        graph, _, __ = self._model_to_graph(
1658
            Fuse(),
1659
            (x,),
1660
            training=TrainingMode.EVAL,
1661
            input_names=["x"],
1662
            dynamic_axes={"x": [0, 1, 2, 3]},
1663
        )
1664
        for node in graph.nodes():
1665
            self.assertNotEqual(node.kind(), "onnx::BatchNormalization")
1666
            self.assertEqual(node.kind(), "onnx::Conv")
1667

1668
        self.assertEqual(len(list(graph.nodes())), 1)
1669

1670
    def test_fuse_resnet18(self):
1671
        model = torchvision.models.resnet18(weights=None)
1672
        x = torch.randn(2, 3, 224, 224, requires_grad=True)
1673
        graph, _, __ = self._model_to_graph(
1674
            model,
1675
            (x,),
1676
            training=TrainingMode.EVAL,
1677
            input_names=["x"],
1678
            dynamic_axes={"x": [0, 1, 2, 3]},
1679
        )
1680

1681
        for node in graph.nodes():
1682
            self.assertNotEqual(node.kind(), "onnx::BatchNormalization")
1683

1684
    def test_onnx_function_substitution_pass(self):
1685
        @torch.jit.script
1686
        def f(x: torch.Tensor, y: torch.Tensor):
1687
            z = x - y
1688
            return x + z
1689

1690
        class MyModule(torch.nn.Module):
1691
            def forward(self, x, y):
1692
                return f(x, y)
1693

1694
        input_1 = torch.tensor([11])
1695
        input_2 = torch.tensor([12])
1696
        GLOBALS.export_onnx_opset_version = self.opset_version
1697
        GLOBALS.operator_export_type = OperatorExportTypes.ONNX
1698
        graph, _, __ = self._model_to_graph(
1699
            MyModule(),
1700
            (input_1, input_2),
1701
            do_constant_folding=True,
1702
            operator_export_type=OperatorExportTypes.ONNX,
1703
            input_names=["input_1", "input_2"],
1704
            dynamic_axes={"input_1": [0], "input_2": [0]},
1705
        )
1706
        # Check that the prim::Constant node in the graph for representing the
1707
        # scripted function `f` is removed and the following prim::CallFunction
1708
        # is replced by inline graph, with onnx::Sub and onnx::Add nodes.
1709
        for node in graph.nodes():
1710
            self.assertNotEqual(node.kind(), "prim::Constant")
1711
        self.assertEqual(
1712
            len(list(graph.nodes())), 2
1713
        )  # onnx::Sub and onnx::Add nodes only.
1714

1715
    def test_onnx_value_name(self):
1716
        class MyModule(torch.nn.Module):
1717
            def __init__(self):
1718
                super().__init__()
1719
                self.in_weight = torch.nn.Parameter(torch.Tensor(3, 3))
1720
                self.in_bias = torch.nn.Parameter(torch.Tensor(3))
1721

1722
            def forward(self, x):
1723
                start = 0
1724
                end = None
1725
                weight = self.in_weight
1726
                bias = self.in_bias
1727
                weight = weight[start:end, :]
1728
                if bias is not None:
1729
                    bias = bias[start:end]
1730
                return torch.nn.functional.linear(x, weight, bias)
1731

1732
        model = MyModule()
1733
        x = torch.randn(3, 3)
1734
        f = io.BytesIO()
1735

1736
        model.eval()
1737
        torch.onnx.export(
1738
            model,
1739
            (x,),
1740
            f,
1741
            opset_version=self.opset_version,
1742
            keep_initializers_as_inputs=True,
1743
        )
1744
        graph = onnx.load(io.BytesIO(f.getvalue()))
1745
        self.assertEqual(graph.graph.input[1].name, "in_weight")
1746
        self.assertEqual(graph.graph.input[2].name, "in_bias")
1747

1748
    def test_onnx_node_naming(self):
1749
        class MainModule(torch.nn.Module):
1750
            def __init__(self):
1751
                super().__init__()
1752
                self._module_1 = torch.nn.Linear(10, 10)
1753
                self._module_2 = torch.nn.Linear(10, 10)
1754
                self._module_3 = torch.nn.Linear(10, 10)
1755
                self._module_4 = torch.nn.Linear(10, 10)
1756

1757
            def forward(self, x):
1758
                y = self._module_1(x)
1759
                z = self._module_2(y)
1760
                z = self._module_3(y * z)
1761
                z = self._module_4(y * z)
1762
                return z
1763

1764
        module = MainModule()
1765
        ref_node_names = [
1766
            "/_module_1/Gemm",
1767
            "/_module_2/Gemm",
1768
            "/_module_3/Gemm",
1769
            "/_module_4/Gemm",
1770
            "/Mul",
1771
            "/Mul_1",
1772
        ]
1773
        f = io.BytesIO()
1774

1775
        torch.onnx.export(module, torch.ones(1, 10), f, output_names=["y"])
1776
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1777
        for n in onnx_model.graph.node:
1778
            self.assertIn(n.name, ref_node_names)
1779

1780
        torch.onnx.export(
1781
            torch.jit.script(module), torch.ones(1, 10), f, output_names=["y"]
1782
        )
1783
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1784
        for n in onnx_model.graph.node:
1785
            self.assertIn(n.name, ref_node_names)
1786

1787
    def _test_deduplicate_initializers(self, torchscript=False):
1788
        class MyModule(torch.nn.Module):
1789
            def __init__(self):
1790
                super().__init__()
1791
                self.layer1 = torch.nn.Linear(3, 3)
1792
                self.layer2 = torch.nn.Linear(3, 3)
1793

1794
                # Reusing layers.
1795
                self.layer3 = self.layer1
1796

1797
                # Reusing parameters.
1798
                self.layer2.weight = self.layer1.weight
1799
                self.layer1.bias = self.layer2.bias
1800

1801
                # Parameter with different tensors equal in value.
1802
                self.param1 = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
1803
                self.param2 = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
1804

1805
            def forward(self, x):
1806
                return (
1807
                    self.layer3(self.layer2(self.layer1(x))) + self.param1 + self.param2
1808
                )
1809

1810
        model = torch.jit.script(MyModule()) if torchscript else MyModule()
1811

1812
        x = torch.randn(3, 3)
1813
        param_name_set = {k for k, _ in model.named_parameters()}
1814

1815
        # Test training mode.
1816
        model.train()
1817
        f = io.BytesIO()
1818
        torch.onnx.export(
1819
            model,
1820
            (x,),
1821
            f,
1822
            training=TrainingMode.TRAINING,
1823
            opset_version=self.opset_version,
1824
        )
1825
        graph = onnx.load(io.BytesIO(f.getvalue()))
1826
        self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
1827

1828
        model.train()
1829
        f = io.BytesIO()
1830
        torch.onnx.export(
1831
            model,
1832
            (x,),
1833
            f,
1834
            training=TrainingMode.PRESERVE,
1835
            opset_version=self.opset_version,
1836
        )
1837
        graph = onnx.load(io.BytesIO(f.getvalue()))
1838
        self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
1839

1840
        # Test eval mode.
1841
        model.eval()
1842
        f = io.BytesIO()
1843
        torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
1844
        graph = onnx.load(io.BytesIO(f.getvalue()))
1845
        param_name_set.remove("param2")
1846
        self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
1847

1848
    def test_deduplicate_initializers(self):
1849
        self._test_deduplicate_initializers(torchscript=False)
1850

1851
    def test_deduplicate_initializers_torchscript(self):
1852
        self._test_deduplicate_initializers(torchscript=True)
1853

1854
    @skipIfNoCuda
1855
    def test_deduplicate_initializers_diff_devices(self):
1856
        class Model(torch.nn.Module):
1857
            def __init__(self):
1858
                super().__init__()
1859
                self.w_cpu = torch.nn.Parameter(
1860
                    torch.ones(3, device=torch.device("cpu"))
1861
                )
1862
                self.w_cuda = torch.nn.Parameter(
1863
                    torch.ones(3, device=torch.device("cuda"))
1864
                )
1865

1866
            def forward(self, x, y):
1867
                return x + self.w_cpu, y + self.w_cuda
1868

1869
        x = torch.randn(3, 3, device=torch.device("cpu"))
1870
        y = torch.randn(3, 3, device=torch.device("cuda"))
1871
        f = io.BytesIO()
1872
        torch.onnx.export(Model(), (x, y), f, opset_version=self.opset_version)
1873
        graph = onnx.load(io.BytesIO(f.getvalue()))
1874
        self.assertSetEqual({i.name for i in graph.graph.initializer}, {"w_cpu"})
1875

1876
    def test_duplicated_output_node(self):
1877
        class DuplicatedOutputNet(torch.nn.Module):
1878
            def __init__(self, input_size, num_classes):
1879
                super().__init__()
1880
                self.fc1 = torch.nn.Linear(input_size, num_classes)
1881

1882
            def forward(self, input0, input1):
1883
                out1 = self.fc1(input0)
1884
                out2 = self.fc1(input1)
1885
                return out1, out1, out2, out1, out2
1886

1887
        N, D_in, H, D_out = 64, 784, 500, 10
1888
        pt_model = DuplicatedOutputNet(D_in, D_out)
1889

1890
        f = io.BytesIO()
1891
        x = torch.randn(N, D_in)
1892
        dynamic_axes = {
1893
            "input0": {0: "input0_dim0", 1: "input0_dim1"},
1894
            "input1": {0: "input1_dim0", 1: "input1_dim1"},
1895
            "output-0": {0: "output-0_dim0", 1: "output-0_dim1"},
1896
            "output-1": {0: "output-1_dim0", 1: "output-1_dim1"},
1897
            "output-2": {0: "output-2_dim0", 1: "output-2_dim1"},
1898
            "output-3": {0: "output-3_dim0", 1: "output-3_dim1"},
1899
            "output-4": {0: "output-4_dim0", 1: "output-4_dim1"},
1900
        }
1901

1902
        torch.onnx.export(
1903
            pt_model,
1904
            (x, x),
1905
            f,
1906
            input_names=["input0", "input1"],
1907
            output_names=["output-0", "output-1", "output-2", "output-3", "output-4"],
1908
            do_constant_folding=False,
1909
            training=torch.onnx.TrainingMode.TRAINING,
1910
            dynamic_axes=dynamic_axes,
1911
            verbose=True,
1912
            keep_initializers_as_inputs=True,
1913
        )
1914

1915
        graph = onnx.load(io.BytesIO(f.getvalue()))
1916
        self.assertEqual(graph.graph.input[0].name, "input0")
1917
        self.assertEqual(graph.graph.input[1].name, "input1")
1918
        for i in range(5):
1919
            self.assertEqual(graph.graph.output[i].name, f"output-{i}")
1920
        self.assertEqual(graph.graph.node[0].op_type, "Gemm")
1921
        self.assertEqual(graph.graph.node[1].op_type, "Identity")
1922
        self.assertEqual(graph.graph.node[2].op_type, "Identity")
1923
        self.assertEqual(graph.graph.node[3].op_type, "Gemm")
1924
        self.assertEqual(graph.graph.node[4].op_type, "Identity")
1925

1926
    def test_deduplicate_ignore_upsample_scale(self):
1927
        # upsample scale is a constant, not a model parameter,
1928
        # therefore should be ignored by shared weight deduplication.
1929
        class Model(torch.nn.Module):
1930
            def __init__(self):
1931
                super().__init__()
1932
                self.upsample_1 = torch.nn.Upsample(scale_factor=2)
1933
                self.upsample_2 = torch.nn.Upsample(scale_factor=2)
1934

1935
            def forward(self, x):
1936
                return self.upsample_1(x), self.upsample_2(x)
1937

1938
        f = io.BytesIO()
1939
        x = torch.randn(1, 32, 224, 224)
1940
        torch.onnx.export(Model(), x, f)
1941
        onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1942
        # aten::upsample converts to onnx::resize
1943
        resize_nodes = [n for n in onnx_model.graph.node if n.op_type == "Resize"]
1944
        self.assertEqual(len(resize_nodes), 2)
1945
        for resize_node in resize_nodes:
1946
            scale_node = [
1947
                n for n in onnx_model.graph.node if n.output[0] == resize_node.input[2]
1948
            ]
1949
            self.assertEqual(len(scale_node), 1)
1950
            self.assertEqual(scale_node[0].op_type, "Constant")
1951

1952
    def test_bad_symbolic_registration(self):
1953
        _onnx_opset_version = 9
1954

1955
        @parse_args("v")
1956
        def cat(g, tensor_list, dim):
1957
            tensors = _unpack_list(tensor_list)
1958
            return g.op("Concat", *tensors, axis_i=dim)
1959

1960
        torch.onnx.register_custom_op_symbolic("::cat", cat, _onnx_opset_version)
1961

1962
        class CatModel(torch.nn.Module):
1963
            def forward(self, x):
1964
                return torch.cat((x, x, x), 0)
1965

1966
        model = CatModel()
1967
        x = torch.randn(2, 3)
1968
        f = io.BytesIO()
1969
        self.assertExpectedRaisesInline(
1970
            AssertionError,
1971
            lambda: torch.onnx.export(
1972
                model, (x,), f, opset_version=_onnx_opset_version
1973
            ),
1974
            (
1975
                "A mismatch between the number of arguments (2) and their descriptors (1) was found at symbolic function "
1976
                "'cat'. If you believe this is not due to custom symbolic implementation within your code or an external "
1977
                "library, please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to "
1978
                "report this bug."
1979
            ),
1980
        )
1981
        torch.onnx.unregister_custom_op_symbolic("::cat", _onnx_opset_version)
1982

1983

1984
if __name__ == "__main__":
1985
    common_utils.run_tests()
1986

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.