8
from caffe2.proto import caffe2_pb2
9
import caffe2.python._import_c_extension as C
12
def onnxifi_set_option(option_name, option_value):
16
return C.onnxifi_set_option(option_name, str(option_value))
19
def onnxifi_get_option(option_name):
23
return C.onnxifi_get_option(option_name)
25
def onnxifi_caffe2_net(
32
merge_fp32_inputs_into_fp16=False,
36
net_ssa_rewritten=False,
39
Transform the caffe2_net by collapsing ONNXIFI-runnable nodes into Onnxifi c2 ops
41
shape_hints = caffe2_pb2.TensorBoundShapes()
42
if type(input_shapes) is caffe2_pb2.TensorBoundShapes:
43
shape_hints = input_shapes
44
elif type(input_shapes) is dict:
45
for k, v in input_shapes.items():
46
tbs = caffe2_pb2.TensorBoundShape()
48
tbs.shape.dims.extend(v)
49
tbs.dim_type.extend([caffe2_pb2.TensorBoundShape.CONSTANT] * len(tbs.shape.dims))
50
tbs.dim_type[0] = caffe2_pb2.TensorBoundShape.BATCH
51
shape_hints.shapes.extend([tbs])
52
shape_hints.max_batch_size = max_batch_size
53
shape_hints.max_feature_len = max_seq_size
54
pred_net_str = C.onnxifi(pred_net.SerializeToString(),
55
shape_hints.SerializeToString(),
56
block_list if block_list else [],
57
weight_names if weight_names is not None else [],
63
merge_fp32_inputs_into_fp16,
66
pred_net_cut = caffe2_pb2.NetDef()
67
pred_net_cut.ParseFromString(pred_net_str)