pytorch

Форк
0
/
split_op_cost_test.py 
246 строк · 8.4 Кб
1
import numpy as np
2
from caffe2.python import core, workspace
3
from caffe2.python.test_util import TestCase
4

5

6
class TestSplitOpCost(TestCase):
7
    def _verify_cost(self, workspace, split_op):
8
        flops, bytes_written, bytes_read = workspace.GetOperatorCost(
9
            split_op, split_op.input
10
        )
11
        self.assertEqual(flops, 0)
12
        self.assertEqual(
13
            bytes_read,
14
            sum(workspace.FetchBlob(b).nbytes for b in split_op.input),
15
        )
16
        self.assertEqual(
17
            bytes_written,
18
            sum(workspace.FetchBlob(b).nbytes for b in split_op.output),
19
        )
20

21
    def test_columnwise_equal_outputSplit(self):
22
        workspace.ResetWorkspace()
23
        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
24
        split_op = core.CreateOperator(
25
            "Split",
26
            ["input"],
27
            ["output_1", "output_2", "output_3"],
28
        )
29
        workspace.RunOperatorOnce(split_op)
30

31
        output_1 = workspace.FetchBlob("output_1")
32
        self.assertTupleEqual(output_1.shape, (2, 1))
33
        np.testing.assert_array_equal(output_1, [[1], [4]])
34

35
        output_2 = workspace.FetchBlob("output_2")
36
        np.testing.assert_array_equal(output_2, [[2], [5]])
37

38
        output_3 = workspace.FetchBlob("output_3")
39
        np.testing.assert_array_equal(output_3, [[3], [6]])
40

41
        self._verify_cost(workspace, split_op)
42

43
    def test_rowwise_equal_outputSplit(self):
44
        workspace.ResetWorkspace()
45
        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
46
        split_op = core.CreateOperator(
47
            "Split",
48
            ["input"],
49
            ["output_1", "output_2"],
50
            axis=0,
51
        )
52
        workspace.RunOperatorOnce(split_op)
53

54
        output_1 = workspace.FetchBlob("output_1")
55
        self.assertTupleEqual(output_1.shape, (1, 3))
56
        np.testing.assert_array_equal(output_1, [[1, 2, 3]])
57

58
        output_2 = workspace.FetchBlob("output_2")
59
        np.testing.assert_array_equal(output_2, [[4, 5, 6]])
60

61
        self._verify_cost(workspace, split_op)
62

63
    def test_columnwise_equal_outputSplit_columnRemoved(self):
64
        workspace.ResetWorkspace()
65
        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
66
        # To be able to use 'add_axis' (which should have been called 'remove_axis') on 'axis',
67
        # the dimensions of split tensors must match on 'axis'
68
        split_op = core.CreateOperator(
69
            "Split",
70
            ["input"],
71
            ["output_1", "output_2", "output_3"],
72
            axis=1,
73
            add_axis=1,
74
        )
75
        workspace.RunOperatorOnce(split_op)
76

77
        output_1 = workspace.FetchBlob("output_1")
78
        self.assertTupleEqual(output_1.shape, (2,))
79
        np.testing.assert_array_equal(output_1, [1, 4])
80

81
        output_2 = workspace.FetchBlob("output_2")
82
        np.testing.assert_array_equal(output_2, [2, 5])
83

84
        output_3 = workspace.FetchBlob("output_3")
85
        np.testing.assert_array_equal(output_3, [3, 6])
86

87
        self._verify_cost(workspace, split_op)
88

89
    def test_rowwise_equal_outputSplit_rowRemoved(self):
90
        workspace.ResetWorkspace()
91
        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
92
        split_op = core.CreateOperator(
93
            "Split",
94
            ["input"],
95
            ["output_1", "output_2"],
96
            axis=0,
97
            add_axis=1,
98
        )
99
        workspace.RunOperatorOnce(split_op)
100

101
        output_1 = workspace.FetchBlob("output_1")
102
        self.assertTupleEqual(output_1.shape, (3,))
103
        np.testing.assert_array_equal(output_1, [1, 2, 3])
104

105
        output_2 = workspace.FetchBlob("output_2")
106
        np.testing.assert_array_equal(output_2, [4, 5, 6])
107

108
        self._verify_cost(workspace, split_op)
109

110
    def test_rowwise_unequal_argSplit(self):
111
        workspace.ResetWorkspace()
112
        workspace.FeedBlob(
113
            "input", np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
114
        )
115
        split_op = core.CreateOperator(
116
            "Split",
117
            ["input"],
118
            ["output_1", "output_2"],
119
            axis=0,
120
            split=[1, 2],
121
        )
122
        workspace.RunOperatorOnce(split_op)
123

124
        output_1 = workspace.FetchBlob("output_1")
125
        self.assertTupleEqual(output_1.shape, (1, 3))
126
        np.testing.assert_array_equal(output_1, [[1, 2, 3]])
127

128
        output_2 = workspace.FetchBlob("output_2")
129
        self.assertTupleEqual(output_2.shape, (2, 3))
130
        np.testing.assert_array_equal(output_2, [[4, 5, 6], [7, 8, 9]])
131

132
        self._verify_cost(workspace, split_op)
133

134
    def test_rowwise_unequal_argSplit_rowRemoved(self):
135
        workspace.ResetWorkspace()
136
        workspace.FeedBlob(
137
            "input", np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
138
        )
139
        split_op = core.CreateOperator(
140
            "Split",
141
            ["input"],
142
            ["output_1", "output_2", "output_3"],
143
            axis=0,
144
            split=[1, 1, 1],
145
            add_axis=1,
146
        )
147
        workspace.RunOperatorOnce(split_op)
148

149
        output_1 = workspace.FetchBlob("output_1")
150
        self.assertTupleEqual(output_1.shape, (3,))
151
        np.testing.assert_array_equal(output_1, [1, 2, 3])
152

153
        output_2 = workspace.FetchBlob("output_2")
154
        np.testing.assert_array_equal(output_2, [4, 5, 6])
155

156
        output_3 = workspace.FetchBlob("output_3")
157
        np.testing.assert_array_equal(output_3, [7, 8, 9])
158

159
        self._verify_cost(workspace, split_op)
160

161
    def test_rowwise_unequal_blobSplit(self):
162
        workspace.ResetWorkspace()
163
        workspace.FeedBlob(
164
            "input", np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
165
        )
166
        workspace.FeedBlob("split", np.array([1, 2], dtype=np.int32))
167
        split_op = core.CreateOperator(
168
            "Split",
169
            ["input", "split"],
170
            ["output_1", "output_2"],
171
            axis=0,
172
        )
173
        workspace.RunOperatorOnce(split_op)
174

175
        output_1 = workspace.FetchBlob("output_1")
176
        self.assertTupleEqual(output_1.shape, (1, 3))
177
        np.testing.assert_array_equal(output_1, [[1, 2, 3]])
178

179
        output_2 = workspace.FetchBlob("output_2")
180
        self.assertTupleEqual(output_2.shape, (2, 3))
181
        np.testing.assert_array_equal(output_2, [[4, 5, 6], [7, 8, 9]])
182

183
        self._verify_cost(workspace, split_op)
184

185
    def test_columnwise_unequal_argSplit(self):
186
        workspace.ResetWorkspace()
187
        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
188
        split_op = core.CreateOperator(
189
            "Split",
190
            ["input"],
191
            ["output_1", "output_2"],
192
            axis=1,
193
            split=[1, 2],
194
        )
195
        workspace.RunOperatorOnce(split_op)
196

197
        output_1 = workspace.FetchBlob("output_1")
198
        self.assertTupleEqual(output_1.shape, (2, 1))
199
        np.testing.assert_array_equal(output_1, [[1], [4]])
200

201
        output_2 = workspace.FetchBlob("output_2")
202
        self.assertTupleEqual(output_2.shape, (2, 2))
203
        np.testing.assert_array_equal(output_2, [[2, 3], [5, 6]])
204

205
        self._verify_cost(workspace, split_op)
206

207
    def test_columnWise_unequal_blobSplit_columnRemoved(self):
208
        workspace.ResetWorkspace()
209
        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
210
        workspace.FeedBlob("split", np.array([1, 1, 1], dtype=np.int32))
211
        split_op = core.CreateOperator(
212
            "Split",
213
            ["input", "split"],
214
            ["output_1", "output_2", "output_3"],
215
            axis=1,
216
            add_axis=1,
217
        )
218
        workspace.RunOperatorOnce(split_op)
219

220
        output_1 = workspace.FetchBlob("output_1")
221
        self.assertTupleEqual(output_1.shape, (2,))
222
        np.testing.assert_array_equal(output_1, [1, 4])
223

224
        output_2 = workspace.FetchBlob("output_2")
225
        np.testing.assert_array_equal(output_2, [2, 5])
226

227
        output_3 = workspace.FetchBlob("output_3")
228
        np.testing.assert_array_equal(output_3, [3, 6])
229

230
        self._verify_cost(workspace, split_op)
231

232
    def test_equal_outputSplit_NHWC(self):
233
        workspace.ResetWorkspace()
234
        workspace.FeedBlob("input", np.random.rand(2, 5, 7, 9).astype(np.int32))
235
        split_op = core.CreateOperator(
236
            "Split",
237
            ["input"],
238
            ["output_1", "output_2", "output_3"],
239
            order="NHWC",
240
        )
241
        workspace.RunOperatorOnce(split_op)
242

243
        for b in split_op.output:
244
            self.assertTupleEqual(workspace.FetchBlob(b).shape, (2, 5, 7, 3))
245

246
        self._verify_cost(workspace, split_op)
247

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.