pytorch

Форк
0
/
filler_ops_test.py 
266 строк · 8.3 Кб
1

2

3

4

5

6
from caffe2.proto import caffe2_pb2
7
from caffe2.python import core, workspace
8
import caffe2.python.hypothesis_test_util as hu
9
import caffe2.python.serialized_test.serialized_test_util as serial
10

11
from hypothesis import given, settings
12
import hypothesis.strategies as st
13
import numpy as np
14

15

16
def _fill_diagonal(shape, value):
17
    result = np.zeros(shape)
18
    np.fill_diagonal(result, value)
19
    return (result,)
20

21

22
class TestFillerOperator(serial.SerializedTestCase):
23

24
    @given(**hu.gcs)
25
    @settings(deadline=10000)
26
    def test_shape_error(self, gc, dc):
27
        op = core.CreateOperator(
28
            'GaussianFill',
29
            [],
30
            'out',
31
            shape=32,  # illegal parameter
32
            mean=0.0,
33
            std=1.0,
34
        )
35
        exception = False
36
        try:
37
            workspace.RunOperatorOnce(op)
38
        except Exception:
39
            exception = True
40
        self.assertTrue(exception, "Did not throw exception on illegal shape")
41

42
        op = core.CreateOperator(
43
            'ConstantFill',
44
            [],
45
            'out',
46
            shape=[],  # scalar
47
            value=2.0,
48
        )
49
        exception = False
50
        self.assertTrue(workspace.RunOperatorOnce(op))
51
        self.assertEqual(workspace.FetchBlob('out'), [2.0])
52

53
    @given(**hu.gcs)
54
    @settings(deadline=10000)
55
    def test_int64_shape(self, gc, dc):
56
        large_dim = 2 ** 31 + 1
57
        net = core.Net("test_shape_net")
58
        net.UniformFill(
59
            [],
60
            'out',
61
            shape=[0, large_dim],
62
            min=0.0,
63
            max=1.0,
64
        )
65
        self.assertTrue(workspace.CreateNet(net))
66
        self.assertTrue(workspace.RunNet(net.Name()))
67
        self.assertEqual(workspace.blobs['out'].shape, (0, large_dim))
68

69
    @given(
70
        shape=hu.dims().flatmap(
71
            lambda dims: hu.arrays(
72
                [dims], dtype=np.int64,
73
                elements=st.integers(min_value=0, max_value=20)
74
            )
75
        ),
76
        a=st.integers(min_value=0, max_value=100),
77
        b=st.integers(min_value=0, max_value=100),
78
        **hu.gcs
79
    )
80
    @settings(deadline=10000)
81
    def test_uniform_int_fill_op_blob_input(self, shape, a, b, gc, dc):
82
        net = core.Net('test_net')
83

84
        with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
85
            shape_blob = net.Const(shape, dtype=np.int64)
86
        a_blob = net.Const(a, dtype=np.int32)
87
        b_blob = net.Const(b, dtype=np.int32)
88
        uniform_fill = net.UniformIntFill([shape_blob, a_blob, b_blob],
89
                                          1, input_as_shape=1)
90

91
        workspace.RunNetOnce(net)
92

93
        blob_out = workspace.FetchBlob(uniform_fill)
94
        if b < a:
95
            new_shape = shape[:]
96
            new_shape[0] = 0
97
            np.testing.assert_array_equal(new_shape, blob_out.shape)
98
        else:
99
            np.testing.assert_array_equal(shape, blob_out.shape)
100
            self.assertTrue((blob_out >= a).all())
101
            self.assertTrue((blob_out <= b).all())
102

103
    @given(
104
        **hu.gcs
105
    )
106
    def test_uniform_fill_using_arg(self, gc, dc):
107
        net = core.Net('test_net')
108
        shape = [2**3, 5]
109
        # uncomment this to test filling large blob
110
        # shape = [2**30, 5]
111
        min_v = -100
112
        max_v = 100
113
        output_blob = net.UniformIntFill(
114
            [],
115
            ['output_blob'],
116
            shape=shape,
117
            min=min_v,
118
            max=max_v,
119
        )
120

121
        workspace.RunNetOnce(net)
122
        output_data = workspace.FetchBlob(output_blob)
123

124
        np.testing.assert_array_equal(shape, output_data.shape)
125
        min_data = np.min(output_data)
126
        max_data = np.max(output_data)
127

128
        self.assertGreaterEqual(min_data, min_v)
129
        self.assertLessEqual(max_data, max_v)
130

131
        self.assertNotEqual(min_data, max_data)
132

133
    @serial.given(
134
        shape=st.sampled_from(
135
            [
136
                [3, 3],
137
                [5, 5, 5],
138
                [7, 7, 7, 7],
139
            ]
140
        ),
141
        **hu.gcs
142
    )
143
    def test_diagonal_fill_op_float(self, shape, gc, dc):
144
        value = 2.5
145
        op = core.CreateOperator(
146
            'DiagonalFill',
147
            [],
148
            'out',
149
            shape=shape,  # scalar
150
            value=value,
151
        )
152

153
        for device_option in dc:
154
            op.device_option.CopyFrom(device_option)
155
            # Check against numpy reference
156
            self.assertReferenceChecks(gc, op, [shape, value], _fill_diagonal)
157

158
    @given(**hu.gcs)
159
    def test_diagonal_fill_op_int(self, gc, dc):
160
        value = 2
161
        shape = [3, 3]
162
        op = core.CreateOperator(
163
            'DiagonalFill',
164
            [],
165
            'out',
166
            shape=shape,
167
            dtype=core.DataType.INT32,
168
            value=value,
169
        )
170

171
        # Check against numpy reference
172
        self.assertReferenceChecks(gc, op, [shape, value], _fill_diagonal)
173

174
    @serial.given(lengths=st.lists(st.integers(min_value=0, max_value=10),
175
                                   min_size=0,
176
                                   max_size=10),
177
           **hu.gcs)
178
    def test_lengths_range_fill(self, lengths, gc, dc):
179
        op = core.CreateOperator(
180
            "LengthsRangeFill",
181
            ["lengths"],
182
            ["increasing_seq"])
183

184
        def _len_range_fill(lengths):
185
            sids = []
186
            for _, l in enumerate(lengths):
187
                sids.extend(list(range(l)))
188
            return (np.array(sids, dtype=np.int32), )
189

190
        self.assertReferenceChecks(
191
            device_option=gc,
192
            op=op,
193
            inputs=[np.array(lengths, dtype=np.int32)],
194
            reference=_len_range_fill)
195

196
    @given(**hu.gcs)
197
    def test_gaussian_fill_op(self, gc, dc):
198
        op = core.CreateOperator(
199
            'GaussianFill',
200
            [],
201
            'out',
202
            shape=[17, 3, 3],  # sample odd dimensions
203
            mean=0.0,
204
            std=1.0,
205
        )
206

207
        for device_option in dc:
208
            op.device_option.CopyFrom(device_option)
209
            assert workspace.RunOperatorOnce(op), "GaussianFill op did not run "
210
            "successfully"
211

212
            blob_out = workspace.FetchBlob('out')
213
            assert np.count_nonzero(blob_out) > 0, "All generated elements are "
214
            "zeros. Is the random generator functioning correctly?"
215

216
    @given(**hu.gcs)
217
    def test_msra_fill_op(self, gc, dc):
218
        op = core.CreateOperator(
219
            'MSRAFill',
220
            [],
221
            'out',
222
            shape=[15, 5, 3],  # sample odd dimensions
223
        )
224
        for device_option in dc:
225
            op.device_option.CopyFrom(device_option)
226
            assert workspace.RunOperatorOnce(op), "MSRAFill op did not run "
227
            "successfully"
228

229
            blob_out = workspace.FetchBlob('out')
230
            assert np.count_nonzero(blob_out) > 0, "All generated elements are "
231
            "zeros. Is the random generator functioning correctly?"
232

233
    @given(
234
        min=st.integers(min_value=0, max_value=5),
235
        range=st.integers(min_value=1, max_value=10),
236
        emb_size=st.sampled_from((10000, 20000, 30000)),
237
        dim_size=st.sampled_from((16, 32, 64)),
238
        **hu.gcs)
239
    @settings(deadline=None)
240
    def test_fp16_uniformfill_op(self, min, range, emb_size, dim_size, gc, dc):
241
        op = core.CreateOperator(
242
            'Float16UniformFill',
243
            [],
244
            'out',
245
            shape=[emb_size, dim_size],
246
            min=float(min),
247
            max=float(min + range),
248
        )
249
        for device_option in dc:
250
            op.device_option.CopyFrom(device_option)
251
            assert workspace.RunOperatorOnce(op), "Float16UniformFill op did not run successfully"
252

253
            self.assertEqual(workspace.blobs['out'].shape, (emb_size, dim_size))
254

255
            blob_out = workspace.FetchBlob('out')
256

257
            expected_type = "float16"
258
            expected_mean = min + range / 2.0
259
            expected_var = range * range / 12.0
260
            expected_min = min
261
            expected_max = min + range
262

263
            self.assertEqual(blob_out.dtype.name, expected_type)
264
            self.assertAlmostEqual(np.mean(blob_out, dtype=np.float32), expected_mean, delta=0.1)
265
            self.assertAlmostEqual(np.var(blob_out, dtype=np.float32), expected_var, delta=0.1)
266
            self.assertGreaterEqual(np.min(blob_out), expected_min)
267
            self.assertLessEqual(np.max(blob_out), expected_max)
268

269
if __name__ == "__main__":
270
    import unittest
271
    unittest.main()
272

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.