onnx-simplifier
185 строк · 6.4 Кб
1import os2from typing import List, Dict, Optional, Union3from collections import OrderedDict4
5import onnx6import onnx.checker7import numpy as np8import onnxruntime as rt9
10Tensors = Dict[str, np.ndarray]11TensorShape = List[int]12TensorShapes = Dict[Optional[str], TensorShape]13
14
15def compare(16model_opt: Union[str, onnx.ModelProto],17model_ori: Union[str, onnx.ModelProto],18n_times: int = 5,19input_shapes: Optional[TensorShapes] = None,20input_data: Optional[Tensors] = None,21custom_lib: Optional[str] = None,22verbose=True,23) -> bool:24"""25:param model_opt: The simplified ONNX model
26:param model_ori: The original ONNX model
27:param n_times: Generate n random inputs
28:param input_shapes: Shapes of generated random inputs
29:param input_data: User-given data instead of random generated data
30:param custom_lib: ONNX Runtime custom lib for custom ops
31"""
32
33def get_shape_from_value_info_proto(v: onnx.ValueInfoProto) -> List[int]:34return [dim.dim_value for dim in v.type.tensor_type.shape.dim]35
36def get_value_info_all(37m: onnx.ModelProto, name: str38) -> Optional[onnx.ValueInfoProto]:39for v in m.graph.value_info:40if v.name == name:41return v42
43for v in m.graph.input:44if v.name == name:45return v46
47for v in m.graph.output:48if v.name == name:49return v50
51return None52
53def get_shape(m: onnx.ModelProto, name: str) -> TensorShape:54"""55Note: This method relies on onnx shape inference, which is not reliable. So only use it on input or output tensors
56"""
57v = get_value_info_all(m, name)58if v is not None:59return get_shape_from_value_info_proto(v)60raise RuntimeError('Cannot get shape of "{}"'.format(name))61
62def get_elem_type(m: onnx.ModelProto, name: str) -> Optional[int]:63v = get_value_info_all(m, name)64if v is not None:65return v.type.tensor_type.elem_type66return None67
68def get_np_type_from_elem_type(elem_type: int) -> int:69sizes = (70None,71np.float32,72np.uint8,73np.int8,74np.uint16,75np.int16,76np.int32,77np.int64,78str,79bool,80np.float16,81np.double,82np.uint32,83np.uint64,84np.complex64,85np.complex128,86np.float16,87)88assert len(sizes) == 1789size = sizes[elem_type]90assert size is not None91return size92
93def get_input_names(model: onnx.ModelProto) -> List[str]:94input_names = list(95set([ipt.name for ipt in model.graph.input])96- set([x.name for x in model.graph.initializer])97)98return input_names99
100def generate_rand_input(101model: Union[str, onnx.ModelProto],102input_shapes: Optional[TensorShapes] = None103):104if input_shapes is None:105input_shapes = {}106if isinstance(model, str):107model = onnx.load(model, load_external_data=False)108input_names = get_input_names(model)109full_input_shapes = {ipt: get_shape(model, ipt) for ipt in input_names}110assert None not in input_shapes111full_input_shapes.update(input_shapes) # type: ignore112for name, shape in full_input_shapes.items():113if any([dim <= 0 for dim in shape[1:]]):114raise RuntimeError(115'The shape of input "{}" has dynamic size, '116"please set an input shape manually with --test-input-shape".format(name)117)118if len(shape) > 0 and shape[0] <= 0:119print(f'shape[0] of input "{name}" is dynamic, we assume it presents batch size and set it as 1 when testing. If it is not wanted, please set the it manually by --test-input-shape (see `onnxsim -h` for the details).')120shape[0] = 1121
122inputs = {123ipt: np.array(124np.random.rand(*full_input_shapes[ipt]),125dtype=get_np_type_from_elem_type(get_elem_type(model, ipt)),126)127for ipt in input_names128}129return inputs130
131def forward(132model: Union[str, onnx.ModelProto],133inputs: Tensors,134custom_lib: Optional[str] = None135) -> Dict[str, np.ndarray]:136sess_options = rt.SessionOptions()137if custom_lib is not None:138if os.path.exists(custom_lib):139sess_options.register_custom_ops_library(custom_lib)140else:141raise ValueError("No such file '{}'".format(custom_lib))142sess_options.graph_optimization_level = rt.GraphOptimizationLevel(0)143sess_options.log_severity_level = 3144if isinstance(model, onnx.ModelProto):145model = model.SerializeToString()146sess = rt.InferenceSession(147model,148sess_options=sess_options,149providers=["CPUExecutionProvider"],150)151outputs = [x.name for x in sess.get_outputs()]152run_options = rt.RunOptions()153run_options.log_severity_level = 3154res = OrderedDict(155zip(outputs, sess.run(outputs, inputs, run_options=run_options))156)157return res158
159if input_shapes is None:160input_shapes = {}161onnx.checker.check_model(model_opt)162for i in range(n_times):163print(f'Checking {i}/{n_times}...')164if input_data is None:165inputs = generate_rand_input(model_opt, input_shapes=input_shapes)166else:167inputs = input_data168res_ori = forward(model_ori, inputs, custom_lib)169res_opt = forward(model_opt, inputs, custom_lib)170
171for name in res_opt.keys():172if not np.allclose(res_opt[name], res_ori[name], rtol=1e-4, atol=1e-5):173if verbose:174print(175"Tensor {} changes after optimization. The max diff is {}.".format(176name, np.max(np.abs(res_opt[name] - res_ori[name]))177)178)179print("After optimization:")180print(res_opt[name])181print("Before optimization:")182print(res_ori[name])183print("----------------")184return False185return True186