pytorch

Форк
0
56 строк · 1.7 Кб
1
import tempfile
2

3
import numpy as np
4

5
from torch import nn
6
from torch.autograd import Variable, Function
7
import torch.onnx
8

9
import onnx
10
import caffe2.python.onnx.backend
11

12
class MyFunction(Function):
13
    @staticmethod
14
    def forward(ctx, x, y):
15
        return x * x + y
16

17
    @staticmethod
18
    def symbolic(graph, x, y):
19
        x2 = graph.at("mul", x, x)
20
        r = graph.at("add", x2, y)
21
        # x, y, x2, and r are 'Node' objects
22
        # print(r) or print(graph) will print out a textual representation for debugging.
23
        # this representation will be converted to ONNX protobufs on export.
24
        return r
25

26
class MyModule(nn.Module):
27
    def forward(self, x, y):
28
        # you can combine your ATen ops with standard onnx ones
29
        x = nn.ReLU()(x)
30
        return MyFunction.apply(x, y)
31

32
f = tempfile.NamedTemporaryFile()
33
torch.onnx.export(MyModule(),
34
                  (Variable(torch.ones(3, 4)), Variable(torch.ones(3, 4))),
35
                  f, verbose=True)
36

37
# prints the graph for debugging:
38
# graph(%input : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
39
#       %y : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
40
#   %2 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::Relu(%input)
41
#   %3 : Tensor = aten::ATen[operator="mul"](%2, %2)
42
#   %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::ATen[operator="add"](%3, %y)
43
#   return (%4)
44

45
graph = onnx.load(f.name)
46

47
a = np.random.randn(3, 4).astype(np.float32)
48
b = np.random.randn(3, 4).astype(np.float32)
49

50
prepared_backend = caffe2.python.onnx.backend.prepare(graph)
51
W = {graph.graph.input[0].name: a, graph.graph.input[1].name: b}
52
c2_out = prepared_backend.run(W)[0]
53

54
x = np.maximum(a, 0)
55
r = x * x + b
56
np.testing.assert_array_almost_equal(r, c2_out)
57

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.