pytorch
246 строк · 8.4 Кб
1import numpy as np
2from caffe2.python import core, workspace
3from caffe2.python.test_util import TestCase
4
5
6class TestSplitOpCost(TestCase):
7def _verify_cost(self, workspace, split_op):
8flops, bytes_written, bytes_read = workspace.GetOperatorCost(
9split_op, split_op.input
10)
11self.assertEqual(flops, 0)
12self.assertEqual(
13bytes_read,
14sum(workspace.FetchBlob(b).nbytes for b in split_op.input),
15)
16self.assertEqual(
17bytes_written,
18sum(workspace.FetchBlob(b).nbytes for b in split_op.output),
19)
20
21def test_columnwise_equal_outputSplit(self):
22workspace.ResetWorkspace()
23workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
24split_op = core.CreateOperator(
25"Split",
26["input"],
27["output_1", "output_2", "output_3"],
28)
29workspace.RunOperatorOnce(split_op)
30
31output_1 = workspace.FetchBlob("output_1")
32self.assertTupleEqual(output_1.shape, (2, 1))
33np.testing.assert_array_equal(output_1, [[1], [4]])
34
35output_2 = workspace.FetchBlob("output_2")
36np.testing.assert_array_equal(output_2, [[2], [5]])
37
38output_3 = workspace.FetchBlob("output_3")
39np.testing.assert_array_equal(output_3, [[3], [6]])
40
41self._verify_cost(workspace, split_op)
42
43def test_rowwise_equal_outputSplit(self):
44workspace.ResetWorkspace()
45workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
46split_op = core.CreateOperator(
47"Split",
48["input"],
49["output_1", "output_2"],
50axis=0,
51)
52workspace.RunOperatorOnce(split_op)
53
54output_1 = workspace.FetchBlob("output_1")
55self.assertTupleEqual(output_1.shape, (1, 3))
56np.testing.assert_array_equal(output_1, [[1, 2, 3]])
57
58output_2 = workspace.FetchBlob("output_2")
59np.testing.assert_array_equal(output_2, [[4, 5, 6]])
60
61self._verify_cost(workspace, split_op)
62
63def test_columnwise_equal_outputSplit_columnRemoved(self):
64workspace.ResetWorkspace()
65workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
66# To be able to use 'add_axis' (which should have been called 'remove_axis') on 'axis',
67# the dimensions of split tensors must match on 'axis'
68split_op = core.CreateOperator(
69"Split",
70["input"],
71["output_1", "output_2", "output_3"],
72axis=1,
73add_axis=1,
74)
75workspace.RunOperatorOnce(split_op)
76
77output_1 = workspace.FetchBlob("output_1")
78self.assertTupleEqual(output_1.shape, (2,))
79np.testing.assert_array_equal(output_1, [1, 4])
80
81output_2 = workspace.FetchBlob("output_2")
82np.testing.assert_array_equal(output_2, [2, 5])
83
84output_3 = workspace.FetchBlob("output_3")
85np.testing.assert_array_equal(output_3, [3, 6])
86
87self._verify_cost(workspace, split_op)
88
89def test_rowwise_equal_outputSplit_rowRemoved(self):
90workspace.ResetWorkspace()
91workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
92split_op = core.CreateOperator(
93"Split",
94["input"],
95["output_1", "output_2"],
96axis=0,
97add_axis=1,
98)
99workspace.RunOperatorOnce(split_op)
100
101output_1 = workspace.FetchBlob("output_1")
102self.assertTupleEqual(output_1.shape, (3,))
103np.testing.assert_array_equal(output_1, [1, 2, 3])
104
105output_2 = workspace.FetchBlob("output_2")
106np.testing.assert_array_equal(output_2, [4, 5, 6])
107
108self._verify_cost(workspace, split_op)
109
110def test_rowwise_unequal_argSplit(self):
111workspace.ResetWorkspace()
112workspace.FeedBlob(
113"input", np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
114)
115split_op = core.CreateOperator(
116"Split",
117["input"],
118["output_1", "output_2"],
119axis=0,
120split=[1, 2],
121)
122workspace.RunOperatorOnce(split_op)
123
124output_1 = workspace.FetchBlob("output_1")
125self.assertTupleEqual(output_1.shape, (1, 3))
126np.testing.assert_array_equal(output_1, [[1, 2, 3]])
127
128output_2 = workspace.FetchBlob("output_2")
129self.assertTupleEqual(output_2.shape, (2, 3))
130np.testing.assert_array_equal(output_2, [[4, 5, 6], [7, 8, 9]])
131
132self._verify_cost(workspace, split_op)
133
134def test_rowwise_unequal_argSplit_rowRemoved(self):
135workspace.ResetWorkspace()
136workspace.FeedBlob(
137"input", np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
138)
139split_op = core.CreateOperator(
140"Split",
141["input"],
142["output_1", "output_2", "output_3"],
143axis=0,
144split=[1, 1, 1],
145add_axis=1,
146)
147workspace.RunOperatorOnce(split_op)
148
149output_1 = workspace.FetchBlob("output_1")
150self.assertTupleEqual(output_1.shape, (3,))
151np.testing.assert_array_equal(output_1, [1, 2, 3])
152
153output_2 = workspace.FetchBlob("output_2")
154np.testing.assert_array_equal(output_2, [4, 5, 6])
155
156output_3 = workspace.FetchBlob("output_3")
157np.testing.assert_array_equal(output_3, [7, 8, 9])
158
159self._verify_cost(workspace, split_op)
160
161def test_rowwise_unequal_blobSplit(self):
162workspace.ResetWorkspace()
163workspace.FeedBlob(
164"input", np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
165)
166workspace.FeedBlob("split", np.array([1, 2], dtype=np.int32))
167split_op = core.CreateOperator(
168"Split",
169["input", "split"],
170["output_1", "output_2"],
171axis=0,
172)
173workspace.RunOperatorOnce(split_op)
174
175output_1 = workspace.FetchBlob("output_1")
176self.assertTupleEqual(output_1.shape, (1, 3))
177np.testing.assert_array_equal(output_1, [[1, 2, 3]])
178
179output_2 = workspace.FetchBlob("output_2")
180self.assertTupleEqual(output_2.shape, (2, 3))
181np.testing.assert_array_equal(output_2, [[4, 5, 6], [7, 8, 9]])
182
183self._verify_cost(workspace, split_op)
184
185def test_columnwise_unequal_argSplit(self):
186workspace.ResetWorkspace()
187workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
188split_op = core.CreateOperator(
189"Split",
190["input"],
191["output_1", "output_2"],
192axis=1,
193split=[1, 2],
194)
195workspace.RunOperatorOnce(split_op)
196
197output_1 = workspace.FetchBlob("output_1")
198self.assertTupleEqual(output_1.shape, (2, 1))
199np.testing.assert_array_equal(output_1, [[1], [4]])
200
201output_2 = workspace.FetchBlob("output_2")
202self.assertTupleEqual(output_2.shape, (2, 2))
203np.testing.assert_array_equal(output_2, [[2, 3], [5, 6]])
204
205self._verify_cost(workspace, split_op)
206
207def test_columnWise_unequal_blobSplit_columnRemoved(self):
208workspace.ResetWorkspace()
209workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
210workspace.FeedBlob("split", np.array([1, 1, 1], dtype=np.int32))
211split_op = core.CreateOperator(
212"Split",
213["input", "split"],
214["output_1", "output_2", "output_3"],
215axis=1,
216add_axis=1,
217)
218workspace.RunOperatorOnce(split_op)
219
220output_1 = workspace.FetchBlob("output_1")
221self.assertTupleEqual(output_1.shape, (2,))
222np.testing.assert_array_equal(output_1, [1, 4])
223
224output_2 = workspace.FetchBlob("output_2")
225np.testing.assert_array_equal(output_2, [2, 5])
226
227output_3 = workspace.FetchBlob("output_3")
228np.testing.assert_array_equal(output_3, [3, 6])
229
230self._verify_cost(workspace, split_op)
231
232def test_equal_outputSplit_NHWC(self):
233workspace.ResetWorkspace()
234workspace.FeedBlob("input", np.random.rand(2, 5, 7, 9).astype(np.int32))
235split_op = core.CreateOperator(
236"Split",
237["input"],
238["output_1", "output_2", "output_3"],
239order="NHWC",
240)
241workspace.RunOperatorOnce(split_op)
242
243for b in split_op.output:
244self.assertTupleEqual(workspace.FetchBlob(b).shape, (2, 5, 7, 3))
245
246self._verify_cost(workspace, split_op)
247