annoy

Форк
0
/
manhattan_index_test.py 
163 строки · 4.4 Кб
1
# Copyright (c) 2013 Spotify AB
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4
# use this file except in compliance with the License. You may obtain a copy of
5
# the License at
6
#
7
# http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12
# License for the specific language governing permissions and limitations under
13
# the License.
14

15
import random
16

17
import numpy
18
import pytest
19

20
from annoy import AnnoyIndex
21

22

23
def test_get_nns_by_vector():
24
    f = 2
25
    i = AnnoyIndex(f, "manhattan")
26
    i.add_item(0, [2, 2])
27
    i.add_item(1, [3, 2])
28
    i.add_item(2, [3, 3])
29
    i.build(10)
30

31
    assert i.get_nns_by_vector([4, 4], 3) == [2, 1, 0]
32
    assert i.get_nns_by_vector([1, 1], 3) == [0, 1, 2]
33
    assert i.get_nns_by_vector([5, 3], 3) == [2, 1, 0]
34

35

36
def test_get_nns_by_item():
37
    f = 2
38
    i = AnnoyIndex(f, "manhattan")
39
    i.add_item(0, [2, 2])
40
    i.add_item(1, [3, 2])
41
    i.add_item(2, [3, 3])
42
    i.build(10)
43

44
    assert i.get_nns_by_item(0, 3) == [0, 1, 2]
45
    assert i.get_nns_by_item(2, 3) == [2, 1, 0]
46

47

48
def test_dist():
49
    f = 2
50
    i = AnnoyIndex(f, "manhattan")
51
    i.add_item(0, [0, 1])
52
    i.add_item(1, [1, 1])
53
    i.add_item(2, [0, 0])
54

55
    assert i.get_distance(0, 1) == pytest.approx(1.0)
56
    assert i.get_distance(1, 2) == pytest.approx(2.0)
57

58

59
def test_large_index():
60
    # Generate pairs of random points where the pair is super close
61
    f = 10
62
    i = AnnoyIndex(f, "manhattan")
63
    for j in range(0, 10000, 2):
64
        p = [random.gauss(0, 1) for z in range(f)]
65
        x = [1 + pi + random.gauss(0, 1e-2) for pi in p]
66
        y = [1 + pi + random.gauss(0, 1e-2) for pi in p]
67
        i.add_item(j, x)
68
        i.add_item(j + 1, y)
69

70
    i.build(10)
71
    for j in range(0, 10000, 2):
72
        assert i.get_nns_by_item(j, 2) == [j, j + 1]
73
        assert i.get_nns_by_item(j + 1, 2) == [j + 1, j]
74

75

76
def precision(n, n_trees=10, n_points=10000, n_rounds=10):
77
    found = 0
78
    for r in range(n_rounds):
79
        # create random points at distance x
80
        f = 10
81
        i = AnnoyIndex(f, "manhattan")
82
        for j in range(n_points):
83
            p = [random.gauss(0, 1) for z in range(f)]
84
            norm = sum([pi**2 for pi in p]) ** 0.5
85
            x = [pi / norm + j for pi in p]
86
            i.add_item(j, x)
87

88
        i.build(n_trees)
89

90
        nns = i.get_nns_by_vector([0] * f, n)
91
        assert nns == sorted(nns)  # should be in order
92
        # The number of gaps should be equal to the last item minus n-1
93
        found += len([x for x in nns if x < n])
94

95
    return 1.0 * found / (n * n_rounds)
96

97

98
def test_precision_1():
99
    assert precision(1) >= 0.98
100

101

102
def test_precision_10():
103
    assert precision(10) >= 0.98
104

105

106
def test_precision_100():
107
    assert precision(100) >= 0.98
108

109

110
def test_precision_1000():
111
    assert precision(1000) >= 0.98
112

113

114
def test_get_nns_with_distances():
115
    f = 3
116
    i = AnnoyIndex(f, "manhattan")
117
    i.add_item(0, [0, 0, 2])
118
    i.add_item(1, [0, 1, 1])
119
    i.add_item(2, [1, 0, 0])
120
    i.build(10)
121

122
    l, d = i.get_nns_by_item(0, 3, -1, True)
123
    assert l == [0, 1, 2]
124
    assert d[0] == pytest.approx(0)
125
    assert d[1] == pytest.approx(2)
126
    assert d[2] == pytest.approx(3)
127

128
    l, d = i.get_nns_by_vector([2, 2, 1], 3, -1, True)
129
    assert l == [1, 2, 0]
130
    assert d[0] == pytest.approx(3)
131
    assert d[1] == pytest.approx(4)
132
    assert d[2] == pytest.approx(5)
133

134

135
def test_include_dists():
136
    f = 40
137
    i = AnnoyIndex(f, "manhattan")
138
    v = numpy.random.normal(size=f)
139
    i.add_item(0, v)
140
    i.add_item(1, -v)
141
    i.build(10)
142

143
    indices, dists = i.get_nns_by_item(0, 2, 10, True)
144
    assert indices == [0, 1]
145
    assert dists[0] == pytest.approx(0)
146

147

148
def test_distance_consistency():
149
    n, f = 1000, 3
150
    i = AnnoyIndex(f, "manhattan")
151
    for j in range(n):
152
        i.add_item(j, numpy.random.normal(size=f))
153
    i.build(10)
154
    for a in random.sample(range(n), 100):
155
        indices, dists = i.get_nns_by_item(a, 100, include_distances=True)
156
        for b, dist in zip(indices, dists):
157
            assert dist == pytest.approx(i.get_distance(a, b))
158
            u = numpy.array(i.get_item_vector(a))
159
            v = numpy.array(i.get_item_vector(b))
160
            assert dist == pytest.approx(numpy.sum(numpy.fabs(u - v)))
161
            assert dist == pytest.approx(
162
                sum([abs(float(x) - float(y)) for x, y in zip(u, v)])
163
            )
164

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.