pytorch
46 строк · 1.3 Кб
1
2
3
4
5
6from caffe2.python import core7import caffe2.python.hypothesis_test_util as hu8import caffe2.python.serialized_test.serialized_test_util as serial9import hypothesis.strategies as st10import numpy as np11
12
13class TestLengthsTileOp(serial.SerializedTestCase):14
15@serial.given(16inputs=st.integers(min_value=1, max_value=20).flatmap(17lambda size: st.tuples(18hu.arrays([size], dtype=np.float32),19hu.arrays([size], dtype=np.int32,20elements=st.integers(min_value=0, max_value=20)),21)22),23**hu.gcs)24def test_lengths_tile(self, inputs, gc, dc):25data, lengths = inputs26
27def lengths_tile_op(data, lengths):28return [np.concatenate([29[d] * l for d, l in zip(data, lengths)30])]31
32op = core.CreateOperator(33"LengthsTile",34["data", "lengths"],35["output"],36)37
38self.assertReferenceChecks(39device_option=gc,40op=op,41inputs=[data, lengths],42reference=lengths_tile_op,43)44
45self.assertGradientChecks(46device_option=gc,47op=op,48inputs=[data, lengths],49outputs_to_check=0,50outputs_with_grads=[0]51)52