pytorch

Форк
0
/
test_pytorch_onnx_shape_inference.py 
538 строк · 22.4 Кб
1
# Owner(s): ["module: onnx"]
2

3
import io
4

5
import numpy as np
6

7
import onnx
8
import pytorch_test_common
9
from pytorch_test_common import skipIfUnsupportedMinOpsetVersion
10

11
import torch
12
from torch.onnx import _constants, utils
13
from torch.onnx._globals import GLOBALS
14
from torch.onnx._internal import jit_utils
15
from torch.testing._internal import common_utils
16

17

18
def expect_tensor(scalar_type, shape=None):
19
    def verify(actual_type):
20
        np.testing.assert_equal(actual_type.scalarType(), scalar_type)
21
        # if shape is not None:
22
        #     np.testing.assert_equal(actual_type.sizes(), shape)
23
        if shape is not None:
24
            np.testing.assert_equal(actual_type.varyingSizes(), shape)
25

26
    return verify
27

28

29
def as_graphcontext(graph: torch.Graph) -> jit_utils.GraphContext:
30
    return jit_utils.GraphContext(
31
        graph=graph,
32
        block=graph.block(),
33
        opset=_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET,
34
        original_node=None,  # type: ignore[arg-type]
35
        params_dict={},
36
        env={},
37
        values_in_env=set(),
38
    )
39

40

41
def g_op(graph: torch.Graph, op_name: str, *args, **kwargs):
42
    return as_graphcontext(graph).op(op_name, *args, **kwargs)
43

44

45
class TestONNXShapeInference(pytorch_test_common.ExportTestCase):
46
    def setUp(self):
47
        self.opset_version = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET
48
        GLOBALS.export_onnx_opset_version = self.opset_version
49

50
    def run_test(self, g, n, type_assertion_funcs):
51
        if not isinstance(type_assertion_funcs, list):
52
            type_assertion_funcs = [type_assertion_funcs]
53

54
        torch._C._jit_pass_onnx_graph_shape_type_inference(g, {}, self.opset_version)
55
        for out, type_assertion_func in zip(n.outputs(), type_assertion_funcs):
56
            type_assertion_func(out.type())
57

58
    def create_empty_graph(self):
59
        g = torch._C.Graph()
60
        # kick off initialization for ConstantMap.
61
        torch._C._jit_pass_onnx_graph_shape_type_inference(g, {}, self.opset_version)
62
        return g
63

64
    def insert_tensor_constant(self, g, tensor):
65
        return g_op(g, "Constant", value_t=tensor)
66

67
    def test_cast(self):
68
        # Test cast with input of unknown scalar type.
69
        g = self.create_empty_graph()
70
        input = g.addInput()
71
        cast_out = g_op(g, "Cast", input, to_i=1)
72
        self.run_test(g, cast_out.node(), expect_tensor("Float"))
73

74
    def test_constant_of_shape(self):
75
        # Test ConstantOfShape with input of onnx::Shape node.
76
        g = self.create_empty_graph()
77
        constant = self.insert_tensor_constant(g, torch.ones(1, 2, 3, 4))
78
        shape = g_op(g, "Shape", constant)
79
        constant_of_shape = g_op(
80
            g, "ConstantOfShape", shape, value_t=torch.tensor([2.0])
81
        )
82
        self.run_test(
83
            g, constant_of_shape.node(), expect_tensor("Float", shape=(1, 2, 3, 4))
84
        )
85

86
    def test_constant_of_shape_static(self):
87
        # Test ConstantOfShape with input of prim::ListConstruct of static tensor
88
        rank = 4
89
        g = self.create_empty_graph()
90
        constants = [
91
            self.insert_tensor_constant(g, torch.tensor(i + 1)) for i in range(rank)
92
        ]
93
        shape = g_op(g, "prim::ListConstruct", *constants)
94
        shape.setType(torch._C.ListType.ofInts())
95
        constant_of_shape = g_op(
96
            g, "ConstantOfShape", shape, value_t=torch.tensor([2.0])
97
        )
98
        self.run_test(
99
            g, constant_of_shape.node(), expect_tensor("Float", shape=(1, 2, 3, 4))
100
        )
101

102
    def test_constant_of_shape_dynamic(self):
103
        # Test ConstantOfShape with input of prim::ListConstruct of dynamic tensor
104
        rank = 4
105
        g = self.create_empty_graph()
106
        inputs = [g.addInput() for i in range(rank)]
107
        shape = g_op(g, "prim::ListConstruct", *inputs)
108
        shape.setType(torch._C.ListType.ofInts())
109
        constant_of_shape = g_op(
110
            g, "ConstantOfShape", shape, value_t=torch.tensor([2.0])
111
        )
112
        self.run_test(
113
            g,
114
            constant_of_shape.node(),
115
            expect_tensor("Float", shape=(None, None, None, None)),
116
        )
117

118
    def test_gather_dynamic_index(self):
119
        g = self.create_empty_graph()
120
        input = g.addInput()
121
        input.setType(
122
            input.type().with_dtype(torch.float).with_sizes([None, 3, 16, 16])
123
        )
124
        indices = g.addInput()
125
        indices.setType(indices.type().with_dtype(torch.int64).with_sizes([None]))
126
        output = g_op(g, "Gather", input, indices, axis_i=1)
127
        self.run_test(
128
            g, output.node(), expect_tensor("Float", shape=([None, None, 16, 16]))
129
        )
130

131
    def test_gather_scalar_index(self):
132
        g = self.create_empty_graph()
133
        input = g.addInput()
134
        input.setType(
135
            input.type().with_dtype(torch.float).with_sizes([None, 3, 16, 16])
136
        )
137
        indices = self.insert_tensor_constant(g, torch.tensor(1))
138
        output = g_op(g, "Gather", input, indices, axis_i=1)
139
        self.run_test(g, output.node(), expect_tensor("Float", shape=([None, 16, 16])))
140

141
    def test_reshape(self):
142
        g = self.create_empty_graph()
143
        constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 5))
144
        constant_2 = self.insert_tensor_constant(g, torch.tensor([2, 0, -1]))
145
        shape = g_op(g, "Reshape", constant, constant_2)
146
        self.run_test(g, shape.node(), expect_tensor("Float", shape=(2, 16, 25)))
147

148
        g = self.create_empty_graph()
149
        constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 4))
150
        constant_2 = self.insert_tensor_constant(g, torch.tensor([-1, 0, 4]))
151
        shape = g_op(g, "Reshape", constant, constant_2)
152
        self.run_test(g, shape.node(), expect_tensor("Float", shape=(10, 16, 4)))
153

154
        g = self.create_empty_graph()
155
        constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 4))
156
        constant_2 = self.insert_tensor_constant(g, torch.tensor([-1, 0, 0]))
157
        shape = g_op(g, "Reshape", constant, constant_2)
158
        self.run_test(g, shape.node(), expect_tensor("Float", shape=(8, 16, 5)))
159

160
    def test_reshape_symbolic(self):
161
        g = self.create_empty_graph()
162
        input = g.addInput()
163
        input.setType(input.type().with_sizes([None, None, 2, 8]))
164
        constant = self.insert_tensor_constant(g, torch.tensor([0, 0, -1]))
165
        output = g_op(g, "Reshape", input, constant)
166
        self.run_test(g, output.node(), expect_tensor(None, shape=(None, None, 16)))
167

168
    @skipIfUnsupportedMinOpsetVersion(14)
169
    def test_reshape_allowzero(self):
170
        g = self.create_empty_graph()
171
        input = g.addInput()
172
        input.setType(input.type().with_sizes([3, 4, 0]))
173
        constant = self.insert_tensor_constant(g, torch.tensor([0, 4, 3]))
174
        output = g_op(g, "Reshape", input, constant, allowzero_i=1)
175
        self.run_test(g, output.node(), expect_tensor(None, shape=(0, 4, 3)))
176

177
    def test_slice(self):
178
        g = self.create_empty_graph()
179
        input = g.addInput()
180
        input.setType(input.type().with_sizes([None, None]))
181
        start_input = g.addInput()
182
        start_input.setType(start_input.type().with_sizes([None]))
183
        end = self.insert_tensor_constant(g, torch.tensor([3]))
184
        axis = self.insert_tensor_constant(g, torch.tensor([0]))
185
        step = self.insert_tensor_constant(g, torch.tensor([1]))
186
        slice = g_op(g, "Slice", input, start_input, end, axis, step)
187
        self.run_test(g, slice.node(), expect_tensor(None, shape=(None, None)))
188

189
    def test_slice_with_dynamic_start_index(self):
190
        g = self.create_empty_graph()
191
        input = self.insert_tensor_constant(g, torch.ones(2, 3, 4, 5))
192
        start_input = g.addInput()
193
        start_input.setType(start_input.type().with_sizes([2]))
194
        end = self.insert_tensor_constant(g, torch.tensor([3, 4]))
195
        axis = self.insert_tensor_constant(g, torch.tensor([1, -1]))
196
        slice = g_op(g, "Slice", input, start_input, end, axis)
197
        self.run_test(g, slice.node(), expect_tensor(None, shape=(2, None, 4, None)))
198

199
    def test_broadcast_matmul(self):
200
        g = self.create_empty_graph()
201
        constant = self.insert_tensor_constant(g, torch.ones(5, 1, 2))
202
        constant_2 = self.insert_tensor_constant(g, torch.ones(3, 1, 2, 1))
203
        shape = g_op(g, "MatMul", constant, constant_2)
204
        self.run_test(g, shape.node(), expect_tensor("Float", shape=(3, 5, 1, 1)))
205

206
        # test when first input is of rank 1
207
        g = self.create_empty_graph()
208
        constant = self.insert_tensor_constant(g, torch.ones(2))
209
        constant_2 = self.insert_tensor_constant(g, torch.ones(3, 1, 2, 1))
210
        shape = g_op(g, "MatMul", constant, constant_2)
211
        self.run_test(g, shape.node(), expect_tensor("Float", shape=(3, 1, 1)))
212

213
        # test when second input is of rank 1
214
        g = self.create_empty_graph()
215
        constant = self.insert_tensor_constant(g, torch.ones(5, 1, 2))
216
        constant_2 = self.insert_tensor_constant(g, torch.ones(2))
217
        shape = g_op(g, "MatMul", constant, constant_2)
218
        self.run_test(g, shape.node(), expect_tensor("Float", shape=(5, 1)))
219

220
        # test when both inputs are of rank 1
221
        g = self.create_empty_graph()
222
        constant = self.insert_tensor_constant(g, torch.ones(2))
223
        constant_2 = self.insert_tensor_constant(g, torch.ones(2))
224
        shape = g_op(g, "MatMul", constant, constant_2)
225
        self.run_test(g, shape.node(), expect_tensor("Float", shape=()))
226

227
    def test_expand(self):
228
        g = self.create_empty_graph()
229
        input = g.addInput()
230
        constant = self.insert_tensor_constant(g, torch.ones(2, 4))
231
        input.setType(constant.type().with_sizes([None, None]))
232
        shape = g_op(g, "Shape", input)
233
        expand = g_op(g, "Expand", constant, shape)
234
        self.run_test(g, expand.node(), expect_tensor("Float", shape=(None, None)))
235

236
    def test_pad(self):
237
        g = self.create_empty_graph()
238
        input = g.addInput()
239
        input.setType(input.type().with_dtype(torch.float).with_sizes([3, 320, 100]))
240
        constant = self.insert_tensor_constant(g, torch.ones(6, dtype=torch.long))
241
        none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
242
        pad = g_op(g, "Pad", input, constant, none, mode_s="constant")
243
        self.run_test(g, pad.node(), expect_tensor("Float", shape=(5, 322, 102)))
244

245
    def test_pad_with_dynamic_input_shape(self):
246
        g = self.create_empty_graph()
247
        input = g.addInput()
248
        input.setType(input.type().with_dtype(torch.float).with_sizes([3, None, None]))
249
        constant = self.insert_tensor_constant(g, torch.ones(6, dtype=torch.long))
250
        none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
251
        pad = g_op(g, "Pad", input, constant, none, mode_s="constant")
252
        self.run_test(g, pad.node(), expect_tensor("Float", shape=(5, None, None)))
253

254
    def test_pad_with_dynamic_pad_size(self):
255
        g = self.create_empty_graph()
256
        input = g.addInput()
257
        input.setType(input.type().with_dtype(torch.float).with_sizes([3, 320, 100]))
258
        pad_size = g.addInput()
259
        pad_size.setType(pad_size.type().with_dtype(torch.long).with_sizes([6]))
260
        none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
261
        pad = g_op(g, "Pad", input, pad_size, none, mode_s="constant")
262
        self.run_test(g, pad.node(), expect_tensor("Float", shape=(None, None, None)))
263

264
    def test_resize(self):
265
        g = self.create_empty_graph()
266
        input = g.addInput()
267
        input.setType(input.type().with_dtype(torch.float).with_sizes([4, 32, 64, 64]))
268
        none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
269
        scales = self.insert_tensor_constant(
270
            g, torch.tensor([1, 1, 2, 2], dtype=torch.float)
271
        )
272
        resize = g_op(
273
            g,
274
            "Resize",
275
            input,
276
            none,
277
            scales,
278
            coordinate_transformation_mode_s="align_corners",
279
            cubic_coeff_a_f=-0.75,
280
            mode_s="linear",
281
            nearest_mode_s="floor",
282
        )
283
        self.run_test(g, resize.node(), expect_tensor("Float", shape=(4, 32, 128, 128)))
284

285
    def test_resize_after_concat(self):
286
        g = self.create_empty_graph()
287
        input = g.addInput()
288
        input.setType(input.type().with_dtype(torch.float).with_sizes([4, 32, 64, 64]))
289
        none = g_op(g, "prim::Constant").setType(torch.NoneType.get())
290
        scale_1 = self.insert_tensor_constant(
291
            g, torch.tensor([1, 1], dtype=torch.float)
292
        )
293
        scale_2 = self.insert_tensor_constant(
294
            g, torch.tensor([2, 2], dtype=torch.float)
295
        )
296
        # `scales` values should be statically known due to constant folding in shape inference.
297
        scales = g_op(g, "Concat", scale_1, scale_2, axis_i=0)
298
        resize = g_op(
299
            g,
300
            "Resize",
301
            input,
302
            none,
303
            scales,
304
            coordinate_transformation_mode_s="align_corners",
305
            cubic_coeff_a_f=-0.75,
306
            mode_s="linear",
307
            nearest_mode_s="floor",
308
        )
309
        self.run_test(g, resize.node(), expect_tensor("Float", shape=(4, 32, 128, 128)))
310

311
    def test_reduce_prod_with_axes(self):
312
        g = self.create_empty_graph()
313
        input = g.addInput()
314
        input.setType(input.type().with_dtype(torch.long).with_sizes([2]))
315
        reduce_prod = g_op(g, "ReduceProd", input, axes_i=[0])
316
        self.run_test(g, reduce_prod.node(), expect_tensor("Long", shape=(1,)))
317

318
    def test_reduce_prod_without_axes(self):
319
        g = self.create_empty_graph()
320
        input = g.addInput()
321
        input.setType(input.type().with_dtype(torch.long).with_sizes([2]))
322
        reduce_prod = g_op(g, "ReduceProd", input)
323
        self.run_test(g, reduce_prod.node(), expect_tensor("Long", shape=(1,)))
324

325
    def test_proceeding_nodes_use_prim_pack_padded_output_dtype_correctly(self):
326
        g = self.create_empty_graph()
327
        input = g.addInput()
328
        input.setType(input.type().with_dtype(torch.float).with_sizes([4, 16]))
329
        length = g.addInput()
330
        length.setType(length.type().with_dtype(torch.long).with_sizes([4]))
331
        padded, batch_size = g_op(g, "prim::PackPadded", input, length, outputs=2)
332
        # `prim::PackPadded` only occurs in tracing mode. Hence its outputs inherits
333
        # shape and data type from traced graph.
334
        padded.setType(padded.type().with_dtype(torch.float).with_sizes([None, None]))
335
        batch_size.setType(batch_size.type().with_dtype(torch.long).with_sizes([None]))
336
        # `Gather` should use the data type of `batch_size` as the data type of its output.
337
        gather_idx = self.insert_tensor_constant(g, torch.tensor([0], dtype=torch.long))
338
        gather = g_op(g, "Gather", batch_size, gather_idx, axis_i=0)
339
        self.run_test(g, gather.node(), expect_tensor("Long", shape=(None,)))
340

341
    def test_squeeze_after_dynamic_if(self):
342
        from torch.onnx.symbolic_opset11 import squeeze as squeeze11
343

344
        g = self.create_empty_graph()
345

346
        input = g.addInput()
347
        input.setType(input.type().with_dtype(torch.float).with_sizes([1, None, 5]))
348

349
        # Type is intentionally not bool to test that
350
        # the added "Cast" node doesn't stop shape inference.
351
        cond = g.addInput()
352
        cond.setType(input.type().with_dtype(torch.int32).with_sizes([1]))
353
        if_op, (if_context, else_context), new_node = jit_utils.add_op_with_blocks(
354
            as_graphcontext(g), "If", cond, n_blocks=2
355
        )
356
        block1_output = if_context.op("Add", input, input)
357
        block2_output = else_context.op("Identity", input)
358
        utils._add_output_to_block(if_context.block, block1_output)
359
        utils._add_output_to_block(else_context.block, block2_output)
360
        if_output = torch._C._jit_pass_fixup_onnx_controlflow_node(
361
            new_node, _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET
362
        )[0]
363
        torch._C._jit_pass_onnx_node_shape_type_inference(
364
            new_node, {}, _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET
365
        )
366

367
        # Exporter will add "If" instead of raw "Squeeze" if it does not know
368
        # that if the dimension it is squeezing has size 1.
369
        squeezed = squeeze11(as_graphcontext(g), if_output, dim=0)
370
        assert squeezed.node().kind() == "onnx::Squeeze"
371
        self.run_test(g, squeezed.node(), expect_tensor("Float", shape=(None, 5)))
372

373

374
class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
375
    def setUp(self):
376
        super().setUp()
377
        self.opset_version = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET
378

379
    def test_setType_maintains_output_shape_for_single_custom_op(self):
380
        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
381

382
        class CustomInverse(torch.nn.Module):
383
            def forward(self, x):
384
                return torch.inverse(x) + x
385

386
        def linalg_inv_settype(g, self):
387
            return g.op("com.microsoft::Inverse", self).setType(self.type())
388

389
        torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9)
390
        model = CustomInverse()
391
        x = torch.randn(2, 3, 3)
392
        f = io.BytesIO()
393
        torch.onnx.export(
394
            model,
395
            (x,),
396
            f,
397
            opset_version=self.opset_version,
398
            custom_opsets={"com.microsoft": 1},
399
        )
400

401
        model_proto = onnx.load(io.BytesIO(f.getvalue()))
402
        model_value_info = model_proto.graph.value_info
403
        self.assertIsNotNone(model_value_info)
404
        assert model_value_info
405
        dims = model_value_info[0].type.tensor_type.shape.dim
406
        for i in range(len(dims)):
407
            # If node output has shape info, it should have dim_value
408
            # Otherwise, it has dim_params with dynamic shape
409
            self.assertTrue(dims[i].HasField("dim_value"))
410
        for dim, rank in zip(dims, x.size()):
411
            self.assertEqual(dim.dim_value, rank)
412

413
    def test_no_setType_for_single_custom_op(self):
414
        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
415

416
        class CustomInverse(torch.nn.Module):
417
            def forward(self, x):
418
                return torch.inverse(x) + x
419

420
        def linalg_inv_no_settype(g, self):
421
            return g.op("com.microsoft::Inverse", self)
422

423
        torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_no_settype, 9)
424
        model = CustomInverse()
425
        x = torch.randn(2, 3, 3)
426
        f = io.BytesIO()
427
        torch.onnx.export(
428
            model,
429
            (x,),
430
            f,
431
            opset_version=self.opset_version,
432
            custom_opsets={"com.microsoft": 1},
433
        )
434

435
        model_proto = onnx.load(io.BytesIO(f.getvalue()))
436
        model_value_info = model_proto.graph.value_info
437
        self.assertIsNotNone(model_value_info)
438
        assert model_value_info
439
        dims = model_value_info[0].type.tensor_type.shape.dim
440
        for i in range(len(dims)):
441
            # If node output has shape info, it should have dim_value
442
            # Otherwise, it has dim_params with dynamic shape
443
            self.assertTrue(dims[i].HasField("dim_param"))
444

445
    def test_setType_maintains_output_shape_for_single_custom_op_with_dynamic_axes(
446
        self,
447
    ):
448
        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
449

450
        class CustomInverse(torch.nn.Module):
451
            def forward(self, x):
452
                return torch.inverse(x) + x
453

454
        def linalg_inv_settype(g, self):
455
            return g.op("com.microsoft::Inverse", self).setType(
456
                self.type().with_dtype(torch.float).with_sizes([None, 3, 3])
457
            )
458

459
        torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9)
460
        model = CustomInverse()
461
        x = torch.randn(2, 3, 3)
462
        f = io.BytesIO()
463
        torch.onnx.export(
464
            model,
465
            (x,),
466
            f,
467
            opset_version=self.opset_version,
468
            custom_opsets={"com.microsoft": 1},
469
            input_names=["x"],
470
            dynamic_axes={"x": {0: "batch"}},
471
        )
472

473
        model_proto = onnx.load(io.BytesIO(f.getvalue()))
474
        model_value_info = model_proto.graph.value_info
475
        self.assertIsNotNone(model_value_info)
476
        assert model_value_info
477
        dims = model_value_info[0].type.tensor_type.shape.dim
478
        # The first axe should be dynamic as we defined when exporting
479
        self.assertTrue(dims[0].HasField("dim_param"))
480
        for i in range(1, len(dims)):
481
            # If node output has shape info, it should have dim_value
482
            # Otherwise, it has dim_params with dynamic shape
483
            self.assertTrue(dims[i].HasField("dim_value"))
484
            self.assertEqual(dims[i].dim_value, x.size()[i])
485

486
    def test_setType_maintains_output_shape_for_single_custom_op_with_onnx_ops(self):
487
        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
488

489
        class CustomInverse(torch.nn.Module):
490
            def forward(self, x, y, z):
491
                x = torch.inverse(x)
492
                return x + y + z
493

494
        def linalg_inv_settype(g, self):
495
            return g.op("com.microsoft::Inverse", self).setType(
496
                self.type().with_dtype(torch.float).with_sizes([2, 3, 10, 10])
497
            )
498

499
        torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9)
500
        model = CustomInverse()
501
        x = torch.randn(2, 3, 10, 10)
502
        y = torch.randn(2, 3, 10, 10)
503
        z = torch.randn(2, 3, 10, 10)
504
        f = io.BytesIO()
505
        torch.onnx.export(
506
            model,
507
            (x, y, z),
508
            f,
509
            opset_version=self.opset_version,
510
            custom_opsets={"com.microsoft": 1},
511
        )
512

513
        model_proto = onnx.load(io.BytesIO(f.getvalue()))
514
        # To validate the shape of inverse Op, we need to find inverse output name,
515
        # and then use it to identify its value_info for the shape.
516
        output_name = ""
517
        for node in model_proto.graph.node:
518
            if node.op_type == "Inverse":
519
                output_name = node.output[0]
520
                break
521
        assert output_name
522
        model_value_info = model_proto.graph.value_info
523
        self.assertIsNotNone(model_value_info)
524
        assert model_value_info
525
        for value_info in model_value_info:
526
            assert value_info.name
527
            if value_info.name == output_name:
528
                dims = value_info.type.tensor_type.shape.dim
529
                for i in range(len(dims)):
530
                    # If node output has shape info, it should have dim_value
531
                    # Otherwise, it has dim_params with dynamic shape
532
                    self.assertTrue(dims[i].HasField("dim_value"))
533
                for dim, rank in zip(dims, x.size()):
534
                    self.assertEqual(dim.dim_value, rank)
535

536

537
if __name__ == "__main__":
538
    common_utils.run_tests()
539

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.