pytorch

Форк
0
/
pytorch_helper.py 
90 строк · 3.3 Кб
1
import io
2

3
import onnx
4

5
import torch.onnx
6
from caffe2.python.core import BlobReference, Net
7
from caffe2.python.onnx.backend import Caffe2Backend
8

9
_next_idx = 0
10
# Clone net takes a dict instead of a lambda
11
# It should probably take a lambda, it is more flexible
12
# We fake dict here
13

14

15
class _FakeDict:
16
    def __init__(self, fn):
17
        self.fn = fn
18

19
    def get(self, name, _):
20
        return self.fn(name)
21

22

23
def PyTorchModule(helper, model, sample_arguments, caffe2_inputs, prefix_name=None):
24
    """
25
    Embed an ONNX-exportable PyTorch Model into a Caffe2 model being built.
26

27
    Args:
28
        helper (caffe2.python.core.ModelHelder): the model helper where
29
            this imported network should be inserted
30
        model (torch.nn.Module): the model to be exported
31
        sample_arguments (tuple of arguments): the inputs to
32
            the model, e.g., such that ``model(*args)`` is a valid
33
            invocation of the model.  Any non-Variable arguments will
34
            be hard-coded into the exported model; any Variable arguments
35
            will become inputs of the exported model, in the order they
36
            occur in args.  If args is a Variable, this is equivalent
37
            to having called it with a 1-ary tuple of that Variable.
38
            (Note: passing keyword arguments to the model is not currently
39
            supported.  Give us a shout if you need it.)
40
        caffe2_inputs (list of str or caffe2.python.core.BlobReference): the
41
           caffe2 Blobs that should be inputs to this network. Must be
42
           the same length as sample_arguments
43
        prefix_name: prefix name to add to each member of the blob, if None then
44
           a fresh prefix pytorch_input_N/ is used
45
    Returns:
46
        A tuple of caffe2.python.core.BlobReference objects referring to the
47
        models outputs, or a single BlobReference when the model returns a single
48
        value.
49
    """
50
    if prefix_name is None:
51
        global _next_idx
52
        prefix_name = "pytorch_import_" + str(_next_idx) + "/"
53
        _next_idx += 1
54

55
    # TODO: handle the case where model cannot be exported
56
    # and embed as a Python op in Caffe2
57
    f = io.BytesIO()
58
    torch.onnx.export(model, sample_arguments, f, export_params=True)
59
    onnx_model = onnx.load(io.BytesIO(f.getvalue()))
60
    init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
61

62
    initialized = {x.name for x in onnx_model.graph.initializer}
63
    uninitialized_inputs = {
64
        x.name: i
65
        for i, x in enumerate(onnx_model.graph.input)
66
        if x.name not in initialized
67
    }
68

69
    if len(uninitialized_inputs) != len(caffe2_inputs):
70
        raise ValueError(
71
            f"Expected {len(uninitialized_inputs)} inputs but found {len(caffe2_inputs)}"
72
        )
73

74
    def remap_blob_name(name):
75
        if name in uninitialized_inputs:
76
            idx = uninitialized_inputs[name]
77
            return str(caffe2_inputs[idx])
78
        return prefix_name + name
79

80
    predict_net = Net(predict_net).Clone("anon", _FakeDict(remap_blob_name))
81
    helper.net.AppendNet(predict_net)
82

83
    init_net = Net(init_net).Clone("anon", _FakeDict(remap_blob_name))
84
    helper.param_init_net.AppendNet(init_net)
85

86
    results = tuple(
87
        BlobReference(remap_blob_name(x.name), helper.net)
88
        for x in onnx_model.graph.output
89
    )
90
    return results
91

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.