onnx-simplifier
536 строк · 21.9 Кб
1import argparse2
3import copy4import os5import sys6import re7import tempfile8from typing import List, Dict, Union, Optional, Tuple, Sequence9from rich.text import Text10from rich import print11import numpy as np12
13import onnx # type: ignore14import onnx.checker # type: ignore15import onnx.helper # type: ignore16import onnx.shape_inference # type: ignore17import onnx.numpy_helper # type: ignore18try:19import onnxruntime as rt # type: ignore20except ImportError:21command = [sys.executable, '-m', 'pip', 'install', 'onnxruntime']22print(Text(f"Installing onnxruntime by `{' '.join(command)}`, please wait for a moment..", style="bold magenta"))23import subprocess24subprocess.check_call(command)25import onnxruntime as rt26
27
28import onnxsim.onnxsim_cpp2py_export as C29from . import model_info30from . import model_checking31from . import version32
33
34TensorShape = List[int]35TensorShapes = Dict[str, TensorShape]36TensorShapesWithOptionalKey = Dict[Optional[str], TensorShape]37
38
39def get_output_names(model: onnx.ModelProto) -> List[str]:40output_names = [opt.name for opt in model.graph.output]41return output_names42
43
44def remove_unused_output(45model: onnx.ModelProto, unused_output: Sequence[str]46) -> onnx.ModelProto:47unused_output_names = unused_output48output_names = get_output_names(model)49for unused_output_name in unused_output_names:50if unused_output_name not in output_names:51raise RuntimeError(52f'The model doesn\'t have output named "{unused_output_name}"'53)54for graph_output in copy.deepcopy(model.graph.output):55if graph_output.name in unused_output_names:56model.graph.output.remove(graph_output)57return model58
59
60def remove_initializer_from_input(model: onnx.ModelProto) -> onnx.ModelProto:61initializer_names = [x.name for x in model.graph.initializer]62for graph_input in copy.deepcopy(model.graph.input):63if graph_input.name in initializer_names:64model.graph.input.remove(graph_input)65return model66
67
68def check_and_update_input_shapes(model: onnx.ModelProto, input_shapes: Optional[TensorShapesWithOptionalKey]) -> Optional[TensorShapes]:69if input_shapes is None:70return None71
72def get_inputs(model: onnx.ModelProto) -> List[onnx.ValueInfoProto]:73initializer_names = [x.name for x in model.graph.initializer]74return [ipt for ipt in model.graph.input if ipt.name not in initializer_names]75
76def get_input_names(model: onnx.ModelProto) -> List[str]:77input_names = [ipt.name for ipt in get_inputs(model)]78return input_names79
80input_names = get_input_names(model)81if None in input_shapes:82if len(input_names) == 1:83input_shapes[input_names[0]] = input_shapes[None]84del input_shapes[None]85else:86raise RuntimeError(87'The model has more than 1 inputs, please use the format "input_name:dim0,dim1,...,dimN" in --input-shape')88for x in input_shapes:89if x not in input_names:90raise RuntimeError(91'The model doesn\'t have input named "{}"'.format(x))92
93return input_shapes # type: ignore94
95
96# A very very large threshold
97DEFAULT_TENSOR_SIZE_THRESHOLDHOLD = '1.5GB'98
99
100def simplify(101model: Union[str, onnx.ModelProto],102check_n: int = 0,103perform_optimization: bool = True,104skip_fuse_bn: bool = False,105overwrite_input_shapes=None,106test_input_shapes=None,107skipped_optimizers: Optional[List[str]] = None,108skip_constant_folding=False,109skip_shape_inference=False,110input_data=None,111dynamic_input_shape: bool = False,112custom_lib: Optional[str] = None,113include_subgraph: bool = False,114unused_output: Optional[Sequence[str]] = None,115tensor_size_threshold: str = DEFAULT_TENSOR_SIZE_THRESHOLDHOLD,116mutable_initializer: bool = False,117*,118input_shapes=None,119) -> Tuple[onnx.ModelProto, bool]:120"""121:param model: onnx ModelProto object or file path
122:param check_n: The simplified model will be checked for `check_n` times by random inputs
123:param perform_optimization: Whether to run onnx optimizer on the model
124:param skip_fuse_bn: Skip fuse_bn_into_conv onnx optimizer
125:param overwrite_input_shapes: If the model has dynamic input shape, user must pass a fixed input shape
126for generating random inputs and checking equality.
127:param test_input_shapes: If the model has dynamic input shape, user must pass a fixed input shape
128for generating random inputs and checking equality.
129:param skipped_optimizers: Skip some specific onnx optimizers
130:param skip_constant_folding: Skip constant folding
131:param skip_shape_inference: Skip shape inference (sometimes shape inference will crash)
132:param input_data: Feed custom input data for checking if needed
133:param dynamic_input_shape: Deprecated. Not needed anymore.
134:param custom_lib: onnxruntime custom ops's shared library
135:param include_subgraph: Simplify subgraph (e.g. true graph and false graph of "If" operator) instead of only the main graph
136:param unused_output: name of unused outputs that will be eliminated from the model
137:param input_shapes: Deprecated. Please use `overwrite_input_shapes` and/or `test_input_shapes` instead.
138:return: A tuple (simplified model, success(True) or failed(False))
139"""
140if dynamic_input_shape:141print(142Text(143"WARNING: The argument `dynamic_input_shape=True` is not needed any more, onnxsim can now support dynamic input shapes natively, please refer to the latest documentation. An error will be raised in the future.",144style="bold red",145)146)147if input_shapes is not None:148print(149Text(150"WARNING: The argument `input_shapes` is deprecated. Please use `overwrite_input_shapes` and/or `test_input_shapes` instead. An error will be raised in the future.",151style="bold red",152)153)154overwrite_input_shapes = input_shapes155test_input_shapes = input_shapes156
157if not perform_optimization:158# None means skip all optimizers159skipped_optimizers = None160elif skipped_optimizers is None:161skipped_optimizers = []162
163if skip_fuse_bn and skipped_optimizers is not None:164skipped_optimizers.append("fuse_bn_into_conv")165if isinstance(model, str):166model = onnx.load(model)167if overwrite_input_shapes is None:168overwrite_input_shapes = {}169overwrite_input_shapes = check_and_update_input_shapes(170model, overwrite_input_shapes)171test_input_shapes = check_and_update_input_shapes(172model, test_input_shapes)173
174for name, input_shape in overwrite_input_shapes.items():175for ipt in model.graph.input:176if ipt.name == name:177for i, dim in enumerate(ipt.type.tensor_type.shape.dim):178dim.dim_value = input_shape[i]179if unused_output is not None:180model = remove_unused_output(model, unused_output)181if not mutable_initializer and model.ir_version >= 4:182model = remove_initializer_from_input(model)183
184# https://stackoverflow.com/a/60708339185def parse_size(size: str) -> int:186units = {"B": 1, "KB": 2**10, "MB": 2**20, "GB": 2**30, "TB": 2**40}187size = size.upper()188if not re.match(r' ', size):189size = re.sub(r'([KMGT]?B)', r' \1', size)190number, unit = [string.strip() for string in size.split()]191return int(float(number)*units[unit])192
193tensor_size_threshold = parse_size(tensor_size_threshold)194if tensor_size_threshold > 2**31 - 9999:195raise ValueError("tensor_size_threshold should be less than 2GB")196
197try:198model_bytes = model.SerializeToString()199model_opt_bytes = C.simplify(200model_bytes,201skipped_optimizers,202not skip_constant_folding,203not skip_shape_inference,204tensor_size_threshold,205)206if len(model_opt_bytes) == 0:207raise ValueError("Simplified model larger than 2GB")208model_opt = onnx.load_from_string(model_opt_bytes)209check_ok = model_checking.compare(210model_opt, model, check_n, test_input_shapes, input_data, custom_lib211)212except (ValueError, onnx.onnx_cpp2py_export.checker.ValidationError):213print("[bold magenta]Simplified model larger than 2GB. Trying to save as external data...[/bold magenta]")214# large models try to convert through a temporary file215with tempfile.TemporaryDirectory() as tmpdirname:216onnx.save(217copy.deepcopy(model),218os.path.join(tmpdirname, 'model.onnx'),219save_as_external_data=True,220)221check_ok = C.simplify_path(222os.path.join(tmpdirname, 'model.onnx'),223os.path.join(tmpdirname, 'opt.onnx'),224skipped_optimizers,225not skip_constant_folding,226not skip_shape_inference,227tensor_size_threshold,228)229check_ok = model_checking.compare(230os.path.join(tmpdirname, 'opt.onnx'),231os.path.join(tmpdirname, 'model.onnx'),232check_n, test_input_shapes, input_data, custom_lib233)234model_opt = onnx.load(os.path.join(tmpdirname, 'opt.onnx'))235return model_opt, check_ok236
237
238class PyModelExecutor(C.ModelExecutor):239def Run(self, model_str: str, inputs_str: List[str]):240model = onnx.ModelProto()241model.ParseFromString(model_str)242
243def deserialize_tp(tp_str):244tp = onnx.TensorProto()245tp.ParseFromString(tp_str)246return tp247
248input_tps = map(deserialize_tp, inputs_str)249input_arrs = map(onnx.numpy_helper.to_array, input_tps)250input_names = [x.name for x in model.graph.input]251inputs = dict(zip(input_names, input_arrs))252sess_options = rt.SessionOptions()253sess_options.graph_optimization_level = rt.GraphOptimizationLevel(0)254sess_options.log_severity_level = 3255sess = rt.InferenceSession(256model.SerializeToString(),257sess_options=sess_options,258providers=["CPUExecutionProvider"],259)260output_names = [x.name for x in sess.get_outputs()]261run_options = rt.RunOptions()262run_options.log_severity_level = 3263output_arrs = sess.run(output_names, inputs, run_options=run_options)264return [265onnx.numpy_helper.from_array(x).SerializeToString() for x in output_arrs266]267
268
269def main():270parser = argparse.ArgumentParser()271parser.add_argument("input_model", help="Input ONNX model")272parser.add_argument("output_model", help="Output ONNX model")273parser.add_argument(274"check_n",275help="Check whether the output is correct with n random inputs",276nargs="?",277type=int,278default=0,279)280parser.add_argument(281"--enable-fuse-bn",282help="This option is deprecated. Fusing bn into conv is enabled by default.",283action="store_true",284)285parser.add_argument(286"--skip-fuse-bn", help="Skip fusing batchnorm into conv.", action="store_true"287)288parser.add_argument(289"--skip-optimization",290help="Skip all ONNX optimizers or some of them. To skip all optimizers, use `onnxsim a.onnx b.onnx --skip-optimization`. To skip some of optimizers, use something like `onnxsim a.onnx b.onnx --skip-optimization fuse_bn_into_conv fuse_pad_into_pool`.",291type=str,292nargs="*",293)294parser.add_argument("--skip-constant-folding", help="Skip constant folding", action="store_true")295parser.add_argument(296"--input-shape",297help="This argument has been renamed to --overwrite-input-shape, please refer to it",298type=str,299nargs="+",300)301parser.add_argument(302"--overwrite-input-shape",303help='Overwrite the input shape. The format is "input_name:dim0,dim1,...,dimN" or simply "dim0,dim1,...,dimN" when there is only one input, for example, "data:1,3,224,224" or "1,3,224,224". Note: you might want to use some visualization tools like netron to make sure what the input name and dimension ordering (NCHW or NHWC) is.',304type=str,305nargs="+",306)307parser.add_argument(308"--test-input-shape",309help='The input shape to generated random inputs for test, useful when the input shape is dynamic. The format is "input_name:dim0,dim1,...,dimN" or simply "dim0,dim1,...,dimN" when there is only one input, for example, "data:1,3,224,224" or "1,3,224,224". Note: you might want to use some visualization tools like netron to make sure what the input name and dimension ordering (NCHW or NHWC) is.',310type=str,311nargs="+",312)313parser.add_argument(314"--skip-optimizer",315help="Deprecated. Refer to --skip-optimization",316type=str,317nargs="+",318)319parser.add_argument(320"--skip-shape-inference", help="Skip shape inference", action="store_true"321)322parser.add_argument(323"--enable-onnxruntime-optimization",324help="Enable ONNX Runtime's ORT_ENABLE_BASIC level optimization.",325action="store_true",326)327parser.add_argument(328"--dynamic-input-shape",329help="Deprecated. Not needed any more.",330action="store_true",331)332parser.add_argument(333"--input-data-path",334help='input data, The value should be "input_name1:xxx1.bin" "input_name2:xxx2.bin ...", input data should be a binary data file.',335type=str,336nargs="+",337)338parser.add_argument(339"--custom-lib", help="Deprecated. Not needed any more.", type=str340)341parser.add_argument(342"--include-subgraph",343help='Experimental feature. Simplify subgraph (e.g. true graph and false graph of "If" operator) instead of only the main graph',344action="store_true",345)346parser.add_argument(347"--unused-output",348help="Name of unused outputs that will be eliminated from the model",349type=str,350nargs="+",351)352parser.add_argument(353"--no-large-tensor",354help="Some ops like Tile and ConstantOfShape can produce large tensor and make the model size much larger. Specifying this flag to skip folding these ops, with loss of some optimization chances. It can be followed with a threshold, for example, --no-large-tensor 1M or --no-large-tensor 100KB. A simple '--no-large-tensor' means '--no-large-tensor 1KB'.",355type=str,356const='1KB',357default=DEFAULT_TENSOR_SIZE_THRESHOLDHOLD,358nargs="?",359dest="tensor_size_threshold",360)361parser.add_argument(362"--mutable-initializer",363help="By ONNX specification, initializers can also serve as inputs. This allows users to overwrite their values during runtime, but some useful optimizations like fuse-conv-and-bn will not be applicable anymore. In almost all cases, having an initializer that is also an input is unintended (usually caused by a out-dated PyTorch). So onnxsim treats all initializers immutable to enabling all optimizations. If it is not wanted, you can specify '--mutable-initializer' to disable this behavior.",364action="store_true",365)366parser.add_argument(367"--save-as-external-data",368help="Save parameters as external data. This will make the .onnx file much smaller, but the .onnx file will depend on the external data file (.data).",369action="store_true",370)371parser.add_argument('-v', '--version', action='version', version='onnxsim ' + version.version)372
373args = parser.parse_args()374
375if args.enable_fuse_bn:376print(377Text(378'WARNING: "--enable-fuse-bn" is not needed any more, because fuse bn is enabled by default. "--enable-fuse-bn" flag is ignored now and will raise an error in the future.',379style="bold red",380)381)382if args.dynamic_input_shape:383print(384Text(385'WARNING: "--dynamic-input-shape" is not needed any more, onnxsim v0.4 now handles dynamic input shapes automatically. "--dynamic-input-shape" flag is ignored now and will raise an error in the future.',386style="bold red",387)388)389assert not (args.input_shape is not None and args.overwrite_input_shape is not None)390if args.input_shape:391print(392Text(393'WARNING: "--input-shape" is renamed to "--overwrite-input-shape". Please use it instead.',394style="bold red",395)396)397args.overwrite_input_shape = args.input_shape398if args.include_subgraph:399print(400Text(401"WARNING: subgraph optimization is not supported in v0.4 for now.",402style="bold red",403)404)405assert not (args.skip_optimizer is not None and args.skip_optimization is not None)406if args.skip_optimizer:407print(408Text(409'WARNING: "--skip-optimizer" is renamed to "--skip-optimization". Please use it instead.',410style="bold red",411)412)413args.skip_optimization = args.skip_optimizer414if args.skip_optimization is None:415# user doesn't specify --skip-optimization416args.skip_optimization = []417elif len(args.skip_optimization) == 0:418# user specify --skip-optimization without any certain optimizer name419# set it to None means skip all optimizations420args.skip_optimization = None421if args.skip_fuse_bn and args.skip_optimization is not None:422args.skip_optimization.append("fuse_bn_into_conv")423
424perform_optimization = False if args.skip_optimization is None else True425
426def parse_shapes(shapes_arg):427shapes = {}428if shapes_arg is not None:429for x in shapes_arg:430if ':' not in x:431shapes[None] = list(map(int, x.split(',')))432else:433pieces = x.split(':')434# for the input name like input:0435name, shape = ':'.join(436pieces[:-1]), list(map(int, pieces[-1].split(',')))437shapes.update({name: shape})438return shapes439
440test_input_shapes = parse_shapes(args.test_input_shape)441overwrite_input_shapes = parse_shapes(args.overwrite_input_shape)442
443if args.enable_onnxruntime_optimization:444
445tmp_file = tempfile.NamedTemporaryFile()446sess_options = rt.SessionOptions()447# Set graph optimization level448sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_BASIC449# To enable model serialization after graph optimization450sess_options.optimized_model_filepath = tmp_file.name451_ = rt.InferenceSession(args.input_model, sess_options, providers=["CPUExecutionProvider"])452
453model = onnx.load(tmp_file.name)454else:455model = onnx.load(args.input_model)456
457if args.tensor_size_threshold == DEFAULT_TENSOR_SIZE_THRESHOLDHOLD:458for node in model.graph.node:459if node.op_type in ["Tile", "ConstantOfShape"]:460print(461Text(462'Your model contains "Tile" ops or/and "ConstantOfShape" ops. Folding these ops can make the simplified model much larger. If it is not expected, please specify "--no-large-tensor" (which will lose some optimization chances)',463style="bold magenta",464)465)466break467
468if not args.mutable_initializer:469initializer_names = set([x.name for x in model.graph.initializer])470input_names = set([x.name for x in model.graph.input])471if len(initializer_names.intersection(input_names)) > 0:472print(473Text(474'Your model contains initializers that are also inputs. This is usually caused by an out-dated PyTorch. onnxsim treats all initializers immutable to enabling all optimizations. If it is not wanted, please specify "--mutable-initializer" to disable this behavior.',475style="bold magenta",476)477)478
479input_tensors = None480if args.input_data_path is not None:481input_tensors = {}482for x in args.input_data_path:483pieces = x.split(':')484name, data = ':'.join(pieces[:-1]), pieces[-1]485input_tensors.update({name: np.load(data)})486
487print("Simplifying...")488
489model_opt, check_ok = simplify(490model,491args.check_n,492perform_optimization,493False,494overwrite_input_shapes,495test_input_shapes,496args.skip_optimization,497args.skip_constant_folding,498args.skip_shape_inference,499input_tensors,500False,501args.custom_lib,502args.include_subgraph,503args.unused_output,504args.tensor_size_threshold,505args.mutable_initializer,506)507
508try:509if not args.save_as_external_data:510onnx.save(model_opt, args.output_model)511else:512raise ValueError("save_as_external_data")513except ValueError:514# large models (>2GB) which onnx.save doesn't support,515# or explicitly specified --save-as-external-data516external_data_path = os.path.basename(args.output_model) + '.data'517if os.path.exists(external_data_path):518os.remove(external_data_path)519onnx.save(520copy.deepcopy(model_opt),521args.output_model,522save_as_external_data=True,523all_tensors_to_one_file=True,524location=external_data_path,525)526
527if check_ok:528print("Finish! Here is the difference:")529model_info.print_simplifying_info(model, model_opt)530else:531print(532'Check failed. Please be careful to use the simplified model, or try specifying "--skip-fuse-bn" or "--skip-optimization" (run "onnxsim -h" for details).'533)534print("Here is the difference after simplification:")535model_info.print_simplifying_info(model, model_opt)536sys.exit(1)537