onnx-simplifier

Форк
0
/
model_checking.py 
185 строк · 6.4 Кб
1
import os
2
from typing import List, Dict, Optional, Union
3
from collections import OrderedDict
4

5
import onnx
6
import onnx.checker
7
import numpy as np
8
import onnxruntime as rt
9

10
Tensors = Dict[str, np.ndarray]
11
TensorShape = List[int]
12
TensorShapes = Dict[Optional[str], TensorShape]
13

14

15
def compare(
16
    model_opt: Union[str, onnx.ModelProto],
17
    model_ori: Union[str, onnx.ModelProto],
18
    n_times: int = 5,
19
    input_shapes: Optional[TensorShapes] = None,
20
    input_data: Optional[Tensors] = None,
21
    custom_lib: Optional[str] = None,
22
    verbose=True,
23
) -> bool:
24
    """
25
    :param model_opt: The simplified ONNX model
26
    :param model_ori: The original ONNX model
27
    :param n_times: Generate n random inputs
28
    :param input_shapes: Shapes of generated random inputs
29
    :param input_data: User-given data instead of random generated data
30
    :param custom_lib: ONNX Runtime custom lib for custom ops
31
    """
32

33
    def get_shape_from_value_info_proto(v: onnx.ValueInfoProto) -> List[int]:
34
        return [dim.dim_value for dim in v.type.tensor_type.shape.dim]
35

36
    def get_value_info_all(
37
        m: onnx.ModelProto, name: str
38
    ) -> Optional[onnx.ValueInfoProto]:
39
        for v in m.graph.value_info:
40
            if v.name == name:
41
                return v
42

43
        for v in m.graph.input:
44
            if v.name == name:
45
                return v
46

47
        for v in m.graph.output:
48
            if v.name == name:
49
                return v
50

51
        return None
52

53
    def get_shape(m: onnx.ModelProto, name: str) -> TensorShape:
54
        """
55
        Note: This method relies on onnx shape inference, which is not reliable. So only use it on input or output tensors
56
        """
57
        v = get_value_info_all(m, name)
58
        if v is not None:
59
            return get_shape_from_value_info_proto(v)
60
        raise RuntimeError('Cannot get shape of "{}"'.format(name))
61

62
    def get_elem_type(m: onnx.ModelProto, name: str) -> Optional[int]:
63
        v = get_value_info_all(m, name)
64
        if v is not None:
65
            return v.type.tensor_type.elem_type
66
        return None
67

68
    def get_np_type_from_elem_type(elem_type: int) -> int:
69
        sizes = (
70
            None,
71
            np.float32,
72
            np.uint8,
73
            np.int8,
74
            np.uint16,
75
            np.int16,
76
            np.int32,
77
            np.int64,
78
            str,
79
            bool,
80
            np.float16,
81
            np.double,
82
            np.uint32,
83
            np.uint64,
84
            np.complex64,
85
            np.complex128,
86
            np.float16,
87
        )
88
        assert len(sizes) == 17
89
        size = sizes[elem_type]
90
        assert size is not None
91
        return size
92

93
    def get_input_names(model: onnx.ModelProto) -> List[str]:
94
        input_names = list(
95
            set([ipt.name for ipt in model.graph.input])
96
            - set([x.name for x in model.graph.initializer])
97
        )
98
        return input_names
99

100
    def generate_rand_input(
101
        model: Union[str, onnx.ModelProto],
102
        input_shapes: Optional[TensorShapes] = None
103
    ):
104
        if input_shapes is None:
105
            input_shapes = {}
106
        if isinstance(model, str):
107
            model = onnx.load(model, load_external_data=False)
108
        input_names = get_input_names(model)
109
        full_input_shapes = {ipt: get_shape(model, ipt) for ipt in input_names}
110
        assert None not in input_shapes
111
        full_input_shapes.update(input_shapes)  # type: ignore
112
        for name, shape in full_input_shapes.items():
113
            if any([dim <= 0 for dim in shape[1:]]):
114
                raise RuntimeError(
115
                    'The shape of input "{}" has dynamic size, '
116
                    "please set an input shape manually with --test-input-shape".format(name)
117
                )
118
            if len(shape) > 0 and shape[0] <= 0:
119
                print(f'shape[0] of input "{name}" is dynamic, we assume it presents batch size and set it as 1 when testing. If it is not wanted, please set the it manually by --test-input-shape (see `onnxsim -h` for the details).')
120
                shape[0] = 1
121

122
        inputs = {
123
            ipt: np.array(
124
                np.random.rand(*full_input_shapes[ipt]),
125
                dtype=get_np_type_from_elem_type(get_elem_type(model, ipt)),
126
            )
127
            for ipt in input_names
128
        }
129
        return inputs
130

131
    def forward(
132
            model: Union[str, onnx.ModelProto],
133
            inputs: Tensors,
134
            custom_lib: Optional[str] = None
135
    ) -> Dict[str, np.ndarray]:
136
        sess_options = rt.SessionOptions()
137
        if custom_lib is not None:
138
            if os.path.exists(custom_lib):
139
                sess_options.register_custom_ops_library(custom_lib)
140
            else:
141
                raise ValueError("No such file '{}'".format(custom_lib))
142
        sess_options.graph_optimization_level = rt.GraphOptimizationLevel(0)
143
        sess_options.log_severity_level = 3
144
        if isinstance(model, onnx.ModelProto):
145
            model = model.SerializeToString()
146
        sess = rt.InferenceSession(
147
            model,
148
            sess_options=sess_options,
149
            providers=["CPUExecutionProvider"],
150
        )
151
        outputs = [x.name for x in sess.get_outputs()]
152
        run_options = rt.RunOptions()
153
        run_options.log_severity_level = 3
154
        res = OrderedDict(
155
            zip(outputs, sess.run(outputs, inputs, run_options=run_options))
156
        )
157
        return res
158

159
    if input_shapes is None:
160
        input_shapes = {}
161
    onnx.checker.check_model(model_opt)
162
    for i in range(n_times):
163
        print(f'Checking {i}/{n_times}...')
164
        if input_data is None:
165
            inputs = generate_rand_input(model_opt, input_shapes=input_shapes)
166
        else:
167
            inputs = input_data
168
        res_ori = forward(model_ori, inputs, custom_lib)
169
        res_opt = forward(model_opt, inputs, custom_lib)
170

171
        for name in res_opt.keys():
172
            if not np.allclose(res_opt[name], res_ori[name], rtol=1e-4, atol=1e-5):
173
                if verbose:
174
                    print(
175
                        "Tensor {} changes after optimization. The max diff is {}.".format(
176
                            name, np.max(np.abs(res_opt[name] - res_ori[name]))
177
                        )
178
                    )
179
                    print("After optimization:")
180
                    print(res_opt[name])
181
                    print("Before optimization:")
182
                    print(res_ori[name])
183
                    print("----------------")
184
                return False
185
    return True
186

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.