8
import pytorch_test_common
9
from pytorch_test_common import skipIfUnsupportedMinOpsetVersion
12
from torch.onnx import _constants, utils
13
from torch.onnx._globals import GLOBALS
14
from torch.onnx._internal import jit_utils
15
from torch.testing._internal import common_utils
18
def expect_tensor(scalar_type, shape=None):
19
def verify(actual_type):
20
np.testing.assert_equal(actual_type.scalarType(), scalar_type)
24
np.testing.assert_equal(actual_type.varyingSizes(), shape)
29
def as_graphcontext(graph: torch.Graph) -> jit_utils.GraphContext:
30
return jit_utils.GraphContext(
33
opset=_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET,
41
def g_op(graph: torch.Graph, op_name: str, *args, **kwargs):
42
return as_graphcontext(graph).op(op_name, *args, **kwargs)
45
class TestONNXShapeInference(pytorch_test_common.ExportTestCase):
47
self.opset_version = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET
48
GLOBALS.export_onnx_opset_version = self.opset_version
50
def run_test(self, g, n, type_assertion_funcs):
51
if not isinstance(type_assertion_funcs, list):
52
type_assertion_funcs = [type_assertion_funcs]
54
torch._C._jit_pass_onnx_graph_shape_type_inference(g, {}, self.opset_version)
55
for out, type_assertion_func in zip(n.outputs(), type_assertion_funcs):
56
type_assertion_func(out.type())
58
def create_empty_graph(self):
61
torch._C._jit_pass_onnx_graph_shape_type_inference(g, {}, self.opset_version)
64
def insert_tensor_constant(self, g, tensor):
65
return g_op(g, "Constant", value_t=tensor)
69
g = self.create_empty_graph()
71
cast_out = g_op(g, "Cast", input, to_i=1)
72
self.run_test(g, cast_out.node(), expect_tensor("Float"))
74
def test_constant_of_shape(self):
76
g = self.create_empty_graph()
77
constant = self.insert_tensor_constant(g, torch.ones(1, 2, 3, 4))
78
shape = g_op(g, "Shape", constant)
79
constant_of_shape = g_op(
80
g, "ConstantOfShape", shape, value_t=torch.tensor([2.0])
83
g, constant_of_shape.node(), expect_tensor("Float", shape=(1, 2, 3, 4))
86
def test_constant_of_shape_static(self):
89
g = self.create_empty_graph()
91
self.insert_tensor_constant(g, torch.tensor(i + 1)) for i in range(rank)
93
shape = g_op(g, "prim::ListConstruct", *constants)
94
shape.setType(torch._C.ListType.ofInts())
95
constant_of_shape = g_op(
96
g, "ConstantOfShape", shape, value_t=torch.tensor([2.0])
99
g, constant_of_shape.node(), expect_tensor("Float", shape=(1, 2, 3, 4))
102
def test_constant_of_shape_dynamic(self):
105
g = self.create_empty_graph()
106
inputs = [g.addInput() for i in range(rank)]
107
shape = g_op(g, "prim::ListConstruct", *inputs)
108
shape.setType(torch._C.ListType.ofInts())
109
constant_of_shape = g_op(
110
g, "ConstantOfShape", shape, value_t=torch.tensor([2.0])
114
constant_of_shape.node(),
115
expect_tensor("Float", shape=(None, None, None, None)),
118
def test_gather_dynamic_index(self):
119
g = self.create_empty_graph()
122
input.type().with_dtype(torch.float).with_sizes([None, 3, 16, 16])
124
indices = g.addInput()
125
indices.setType(indices.type().with_dtype(torch.int64).with_sizes([None]))
126
output = g_op(g, "Gather", input, indices, axis_i=1)
128
g, output.node(), expect_tensor("Float", shape=([None, None, 16, 16]))
131
def test_gather_scalar_index(self):
132
g = self.create_empty_graph()
135
input.type().with_dtype(torch.float).with_sizes([None, 3, 16, 16])
137
indices = self.insert_tensor_constant(g, torch.tensor(1))
138
output = g_op(g, "Gather", input, indices, axis_i=1)
139
self.run_test(g, output.node(), expect_tensor("Float", shape=([None, 16, 16])))
141
def test_reshape(self):
142
g = self.create_empty_graph()
143
constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 5))
144
constant_2 = self.insert_tensor_constant(g, torch.tensor([2, 0, -1]))
145
shape = g_op(g, "Reshape", constant, constant_2)
146
self.run_test(g, shape.node(), expect_tensor("Float", shape=(2, 16, 25)))
148
g = self.create_empty_graph()
149
constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 4))
150
constant_2 = self.insert_tensor_constant(g, torch.tensor([-1, 0, 4]))
151
shape = g_op(g, "Reshape", constant, constant_2)
152
self.run_test(g, shape.node(), expect_tensor("Float", shape=(10, 16, 4)))
154
g = self.create_empty_graph()
155
constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 4))
156
constant_2 = self.insert_tensor_constant(g, torch.tensor([-1, 0, 0]))
157
shape = g_op(g, "Reshape", constant, constant_2)
158
self.run_test(g, shape.node(), expect_tensor("Float", shape=(8, 16, 5)))
160
def test_reshape_symbolic(self):
161
g = self.create_empty_graph()
163
input.setType(input.type().with_sizes([None, None, 2, 8]))
164
constant = self.insert_tensor_constant(g, torch.tensor([0, 0, -1]))
165
output = g_op(g, "Reshape", input, constant)
166
self.run_test(g, output.node(), expect_tensor(None, shape=(None, None, 16)))
168
@skipIfUnsupportedMinOpsetVersion(14)
169
def test_reshape_allowzero(self):
170
g = self.create_empty_graph()
172
input.setType(input.type().with_sizes([3, 4, 0]))
173
constant = self.insert_tensor_constant(g, torch.tensor([0, 4, 3]))
174
output = g_op(g, "Reshape", input, constant, allowzero_i=1)
175
self.run_test(g, output.node(), expect_tensor(None, shape=(0, 4, 3)))
177
def test_slice(self):
178
g = self.create_empty_graph()
180
input.setType(input.type().with_sizes([None, None]))
181
start_input = g.addInput()
182
start_input.setType(start_input.type().with_sizes([None]))
183
end = self.insert_tensor_constant(g, torch.tensor([3]))
184
axis = self.insert_tensor_constant(g, torch.tensor([0]))
185
step = self.insert_tensor_constant(g, torch.tensor([1]))
186
slice = g_op(g, "Slice", input, start_input, end, axis, step)
187
self.run_test(g, slice.node(), expect_tensor(None, shape=(None, None)))
189
def test_slice_with_dynamic_start_index(self):
190
g = self.create_empty_graph()
191
input = self.insert_tensor_constant(g, torch.ones(2, 3, 4, 5))
192
start_input = g.addInput()
193
start_input.setType(start_input.type().with_sizes([2]))
194
end = self.insert_tensor_constant(g, torch.tensor([3, 4]))
195
axis = self.insert_tensor_constant(g, torch.tensor([1, -1]))
196
slice = g_op(g, "Slice", input, start_input, end, axis)
197
self.run_test(g, slice.node(), expect_tensor(None, shape=(2, None, 4, None)))
199
def test_broadcast_matmul(self):
200
g = self.create_empty_graph()
201
constant = self.insert_tensor_constant(g, torch.ones(5, 1, 2))
202
constant_2 = self.insert_tensor_constant(g, torch.ones(3, 1, 2, 1))
203
shape = g_op(g, "MatMul", constant, constant_2)
204
self.run_test(g, shape.node(), expect_tensor("Float", shape=(3, 5, 1, 1)))
207
g = self.create_empty_graph()
208
constant = self.insert_tensor_constant(g, torch.ones(2))
209
constant_2 = self.insert_tensor_constant(g, torch.ones(3, 1, 2, 1))
210
shape = g_op(g, "MatMul", constant, constant_2)
211
self.run_test(g, shape.node(), expect_tensor("Float", shape=(3, 1, 1)))
214
g = self.create_empty_graph()
215
constant = self.insert_tensor_constant(g, torch.ones(5, 1, 2))
216
constant_2 = self.insert_tensor_constant(g, torch.ones(2))
217
shape = g_op(g, "MatMul", constant, constant_2)
218
self.run_test(g, shape.node(), expect_tensor("Float", shape=(5, 1)))
221
g = self.create_empty_graph()
222
constant = self.insert_tensor_constant(g, torch.ones(2))
223
constant_2 = self.insert_tensor_constant(g, torch.ones(2))
224
shape = g_op(g, "MatMul", constant, constant_2)
225
self.run_test(g, shape.node(), expect_tensor("Float", shape=()))
227
def test_expand(self):
228
g = self.create_empty_graph()
230
constant = self.insert_tensor_constant(g, torch.ones(2, 4))
231
input.setType(constant.type().with_sizes([None, None]))
232
shape = g_op(g, "Shape", input)
233
expand = g_op(g, "Expand", constant, shape)
234
self.run_test(g, expand.node(), expect_tensor("Float", shape=(None, None)))
237
g = self.create_empty_graph()
239
input.setType(input.type().with_dtype(torch.float).with_sizes([3, 320, 100]))
240
constant = self.insert_tensor_constant(g, torch.ones(6, dtype=torch.long))
241
none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
242
pad = g_op(g, "Pad", input, constant, none, mode_s="constant")
243
self.run_test(g, pad.node(), expect_tensor("Float", shape=(5, 322, 102)))
245
def test_pad_with_dynamic_input_shape(self):
246
g = self.create_empty_graph()
248
input.setType(input.type().with_dtype(torch.float).with_sizes([3, None, None]))
249
constant = self.insert_tensor_constant(g, torch.ones(6, dtype=torch.long))
250
none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
251
pad = g_op(g, "Pad", input, constant, none, mode_s="constant")
252
self.run_test(g, pad.node(), expect_tensor("Float", shape=(5, None, None)))
254
def test_pad_with_dynamic_pad_size(self):
255
g = self.create_empty_graph()
257
input.setType(input.type().with_dtype(torch.float).with_sizes([3, 320, 100]))
258
pad_size = g.addInput()
259
pad_size.setType(pad_size.type().with_dtype(torch.long).with_sizes([6]))
260
none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
261
pad = g_op(g, "Pad", input, pad_size, none, mode_s="constant")
262
self.run_test(g, pad.node(), expect_tensor("Float", shape=(None, None, None)))
264
def test_resize(self):
265
g = self.create_empty_graph()
267
input.setType(input.type().with_dtype(torch.float).with_sizes([4, 32, 64, 64]))
268
none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
269
scales = self.insert_tensor_constant(
270
g, torch.tensor([1, 1, 2, 2], dtype=torch.float)
278
coordinate_transformation_mode_s="align_corners",
279
cubic_coeff_a_f=-0.75,
281
nearest_mode_s="floor",
283
self.run_test(g, resize.node(), expect_tensor("Float", shape=(4, 32, 128, 128)))
285
def test_resize_after_concat(self):
286
g = self.create_empty_graph()
288
input.setType(input.type().with_dtype(torch.float).with_sizes([4, 32, 64, 64]))
289
none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
290
scale_1 = self.insert_tensor_constant(
291
g, torch.tensor([1, 1], dtype=torch.float)
293
scale_2 = self.insert_tensor_constant(
294
g, torch.tensor([2, 2], dtype=torch.float)
297
scales = g_op(g, "Concat", scale_1, scale_2, axis_i=0)
304
coordinate_transformation_mode_s="align_corners",
305
cubic_coeff_a_f=-0.75,
307
nearest_mode_s="floor",
309
self.run_test(g, resize.node(), expect_tensor("Float", shape=(4, 32, 128, 128)))
311
def test_reduce_prod_with_axes(self):
312
g = self.create_empty_graph()
314
input.setType(input.type().with_dtype(torch.long).with_sizes([2]))
315
reduce_prod = g_op(g, "ReduceProd", input, axes_i=[0])
316
self.run_test(g, reduce_prod.node(), expect_tensor("Long", shape=(1,)))
318
def test_reduce_prod_without_axes(self):
319
g = self.create_empty_graph()
321
input.setType(input.type().with_dtype(torch.long).with_sizes([2]))
322
reduce_prod = g_op(g, "ReduceProd", input)
323
self.run_test(g, reduce_prod.node(), expect_tensor("Long", shape=(1,)))
325
def test_proceeding_nodes_use_prim_pack_padded_output_dtype_correctly(self):
326
g = self.create_empty_graph()
328
input.setType(input.type().with_dtype(torch.float).with_sizes([4, 16]))
329
length = g.addInput()
330
length.setType(length.type().with_dtype(torch.long).with_sizes([4]))
331
padded, batch_size = g_op(g, "prim::PackPadded", input, length, outputs=2)
334
padded.setType(padded.type().with_dtype(torch.float).with_sizes([None, None]))
335
batch_size.setType(batch_size.type().with_dtype(torch.long).with_sizes([None]))
337
gather_idx = self.insert_tensor_constant(g, torch.tensor([0], dtype=torch.long))
338
gather = g_op(g, "Gather", batch_size, gather_idx, axis_i=0)
339
self.run_test(g, gather.node(), expect_tensor("Long", shape=(None,)))
341
def test_squeeze_after_dynamic_if(self):
342
from torch.onnx.symbolic_opset11 import squeeze as squeeze11
344
g = self.create_empty_graph()
347
input.setType(input.type().with_dtype(torch.float).with_sizes([1, None, 5]))
352
cond.setType(input.type().with_dtype(torch.int32).with_sizes([1]))
353
if_op, (if_context, else_context), new_node = jit_utils.add_op_with_blocks(
354
as_graphcontext(g), "If", cond, n_blocks=2
356
block1_output = if_context.op("Add", input, input)
357
block2_output = else_context.op("Identity", input)
358
utils._add_output_to_block(if_context.block, block1_output)
359
utils._add_output_to_block(else_context.block, block2_output)
360
if_output = torch._C._jit_pass_fixup_onnx_controlflow_node(
361
new_node, _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET
363
torch._C._jit_pass_onnx_node_shape_type_inference(
364
new_node, {}, _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET
369
squeezed = squeeze11(as_graphcontext(g), if_output, dim=0)
370
assert squeezed.node().kind() == "onnx::Squeeze"
371
self.run_test(g, squeezed.node(), expect_tensor("Float", shape=(None, 5)))
374
class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
377
self.opset_version = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET
379
def test_setType_maintains_output_shape_for_single_custom_op(self):
380
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
382
class CustomInverse(torch.nn.Module):
383
def forward(self, x):
384
return torch.inverse(x) + x
386
def linalg_inv_settype(g, self):
387
return g.op("com.microsoft::Inverse", self).setType(self.type())
389
torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9)
390
model = CustomInverse()
391
x = torch.randn(2, 3, 3)
397
opset_version=self.opset_version,
398
custom_opsets={"com.microsoft": 1},
401
model_proto = onnx.load(io.BytesIO(f.getvalue()))
402
model_value_info = model_proto.graph.value_info
403
self.assertIsNotNone(model_value_info)
404
assert model_value_info
405
dims = model_value_info[0].type.tensor_type.shape.dim
406
for i in range(len(dims)):
409
self.assertTrue(dims[i].HasField("dim_value"))
410
for dim, rank in zip(dims, x.size()):
411
self.assertEqual(dim.dim_value, rank)
413
def test_no_setType_for_single_custom_op(self):
414
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
416
class CustomInverse(torch.nn.Module):
417
def forward(self, x):
418
return torch.inverse(x) + x
420
def linalg_inv_no_settype(g, self):
421
return g.op("com.microsoft::Inverse", self)
423
torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_no_settype, 9)
424
model = CustomInverse()
425
x = torch.randn(2, 3, 3)
431
opset_version=self.opset_version,
432
custom_opsets={"com.microsoft": 1},
435
model_proto = onnx.load(io.BytesIO(f.getvalue()))
436
model_value_info = model_proto.graph.value_info
437
self.assertIsNotNone(model_value_info)
438
assert model_value_info
439
dims = model_value_info[0].type.tensor_type.shape.dim
440
for i in range(len(dims)):
443
self.assertTrue(dims[i].HasField("dim_param"))
445
def test_setType_maintains_output_shape_for_single_custom_op_with_dynamic_axes(
448
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
450
class CustomInverse(torch.nn.Module):
451
def forward(self, x):
452
return torch.inverse(x) + x
454
def linalg_inv_settype(g, self):
455
return g.op("com.microsoft::Inverse", self).setType(
456
self.type().with_dtype(torch.float).with_sizes([None, 3, 3])
459
torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9)
460
model = CustomInverse()
461
x = torch.randn(2, 3, 3)
467
opset_version=self.opset_version,
468
custom_opsets={"com.microsoft": 1},
470
dynamic_axes={"x": {0: "batch"}},
473
model_proto = onnx.load(io.BytesIO(f.getvalue()))
474
model_value_info = model_proto.graph.value_info
475
self.assertIsNotNone(model_value_info)
476
assert model_value_info
477
dims = model_value_info[0].type.tensor_type.shape.dim
479
self.assertTrue(dims[0].HasField("dim_param"))
480
for i in range(1, len(dims)):
483
self.assertTrue(dims[i].HasField("dim_value"))
484
self.assertEqual(dims[i].dim_value, x.size()[i])
486
def test_setType_maintains_output_shape_for_single_custom_op_with_onnx_ops(self):
487
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
489
class CustomInverse(torch.nn.Module):
490
def forward(self, x, y, z):
494
def linalg_inv_settype(g, self):
495
return g.op("com.microsoft::Inverse", self).setType(
496
self.type().with_dtype(torch.float).with_sizes([2, 3, 10, 10])
499
torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9)
500
model = CustomInverse()
501
x = torch.randn(2, 3, 10, 10)
502
y = torch.randn(2, 3, 10, 10)
503
z = torch.randn(2, 3, 10, 10)
509
opset_version=self.opset_version,
510
custom_opsets={"com.microsoft": 1},
513
model_proto = onnx.load(io.BytesIO(f.getvalue()))
517
for node in model_proto.graph.node:
518
if node.op_type == "Inverse":
519
output_name = node.output[0]
522
model_value_info = model_proto.graph.value_info
523
self.assertIsNotNone(model_value_info)
524
assert model_value_info
525
for value_info in model_value_info:
526
assert value_info.name
527
if value_info.name == output_name:
528
dims = value_info.type.tensor_type.shape.dim
529
for i in range(len(dims)):
532
self.assertTrue(dims[i].HasField("dim_value"))
533
for dim, rank in zip(dims, x.size()):
534
self.assertEqual(dim.dim_value, rank)
537
if __name__ == "__main__":
538
common_utils.run_tests()