pytorch
241 строка · 8.9 Кб
1
2
3
4
5
6import hypothesis.strategies as st
7import numpy as np
8
9from caffe2.python import core
10from hypothesis import given, settings
11import caffe2.python.hypothesis_test_util as hu
12import caffe2.python.serialized_test.serialized_test_util as serial
13
14
15class TestTopK(serial.SerializedTestCase):
16
17def top_k_ref(self, X, k, flatten_indices, axis=-1):
18in_dims = X.shape
19out_dims = list(in_dims)
20out_dims[axis] = k
21out_dims = tuple(out_dims)
22if axis == -1:
23axis = len(in_dims) - 1
24prev_dims = 1
25next_dims = 1
26for i in range(axis):
27prev_dims *= in_dims[i]
28for i in range(axis + 1, len(in_dims)):
29next_dims *= in_dims[i]
30n = in_dims[axis]
31X_flat = X.reshape((prev_dims, n, next_dims))
32
33values_ref = np.ndarray(
34shape=(prev_dims, k, next_dims), dtype=np.float32)
35values_ref.fill(0)
36indices_ref = np.ndarray(
37shape=(prev_dims, k, next_dims), dtype=np.int64)
38indices_ref.fill(-1)
39flatten_indices_ref = np.ndarray(
40shape=(prev_dims, k, next_dims), dtype=np.int64)
41flatten_indices_ref.fill(-1)
42for i in range(prev_dims):
43for j in range(next_dims):
44kv = []
45for x in range(n):
46val = X_flat[i, x, j]
47y = x * next_dims + i * in_dims[axis] * next_dims + j
48kv.append((val, x, y))
49cnt = 0
50for val, x, y in sorted(
51kv, key=lambda x: (x[0], -x[1]), reverse=True):
52values_ref[i, cnt, j] = val
53indices_ref[i, cnt, j] = x
54flatten_indices_ref[i, cnt, j] = y
55cnt += 1
56if cnt >= k or cnt >= n:
57break
58
59values_ref = values_ref.reshape(out_dims)
60indices_ref = indices_ref.reshape(out_dims)
61flatten_indices_ref = flatten_indices_ref.flatten()
62
63if flatten_indices:
64return (values_ref, indices_ref, flatten_indices_ref)
65else:
66return (values_ref, indices_ref)
67
68@serial.given(
69X=hu.tensor(),
70flatten_indices=st.booleans(),
71seed=st.integers(0, 10),
72**hu.gcs
73)
74def test_top_k(self, X, flatten_indices, seed, gc, dc):
75X = X.astype(dtype=np.float32)
76np.random.seed(seed)
77# `k` can be larger than the total size
78k = np.random.randint(1, X.shape[-1] + 4)
79
80output_list = ["Values", "Indices"]
81if flatten_indices:
82output_list.append("FlattenIndices")
83op = core.CreateOperator("TopK", ["X"], output_list,
84k=k, device_option=gc)
85
86def bind_ref(X_loc):
87return self.top_k_ref(X_loc, k, flatten_indices)
88
89self.assertReferenceChecks(gc, op, [X], bind_ref)
90self.assertDeviceChecks(dc, op, [X], [0])
91
92@given(bs=st.integers(1, 3), n=st.integers(1, 1), k=st.integers(1, 1),
93flatten_indices=st.booleans(), **hu.gcs)
94def test_top_k_1(self, bs, n, k, flatten_indices, gc, dc):
95X = np.random.rand(bs, n).astype(dtype=np.float32)
96output_list = ["Values", "Indices"]
97if flatten_indices:
98output_list.append("FlattenIndices")
99op = core.CreateOperator("TopK", ["X"], output_list,
100k=k, device_option=gc)
101
102def bind_ref(X_loc):
103return self.top_k_ref(X_loc, k, flatten_indices)
104
105self.assertReferenceChecks(gc, op, [X], bind_ref)
106self.assertDeviceChecks(dc, op, [X], [0])
107
108@given(bs=st.integers(1, 3), n=st.integers(1, 10000), k=st.integers(1, 1),
109flatten_indices=st.booleans(), **hu.gcs)
110def test_top_k_2(self, bs, n, k, flatten_indices, gc, dc):
111X = np.random.rand(bs, n).astype(dtype=np.float32)
112
113output_list = ["Values", "Indices"]
114if flatten_indices:
115output_list.append("FlattenIndices")
116op = core.CreateOperator("TopK", ["X"], output_list,
117k=k, device_option=gc)
118
119def bind_ref(X_loc):
120return self.top_k_ref(X_loc, k, flatten_indices)
121
122self.assertReferenceChecks(gc, op, [X], bind_ref)
123self.assertDeviceChecks(dc, op, [X], [0])
124
125@given(bs=st.integers(1, 3), n=st.integers(1, 10000),
126k=st.integers(1, 1024), flatten_indices=st.booleans(), **hu.gcs)
127def test_top_k_3(self, bs, n, k, flatten_indices, gc, dc):
128X = np.random.rand(bs, n).astype(dtype=np.float32)
129output_list = ["Values", "Indices"]
130if flatten_indices:
131output_list.append("FlattenIndices")
132op = core.CreateOperator("TopK", ["X"], output_list,
133k=k, device_option=gc)
134
135def bind_ref(X_loc):
136return self.top_k_ref(X_loc, k, flatten_indices)
137
138self.assertReferenceChecks(gc, op, [X], bind_ref)
139self.assertDeviceChecks(dc, op, [X], [0])
140
141@given(bs=st.integers(1, 3), n=st.integers(100, 10000),
142flatten_indices=st.booleans(), **hu.gcs)
143@settings(deadline=10000)
144def test_top_k_4(self, bs, n, flatten_indices, gc, dc):
145k = np.random.randint(n // 3, 3 * n // 4)
146X = np.random.rand(bs, n).astype(dtype=np.float32)
147
148output_list = ["Values", "Indices"]
149if flatten_indices:
150output_list.append("FlattenIndices")
151op = core.CreateOperator("TopK", ["X"], output_list,
152k=k, device_option=gc)
153
154def bind_ref(X_loc):
155return self.top_k_ref(X_loc, k, flatten_indices)
156
157self.assertReferenceChecks(gc, op, [X], bind_ref)
158self.assertDeviceChecks(dc, op, [X], [0])
159
160@given(bs=st.integers(1, 3), n=st.integers(1, 1024),
161flatten_indices=st.booleans(), **hu.gcs)
162def test_top_k_5(self, bs, n, flatten_indices, gc, dc):
163k = n
164X = np.random.rand(bs, n).astype(dtype=np.float32)
165
166output_list = ["Values", "Indices"]
167if flatten_indices:
168output_list.append("FlattenIndices")
169op = core.CreateOperator("TopK", ["X"], output_list,
170k=k, device_option=gc)
171
172def bind_ref(X_loc):
173return self.top_k_ref(X_loc, k, flatten_indices)
174
175self.assertReferenceChecks(gc, op, [X], bind_ref)
176self.assertDeviceChecks(dc, op, [X], [0])
177
178@given(bs=st.integers(1, 3), n=st.integers(1, 5000),
179flatten_indices=st.booleans(), **hu.gcs)
180@settings(deadline=10000)
181def test_top_k_6(self, bs, n, flatten_indices, gc, dc):
182k = n
183X = np.random.rand(bs, n).astype(dtype=np.float32)
184
185output_list = ["Values", "Indices"]
186if flatten_indices:
187output_list.append("FlattenIndices")
188op = core.CreateOperator("TopK", ["X"], output_list,
189k=k, device_option=gc)
190
191def bind_ref(X_loc):
192return self.top_k_ref(X_loc, k, flatten_indices)
193
194self.assertReferenceChecks(gc, op, [X], bind_ref)
195self.assertDeviceChecks(dc, op, [X], [0])
196
197@given(X=hu.tensor(dtype=np.float32), k=st.integers(1, 5),
198axis=st.integers(-1, 5), flatten_indices=st.booleans(),
199**hu.gcs)
200def test_top_k_axis(self, X, k, axis, flatten_indices, gc, dc):
201dims = X.shape
202if axis >= len(dims):
203axis %= len(dims)
204
205output_list = ["Values", "Indices"]
206if flatten_indices:
207output_list.append("FlattenIndices")
208op = core.CreateOperator(
209"TopK", ["X"], output_list, k=k, axis=axis, device_option=gc)
210
211def bind_ref(X_loc):
212return self.top_k_ref(X_loc, k, flatten_indices, axis)
213
214self.assertReferenceChecks(gc, op, [X], bind_ref)
215self.assertDeviceChecks(dc, op, [X], [0])
216
217@given(X=hu.tensor(dtype=np.float32), k=st.integers(1, 5),
218axis=st.integers(-1, 5), **hu.gcs)
219@settings(deadline=10000)
220def test_top_k_grad(self, X, k, axis, gc, dc):
221dims = X.shape
222if axis >= len(dims):
223axis %= len(dims)
224
225input_axis = len(dims) - 1 if axis == -1 else axis
226prev_dims = 1
227next_dims = 1
228for i in range(input_axis):
229prev_dims *= dims[i]
230for i in range(input_axis + 1, len(dims)):
231next_dims *= dims[i]
232
233X_flat = X.reshape((prev_dims, dims[input_axis], next_dims))
234for i in range(prev_dims):
235for j in range(next_dims):
236# this try to make sure adding stepsize (0.05)
237# will not change TopK selections at all
238X_flat[i, :, j] = np.arange(dims[axis], dtype=np.float32) / 5
239np.random.shuffle(X_flat[i, :, j])
240X = X_flat.reshape(dims)
241
242op = core.CreateOperator(
243"TopK", ["X"], ["Values", "Indices"], k=k, axis=axis,
244device_option=gc)
245
246self.assertGradientChecks(gc, op, [X], 0, [0], stepsize=0.05)
247