6
from caffe2.python.core import BlobReference, Net
7
from caffe2.python.onnx.backend import Caffe2Backend
10
# Clone net takes a dict instead of a lambda
11
# It should probably take a lambda, it is more flexible
16
def __init__(self, fn):
19
def get(self, name, _):
23
def PyTorchModule(helper, model, sample_arguments, caffe2_inputs, prefix_name=None):
25
Embed an ONNX-exportable PyTorch Model into a Caffe2 model being built.
28
helper (caffe2.python.core.ModelHelder): the model helper where
29
this imported network should be inserted
30
model (torch.nn.Module): the model to be exported
31
sample_arguments (tuple of arguments): the inputs to
32
the model, e.g., such that ``model(*args)`` is a valid
33
invocation of the model. Any non-Variable arguments will
34
be hard-coded into the exported model; any Variable arguments
35
will become inputs of the exported model, in the order they
36
occur in args. If args is a Variable, this is equivalent
37
to having called it with a 1-ary tuple of that Variable.
38
(Note: passing keyword arguments to the model is not currently
39
supported. Give us a shout if you need it.)
40
caffe2_inputs (list of str or caffe2.python.core.BlobReference): the
41
caffe2 Blobs that should be inputs to this network. Must be
42
the same length as sample_arguments
43
prefix_name: prefix name to add to each member of the blob, if None then
44
a fresh prefix pytorch_input_N/ is used
46
A tuple of caffe2.python.core.BlobReference objects referring to the
47
models outputs, or a single BlobReference when the model returns a single
50
if prefix_name is None:
52
prefix_name = "pytorch_import_" + str(_next_idx) + "/"
55
# TODO: handle the case where model cannot be exported
56
# and embed as a Python op in Caffe2
58
torch.onnx.export(model, sample_arguments, f, export_params=True)
59
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
60
init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
62
initialized = {x.name for x in onnx_model.graph.initializer}
63
uninitialized_inputs = {
65
for i, x in enumerate(onnx_model.graph.input)
66
if x.name not in initialized
69
if len(uninitialized_inputs) != len(caffe2_inputs):
71
f"Expected {len(uninitialized_inputs)} inputs but found {len(caffe2_inputs)}"
74
def remap_blob_name(name):
75
if name in uninitialized_inputs:
76
idx = uninitialized_inputs[name]
77
return str(caffe2_inputs[idx])
78
return prefix_name + name
80
predict_net = Net(predict_net).Clone("anon", _FakeDict(remap_blob_name))
81
helper.net.AppendNet(predict_net)
83
init_net = Net(init_net).Clone("anon", _FakeDict(remap_blob_name))
84
helper.param_init_net.AppendNet(init_net)
87
BlobReference(remap_blob_name(x.name), helper.net)
88
for x in onnx_model.graph.output