annoy

Форк
0
/
angular_index_test.py 
265 строк · 7.4 Кб
1
# Copyright (c) 2013 Spotify AB
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4
# use this file except in compliance with the License. You may obtain a copy of
5
# the License at
6
#
7
# http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12
# License for the specific language governing permissions and limitations under
13
# the License.
14

15
import random
16

17
import numpy
18
import pytest
19

20
from annoy import AnnoyIndex
21

22

23
def test_get_nns_by_vector():
24
    f = 3
25
    i = AnnoyIndex(f, "angular")
26
    i.add_item(0, [0, 0, 1])
27
    i.add_item(1, [0, 1, 0])
28
    i.add_item(2, [1, 0, 0])
29
    i.build(10)
30

31
    assert i.get_nns_by_vector([3, 2, 1], 3) == [2, 1, 0]
32
    assert i.get_nns_by_vector([1, 2, 3], 3) == [0, 1, 2]
33
    assert i.get_nns_by_vector([2, 0, 1], 3) == [2, 0, 1]
34

35

36
def test_get_nns_by_item():
37
    f = 3
38
    i = AnnoyIndex(f, "angular")
39
    i.add_item(0, [2, 1, 0])
40
    i.add_item(1, [1, 2, 0])
41
    i.add_item(2, [0, 0, 1])
42
    i.build(10)
43

44
    assert i.get_nns_by_item(0, 3) == [0, 1, 2]
45
    assert i.get_nns_by_item(1, 3) == [1, 0, 2]
46
    assert i.get_nns_by_item(2, 3) in [[2, 0, 1], [2, 1, 0]]  # could be either
47

48

49
def test_dist():
50
    f = 2
51
    i = AnnoyIndex(f, "angular")
52
    i.add_item(0, [0, 1])
53
    i.add_item(1, [1, 1])
54

55
    assert i.get_distance(0, 1) == pytest.approx((2 * (1.0 - 2**-0.5)) ** 0.5)
56

57

58
def test_dist_2():
59
    f = 2
60
    i = AnnoyIndex(f, "angular")
61
    i.add_item(0, [1000, 0])
62
    i.add_item(1, [10, 0])
63

64
    assert i.get_distance(0, 1) == pytest.approx(0)
65

66

67
def test_dist_3():
68
    f = 2
69
    i = AnnoyIndex(f, "angular")
70
    i.add_item(0, [97, 0])
71
    i.add_item(1, [42, 42])
72

73
    dist = ((1 - 2**-0.5) ** 2 + (2**-0.5) ** 2) ** 0.5
74

75
    assert i.get_distance(0, 1) == pytest.approx(dist)
76

77

78
def test_dist_degen():
79
    f = 2
80
    i = AnnoyIndex(f, "angular")
81
    i.add_item(0, [1, 0])
82
    i.add_item(1, [0, 0])
83

84
    assert i.get_distance(0, 1) == pytest.approx(2.0**0.5)
85

86

87
def test_large_index():
88
    # Generate pairs of random points where the pair is super close
89
    f = 10
90
    i = AnnoyIndex(f, "angular")
91
    for j in range(0, 10000, 2):
92
        p = [random.gauss(0, 1) for z in range(f)]
93
        f1 = random.random() + 1
94
        f2 = random.random() + 1
95
        x = [f1 * pi + random.gauss(0, 1e-2) for pi in p]
96
        y = [f2 * pi + random.gauss(0, 1e-2) for pi in p]
97
        i.add_item(j, x)
98
        i.add_item(j + 1, y)
99

100
    i.build(10)
101
    for j in range(0, 10000, 2):
102
        assert i.get_nns_by_item(j, 2) == [j, j + 1]
103
        assert i.get_nns_by_item(j + 1, 2) == [j + 1, j]
104

105

106
def precision(n, n_trees=10, n_points=10000, n_rounds=10, search_k=100000):
107
    found = 0
108
    for r in range(n_rounds):
109
        # create random points at distance x from (1000, 0, 0, ...)
110
        f = 10
111
        i = AnnoyIndex(f, "angular")
112
        for j in range(n_points):
113
            p = [random.gauss(0, 1) for z in range(f - 1)]
114
            norm = sum([pi**2 for pi in p]) ** 0.5
115
            x = [1000] + [pi / norm * j for pi in p]
116
            i.add_item(j, x)
117

118
        i.build(n_trees)
119

120
        nns = i.get_nns_by_vector([1000] + [0] * (f - 1), n, search_k)
121
        assert nns == sorted(nns)  # should be in order
122
        # The number of gaps should be equal to the last item minus n-1
123
        found += len([x for x in nns if x < n])
124

125
    return 1.0 * found / (n * n_rounds)
126

127

128
def test_precision_1():
129
    assert precision(1) >= 0.98
130

131

132
def test_precision_10():
133
    assert precision(10) >= 0.98
134

135

136
def test_precision_100():
137
    assert precision(100) >= 0.98
138

139

140
def test_precision_1000():
141
    assert precision(1000) >= 0.98
142

143

144
def test_load_save_get_item_vector():
145
    f = 3
146
    i = AnnoyIndex(f, "angular")
147
    i.add_item(0, [1.1, 2.2, 3.3])
148
    i.add_item(1, [4.4, 5.5, 6.6])
149
    i.add_item(2, [7.7, 8.8, 9.9])
150

151
    numpy.testing.assert_array_almost_equal(i.get_item_vector(0), [1.1, 2.2, 3.3])
152
    assert i.build(10)
153
    assert i.save("blah.ann")
154
    numpy.testing.assert_array_almost_equal(i.get_item_vector(1), [4.4, 5.5, 6.6])
155
    j = AnnoyIndex(f, "angular")
156
    assert j.load("blah.ann")
157
    numpy.testing.assert_array_almost_equal(j.get_item_vector(2), [7.7, 8.8, 9.9])
158

159

160
def test_get_nns_search_k():
161
    f = 3
162
    i = AnnoyIndex(f, "angular")
163
    i.add_item(0, [0, 0, 1])
164
    i.add_item(1, [0, 1, 0])
165
    i.add_item(2, [1, 0, 0])
166
    i.build(10)
167

168
    assert i.get_nns_by_item(0, 3, 10) == [0, 1, 2]
169
    assert i.get_nns_by_vector([3, 2, 1], 3, 10) == [2, 1, 0]
170

171

172
def test_include_dists():
173
    # Double checking issue 112
174
    f = 40
175
    i = AnnoyIndex(f, "angular")
176
    v = numpy.random.normal(size=f)
177
    i.add_item(0, v)
178
    i.add_item(1, -v)
179
    i.build(10)
180

181
    indices, dists = i.get_nns_by_item(0, 2, 10, True)
182
    assert indices == [0, 1]
183
    assert dists[0] == pytest.approx(0.0)
184
    assert dists[1] == pytest.approx(2.0)
185

186

187
def test_include_dists_check_ranges():
188
    f = 3
189
    i = AnnoyIndex(f, "angular")
190
    for j in range(100000):
191
        i.add_item(j, numpy.random.normal(size=f))
192
    i.build(10)
193
    indices, dists = i.get_nns_by_item(0, 100000, include_distances=True)
194
    assert max(dists) <= 2.0
195
    assert min(dists) == pytest.approx(0.0)
196

197

198
def test_distance_consistency():
199
    n, f = 1000, 3
200
    i = AnnoyIndex(f, "angular")
201
    for j in range(n):
202
        while True:
203
            v = numpy.random.normal(size=f)
204
            if numpy.dot(v, v) > 0.1:
205
                break
206
        i.add_item(j, v)
207
    i.build(10)
208
    for a in random.sample(range(n), 100):
209
        indices, dists = i.get_nns_by_item(a, 100, include_distances=True)
210
        for b, dist in zip(indices, dists):
211
            u = i.get_item_vector(a)
212
            v = i.get_item_vector(b)
213
            assert dist == pytest.approx(i.get_distance(a, b), rel=1e-3, abs=1e-3)
214
            u_norm = numpy.array(u) * numpy.dot(u, u) ** -0.5
215
            v_norm = numpy.array(v) * numpy.dot(v, v) ** -0.5
216
            # cos = numpy.clip(1 - cosine(u, v), -1, 1) # scipy returns 1 - cos
217
            assert dist**2 == pytest.approx(
218
                numpy.dot(u_norm - v_norm, u_norm - v_norm), rel=1e-3, abs=1e-3
219
            )
220
            # self.assertAlmostEqual(dist, (2*(1 - cos))**0.5)
221
            assert dist**2 == pytest.approx(
222
                sum([(x - y) ** 2 for x, y in zip(u_norm, v_norm)]),
223
                rel=1e-3,
224
                abs=1e-3,
225
            )
226

227

228
def test_only_one_item():
229
    # reported to annoy-user by Kireet Reddy
230
    idx = AnnoyIndex(100, "angular")
231
    idx.add_item(0, numpy.random.randn(100))
232
    idx.build(n_trees=10)
233
    idx.save("foo.idx")
234
    idx = AnnoyIndex(100, "angular")
235
    idx.load("foo.idx")
236
    assert idx.get_n_items() == 1
237
    assert idx.get_nns_by_vector(
238
        vector=numpy.random.randn(100), n=50, include_distances=False
239
    ) == [0]
240

241

242
def test_no_items():
243
    idx = AnnoyIndex(100, "angular")
244
    idx.build(n_trees=10)
245
    idx.save("foo.idx")
246
    idx = AnnoyIndex(100, "angular")
247
    idx.load("foo.idx")
248
    assert idx.get_n_items() == 0
249
    assert (
250
        idx.get_nns_by_vector(
251
            vector=numpy.random.randn(100), n=50, include_distances=False
252
        )
253
        == []
254
    )
255

256

257
def test_single_vector():
258
    # https://github.com/spotify/annoy/issues/194
259
    a = AnnoyIndex(3, "angular")
260
    a.add_item(0, [1, 0, 0])
261
    a.build(10)
262
    a.save("1.ann")
263
    indices, dists = a.get_nns_by_vector([1, 0, 0], 3, include_distances=True)
264
    assert indices == [0]
265
    assert dists[0] ** 2 == pytest.approx(0.0)
266

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.