onnx-simplifier

Форк
0
/
onnx_simplifier.py 
536 строк · 21.9 Кб
1
import argparse
2

3
import copy
4
import os
5
import sys
6
import re
7
import tempfile
8
from typing import List, Dict, Union, Optional, Tuple, Sequence
9
from rich.text import Text
10
from rich import print
11
import numpy as np
12

13
import onnx  # type: ignore
14
import onnx.checker  # type: ignore
15
import onnx.helper  # type: ignore
16
import onnx.shape_inference  # type: ignore
17
import onnx.numpy_helper  # type: ignore
18
try:
19
    import onnxruntime as rt  # type: ignore
20
except ImportError:
21
    command = [sys.executable, '-m', 'pip', 'install', 'onnxruntime']
22
    print(Text(f"Installing onnxruntime by `{' '.join(command)}`, please wait for a moment..", style="bold magenta"))
23
    import subprocess
24
    subprocess.check_call(command)
25
    import onnxruntime as rt
26

27

28
import onnxsim.onnxsim_cpp2py_export as C
29
from . import model_info
30
from . import model_checking
31
from . import version
32

33

34
TensorShape = List[int]
35
TensorShapes = Dict[str, TensorShape]
36
TensorShapesWithOptionalKey = Dict[Optional[str], TensorShape]
37

38

39
def get_output_names(model: onnx.ModelProto) -> List[str]:
40
    output_names = [opt.name for opt in model.graph.output]
41
    return output_names
42

43

44
def remove_unused_output(
45
    model: onnx.ModelProto, unused_output: Sequence[str]
46
) -> onnx.ModelProto:
47
    unused_output_names = unused_output
48
    output_names = get_output_names(model)
49
    for unused_output_name in unused_output_names:
50
        if unused_output_name not in output_names:
51
            raise RuntimeError(
52
                f'The model doesn\'t have output named "{unused_output_name}"'
53
            )
54
    for graph_output in copy.deepcopy(model.graph.output):
55
        if graph_output.name in unused_output_names:
56
            model.graph.output.remove(graph_output)
57
    return model
58

59

60
def remove_initializer_from_input(model: onnx.ModelProto) -> onnx.ModelProto:
61
    initializer_names = [x.name for x in model.graph.initializer]
62
    for graph_input in copy.deepcopy(model.graph.input):
63
        if graph_input.name in initializer_names:
64
            model.graph.input.remove(graph_input)
65
    return model
66

67

68
def check_and_update_input_shapes(model: onnx.ModelProto, input_shapes: Optional[TensorShapesWithOptionalKey]) -> Optional[TensorShapes]:
69
    if input_shapes is None:
70
        return None
71

72
    def get_inputs(model: onnx.ModelProto) -> List[onnx.ValueInfoProto]:
73
        initializer_names = [x.name for x in model.graph.initializer]
74
        return [ipt for ipt in model.graph.input if ipt.name not in initializer_names]
75

76
    def get_input_names(model: onnx.ModelProto) -> List[str]:
77
        input_names = [ipt.name for ipt in get_inputs(model)]
78
        return input_names
79

80
    input_names = get_input_names(model)
81
    if None in input_shapes:
82
        if len(input_names) == 1:
83
            input_shapes[input_names[0]] = input_shapes[None]
84
            del input_shapes[None]
85
        else:
86
            raise RuntimeError(
87
                'The model has more than 1 inputs, please use the format "input_name:dim0,dim1,...,dimN" in --input-shape')
88
    for x in input_shapes:
89
        if x not in input_names:
90
            raise RuntimeError(
91
                'The model doesn\'t have input named "{}"'.format(x))
92

93
    return input_shapes  # type: ignore
94

95

96
# A very very large threshold
97
DEFAULT_TENSOR_SIZE_THRESHOLDHOLD = '1.5GB'
98

99

100
def simplify(
101
    model: Union[str, onnx.ModelProto],
102
    check_n: int = 0,
103
    perform_optimization: bool = True,
104
    skip_fuse_bn: bool = False,
105
    overwrite_input_shapes=None,
106
    test_input_shapes=None,
107
    skipped_optimizers: Optional[List[str]] = None,
108
    skip_constant_folding=False,
109
    skip_shape_inference=False,
110
    input_data=None,
111
    dynamic_input_shape: bool = False,
112
    custom_lib: Optional[str] = None,
113
    include_subgraph: bool = False,
114
    unused_output: Optional[Sequence[str]] = None,
115
    tensor_size_threshold: str = DEFAULT_TENSOR_SIZE_THRESHOLDHOLD,
116
    mutable_initializer: bool = False,
117
    *,
118
    input_shapes=None,
119
) -> Tuple[onnx.ModelProto, bool]:
120
    """
121
    :param model: onnx ModelProto object or file path
122
    :param check_n: The simplified model will be checked for `check_n` times by random inputs
123
    :param perform_optimization: Whether to run onnx optimizer on the model
124
    :param skip_fuse_bn: Skip fuse_bn_into_conv onnx optimizer
125
    :param overwrite_input_shapes: If the model has dynamic input shape, user must pass a fixed input shape
126
            for generating random inputs and checking equality.
127
    :param test_input_shapes: If the model has dynamic input shape, user must pass a fixed input shape
128
            for generating random inputs and checking equality.
129
    :param skipped_optimizers: Skip some specific onnx optimizers
130
    :param skip_constant_folding: Skip constant folding
131
    :param skip_shape_inference: Skip shape inference (sometimes shape inference will crash)
132
    :param input_data: Feed custom input data for checking if needed
133
    :param dynamic_input_shape: Deprecated. Not needed anymore.
134
    :param custom_lib: onnxruntime custom ops's shared library
135
    :param include_subgraph: Simplify subgraph (e.g. true graph and false graph of "If" operator) instead of only the main graph
136
    :param unused_output: name of unused outputs that will be eliminated from the model
137
    :param input_shapes: Deprecated. Please use `overwrite_input_shapes` and/or `test_input_shapes` instead.
138
    :return: A tuple (simplified model, success(True) or failed(False))
139
    """
140
    if dynamic_input_shape:
141
        print(
142
            Text(
143
                "WARNING: The argument `dynamic_input_shape=True` is not needed any more, onnxsim can now support dynamic input shapes natively, please refer to the latest documentation. An error will be raised in the future.",
144
                style="bold red",
145
            )
146
        )
147
    if input_shapes is not None:
148
        print(
149
            Text(
150
                "WARNING: The argument `input_shapes` is deprecated. Please use `overwrite_input_shapes` and/or `test_input_shapes` instead. An error will be raised in the future.",
151
                style="bold red",
152
            )
153
        )
154
        overwrite_input_shapes = input_shapes
155
        test_input_shapes = input_shapes
156

157
    if not perform_optimization:
158
        # None means skip all optimizers
159
        skipped_optimizers = None
160
    elif skipped_optimizers is None:
161
        skipped_optimizers = []
162

163
    if skip_fuse_bn and skipped_optimizers is not None:
164
        skipped_optimizers.append("fuse_bn_into_conv")
165
    if isinstance(model, str):
166
        model = onnx.load(model)
167
    if overwrite_input_shapes is None:
168
        overwrite_input_shapes = {}
169
    overwrite_input_shapes = check_and_update_input_shapes(
170
        model, overwrite_input_shapes)
171
    test_input_shapes = check_and_update_input_shapes(
172
        model, test_input_shapes)
173

174
    for name, input_shape in overwrite_input_shapes.items():
175
        for ipt in model.graph.input:
176
            if ipt.name == name:
177
                for i, dim in enumerate(ipt.type.tensor_type.shape.dim):
178
                    dim.dim_value = input_shape[i]
179
    if unused_output is not None:
180
        model = remove_unused_output(model, unused_output)
181
    if not mutable_initializer and model.ir_version >= 4:
182
        model = remove_initializer_from_input(model)
183

184
    # https://stackoverflow.com/a/60708339
185
    def parse_size(size: str) -> int:
186
        units = {"B": 1, "KB": 2**10, "MB": 2**20, "GB": 2**30, "TB": 2**40}
187
        size = size.upper()
188
        if not re.match(r' ', size):
189
            size = re.sub(r'([KMGT]?B)', r' \1', size)
190
        number, unit = [string.strip() for string in size.split()]
191
        return int(float(number)*units[unit])
192

193
    tensor_size_threshold = parse_size(tensor_size_threshold)
194
    if tensor_size_threshold > 2**31 - 9999:
195
        raise ValueError("tensor_size_threshold should be less than 2GB")
196

197
    try:
198
        model_bytes = model.SerializeToString()
199
        model_opt_bytes = C.simplify(
200
            model_bytes,
201
            skipped_optimizers,
202
            not skip_constant_folding,
203
            not skip_shape_inference,
204
            tensor_size_threshold,
205
        )
206
        if len(model_opt_bytes) == 0:
207
            raise ValueError("Simplified model larger than 2GB")
208
        model_opt = onnx.load_from_string(model_opt_bytes)
209
        check_ok = model_checking.compare(
210
            model_opt, model, check_n, test_input_shapes, input_data, custom_lib
211
        )
212
    except (ValueError, onnx.onnx_cpp2py_export.checker.ValidationError):
213
        print("[bold magenta]Simplified model larger than 2GB. Trying to save as external data...[/bold magenta]")
214
        # large models try to convert through a temporary file
215
        with tempfile.TemporaryDirectory() as tmpdirname:
216
            onnx.save(
217
                copy.deepcopy(model),
218
                os.path.join(tmpdirname, 'model.onnx'),
219
                save_as_external_data=True,
220
            )
221
            check_ok = C.simplify_path(
222
                os.path.join(tmpdirname, 'model.onnx'),
223
                os.path.join(tmpdirname, 'opt.onnx'),
224
                skipped_optimizers,
225
                not skip_constant_folding,
226
                not skip_shape_inference,
227
                tensor_size_threshold,
228
            )
229
            check_ok = model_checking.compare(
230
                os.path.join(tmpdirname, 'opt.onnx'),
231
                os.path.join(tmpdirname, 'model.onnx'),
232
                check_n, test_input_shapes, input_data, custom_lib
233
            )
234
            model_opt = onnx.load(os.path.join(tmpdirname, 'opt.onnx'))
235
    return model_opt, check_ok
236

237

238
class PyModelExecutor(C.ModelExecutor):
239
    def Run(self, model_str: str, inputs_str: List[str]):
240
        model = onnx.ModelProto()
241
        model.ParseFromString(model_str)
242

243
        def deserialize_tp(tp_str):
244
            tp = onnx.TensorProto()
245
            tp.ParseFromString(tp_str)
246
            return tp
247

248
        input_tps = map(deserialize_tp, inputs_str)
249
        input_arrs = map(onnx.numpy_helper.to_array, input_tps)
250
        input_names = [x.name for x in model.graph.input]
251
        inputs = dict(zip(input_names, input_arrs))
252
        sess_options = rt.SessionOptions()
253
        sess_options.graph_optimization_level = rt.GraphOptimizationLevel(0)
254
        sess_options.log_severity_level = 3
255
        sess = rt.InferenceSession(
256
            model.SerializeToString(),
257
            sess_options=sess_options,
258
            providers=["CPUExecutionProvider"],
259
        )
260
        output_names = [x.name for x in sess.get_outputs()]
261
        run_options = rt.RunOptions()
262
        run_options.log_severity_level = 3
263
        output_arrs = sess.run(output_names, inputs, run_options=run_options)
264
        return [
265
            onnx.numpy_helper.from_array(x).SerializeToString() for x in output_arrs
266
        ]
267

268

269
def main():
270
    parser = argparse.ArgumentParser()
271
    parser.add_argument("input_model", help="Input ONNX model")
272
    parser.add_argument("output_model", help="Output ONNX model")
273
    parser.add_argument(
274
        "check_n",
275
        help="Check whether the output is correct with n random inputs",
276
        nargs="?",
277
        type=int,
278
        default=0,
279
    )
280
    parser.add_argument(
281
        "--enable-fuse-bn",
282
        help="This option is deprecated. Fusing bn into conv is enabled by default.",
283
        action="store_true",
284
    )
285
    parser.add_argument(
286
        "--skip-fuse-bn", help="Skip fusing batchnorm into conv.", action="store_true"
287
    )
288
    parser.add_argument(
289
        "--skip-optimization",
290
        help="Skip all ONNX optimizers or some of them. To skip all optimizers, use `onnxsim a.onnx b.onnx --skip-optimization`. To skip some of optimizers, use something like `onnxsim a.onnx b.onnx --skip-optimization fuse_bn_into_conv fuse_pad_into_pool`.",
291
        type=str,
292
        nargs="*",
293
    )
294
    parser.add_argument("--skip-constant-folding", help="Skip constant folding", action="store_true")
295
    parser.add_argument(
296
        "--input-shape",
297
        help="This argument has been renamed to --overwrite-input-shape, please refer to it",
298
        type=str,
299
        nargs="+",
300
    )
301
    parser.add_argument(
302
        "--overwrite-input-shape",
303
        help='Overwrite the input shape. The format is "input_name:dim0,dim1,...,dimN" or simply "dim0,dim1,...,dimN" when there is only one input, for example, "data:1,3,224,224" or "1,3,224,224". Note: you might want to use some visualization tools like netron to make sure what the input name and dimension ordering (NCHW or NHWC) is.',
304
        type=str,
305
        nargs="+",
306
    )
307
    parser.add_argument(
308
        "--test-input-shape",
309
        help='The input shape to generated random inputs for test, useful when the input shape is dynamic. The format is "input_name:dim0,dim1,...,dimN" or simply "dim0,dim1,...,dimN" when there is only one input, for example, "data:1,3,224,224" or "1,3,224,224". Note: you might want to use some visualization tools like netron to make sure what the input name and dimension ordering (NCHW or NHWC) is.',
310
        type=str,
311
        nargs="+",
312
    )
313
    parser.add_argument(
314
        "--skip-optimizer",
315
        help="Deprecated. Refer to --skip-optimization",
316
        type=str,
317
        nargs="+",
318
    )
319
    parser.add_argument(
320
        "--skip-shape-inference", help="Skip shape inference", action="store_true"
321
    )
322
    parser.add_argument(
323
        "--enable-onnxruntime-optimization",
324
        help="Enable ONNX Runtime's ORT_ENABLE_BASIC level optimization.",
325
        action="store_true",
326
    )
327
    parser.add_argument(
328
        "--dynamic-input-shape",
329
        help="Deprecated. Not needed any more.",
330
        action="store_true",
331
    )
332
    parser.add_argument(
333
        "--input-data-path",
334
        help='input data, The value should be "input_name1:xxx1.bin"  "input_name2:xxx2.bin ...", input data should be a binary data file.',
335
        type=str,
336
        nargs="+",
337
    )
338
    parser.add_argument(
339
        "--custom-lib", help="Deprecated. Not needed any more.", type=str
340
    )
341
    parser.add_argument(
342
        "--include-subgraph",
343
        help='Experimental feature. Simplify subgraph (e.g. true graph and false graph of "If" operator) instead of only the main graph',
344
        action="store_true",
345
    )
346
    parser.add_argument(
347
        "--unused-output",
348
        help="Name of unused outputs that will be eliminated from the model",
349
        type=str,
350
        nargs="+",
351
    )
352
    parser.add_argument(
353
        "--no-large-tensor",
354
        help="Some ops like Tile and ConstantOfShape can produce large tensor and make the model size much larger. Specifying this flag to skip folding these ops, with loss of some optimization chances. It can be followed with a threshold, for example, --no-large-tensor 1M or --no-large-tensor 100KB. A simple '--no-large-tensor' means '--no-large-tensor 1KB'.",
355
        type=str,
356
        const='1KB',
357
        default=DEFAULT_TENSOR_SIZE_THRESHOLDHOLD,
358
        nargs="?",
359
        dest="tensor_size_threshold",
360
    )
361
    parser.add_argument(
362
        "--mutable-initializer",
363
        help="By ONNX specification, initializers can also serve as inputs. This allows users to overwrite their values during runtime, but some useful optimizations like fuse-conv-and-bn will not be applicable anymore. In almost all cases, having an initializer that is also an input is unintended (usually caused by a out-dated PyTorch). So onnxsim treats all initializers immutable to enabling all optimizations. If it is not wanted, you can specify '--mutable-initializer' to disable this behavior.",
364
        action="store_true",
365
        )
366
    parser.add_argument(
367
        "--save-as-external-data",
368
        help="Save parameters as external data. This will make the .onnx file much smaller, but the .onnx file will depend on the external data file (.data).",
369
        action="store_true",
370
        )
371
    parser.add_argument('-v', '--version', action='version', version='onnxsim ' + version.version)
372

373
    args = parser.parse_args()
374

375
    if args.enable_fuse_bn:
376
        print(
377
            Text(
378
                'WARNING: "--enable-fuse-bn" is not needed any more, because fuse bn is enabled by default. "--enable-fuse-bn" flag is ignored now and will raise an error in the future.',
379
                style="bold red",
380
            )
381
        )
382
    if args.dynamic_input_shape:
383
        print(
384
            Text(
385
                'WARNING: "--dynamic-input-shape" is not needed any more, onnxsim v0.4 now handles dynamic input shapes automatically. "--dynamic-input-shape" flag is ignored now and will raise an error in the future.',
386
                style="bold red",
387
            )
388
        )
389
    assert not (args.input_shape is not None and args.overwrite_input_shape is not None)
390
    if args.input_shape:
391
        print(
392
            Text(
393
                'WARNING: "--input-shape" is renamed to "--overwrite-input-shape". Please use it instead.',
394
                style="bold red",
395
            )
396
        )
397
        args.overwrite_input_shape = args.input_shape
398
    if args.include_subgraph:
399
        print(
400
            Text(
401
                "WARNING: subgraph optimization is not supported in v0.4 for now.",
402
                style="bold red",
403
            )
404
        )
405
    assert not (args.skip_optimizer is not None and args.skip_optimization is not None)
406
    if args.skip_optimizer:
407
        print(
408
            Text(
409
                'WARNING: "--skip-optimizer" is renamed to "--skip-optimization". Please use it instead.',
410
                style="bold red",
411
            )
412
        )
413
        args.skip_optimization = args.skip_optimizer
414
    if args.skip_optimization is None:
415
        # user doesn't specify --skip-optimization
416
        args.skip_optimization = []
417
    elif len(args.skip_optimization) == 0:
418
        # user specify --skip-optimization without any certain optimizer name
419
        # set it to None means skip all optimizations
420
        args.skip_optimization = None
421
    if args.skip_fuse_bn and args.skip_optimization is not None:
422
        args.skip_optimization.append("fuse_bn_into_conv")
423

424
    perform_optimization = False if args.skip_optimization is None else True
425

426
    def parse_shapes(shapes_arg):
427
        shapes = {}
428
        if shapes_arg is not None:
429
            for x in shapes_arg:
430
                if ':' not in x:
431
                    shapes[None] = list(map(int, x.split(',')))
432
                else:
433
                    pieces = x.split(':')
434
                    # for the input name like input:0
435
                    name, shape = ':'.join(
436
                        pieces[:-1]), list(map(int, pieces[-1].split(',')))
437
                    shapes.update({name: shape})
438
        return shapes
439

440
    test_input_shapes = parse_shapes(args.test_input_shape)
441
    overwrite_input_shapes = parse_shapes(args.overwrite_input_shape)
442

443
    if args.enable_onnxruntime_optimization:
444

445
        tmp_file = tempfile.NamedTemporaryFile()
446
        sess_options = rt.SessionOptions()
447
        # Set graph optimization level
448
        sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_BASIC
449
        # To enable model serialization after graph optimization
450
        sess_options.optimized_model_filepath = tmp_file.name
451
        _ = rt.InferenceSession(args.input_model, sess_options, providers=["CPUExecutionProvider"])
452

453
        model = onnx.load(tmp_file.name)
454
    else:
455
        model = onnx.load(args.input_model)
456

457
    if args.tensor_size_threshold == DEFAULT_TENSOR_SIZE_THRESHOLDHOLD:
458
        for node in model.graph.node:
459
            if node.op_type in ["Tile", "ConstantOfShape"]:
460
                print(
461
                    Text(
462
                        'Your model contains "Tile" ops or/and "ConstantOfShape" ops. Folding these ops can make the simplified model much larger. If it is not expected, please specify "--no-large-tensor" (which will lose some optimization chances)',
463
                        style="bold magenta",
464
                    )
465
                )
466
                break
467

468
    if not args.mutable_initializer:
469
        initializer_names = set([x.name for x in model.graph.initializer])
470
        input_names = set([x.name for x in model.graph.input])
471
        if len(initializer_names.intersection(input_names)) > 0:
472
            print(
473
                Text(
474
                    'Your model contains initializers that are also inputs. This is usually caused by an out-dated PyTorch. onnxsim treats all initializers immutable to enabling all optimizations. If it is not wanted, please specify "--mutable-initializer" to disable this behavior.',
475
                    style="bold magenta",
476
                )
477
            )
478

479
    input_tensors = None
480
    if args.input_data_path is not None:
481
        input_tensors = {}
482
        for x in args.input_data_path:
483
            pieces = x.split(':')
484
            name, data = ':'.join(pieces[:-1]), pieces[-1]
485
            input_tensors.update({name: np.load(data)})
486

487
    print("Simplifying...")
488

489
    model_opt, check_ok = simplify(
490
        model,
491
        args.check_n,
492
        perform_optimization,
493
        False,
494
        overwrite_input_shapes,
495
        test_input_shapes,
496
        args.skip_optimization,
497
        args.skip_constant_folding,
498
        args.skip_shape_inference,
499
        input_tensors,
500
        False,
501
        args.custom_lib,
502
        args.include_subgraph,
503
        args.unused_output,
504
        args.tensor_size_threshold,
505
        args.mutable_initializer,
506
    )
507

508
    try:
509
        if not args.save_as_external_data:
510
            onnx.save(model_opt, args.output_model)
511
        else:
512
            raise ValueError("save_as_external_data")
513
    except ValueError:
514
        # large models (>2GB) which onnx.save doesn't support,
515
        # or explicitly specified --save-as-external-data
516
        external_data_path = os.path.basename(args.output_model) + '.data'
517
        if os.path.exists(external_data_path):
518
            os.remove(external_data_path)
519
        onnx.save(
520
            copy.deepcopy(model_opt),
521
            args.output_model,
522
            save_as_external_data=True,
523
            all_tensors_to_one_file=True,
524
            location=external_data_path,
525
        )
526

527
    if check_ok:
528
        print("Finish! Here is the difference:")
529
        model_info.print_simplifying_info(model, model_opt)
530
    else:
531
        print(
532
            'Check failed. Please be careful to use the simplified model, or try specifying "--skip-fuse-bn" or "--skip-optimization" (run "onnxsim -h" for details).'
533
        )
534
        print("Here is the difference after simplification:")
535
        model_info.print_simplifying_info(model, model_opt)
536
        sys.exit(1)
537

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.