1
## @package predictor_py_utils
2
# Module caffe2.python.predictor.predictor_py_utils
5
from caffe2.python import core, scope
8
def create_predict_net(predictor_export_meta):
10
Return the input prediction net.
12
# Construct a new net to clear the existing settings.
13
net = core.Net(predictor_export_meta.predict_net.name or "predict")
14
net.Proto().op.extend(predictor_export_meta.predict_net.op)
15
net.Proto().partition_info.extend(predictor_export_meta.predict_net.partition_info)
16
net.Proto().external_input.extend(
17
predictor_export_meta.inputs + predictor_export_meta.parameters
19
net.Proto().external_output.extend(predictor_export_meta.outputs)
20
net.Proto().arg.extend(predictor_export_meta.predict_net.arg)
21
if predictor_export_meta.net_type is not None:
22
net.Proto().type = predictor_export_meta.net_type
23
if predictor_export_meta.num_workers is not None:
24
net.Proto().num_workers = predictor_export_meta.num_workers
28
def create_predict_init_net(ws, predictor_export_meta):
30
Return an initialization net that zero-fill all the input and
31
output blobs, using the shapes from the provided workspace. This is
32
necessary as there is no shape inference functionality in Caffe2.
34
net = core.Net("predict-init")
37
shape = predictor_export_meta.shapes.get(blob)
39
if blob not in ws.blobs:
41
"{} not in workspace but needed for shape: {}".format(
46
shape = ws.blobs[blob].fetch().shape
48
# Explicitly null-out the scope so users (e.g. PredictorGPU)
49
# can control (at a Net-global level) the DeviceOption of
50
# these filling operators.
51
with scope.EmptyDeviceScope():
52
net.ConstantFill([], blob, shape=shape, value=0.0)
54
external_blobs = predictor_export_meta.inputs + predictor_export_meta.outputs
55
for blob in external_blobs:
58
net.Proto().external_input.extend(external_blobs)
59
if predictor_export_meta.extra_init_net:
60
net.AppendNet(predictor_export_meta.extra_init_net)
62
# Add the model_id in the predict_net to the init_net
63
AddModelIdArg(predictor_export_meta, net.Proto())
68
def get_comp_name(string, name):
70
return string + "_" + name
74
def to_first_match_dict(kv_list):
76
Construct dict from kv_list
81
d[item.key] = item.value
85
def _ProtoMapGet(field, key):
87
Given the key, get the value of the repeated field.
88
Helper function used by protobuf since it doesn't have map construct
96
def GetPlan(meta_net_def, key):
97
return _ProtoMapGet(meta_net_def.plans, key)
100
def GetPlanOriginal(meta_net_def, key):
101
return _ProtoMapGet(meta_net_def.plans, key)
104
def GetBlobs(meta_net_def, key):
105
blobs = _ProtoMapGet(meta_net_def.blobs, key)
111
def GetBlobsByTypePrefix(meta_net_def, blob_type_prefix):
113
for b in meta_net_def.blobs:
114
if b.key.startswith(blob_type_prefix):
116
if blob not in blob_map:
117
blob_map[blob] = len(blob_map)
118
return sorted(blob_map, key=lambda blob: blob_map[blob])
121
def GetNet(meta_net_def, key):
122
return _ProtoMapGet(meta_net_def.nets, key)
125
def GetNetOriginal(meta_net_def, key):
126
return _ProtoMapGet(meta_net_def.nets, key)
129
def GetApplicationSpecificInfo(meta_net_def, key):
130
return _ProtoMapGet(meta_net_def.applicationSpecificInfo, key)
133
def GetApplicationSpecificInfoDict(meta_net_def):
134
return to_first_match_dict(meta_net_def.applicationSpecificInfo)
137
def AddBlobs(meta_net_def, blob_name, blob_def):
138
blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
140
blobs = meta_net_def.blobs.add()
141
blobs.key = blob_name
143
for blob in blob_def:
147
def ReplaceBlobs(meta_net_def, blob_name, blob_def):
148
blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
149
assert blobs is not None, "The blob_name:{} does not exist".format(blob_name)
151
for blob in blob_def:
155
def AddPlan(meta_net_def, plan_name, plan_def):
156
meta_net_def.plans.add(key=plan_name, value=plan_def)
159
def AddNet(meta_net_def, net_name, net_def):
160
meta_net_def.nets.add(key=net_name, value=net_def)
163
def SetBlobsOrder(meta_net_def, blobs_order):
164
for blob in blobs_order:
165
meta_net_def.blobsOrder.append(blob)
168
def SetPreLoadBlobs(meta_net_def, pre_load_blobs):
169
for blob in pre_load_blobs:
170
meta_net_def.preLoadBlobs.append(blob)
173
def SetRequestOnlyEmbeddings(meta_net_def, request_only_embeddings):
174
for blob in request_only_embeddings:
175
meta_net_def.requestOnlyEmbeddings.append(blob)
178
def GetBlobsOrder(meta_net_def):
179
return meta_net_def.blobsOrder
182
def SetTensorBoundShapes(meta_net_def, tensor_bound_shapes):
183
meta_net_def.tensorBoundShapes.CopyFrom(tensor_bound_shapes)
186
def SetAOTConfig(meta_net_def, aot_config):
187
meta_net_def.aotConfig.CopyFrom(aot_config)
190
def GetArgumentByName(net_def, arg_name):
191
for arg in net_def.arg:
192
if arg.name == arg_name:
197
def AddModelIdArg(meta_net_def, net_def):
198
"""Takes the model_id from the predict_net of meta_net_def (if it is
199
populated) and adds it to the net_def passed in. This is intended to be
200
called on init_nets, as their model_id is not populated by default, but
201
should be the same as that of the predict_net
203
# Get model_id from the predict_net, assuming it's an integer
204
model_id = GetArgumentByName(meta_net_def.predict_net, "model_id")
207
model_id = model_id.i
209
# If there's another model_id on the net, replace it with the new one
210
old_id = GetArgumentByName(net_def, "model_id")
211
if old_id is not None:
215
# Add as an integer argument, this is also assumed above
216
arg = net_def.arg.add()
217
arg.name = "model_id"