pytorch

Форк
0
/
caffe_translator.py 
937 строк · 34.4 Кб
1
## @package caffe_translator
2
# Module caffe2.python.caffe_translator
3

4
import argparse
5
import copy
6
import logging
7
import re
8
import numpy as np  # noqa
9

10
from caffe2.proto import caffe2_pb2, caffe2_legacy_pb2
11
from caffe.proto import caffe_pb2
12
from caffe2.python import core, utils, workspace
13
from google.protobuf import text_format
14

15
logging.basicConfig()
16
log = logging.getLogger("caffe_translator")
17
log.setLevel(logging.INFO)
18

19

20
def _StateMeetsRule(state, rule):
21
    """A function that reproduces Caffe's StateMeetsRule functionality."""
22
    if rule.HasField('phase') and rule.phase != state.phase:
23
        return False
24
    if rule.HasField('min_level') and state.level < rule.min_level:
25
        return False
26
    if rule.HasField('max_level') and state.level > rule.max_level:
27
        return False
28
    curr_stages = set(list(state.stage))
29
    # all stages in rule.stages should be in, otherwise it's not a match.
30
    if len(rule.stage) and any([s not in curr_stages for s in rule.stage]):
31
        return False
32
    # none of the stage in rule.stages should be in, otherwise it's not a match.
33
    if len(rule.not_stage) and any([s in curr_stages for s in rule.not_stage]):
34
        return False
35
    # If none of the nonmatch happens, return True.
36
    return True
37

38

39
def _ShouldInclude(net_state, layer):
40
    """A function that reproduces Caffe's inclusion and exclusion rule."""
41
    ret = (len(layer.include) == 0)
42
    # check exclude rules: if any exclusion is met, we shouldn't include.
43
    ret &= not any([_StateMeetsRule(net_state, rule) for rule in layer.exclude])
44
    if len(layer.include):
45
        # check include rules: if any inclusion is met, we should include.
46
        ret |= any([_StateMeetsRule(net_state, rule) for rule in layer.include])
47
    return ret
48

49

50
def _GetLegacyDims(net, net_params, dummy_input, legacy_pad_ops):
51
    dim_map = {}
52
    ws = workspace.C.Workspace()
53
    for param in net_params.protos:
54
        ws.create_blob(param.name) \
55
            .feed(utils.Caffe2TensorToNumpyArray(param))
56
    external_input = net.op[0].input[0]
57
    ws.create_blob(external_input).feed(dummy_input)
58
    # Get dimensions with legacy pad
59
    for i in range(len(net.op)):
60
        op_def = net.op[i]
61
        ws._run_operator(op_def.SerializeToString())
62
        if i in legacy_pad_ops:
63
            output = op_def.output[0]
64
            blob_legacy = ws.fetch_blob(output)
65
            dim_map[i] = blob_legacy.shape
66
    return dim_map
67

68

69
def _GetLegacyPadArgs(op_def, arg_map):
70
    pads = {}
71
    keys = ['pad_l', 'pad_t', 'pad_r', 'pad_b']
72
    is_pad = 'pad' in arg_map
73
    if is_pad:
74
        for k in keys:
75
            pads[k] = arg_map['pad'].i
76
    else:
77
        pads = {x: arg_map[x].i for x in keys}
78
    return pads
79

80

81
def _AdjustDims(op_def, arg_map, pads, dim1, dim2):
82
    n1, c1, h1, w1 = dim1
83
    n2, c2, h2, w2 = dim2
84
    assert(n1 == n2)
85
    assert(c1 == c2)
86
    is_pad = 'pad' in arg_map
87
    if h1 != h2 or w1 != w2:
88
        if h1 == h2 + 1:
89
            pads['pad_b'] += 1
90
        elif h1 != h2:
91
            raise Exception("Unexpected dimensions for height:", h1, h2)
92
        if w1 == w2 + 1:
93
            pads['pad_r'] += 1
94
        elif w1 != w2:
95
            raise Exception("Unexpected dimensions for width:", w1, w2)
96
        if is_pad:
97
            op_def.arg.remove(arg_map['pad'])
98
            args = []
99
            for name in pads.keys():
100
                arg = caffe2_pb2.Argument()
101
                arg.name = name
102
                arg.i = pads[name]
103
                args.append(arg)
104
            op_def.arg.extend(args)
105
        else:
106
            for name in pads.keys():
107
                arg_map[name].i = pads[name]
108

109

110
def _RemoveLegacyPad(net, net_params, input_dims):
111
    legacy_pad_ops = []
112
    for i in range(len(net.op)):
113
        op_def = net.op[i]
114
        if re.match(r'^(Conv|ConvTranspose|MaxPool|AveragePool)(\dD)?$',
115
                    op_def.type):
116
            for arg in op_def.arg:
117
                if arg.name == 'legacy_pad':
118
                    legacy_pad_ops.append(i)
119
                    break
120
    if legacy_pad_ops:
121
        n, c, h, w = input_dims
122
        dummy_input = np.random.randn(n, c, h, w).astype(np.float32)
123
        dim_map = _GetLegacyDims(net, net_params, dummy_input, legacy_pad_ops)
124

125
        # Running with the legacy pad argument removed
126
        # compare the dimensions and adjust pad argument when necessary
127
        ws = workspace.C.Workspace()
128

129
        external_input = net.op[0].input[0]
130
        ws.create_blob(external_input).feed_blob(dummy_input)
131
        for param in net_params.protos:
132
            ws.create_blob(param.name) \
133
              .feed_blob(utils.Caffe2TensorToNumpyArray(param))
134

135
        for i in range(len(net.op)):
136
            op_def = net.op[i]
137
            if i in legacy_pad_ops:
138
                arg_map = {}
139
                for arg in op_def.arg:
140
                    arg_map[arg.name] = arg
141
                pads = _GetLegacyPadArgs(op_def, arg_map)
142
                # remove legacy pad arg
143
                for j in range(len(op_def.arg)):
144
                    arg = op_def.arg[j]
145
                    if arg.name == 'legacy_pad':
146
                        del op_def.arg[j]
147
                        break
148
                output = op_def.output[0]
149
                # use a new name to avoid the interference with inplace
150
                nonlegacy_output = output + '_nonlegacy'
151
                op_def.output[0] = nonlegacy_output
152
                ws._run_operator(op_def.SerializeToString())
153
                blob_nonlegacy = ws.fetch_blob(nonlegacy_output)
154
                # reset output name
155
                op_def.output[0] = output
156

157
                dim1 = dim_map[i]
158
                dim2 = blob_nonlegacy.shape
159
                _AdjustDims(op_def, arg_map, pads, dim1, dim2)
160

161
            ws._run_operator(op_def.SerializeToString())
162
    return net
163

164

165
def _GetBlobDimMap(net, net_params, dummy_input):
166
    dim_map = {}
167
    ws = workspace.C.Workspace()
168
    for param in net_params.protos:
169
        ws.create_blob(param.name) \
170
          .feed(utils.Caffe2TensorToNumpyArray(param))
171
    external_input = net.op[0].input[0]
172
    ws.create_blob(external_input).feed(dummy_input)
173
    # Get dimensions with legacy pad
174
    for i in range(len(net.op)):
175
        op_def = net.op[i]
176
        ws._run_operator(op_def.SerializeToString())
177
        for output in op_def.output:
178
            blob = ws.fetch_blob(output)
179
            dim_map[output] = blob.shape
180
    return dim_map
181

182

183
def _GetInputDims(caffe_net):
184
    input_dims = []
185
    if caffe_net.input_dim:
186
        input_dims = caffe_net.input_dim
187
    elif caffe_net.input_shape:
188
        input_dims = caffe_net.input_shape[0].dim
189
    elif caffe_net.layer[0].input_param.shape:
190
        # getting input dimension from first layer
191
        input_dims = caffe_net.layer[0].input_param.shape[0].dim
192
    return input_dims
193

194

195
class TranslatorRegistry:
196
    registry_ = {}
197

198
    @classmethod
199
    def Register(cls, op_name):
200
        """A decorator for registering gradient mappings."""
201

202
        def Wrapper(func):
203
            cls.registry_[op_name] = func
204
            return func
205

206
        return Wrapper
207

208
    @classmethod
209
    def TranslateLayer(cls, layer, pretrained_blobs, is_test, **kwargs):
210
        try:
211
            caffe_ops, params = cls.registry_[layer.type](
212
                layer, pretrained_blobs, is_test, **kwargs)
213
        except KeyError as e:
214
            raise KeyError('No translator registered for layer: %s yet.' %
215
                           str(layer)) from e
216
        if caffe_ops is None:
217
            caffe_ops = []
218
        if type(caffe_ops) is not list:
219
            caffe_ops = [caffe_ops]
220
        return caffe_ops, params
221

222
    @classmethod
223
    def TranslateModel(
224
        cls,
225
        caffe_net,
226
        pretrained_net,
227
        is_test=False,
228
        net_state=None,
229
        remove_legacy_pad=False,
230
        input_dims=None
231
    ):
232
        net_state = caffe_pb2.NetState() if net_state is None else net_state
233
        net = caffe2_pb2.NetDef()
234
        net.name = caffe_net.name
235
        net_params = caffe2_pb2.TensorProtos()
236
        if len(caffe_net.layers) > 0:
237
            raise ValueError(
238
                'I think something is wrong. This translation script '
239
                'only accepts new style layers that are stored in the '
240
                'layer field.'
241
            )
242
        if not input_dims:
243
            input_dims = _GetInputDims(caffe_net)
244
        for layer in caffe_net.layer:
245
            if not _ShouldInclude(net_state, layer):
246
                log.info('Current net state does not need layer {}'
247
                            .format(layer.name))
248
                continue
249
            log.info('Translate layer {}'.format(layer.name))
250
            # Get pretrained one
251
            pretrained_layers = (
252
                [l for l in pretrained_net.layer
253
                 if l.name == layer.name] + [l
254
                                             for l in pretrained_net.layers
255
                                             if l.name == layer.name]
256
            )
257
            if len(pretrained_layers) > 1:
258
                raise ValueError(
259
                    'huh? more than one pretrained layer of one name?')
260
            elif len(pretrained_layers) == 1:
261
                pretrained_blobs = [
262
                    utils.CaffeBlobToNumpyArray(blob)
263
                    for blob in pretrained_layers[0].blobs
264
                ]
265
            else:
266
                # No pretrained layer for the given layer name. We'll just pass
267
                # no parameter blobs.
268
                # print 'No pretrained layer for layer', layer.name
269
                pretrained_blobs = []
270
            operators, params = cls.TranslateLayer(
271
                layer, pretrained_blobs, is_test, net=net,
272
                net_params=net_params, input_dims=input_dims)
273
            net.op.extend(operators)
274
            net_params.protos.extend(params)
275
        if remove_legacy_pad:
276
            assert input_dims, \
277
                   'Please specify input_dims to remove legacy_pad'
278
            net = _RemoveLegacyPad(net, net_params, input_dims)
279
        return net, net_params
280

281

282
def TranslateModel(*args, **kwargs):
283
    return TranslatorRegistry.TranslateModel(*args, **kwargs)
284

285

286
def ConvertTensorProtosToInitNet(net_params, input_name):
287
    """Takes the net_params returned from TranslateModel, and wrap it as an
288
    init net that contain GivenTensorFill.
289

290
    This is a very simple feature that only works with float tensors, and is
291
    only intended to be used in an environment where you want a single
292
    initialization file - for more complex cases, use a db to store the
293
    parameters.
294
    """
295
    init_net = caffe2_pb2.NetDef()
296
    for tensor in net_params.protos:
297
        if len(tensor.float_data) == 0:
298
            raise RuntimeError(
299
                "Only float tensors are supported in this util.")
300
        op = core.CreateOperator(
301
            "GivenTensorFill", [], [tensor.name],
302
            arg=[
303
                utils.MakeArgument("shape", list(tensor.dims)),
304
                utils.MakeArgument("values", tensor.float_data)])
305
        init_net.op.extend([op])
306
    init_net.op.extend([core.CreateOperator("ConstantFill", [], [input_name], shape=[1])])
307
    return init_net
308

309

310
def BaseTranslate(layer, caffe2_type):
311
    """A simple translate interface that maps the layer input and output."""
312
    caffe2_op = caffe2_pb2.OperatorDef()
313
    caffe2_op.type = caffe2_type
314
    caffe2_op.input.extend(layer.bottom)
315
    caffe2_op.output.extend(layer.top)
316
    return caffe2_op
317

318

319
def AddArgument(op, key, value):
320
    """Makes an argument based on the value type."""
321
    op.arg.extend([utils.MakeArgument(key, value)])
322

323
################################################################################
324
# Common translators for layers.
325
################################################################################
326

327

328
@TranslatorRegistry.Register("Input")
329
def TranslateInput(layer, pretrained_blobs, is_test, **kwargs):
330
    return [], []
331

332

333
@TranslatorRegistry.Register("VideoData")
334
def TranslateVideoData(layer, pretrained_blobs, is_test, **kwargs):
335
    return [], []
336

337

338
@TranslatorRegistry.Register("Data")
339
def TranslateData(layer, pretrained_blobs, is_test, **kwargs):
340
    return [], []
341

342

343
# A function used in convolution, pooling and deconvolution to deal with
344
# conv pool specific parameters.
345
def _TranslateStridePadKernelHelper(param, caffe_op):
346
    try:
347
        if (len(param.stride) > 1 or len(param.kernel_size) > 1 or
348
                len(param.pad) > 1):
349
            raise NotImplementedError(
350
                "Translator currently does not support non-conventional "
351
                "pad/kernel/stride settings."
352
            )
353
        stride = param.stride[0] if len(param.stride) else 1
354
        pad = param.pad[0] if len(param.pad) else 0
355
        kernel = param.kernel_size[0] if len(param.kernel_size) else 0
356
    except TypeError:
357
        # This catches the case of a PoolingParameter, in which case we are
358
        # having non-repeating pad, stride and kernel.
359
        stride = param.stride
360
        pad = param.pad
361
        kernel = param.kernel_size
362
    # Get stride
363
    if param.HasField("stride_h") or param.HasField("stride_w"):
364
        AddArgument(caffe_op, "stride_h", param.stride_h)
365
        AddArgument(caffe_op, "stride_w", param.stride_w)
366
    else:
367
        AddArgument(caffe_op, "stride", stride)
368
    # Get pad
369
    if param.HasField("pad_h") or param.HasField("pad_w"):
370
        if param.pad_h == param.pad_w:
371
            AddArgument(caffe_op, "pad", param.pad_h)
372
        else:
373
            AddArgument(caffe_op, "pad_t", param.pad_h)
374
            AddArgument(caffe_op, "pad_b", param.pad_h)
375
            AddArgument(caffe_op, "pad_l", param.pad_w)
376
            AddArgument(caffe_op, "pad_r", param.pad_w)
377
    else:
378
        AddArgument(caffe_op, "pad", pad)
379
    # Get kernel
380
    if param.HasField("kernel_h") or param.HasField("kernel_w"):
381
        AddArgument(caffe_op, "kernel_h", param.kernel_h)
382
        AddArgument(caffe_op, "kernel_w", param.kernel_w)
383
    else:
384
        AddArgument(caffe_op, "kernel", kernel)
385

386

387
@TranslatorRegistry.Register("Convolution3D")
388
def TranslateConvNd(layer, pretrained_blobs, is_test, **kwargs):
389
    param = layer.convolution3d_param
390
    caffe_op = BaseTranslate(layer, "Conv")
391
    output = caffe_op.output[0]
392
    caffe_op.input.append(output + '_w')
393

394
    AddArgument(
395
        caffe_op,
396
        "kernels",
397
        [param.kernel_depth, param.kernel_size, param.kernel_size])
398
    AddArgument(
399
        caffe_op,
400
        "strides",
401
        [param.temporal_stride, param.stride, param.stride])
402
    temporal_pad = 0
403
    spatial_pad = 0
404
    if hasattr(param, 'temporal_pad'):
405
        temporal_pad = param.temporal_pad
406
    if hasattr(param, 'pad'):
407
        spatial_pad = param.pad
408
    AddArgument(caffe_op, "pads", [temporal_pad, spatial_pad, spatial_pad] * 2)
409

410
    # weight
411
    params = [
412
        utils.NumpyArrayToCaffe2Tensor(pretrained_blobs[0], output + '_w')]
413
    # bias
414
    if len(pretrained_blobs) == 2:
415
        caffe_op.input.append(output + '_b')
416
        params.append(
417
            utils.NumpyArrayToCaffe2Tensor(
418
                pretrained_blobs[1].flatten(), output + '_b'))
419
    return caffe_op, params
420

421

422
@TranslatorRegistry.Register("Convolution")
423
def TranslateConv(layer, pretrained_blobs, is_test, **kwargs):
424
    param = layer.convolution_param
425
    caffe_op = BaseTranslate(layer, "Conv")
426
    output = caffe_op.output[0]
427
    caffe_op.input.append(output + '_w')
428
    _TranslateStridePadKernelHelper(param, caffe_op)
429
    # weight
430
    params = [
431
        utils.NumpyArrayToCaffe2Tensor(pretrained_blobs[0], output + '_w')]
432
    # bias
433
    if len(pretrained_blobs) == 2:
434
        caffe_op.input.append(output + '_b')
435
        params.append(
436
            utils.NumpyArrayToCaffe2Tensor(
437
                pretrained_blobs[1].flatten(), output + '_b'))
438
    # Group convolution option
439
    if param.group != 1:
440
        AddArgument(caffe_op, "group", param.group)
441
    # Get dilation - not tested. If you have a model and this checks out,
442
    # please provide a test and uncomment this.
443
    if len(param.dilation) > 0:
444
        if len(param.dilation) == 1:
445
            AddArgument(caffe_op, "dilation", param.dilation[0])
446
        elif len(param.dilation) == 2:
447
            AddArgument(caffe_op, "dilation_h", param.dilation[0])
448
            AddArgument(caffe_op, "dilation_w", param.dilation[1])
449
    return caffe_op, params
450

451

452
@TranslatorRegistry.Register("Deconvolution")
453
def TranslateDeconv(layer, pretrained_blobs, is_test, **kwargs):
454
    param = layer.convolution_param
455
    if param.group > 1:
456
        raise NotImplementedError(
457
            "Translator currently does not support group deconvolution."
458
        )
459
    caffe_op = BaseTranslate(layer, "ConvTranspose")
460
    output = caffe_op.output[0]
461
    _TranslateStridePadKernelHelper(param, caffe_op)
462
    caffe_op.input.extend([output + '_w'])
463
    AddArgument(caffe_op, "order", "NCHW")
464
    weight = utils.NumpyArrayToCaffe2Tensor(pretrained_blobs[0], output + '_w')
465
    if param.bias_term:
466
        bias = utils.NumpyArrayToCaffe2Tensor(
467
            pretrained_blobs[1].flatten(), output + '_b'
468
        )
469
        caffe_op.input.extend([output + '_b'])
470
        return caffe_op, [weight, bias]
471
    else:
472
        return caffe_op, [weight]
473

474

475
@TranslatorRegistry.Register("Crop")
476
def TranslateCrop(layer, pretrained_blobs, is_test, **kwargs):
477
    net, net_params, input_dims = kwargs['net'], kwargs['net_params'], kwargs['input_dims']
478
    n, c, h, w = input_dims
479
    dummy_input = np.random.randn(n, c, h, w).astype(np.float32)
480
    dim_map = _GetBlobDimMap(net, net_params, dummy_input)
481
    param = layer.crop_param
482
    axis, offsets = param.axis, param.offset
483
    caffe_op = BaseTranslate(layer, "Slice")
484
    input_1 = caffe_op.input[1]
485
    input_1_dim = dim_map[input_1]
486
    starts, ends = [], []
487
    dims = len(dim_map[input_1])
488
    assert len(offsets) == 1, 'Caffe Translator for Crop only works for offset \
489
    of 1 for now'
490
    for _ in range(axis):
491
        starts.append(0)
492
        ends.append(-1)
493
    end_offset = [int(offsets[0] + input_1_dim[i]) for i in range(axis, dims)]
494
    ends.extend(end_offset)
495
    starts.extend([offsets[0]] * len(end_offset))
496
    op = caffe2_pb2.OperatorDef()
497
    op.input.extend([caffe_op.input[0]])
498
    op.output.extend(caffe_op.output)
499
    op.arg.extend(caffe_op.arg)
500
    op.type = caffe_op.type
501
    AddArgument(op, "starts", starts)
502
    AddArgument(op, "ends", ends)
503
    return op, []
504

505
@TranslatorRegistry.Register("ReLU")
506
def TranslateRelu(layer, pretrained_blobs, is_test, **kwargs):
507
    return BaseTranslate(layer, "Relu"), []
508

509

510
@TranslatorRegistry.Register("Pooling")
511
def TranslatePool(layer, pretrained_blobs, is_test, **kwargs):
512
    param = layer.pooling_param
513
    if param.pool == caffe_pb2.PoolingParameter.MAX:
514
        caffe_op = BaseTranslate(layer, "MaxPool")
515
    elif param.pool == caffe_pb2.PoolingParameter.AVE:
516
        caffe_op = BaseTranslate(layer, "AveragePool")
517
    _TranslateStridePadKernelHelper(param, caffe_op)
518
    AddArgument(caffe_op, "order", "NCHW")
519
    try:
520
        # In the Facebook port of Caffe, a torch_pooling field was added to
521
        # map the pooling computation of Torch. Essentially, it uses
522
        #   floor((height + 2 * padding - kernel) / stride) + 1
523
        # instead of
524
        #   ceil((height + 2 * padding - kernel) / stride) + 1
525
        # which is Caffe's version.
526
        # Torch pooling is actually the same as Caffe2 pooling, so we don't
527
        # need to do anything.
528
        is_torch_pooling = param.torch_pooling
529
    except AttributeError:
530
        is_torch_pooling = False
531
    if not is_torch_pooling:
532
        AddArgument(caffe_op, "legacy_pad",
533
                    caffe2_legacy_pb2.CAFFE_LEGACY_POOLING)
534
    if param.global_pooling:
535
        AddArgument(caffe_op, "global_pooling", 1)
536
    return caffe_op, []
537

538

539
@TranslatorRegistry.Register("Pooling3D")
540
def TranslatePool3D(layer, pretrained_blobs, is_test, **kwargs):
541
    param = layer.pooling3d_param
542
    if param.pool == caffe_pb2.Pooling3DParameter.MAX:
543
        caffe_op = BaseTranslate(layer, "MaxPool")
544

545
    elif param.pool == caffe_pb2.Pooling3DParameter.AVE:
546
        caffe_op = BaseTranslate(layer, "AveragePool")
547
    AddArgument(caffe_op, "order", "NCHW")
548
    AddArgument(
549
        caffe_op,
550
        "kernels",
551
        [param.kernel_depth, param.kernel_size, param.kernel_size])
552

553
    AddArgument(
554
        caffe_op,
555
        "strides",
556
        [param.temporal_stride, param.stride, param.stride])
557
    temporal_pad = 0
558
    spatial_pad = 0
559
    if hasattr(param, 'temporal_pad'):
560
        temporal_pad = param.temporal_pad
561
    if hasattr(param, 'pad'):
562
        spatial_pad = param.pad
563
    AddArgument(caffe_op, "pads", [temporal_pad, spatial_pad, spatial_pad] * 2)
564
    return caffe_op, []
565

566

567
@TranslatorRegistry.Register("LRN")
568
def TranslateLRN(layer, pretrained_blobs, is_test, **kwargs):
569
    caffe_op = BaseTranslate(layer, "LRN")
570
    caffe_op.output.extend(['_' + caffe_op.output[0] + '_scale'])
571
    param = layer.lrn_param
572
    if param.norm_region != caffe_pb2.LRNParameter.ACROSS_CHANNELS:
573
        raise ValueError(
574
            "Does not support norm region other than across channels.")
575
    AddArgument(caffe_op, "size", int(param.local_size))
576
    AddArgument(caffe_op, "alpha", float(param.alpha))
577
    AddArgument(caffe_op, "beta", float(param.beta))
578
    AddArgument(caffe_op, "bias", float(param.k))
579
    AddArgument(caffe_op, "order", "NCHW")
580
    return caffe_op, []
581

582

583
@TranslatorRegistry.Register("InnerProduct")
584
def TranslateInnerProduct(layer, pretrained_blobs, is_test, **kwargs):
585
    param = layer.inner_product_param
586
    try:
587
        if param.axis != 1 or param.transpose:
588
            raise ValueError(
589
                "We don't have testing case for non-default axis and transpose "
590
                "cases yet so we are disabling it for now. If you have a model "
591
                "with this, please do send us your model for us to update this "
592
                "support, and you are more than welcome to send a PR for this.")
593
    except AttributeError:
594
        # We might be using an historic Caffe protobuf that does not have axis
595
        # and transpose arguments, so we will silently pass.
596
        pass
597
    caffe_op = BaseTranslate(layer, "FC")
598
    output = caffe_op.output[0]
599
    caffe_op.input.extend([output + '_w', output + '_b'])
600
    # To provide the old-style 4-dimensional blob (1, 1, dim_output, dim_input)
601
    # case, we always explicitly reshape the pretrained blob.
602
    if pretrained_blobs[0].ndim not in [2, 4]:
603
        raise ValueError("Unexpected weight ndim.")
604
    if (pretrained_blobs[0].ndim == 4 and
605
            list(pretrained_blobs[0].shape[:2]) != [1, 1]):
606
        raise ValueError(
607
            "If pretrained blob has 4 dims (old-style Caffe), the first two "
608
            "should be of value 1, but I got " + str(pretrained_blobs[0].shape))
609
    weight = utils.NumpyArrayToCaffe2Tensor(
610
        pretrained_blobs[0].reshape(-1, pretrained_blobs[0].shape[-1]),
611
        output + '_w'
612
    )
613
    bias = utils.NumpyArrayToCaffe2Tensor(
614
        pretrained_blobs[1].flatten(), output + '_b'
615
    )
616
    return caffe_op, [weight, bias]
617

618

619
@TranslatorRegistry.Register("Dropout")
620
def TranslateDropout(layer, pretrained_blobs, is_test, **kwargs):
621
    caffe_op = BaseTranslate(layer, "Dropout")
622
    caffe_op.output.extend(['_' + caffe_op.output[0] + '_mask'])
623
    param = layer.dropout_param
624
    AddArgument(caffe_op, "ratio", param.dropout_ratio)
625
    if (is_test):
626
        AddArgument(caffe_op, "is_test", 1)
627
    return caffe_op, []
628

629

630
@TranslatorRegistry.Register("Softmax")
631
def TranslateSoftmax(layer, pretrained_blobs, is_test, **kwargs):
632
    caffe_op = BaseTranslate(layer, "Softmax")
633
    return caffe_op, []
634

635

636
@TranslatorRegistry.Register("SoftmaxWithLoss")
637
def TranslateSoftmaxWithLoss(layer, pretrained_blobs, is_test, **kwargs):
638
    softmax_op = core.CreateOperator(
639
        "Softmax", [layer.bottom[0]],
640
        layer.bottom[0] + "_translator_autogen_softmax")
641
    xent_op = core.CreateOperator(
642
        "LabelCrossEntropy",
643
        [softmax_op.output[0], layer.bottom[1]],
644
        layer.bottom[0] + "_translator_autogen_xent")
645
    loss_op = core.CreateOperator(
646
        "AveragedLoss",
647
        xent_op.output[0],
648
        layer.top[0])
649
    return [softmax_op, xent_op, loss_op], []
650

651

652
@TranslatorRegistry.Register("Accuracy")
653
def TranslateAccuracy(layer, pretrained_blobs, is_test, **kwargs):
654
    caffe_op = BaseTranslate(layer, "Accuracy")
655
    if layer.accuracy_param.top_k != 1:
656
        AddArgument(caffe_op, "top_k", layer.accuracy_param.top_k)
657
    return caffe_op, []
658

659

660
@TranslatorRegistry.Register("Concat")
661
def TranslateConcat(layer, pretrained_blobs, is_test, **kwargs):
662
    caffe_op = BaseTranslate(layer, "Concat")
663
    caffe_op.output.extend(['_' + caffe_op.output[0] + '_dims'])
664
    AddArgument(caffe_op, "order", "NCHW")
665
    return caffe_op, []
666

667

668
@TranslatorRegistry.Register("TanH")
669
def TranslateTanH(layer, pretrained_blobs, is_test, **kwargs):
670
    caffe_op = BaseTranslate(layer, "Tanh")
671
    return caffe_op, []
672

673

674
@TranslatorRegistry.Register("InstanceNorm")
675
def TranslateInstanceNorm(layer, pretrained_blobs, is_test, **kwargs):
676
    caffe_op = BaseTranslate(layer, "InstanceNorm")
677
    output = caffe_op.output[0]
678
    weight = utils.NumpyArrayToCaffe2Tensor(
679
        pretrained_blobs[0].flatten(), output + '_w')
680
    bias = utils.NumpyArrayToCaffe2Tensor(
681
        pretrained_blobs[1].flatten(), output + '_b')
682
    caffe_op.input.extend([output + '_w', output + '_b'])
683
    AddArgument(caffe_op, "order", "NCHW")
684
    return caffe_op, [weight, bias]
685

686

687
@TranslatorRegistry.Register("BatchNorm")
688
def TranslateBatchNorm(layer, pretrained_blobs, is_test, **kwargs):
689
    caffe_op = BaseTranslate(layer, "SpatialBN")
690
    output = caffe_op.output[0]
691
    param = layer.batch_norm_param
692
    AddArgument(caffe_op, "is_test", is_test)
693
    AddArgument(caffe_op, "epsilon", param.eps)
694
    AddArgument(caffe_op, "order", "NCHW")
695

696
    caffe_op.input.extend(
697
        [output + "_scale",
698
         output + "_bias",
699
         output + "_mean",
700
         output + "_var"])
701
    if not is_test:
702
        caffe_op.output.extend(
703
            [output + "_mean",
704
             output + "_var",
705
             output + "_saved_mean",
706
             output + "_saved_var"])
707

708
    n_channels = pretrained_blobs[0].shape[0]
709
    if pretrained_blobs[2][0] != 0:
710
        mean = utils.NumpyArrayToCaffe2Tensor(
711
            (1. / pretrained_blobs[2][0]) * pretrained_blobs[0],
712
            output + '_mean')
713
        var = utils.NumpyArrayToCaffe2Tensor(
714
            (1. / pretrained_blobs[2][0]) * pretrained_blobs[1],
715
            output + '_var')
716
    else:
717
        raise RuntimeError("scalar is zero.")
718
    if len(pretrained_blobs) > 3:
719
        # IntelCaffe and NVCaffe uses fused BN+Scale,
720
        # three blobs for BN and two blobs for Scale,
721
        # so that the total number of blobs becomes five (including scale and bias).
722
        scale = utils.NumpyArrayToCaffe2Tensor(
723
            pretrained_blobs[3].flatten(),
724
            output + '_scale')
725
        bias = utils.NumpyArrayToCaffe2Tensor(
726
            pretrained_blobs[4].flatten(),
727
            output + '_bias')
728
    else:
729
        pretrained_blobs[2][0] = 1
730
        pretrained_blobs[2] = np.tile(pretrained_blobs[2], (n_channels, ))
731
        scale = utils.NumpyArrayToCaffe2Tensor(
732
            pretrained_blobs[2],
733
            output + '_scale')
734
        bias = utils.NumpyArrayToCaffe2Tensor(
735
            np.zeros_like(pretrained_blobs[2]),
736
            output + '_bias')
737

738
    return caffe_op, [scale, bias, mean, var]
739

740

741
@TranslatorRegistry.Register("Eltwise")
742
def TranslateElementWise(layer, pretrained_blobs, is_test, **kwargs):
743
    param = layer.eltwise_param
744
    # TODO(jiayq): if we have a protobuf that uses this, lift this constraint
745
    # and verify that we can correctly translate.
746
    if len(param.coeff) or param.operation != 1:
747
        raise RuntimeError("This eltwise layer is not yet supported.")
748
    caffe_op = BaseTranslate(layer, "Sum")
749
    return caffe_op, []
750

751

752
@TranslatorRegistry.Register("Scale")
753
def TranslateScale(layer, pretrained_blobs, is_test, **kwargs):
754
    mul_op = BaseTranslate(layer, "Mul")
755
    scale_param = layer.scale_param
756
    AddArgument(mul_op, "axis", scale_param.axis)
757
    AddArgument(mul_op, "broadcast", True)
758
    if len(mul_op.input) == 1:
759
        # the scale parameter is in pretrained blobs
760
        if scale_param.num_axes != 1:
761
            raise RuntimeError("This path has not been verified yet.")
762

763
        output = mul_op.output[0]
764
        mul_op_param = output + 'scale_w'
765
        mul_op.input.append(mul_op_param)
766
        weights = []
767
        weights.append(utils.NumpyArrayToCaffe2Tensor(
768
            pretrained_blobs[0].flatten(), mul_op_param))
769

770
        add_op = None
771
        if len(pretrained_blobs) == 1:
772
            # No bias-term in Scale layer
773
            pass
774
        elif len(pretrained_blobs) == 2:
775
            # Caffe Scale layer supports a bias term such that it computes
776
            # (scale_param * X + bias), whereas Caffe2 Mul op doesn't.
777
            # Include a separate Add op for the bias followed by Mul.
778
            add_op = copy.deepcopy(mul_op)
779
            add_op.type = "Add"
780
            add_op_param = output + 'scale_b'
781
            internal_blob = output + "_internal"
782
            del mul_op.output[:]
783
            mul_op.output.append(internal_blob)
784
            del add_op.input[:]
785
            add_op.input.append(internal_blob)
786
            add_op.input.append(add_op_param)
787
            weights.append(utils.NumpyArrayToCaffe2Tensor(
788
                pretrained_blobs[1].flatten(), add_op_param))
789
        else:
790
            raise RuntimeError("Unexpected number of pretrained blobs in Scale")
791

792
        caffe_ops = [mul_op]
793
        if add_op:
794
            caffe_ops.append(add_op)
795
        assert len(caffe_ops) == len(weights)
796
        return caffe_ops, weights
797
    elif len(mul_op.input) == 2:
798
        # TODO(jiayq): find a protobuf that uses this and verify.
799
        raise RuntimeError("This path has not been verified yet.")
800
    else:
801
        raise RuntimeError("Unexpected number of inputs.")
802

803

804
@TranslatorRegistry.Register("Reshape")
805
def TranslateReshape(layer, pretrained_blobs, is_test, **kwargs):
806
    caffe_op = BaseTranslate(layer, "Reshape")
807
    caffe_op.output.append("_" + caffe_op.input[0] + "_dims")
808
    reshape_param = layer.reshape_param
809
    AddArgument(caffe_op, 'shape', reshape_param.shape.dim)
810
    return caffe_op, []
811

812

813
@TranslatorRegistry.Register("Flatten")
814
def TranslateFlatten(layer, pretrained_blobs, is_test, **kwargs):
815
    param = layer.flatten_param
816
    if param.end_axis != -1:
817
        raise NotImplementedError("flatten_param.end_axis not supported yet.")
818

819
    if param.axis == 0:
820
        caffe_op = BaseTranslate(layer, "FlattenToVec")
821
    elif param.axis == 1:
822
        caffe_op = BaseTranslate(layer, "Flatten")
823
    else:
824
        # This could be a Reshape op, but dim size is not known here.
825
        raise NotImplementedError(
826
            "Not supported yet for flatten_param.axis {}.".format(param.axis))
827

828
    return caffe_op, []
829

830

831
@TranslatorRegistry.Register("Sigmoid")
832
def TranslateSigmoid(layer, pretrained_blobs, is_test, **kwargs):
833
    caffe_op = BaseTranslate(layer, "Sigmoid")
834
    return caffe_op, []
835

836

837
@TranslatorRegistry.Register("ROIPooling")
838
def TranslateROIPooling(layer, pretrained_blobs, is_test, **kwargs):
839
    caffe_op = BaseTranslate(layer, "RoIPool")
840
    AddArgument(caffe_op, "order", "NCHW")
841

842
    if is_test:
843
        AddArgument(caffe_op, "is_test", is_test)
844
    else:
845
        # Only used for gradient computation
846
        caffe_op.output.append(caffe_op.output[0] + '_argmaxes')
847

848
    param = layer.roi_pooling_param
849
    if param.HasField('pooled_h'):
850
        AddArgument(caffe_op, 'pooled_h', param.pooled_h)
851
    if param.HasField('pooled_w'):
852
        AddArgument(caffe_op, 'pooled_w', param.pooled_w)
853
    if param.HasField('spatial_scale'):
854
        AddArgument(caffe_op, 'spatial_scale', param.spatial_scale)
855

856
    return caffe_op, []
857

858

859
@TranslatorRegistry.Register("PReLU")
860
def TranslatePRelu(layer, pretrained_blobs, is_test, **kwargs):
861
    caffe_op = BaseTranslate(layer, "PRelu")
862
    output = caffe_op.output[0]
863
    caffe_op.input.extend([output + '_Slope'])
864
    slope = utils.NumpyArrayToCaffe2Tensor(pretrained_blobs[0], output + '_Slope')
865

866
    return caffe_op, [slope]
867

868

869
@TranslatorRegistry.Register("Reduction")
870
def TranslateReduction(layer, pretrained_blobs, is_test, **kwargs):
871
    param = layer.reduction_param
872
    if param.operation == caffe_pb2.ReductionParameter.SUM:
873
        caffe_op = BaseTranslate(layer, "ReduceBackSum")
874
    elif param.operation == caffe_pb2.ReductionParameter.MEAN:
875
        caffe_op = BaseTranslate(layer, "ReduceBackMean")
876
    else:
877
        raise NotImplementedError("Not yet supported")
878

879
    if param.axis > 0:
880
        # We can't figure out the number of dims to reduce from positive axis
881
        # for back reduction since the shape info is not known here.
882
        raise NotImplementedError("Not yet supported")
883
    num_reduce_dim = -param.axis
884
    AddArgument(caffe_op, "num_reduce_dim", num_reduce_dim)
885

886
    return caffe_op, []
887

888

889
if __name__ == '__main__':
890
    parser = argparse.ArgumentParser(
891
        description="Utilitity to convert pretrained caffe models to Caffe2 models.")
892
    parser.add_argument("prototext", help="Caffe prototext.")
893
    parser.add_argument("caffemodel", help="Caffe trained model.")
894
    parser.add_argument("--init_net", help="Caffe2 initialization net.",
895
                        default="init_net.pb")
896
    parser.add_argument("--predict_net", help="Caffe2 prediction net.",
897
                        default="predict_net.pb")
898
    parser.add_argument("--remove_legacy_pad", help="Remove legacy pad \
899
                        (Only works for nets with one input blob)",
900
                        action="store_true",
901
                        default=False)
902
    parser.add_argument("--input_dims", help="Dimension of input blob", nargs='+',
903
                        type=int, default=[])
904
    args = parser.parse_args()
905

906
    caffenet = caffe_pb2.NetParameter()
907
    caffenet_pretrained = caffe_pb2.NetParameter()
908
    input_proto = args.prototext
909
    input_caffemodel = args.caffemodel
910
    output_init_net = args.init_net
911
    output_predict_net = args.predict_net
912

913
    with open(input_proto) as f:
914
        text_format.Merge(f.read(), caffenet)
915
    with open(input_caffemodel, 'rb') as f:
916
        caffenet_pretrained.ParseFromString(f.read())
917
    net, pretrained_params = TranslateModel(
918
        caffenet, caffenet_pretrained, is_test=True,
919
        remove_legacy_pad=args.remove_legacy_pad,
920
        input_dims=args.input_dims
921
    )
922

923
    # Assume there is one input and one output
924
    external_input = net.op[0].input[0]
925
    external_output = net.op[-1].output[0]
926

927
    net.external_input.extend([external_input])
928
    net.external_input.extend([param.name for param in pretrained_params.protos])
929
    net.external_output.extend([external_output])
930
    init_net = ConvertTensorProtosToInitNet(pretrained_params, external_input)
931

932
    with open(output_predict_net, 'wb') as f:
933
        f.write(net.SerializeToString())
934
    with open(output_predict_net + 'txt', 'w') as f:
935
        f.write(str(net))
936
    with open(output_init_net, 'wb') as f:
937
        f.write(init_net.SerializeToString())
938

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.