pytorch
372 строки · 12.3 Кб
1
2
3
4
5
6from caffe2.python import core, workspace7import caffe2.python.hypothesis_test_util as hu8import caffe2.python.serialized_test.serialized_test_util as serial9
10from hypothesis import given, settings11from hypothesis import strategies as st12import numpy as np13import time14
15
16class TestTensorPackOps(serial.SerializedTestCase):17
18def pack_segments_ref(self, return_presence_mask=False, max_length=None):19def pack_segments_ref(lengths, data, max_length=max_length):20arr = []21constant_values = 022if data.dtype.char == 'S':23constant_values = ''24if max_length is None:25max_length = np.max(lengths)26start = 027for idx in range(np.size(lengths)):28len = lengths[idx] if max_length >= lengths[idx] else max_length29chunk = data[start : start + len]30pad_length = max_length - len31
32# ((0, pad_length), (0, 0)) says add pad_length rows of padding33# below chunk and 0 rows of padding elsewhere34arr.append(35np.pad(36chunk, ((0, pad_length), (0, 0)),37mode=str("constant"),38constant_values=constant_values39)40)41start += lengths[idx]42result = [arr]43if return_presence_mask:44presence_arr = []45for length in lengths:46length = length if max_length >= length else max_length47pad_length = max_length - length48presence_arr.append(49np.pad(50np.ones((length), dtype=bool), ((0, pad_length)),51mode=str("constant")52)53)54result.append(presence_arr)55return result56
57return pack_segments_ref58
59@given(60num_seq=st.integers(10, 100),61cell_size=st.integers(1, 10),62max_length_buffer=st.integers(-5, 5),63**hu.gcs64)65@settings(deadline=None, max_examples=50)66def test_pack_with_max_length_ops(67self, num_seq, cell_size, max_length_buffer, gc, dc68):69# create data70lengths = np.arange(num_seq, dtype=np.int32) + 171num_cell = np.sum(lengths)72data = np.zeros(num_cell * cell_size, dtype=np.float32)73left = np.cumsum(np.arange(num_seq) * cell_size)74right = np.cumsum(lengths * cell_size)75for i in range(num_seq):76data[left[i]:right[i]] = i + 1.077data.resize(num_cell, cell_size)78print("\nnum seq:{}, num cell: {}, cell size:{}\n".format(79num_seq, num_cell, cell_size)80+ "=" * 6081)82# run test83max_length = num_seq + max_length_buffer84op = core.CreateOperator(85'PackSegments', ['l', 'd'], ['t'], max_length=max_length)86workspace.FeedBlob('l', lengths)87workspace.FeedBlob('d', data)88start = time.time()89self.assertReferenceChecks(90device_option=gc,91op=op,92inputs=[lengths, data, max_length],93reference=self.pack_segments_ref(max_length=max_length),94)95end = time.time()96print("{} used time: {}".format(gc, end - start).replace('\n', ' '))97
98with core.DeviceScope(gc):99workspace.FeedBlob('l', lengths)100workspace.FeedBlob('d', data)101workspace.RunOperatorOnce(core.CreateOperator(102'PackSegments',103['l', 'd'],104['t'],105max_length=max_length,106device_option=gc))107workspace.RunOperatorOnce(core.CreateOperator(108'UnpackSegments',109['l', 't'],110['newd'],111max_length=max_length,112device_option=gc))113assert(workspace.FetchBlob('t').shape[1] == max_length)114
115def _cal_unpacked_data(data):116if max_length >= num_seq:117return data118output = None119start = 0120for i, length in enumerate(lengths):121new_len = max_length if length > max_length else length122chunk = data[start: start + new_len]123if output is None:124output = chunk125else:126output = np.concatenate((output, chunk), axis=0)127start += length128return output129
130true_newd = _cal_unpacked_data(workspace.FetchBlob('d'))131assert((workspace.FetchBlob('newd') == true_newd).all())132
133@given(134num_seq=st.integers(10, 500),135cell_size=st.integers(1, 10),136**hu.gcs137)138@settings(deadline=10000)139def test_pack_ops(self, num_seq, cell_size, gc, dc):140# create data141lengths = np.arange(num_seq, dtype=np.int32) + 1142num_cell = np.sum(lengths)143data = np.zeros(num_cell * cell_size, dtype=np.float32)144left = np.cumsum(np.arange(num_seq) * cell_size)145right = np.cumsum(lengths * cell_size)146for i in range(num_seq):147data[left[i]:right[i]] = i + 1.0148data.resize(num_cell, cell_size)149print("\nnum seq:{}, num cell: {}, cell size:{}\n".format(150num_seq, num_cell, cell_size)151+ "=" * 60152)153# run test154op = core.CreateOperator(155'PackSegments', ['l', 'd'], ['t'])156workspace.FeedBlob('l', lengths)157workspace.FeedBlob('d', data)158
159start = time.time()160self.assertReferenceChecks(161device_option=gc,162op=op,163inputs=[lengths, data],164reference=self.pack_segments_ref(),165)166end = time.time()167print("{} used time: {}".format(gc, end - start).replace('\n', ' '))168
169with core.DeviceScope(gc):170workspace.FeedBlob('l', lengths)171workspace.FeedBlob('d', data)172workspace.RunOperatorOnce(core.CreateOperator(173'PackSegments',174['l', 'd'],175['t'],176device_option=gc))177workspace.RunOperatorOnce(core.CreateOperator(178'UnpackSegments',179['l', 't'],180['newd'],181device_option=gc))182assert((workspace.FetchBlob('newd') == workspace.FetchBlob('d')).all())183
184@given(185**hu.gcs_cpu_only186)187def test_pack_ops_str(self, gc, dc):188# GPU does not support string. Test CPU implementation only.189workspace.FeedBlob('l', np.array([1, 2, 3], dtype=np.int64))190strs = np.array([191["a", "a"],192["b", "b"],193["bb", "bb"],194["c", "c"],195["cc", "cc"],196["ccc", "ccc"]],197dtype='|S')198workspace.FeedBlob('d', strs)199workspace.RunOperatorOnce(core.CreateOperator(200'PackSegments',201['l', 'd'],202['t'],203device_option=gc))204workspace.RunOperatorOnce(core.CreateOperator(205'UnpackSegments',206['l', 't'],207['newd'],208device_option=gc))209assert((workspace.FetchBlob('newd') == workspace.FetchBlob('d')).all())210
211def test_pad_minf(self):212workspace.FeedBlob('l', np.array([1, 2, 3], dtype=np.int32))213workspace.FeedBlob(214'd',215np.array([216[1.0, 1.1],217[2.0, 2.1],218[2.2, 2.2],219[3.0, 3.1],220[3.2, 3.3],221[3.4, 3.5]],222dtype=np.float32))223workspace.RunOperatorOnce(core.CreateOperator(224'PackSegments', ['l', 'd'], ['t'], pad_minf=True))225workspace.RunOperatorOnce(core.CreateOperator(226'Exp', ['t'], ['r']227))228result = workspace.FetchBlob('t')229assert(result[0, -1, 0] < -1000.0)230
231# The whole point of padding with -inf is that when we exponentiate it232# then it should be zero.233exponentiated = workspace.FetchBlob('r')234assert(exponentiated[0, -1, 0] == 0.0)235
236def test_pad_no_minf(self):237workspace.FeedBlob('l', np.array([1, 2, 3], dtype=np.int32))238workspace.FeedBlob(239'd',240np.array([241[1.0, 1.1],242[2.0, 2.1],243[2.2, 2.2],244[3.0, 3.1],245[3.2, 3.3],246[3.4, 3.5]],247dtype=np.float32))248workspace.RunOperatorOnce(249core.CreateOperator(250'PackSegments', ['l', 'd'], ['t'], pad_minf=False),251)252result = workspace.FetchBlob('t')253assert(result[0, -1, 0] == 0.0)254
255workspace.FeedBlob(256'i',257np.array([258[1, 1],259[2, 2],260[2, 2],261[3, 3],262[3, 3],263[3, 3]],264dtype=np.int32))265workspace.RunOperatorOnce(266core.CreateOperator(267'PackSegments', ['l', 'i'], ['t2'], pad_minf=False),268)269result = workspace.FetchBlob('t2')270assert(result[0, -1, 0] == 0)271
272@given(**hu.gcs)273def test_presence_mask(self, gc, dc):274lengths = np.array([1, 2, 3], dtype=np.int32)275data = np.array(276[277[1.0, 1.0], [2.0, 2.0], [2.0, 2.0], [3.0, 3.0], [3.0, 3.0],278[3.0, 3.0]279],280dtype=np.float32281)282
283op = core.CreateOperator(284'PackSegments', ['l', 'd'], ['t', 'p'], return_presence_mask=True285)286workspace.FeedBlob('l', lengths)287workspace.FeedBlob('d', data)288inputs = [lengths, data]289self.assertReferenceChecks(290device_option=gc,291op=op,292inputs=inputs,293reference=self.pack_segments_ref(return_presence_mask=True),294)295
296op = core.CreateOperator(297'PackSegments', ['l', 'd'], ['t', 'p'], return_presence_mask=True298)299workspace.RunOperatorOnce(op)300
301output = workspace.FetchBlob('t')302expected_output_shape = (3, 3, 2)303self.assertEqual(output.shape, expected_output_shape)304
305presence_mask = workspace.FetchBlob('p')306expected_presence_mask = np.array(307[[True, False, False], [True, True, False], [True, True, True]],308dtype=bool309)310self.assertEqual(presence_mask.shape, expected_presence_mask.shape)311np.testing.assert_array_equal(presence_mask, expected_presence_mask)312
313def test_presence_mask_empty(self):314lengths = np.array([], dtype=np.int32)315data = np.array([], dtype=np.float32)316
317op = core.CreateOperator(318'PackSegments', ['l', 'd'], ['t', 'p'], return_presence_mask=True319)320workspace.FeedBlob('l', lengths)321workspace.FeedBlob('d', data)322workspace.RunOperatorOnce(op)323
324output = workspace.FetchBlob('p')325expected_output_shape = (0, 0)326self.assertEqual(output.shape, expected_output_shape)327
328@given(**hu.gcs_cpu_only)329@settings(deadline=10000)330def test_out_of_bounds(self, gc, dc):331# Copy pasted from test_pack_ops but with 3 changed to 4332lengths = np.array([1, 2, 4], dtype=np.int32)333data = np.array([334[1.0, 1.0],335[2.0, 2.0],336[2.0, 2.0],337[3.0, 3.0],338[3.0, 3.0],339[3.0, 3.0]], dtype=np.float32)340op = core.CreateOperator(341'PackSegments', ['l', 'd'], ['t'])342
343inputs = [lengths, data]344self.assertRunOpRaises(345device_option=gc,346op=op,347inputs=inputs,348exception=RuntimeError349)350
351@given(**hu.gcs_cpu_only)352@settings(deadline=10000)353def test_under_bounds(self, gc, dc):354# Copy pasted from test_pack_ops but with 3 changed to 2355lengths = np.array([1, 2, 2], dtype=np.int32)356data = np.array([357[1.0, 1.0],358[2.0, 2.0],359[2.0, 2.0],360[3.0, 3.0],361[3.0, 3.0],362[3.0, 3.0]], dtype=np.float32)363op = core.CreateOperator(364'PackSegments', ['l', 'd'], ['t'])365
366inputs = [lengths, data]367self.assertRunOpRaises(368device_option=gc,369op=op,370inputs=inputs,371exception=RuntimeError372)373
374
375if __name__ == "__main__":376import unittest377unittest.main()378