pytorch

Форк
0
/
pack_ops_test.py 
372 строки · 12.3 Кб
1

2

3

4

5

6
from caffe2.python import core, workspace
7
import caffe2.python.hypothesis_test_util as hu
8
import caffe2.python.serialized_test.serialized_test_util as serial
9

10
from hypothesis import given, settings
11
from hypothesis import strategies as st
12
import numpy as np
13
import time
14

15

16
class TestTensorPackOps(serial.SerializedTestCase):
17

18
    def pack_segments_ref(self, return_presence_mask=False, max_length=None):
19
        def pack_segments_ref(lengths, data, max_length=max_length):
20
            arr = []
21
            constant_values = 0
22
            if data.dtype.char == 'S':
23
                constant_values = ''
24
            if max_length is None:
25
                max_length = np.max(lengths)
26
            start = 0
27
            for idx in range(np.size(lengths)):
28
                len = lengths[idx] if max_length >= lengths[idx] else max_length
29
                chunk = data[start : start + len]
30
                pad_length = max_length - len
31

32
                # ((0, pad_length), (0, 0)) says add pad_length rows of padding
33
                # below chunk and 0 rows of padding elsewhere
34
                arr.append(
35
                    np.pad(
36
                        chunk, ((0, pad_length), (0, 0)),
37
                        mode=str("constant"),
38
                        constant_values=constant_values
39
                    )
40
                )
41
                start += lengths[idx]
42
            result = [arr]
43
            if return_presence_mask:
44
                presence_arr = []
45
                for length in lengths:
46
                    length = length if max_length >= length else max_length
47
                    pad_length = max_length - length
48
                    presence_arr.append(
49
                        np.pad(
50
                            np.ones((length), dtype=bool), ((0, pad_length)),
51
                            mode=str("constant")
52
                        )
53
                    )
54
                result.append(presence_arr)
55
            return result
56

57
        return pack_segments_ref
58

59
    @given(
60
        num_seq=st.integers(10, 100),
61
        cell_size=st.integers(1, 10),
62
        max_length_buffer=st.integers(-5, 5),
63
        **hu.gcs
64
    )
65
    @settings(deadline=None, max_examples=50)
66
    def test_pack_with_max_length_ops(
67
        self, num_seq, cell_size, max_length_buffer, gc, dc
68
    ):
69
        # create data
70
        lengths = np.arange(num_seq, dtype=np.int32) + 1
71
        num_cell = np.sum(lengths)
72
        data = np.zeros(num_cell * cell_size, dtype=np.float32)
73
        left = np.cumsum(np.arange(num_seq) * cell_size)
74
        right = np.cumsum(lengths * cell_size)
75
        for i in range(num_seq):
76
            data[left[i]:right[i]] = i + 1.0
77
        data.resize(num_cell, cell_size)
78
        print("\nnum seq:{},    num cell: {},   cell size:{}\n".format(
79
            num_seq, num_cell, cell_size)
80
            + "=" * 60
81
        )
82
        # run test
83
        max_length = num_seq + max_length_buffer
84
        op = core.CreateOperator(
85
            'PackSegments', ['l', 'd'], ['t'], max_length=max_length)
86
        workspace.FeedBlob('l', lengths)
87
        workspace.FeedBlob('d', data)
88
        start = time.time()
89
        self.assertReferenceChecks(
90
            device_option=gc,
91
            op=op,
92
            inputs=[lengths, data, max_length],
93
            reference=self.pack_segments_ref(max_length=max_length),
94
        )
95
        end = time.time()
96
        print("{} used time: {}".format(gc, end - start).replace('\n', ' '))
97

98
        with core.DeviceScope(gc):
99
            workspace.FeedBlob('l', lengths)
100
            workspace.FeedBlob('d', data)
101
        workspace.RunOperatorOnce(core.CreateOperator(
102
            'PackSegments',
103
            ['l', 'd'],
104
            ['t'],
105
            max_length=max_length,
106
            device_option=gc))
107
        workspace.RunOperatorOnce(core.CreateOperator(
108
            'UnpackSegments',
109
            ['l', 't'],
110
            ['newd'],
111
            max_length=max_length,
112
            device_option=gc))
113
        assert(workspace.FetchBlob('t').shape[1] == max_length)
114

115
        def _cal_unpacked_data(data):
116
            if max_length >= num_seq:
117
                return data
118
            output = None
119
            start = 0
120
            for i, length in enumerate(lengths):
121
                new_len = max_length if length > max_length else length
122
                chunk = data[start: start + new_len]
123
                if output is None:
124
                    output = chunk
125
                else:
126
                    output = np.concatenate((output, chunk), axis=0)
127
                start += length
128
            return output
129

130
        true_newd = _cal_unpacked_data(workspace.FetchBlob('d'))
131
        assert((workspace.FetchBlob('newd') == true_newd).all())
132

133
    @given(
134
        num_seq=st.integers(10, 500),
135
        cell_size=st.integers(1, 10),
136
        **hu.gcs
137
    )
138
    @settings(deadline=10000)
139
    def test_pack_ops(self, num_seq, cell_size, gc, dc):
140
        # create data
141
        lengths = np.arange(num_seq, dtype=np.int32) + 1
142
        num_cell = np.sum(lengths)
143
        data = np.zeros(num_cell * cell_size, dtype=np.float32)
144
        left = np.cumsum(np.arange(num_seq) * cell_size)
145
        right = np.cumsum(lengths * cell_size)
146
        for i in range(num_seq):
147
            data[left[i]:right[i]] = i + 1.0
148
        data.resize(num_cell, cell_size)
149
        print("\nnum seq:{},    num cell: {},   cell size:{}\n".format(
150
            num_seq, num_cell, cell_size)
151
            + "=" * 60
152
        )
153
        # run test
154
        op = core.CreateOperator(
155
            'PackSegments', ['l', 'd'], ['t'])
156
        workspace.FeedBlob('l', lengths)
157
        workspace.FeedBlob('d', data)
158

159
        start = time.time()
160
        self.assertReferenceChecks(
161
            device_option=gc,
162
            op=op,
163
            inputs=[lengths, data],
164
            reference=self.pack_segments_ref(),
165
        )
166
        end = time.time()
167
        print("{} used time: {}".format(gc, end - start).replace('\n', ' '))
168

169
        with core.DeviceScope(gc):
170
            workspace.FeedBlob('l', lengths)
171
            workspace.FeedBlob('d', data)
172
        workspace.RunOperatorOnce(core.CreateOperator(
173
            'PackSegments',
174
            ['l', 'd'],
175
            ['t'],
176
            device_option=gc))
177
        workspace.RunOperatorOnce(core.CreateOperator(
178
            'UnpackSegments',
179
            ['l', 't'],
180
            ['newd'],
181
            device_option=gc))
182
        assert((workspace.FetchBlob('newd') == workspace.FetchBlob('d')).all())
183

184
    @given(
185
        **hu.gcs_cpu_only
186
    )
187
    def test_pack_ops_str(self, gc, dc):
188
        # GPU does not support string. Test CPU implementation only.
189
        workspace.FeedBlob('l', np.array([1, 2, 3], dtype=np.int64))
190
        strs = np.array([
191
            ["a", "a"],
192
            ["b", "b"],
193
            ["bb", "bb"],
194
            ["c", "c"],
195
            ["cc", "cc"],
196
            ["ccc", "ccc"]],
197
            dtype='|S')
198
        workspace.FeedBlob('d', strs)
199
        workspace.RunOperatorOnce(core.CreateOperator(
200
            'PackSegments',
201
            ['l', 'd'],
202
            ['t'],
203
            device_option=gc))
204
        workspace.RunOperatorOnce(core.CreateOperator(
205
            'UnpackSegments',
206
            ['l', 't'],
207
            ['newd'],
208
            device_option=gc))
209
        assert((workspace.FetchBlob('newd') == workspace.FetchBlob('d')).all())
210

211
    def test_pad_minf(self):
212
        workspace.FeedBlob('l', np.array([1, 2, 3], dtype=np.int32))
213
        workspace.FeedBlob(
214
            'd',
215
            np.array([
216
                [1.0, 1.1],
217
                [2.0, 2.1],
218
                [2.2, 2.2],
219
                [3.0, 3.1],
220
                [3.2, 3.3],
221
                [3.4, 3.5]],
222
                dtype=np.float32))
223
        workspace.RunOperatorOnce(core.CreateOperator(
224
            'PackSegments', ['l', 'd'], ['t'], pad_minf=True))
225
        workspace.RunOperatorOnce(core.CreateOperator(
226
            'Exp', ['t'], ['r']
227
        ))
228
        result = workspace.FetchBlob('t')
229
        assert(result[0, -1, 0] < -1000.0)
230

231
        # The whole point of padding with -inf is that when we exponentiate it
232
        # then it should be zero.
233
        exponentiated = workspace.FetchBlob('r')
234
        assert(exponentiated[0, -1, 0] == 0.0)
235

236
    def test_pad_no_minf(self):
237
        workspace.FeedBlob('l', np.array([1, 2, 3], dtype=np.int32))
238
        workspace.FeedBlob(
239
            'd',
240
            np.array([
241
                [1.0, 1.1],
242
                [2.0, 2.1],
243
                [2.2, 2.2],
244
                [3.0, 3.1],
245
                [3.2, 3.3],
246
                [3.4, 3.5]],
247
                dtype=np.float32))
248
        workspace.RunOperatorOnce(
249
            core.CreateOperator(
250
                'PackSegments', ['l', 'd'], ['t'], pad_minf=False),
251
        )
252
        result = workspace.FetchBlob('t')
253
        assert(result[0, -1, 0] == 0.0)
254

255
        workspace.FeedBlob(
256
            'i',
257
            np.array([
258
                [1, 1],
259
                [2, 2],
260
                [2, 2],
261
                [3, 3],
262
                [3, 3],
263
                [3, 3]],
264
                dtype=np.int32))
265
        workspace.RunOperatorOnce(
266
            core.CreateOperator(
267
                'PackSegments', ['l', 'i'], ['t2'], pad_minf=False),
268
        )
269
        result = workspace.FetchBlob('t2')
270
        assert(result[0, -1, 0] == 0)
271

272
    @given(**hu.gcs)
273
    def test_presence_mask(self, gc, dc):
274
        lengths = np.array([1, 2, 3], dtype=np.int32)
275
        data = np.array(
276
            [
277
                [1.0, 1.0], [2.0, 2.0], [2.0, 2.0], [3.0, 3.0], [3.0, 3.0],
278
                [3.0, 3.0]
279
            ],
280
            dtype=np.float32
281
        )
282

283
        op = core.CreateOperator(
284
            'PackSegments', ['l', 'd'], ['t', 'p'], return_presence_mask=True
285
        )
286
        workspace.FeedBlob('l', lengths)
287
        workspace.FeedBlob('d', data)
288
        inputs = [lengths, data]
289
        self.assertReferenceChecks(
290
            device_option=gc,
291
            op=op,
292
            inputs=inputs,
293
            reference=self.pack_segments_ref(return_presence_mask=True),
294
        )
295

296
        op = core.CreateOperator(
297
            'PackSegments', ['l', 'd'], ['t', 'p'], return_presence_mask=True
298
        )
299
        workspace.RunOperatorOnce(op)
300

301
        output = workspace.FetchBlob('t')
302
        expected_output_shape = (3, 3, 2)
303
        self.assertEqual(output.shape, expected_output_shape)
304

305
        presence_mask = workspace.FetchBlob('p')
306
        expected_presence_mask = np.array(
307
            [[True, False, False], [True, True, False], [True, True, True]],
308
            dtype=bool
309
        )
310
        self.assertEqual(presence_mask.shape, expected_presence_mask.shape)
311
        np.testing.assert_array_equal(presence_mask, expected_presence_mask)
312

313
    def test_presence_mask_empty(self):
314
        lengths = np.array([], dtype=np.int32)
315
        data = np.array([], dtype=np.float32)
316

317
        op = core.CreateOperator(
318
            'PackSegments', ['l', 'd'], ['t', 'p'], return_presence_mask=True
319
        )
320
        workspace.FeedBlob('l', lengths)
321
        workspace.FeedBlob('d', data)
322
        workspace.RunOperatorOnce(op)
323

324
        output = workspace.FetchBlob('p')
325
        expected_output_shape = (0, 0)
326
        self.assertEqual(output.shape, expected_output_shape)
327

328
    @given(**hu.gcs_cpu_only)
329
    @settings(deadline=10000)
330
    def test_out_of_bounds(self, gc, dc):
331
        # Copy pasted from test_pack_ops but with 3 changed to 4
332
        lengths = np.array([1, 2, 4], dtype=np.int32)
333
        data = np.array([
334
            [1.0, 1.0],
335
            [2.0, 2.0],
336
            [2.0, 2.0],
337
            [3.0, 3.0],
338
            [3.0, 3.0],
339
            [3.0, 3.0]], dtype=np.float32)
340
        op = core.CreateOperator(
341
            'PackSegments', ['l', 'd'], ['t'])
342

343
        inputs = [lengths, data]
344
        self.assertRunOpRaises(
345
            device_option=gc,
346
            op=op,
347
            inputs=inputs,
348
            exception=RuntimeError
349
        )
350

351
    @given(**hu.gcs_cpu_only)
352
    @settings(deadline=10000)
353
    def test_under_bounds(self, gc, dc):
354
        # Copy pasted from test_pack_ops but with 3 changed to 2
355
        lengths = np.array([1, 2, 2], dtype=np.int32)
356
        data = np.array([
357
            [1.0, 1.0],
358
            [2.0, 2.0],
359
            [2.0, 2.0],
360
            [3.0, 3.0],
361
            [3.0, 3.0],
362
            [3.0, 3.0]], dtype=np.float32)
363
        op = core.CreateOperator(
364
            'PackSegments', ['l', 'd'], ['t'])
365

366
        inputs = [lengths, data]
367
        self.assertRunOpRaises(
368
            device_option=gc,
369
            op=op,
370
            inputs=inputs,
371
            exception=RuntimeError
372
        )
373

374

375
if __name__ == "__main__":
376
    import unittest
377
    unittest.main()
378

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.