annoy

Форк
0
/
hamming_index_test.py 
123 строки · 4.1 Кб
1
# Copyright (c) 2013 Spotify AB
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4
# use this file except in compliance with the License. You may obtain a copy of
5
# the License at
6
#
7
# http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12
# License for the specific language governing permissions and limitations under
13
# the License.
14

15

16
import numpy
17
import pytest
18

19
from annoy import AnnoyIndex
20

21

22
def test_basic_conversion():
23
    f = 100
24
    i = AnnoyIndex(f, "hamming")
25
    u = numpy.random.binomial(1, 0.5, f)
26
    v = numpy.random.binomial(1, 0.5, f)
27
    i.add_item(0, u)
28
    i.add_item(1, v)
29
    u2 = i.get_item_vector(0)
30
    v2 = i.get_item_vector(1)
31
    assert numpy.dot(u - u2, u - u2) == pytest.approx(0.0)
32
    assert numpy.dot(v - v2, v - v2) == pytest.approx(0.0)
33
    assert i.get_distance(0, 0) == pytest.approx(0.0)
34
    assert i.get_distance(1, 1) == pytest.approx(0.0)
35
    assert i.get_distance(0, 1) == pytest.approx(numpy.dot(u - v, u - v))
36
    assert i.get_distance(1, 0) == pytest.approx(numpy.dot(u - v, u - v))
37

38

39
def test_basic_nns():
40
    f = 100
41
    i = AnnoyIndex(f, "hamming")
42
    u = numpy.random.binomial(1, 0.5, f)
43
    v = numpy.random.binomial(1, 0.5, f)
44
    i.add_item(0, u)
45
    i.add_item(1, v)
46
    i.build(10)
47
    assert i.get_nns_by_item(0, 99) == [0, 1]
48
    assert i.get_nns_by_item(1, 99) == [1, 0]
49
    rs, ds = i.get_nns_by_item(0, 99, include_distances=True)
50
    assert rs == [0, 1]
51
    assert ds[0] == pytest.approx(0)
52
    assert ds[1] == pytest.approx(numpy.dot(u - v, u - v))
53

54

55
def test_save_load():
56
    f = 100
57
    i = AnnoyIndex(f, "hamming")
58
    u = numpy.random.binomial(1, 0.5, f)
59
    v = numpy.random.binomial(1, 0.5, f)
60
    i.add_item(0, u)
61
    i.add_item(1, v)
62
    i.build(10)
63
    i.save("blah.ann")
64
    j = AnnoyIndex(f, "hamming")
65
    j.load("blah.ann")
66
    rs, ds = j.get_nns_by_item(0, 99, include_distances=True)
67
    assert rs == [0, 1]
68
    assert ds[0] == pytest.approx(0)
69
    assert ds[1] == pytest.approx(numpy.dot(u - v, u - v))
70

71

72
def test_many_vectors():
73
    f = 10
74
    i = AnnoyIndex(f, "hamming")
75
    for x in range(100000):
76
        i.add_item(x, numpy.random.binomial(1, 0.5, f))
77
    i.build(10)
78

79
    rs, ds = i.get_nns_by_vector([0] * f, 10000, include_distances=True)
80
    assert min(ds) >= 0
81
    assert max(ds) <= f
82

83
    dists = []
84
    for x in range(1000):
85
        rs, ds = i.get_nns_by_vector(
86
            numpy.random.binomial(1, 0.5, f), 1, search_k=1000, include_distances=True
87
        )
88
        dists.append(ds[0])
89
    avg_dist = 1.0 * sum(dists) / len(dists)
90
    assert avg_dist <= 0.42
91

92

93
@pytest.mark.skip  # will fix later
94
def test_zero_vectors():
95
    # Mentioned on the annoy-user list
96
    bitstrings = [
97
        "0000000000011000001110000011111000101110111110000100000100000000",
98
        "0000000000011000001110000011111000101110111110000100000100000001",
99
        "0000000000011000001110000011111000101110111110000100000100000010",
100
        "0010010100011001001000010001100101011110000000110000011110001100",
101
        "1001011010000110100101101001111010001110100001101000111000001110",
102
        "0111100101111001011110010010001100010111000111100001101100011111",
103
        "0011000010011101000011010010111000101110100101111000011101001011",
104
        "0011000010011100000011010010111000101110100101111000011101001011",
105
        "1001100000111010001010000010110000111100100101001001010000000111",
106
        "0000000000111101010100010001000101101001000000011000001101000000",
107
        "1000101001010001011100010111001100110011001100110011001111001100",
108
        "1110011001001111100110010001100100001011000011010010111100100111",
109
    ]
110
    vectors = [[int(bit) for bit in bitstring] for bitstring in bitstrings]
111

112
    f = 64
113
    idx = AnnoyIndex(f, "hamming")
114
    for i, v in enumerate(vectors):
115
        idx.add_item(i, v)
116

117
    idx.build(10)
118
    idx.save("idx.ann")
119
    idx = AnnoyIndex(f, "hamming")
120
    idx.load("idx.ann")
121
    js, ds = idx.get_nns_by_item(0, 5, include_distances=True)
122
    assert js[0] == 0
123
    assert ds[:4] == [0, 1, 1, 22]
124

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.