1
# Copyright (c) 2013 Spotify AB
3
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4
# use this file except in compliance with the License. You may obtain a copy of
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12
# License for the specific language governing permissions and limitations under
20
from annoy import AnnoyIndex
23
def test_get_nns_by_vector():
25
i = AnnoyIndex(f, "angular")
26
i.add_item(0, [0, 0, 1])
27
i.add_item(1, [0, 1, 0])
28
i.add_item(2, [1, 0, 0])
31
assert i.get_nns_by_vector([3, 2, 1], 3) == [2, 1, 0]
32
assert i.get_nns_by_vector([1, 2, 3], 3) == [0, 1, 2]
33
assert i.get_nns_by_vector([2, 0, 1], 3) == [2, 0, 1]
36
def test_get_nns_by_item():
38
i = AnnoyIndex(f, "angular")
39
i.add_item(0, [2, 1, 0])
40
i.add_item(1, [1, 2, 0])
41
i.add_item(2, [0, 0, 1])
44
assert i.get_nns_by_item(0, 3) == [0, 1, 2]
45
assert i.get_nns_by_item(1, 3) == [1, 0, 2]
46
assert i.get_nns_by_item(2, 3) in [[2, 0, 1], [2, 1, 0]] # could be either
51
i = AnnoyIndex(f, "angular")
55
assert i.get_distance(0, 1) == pytest.approx((2 * (1.0 - 2**-0.5)) ** 0.5)
60
i = AnnoyIndex(f, "angular")
61
i.add_item(0, [1000, 0])
62
i.add_item(1, [10, 0])
64
assert i.get_distance(0, 1) == pytest.approx(0)
69
i = AnnoyIndex(f, "angular")
70
i.add_item(0, [97, 0])
71
i.add_item(1, [42, 42])
73
dist = ((1 - 2**-0.5) ** 2 + (2**-0.5) ** 2) ** 0.5
75
assert i.get_distance(0, 1) == pytest.approx(dist)
80
i = AnnoyIndex(f, "angular")
84
assert i.get_distance(0, 1) == pytest.approx(2.0**0.5)
87
def test_large_index():
88
# Generate pairs of random points where the pair is super close
90
i = AnnoyIndex(f, "angular")
91
for j in range(0, 10000, 2):
92
p = [random.gauss(0, 1) for z in range(f)]
93
f1 = random.random() + 1
94
f2 = random.random() + 1
95
x = [f1 * pi + random.gauss(0, 1e-2) for pi in p]
96
y = [f2 * pi + random.gauss(0, 1e-2) for pi in p]
101
for j in range(0, 10000, 2):
102
assert i.get_nns_by_item(j, 2) == [j, j + 1]
103
assert i.get_nns_by_item(j + 1, 2) == [j + 1, j]
106
def precision(n, n_trees=10, n_points=10000, n_rounds=10, search_k=100000):
108
for r in range(n_rounds):
109
# create random points at distance x from (1000, 0, 0, ...)
111
i = AnnoyIndex(f, "angular")
112
for j in range(n_points):
113
p = [random.gauss(0, 1) for z in range(f - 1)]
114
norm = sum([pi**2 for pi in p]) ** 0.5
115
x = [1000] + [pi / norm * j for pi in p]
120
nns = i.get_nns_by_vector([1000] + [0] * (f - 1), n, search_k)
121
assert nns == sorted(nns) # should be in order
122
# The number of gaps should be equal to the last item minus n-1
123
found += len([x for x in nns if x < n])
125
return 1.0 * found / (n * n_rounds)
128
def test_precision_1():
129
assert precision(1) >= 0.98
132
def test_precision_10():
133
assert precision(10) >= 0.98
136
def test_precision_100():
137
assert precision(100) >= 0.98
140
def test_precision_1000():
141
assert precision(1000) >= 0.98
144
def test_load_save_get_item_vector():
146
i = AnnoyIndex(f, "angular")
147
i.add_item(0, [1.1, 2.2, 3.3])
148
i.add_item(1, [4.4, 5.5, 6.6])
149
i.add_item(2, [7.7, 8.8, 9.9])
151
numpy.testing.assert_array_almost_equal(i.get_item_vector(0), [1.1, 2.2, 3.3])
153
assert i.save("blah.ann")
154
numpy.testing.assert_array_almost_equal(i.get_item_vector(1), [4.4, 5.5, 6.6])
155
j = AnnoyIndex(f, "angular")
156
assert j.load("blah.ann")
157
numpy.testing.assert_array_almost_equal(j.get_item_vector(2), [7.7, 8.8, 9.9])
160
def test_get_nns_search_k():
162
i = AnnoyIndex(f, "angular")
163
i.add_item(0, [0, 0, 1])
164
i.add_item(1, [0, 1, 0])
165
i.add_item(2, [1, 0, 0])
168
assert i.get_nns_by_item(0, 3, 10) == [0, 1, 2]
169
assert i.get_nns_by_vector([3, 2, 1], 3, 10) == [2, 1, 0]
172
def test_include_dists():
173
# Double checking issue 112
175
i = AnnoyIndex(f, "angular")
176
v = numpy.random.normal(size=f)
181
indices, dists = i.get_nns_by_item(0, 2, 10, True)
182
assert indices == [0, 1]
183
assert dists[0] == pytest.approx(0.0)
184
assert dists[1] == pytest.approx(2.0)
187
def test_include_dists_check_ranges():
189
i = AnnoyIndex(f, "angular")
190
for j in range(100000):
191
i.add_item(j, numpy.random.normal(size=f))
193
indices, dists = i.get_nns_by_item(0, 100000, include_distances=True)
194
assert max(dists) <= 2.0
195
assert min(dists) == pytest.approx(0.0)
198
def test_distance_consistency():
200
i = AnnoyIndex(f, "angular")
203
v = numpy.random.normal(size=f)
204
if numpy.dot(v, v) > 0.1:
208
for a in random.sample(range(n), 100):
209
indices, dists = i.get_nns_by_item(a, 100, include_distances=True)
210
for b, dist in zip(indices, dists):
211
u = i.get_item_vector(a)
212
v = i.get_item_vector(b)
213
assert dist == pytest.approx(i.get_distance(a, b), rel=1e-3, abs=1e-3)
214
u_norm = numpy.array(u) * numpy.dot(u, u) ** -0.5
215
v_norm = numpy.array(v) * numpy.dot(v, v) ** -0.5
216
# cos = numpy.clip(1 - cosine(u, v), -1, 1) # scipy returns 1 - cos
217
assert dist**2 == pytest.approx(
218
numpy.dot(u_norm - v_norm, u_norm - v_norm), rel=1e-3, abs=1e-3
220
# self.assertAlmostEqual(dist, (2*(1 - cos))**0.5)
221
assert dist**2 == pytest.approx(
222
sum([(x - y) ** 2 for x, y in zip(u_norm, v_norm)]),
228
def test_only_one_item():
229
# reported to annoy-user by Kireet Reddy
230
idx = AnnoyIndex(100, "angular")
231
idx.add_item(0, numpy.random.randn(100))
232
idx.build(n_trees=10)
234
idx = AnnoyIndex(100, "angular")
236
assert idx.get_n_items() == 1
237
assert idx.get_nns_by_vector(
238
vector=numpy.random.randn(100), n=50, include_distances=False
243
idx = AnnoyIndex(100, "angular")
244
idx.build(n_trees=10)
246
idx = AnnoyIndex(100, "angular")
248
assert idx.get_n_items() == 0
250
idx.get_nns_by_vector(
251
vector=numpy.random.randn(100), n=50, include_distances=False
257
def test_single_vector():
258
# https://github.com/spotify/annoy/issues/194
259
a = AnnoyIndex(3, "angular")
260
a.add_item(0, [1, 0, 0])
263
indices, dists = a.get_nns_by_vector([1, 0, 0], 3, include_distances=True)
264
assert indices == [0]
265
assert dists[0] ** 2 == pytest.approx(0.0)