1
# Owner(s): ["module: onnx"]
8
from typing import Callable
12
import pytorch_test_common
16
import torch.utils.cpp_extension
18
from autograd_helper import CustomFunction as CustomFunction2
19
from pytorch_test_common import (
21
skipIfUnsupportedMaxOpsetVersion,
22
skipIfUnsupportedMinOpsetVersion,
24
from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils
25
from torch.onnx._globals import GLOBALS
26
from torch.onnx.symbolic_helper import _unpack_list, parse_args
27
from torch.testing._internal import common_utils
28
from torch.testing._internal.common_utils import skipIfNoCaffe2, skipIfNoLapack
29
from verify import verify
32
def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str:
33
"""Remove test environment prefix added to module.
35
Remove prefix to normalize scope names, since different test environments add
36
prefixes with slight differences.
40
>>> _remove_test_environment_prefix_from_scope_name(
41
>>> "test_utility_funs.M"
44
>>> _remove_test_environment_prefix_from_scope_name(
45
>>> "test_utility_funs.test_abc.<locals>.M"
48
>>> _remove_test_environment_prefix_from_scope_name(
53
prefixes_to_remove = ["test_utility_funs", "__main__"]
54
for prefix in prefixes_to_remove:
55
scope_name = re.sub(f"{prefix}\\.(.*?<locals>\\.)?", "", scope_name)
59
class _BaseTestCase(pytorch_test_common.ExportTestCase):
64
do_constant_folding=True,
65
training=TrainingMode.EVAL,
66
operator_export_type=OperatorExportTypes.ONNX,
70
torch.onnx.utils._setup_trace_module_map(model, False)
71
if training == torch.onnx.TrainingMode.TRAINING:
73
elif training == torch.onnx.TrainingMode.EVAL:
75
utils._validate_dynamic_axes(dynamic_axes, model, None, None)
76
graph, params_dict, torch_out = utils._model_to_graph(
79
do_constant_folding=do_constant_folding,
80
_disable_torch_constant_prop=True,
81
operator_export_type=operator_export_type,
83
input_names=input_names,
84
dynamic_axes=dynamic_axes,
86
return graph, params_dict, torch_out
89
@common_utils.instantiate_parametrized_tests
90
class TestUnconvertibleOps(pytorch_test_common.ExportTestCase):
91
"""Unit tests for the `unconvertible_ops` function."""
94
class EinsumModule(torch.nn.Module):
96
return torch.einsum("ii", x)
98
self.einsum_module = EinsumModule()
100
def test_it_returns_graph_and_unconvertible_ops_at_lower_opset_version(self):
101
x = torch.randn(4, 4)
103
# Einsum is supported since opset 12. It should be unconvertible at opset 9.
104
graph, unconvertible_ops = utils.unconvertible_ops(
105
self.einsum_module, (x,), opset_version=9
107
nodes = graph.nodes()
108
self.assertEqual(next(nodes).kind(), "prim::Constant")
109
self.assertEqual(next(nodes).kind(), "prim::ListConstruct")
110
self.assertEqual(next(nodes).kind(), "prim::Constant")
111
self.assertEqual(next(nodes).kind(), "aten::einsum")
112
self.assertEqual(unconvertible_ops, ["aten::einsum"])
114
@common_utils.parametrize(
117
common_utils.subtest(
118
functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
121
common_utils.subtest(torch.jit.script, name="scripted"),
124
def test_it_returns_unconvertible_ops_at_lower_opset_version_for_jit_module(
125
self, jit_function: Callable
127
module = jit_function(self.einsum_module)
128
x = torch.randn(4, 4)
130
# Einsum is supported since opset 12. It should be unconvertible at opset 9.
131
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=9)
132
self.assertEqual(unconvertible_ops, ["aten::einsum"])
134
@common_utils.parametrize(
137
common_utils.subtest(lambda x: x, name="nn_module"),
138
common_utils.subtest(
139
functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
142
common_utils.subtest(torch.jit.script, name="scripted"),
145
def test_it_returns_empty_list_when_all_ops_convertible(
146
self, jit_function: Callable
148
module = jit_function(self.einsum_module)
149
x = torch.randn(4, 4)
151
# Einsum is supported since opset 12
152
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12)
153
self.assertEqual(unconvertible_ops, [])
155
def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self):
156
class SkipConnectionModule(torch.nn.Module):
157
def forward(self, x):
160
out = torch.nn.functional.relu(out, inplace=True)
163
module = SkipConnectionModule()
164
x = torch.randn(4, 4)
165
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13)
166
self.assertEqual(unconvertible_ops, [])
169
@parameterized.parameterized_class(
171
{"opset_version": opset}
173
_constants.ONNX_BASE_OPSET,
174
_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET + 1,
177
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}",
179
class TestUtilityFuns(_BaseTestCase):
182
def test_is_in_onnx_export(self):
185
class MyModule(torch.nn.Module):
186
def forward(self, x):
187
test_self.assertTrue(torch.onnx.is_in_onnx_export())
191
x = torch.randn(3, 4)
194
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
196
self.assertFalse(torch.onnx.is_in_onnx_export())
198
def test_validate_dynamic_axes_invalid_input_output_name(self):
199
with warnings.catch_warnings(record=True) as w:
200
warnings.simplefilter("always")
201
utils._validate_dynamic_axes(
202
{"input1": {}, "output": {}, "invalid_name1": {}, "invalid_name2": {}},
204
["input1", "input2"],
207
messages = [str(warning.message) for warning in w]
209
"Provided key invalid_name1 for dynamic axes is not a valid input/output name",
213
"Provided key invalid_name2 for dynamic axes is not a valid input/output name",
216
self.assertEqual(len(messages), 2)
218
@skipIfUnsupportedMinOpsetVersion(11)
219
def test_split_to_slice(self):
220
class SplitModule(torch.nn.Module):
221
def forward(self, x, y, t):
222
splits = (x.size(1), y.size(1))
223
out, out2 = torch.split(t, splits, dim=1)
226
GLOBALS.export_onnx_opset_version = self.opset_version
227
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
228
x = torch.randn(2, 3)
229
y = torch.randn(2, 4)
230
t = torch.randn(2, 7)
231
graph, _, _ = self._model_to_graph(
234
input_names=["x", "y", "t"],
235
dynamic_axes={"x": [0, 1], "y": [0, 1], "t": [0, 1]},
237
for node in graph.nodes():
238
self.assertNotEqual(node.kind(), "onnx::SplitToSequence")
240
def test_constant_fold_transpose(self):
241
class TransposeModule(torch.nn.Module):
242
def forward(self, x):
243
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
244
b = torch.transpose(a, 1, 0)
247
GLOBALS.export_onnx_opset_version = self.opset_version
248
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
250
graph, _, __ = self._model_to_graph(
251
TransposeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
254
for node in graph.nodes():
255
self.assertNotEqual(node.kind(), "onnx::Transpose")
256
self.assertNotEqual(node.kind(), "onnx::Cast")
257
self.assertEqual(len(list(graph.nodes())), 2)
259
def test_constant_fold_reduceL2(self):
260
class ReduceModule(torch.nn.Module):
261
def forward(self, x):
262
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
263
b = torch.norm(a, p=2, dim=-2, keepdim=False)
266
GLOBALS.export_onnx_opset_version = self.opset_version
267
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
269
graph, _, __ = self._model_to_graph(
270
ReduceModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
273
for node in graph.nodes():
274
self.assertNotEqual(node.kind(), "onnx::ReduceL2")
276
def test_constant_fold_reduceL1(self):
277
class NormModule(torch.nn.Module):
278
def forward(self, x):
279
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
280
b = torch.norm(a, p=1, dim=-2)
283
GLOBALS.export_onnx_opset_version = self.opset_version
284
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
286
graph, _, __ = self._model_to_graph(
287
NormModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
290
for node in graph.nodes():
291
self.assertNotEqual(node.kind(), "onnx::ReduceL1")
293
def test_constant_fold_slice(self):
294
class NarrowModule(torch.nn.Module):
295
def forward(self, x):
296
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
297
b = torch.narrow(a, 0, 0, 1)
300
GLOBALS.export_onnx_opset_version = self.opset_version
301
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
303
graph, _, __ = self._model_to_graph(
304
NarrowModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
307
for node in graph.nodes():
308
self.assertNotEqual(node.kind(), "onnx::Slice")
309
self.assertNotEqual(node.kind(), "onnx::Cast")
310
self.assertEqual(len(list(graph.nodes())), 2)
312
def test_constant_fold_slice_index_exceeds_dim(self):
313
class SliceIndexExceedsDimModule(torch.nn.Module):
314
def forward(self, x):
315
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
316
b = a[1:10] # index exceeds dimension
319
GLOBALS.export_onnx_opset_version = self.opset_version
320
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
322
graph, _, __ = self._model_to_graph(
323
SliceIndexExceedsDimModule(),
326
dynamic_axes={"x": [0, 1]},
329
for node in graph.nodes():
330
self.assertNotEqual(node.kind(), "onnx::Slice")
331
self.assertNotEqual(node.kind(), "onnx::Cast")
332
self.assertEqual(len(list(graph.nodes())), 2)
334
def test_constant_fold_slice_negative_index(self):
335
class SliceNegativeIndexModule(torch.nn.Module):
336
def forward(self, x):
337
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
338
b = a[0:-1] # index relative to the end
339
c = torch.select(a, dim=-1, index=-2)
340
d = torch.select(a, dim=1, index=0)
343
GLOBALS.export_onnx_opset_version = self.opset_version
344
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
346
graph, _, __ = self._model_to_graph(
347
SliceNegativeIndexModule(),
350
dynamic_axes={"x": [0, 1]},
353
for node in graph.nodes():
354
self.assertNotEqual(node.kind(), "onnx::Slice")
355
self.assertNotEqual(node.kind(), "onnx::Cast")
357
def test_constant_fold_gather(self):
358
class GatherModule(torch.nn.Module):
359
def forward(self, x):
360
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
361
b = torch.select(a, dim=1, index=-2)
362
c = torch.index_select(a, dim=-2, index=torch.tensor([0, 1]))
365
GLOBALS.export_onnx_opset_version = self.opset_version
366
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
368
model = GatherModule()
370
graph, _, __ = self._model_to_graph(
371
GatherModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
374
for node in graph.nodes():
375
self.assertNotEqual(node.kind(), "onnx::Gather")
377
def test_constant_fold_unsqueeze(self):
378
class UnsqueezeModule(torch.nn.Module):
379
def forward(self, x):
380
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
381
b = torch.unsqueeze(a, -2)
384
GLOBALS.export_onnx_opset_version = self.opset_version
385
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
386
x = torch.ones(1, 2, 3)
387
graph, _, __ = self._model_to_graph(
388
UnsqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
391
for node in graph.nodes():
392
self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
393
self.assertNotEqual(node.kind(), "onnx::Cast")
394
self.assertEqual(len(list(graph.nodes())), 2)
396
def test_constant_fold_unsqueeze_multi_axies(self):
397
class PReluModel(torch.nn.Module):
400
self.prelu = torch.nn.PReLU()
402
def forward(self, x):
403
a = torch.randn(2, 3, 4, 5, 8, 7)
404
return self.prelu(x) + a
406
GLOBALS.export_onnx_opset_version = self.opset_version
407
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
408
x = torch.randn(2, 3, 4, 5, 8, 7)
409
graph, _, __ = self._model_to_graph(
410
PReluModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3, 4, 5]}
413
for node in graph.nodes():
414
self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
415
self.assertNotEqual(node.kind(), "onnx::Cast")
416
self.assertEqual(len(list(graph.nodes())), 5)
418
def test_constant_fold_squeeze_without_axes(self):
419
class SqueezeModule(torch.nn.Module):
420
def forward(self, x):
421
a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
422
return torch.squeeze(a) + x + torch.squeeze(a)
424
GLOBALS.export_onnx_opset_version = self.opset_version
425
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
427
graph, _, __ = self._model_to_graph(
428
SqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
430
for node in graph.nodes():
431
self.assertNotEqual(node.kind(), "onnx::Squeeze")
432
self.assertNotEqual(node.kind(), "onnx::Cast")
433
self.assertEqual(len(list(graph.nodes())), 4)
435
def test_constant_fold_squeeze_with_axes(self):
436
class SqueezeAxesModule(torch.nn.Module):
437
def forward(self, x):
438
a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
439
return torch.squeeze(a, dim=-3) + x
441
GLOBALS.export_onnx_opset_version = self.opset_version
442
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
444
graph, _, __ = self._model_to_graph(
445
SqueezeAxesModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
448
for node in graph.nodes():
449
self.assertNotEqual(node.kind(), "onnx::Squeeze")
450
self.assertNotEqual(node.kind(), "onnx::Cast")
451
self.assertEqual(len(list(graph.nodes())), 2)
453
def test_constant_fold_concat(self):
454
class ConcatModule(torch.nn.Module):
455
def forward(self, x):
456
# Why did I insert a Cast here? There appears to be intentional
457
# behavior in ONNX constant folding where constant tensors which
458
# are not attached to any known to be foldable onnx
459
# operations don't get extracted into the initializer graph. So
460
# without these casts, we will actually fail to pull out one of
461
# the constants, thus failing constant folding. I think the
462
# test is wrong but I don't have time to write a more correct
463
# test (I think the right way to go about the test is to setup
464
# a predicate for what invariant graphs should hold after
465
# constant folding, and then verify this predicate holds.
466
# I think the asserts below are an attempt at this predicate,
467
# but it is not right!)
470
# https://github.com/pytorch/pytorch/pull/18698/files#r340107552
471
a = torch.tensor([[1.0, 2.0, 3.0]]).to(torch.float)
472
b = torch.tensor([[4.0, 5.0, 6.0]]).to(torch.float)
473
c = torch.cat((a, b), 0)
477
GLOBALS.export_onnx_opset_version = self.opset_version
478
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
480
graph, _, __ = self._model_to_graph(
481
ConcatModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
484
for node in graph.nodes():
485
self.assertNotEqual(node.kind(), "onnx::Concat")
486
self.assertNotEqual(node.kind(), "onnx::Cast")
487
self.assertEqual(len(list(graph.nodes())), 2)
489
def test_constant_fold_lstm(self):
490
class GruNet(torch.nn.Module):
493
self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False)
495
def forward(self, input, initial_state):
496
return self.mygru(input, initial_state)
498
GLOBALS.export_onnx_opset_version = self.opset_version
499
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
500
input = torch.randn(5, 3, 7)
501
h0 = torch.randn(1, 3, 3)
502
graph, _, __ = self._model_to_graph(
505
input_names=["input", "h0"],
506
dynamic_axes={"input": [0, 1, 2], "h0": [0, 1, 2]},
509
for node in graph.nodes():
510
self.assertNotEqual(node.kind(), "onnx::Slice")
511
self.assertNotEqual(node.kind(), "onnx::Concat")
512
self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
514
if self.opset_version <= 12:
515
self.assertEqual(len(list(graph.nodes())), 3)
517
# Unsqueeze op parameter "axes" as an input instead of as an attribute when opset version >= 13
518
self.assertEqual(len(list(graph.nodes())), 4)
520
def test_constant_fold_transpose_matmul(self):
521
class MatMulNet(torch.nn.Module):
524
self.B = torch.nn.Parameter(torch.ones(5, 3))
526
def forward(self, A):
527
return torch.matmul(A, torch.transpose(self.B, -1, -2))
529
GLOBALS.export_onnx_opset_version = self.opset_version
530
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
531
A = torch.randn(2, 3)
532
graph, _, __ = self._model_to_graph(
533
MatMulNet(), (A,), input_names=["A"], dynamic_axes={"A": [0, 1]}
536
for node in graph.nodes():
537
self.assertNotEqual(node.kind(), "onnx::Transpose")
538
self.assertEqual(len(list(graph.nodes())), 1)
540
def test_constant_fold_reshape(self):
541
class ReshapeModule(torch.nn.Module):
546
self.register_buffer("weight", torch.ones(5))
548
def forward(self, x):
549
b = self.weight.reshape(1, -1, 1, 1)
552
GLOBALS.export_onnx_opset_version = self.opset_version
553
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
554
x = torch.randn(4, 5)
555
graph, _, __ = self._model_to_graph(
556
ReshapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
559
for node in graph.nodes():
560
self.assertNotEqual(node.kind(), "onnx::Reshape")
561
self.assertEqual(len(list(graph.nodes())), 1)
563
def test_constant_fold_div(self):
564
class Module(torch.nn.Module):
569
self.register_buffer("weight", torch.ones(5))
571
def forward(self, x):
572
div = self.weight.div(torch.tensor([1, 2, 3, 4, 5]))
575
x = torch.randn(2, 5)
576
GLOBALS.export_onnx_opset_version = self.opset_version
577
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
578
graph, _, __ = self._model_to_graph(
579
Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
582
for node in graph.nodes():
583
self.assertNotEqual(node.kind(), "onnx::Div")
584
self.assertEqual(len(list(graph.nodes())), 1)
586
def test_constant_fold_mul(self):
587
class Module(torch.nn.Module):
592
self.register_buffer("weight", torch.ones(5))
594
def forward(self, x):
595
mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5]))
598
x = torch.randn(2, 5)
599
GLOBALS.export_onnx_opset_version = self.opset_version
600
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
601
graph, _, __ = self._model_to_graph(
602
Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
605
for node in graph.nodes():
606
self.assertNotEqual(node.kind(), "onnx::Mul")
607
self.assertEqual(len(list(graph.nodes())), 1)
609
def test_constant_fold_add(self):
610
class Module(torch.nn.Module):
615
self.register_buffer("weight", torch.ones(5))
617
def forward(self, x):
618
add = self.weight + torch.tensor([1, 2, 3, 4, 5])
621
x = torch.randn(2, 5)
622
GLOBALS.export_onnx_opset_version = self.opset_version
623
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
624
graph, params_dict, __ = self._model_to_graph(
627
do_constant_folding=True,
628
operator_export_type=OperatorExportTypes.ONNX,
630
dynamic_axes={"x": [0, 1]},
632
for node in graph.nodes():
633
self.assertTrue(node.kind() != "onnx::Add")
634
self.assertEqual(len(list(graph.nodes())), 1)
635
params = list(params_dict.values())
636
self.assertEqual(len(params), 1)
638
self.assertEqual(weight, torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0]))
640
def test_constant_fold_sub(self):
641
class Module(torch.nn.Module):
646
self.register_buffer("weight", torch.ones(5))
648
def forward(self, x):
649
sub = self.weight - torch.tensor([1, 2, 3, 4, 5])
652
x = torch.randn(2, 5)
653
GLOBALS.export_onnx_opset_version = self.opset_version
654
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
655
graph, params_dict, __ = self._model_to_graph(
658
do_constant_folding=True,
659
operator_export_type=OperatorExportTypes.ONNX,
661
dynamic_axes={"x": [0, 1]},
663
for node in graph.nodes():
664
self.assertNotEqual(node.kind(), "onnx::Sub")
665
self.assertEqual(len(list(graph.nodes())), 1)
666
params = list(params_dict.values())
667
self.assertEqual(len(params), 1)
669
self.assertEqual(weight, torch.tensor([0.0, -1.0, -2.0, -3.0, -4.0]))
671
def test_constant_fold_sqrt(self):
672
class Module(torch.nn.Module):
677
self.register_buffer("weight", torch.ones(5))
679
def forward(self, x):
680
sqrt = torch.sqrt(self.weight)
683
x = torch.randn(2, 5)
684
GLOBALS.export_onnx_opset_version = self.opset_version
685
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
686
graph, _, __ = self._model_to_graph(
687
Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
689
for node in graph.nodes():
690
self.assertNotEqual(node.kind(), "onnx::Sqrt")
691
self.assertEqual(len(list(graph.nodes())), 1)
693
def test_constant_fold_shape(self):
694
class ShapeModule(torch.nn.Module):
697
self.register_buffer("weight", torch.ones(5))
699
def forward(self, x):
700
shape = self.weight.shape[0]
703
x = torch.randn(2, 5)
704
GLOBALS.export_onnx_opset_version = self.opset_version
705
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
706
graph, _, __ = self._model_to_graph(
707
ShapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
709
for node in graph.nodes():
710
self.assertNotEqual(node.kind(), "onnx::Shape")
711
self.assertEqual(len(list(graph.nodes())), 2)
713
def test_constant_fold_upsample_scale_fold_as_constant(self):
714
# upsample scale is a constant, not a model parameter,
715
# therefore should not be added as initializer after constant folding.
716
model = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
717
x = torch.randn(1, 32, 224, 224)
719
torch.onnx.export(model, x, f)
720
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
721
self.assertEqual(len(onnx_model.graph.initializer), 0)
723
def test_verbose(self):
724
class MyModule(torch.nn.Module):
725
def forward(self, input):
726
return torch.exp(input)
728
x = torch.randn(3, 4)
730
def is_model_stripped(f, verbose=None):
732
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
735
MyModule(), x, f, verbose=verbose, opset_version=self.opset_version
737
model = onnx.load(io.BytesIO(f.getvalue()))
738
model_strip = copy.copy(model)
739
onnx.helper.strip_doc_string(model_strip)
740
return model == model_strip
742
# test verbose=False (default)
743
self.assertTrue(is_model_stripped(io.BytesIO()))
745
self.assertFalse(is_model_stripped(io.BytesIO(), True))
747
# NB: remove this test once DataParallel can be correctly handled
748
def test_error_on_data_parallel(self):
749
model = torch.nn.DataParallel(torch.nn.ReflectionPad2d((1, 2, 3, 4)))
750
x = torch.randn(1, 2, 3, 4)
752
with self.assertRaisesRegex(
754
"torch.nn.DataParallel is not supported by ONNX "
755
"exporter, please use 'attribute' module to "
756
"unwrap model from torch.nn.DataParallel. Try ",
758
torch.onnx.export(model, x, f, opset_version=self.opset_version)
760
@skipIfUnsupportedMinOpsetVersion(11)
761
def test_sequence_dim(self):
762
class Module(torch.nn.Module):
763
def forward(self, x, y):
767
# Export with scripting to keep output as Sequence type.
768
# Tracing unpacks the list.
769
script_model = torch.jit.script(model)
770
x = torch.randn(2, 3)
772
# Case 1: dynamic axis
774
y = torch.randn(2, 3)
779
opset_version=self.opset_version,
780
input_names=["x", "y"],
781
dynamic_axes={"y": [1]},
783
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
784
loop_output_value_info_proto = onnx_model.graph.output[0]
785
ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(
786
loop_output_value_info_proto.name, 1, [2, None]
788
self.assertEqual(loop_output_value_info_proto, ref_value_info_proto)
790
# Case 2: no dynamic axes.
792
y = torch.randn(2, 3)
793
torch.onnx.export(script_model, (x, y), f, opset_version=self.opset_version)
794
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
795
loop_output_value_info_proto = onnx_model.graph.output[0]
796
ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(
797
loop_output_value_info_proto.name, 1, [2, 3]
799
self.assertEqual(loop_output_value_info_proto, ref_value_info_proto)
801
def test_export_mode(self):
802
class MyModule(torch.nn.Module):
803
def forward(self, x):
808
x = torch.randn(10, 3, 128, 128)
811
# set mode to in inference mode and export in training mode
813
old_state = model.training
818
opset_version=self.opset_version,
819
training=torch.onnx.TrainingMode.TRAINING,
821
# verify that the model state is preserved
822
self.assertEqual(model.training, old_state)
824
# set mode to training mode and export in inference mode
826
old_state = model.training
831
opset_version=self.opset_version,
832
training=torch.onnx.TrainingMode.EVAL,
834
# verify that the model state is preserved
835
self.assertEqual(model.training, old_state)
837
def test_export_does_not_fail_on_frozen_scripted_module(self):
838
class Inner(torch.nn.Module):
839
def forward(self, x):
845
class Outer(torch.nn.Module):
848
self.inner = torch.jit.script(Inner())
850
def forward(self, x):
854
# Freezing is only implemented in eval mode. So we need to call eval()
855
outer_module = Outer().eval()
856
module = torch.jit.trace_module(outer_module, {"forward": (x)})
857
# jit.freeze removes the training attribute in the module
858
module = torch.jit.freeze(module)
860
torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version)
862
@skipIfUnsupportedMinOpsetVersion(15)
863
def test_local_function(self):
864
class N(torch.nn.Module):
865
def __init__(self, prob):
867
self.dropout = torch.nn.Dropout(prob)
869
def forward(self, x):
870
return self.dropout(x)
872
class M(torch.nn.Module):
873
def __init__(self, num_layers):
875
self.num_layers = num_layers
876
self.lns = torch.nn.ModuleList(
877
[torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)]
879
self.celu1 = torch.nn.CELU(1.0)
880
self.celu2 = torch.nn.CELU(2.0)
881
self.dropout = N(0.5)
883
def forward(self, x, y, z):
888
return res1 + res2, self.dropout(z)
890
x = torch.randn(2, 3)
891
y = torch.randn(2, 3)
892
z = torch.randn(2, 3)
894
# Export specified modules. Test against specifying modules that won't
895
# exist in the exported model.
896
# Model export in inference mode will remove dropout node,
897
# thus the dropout module no longer exist in graph.
903
opset_version=self.opset_version,
904
export_modules_as_functions={
911
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
913
# Check function definition
914
funcs = onnx_model.functions
915
celu_funcs = [f for f in funcs if f.name == "CELU"]
916
self.assertEqual(len(celu_funcs), 1)
917
self.assertEqual(celu_funcs[0].domain, "torch.nn.modules.activation")
918
self.assertEqual(len(celu_funcs[0].attribute), 3)
919
ln_funcs = [f for f in funcs if f.name == "LayerNorm"]
920
self.assertEqual(len(ln_funcs), 1)
921
self.assertEqual(ln_funcs[0].domain, "torch.nn.modules.normalization")
922
self.assertEqual(len(ln_funcs[0].attribute), 3)
924
# Check local function nodes
925
nodes = onnx_model.graph.node
926
celu_ns = [n for n in nodes if n.op_type == "CELU"]
927
ln_ns = [n for n in nodes if n.op_type == "LayerNorm"]
928
self.assertEqual(len(celu_ns), 2)
929
self.assertEqual(celu_ns[0].domain, "torch.nn.modules.activation")
930
self.assertEqual(len(celu_ns[0].attribute), 3)
931
self.assertEqual(len(ln_ns), 3)
932
self.assertEqual(ln_ns[0].domain, "torch.nn.modules.normalization")
933
self.assertEqual(len(ln_ns[0].attribute), 3)
935
# Export specified modules.
941
opset_version=self.opset_version,
942
export_modules_as_functions={torch.nn.CELU},
945
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
946
funcs = onnx_model.functions
947
self.assertEqual(len(funcs), 1)
948
self.assertEqual(funcs[0].name, "CELU")
950
# Export with empty specified modules. Normal export.
956
opset_version=self.opset_version,
957
export_modules_as_functions=set(),
960
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
961
funcs = onnx_model.functions
962
self.assertEqual(len(funcs), 0)
964
# Export all modules. Should contain {M, CELU, LayerNorm}.
970
opset_version=self.opset_version,
971
export_modules_as_functions=True,
974
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
975
funcs = onnx_model.functions
976
self.assertEqual(len(funcs), 3)
978
@skipIfUnsupportedMinOpsetVersion(15)
979
def test_local_function_overloads(self):
980
class NWithOverloads(torch.nn.Module):
981
def forward(self, x, y=None, z=None):
989
class M(torch.nn.Module):
990
def __init__(self, num_layers):
992
self.n = NWithOverloads()
994
def forward(self, x, y, z):
995
return self.n(x), self.n(x, y), self.n(x, y, z)
997
x = torch.randn(2, 3)
998
y = torch.randn(2, 3)
999
z = torch.randn(2, 3)
1006
opset_version=self.opset_version,
1007
export_modules_as_functions={NWithOverloads},
1010
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1011
funcs = onnx_model.functions
1012
self.assertEqual(len(funcs), 3)
1013
func_names = [f.name for f in funcs]
1014
self.assertIn("NWithOverloads", func_names)
1015
self.assertIn("NWithOverloads.1", func_names)
1016
self.assertIn("NWithOverloads.2", func_names)
1018
# Failing after ONNX 1.13.0
1019
@skipIfUnsupportedMaxOpsetVersion(1)
1020
def test_local_function_infer_scopes(self):
1021
class M(torch.nn.Module):
1022
def forward(self, x):
1023
# Concatenation of scalars inserts unscoped tensors in IR graph.
1024
new_tensor_shape = x.size()[:-1] + (1, 1, -1)
1025
tensor = x.view(*new_tensor_shape)
1028
x = torch.randn(4, 5)
1034
export_modules_as_functions=True,
1035
opset_version=self.opset_version,
1036
do_constant_folding=False,
1039
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1040
funcs = onnx_model.functions
1041
self.assertIn("M", [f.name for f in funcs])
1043
@skipIfUnsupportedMinOpsetVersion(15)
1044
def test_local_function_predefined_attributes(self):
1045
class M(torch.nn.Module):
1048
def __init__(self, num_layers):
1050
self.num_layers = num_layers
1051
self.lns = torch.nn.ModuleList(
1052
[torch.nn.LayerNorm(3, eps=1e-4) for _ in range(num_layers)]
1055
def forward(self, x):
1060
x = torch.randn(2, 3)
1067
export_modules_as_functions=True,
1068
opset_version=self.opset_version,
1071
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1072
funcs = onnx_model.functions
1073
m_funcs = [fn for fn in funcs if fn.name == "M"]
1074
self.assertEqual(m_funcs[0].attribute, ["num_layers"])
1075
ln_funcs = [fn for fn in funcs if fn.name == "LayerNorm"]
1076
self.assertEqual(ln_funcs[0].attribute, ["eps", "elementwise_affine"])
1078
from onnx import helper
1080
m_node = [n for n in onnx_model.graph.node if n.op_type == "M"]
1082
m_node[0].attribute[0],
1083
helper.make_attribute("num_layers", model.num_layers),
1086
ln_nodes = [n for n in m_funcs[0].node if n.op_type == "LayerNorm"]
1087
expected_ln_attrs = [
1088
helper.make_attribute(
1089
"elementwise_affine", model.lns[0].elementwise_affine
1091
helper.make_attribute("eps", model.lns[0].eps),
1093
for ln_node in ln_nodes:
1094
self.assertIn(ln_node.attribute[0], expected_ln_attrs)
1095
self.assertIn(ln_node.attribute[1], expected_ln_attrs)
1097
# This test cases checks the issue where an object does not have an attribute.
1098
# When enabling `export_modules_as_functions = True`, the exporter could return an
1099
# AttributeError. With this test case, we check that the export passes successfully
1100
# without any AttributeError exceptions.
1101
# See https://github.com/pytorch/pytorch/pull/109759 for an example. The exception that
1102
# this test tries to avoid is `AttributeError: 'Embedding' object has no attribute 'freeze'`.
1103
@skipIfUnsupportedMinOpsetVersion(15)
1104
def test_local_function_subset_of_predefined_attributes(self):
1105
class M(torch.nn.Module):
1108
def __init__(self, num_layers):
1110
self.embed_layer = torch.nn.Embedding.from_pretrained(
1111
torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
1113
self.num_layers = num_layers
1114
self.lns = torch.nn.ModuleList(
1115
[torch.nn.LayerNorm(3, eps=1e-4) for _ in range(num_layers)]
1118
def forward(self, x):
1119
e = self.embed_layer(torch.LongTensor([1]))
1124
x = torch.randn(2, 3)
1131
export_modules_as_functions=True,
1132
opset_version=self.opset_version,
1133
verbose=True, # Allows the test case to print `Skipping module attribute 'freeze'`
1136
def test_node_scope(self):
1137
class N(torch.nn.Module):
1140
self.relu = torch.nn.ReLU()
1142
def forward(self, x):
1145
class M(torch.nn.Module):
1146
def __init__(self, num_layers):
1148
self.num_layers = num_layers
1149
self.lns = torch.nn.ModuleList(
1150
[torch.nn.LayerNorm(3, eps=float(i)) for i in range(num_layers)]
1152
self.gelu1 = torch.nn.GELU()
1153
self.gelu2 = torch.nn.GELU()
1156
def forward(self, x, y, z):
1157
res1 = self.gelu1(x)
1158
res2 = self.gelu2(y)
1161
return res1 + res2, self.relu(z)
1163
x = torch.randn(2, 3)
1164
y = torch.randn(2, 3)
1165
z = torch.randn(2, 3)
1168
expected_scope_names = {
1169
"M::/torch.nn.modules.activation.GELU::gelu1",
1170
"M::/torch.nn.modules.activation.GELU::gelu2",
1171
"M::/torch.nn.modules.normalization.LayerNorm::lns.0",
1172
"M::/torch.nn.modules.normalization.LayerNorm::lns.1",
1173
"M::/torch.nn.modules.normalization.LayerNorm::lns.2",
1174
"M::/N::relu/torch.nn.modules.activation.ReLU::relu",
1178
graph, _, _ = self._model_to_graph(
1179
model, (x, y, z), input_names=[], dynamic_axes={}
1181
for node in graph.nodes():
1183
_remove_test_environment_prefix_from_scope_name(node.scopeName()),
1184
expected_scope_names,
1187
graph, _, _ = self._model_to_graph(
1188
torch.jit.script(model), (x, y, z), input_names=[], dynamic_axes={}
1190
for node in graph.nodes():
1192
_remove_test_environment_prefix_from_scope_name(node.scopeName()),
1193
expected_scope_names,
1196
def test_scope_of_constants_when_combined_by_cse_pass(self):
1199
class M(torch.nn.Module):
1200
def __init__(self, constant):
1202
self.constant = constant
1204
def forward(self, x):
1205
# 'self.constant' is designed to be the same for all layers,
1206
# hence it is common sub expression.
1207
return x + self.constant
1209
class N(torch.nn.Module):
1210
def __init__(self, layers: int = layer_num):
1212
self.layers = torch.nn.ModuleList(
1213
[M(constant=torch.tensor(1.0)) for i in range(layers)]
1216
def forward(self, x):
1217
for layer in self.layers:
1221
graph, _, _ = self._model_to_graph(
1222
N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
1225
# NOTE: Duplicated constants are populated due to implicit casting in scalar_type_analysis,
1226
# so we expect 3 constants with different scopes. The 3 constants are for the 3 layers.
1227
# If CSE in exporter is improved later, this test needs to be updated.
1228
# It should expect 1 constant, with same scope as root.
1229
expected_root_scope_name = "N::"
1230
expected_layer_scope_name = "M::layers"
1231
expected_constant_scope_name = [
1232
f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
1233
for i in range(layer_num)
1236
constant_scope_names = []
1237
for node in graph.nodes():
1238
if node.kind() == "onnx::Constant":
1239
constant_scope_names.append(
1240
_remove_test_environment_prefix_from_scope_name(node.scopeName())
1242
self.assertEqual(constant_scope_names, expected_constant_scope_name)
1244
def test_scope_of_nodes_when_combined_by_cse_pass(self):
1247
class M(torch.nn.Module):
1248
def __init__(self, constant, bias):
1250
self.constant = constant
1253
def forward(self, x):
1254
# 'constant' and 'x' is designed to be the same for all layers,
1255
# hence `x + self.constant` is common sub expression.
1256
# 'bias' is designed to be different for all layers,
1257
# hence `* self.bias` is not common sub expression.
1258
return (x + self.constant) * self.bias
1260
class N(torch.nn.Module):
1261
def __init__(self, layers: int = layer_num):
1264
self.layers = torch.nn.ModuleList(
1266
M(constant=torch.tensor([1.0]), bias=torch.randn(1))
1267
for i in range(layers)
1271
def forward(self, x):
1273
for layer in self.layers:
1275
return y[0], y[1], y[2]
1277
graph, _, _ = self._model_to_graph(
1278
N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
1280
expected_root_scope_name = "N::"
1281
expected_layer_scope_name = "M::layers"
1282
expected_add_scope_names = [
1283
f"{expected_root_scope_name}/{expected_layer_scope_name}.0"
1285
expected_mul_scope_names = [
1286
f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
1287
for i in range(layer_num)
1290
add_scope_names = []
1291
mul_scope_names = []
1292
for node in graph.nodes():
1293
if node.kind() == "onnx::Add":
1294
add_scope_names.append(
1295
_remove_test_environment_prefix_from_scope_name(node.scopeName())
1297
elif node.kind() == "onnx::Mul":
1298
mul_scope_names.append(
1299
_remove_test_environment_prefix_from_scope_name(node.scopeName())
1301
self.assertEqual(add_scope_names, expected_add_scope_names)
1302
self.assertEqual(mul_scope_names, expected_mul_scope_names)
1304
def test_aten_fallthrough(self):
1305
# Test aten export of op with no symbolic
1306
class Module(torch.nn.Module):
1307
def forward(self, x):
1308
return torch.erfc(x)
1310
x = torch.randn(2, 3, 4)
1311
GLOBALS.export_onnx_opset_version = self.opset_version
1312
graph, _, __ = self._model_to_graph(
1315
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1317
dynamic_axes={"x": [0, 1, 2]},
1319
iter = graph.nodes()
1320
self.assertEqual(next(iter).kind(), "aten::erfc")
1322
def test_custom_op_fallthrough(self):
1325
#include <torch/script.h>
1327
torch::Tensor custom_add(torch::Tensor self, torch::Tensor other) {
1328
return self + other;
1331
static auto registry =
1332
torch::RegisterOperators("custom_namespace::custom_op", &custom_add);
1335
torch.utils.cpp_extension.load_inline(
1337
cpp_sources=op_source,
1338
is_python_module=False,
1342
class FooModel(torch.nn.Module):
1343
def forward(self, input, other):
1345
return torch.ops.custom_namespace.custom_op(input, other)
1347
x = torch.randn(2, 3, 4, requires_grad=False)
1348
y = torch.randn(2, 3, 4, requires_grad=False)
1350
graph, _, __ = self._model_to_graph(
1353
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
1354
input_names=["x", "y"],
1355
dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
1357
iter = graph.nodes()
1358
self.assertEqual(next(iter).kind(), "custom_namespace::custom_op")
1360
def test_custom_opsets_gelu(self):
1361
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::gelu", 9)
1363
def gelu(g, self, approximate):
1364
return g.op("com.microsoft::Gelu", self).setType(self.type())
1366
torch.onnx.register_custom_op_symbolic("::gelu", gelu, 9)
1367
model = torch.nn.GELU(approximate="none")
1368
x = torch.randn(3, 3)
1374
opset_version=self.opset_version,
1375
custom_opsets={"com.microsoft": 1},
1378
graph = onnx.load(io.BytesIO(f.getvalue()))
1379
self.assertEqual(graph.graph.node[0].op_type, "Gelu")
1380
self.assertEqual(graph.opset_import[0].version, self.opset_version)
1381
self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
1382
self.assertEqual(graph.opset_import[1].version, 1)
1384
def test_register_aten_custom_op_symbolic(self):
1385
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "aten::gelu", 9)
1387
def gelu(g, self, approximate):
1388
return g.op("com.microsoft::Gelu", self).setType(self.type())
1390
torch.onnx.register_custom_op_symbolic("aten::gelu", gelu, 9)
1391
model = torch.nn.GELU(approximate="none")
1392
x = torch.randn(3, 3)
1394
torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
1395
graph = onnx.load(io.BytesIO(f.getvalue()))
1397
self.assertEqual(graph.graph.node[0].op_type, "Gelu")
1398
self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
1401
def test_custom_opsets_inverse(self):
1402
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
1404
class CustomInverse(torch.nn.Module):
1405
def forward(self, x):
1406
return torch.inverse(x) + x
1408
def linalg_inv(g, self):
1409
return g.op("com.microsoft::Inverse", self).setType(self.type())
1411
torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv, 9)
1412
model = CustomInverse()
1413
x = torch.randn(2, 3, 3)
1419
opset_version=self.opset_version,
1420
custom_opsets={"com.microsoft": 1},
1423
graph = onnx.load(io.BytesIO(f.getvalue()))
1424
self.assertEqual(graph.graph.node[0].op_type, "Inverse")
1425
self.assertEqual(graph.opset_import[0].version, self.opset_version)
1426
self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
1427
self.assertEqual(graph.opset_import[1].version, 1)
1429
def test_onnx_fallthrough(self):
1430
# Test aten export of op with symbolic for aten
1431
class Module(torch.nn.Module):
1432
def forward(self, x):
1433
return torch.digamma(x)
1435
x = torch.randn(100, 128)
1436
graph, _, __ = self._model_to_graph(
1439
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1441
dynamic_axes={"x": [0, 1]},
1443
iter = graph.nodes()
1444
self.assertEqual(next(iter).kind(), "aten::digamma")
1446
# prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11
1447
@skipIfUnsupportedMaxOpsetVersion(10)
1448
def test_prim_fallthrough(self):
1450
class PrimModule(torch.jit.ScriptModule):
1451
@torch.jit.script_method
1452
def forward(self, x):
1453
if isinstance(x, list):
1459
x = torch.tensor([2])
1460
model = PrimModule()
1462
graph, _, __ = self._model_to_graph(
1465
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1467
dynamic_axes={"x": [0]},
1469
iter = graph.nodes()
1470
self.assertEqual(next(iter).kind(), "prim::ListConstruct")
1472
def test_custom_layer_tuple(self):
1473
class CustomFunction(torch.autograd.Function):
1475
def symbolic(g, input):
1476
return g.op("CustomNamespace::Custom", input, outputs=2)
1479
def forward(ctx, input):
1482
class Custom(torch.nn.Module):
1483
def forward(self, input):
1484
return CustomFunction.apply(input)
1487
batch = torch.FloatTensor(1, 3)
1489
graph, _, _ = self._model_to_graph(
1490
model, batch, input_names=["batch"], dynamic_axes={"batch": [0, 1]}
1492
iter = graph.nodes()
1493
self.assertEqual(next(iter).kind(), "CustomNamespace::Custom")
1495
def test_autograd_onnx_fallthrough(self):
1496
class CustomFunction(torch.autograd.Function):
1498
def forward(ctx, input):
1499
ctx.save_for_backward(input)
1500
return input.clamp(min=0)
1503
def backward(ctx, grad_output):
1504
(input,) = ctx.saved_tensors
1505
grad_input = grad_output.clone()
1506
grad_input[input < 0] = 0
1509
class Custom(torch.nn.Module):
1510
def forward(self, input):
1511
return CustomFunction.apply(input)
1514
batch = torch.FloatTensor(1, 3)
1516
graph, _, _ = self._model_to_graph(
1519
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1520
input_names=["batch"],
1521
dynamic_axes={"batch": [0, 1]},
1523
iter = graph.nodes()
1524
self.assertEqual(next(iter).kind(), "prim::PythonOp")
1526
def test_autograd_module_name(self):
1527
class CustomFunction(torch.autograd.Function):
1529
def forward(ctx, input):
1530
ctx.save_for_backward(input)
1531
return input.clamp(min=0)
1534
def backward(ctx, grad_output):
1535
(input,) = ctx.saved_tensors
1536
grad_input = grad_output.clone()
1537
grad_input[input < 0] = 0
1540
class Custom(torch.nn.Module):
1541
def forward(self, input):
1542
return CustomFunction.apply(input) + CustomFunction2.apply(input)
1545
batch = torch.FloatTensor(1, 3)
1547
graph, _, _ = self._model_to_graph(
1550
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
1551
input_names=["batch"],
1552
dynamic_axes={"batch": [0, 1]},
1554
iter = graph.nodes()
1555
autograd1 = next(iter)
1556
autograd2 = next(iter)
1557
self.assertEqual(autograd1.kind(), "prim::PythonOp")
1558
self.assertEqual(autograd2.kind(), "prim::PythonOp")
1559
self.assertNotEqual(autograd1.s("module"), autograd2.s("module"))
1561
def test_unused_initializers(self):
1562
class Model(torch.nn.Module):
1565
self.conv2 = torch.nn.ConvTranspose2d(
1566
16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1)
1568
self.k_proj = torch.nn.Linear(5, 5, bias=True)
1570
def forward(self, x):
1574
x = torch.randn(20, 16, 50, 100)
1575
GLOBALS.export_onnx_opset_version = self.opset_version
1576
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
1577
_, params_dict, __ = self._model_to_graph(
1580
do_constant_folding=False,
1581
operator_export_type=OperatorExportTypes.ONNX,
1583
dynamic_axes={"x": [0, 1, 2, 3]},
1586
self.assertEqual(len(params_dict), 2)
1588
def test_scripting_param(self):
1589
class MyModule(torch.nn.Module):
1592
self.conv = torch.nn.Conv2d(
1593
3, 16, kernel_size=1, stride=2, padding=3, bias=True
1595
self.bn = torch.nn.BatchNorm2d(16, affine=True)
1597
def forward(self, x):
1602
model = torch.jit.script(MyModule())
1603
x = torch.randn(10, 3, 128, 128)
1604
GLOBALS.export_onnx_opset_version = self.opset_version
1605
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
1606
graph, _, __ = self._model_to_graph(
1609
do_constant_folding=True,
1610
operator_export_type=OperatorExportTypes.ONNX,
1611
training=torch.onnx.TrainingMode.TRAINING,
1613
dynamic_axes={"x": [0, 1, 2, 3]},
1616
graph_input_params = [param.debugName() for param in graph.inputs()]
1617
for item in dict(model.named_parameters()):
1621
"Graph parameter names does not match model parameters.",
1625
def test_modifying_params(self):
1626
class MyModel(torch.nn.Module):
1629
self.param = torch.nn.Parameter(torch.tensor([2.0]))
1631
def forward(self, x):
1633
self.param.data.add_(1.0)
1636
x = torch.tensor([1, 2])
1637
# Move import to local as caffe2 backend requires additional build flag,
1638
# and is only used in this test case.
1639
import caffe2.python.onnx.backend as backend
1641
verify(MyModel(), x, backend, do_constant_folding=False)
1643
def test_fuse_conv_bn(self):
1644
class Fuse(torch.nn.Module):
1647
self.conv = torch.nn.Conv2d(
1648
3, 2, kernel_size=1, stride=2, padding=3, bias=True
1650
self.bn = torch.nn.BatchNorm2d(2)
1652
def forward(self, x):
1656
x = torch.randn(2, 3, 2, 2, requires_grad=True)
1657
graph, _, __ = self._model_to_graph(
1660
training=TrainingMode.EVAL,
1662
dynamic_axes={"x": [0, 1, 2, 3]},
1664
for node in graph.nodes():
1665
self.assertNotEqual(node.kind(), "onnx::BatchNormalization")
1666
self.assertEqual(node.kind(), "onnx::Conv")
1668
self.assertEqual(len(list(graph.nodes())), 1)
1670
def test_fuse_resnet18(self):
1671
model = torchvision.models.resnet18(weights=None)
1672
x = torch.randn(2, 3, 224, 224, requires_grad=True)
1673
graph, _, __ = self._model_to_graph(
1676
training=TrainingMode.EVAL,
1678
dynamic_axes={"x": [0, 1, 2, 3]},
1681
for node in graph.nodes():
1682
self.assertNotEqual(node.kind(), "onnx::BatchNormalization")
1684
def test_onnx_function_substitution_pass(self):
1686
def f(x: torch.Tensor, y: torch.Tensor):
1690
class MyModule(torch.nn.Module):
1691
def forward(self, x, y):
1694
input_1 = torch.tensor([11])
1695
input_2 = torch.tensor([12])
1696
GLOBALS.export_onnx_opset_version = self.opset_version
1697
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
1698
graph, _, __ = self._model_to_graph(
1701
do_constant_folding=True,
1702
operator_export_type=OperatorExportTypes.ONNX,
1703
input_names=["input_1", "input_2"],
1704
dynamic_axes={"input_1": [0], "input_2": [0]},
1706
# Check that the prim::Constant node in the graph for representing the
1707
# scripted function `f` is removed and the following prim::CallFunction
1708
# is replced by inline graph, with onnx::Sub and onnx::Add nodes.
1709
for node in graph.nodes():
1710
self.assertNotEqual(node.kind(), "prim::Constant")
1712
len(list(graph.nodes())), 2
1713
) # onnx::Sub and onnx::Add nodes only.
1715
def test_onnx_value_name(self):
1716
class MyModule(torch.nn.Module):
1719
self.in_weight = torch.nn.Parameter(torch.Tensor(3, 3))
1720
self.in_bias = torch.nn.Parameter(torch.Tensor(3))
1722
def forward(self, x):
1725
weight = self.in_weight
1727
weight = weight[start:end, :]
1728
if bias is not None:
1729
bias = bias[start:end]
1730
return torch.nn.functional.linear(x, weight, bias)
1733
x = torch.randn(3, 3)
1741
opset_version=self.opset_version,
1742
keep_initializers_as_inputs=True,
1744
graph = onnx.load(io.BytesIO(f.getvalue()))
1745
self.assertEqual(graph.graph.input[1].name, "in_weight")
1746
self.assertEqual(graph.graph.input[2].name, "in_bias")
1748
def test_onnx_node_naming(self):
1749
class MainModule(torch.nn.Module):
1752
self._module_1 = torch.nn.Linear(10, 10)
1753
self._module_2 = torch.nn.Linear(10, 10)
1754
self._module_3 = torch.nn.Linear(10, 10)
1755
self._module_4 = torch.nn.Linear(10, 10)
1757
def forward(self, x):
1758
y = self._module_1(x)
1759
z = self._module_2(y)
1760
z = self._module_3(y * z)
1761
z = self._module_4(y * z)
1764
module = MainModule()
1775
torch.onnx.export(module, torch.ones(1, 10), f, output_names=["y"])
1776
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1777
for n in onnx_model.graph.node:
1778
self.assertIn(n.name, ref_node_names)
1781
torch.jit.script(module), torch.ones(1, 10), f, output_names=["y"]
1783
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1784
for n in onnx_model.graph.node:
1785
self.assertIn(n.name, ref_node_names)
1787
def _test_deduplicate_initializers(self, torchscript=False):
1788
class MyModule(torch.nn.Module):
1791
self.layer1 = torch.nn.Linear(3, 3)
1792
self.layer2 = torch.nn.Linear(3, 3)
1795
self.layer3 = self.layer1
1797
# Reusing parameters.
1798
self.layer2.weight = self.layer1.weight
1799
self.layer1.bias = self.layer2.bias
1801
# Parameter with different tensors equal in value.
1802
self.param1 = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
1803
self.param2 = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
1805
def forward(self, x):
1807
self.layer3(self.layer2(self.layer1(x))) + self.param1 + self.param2
1810
model = torch.jit.script(MyModule()) if torchscript else MyModule()
1812
x = torch.randn(3, 3)
1813
param_name_set = {k for k, _ in model.named_parameters()}
1815
# Test training mode.
1822
training=TrainingMode.TRAINING,
1823
opset_version=self.opset_version,
1825
graph = onnx.load(io.BytesIO(f.getvalue()))
1826
self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
1834
training=TrainingMode.PRESERVE,
1835
opset_version=self.opset_version,
1837
graph = onnx.load(io.BytesIO(f.getvalue()))
1838
self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
1843
torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
1844
graph = onnx.load(io.BytesIO(f.getvalue()))
1845
param_name_set.remove("param2")
1846
self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
1848
def test_deduplicate_initializers(self):
1849
self._test_deduplicate_initializers(torchscript=False)
1851
def test_deduplicate_initializers_torchscript(self):
1852
self._test_deduplicate_initializers(torchscript=True)
1855
def test_deduplicate_initializers_diff_devices(self):
1856
class Model(torch.nn.Module):
1859
self.w_cpu = torch.nn.Parameter(
1860
torch.ones(3, device=torch.device("cpu"))
1862
self.w_cuda = torch.nn.Parameter(
1863
torch.ones(3, device=torch.device("cuda"))
1866
def forward(self, x, y):
1867
return x + self.w_cpu, y + self.w_cuda
1869
x = torch.randn(3, 3, device=torch.device("cpu"))
1870
y = torch.randn(3, 3, device=torch.device("cuda"))
1872
torch.onnx.export(Model(), (x, y), f, opset_version=self.opset_version)
1873
graph = onnx.load(io.BytesIO(f.getvalue()))
1874
self.assertSetEqual({i.name for i in graph.graph.initializer}, {"w_cpu"})
1876
def test_duplicated_output_node(self):
1877
class DuplicatedOutputNet(torch.nn.Module):
1878
def __init__(self, input_size, num_classes):
1880
self.fc1 = torch.nn.Linear(input_size, num_classes)
1882
def forward(self, input0, input1):
1883
out1 = self.fc1(input0)
1884
out2 = self.fc1(input1)
1885
return out1, out1, out2, out1, out2
1887
N, D_in, H, D_out = 64, 784, 500, 10
1888
pt_model = DuplicatedOutputNet(D_in, D_out)
1891
x = torch.randn(N, D_in)
1893
"input0": {0: "input0_dim0", 1: "input0_dim1"},
1894
"input1": {0: "input1_dim0", 1: "input1_dim1"},
1895
"output-0": {0: "output-0_dim0", 1: "output-0_dim1"},
1896
"output-1": {0: "output-1_dim0", 1: "output-1_dim1"},
1897
"output-2": {0: "output-2_dim0", 1: "output-2_dim1"},
1898
"output-3": {0: "output-3_dim0", 1: "output-3_dim1"},
1899
"output-4": {0: "output-4_dim0", 1: "output-4_dim1"},
1906
input_names=["input0", "input1"],
1907
output_names=["output-0", "output-1", "output-2", "output-3", "output-4"],
1908
do_constant_folding=False,
1909
training=torch.onnx.TrainingMode.TRAINING,
1910
dynamic_axes=dynamic_axes,
1912
keep_initializers_as_inputs=True,
1915
graph = onnx.load(io.BytesIO(f.getvalue()))
1916
self.assertEqual(graph.graph.input[0].name, "input0")
1917
self.assertEqual(graph.graph.input[1].name, "input1")
1919
self.assertEqual(graph.graph.output[i].name, f"output-{i}")
1920
self.assertEqual(graph.graph.node[0].op_type, "Gemm")
1921
self.assertEqual(graph.graph.node[1].op_type, "Identity")
1922
self.assertEqual(graph.graph.node[2].op_type, "Identity")
1923
self.assertEqual(graph.graph.node[3].op_type, "Gemm")
1924
self.assertEqual(graph.graph.node[4].op_type, "Identity")
1926
def test_deduplicate_ignore_upsample_scale(self):
1927
# upsample scale is a constant, not a model parameter,
1928
# therefore should be ignored by shared weight deduplication.
1929
class Model(torch.nn.Module):
1932
self.upsample_1 = torch.nn.Upsample(scale_factor=2)
1933
self.upsample_2 = torch.nn.Upsample(scale_factor=2)
1935
def forward(self, x):
1936
return self.upsample_1(x), self.upsample_2(x)
1939
x = torch.randn(1, 32, 224, 224)
1940
torch.onnx.export(Model(), x, f)
1941
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1942
# aten::upsample converts to onnx::resize
1943
resize_nodes = [n for n in onnx_model.graph.node if n.op_type == "Resize"]
1944
self.assertEqual(len(resize_nodes), 2)
1945
for resize_node in resize_nodes:
1947
n for n in onnx_model.graph.node if n.output[0] == resize_node.input[2]
1949
self.assertEqual(len(scale_node), 1)
1950
self.assertEqual(scale_node[0].op_type, "Constant")
1952
def test_bad_symbolic_registration(self):
1953
_onnx_opset_version = 9
1956
def cat(g, tensor_list, dim):
1957
tensors = _unpack_list(tensor_list)
1958
return g.op("Concat", *tensors, axis_i=dim)
1960
torch.onnx.register_custom_op_symbolic("::cat", cat, _onnx_opset_version)
1962
class CatModel(torch.nn.Module):
1963
def forward(self, x):
1964
return torch.cat((x, x, x), 0)
1967
x = torch.randn(2, 3)
1969
self.assertExpectedRaisesInline(
1971
lambda: torch.onnx.export(
1972
model, (x,), f, opset_version=_onnx_opset_version
1975
"A mismatch between the number of arguments (2) and their descriptors (1) was found at symbolic function "
1976
"'cat'. If you believe this is not due to custom symbolic implementation within your code or an external "
1977
"library, please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to "
1981
torch.onnx.unregister_custom_op_symbolic("::cat", _onnx_opset_version)
1984
if __name__ == "__main__":
1985
common_utils.run_tests()