pytorch

Форк
0
/
gather_ranges_op_test.py 
275 строк · 8.9 Кб
1

2

3
import caffe2.python.hypothesis_test_util as hu
4
import caffe2.python.serialized_test.serialized_test_util as serial
5
import numpy as np
6
from caffe2.python import core, workspace
7
from hypothesis import given, settings, strategies as st
8

9

10
def batched_boarders_and_data(
11
    data_min_size=5,
12
    data_max_size=10,
13
    examples_min_number=1,
14
    examples_max_number=4,
15
    example_min_size=1,
16
    example_max_size=3,
17
    dtype=np.float32,
18
    elements=None,
19
):
20
    dims_ = st.tuples(
21
        st.integers(min_value=data_min_size, max_value=data_max_size),
22
        st.integers(min_value=examples_min_number, max_value=examples_max_number),
23
        st.integers(min_value=example_min_size, max_value=example_max_size),
24
    )
25
    return dims_.flatmap(
26
        lambda dims: st.tuples(
27
            hu.arrays(
28
                [dims[1], dims[2], 2],
29
                dtype=np.int32,
30
                elements=st.integers(min_value=0, max_value=dims[0]),
31
            ),
32
            hu.arrays([dims[0]], dtype, elements),
33
        )
34
    )
35

36

37
@st.composite
38
def _tensor_splits(draw):
39
    lengths = draw(st.lists(st.integers(1, 5), min_size=1, max_size=10))
40
    batch_size = draw(st.integers(1, 5))
41
    element_pairs = [
42
        (batch, r) for batch in range(batch_size) for r in range(len(lengths))
43
    ]
44
    perm = draw(st.permutations(element_pairs))
45
    perm = perm[:-1]  # skip one range
46
    ranges = [[(0, 0)] * len(lengths) for _ in range(batch_size)]
47
    offset = 0
48
    for pair in perm:
49
        ranges[pair[0]][pair[1]] = (offset, lengths[pair[1]])
50
        offset += lengths[pair[1]]
51

52
    data = draw(
53
        st.lists(
54
            st.floats(min_value=-1.0, max_value=1.0), min_size=offset, max_size=offset
55
        )
56
    )
57

58
    key = draw(st.permutations(range(offset)))
59

60
    return (
61
        np.array(data).astype(np.float32),
62
        np.array(ranges),
63
        np.array(lengths),
64
        np.array(key).astype(np.int64),
65
    )
66

67

68
@st.composite
69
def _bad_tensor_splits(draw):
70
    lengths = draw(st.lists(st.integers(4, 6), min_size=4, max_size=4))
71
    batch_size = 4
72
    element_pairs = [
73
        (batch, r) for batch in range(batch_size) for r in range(len(lengths))
74
    ]
75
    perm = draw(st.permutations(element_pairs))
76
    ranges = [[(0, 0)] * len(lengths) for _ in range(batch_size)]
77
    offset = 0
78

79
    # Inject some bad samples depending on the batch.
80
    # Batch 2: length is set to 0. This way, 25% of the samples are empty.
81
    # Batch 0-1: length is set to half the original length. This way, 50% of the
82
    # samples are of mismatched length.
83
    for pair in perm:
84
        if pair[0] == 2:
85
            length = 0
86
        elif pair[0] <= 1:
87
            length = lengths[pair[1]] // 2
88
        else:
89
            length = lengths[pair[1]]
90
        ranges[pair[0]][pair[1]] = (offset, length)
91
        offset += length
92

93
    data = draw(
94
        st.lists(
95
            st.floats(min_value=-1.0, max_value=1.0), min_size=offset, max_size=offset
96
        )
97
    )
98

99
    key = draw(st.permutations(range(offset)))
100

101
    return (
102
        np.array(data).astype(np.float32),
103
        np.array(ranges),
104
        np.array(lengths),
105
        np.array(key).astype(np.int64),
106
    )
107

108

109
def gather_ranges(data, ranges):
110
    lengths = []
111
    output = []
112
    for example_ranges in ranges:
113
        length = 0
114
        for range in example_ranges:
115
            assert len(range) == 2
116
            output.extend(data[range[0] : range[0] + range[1]])
117
            length += range[1]
118
        lengths.append(length)
119
    return output, lengths
120

121

122
def gather_ranges_to_dense(data, ranges, lengths):
123
    outputs = []
124
    assert len(ranges)
125
    batch_size = len(ranges)
126
    assert len(ranges[0])
127
    num_ranges = len(ranges[0])
128
    assert ranges.shape[2] == 2
129
    for i in range(num_ranges):
130
        out = []
131
        for j in range(batch_size):
132
            start, length = ranges[j][i]
133
            if not length:
134
                out.append([0] * lengths[i])
135
            else:
136
                assert length == lengths[i]
137
                out.append(data[start : start + length])
138
        outputs.append(np.array(out))
139
    return outputs
140

141

142
def gather_ranges_to_dense_with_key(data, ranges, key, lengths):
143
    outputs = []
144
    assert len(ranges)
145
    batch_size = len(ranges)
146
    assert len(ranges[0])
147
    num_ranges = len(ranges[0])
148
    assert ranges.shape[2] == 2
149
    for i in range(num_ranges):
150
        out = []
151
        for j in range(batch_size):
152
            start, length = ranges[j][i]
153
            if not length:
154
                out.append([0] * lengths[i])
155
            else:
156
                assert length == lengths[i]
157
                key_data_list = zip(
158
                    key[start : start + length], data[start : start + length]
159
                )
160
                sorted_key_data_list = sorted(key_data_list, key=lambda x: x[0])
161
                sorted_data = [d for (k, d) in sorted_key_data_list]
162
                out.append(sorted_data)
163
        outputs.append(np.array(out))
164
    return outputs
165

166

167
class TestGatherRanges(serial.SerializedTestCase):
168
    @given(boarders_and_data=batched_boarders_and_data(), **hu.gcs_cpu_only)
169
    @settings(deadline=10000)
170
    def test_gather_ranges(self, boarders_and_data, gc, dc):
171
        boarders, data = boarders_and_data
172

173
        def boarders_to_range(boarders):
174
            assert len(boarders) == 2
175
            boarders = sorted(boarders)
176
            return [boarders[0], boarders[1] - boarders[0]]
177

178
        ranges = np.apply_along_axis(boarders_to_range, 2, boarders)
179

180
        self.assertReferenceChecks(
181
            device_option=gc,
182
            op=core.CreateOperator(
183
                "GatherRanges", ["data", "ranges"], ["output", "lengths"]
184
            ),
185
            inputs=[data, ranges],
186
            reference=gather_ranges,
187
        )
188

189
    @given(tensor_splits=_tensor_splits(), **hu.gcs_cpu_only)
190
    @settings(deadline=10000)
191
    def test_gather_ranges_split(self, tensor_splits, gc, dc):
192
        data, ranges, lengths, _ = tensor_splits
193

194
        self.assertReferenceChecks(
195
            device_option=gc,
196
            op=core.CreateOperator(
197
                "GatherRangesToDense",
198
                ["data", "ranges"],
199
                ["X_{}".format(i) for i in range(len(lengths))],
200
                lengths=lengths,
201
            ),
202
            inputs=[data, ranges, lengths],
203
            reference=gather_ranges_to_dense,
204
        )
205

206
    @given(tensor_splits=_tensor_splits(), **hu.gcs_cpu_only)
207
    def test_gather_ranges_with_key_split(self, tensor_splits, gc, dc):
208
        data, ranges, lengths, key = tensor_splits
209

210
        self.assertReferenceChecks(
211
            device_option=gc,
212
            op=core.CreateOperator(
213
                "GatherRangesToDense",
214
                ["data", "ranges", "key"],
215
                ["X_{}".format(i) for i in range(len(lengths))],
216
                lengths=lengths,
217
            ),
218
            inputs=[data, ranges, key, lengths],
219
            reference=gather_ranges_to_dense_with_key,
220
        )
221

222
    def test_shape_and_type_inference(self):
223
        with hu.temp_workspace("shape_type_inf_int32"):
224
            net = core.Net("test_net")
225
            net.ConstantFill([], "ranges", shape=[3, 5, 2], dtype=core.DataType.INT32)
226
            net.ConstantFill([], "values", shape=[64], dtype=core.DataType.INT64)
227
            net.GatherRanges(["values", "ranges"], ["values_output", "lengths_output"])
228
            (shapes, types) = workspace.InferShapesAndTypes([net], {})
229

230
            self.assertEqual(shapes["values_output"], [64])
231
            self.assertEqual(types["values_output"], core.DataType.INT64)
232
            self.assertEqual(shapes["lengths_output"], [3])
233
            self.assertEqual(types["lengths_output"], core.DataType.INT32)
234

235
    @given(tensor_splits=_bad_tensor_splits(), **hu.gcs_cpu_only)
236
    @settings(deadline=10000)
237
    def test_empty_range_check(self, tensor_splits, gc, dc):
238
        data, ranges, lengths, key = tensor_splits
239

240
        workspace.FeedBlob("data", data)
241
        workspace.FeedBlob("ranges", ranges)
242
        workspace.FeedBlob("key", key)
243

244
        def getOpWithThreshold(
245
            min_observation=2, max_mismatched_ratio=0.5, max_empty_ratio=None
246
        ):
247
            return core.CreateOperator(
248
                "GatherRangesToDense",
249
                ["data", "ranges", "key"],
250
                ["X_{}".format(i) for i in range(len(lengths))],
251
                lengths=lengths,
252
                min_observation=min_observation,
253
                max_mismatched_ratio=max_mismatched_ratio,
254
                max_empty_ratio=max_empty_ratio,
255
            )
256

257
        workspace.RunOperatorOnce(getOpWithThreshold())
258

259
        workspace.RunOperatorOnce(
260
            getOpWithThreshold(max_mismatched_ratio=0.3, min_observation=50)
261
        )
262

263
        with self.assertRaises(RuntimeError):
264
            workspace.RunOperatorOnce(
265
                getOpWithThreshold(max_mismatched_ratio=0.3, min_observation=5)
266
            )
267

268
        with self.assertRaises(RuntimeError):
269
            workspace.RunOperatorOnce(
270
                getOpWithThreshold(min_observation=50, max_empty_ratio=0.01)
271
            )
272

273

274
if __name__ == "__main__":
275
    import unittest
276

277
    unittest.main()
278

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.