pytorch

Форк
0
/
onnxifi.py 
68 строк · 2.2 Кб
1
## @package onnx
2
#Module caffe2.python.onnx.onnxifi
3

4
"""
5
ONNXIFI a Caffe2 net
6
"""
7

8
from caffe2.proto import caffe2_pb2
9
import caffe2.python._import_c_extension as C
10

11

12
def onnxifi_set_option(option_name, option_value):
13
    """
14
    Set onnxifi option
15
    """
16
    return C.onnxifi_set_option(option_name, str(option_value))
17

18

19
def onnxifi_get_option(option_name):
20
    """
21
    Get onnxifi option
22
    """
23
    return C.onnxifi_get_option(option_name)
24

25
def onnxifi_caffe2_net(
26
        pred_net,
27
        input_shapes,
28
        max_batch_size=1,
29
        max_seq_size=1,
30
        debug=False,
31
        use_onnx=True,
32
        merge_fp32_inputs_into_fp16=False,
33
        adjust_batch=True,
34
        block_list=None,
35
        weight_names=None,
36
        net_ssa_rewritten=False,
37
        timeout=0):
38
    """
39
    Transform the caffe2_net by collapsing ONNXIFI-runnable nodes into Onnxifi c2 ops
40
    """
41
    shape_hints = caffe2_pb2.TensorBoundShapes()
42
    if type(input_shapes) is caffe2_pb2.TensorBoundShapes:
43
        shape_hints = input_shapes
44
    elif type(input_shapes) is dict:
45
        for k, v in input_shapes.items():
46
            tbs = caffe2_pb2.TensorBoundShape()
47
            tbs.name = k
48
            tbs.shape.dims.extend(v)
49
            tbs.dim_type.extend([caffe2_pb2.TensorBoundShape.CONSTANT] * len(tbs.shape.dims))
50
            tbs.dim_type[0] = caffe2_pb2.TensorBoundShape.BATCH
51
            shape_hints.shapes.extend([tbs])
52
        shape_hints.max_batch_size = max_batch_size
53
        shape_hints.max_feature_len = max_seq_size
54
    pred_net_str = C.onnxifi(pred_net.SerializeToString(),
55
                             shape_hints.SerializeToString(),
56
                             block_list if block_list else [],
57
                             weight_names if weight_names is not None else [],
58
                             max_batch_size,
59
                             max_seq_size,
60
                             timeout,
61
                             adjust_batch,
62
                             debug,
63
                             merge_fp32_inputs_into_fp16,
64
                             net_ssa_rewritten,
65
                             use_onnx)
66
    pred_net_cut = caffe2_pb2.NetDef()
67
    pred_net_cut.ParseFromString(pred_net_str)
68
    return pred_net_cut
69

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.