annoy

Форк
0
/
precision_test.cpp 
176 строк · 4.7 Кб
1
/*
2
 * precision_test.cpp
3

4
 *
5
 *  Created on: Jul 13, 2016
6
 *      Author: Claudio Sanhueza
7
 *      Contact: csanhuezalobos@gmail.com
8
 */
9

10
#include <iostream>
11
#include <iomanip>
12
#include "../src/kissrandom.h"
13
#include "../src/annoylib.h"
14
#include <chrono>
15
#include <algorithm>
16
#include <map>
17
#include <random>
18

19
using namespace Annoy;
20
int precision(int f=40, int n=1000000){
21
	std::chrono::high_resolution_clock::time_point t_start, t_end;
22

23
	std::default_random_engine generator;
24
	std::normal_distribution<double> distribution(0.0, 1.0);
25

26
	//******************************************************
27
	//Building the tree
28
	AnnoyIndex<int, double, Angular, Kiss32Random, AnnoyIndexMultiThreadedBuildPolicy> t = AnnoyIndex<int, double, Angular, Kiss32Random, AnnoyIndexMultiThreadedBuildPolicy>(f);
29

30
	std::cout << "Building index ... be patient !!" << std::endl;
31
	std::cout << "\"Trees that are slow to grow bear the best fruit\" (Moliere)" << std::endl;
32

33

34

35
	for(int i=0; i<n; ++i){
36
		double *vec = (double *) malloc( f * sizeof(double) );
37

38
		for(int z=0; z<f; ++z){
39
			vec[z] = (distribution(generator));
40
		}
41

42
		t.add_item(i, vec);
43

44
		std::cout << "Loading objects ...\t object: "<< i+1 << "\tProgress:"<< std::fixed << std::setprecision(2) << (double) i / (double)(n + 1) * 100 << "%\r";
45

46
	}
47
	std::cout << std::endl;
48
	std::cout << "Building index num_trees = 2 * num_features ...";
49
	t_start = std::chrono::high_resolution_clock::now();
50
	t.build(2 * f);
51
	t_end = std::chrono::high_resolution_clock::now();
52
	auto duration = std::chrono::duration_cast<std::chrono::seconds>( t_end - t_start ).count();
53
	std::cout << " Done in "<< duration << " secs." << std::endl;
54

55

56
	std::cout << "Saving index ...";
57
	t.save("precision.tree");
58
	std::cout << " Done" << std::endl;
59

60

61

62
	//******************************************************
63
	std::vector<int> limits = {10, 100, 1000, 10000};
64
	int K=10;
65
	int prec_n = 1000;
66

67
	std::map<int, double> prec_sum;
68
	std::map<int, double> time_sum;
69
	std::vector<int> closest;
70

71
	//init precision and timers map
72
	for(std::vector<int>::iterator it = limits.begin(); it!=limits.end(); ++it){
73
		prec_sum[(*it)] = 0.0;
74
		time_sum[(*it)] = 0.0;
75
	}
76

77
	// doing the work
78
	for(int i=0; i<prec_n; ++i){
79

80
		//select a random node
81
		int j = rand() % n;
82

83
		std::cout << "finding nbs for " << j << std::endl;
84

85
		// getting the K closest
86
		t.get_nns_by_item(j, K, n, &closest, nullptr);
87

88
		std::vector<int> toplist;
89
		std::vector<int> intersection;
90

91
		for(std::vector<int>::iterator limit = limits.begin(); limit!=limits.end(); ++limit){
92

93
			t_start = std::chrono::high_resolution_clock::now();
94
			t.get_nns_by_item(j, (*limit), (size_t) -1, &toplist, nullptr); //search_k defaults to "n_trees * n" if not provided.
95
			t_end = std::chrono::high_resolution_clock::now();
96
			auto duration = std::chrono::duration_cast<std::chrono::milliseconds>( t_end - t_start ).count();
97

98
			//intersecting results
99
			std::sort(closest.begin(), closest.end(), std::less<int>());
100
			std::sort(toplist.begin(), toplist.end(), std::less<int>());
101
			intersection.resize(std::max(closest.size(), toplist.size()));
102
			std::vector<int>::iterator it_set = std::set_intersection(closest.begin(), closest.end(), toplist.begin(), toplist.end(), intersection.begin());
103
			intersection.resize(it_set-intersection.begin());
104

105
			// storing metrics
106
			int found = intersection.size();
107
			double hitrate = found / (double) K;
108
			prec_sum[(*limit)] += hitrate;
109

110
			time_sum[(*limit)] += duration;
111

112

113
			//deallocate memory
114
			vector<int>().swap(intersection);
115
			vector<int>().swap(toplist);
116
		}
117

118
		//print resulting metrics
119
		for(std::vector<int>::iterator limit = limits.begin(); limit!=limits.end(); ++limit){
120
			std::cout << "limit: " << (*limit) << "\tprecision: "<< std::fixed << std::setprecision(2) << (100.0 * prec_sum[(*limit)] / (i + 1)) << "% \tavg. time: "<< std::fixed<< std::setprecision(6) << (time_sum[(*limit)] / (i + 1)) * 1e-04 << "s" << std::endl;
121
		}
122

123
		closest.clear(); vector<int>().swap(closest);
124

125
	}
126

127
	std::cout << "\nDone" << std::endl;
128
	return 0;
129
}
130

131

132
void help(){
133
	std::cout << "Annoy Precision C++ example" << std::endl;
134
	std::cout << "Usage:" << std::endl;
135
	std::cout << "(default)		./precision" << std::endl;
136
	std::cout << "(using parameters)	./precision num_features num_nodes" << std::endl;
137
	std::cout << std::endl;
138
}
139

140
void feedback(int f, int n){
141
	std::cout<<"Runing precision example with:" << std::endl;
142
	std::cout<<"num. features: "<< f << std::endl;
143
	std::cout<<"num. nodes: "<< n << std::endl;
144
	std::cout << std::endl;
145
}
146

147

148
int main(int argc, char **argv) {
149
	int f, n;
150

151

152
	if(argc == 1){
153
		f = 40;
154
		n = 1000000;
155

156
		feedback(f,n);
157

158
		precision(40, 1000000);
159
	}
160
	else if(argc == 3){
161

162
		f = atoi(argv[1]);
163
		n = atoi(argv[2]);
164

165
		feedback(f,n);
166

167
		precision(f, n);
168
	}
169
	else {
170
		help();
171
		return EXIT_FAILURE;
172
	}
173

174

175
	return EXIT_SUCCESS;
176
}
177

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.