2
# Module caffe2.python.onnx.tests.ssa_test
11
from caffe2.proto import caffe2_pb2
12
from caffe2.python import core
13
from onnx import TensorProto
15
import caffe2.python.onnx.frontend as c2_onnx
16
from caffe2.python.onnx.helper import c2_native_run_net
17
from caffe2.python.onnx.tests.test_utils import TestCase
20
class TestFrontendSSAConversion(TestCase):
22
X = np.random.randn(4, 2).astype(np.float32)
23
W = np.random.randn(3, 2).astype(np.float32)
24
b = np.random.randn(3).astype(np.float32)
25
s = np.random.randn(1).astype(np.float32)
26
np_result = X.dot(W.transpose()) + b + s
28
net = caffe2_pb2.NetDef()
30
net.external_input[:] = ['W', 'X', 'b', 's']
44
net.external_output[:] = ['Y']
46
init_net = caffe2_pb2.NetDef()
47
init_net.name = 'test-ssa-init'
71
init_net.external_output[:] = ['W', 'b', 's']
73
_, orig_output = c2_native_run_net(
78
value_info = {'X': (TensorProto.FLOAT, X.shape)}
79
c2_onnx.Caffe2Frontend._ssa_rewrite(
84
self.assertEqual(net.external_input, ['W', 'X', 'b', 's'])
85
self.assertEqual(net.op[0].input, ['X', 'W', 'b'])
86
self.assertEqual(net.op[0].output, ['Y_1'])
87
self.assertEqual(net.op[1].input, ['Y_1', 's'])
88
self.assertEqual(net.op[1].output, ['Y_2'])
89
self.assertEqual(net.external_output, ['Y_2'])
91
self.assertEqual(init_net.external_input, [])
92
self.assertEqual(init_net.op[0].input, [])
93
self.assertEqual(init_net.op[0].output, ['W'])
94
self.assertEqual(init_net.op[1].input, [])
95
self.assertEqual(init_net.op[1].output, ['b'])
96
self.assertEqual(init_net.op[2].input, [])
97
self.assertEqual(init_net.op[2].output, ['s'])
98
self.assertEqual(init_net.external_output, ['W', 'b', 's'])
99
self.assertEqual(value_info, {'X': (TensorProto.FLOAT, X.shape)})
101
_, ssa_output = c2_native_run_net(
106
self.assertSameOutputs(ssa_output, orig_output)
107
self.assertSameOutputs(ssa_output, [np_result])
109
def test_idempotence(self):
110
net = caffe2_pb2.NetDef()
111
net.name = 'test-idempotence'
112
net.external_input[:] = ['W', 'X', 'b', 's']
126
net.external_output[:] = ['Z']
128
value_info = {'X': (TensorProto.FLOAT, [4, 2])}
129
net_copy = copy.deepcopy(net)
130
c2_onnx.Caffe2Frontend._ssa_rewrite(
134
self.assertEqual(net, net_copy)