pytorch

Форк
0
/
predictor_py_utils.py 
218 строк · 6.4 Кб
1
## @package predictor_py_utils
2
# Module caffe2.python.predictor.predictor_py_utils
3

4

5
from caffe2.python import core, scope
6

7

8
def create_predict_net(predictor_export_meta):
9
    """
10
    Return the input prediction net.
11
    """
12
    # Construct a new net to clear the existing settings.
13
    net = core.Net(predictor_export_meta.predict_net.name or "predict")
14
    net.Proto().op.extend(predictor_export_meta.predict_net.op)
15
    net.Proto().partition_info.extend(predictor_export_meta.predict_net.partition_info)
16
    net.Proto().external_input.extend(
17
        predictor_export_meta.inputs + predictor_export_meta.parameters
18
    )
19
    net.Proto().external_output.extend(predictor_export_meta.outputs)
20
    net.Proto().arg.extend(predictor_export_meta.predict_net.arg)
21
    if predictor_export_meta.net_type is not None:
22
        net.Proto().type = predictor_export_meta.net_type
23
    if predictor_export_meta.num_workers is not None:
24
        net.Proto().num_workers = predictor_export_meta.num_workers
25
    return net.Proto()
26

27

28
def create_predict_init_net(ws, predictor_export_meta):
29
    """
30
    Return an initialization net that zero-fill all the input and
31
    output blobs, using the shapes from the provided workspace. This is
32
    necessary as there is no shape inference functionality in Caffe2.
33
    """
34
    net = core.Net("predict-init")
35

36
    def zero_fill(blob):
37
        shape = predictor_export_meta.shapes.get(blob)
38
        if shape is None:
39
            if blob not in ws.blobs:
40
                raise Exception(
41
                    "{} not in workspace but needed for shape: {}".format(
42
                        blob, ws.blobs
43
                    )
44
                )
45

46
            shape = ws.blobs[blob].fetch().shape
47

48
        # Explicitly null-out the scope so users (e.g. PredictorGPU)
49
        # can control (at a Net-global level) the DeviceOption of
50
        # these filling operators.
51
        with scope.EmptyDeviceScope():
52
            net.ConstantFill([], blob, shape=shape, value=0.0)
53

54
    external_blobs = predictor_export_meta.inputs + predictor_export_meta.outputs
55
    for blob in external_blobs:
56
        zero_fill(blob)
57

58
    net.Proto().external_input.extend(external_blobs)
59
    if predictor_export_meta.extra_init_net:
60
        net.AppendNet(predictor_export_meta.extra_init_net)
61

62
    # Add the model_id in the predict_net to the init_net
63
    AddModelIdArg(predictor_export_meta, net.Proto())
64

65
    return net.Proto()
66

67

68
def get_comp_name(string, name):
69
    if name:
70
        return string + "_" + name
71
    return string
72

73

74
def to_first_match_dict(kv_list):
75
    """
76
    Construct dict from kv_list
77
    """
78
    d = {}
79
    for item in kv_list:
80
        if item.key not in d:
81
            d[item.key] = item.value
82
    return d
83

84

85
def _ProtoMapGet(field, key):
86
    """
87
    Given the key, get the value of the repeated field.
88
    Helper function used by protobuf since it doesn't have map construct
89
    """
90
    for v in field:
91
        if v.key == key:
92
            return v.value
93
    return None
94

95

96
def GetPlan(meta_net_def, key):
97
    return _ProtoMapGet(meta_net_def.plans, key)
98

99

100
def GetPlanOriginal(meta_net_def, key):
101
    return _ProtoMapGet(meta_net_def.plans, key)
102

103

104
def GetBlobs(meta_net_def, key):
105
    blobs = _ProtoMapGet(meta_net_def.blobs, key)
106
    if blobs is None:
107
        return []
108
    return blobs
109

110

111
def GetBlobsByTypePrefix(meta_net_def, blob_type_prefix):
112
    blob_map = {}
113
    for b in meta_net_def.blobs:
114
        if b.key.startswith(blob_type_prefix):
115
            for blob in b.value:
116
                if blob not in blob_map:
117
                    blob_map[blob] = len(blob_map)
118
    return sorted(blob_map, key=lambda blob: blob_map[blob])
119

120

121
def GetNet(meta_net_def, key):
122
    return _ProtoMapGet(meta_net_def.nets, key)
123

124

125
def GetNetOriginal(meta_net_def, key):
126
    return _ProtoMapGet(meta_net_def.nets, key)
127

128

129
def GetApplicationSpecificInfo(meta_net_def, key):
130
    return _ProtoMapGet(meta_net_def.applicationSpecificInfo, key)
131

132

133
def GetApplicationSpecificInfoDict(meta_net_def):
134
    return to_first_match_dict(meta_net_def.applicationSpecificInfo)
135

136

137
def AddBlobs(meta_net_def, blob_name, blob_def):
138
    blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
139
    if blobs is None:
140
        blobs = meta_net_def.blobs.add()
141
        blobs.key = blob_name
142
        blobs = blobs.value
143
    for blob in blob_def:
144
        blobs.append(blob)
145

146

147
def ReplaceBlobs(meta_net_def, blob_name, blob_def):
148
    blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
149
    assert blobs is not None, "The blob_name:{} does not exist".format(blob_name)
150
    del blobs[:]
151
    for blob in blob_def:
152
        blobs.append(blob)
153

154

155
def AddPlan(meta_net_def, plan_name, plan_def):
156
    meta_net_def.plans.add(key=plan_name, value=plan_def)
157

158

159
def AddNet(meta_net_def, net_name, net_def):
160
    meta_net_def.nets.add(key=net_name, value=net_def)
161

162

163
def SetBlobsOrder(meta_net_def, blobs_order):
164
    for blob in blobs_order:
165
        meta_net_def.blobsOrder.append(blob)
166

167

168
def SetPreLoadBlobs(meta_net_def, pre_load_blobs):
169
    for blob in pre_load_blobs:
170
        meta_net_def.preLoadBlobs.append(blob)
171

172

173
def SetRequestOnlyEmbeddings(meta_net_def, request_only_embeddings):
174
    for blob in request_only_embeddings:
175
        meta_net_def.requestOnlyEmbeddings.append(blob)
176

177

178
def GetBlobsOrder(meta_net_def):
179
    return meta_net_def.blobsOrder
180

181

182
def SetTensorBoundShapes(meta_net_def, tensor_bound_shapes):
183
    meta_net_def.tensorBoundShapes.CopyFrom(tensor_bound_shapes)
184

185

186
def SetAOTConfig(meta_net_def, aot_config):
187
    meta_net_def.aotConfig.CopyFrom(aot_config)
188

189

190
def GetArgumentByName(net_def, arg_name):
191
    for arg in net_def.arg:
192
        if arg.name == arg_name:
193
            return arg
194
    return None
195

196

197
def AddModelIdArg(meta_net_def, net_def):
198
    """Takes the model_id from the predict_net of meta_net_def (if it is
199
    populated) and adds it to the net_def passed in. This is intended to be
200
    called on init_nets, as their model_id is not populated by default, but
201
    should be the same as that of the predict_net
202
    """
203
    # Get model_id from the predict_net, assuming it's an integer
204
    model_id = GetArgumentByName(meta_net_def.predict_net, "model_id")
205
    if model_id is None:
206
        return
207
    model_id = model_id.i
208

209
    # If there's another model_id on the net, replace it with the new one
210
    old_id = GetArgumentByName(net_def, "model_id")
211
    if old_id is not None:
212
        old_id.i = model_id
213
        return
214

215
    # Add as an integer argument, this is also assumed above
216
    arg = net_def.arg.add()
217
    arg.name = "model_id"
218
    arg.i = model_id
219

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.