pytorch

Форк
0
/
mobile_exporter.py 
106 строк · 3.6 Кб
1
## @package mobile_exporter
2
# Module caffe2.python.mobile_exporter
3

4

5

6

7

8
from caffe2.python import core, utils
9
from caffe2.proto import caffe2_pb2
10
import numpy as np
11

12

13
def add_tensor(net, name, blob):
14
    ''' Create an operator to store the tensor 'blob',
15
        run the operator to put the blob to workspace.
16
        uint8 is stored as an array of string with one element.
17
    '''
18
    kTypeNameMapper = {
19
        np.dtype('float32'): "GivenTensorFill",
20
        np.dtype('int32'): "GivenTensorIntFill",
21
        np.dtype('int64'): "GivenTensorInt64Fill",
22
        np.dtype('uint8'): "GivenTensorByteStringToUInt8Fill",
23
        np.dtype('O'): "GivenTensorStringFill"
24
    }
25

26
    shape = blob.shape
27
    values = blob
28
    # pass array of uint8 as a string to save storage
29
    # storing uint8_t has a large overhead for now
30
    if blob.dtype == np.dtype('uint8'):
31
        shape = blob.shape
32
        values = [blob.tobytes()]
33
    # Only allow string arrays as objects.
34
    # The only intended use case for this is to store arrays of strings in the
35
    # model which can be used for post processing results in subsequent ops.
36
    if blob.dtype == np.dtype('O'):
37
        for blob_val in blob:
38
            assert(isinstance(blob_val, bytes))
39

40
    op = core.CreateOperator(
41
        kTypeNameMapper[blob.dtype],
42
        [], [name],
43
        arg=[
44
            utils.MakeArgument("shape", shape),
45
            utils.MakeArgument("values", values),
46
        ]
47
    )
48
    net.op.extend([op])
49

50

51
def Export(workspace, net, params):
52
    """Returns init_net and predict_net suitable for writing to disk
53
       and loading into a Predictor"""
54
    proto = net if isinstance(net, caffe2_pb2.NetDef) else net.Proto()
55
    predict_net = caffe2_pb2.NetDef()
56
    predict_net.CopyFrom(proto)
57
    init_net = caffe2_pb2.NetDef()
58
    # Populate the init_net.
59
    ssa, blob_versions = core.get_ssa(net)
60
    inputs = []
61
    for versioned_inputs, _ in ssa:
62
        inputs += [name for name, _ in versioned_inputs]
63

64
    input_blobs = [blob_name for blob_name, version in
65
                   blob_versions.items()
66
                   if version == 0 and blob_name not in params]
67
    # Blobs that are never used as an input to another layer,
68
    # i.e. strictly output blobs.
69
    output_blobs = [blob_name for blob_name, version in
70
                    blob_versions.items()
71
                    if version != 0 and blob_name not in inputs]
72

73
    for blob_ref in params:
74
        blob_name = str(blob_ref)
75
        blob = workspace.FetchBlob(blob_name)
76
        add_tensor(init_net, blob_name, blob)
77
    # We have to make sure the blob exists in the namespace
78
    # and we can do so with fake data. (Which is immediately overwritten
79
    # by any typical usage)
80
    for blob_name in input_blobs:
81
        init_net.op.extend(
82
            [
83
                core.CreateOperator(
84
                    "GivenTensorFill", [], [blob_name],
85
                    arg=[
86
                        utils.MakeArgument("shape", [1, 1]),
87
                        utils.MakeArgument("values", [0.0])
88
                    ]
89
                )
90
            ]
91
        )
92

93
    # Now we make input/output_blobs line up with what Predictor expects.
94
    del predict_net.external_input[:]
95

96
    new_external_inputs = input_blobs
97
    for external_input in proto.external_input:
98
        if external_input not in new_external_inputs:
99
            new_external_inputs.append(external_input)
100

101
    # For populating weights
102
    predict_net.external_input.extend(new_external_inputs)
103
    # Ensure the output is also consistent with what we want
104
    del predict_net.external_output[:]
105
    predict_net.external_output.extend(output_blobs)
106
    return init_net, predict_net
107

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.