onnxruntime

Форк
0
/
usability_checker.py 
739 строк · 30.9 Кб
1
# Copyright (c) Microsoft Corporation. All rights reserved.
2
# Licensed under the MIT License.
3
from __future__ import annotations
4

5
import argparse
6
import logging
7
import os
8
import pathlib
9
import tempfile
10
from collections import deque
11
from enum import IntEnum
12

13
import onnx
14

15
from ..onnx_model_utils import ModelProtoWithShapeInfo, get_producer_consumer_maps, is_fixed_size_tensor, optimize_model
16

17

18
class _SupportedOpsChecker:
19
    """
20
    Class to process the md file with list of supported ops and caveats for an execution provider.
21
    e.g. /tools/ci_build/github/android/nnapi_supported_ops.md
22
         /tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md
23
         /tools/ci_build/github/apple/coreml_supported_neuralnetwork_ops.md
24
    """
25

26
    def __init__(self, filename):
27
        self._filename = filename
28
        self._ops = {}  # op to caveats
29
        self._ops_seen = set()
30

31
        with open(filename) as f:
32
            for line in f:
33
                # we're looking for a markdown table with 2 columns. first is op name. second is caveats
34
                # op name is domain:op
35
                if line.startswith("|"):
36
                    pieces = line.strip().split("|")
37
                    if len(pieces) == 4:  # pre-first '|'. op, caveat, post-last '|'
38
                        domain_op = pieces[1]
39
                        caveat = pieces[2]
40
                        caveat = caveat.replace("<br/>", " ")  # remove some HTML tags
41
                        # skip lines that don't have the ':' which separates the domain and op
42
                        # e.g. the table header will fail this check
43
                        if ":" in domain_op:
44
                            self._ops[domain_op] = caveat
45

46
    def is_op_supported(self, node):
47
        domain = node.domain if node.domain else "ai.onnx"
48
        domain_op = domain + ":" + node.op_type
49

50
        is_supported = domain_op in self._ops
51
        if is_supported:
52
            self._ops_seen.add(domain_op)
53

54
        return is_supported
55

56
    def get_caveats(self):
57
        caveats = []
58
        for op in sorted(self._ops_seen):
59
            caveat = self._ops[op]
60
            if caveat:
61
                caveats.append(f"{op}:{caveat}")
62

63
        return caveats
64

65

66
class PartitioningInfo:
67
    class TryWithEP(IntEnum):
68
        NO = (0,)
69
        MAYBE = (1,)
70
        YES = 2
71

72
    def __init__(
73
        self,
74
        num_nodes: int,
75
        num_supported_nodes: int,
76
        num_partitions: int,
77
        supported_ops_checker: _SupportedOpsChecker,
78
        supported_groups: list[onnx.NodeProto],
79
        unsupported_ops: set[str],
80
        nodes_unsupported_due_to_op: int,
81
        nodes_unsupported_due_to_dynamic_input: int,
82
        num_unsupported_nodes_due_to_rank: int,
83
        ops_with_unsupported_rank: set[str],
84
    ):
85
        self.num_nodes = num_nodes
86
        self.num_supported_nodes = num_supported_nodes
87
        self.num_partitions = num_partitions
88
        self.supported_ops_checker = supported_ops_checker
89
        self.supported_groups = supported_groups
90
        self.unsupported_ops = unsupported_ops
91
        self.nodes_unsupported_due_to_op = nodes_unsupported_due_to_op
92
        self.nodes_unsupported_due_to_dynamic_input = nodes_unsupported_due_to_dynamic_input
93
        self.num_unsupported_nodes_due_to_rank = num_unsupported_nodes_due_to_rank
94
        self.ops_with_unsupported_rank = ops_with_unsupported_rank
95

96
        self.num_subgraphs = 0
97
        self.num_nodes_in_subgraphs = 0
98

99
    def merge(self, other: PartitioningInfo):
100
        """
101
        Merge the information from another PartitioningInfo instance into this one.
102
        """
103
        self.num_nodes += other.num_nodes
104
        self.num_supported_nodes += other.num_supported_nodes
105
        self.num_partitions += other.num_partitions
106
        self.supported_groups.extend(other.supported_groups)
107
        self.unsupported_ops.update(other.unsupported_ops)
108
        self.nodes_unsupported_due_to_op += other.nodes_unsupported_due_to_op
109
        self.nodes_unsupported_due_to_dynamic_input += other.nodes_unsupported_due_to_dynamic_input
110
        self.num_unsupported_nodes_due_to_rank += other.num_unsupported_nodes_due_to_rank
111
        self.ops_with_unsupported_rank.update(other.ops_with_unsupported_rank)
112

113
        # hard assumption that we merge into the main graph partitioning info
114
        self.num_subgraphs += 1
115
        self.num_nodes_in_subgraphs += other.num_nodes
116

117
    def suitability(self):
118
        # semi-arbitrary choices that err on the side of MAYBE.
119
        # having 1 partition is always preferred, but if that is small it may not be useful.
120
        # having 2 partitions may be okay if they cover most nodes
121
        # more than 2 partitions and the device copy cost is almost guaranteed to outweigh the benefit of using the NPU
122
        # NOTE: This assumes the EP is not CPU based and there is device copy overhead to consider
123
        pct_supported = self.num_supported_nodes / self.num_nodes * 100
124
        if self.num_partitions == 1:
125
            if pct_supported > 75:
126
                return PartitioningInfo.TryWithEP.YES
127
            elif pct_supported > 50:
128
                return PartitioningInfo.TryWithEP.MAYBE
129
            else:
130
                return PartitioningInfo.TryWithEP.NO
131

132
        if self.num_partitions == 2:
133
            if pct_supported > 75:
134
                return PartitioningInfo.TryWithEP.MAYBE
135
            else:
136
                return PartitioningInfo.TryWithEP.NO
137

138
        return PartitioningInfo.TryWithEP.NO
139

140
    def print_analysis(self, logger: logging.Logger, ep_name: str):
141
        """
142
        Analyze the partitioning information and log the analysis
143
        :param logger: Logger to use
144
        :param ep_name: Execution provider name to use in the log messages
145
        """
146

147
        logger.info(
148
            f"{self.num_partitions} partitions with a total of {self.num_supported_nodes}/{self.num_nodes} "
149
            f"nodes can be handled by the {ep_name} EP."
150
        )
151

152
        if self.supported_groups:
153
            logger.info(
154
                f'\tPartition sizes: [{", ".join([str(len(partition)) for partition in self.supported_groups])}]'
155
            )
156

157
            # dump full groups if debug output is enabled
158
            for group in self.supported_groups:
159
                logger.debug(f'Nodes in group: {",".join([f"{node.op_type}:{node.name}" for node in group])}')
160

161
        logger.info(f"Unsupported nodes due to operator={self.nodes_unsupported_due_to_op}")
162
        if self.unsupported_ops:
163
            logger.info(f'\tUnsupported ops: {",".join(sorted(self.unsupported_ops))}')
164

165
        caveats = self.supported_ops_checker.get_caveats()
166
        if caveats:
167
            indent = " " * 5
168
            logger.info(
169
                "\tCaveats that have not been checked and may result in a node not actually being supported:  "
170
                f'{"".join([os.linesep + indent + caveat for caveat in caveats])}'
171
            )
172

173
        if self.nodes_unsupported_due_to_dynamic_input:
174
            logger.info(
175
                "Unsupported nodes due to input having a dynamic shape=%d",
176
                self.nodes_unsupported_due_to_dynamic_input,
177
            )
178

179
        if self.num_unsupported_nodes_due_to_rank:
180
            logger.info(f"Unsupported nodes due to rank of input data={self.num_unsupported_nodes_due_to_rank}")
181
            logger.info(f"\tOps with unsupported rank: {','.join(sorted(self.ops_with_unsupported_rank))}")
182

183
        if self.num_subgraphs > 0:
184
            # TODO: CoreML has a flag. NNAPI doesn't. Either should be able to support a subgraph when treated as a
185
            # separate graph (only extra detail would be making sure implicit inputs are handled).
186
            # Merging the subgraph into the parent graph would be more complex.
187
            #   e.g. for CoreML we could potentially convert Loop to while_loop and If to cond if the subgraphs in the
188
            #        control flow node are fully supported.
189
            #        NNAPI also has While and If.
190

191
            # It most likely will be necessary to support merging in If nodes with fully supported subgraphs,
192
            # as the subgraphs in those are often very simple, so the performance cost of going to the CPU EP and back
193
            # is high.
194
            logger.info(
195
                f"{self.num_nodes_in_subgraphs} nodes are in {self.num_subgraphs} subgraphs. "
196
                "Check EP as to whether subgraphs are supported."
197
            )
198

199
        pct_nodes_using_ep = self.num_supported_nodes / self.num_nodes * 100
200
        if self.num_partitions == 0:
201
            logger.info(f"{ep_name} cannot run any nodes in this model.")
202
        elif self.num_partitions == 1:
203
            if pct_nodes_using_ep > 75:
204
                logger.info(
205
                    f"{ep_name} should work well for this model as there is one partition "
206
                    f"covering {pct_nodes_using_ep:.1f}% of the nodes in the model."
207
                )
208
            elif pct_nodes_using_ep > 50:
209
                logger.info(
210
                    f"{ep_name} may work well for this model, however only {pct_nodes_using_ep:.1f}% of nodes "
211
                    "will use it. Performance testing is required to validate."
212
                )
213
            else:
214
                logger.info(
215
                    f"{ep_name} will probably not work will for this model as only {pct_nodes_using_ep:.2f}% "
216
                    "of nodes will use it."
217
                )
218

219
        elif self.num_partitions == 2 and pct_nodes_using_ep > 75:
220
            logger.info(
221
                f"{ep_name} can be considered for this model as there are two partitions "
222
                f"covering {pct_nodes_using_ep:.1f}% of the nodes. "
223
                "Performance testing is required to validate."
224
            )
225
        else:
226
            logger.info(
227
                f"{ep_name} is not recommended with this model as there are {self.num_partitions} partitions "
228
                f"covering {pct_nodes_using_ep:.1f}% of the nodes in the model. "
229
                "This will most likely result in worse performance than just using the CPU EP."
230
            )
231

232

233
def _check_partitioning_for_graph(
234
    graph: onnx.GraphProto,
235
    node_to_producers: dict[onnx.NodeProto, set[onnx.NodeProto]],
236
    node_to_consumers: dict[onnx.NodeProto, set[onnx.NodeProto]],
237
    supported_ops_checker: _SupportedOpsChecker,
238
    outer_scope_initializers: set[str],
239
    require_fixed_input_sizes: bool,
240
    value_info: dict[str, onnx.ValueInfoProto],
241
    max_rank: int = 999,  # max rank if EP has a limitation
242
):
243
    # initializers have fixed sizes.
244
    initializers = [i.name for i in graph.initializer]
245

246
    def _is_fixed_shape_value(value):
247
        if value in value_info:
248
            return is_fixed_size_tensor(value_info[value])
249

250
        if value in initializers or value in outer_scope_initializers:
251
            return True
252

253
        # if something has an unknown shape (e.g. something downstream of a Reshape with dynamic input for the shape)
254
        # it won't have an entry in value_info
255
        return False
256

257
    #
258
    # Replicate logic from /onnxruntime/core/providers/partitioning_utils.cc:CreateSupportedPartitionNodeGroups
259
    # to roughly estimate number of partitions for nodes that is_node_supported_fn returns true for.
260
    #
261
    # We keep the structure and variable names as close as possible to the C++ implementation to simplify keeping them
262
    # in sync if future updates are needed.
263
    #
264
    # NOTE: CreateSupportedPartitionNodeGroups was recently updated to be QDQ aware so that partitions did not split
265
    # QDQ node groups. This code does not need to be QDQ aware as splitting a QDQ node group does not affect the total
266
    # number of partitions or supported nodes.
267
    #
268

269
    # we don't currently support a callback for additional group closure checks in the python implementation
270
    on_group_closed_fn = None
271

272
    supported_groups = []
273
    # number of inputs from unprocessed nodes (in-degree) per node
274
    in_degree = {}
275
    # nodes that are ready to process
276
    nodes_to_process = deque()  # deque of Node instances
277
    # nodes that will be processed when considering the next partition node group
278
    nodes_to_process_with_next_group = deque()
279

280
    # initialize in-degrees and find root nodes
281
    for node in graph.node:
282
        node_input_edge_count = len(node_to_producers[node]) if node in node_to_producers else 0
283
        in_degree[node] = node_input_edge_count
284
        if node_input_edge_count == 0:
285
            # node is only dependent on graph input or initializers
286
            nodes_to_process.append(node)
287

288
    supported_group = []
289
    # the partition node group's border is the aggregate of its nodes' output nodes
290
    supported_group_border = set()
291
    num_supported_nodes = 0
292
    num_unsupported_nodes_due_to_op = 0
293
    num_unsupported_nodes_due_to_dynamic_input = 0
294
    num_unsupported_nodes_due_to_rank = 0
295
    unsupported_ops = set()
296
    ops_with_unsupported_rank = set()
297

298
    def close_group():
299
        if supported_group:
300
            keep_partition = not on_group_closed_fn or on_group_closed_fn(supported_group)
301

302
            if keep_partition:
303
                supported_groups.append(supported_group.copy())
304

305
            supported_group.clear()
306
            supported_group_border.clear()
307

308
    while nodes_to_process or nodes_to_process_with_next_group:
309
        if not nodes_to_process:
310
            close_group()
311
            nodes_to_process = nodes_to_process_with_next_group
312
            nodes_to_process_with_next_group = deque()
313
            continue
314

315
        node = nodes_to_process.popleft()
316

317
        is_op_supported = supported_ops_checker.is_op_supported(node)
318
        is_input_shape_supported = not require_fixed_input_sizes or all(_is_fixed_shape_value(i) for i in node.input)
319

320
        is_rank_supported = True
321
        if value_info:
322
            for node_input in node.input:
323
                if node_input and node_input in value_info and value_info[node_input].type.HasField("tensor_type"):
324
                    input_rank = len(value_info[node_input].type.tensor_type.shape.dim)
325
                    if input_rank > max_rank:
326
                        is_rank_supported = False
327
                        break
328

329
        # special-case if we can infer the rank from the length of the 'perms' Transpose attribute
330
        # e.g. this works with SegmentAnything where dynamic Reshape operators result in no shape info.
331
        if node.op_type == "Transpose" and len(node.attribute[0].ints) > max_rank:
332
            is_rank_supported = False
333

334
        is_node_supported = is_op_supported and is_input_shape_supported and is_rank_supported
335

336
        if not is_node_supported:
337
            if node in supported_group_border:
338
                # an unsupported node on the border will be processed after the current partition node group
339
                # so skip any additional processing/counting here
340
                nodes_to_process_with_next_group.append(node)
341
                continue
342

343
            if not is_op_supported:
344
                unsupported_ops.add(f'{node.domain if node.domain else "ai.onnx"}:{node.op_type}')
345
                num_unsupported_nodes_due_to_op += 1
346

347
            if not is_input_shape_supported:
348
                num_unsupported_nodes_due_to_dynamic_input += 1
349

350
            if not is_rank_supported:
351
                num_unsupported_nodes_due_to_rank += 1
352
                ops_with_unsupported_rank.add(f'{node.domain if node.domain else "ai.onnx"}:{node.op_type}')
353

354
        if is_node_supported:
355
            num_supported_nodes += 1
356

357
            # add node to the partition node group
358
            supported_group.append(node)
359

360
            # remove node from the border and add its outputs to the border
361
            if node in supported_group_border:
362
                supported_group_border.remove(node)
363

364
            # for each consumer node add to supported_group_border
365
            if node in node_to_consumers:
366
                for consumer in node_to_consumers[node]:
367
                    supported_group_border.add(consumer)
368

369
        # adjust in-degrees of the node outputs and add any new nodes to process
370
        if node in node_to_consumers:
371
            for consumer in node_to_consumers[node]:
372
                consumer_node_in_degree = in_degree[consumer]
373
                consumer_node_in_degree -= 1
374
                if consumer_node_in_degree == 0:
375
                    nodes_to_process.append(consumer)
376

377
                in_degree[consumer] = consumer_node_in_degree
378

379
    close_group()
380

381
    num_nodes = len(graph.node)
382
    num_partitions = len(supported_groups)
383

384
    info = PartitioningInfo(
385
        num_nodes,
386
        num_supported_nodes,
387
        num_partitions,
388
        supported_ops_checker,
389
        supported_groups,
390
        unsupported_ops,
391
        num_unsupported_nodes_due_to_op,
392
        num_unsupported_nodes_due_to_dynamic_input,
393
        num_unsupported_nodes_due_to_rank,
394
        ops_with_unsupported_rank,
395
    )
396

397
    return info
398

399

400
def check_partitioning(
401
    main_graph: onnx.GraphProto,
402
    supported_ops_checker: _SupportedOpsChecker,
403
    require_fixed_input_sizes: bool,
404
    max_rank: int = 999,
405
) -> PartitioningInfo:
406
    """
407
    Estimate the partitions the graph will be split into for nodes that is_node_supported_fn returns true for.
408

409
    The check on whether a node is supported is purely based on the operator type. Additional limitations
410
    (e.g. NNAPI EP only supports 2D Conv) are not checked, so partitions may not be 100% accurate. The limitations
411
    for operators in the partitions are printed so the user can manually check.
412
    :param main_graph: Graph to process
413
    :param supported_ops_checker: Checker with info on supported ops.
414
    :param require_fixed_input_sizes: If True, require that the inputs to a potentially supported node are fixed size
415
                                      tensors for it to be considered as supported. This requires
416
                                      onnx.shape_inference.infer_shapes to have been run on the model to populate the
417
                                      shape information.
418
                                      If False, shapes are ignored during the check.
419
    :param max_rank: Set if EP has a limitation on the rank of tensors it supports.
420
    :return PartitioningInfo instance with details
421
    """
422

423
    if require_fixed_input_sizes and len(main_graph.value_info) == 0 and len(main_graph.node) > 1:
424
        raise ValueError("Run onnx.shape_inference.infer_shapes on the model to populate the shape information.")
425

426
    # create lookup map from ValueInfo for efficiency
427
    def _update_value_info(graph: onnx.GraphProto, value_to_shape: dict[str, onnx.ValueInfoProto]):
428
        for v in graph.input:
429
            value_to_shape[v.name] = v
430
        for v in graph.output:
431
            value_to_shape[v.name] = v
432
        for v in graph.value_info:
433
            value_to_shape[v.name] = v
434

435
    # the producer/consumer maps are for the entire model
436
    node_to_producers, node_to_consumers = get_producer_consumer_maps(main_graph)
437

438
    def _check_graph(
439
        graph: onnx.GraphProto,
440
        outer_scope_value_info: dict[str, onnx.ValueInfoProto] | None,
441
        outer_scope_initializers: set[str] | None = None,
442
        partitioning_info: PartitioningInfo | None = None,
443
    ) -> PartitioningInfo:
444
        if outer_scope_value_info is not None:
445
            # extend value info if we're using it. we replace any value shadowed with a local one
446
            value_info = outer_scope_value_info.copy()
447
            _update_value_info(graph, value_info)
448
        else:
449
            value_info = {}
450

451
        if outer_scope_initializers is None:
452
            outer_scope_initializers = set()
453

454
        info = _check_partitioning_for_graph(
455
            graph,
456
            node_to_producers,
457
            node_to_consumers,
458
            supported_ops_checker,
459
            outer_scope_initializers,
460
            require_fixed_input_sizes,
461
            value_info,
462
            max_rank,
463
        )
464

465
        if partitioning_info:
466
            # merge in subgraph info
467
            partitioning_info.merge(info)
468
        else:
469
            # main graph info
470
            partitioning_info = info
471

472
        # setup outer scope initializers. we copy the input set as a model may have multiple subgraphs
473
        # on multiple levels, so we need to keep the set for each descent separate
474
        subgraph_outer_scope_initializers = set(outer_scope_initializers)
475
        for initializer in graph.initializer:
476
            subgraph_outer_scope_initializers.add(initializer.name)
477

478
        for node in graph.node:
479
            # recurse into nodes with subgraphs
480
            for attr in node.attribute:
481
                if attr.HasField("g"):
482
                    subgraph = attr.g
483
                    partitioning_info = _check_graph(
484
                        subgraph, value_info, subgraph_outer_scope_initializers, partitioning_info
485
                    )
486

487
        return partitioning_info
488

489
    aggregated_partitioning_info = _check_graph(main_graph, {} if require_fixed_input_sizes else None)
490

491
    return aggregated_partitioning_info
492

493

494
def _check_ep_partitioning(
495
    model: onnx.ModelProto, supported_ops_config: pathlib.Path, require_fixed_input_sizes: bool, max_rank: int = 999
496
):
497
    supported_ops = _SupportedOpsChecker(supported_ops_config)
498
    partition_info = check_partitioning(model.graph, supported_ops, require_fixed_input_sizes, max_rank)
499
    return partition_info
500

501

502
def check_nnapi_partitions(model, require_fixed_input_sizes: bool):
503
    # if we're running in the ORT python package the file should be local. otherwise assume we're running from the
504
    # ORT repo
505
    script_dir = pathlib.Path(__file__).parent
506
    local_config = script_dir / "nnapi_supported_ops.md"
507
    if local_config.exists():
508
        config_path = local_config
509
    else:
510
        ort_root = script_dir.parents[3]
511
        config_path = ort_root / "tools" / "ci_build" / "github" / "android" / "nnapi_supported_ops.md"
512

513
    return _check_ep_partitioning(model, config_path, require_fixed_input_sizes)
514

515

516
def check_coreml_partitions(model: onnx.ModelProto, require_fixed_input_sizes: bool, config_filename: str):
517
    # if we're running in the ORT python package the file should be local. otherwise assume we're running from the
518
    # ORT repo
519
    script_dir = pathlib.Path(__file__).parent
520
    local_config = script_dir / config_filename
521
    if local_config.exists():
522
        config_path = local_config
523
    else:
524
        ort_root = script_dir.parents[3]
525
        config_path = ort_root / "tools" / "ci_build" / "github" / "apple" / config_filename
526

527
    max_rank = 5
528
    return _check_ep_partitioning(model, config_path, require_fixed_input_sizes, max_rank)
529

530

531
def check_shapes(graph: onnx.GraphProto, logger: logging.Logger | None = None):
532
    """
533
    Check the shapes of graph inputs, values and graph outputs to determine if they have static or dynamic sizes.
534
    NNAPI does not support dynamically sized values. CoreML does, but it will most likely cost performance.
535
    :param graph: Graph to check. If shape inferencing has been run the checks on values will be meaningful.
536
    :param logger: Optional logger for diagnostic information.
537
    :return: Tuple of List of inputs with dynamic shapes, Number of dynamic values found
538
    """
539

540
    # it's OK if the input is dynamically sized and we do a Resize early to a fixed size.
541
    # it's not good if lots of ops have dynamic inputs
542

543
    num_fixed_values = 0
544
    num_dynamic_values = 0
545

546
    dynamic_inputs = []
547
    for i in graph.input:
548
        if not is_fixed_size_tensor(i):
549
            dynamic_inputs.append(i)
550
            # split/join to remove repeated whitespace and newlines from str(i)
551
            if logger:
552
                logger.info(f"Input is not a fixed size tensor: {' '.join(str(i).split())}")
553
            num_dynamic_values += 1
554
        else:
555
            num_fixed_values += 1
556

557
    dynamic_outputs = []
558
    for o in graph.output:
559
        if not is_fixed_size_tensor(o):
560
            dynamic_outputs.append(o)
561
            if logger:
562
                logger.info(f"Output is not a fixed size tensor: {' '.join(str(o).split())}")
563
            num_dynamic_values += 1
564
        else:
565
            num_fixed_values += 1
566

567
    # check we have value info.
568
    # special case some test graphs with a single node which only have graph input and output values, and
569
    # a model where all inputs are dynamic (results in no value_info)
570
    if not graph.value_info and not (len(graph.node) == 1 or len(dynamic_inputs) == len(graph.input)):
571
        logger.warning(
572
            "Unable to check shapes within model. "
573
            "ONNX shape inferencing should be run on the model prior to checking."
574
        )
575

576
    for vi in graph.value_info:
577
        if is_fixed_size_tensor(vi):
578
            num_fixed_values += 1
579
        else:
580
            num_dynamic_values += 1
581

582
    if logger:
583
        logger.info(
584
            f"Num values with fixed shape={num_fixed_values}. Num values with dynamic shape={num_dynamic_values}"
585
        )
586

587
        if dynamic_inputs:
588
            if dynamic_outputs:
589
                logger.info(
590
                    "Model has dynamic inputs and outputs. Consider re-exporting model with fixed sizes "
591
                    "if NNAPI or CoreML can be used with this model."
592
                )
593
            else:
594
                logger.info(
595
                    """Model has dynamically sized inputs but fixed sized outputs.
596
                       If the sizes become fixed early in the model (e.g. pre-processing of a dynamic input size
597
                       results in a fixed input size for the majority of the model) performance with NNAPI and CoreML,
598
                       if applicable, should not be significantly impacted."""
599
                )
600

601
    return dynamic_inputs, num_dynamic_values
602

603

604
def checker(model_path: pathlib.Path, logger: logging.Logger):
605
    model_with_shape_info_wrapper = ModelProtoWithShapeInfo(model_path)
606
    model_with_shape_info = model_with_shape_info_wrapper.model_with_shape_info
607

608
    dynamic_inputs, num_dynamic_values = check_shapes(model_with_shape_info.graph)
609

610
    def check_ep(ep_name, checker_func):
611
        logger.info(f"Checking {ep_name}")
612

613
        # check with shape info first so supported nodes takes into account values with dynamic shapes
614
        require_fixed_input_sizes = True
615
        partition_info = checker_func(model_with_shape_info, require_fixed_input_sizes)
616
        if logger.getEffectiveLevel() <= logging.INFO:
617
            partition_info.print_analysis(logger, ep_name)
618

619
        suitability = partition_info.suitability()
620
        logger.info(f"Model should perform well with {ep_name} as is: {suitability.name}")
621

622
        if suitability != PartitioningInfo.TryWithEP.YES and dynamic_inputs:
623
            logger.info("--------")
624
            logger.info("Checking if model will perform better if the dynamic shapes are fixed...")
625
            require_fixed_input_sizes = False
626
            partition_info_with_fixed_shapes = checker_func(model_with_shape_info, require_fixed_input_sizes)
627

628
            if logger.getEffectiveLevel() <= logging.INFO:
629
                # analyze and log detailed info
630
                logger.info("Partition information if the model was updated to make the shapes fixed:")
631
                partition_info_with_fixed_shapes.print_analysis(logger, ep_name)
632

633
            fixed_shape_suitability = partition_info_with_fixed_shapes.suitability()
634
            logger.info(
635
                f"Model should perform well with {ep_name} if modified to have fixed input shapes: "
636
                f"{fixed_shape_suitability.name}"
637
            )
638

639
            if fixed_shape_suitability != PartitioningInfo.TryWithEP.NO:
640
                logger.info("Shapes can be altered using python -m onnxruntime.tools.make_dynamic_shape_fixed")
641

642
            if fixed_shape_suitability.value > suitability.value:
643
                suitability = fixed_shape_suitability
644

645
        logger.info("================")
646
        logger.info("")
647

648
        return suitability
649

650
    nnapi_suitability = check_ep("NNAPI", check_nnapi_partitions)
651

652
    # Check for NeuralNetwork CoreML model
653
    def check_nn_coreml(model: onnx.ModelProto, require_fixed_input_sizes):
654
        return check_coreml_partitions(model, require_fixed_input_sizes, "coreml_supported_neuralnetwork_ops.md")
655

656
    # Check for MLProgram CoreML model
657
    def check_mlprogram_coreml(model: onnx.ModelProto, require_fixed_input_sizes):
658
        return check_coreml_partitions(model, require_fixed_input_sizes, "coreml_supported_mlprogram_ops.md")
659

660
    coreml_nn_suitability = check_ep("CoreML NeuralNetwork", check_nn_coreml)
661
    coreml_mlprogram_suitability = check_ep("CoreML MLProgram", check_mlprogram_coreml)
662

663
    if (
664
        nnapi_suitability != PartitioningInfo.TryWithEP.YES
665
        or coreml_nn_suitability != PartitioningInfo.TryWithEP.YES
666
        or coreml_mlprogram_suitability != PartitioningInfo.TryWithEP.YES
667
    ) and logger.getEffectiveLevel() > logging.INFO:
668
        logger.info("Re-run with log level of INFO for more details on the NNAPI/CoreML issues.")
669

670
    return (
671
        nnapi_suitability != PartitioningInfo.TryWithEP.NO
672
        or coreml_nn_suitability != PartitioningInfo.TryWithEP.NO
673
        or coreml_mlprogram_suitability != PartitioningInfo.TryWithEP.NO
674
    )
675

676

677
def analyze_model(model_path: pathlib.Path, skip_optimize: bool = False, logger: logging.Logger | None = None):
678
    """
679
    Analyze the provided model to determine if it's likely to work well with the NNAPI or CoreML Execution Providers
680
    :param model_path: Model to analyze.
681
    :param skip_optimize: Skip optimizing to BASIC level before checking. When exporting to ORT format we will do this
682
                          optimization..
683
    :param logger: Logger for output
684
    :return: True if either the NNAPI or CoreML Execution Providers may work well with this model.
685
    """
686
    if not logger:
687
        logger = logging.getLogger("usability_checker")
688
        logger.setLevel(logging.INFO)
689

690
    logger.info(f"Checking {model_path} for usability with ORT Mobile.")
691

692
    with tempfile.TemporaryDirectory() as tmp:
693
        if not skip_optimize:
694
            tmp_path = pathlib.Path(tmp) / model_path.name
695
            optimize_model(model_path, tmp_path, use_external_initializers=True)
696
            model_path = tmp_path
697

698
        try_eps = checker(model_path.resolve(strict=True), logger)
699

700
    return try_eps
701

702

703
def parse_args():
704
    parser = argparse.ArgumentParser(
705
        os.path.basename(__file__), description="""Analyze an ONNX model for usage with the ORT mobile"""
706
    )
707

708
    parser.add_argument("--log_level", choices=["debug", "info"], default="info", help="Logging level")
709
    parser.add_argument(
710
        "--skip_optimize",
711
        action="store_true",
712
        help="Don't optimize the model to BASIC level prior to analyzing. "
713
        "Optimization will occur when exporting the model to ORT format, so in general "
714
        "should not be skipped unless you have a specific reason to do so.",
715
    )
716
    parser.add_argument("model_path", type=pathlib.Path, help="Provide path to ONNX model")
717

718
    return parser.parse_args()
719

720

721
def run_analyze_model():
722
    args = parse_args()
723
    logger = logging.getLogger("default")
724

725
    if args.log_level == "debug":
726
        logger.setLevel(logging.DEBUG)
727
    elif args.log_level == "info":
728
        logger.setLevel(logging.INFO)
729
    elif args.log_level == "warning":
730
        logger.setLevel(logging.WARNING)
731
    else:
732
        logger.setLevel(logging.ERROR)
733

734
    model_path = args.model_path.resolve()
735
    analyze_model(model_path, args.skip_optimize, logger)
736

737

738
if __name__ == "__main__":
739
    run_analyze_model()
740

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.