1
# Copyright (c) 2013 Spotify AB
3
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4
# use this file except in compliance with the License. You may obtain a copy of
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12
# License for the specific language governing permissions and limitations under
20
from annoy import AnnoyIndex
23
def test_get_nns_by_vector():
25
i = AnnoyIndex(f, "manhattan")
31
assert i.get_nns_by_vector([4, 4], 3) == [2, 1, 0]
32
assert i.get_nns_by_vector([1, 1], 3) == [0, 1, 2]
33
assert i.get_nns_by_vector([5, 3], 3) == [2, 1, 0]
36
def test_get_nns_by_item():
38
i = AnnoyIndex(f, "manhattan")
44
assert i.get_nns_by_item(0, 3) == [0, 1, 2]
45
assert i.get_nns_by_item(2, 3) == [2, 1, 0]
50
i = AnnoyIndex(f, "manhattan")
55
assert i.get_distance(0, 1) == pytest.approx(1.0)
56
assert i.get_distance(1, 2) == pytest.approx(2.0)
59
def test_large_index():
60
# Generate pairs of random points where the pair is super close
62
i = AnnoyIndex(f, "manhattan")
63
for j in range(0, 10000, 2):
64
p = [random.gauss(0, 1) for z in range(f)]
65
x = [1 + pi + random.gauss(0, 1e-2) for pi in p]
66
y = [1 + pi + random.gauss(0, 1e-2) for pi in p]
71
for j in range(0, 10000, 2):
72
assert i.get_nns_by_item(j, 2) == [j, j + 1]
73
assert i.get_nns_by_item(j + 1, 2) == [j + 1, j]
76
def precision(n, n_trees=10, n_points=10000, n_rounds=10):
78
for r in range(n_rounds):
79
# create random points at distance x
81
i = AnnoyIndex(f, "manhattan")
82
for j in range(n_points):
83
p = [random.gauss(0, 1) for z in range(f)]
84
norm = sum([pi**2 for pi in p]) ** 0.5
85
x = [pi / norm + j for pi in p]
90
nns = i.get_nns_by_vector([0] * f, n)
91
assert nns == sorted(nns) # should be in order
92
# The number of gaps should be equal to the last item minus n-1
93
found += len([x for x in nns if x < n])
95
return 1.0 * found / (n * n_rounds)
98
def test_precision_1():
99
assert precision(1) >= 0.98
102
def test_precision_10():
103
assert precision(10) >= 0.98
106
def test_precision_100():
107
assert precision(100) >= 0.98
110
def test_precision_1000():
111
assert precision(1000) >= 0.98
114
def test_get_nns_with_distances():
116
i = AnnoyIndex(f, "manhattan")
117
i.add_item(0, [0, 0, 2])
118
i.add_item(1, [0, 1, 1])
119
i.add_item(2, [1, 0, 0])
122
l, d = i.get_nns_by_item(0, 3, -1, True)
123
assert l == [0, 1, 2]
124
assert d[0] == pytest.approx(0)
125
assert d[1] == pytest.approx(2)
126
assert d[2] == pytest.approx(3)
128
l, d = i.get_nns_by_vector([2, 2, 1], 3, -1, True)
129
assert l == [1, 2, 0]
130
assert d[0] == pytest.approx(3)
131
assert d[1] == pytest.approx(4)
132
assert d[2] == pytest.approx(5)
135
def test_include_dists():
137
i = AnnoyIndex(f, "manhattan")
138
v = numpy.random.normal(size=f)
143
indices, dists = i.get_nns_by_item(0, 2, 10, True)
144
assert indices == [0, 1]
145
assert dists[0] == pytest.approx(0)
148
def test_distance_consistency():
150
i = AnnoyIndex(f, "manhattan")
152
i.add_item(j, numpy.random.normal(size=f))
154
for a in random.sample(range(n), 100):
155
indices, dists = i.get_nns_by_item(a, 100, include_distances=True)
156
for b, dist in zip(indices, dists):
157
assert dist == pytest.approx(i.get_distance(a, b))
158
u = numpy.array(i.get_item_vector(a))
159
v = numpy.array(i.get_item_vector(b))
160
assert dist == pytest.approx(numpy.sum(numpy.fabs(u - v)))
161
assert dist == pytest.approx(
162
sum([abs(float(x) - float(y)) for x, y in zip(u, v)])