pytorch

Форк
0
241 строка · 8.9 Кб
1

2

3

4

5

6
import hypothesis.strategies as st
7
import numpy as np
8

9
from caffe2.python import core
10
from hypothesis import given, settings
11
import caffe2.python.hypothesis_test_util as hu
12
import caffe2.python.serialized_test.serialized_test_util as serial
13

14

15
class TestTopK(serial.SerializedTestCase):
16

17
    def top_k_ref(self, X, k, flatten_indices, axis=-1):
18
        in_dims = X.shape
19
        out_dims = list(in_dims)
20
        out_dims[axis] = k
21
        out_dims = tuple(out_dims)
22
        if axis == -1:
23
            axis = len(in_dims) - 1
24
        prev_dims = 1
25
        next_dims = 1
26
        for i in range(axis):
27
            prev_dims *= in_dims[i]
28
        for i in range(axis + 1, len(in_dims)):
29
            next_dims *= in_dims[i]
30
        n = in_dims[axis]
31
        X_flat = X.reshape((prev_dims, n, next_dims))
32

33
        values_ref = np.ndarray(
34
            shape=(prev_dims, k, next_dims), dtype=np.float32)
35
        values_ref.fill(0)
36
        indices_ref = np.ndarray(
37
            shape=(prev_dims, k, next_dims), dtype=np.int64)
38
        indices_ref.fill(-1)
39
        flatten_indices_ref = np.ndarray(
40
            shape=(prev_dims, k, next_dims), dtype=np.int64)
41
        flatten_indices_ref.fill(-1)
42
        for i in range(prev_dims):
43
            for j in range(next_dims):
44
                kv = []
45
                for x in range(n):
46
                    val = X_flat[i, x, j]
47
                    y = x * next_dims + i * in_dims[axis] * next_dims + j
48
                    kv.append((val, x, y))
49
                cnt = 0
50
                for val, x, y in sorted(
51
                        kv, key=lambda x: (x[0], -x[1]), reverse=True):
52
                    values_ref[i, cnt, j] = val
53
                    indices_ref[i, cnt, j] = x
54
                    flatten_indices_ref[i, cnt, j] = y
55
                    cnt += 1
56
                    if cnt >= k or cnt >= n:
57
                        break
58

59
        values_ref = values_ref.reshape(out_dims)
60
        indices_ref = indices_ref.reshape(out_dims)
61
        flatten_indices_ref = flatten_indices_ref.flatten()
62

63
        if flatten_indices:
64
            return (values_ref, indices_ref, flatten_indices_ref)
65
        else:
66
            return (values_ref, indices_ref)
67

68
    @serial.given(
69
        X=hu.tensor(),
70
        flatten_indices=st.booleans(),
71
        seed=st.integers(0, 10),
72
        **hu.gcs
73
    )
74
    def test_top_k(self, X, flatten_indices, seed, gc, dc):
75
        X = X.astype(dtype=np.float32)
76
        np.random.seed(seed)
77
        # `k` can be larger than the total size
78
        k = np.random.randint(1, X.shape[-1] + 4)
79

80
        output_list = ["Values", "Indices"]
81
        if flatten_indices:
82
            output_list.append("FlattenIndices")
83
        op = core.CreateOperator("TopK", ["X"], output_list,
84
                                 k=k, device_option=gc)
85

86
        def bind_ref(X_loc):
87
            return self.top_k_ref(X_loc, k, flatten_indices)
88

89
        self.assertReferenceChecks(gc, op, [X], bind_ref)
90
        self.assertDeviceChecks(dc, op, [X], [0])
91

92
    @given(bs=st.integers(1, 3), n=st.integers(1, 1), k=st.integers(1, 1),
93
           flatten_indices=st.booleans(), **hu.gcs)
94
    def test_top_k_1(self, bs, n, k, flatten_indices, gc, dc):
95
        X = np.random.rand(bs, n).astype(dtype=np.float32)
96
        output_list = ["Values", "Indices"]
97
        if flatten_indices:
98
            output_list.append("FlattenIndices")
99
        op = core.CreateOperator("TopK", ["X"], output_list,
100
                                 k=k, device_option=gc)
101

102
        def bind_ref(X_loc):
103
            return self.top_k_ref(X_loc, k, flatten_indices)
104

105
        self.assertReferenceChecks(gc, op, [X], bind_ref)
106
        self.assertDeviceChecks(dc, op, [X], [0])
107

108
    @given(bs=st.integers(1, 3), n=st.integers(1, 10000), k=st.integers(1, 1),
109
           flatten_indices=st.booleans(), **hu.gcs)
110
    def test_top_k_2(self, bs, n, k, flatten_indices, gc, dc):
111
        X = np.random.rand(bs, n).astype(dtype=np.float32)
112

113
        output_list = ["Values", "Indices"]
114
        if flatten_indices:
115
            output_list.append("FlattenIndices")
116
        op = core.CreateOperator("TopK", ["X"], output_list,
117
                                 k=k, device_option=gc)
118

119
        def bind_ref(X_loc):
120
            return self.top_k_ref(X_loc, k, flatten_indices)
121

122
        self.assertReferenceChecks(gc, op, [X], bind_ref)
123
        self.assertDeviceChecks(dc, op, [X], [0])
124

125
    @given(bs=st.integers(1, 3), n=st.integers(1, 10000),
126
           k=st.integers(1, 1024), flatten_indices=st.booleans(), **hu.gcs)
127
    def test_top_k_3(self, bs, n, k, flatten_indices, gc, dc):
128
        X = np.random.rand(bs, n).astype(dtype=np.float32)
129
        output_list = ["Values", "Indices"]
130
        if flatten_indices:
131
            output_list.append("FlattenIndices")
132
        op = core.CreateOperator("TopK", ["X"], output_list,
133
                                 k=k, device_option=gc)
134

135
        def bind_ref(X_loc):
136
            return self.top_k_ref(X_loc, k, flatten_indices)
137

138
        self.assertReferenceChecks(gc, op, [X], bind_ref)
139
        self.assertDeviceChecks(dc, op, [X], [0])
140

141
    @given(bs=st.integers(1, 3), n=st.integers(100, 10000),
142
           flatten_indices=st.booleans(), **hu.gcs)
143
    @settings(deadline=10000)
144
    def test_top_k_4(self, bs, n, flatten_indices, gc, dc):
145
        k = np.random.randint(n // 3, 3 * n // 4)
146
        X = np.random.rand(bs, n).astype(dtype=np.float32)
147

148
        output_list = ["Values", "Indices"]
149
        if flatten_indices:
150
            output_list.append("FlattenIndices")
151
        op = core.CreateOperator("TopK", ["X"], output_list,
152
                                 k=k, device_option=gc)
153

154
        def bind_ref(X_loc):
155
            return self.top_k_ref(X_loc, k, flatten_indices)
156

157
        self.assertReferenceChecks(gc, op, [X], bind_ref)
158
        self.assertDeviceChecks(dc, op, [X], [0])
159

160
    @given(bs=st.integers(1, 3), n=st.integers(1, 1024),
161
           flatten_indices=st.booleans(), **hu.gcs)
162
    def test_top_k_5(self, bs, n, flatten_indices, gc, dc):
163
        k = n
164
        X = np.random.rand(bs, n).astype(dtype=np.float32)
165

166
        output_list = ["Values", "Indices"]
167
        if flatten_indices:
168
            output_list.append("FlattenIndices")
169
        op = core.CreateOperator("TopK", ["X"], output_list,
170
                                 k=k, device_option=gc)
171

172
        def bind_ref(X_loc):
173
            return self.top_k_ref(X_loc, k, flatten_indices)
174

175
        self.assertReferenceChecks(gc, op, [X], bind_ref)
176
        self.assertDeviceChecks(dc, op, [X], [0])
177

178
    @given(bs=st.integers(1, 3), n=st.integers(1, 5000),
179
           flatten_indices=st.booleans(), **hu.gcs)
180
    @settings(deadline=10000)
181
    def test_top_k_6(self, bs, n, flatten_indices, gc, dc):
182
        k = n
183
        X = np.random.rand(bs, n).astype(dtype=np.float32)
184

185
        output_list = ["Values", "Indices"]
186
        if flatten_indices:
187
            output_list.append("FlattenIndices")
188
        op = core.CreateOperator("TopK", ["X"], output_list,
189
                                 k=k, device_option=gc)
190

191
        def bind_ref(X_loc):
192
            return self.top_k_ref(X_loc, k, flatten_indices)
193

194
        self.assertReferenceChecks(gc, op, [X], bind_ref)
195
        self.assertDeviceChecks(dc, op, [X], [0])
196

197
    @given(X=hu.tensor(dtype=np.float32), k=st.integers(1, 5),
198
           axis=st.integers(-1, 5), flatten_indices=st.booleans(),
199
           **hu.gcs)
200
    def test_top_k_axis(self, X, k, axis, flatten_indices, gc, dc):
201
        dims = X.shape
202
        if axis >= len(dims):
203
            axis %= len(dims)
204

205
        output_list = ["Values", "Indices"]
206
        if flatten_indices:
207
            output_list.append("FlattenIndices")
208
        op = core.CreateOperator(
209
            "TopK", ["X"], output_list, k=k, axis=axis, device_option=gc)
210

211
        def bind_ref(X_loc):
212
            return self.top_k_ref(X_loc, k, flatten_indices, axis)
213

214
        self.assertReferenceChecks(gc, op, [X], bind_ref)
215
        self.assertDeviceChecks(dc, op, [X], [0])
216

217
    @given(X=hu.tensor(dtype=np.float32), k=st.integers(1, 5),
218
           axis=st.integers(-1, 5), **hu.gcs)
219
    @settings(deadline=10000)
220
    def test_top_k_grad(self, X, k, axis, gc, dc):
221
        dims = X.shape
222
        if axis >= len(dims):
223
            axis %= len(dims)
224

225
        input_axis = len(dims) - 1 if axis == -1 else axis
226
        prev_dims = 1
227
        next_dims = 1
228
        for i in range(input_axis):
229
            prev_dims *= dims[i]
230
        for i in range(input_axis + 1, len(dims)):
231
            next_dims *= dims[i]
232

233
        X_flat = X.reshape((prev_dims, dims[input_axis], next_dims))
234
        for i in range(prev_dims):
235
            for j in range(next_dims):
236
                # this try to make sure adding stepsize (0.05)
237
                # will not change TopK selections at all
238
                X_flat[i, :, j] = np.arange(dims[axis], dtype=np.float32) / 5
239
                np.random.shuffle(X_flat[i, :, j])
240
        X = X_flat.reshape(dims)
241

242
        op = core.CreateOperator(
243
            "TopK", ["X"], ["Values", "Indices"], k=k, axis=axis,
244
            device_option=gc)
245

246
        self.assertGradientChecks(gc, op, [X], 0, [0], stepsize=0.05)
247

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.