google-research
303 строки · 7.6 Кб
1{
2"cells": [
3{
4"cell_type": "code",
5"execution_count": null,
6"metadata": {},
7"outputs": [],
8"source": [
9"# Copyright 2022 Google LLC\n",
10"\n",
11"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
12"# you may not use this file except in compliance with the License.\n",
13"# You may obtain a copy of the License at\n",
14"\n",
15"# http://www.apache.org/licenses/LICENSE-2.0\n",
16"\n",
17"# Unless required by applicable law or agreed to in writing, software\n",
18"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
19"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
20"# See the License for the specific language governing permissions and\n",
21"# limitations under the License."
22]
23},
24{
25"cell_type": "markdown",
26"metadata": {},
27"source": [
28"# ScaNN Demo with GloVe Dataset"
29]
30},
31{
32"cell_type": "code",
33"execution_count": 1,
34"metadata": {
35"scrolled": true
36},
37"outputs": [],
38"source": [
39"import numpy as np\n",
40"import h5py\n",
41"import os\n",
42"import requests\n",
43"import tempfile\n",
44"import time\n",
45"\n",
46"import scann"
47]
48},
49{
50"cell_type": "markdown",
51"metadata": {},
52"source": [
53"### Download dataset"
54]
55},
56{
57"cell_type": "code",
58"execution_count": 2,
59"metadata": {},
60"outputs": [],
61"source": [
62"with tempfile.TemporaryDirectory() as tmp:\n",
63" response = requests.get(\"http://ann-benchmarks.com/glove-100-angular.hdf5\")\n",
64" loc = os.path.join(tmp, \"glove.hdf5\")\n",
65" with open(loc, 'wb') as f:\n",
66" f.write(response.content)\n",
67" \n",
68" glove_h5py = h5py.File(loc, \"r\")"
69]
70},
71{
72"cell_type": "code",
73"execution_count": 3,
74"metadata": {},
75"outputs": [
76{
77"data": {
78"text/plain": [
79"['distances', 'neighbors', 'test', 'train']"
80]
81},
82"execution_count": 3,
83"metadata": {},
84"output_type": "execute_result"
85}
86],
87"source": [
88"list(glove_h5py.keys())"
89]
90},
91{
92"cell_type": "code",
93"execution_count": 4,
94"metadata": {},
95"outputs": [
96{
97"name": "stdout",
98"output_type": "stream",
99"text": [
100"(1183514, 100)\n",
101"(10000, 100)\n"
102]
103}
104],
105"source": [
106"dataset = glove_h5py['train']\n",
107"queries = glove_h5py['test']\n",
108"print(dataset.shape)\n",
109"print(queries.shape)"
110]
111},
112{
113"cell_type": "markdown",
114"metadata": {},
115"source": [
116"### Create ScaNN searcher"
117]
118},
119{
120"cell_type": "code",
121"execution_count": 5,
122"metadata": {},
123"outputs": [],
124"source": [
125"normalized_dataset = dataset / np.linalg.norm(dataset, axis=1)[:, np.newaxis]\n",
126"# configure ScaNN as a tree - asymmetric hash hybrid with reordering\n",
127"# anisotropic quantization as described in the paper; see README\n",
128"\n",
129"# use scann.scann_ops.build() to instead create a TensorFlow-compatible searcher\n",
130"searcher = scann.scann_ops_pybind.builder(normalized_dataset, 10, \"dot_product\").tree(\n",
131" num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_ah(\n",
132" 2, anisotropic_quantization_threshold=0.2).reorder(100).build()"
133]
134},
135{
136"cell_type": "code",
137"execution_count": 6,
138"metadata": {},
139"outputs": [],
140"source": [
141"def compute_recall(neighbors, true_neighbors):\n",
142" total = 0\n",
143" for gt_row, row in zip(true_neighbors, neighbors):\n",
144" total += np.intersect1d(gt_row, row).shape[0]\n",
145" return total / true_neighbors.size"
146]
147},
148{
149"cell_type": "markdown",
150"metadata": {},
151"source": [
152"### ScaNN interface features"
153]
154},
155{
156"cell_type": "code",
157"execution_count": 7,
158"metadata": {},
159"outputs": [
160{
161"name": "stdout",
162"output_type": "stream",
163"text": [
164"Recall: 0.8999\n",
165"Time: 1.3812487125396729\n"
166]
167}
168],
169"source": [
170"# this will search the top 100 of the 2000 leaves, and compute\n",
171"# the exact dot products of the top 100 candidates from asymmetric\n",
172"# hashing to get the final top 10 candidates.\n",
173"start = time.time()\n",
174"neighbors, distances = searcher.search_batched(queries)\n",
175"end = time.time()\n",
176"\n",
177"# we are given top 100 neighbors in the ground truth, so select top 10\n",
178"print(\"Recall:\", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))\n",
179"print(\"Time:\", end - start)"
180]
181},
182{
183"cell_type": "code",
184"execution_count": 8,
185"metadata": {},
186"outputs": [
187{
188"name": "stdout",
189"output_type": "stream",
190"text": [
191"Recall: 0.92327\n",
192"Time: 1.8380558490753174\n"
193]
194}
195],
196"source": [
197"# increasing the leaves to search increases recall at the cost of speed\n",
198"start = time.time()\n",
199"neighbors, distances = searcher.search_batched(queries, leaves_to_search=150)\n",
200"end = time.time()\n",
201"\n",
202"print(\"Recall:\", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))\n",
203"print(\"Time:\", end - start)"
204]
205},
206{
207"cell_type": "code",
208"execution_count": 9,
209"metadata": {},
210"outputs": [
211{
212"name": "stdout",
213"output_type": "stream",
214"text": [
215"Recall: 0.93098\n",
216"Time: 2.2772152423858643\n"
217]
218}
219],
220"source": [
221"# increasing reordering (the exact scoring of top AH candidates) has a similar effect.\n",
222"start = time.time()\n",
223"neighbors, distances = searcher.search_batched(queries, leaves_to_search=150, pre_reorder_num_neighbors=250)\n",
224"end = time.time()\n",
225"\n",
226"print(\"Recall:\", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))\n",
227"print(\"Time:\", end - start)"
228]
229},
230{
231"cell_type": "code",
232"execution_count": 10,
233"metadata": {},
234"outputs": [
235{
236"name": "stdout",
237"output_type": "stream",
238"text": [
239"(10000, 10) (10000, 10)\n",
240"(10000, 20) (10000, 20)\n"
241]
242}
243],
244"source": [
245"# we can also dynamically configure the number of neighbors returned\n",
246"# currently returns 10 as configued in ScannBuilder()\n",
247"neighbors, distances = searcher.search_batched(queries)\n",
248"print(neighbors.shape, distances.shape)\n",
249"\n",
250"# now returns 20\n",
251"neighbors, distances = searcher.search_batched(queries, final_num_neighbors=20)\n",
252"print(neighbors.shape, distances.shape)"
253]
254},
255{
256"cell_type": "code",
257"execution_count": 11,
258"metadata": {},
259"outputs": [
260{
261"name": "stdout",
262"output_type": "stream",
263"text": [
264"[ 97478 262700 846101 671078 232287]\n",
265"[2.5518737 2.542952 2.539792 2.5383418 2.519638 ]\n",
266"Latency (ms): 0.7724761962890625\n"
267]
268}
269],
270"source": [
271"# we have been exclusively calling batch search so far; the single-query call has the same API\n",
272"start = time.time()\n",
273"neighbors, distances = searcher.search(queries[0], final_num_neighbors=5)\n",
274"end = time.time()\n",
275"\n",
276"print(neighbors)\n",
277"print(distances)\n",
278"print(\"Latency (ms):\", 1000*(end - start))"
279]
280}
281],
282"metadata": {
283"kernelspec": {
284"display_name": "Python 3",
285"language": "python",
286"name": "python3"
287},
288"language_info": {
289"codemirror_mode": {
290"name": "ipython",
291"version": 3
292},
293"file_extension": ".py",
294"mimetype": "text/x-python",
295"name": "python",
296"nbconvert_exporter": "python",
297"pygments_lexer": "ipython3",
298"version": "3.6.10"
299}
300},
301"nbformat": 4,
302"nbformat_minor": 4
303}
304