pytorch

Форк
0
134 строки · 3.9 Кб
1
## @package onnx
2
# Module caffe2.python.onnx.tests.ssa_test
3

4

5

6

7

8

9
import copy
10
import numpy as np
11
from caffe2.proto import caffe2_pb2
12
from caffe2.python import core
13
from onnx import TensorProto
14

15
import caffe2.python.onnx.frontend as c2_onnx
16
from caffe2.python.onnx.helper import c2_native_run_net
17
from caffe2.python.onnx.tests.test_utils import TestCase
18

19

20
class TestFrontendSSAConversion(TestCase):
21
    def test_ssa(self):
22
        X = np.random.randn(4, 2).astype(np.float32)
23
        W = np.random.randn(3, 2).astype(np.float32)
24
        b = np.random.randn(3).astype(np.float32)
25
        s = np.random.randn(1).astype(np.float32)
26
        np_result = X.dot(W.transpose()) + b + s
27

28
        net = caffe2_pb2.NetDef()
29
        net.name = 'test-ssa'
30
        net.external_input[:] = ['W', 'X', 'b', 's']
31
        net.op.extend([
32
            core.CreateOperator(
33
                'FC',
34
                ['X', 'W', 'b'],
35
                ['Y']
36
            ),
37
            core.CreateOperator(
38
                'Add',
39
                ['Y', 's'],
40
                ['Y'],
41
                broadcast=True,
42
            )
43
        ])
44
        net.external_output[:] = ['Y']
45

46
        init_net = caffe2_pb2.NetDef()
47
        init_net.name = 'test-ssa-init'
48
        init_net.op.extend([
49
            core.CreateOperator(
50
                'GivenTensorFill',
51
                [],
52
                ['W'],
53
                values=W,
54
                shape=W.shape,
55
            ),
56
            core.CreateOperator(
57
                'GivenTensorFill',
58
                [],
59
                ['b'],
60
                values=b,
61
                shape=b.shape,
62
            ),
63
            core.CreateOperator(
64
                'GivenTensorFill',
65
                [],
66
                ['s'],
67
                values=s,
68
                shape=s.shape,
69
            )
70
        ])
71
        init_net.external_output[:] = ['W', 'b', 's']
72

73
        _, orig_output = c2_native_run_net(
74
            predict_net=net,
75
            init_net=init_net,
76
            inputs=[X])
77

78
        value_info = {'X': (TensorProto.FLOAT, X.shape)}
79
        c2_onnx.Caffe2Frontend._ssa_rewrite(
80
            net,
81
            init_net,
82
            value_info)
83

84
        self.assertEqual(net.external_input, ['W', 'X', 'b', 's'])
85
        self.assertEqual(net.op[0].input, ['X', 'W', 'b'])
86
        self.assertEqual(net.op[0].output, ['Y_1'])
87
        self.assertEqual(net.op[1].input, ['Y_1', 's'])
88
        self.assertEqual(net.op[1].output, ['Y_2'])
89
        self.assertEqual(net.external_output, ['Y_2'])
90

91
        self.assertEqual(init_net.external_input, [])
92
        self.assertEqual(init_net.op[0].input, [])
93
        self.assertEqual(init_net.op[0].output, ['W'])
94
        self.assertEqual(init_net.op[1].input, [])
95
        self.assertEqual(init_net.op[1].output, ['b'])
96
        self.assertEqual(init_net.op[2].input, [])
97
        self.assertEqual(init_net.op[2].output, ['s'])
98
        self.assertEqual(init_net.external_output, ['W', 'b', 's'])
99
        self.assertEqual(value_info, {'X': (TensorProto.FLOAT, X.shape)})
100

101
        _, ssa_output = c2_native_run_net(
102
            predict_net=net,
103
            init_net=init_net,
104
            inputs=[X])
105

106
        self.assertSameOutputs(ssa_output, orig_output)
107
        self.assertSameOutputs(ssa_output, [np_result])
108

109
    def test_idempotence(self):
110
        net = caffe2_pb2.NetDef()
111
        net.name = 'test-idempotence'
112
        net.external_input[:] = ['W', 'X', 'b', 's']
113
        net.op.extend([
114
            core.CreateOperator(
115
                'FC',
116
                ['X', 'W', 'b'],
117
                ['Y']
118
            ),
119
            core.CreateOperator(
120
                'Add',
121
                ['Y', 's'],
122
                ['Z'],
123
                broadcast=True,
124
            )
125
        ])
126
        net.external_output[:] = ['Z']
127

128
        value_info = {'X': (TensorProto.FLOAT, [4, 2])}
129
        net_copy = copy.deepcopy(net)
130
        c2_onnx.Caffe2Frontend._ssa_rewrite(
131
            net_copy,
132
            None,
133
            value_info)
134
        self.assertEqual(net, net_copy)
135

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.