pytorch
1import tempfile2
3import numpy as np4
5from torch import nn6from torch.autograd import Variable, Function7import torch.onnx8
9import onnx10import caffe2.python.onnx.backend11
12class MyFunction(Function):13@staticmethod14def forward(ctx, x, y):15return x * x + y16
17@staticmethod18def symbolic(graph, x, y):19x2 = graph.at("mul", x, x)20r = graph.at("add", x2, y)21# x, y, x2, and r are 'Node' objects22# print(r) or print(graph) will print out a textual representation for debugging.23# this representation will be converted to ONNX protobufs on export.24return r25
26class MyModule(nn.Module):27def forward(self, x, y):28# you can combine your ATen ops with standard onnx ones29x = nn.ReLU()(x)30return MyFunction.apply(x, y)31
32f = tempfile.NamedTemporaryFile()33torch.onnx.export(MyModule(),34(Variable(torch.ones(3, 4)), Variable(torch.ones(3, 4))),35f, verbose=True)36
37# prints the graph for debugging:
38# graph(%input : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
39# %y : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
40# %2 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::Relu(%input)
41# %3 : Tensor = aten::ATen[operator="mul"](%2, %2)
42# %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::ATen[operator="add"](%3, %y)
43# return (%4)
44
45graph = onnx.load(f.name)46
47a = np.random.randn(3, 4).astype(np.float32)48b = np.random.randn(3, 4).astype(np.float32)49
50prepared_backend = caffe2.python.onnx.backend.prepare(graph)51W = {graph.graph.input[0].name: a, graph.graph.input[1].name: b}52c2_out = prepared_backend.run(W)[0]53
54x = np.maximum(a, 0)55r = x * x + b56np.testing.assert_array_almost_equal(r, c2_out)57