google-research

Форк
0
/
example.ipynb 
303 строки · 7.6 Кб
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": null,
6
   "metadata": {},
7
   "outputs": [],
8
   "source": [
9
    "# Copyright 2022 Google LLC\n",
10
    "\n",
11
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
12
    "# you may not use this file except in compliance with the License.\n",
13
    "# You may obtain a copy of the License at\n",
14
    "\n",
15
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
16
    "\n",
17
    "# Unless required by applicable law or agreed to in writing, software\n",
18
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
19
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
20
    "# See the License for the specific language governing permissions and\n",
21
    "# limitations under the License."
22
   ]
23
  },
24
  {
25
   "cell_type": "markdown",
26
   "metadata": {},
27
   "source": [
28
    "# ScaNN Demo with GloVe Dataset"
29
   ]
30
  },
31
  {
32
   "cell_type": "code",
33
   "execution_count": 1,
34
   "metadata": {
35
    "scrolled": true
36
   },
37
   "outputs": [],
38
   "source": [
39
    "import numpy as np\n",
40
    "import h5py\n",
41
    "import os\n",
42
    "import requests\n",
43
    "import tempfile\n",
44
    "import time\n",
45
    "\n",
46
    "import scann"
47
   ]
48
  },
49
  {
50
   "cell_type": "markdown",
51
   "metadata": {},
52
   "source": [
53
    "### Download dataset"
54
   ]
55
  },
56
  {
57
   "cell_type": "code",
58
   "execution_count": 2,
59
   "metadata": {},
60
   "outputs": [],
61
   "source": [
62
    "with tempfile.TemporaryDirectory() as tmp:\n",
63
    "    response = requests.get(\"http://ann-benchmarks.com/glove-100-angular.hdf5\")\n",
64
    "    loc = os.path.join(tmp, \"glove.hdf5\")\n",
65
    "    with open(loc, 'wb') as f:\n",
66
    "        f.write(response.content)\n",
67
    "    \n",
68
    "    glove_h5py = h5py.File(loc, \"r\")"
69
   ]
70
  },
71
  {
72
   "cell_type": "code",
73
   "execution_count": 3,
74
   "metadata": {},
75
   "outputs": [
76
    {
77
     "data": {
78
      "text/plain": [
79
       "['distances', 'neighbors', 'test', 'train']"
80
      ]
81
     },
82
     "execution_count": 3,
83
     "metadata": {},
84
     "output_type": "execute_result"
85
    }
86
   ],
87
   "source": [
88
    "list(glove_h5py.keys())"
89
   ]
90
  },
91
  {
92
   "cell_type": "code",
93
   "execution_count": 4,
94
   "metadata": {},
95
   "outputs": [
96
    {
97
     "name": "stdout",
98
     "output_type": "stream",
99
     "text": [
100
      "(1183514, 100)\n",
101
      "(10000, 100)\n"
102
     ]
103
    }
104
   ],
105
   "source": [
106
    "dataset = glove_h5py['train']\n",
107
    "queries = glove_h5py['test']\n",
108
    "print(dataset.shape)\n",
109
    "print(queries.shape)"
110
   ]
111
  },
112
  {
113
   "cell_type": "markdown",
114
   "metadata": {},
115
   "source": [
116
    "### Create ScaNN searcher"
117
   ]
118
  },
119
  {
120
   "cell_type": "code",
121
   "execution_count": 5,
122
   "metadata": {},
123
   "outputs": [],
124
   "source": [
125
    "normalized_dataset = dataset / np.linalg.norm(dataset, axis=1)[:, np.newaxis]\n",
126
    "# configure ScaNN as a tree - asymmetric hash hybrid with reordering\n",
127
    "# anisotropic quantization as described in the paper; see README\n",
128
    "\n",
129
    "# use scann.scann_ops.build() to instead create a TensorFlow-compatible searcher\n",
130
    "searcher = scann.scann_ops_pybind.builder(normalized_dataset, 10, \"dot_product\").tree(\n",
131
    "    num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_ah(\n",
132
    "    2, anisotropic_quantization_threshold=0.2).reorder(100).build()"
133
   ]
134
  },
135
  {
136
   "cell_type": "code",
137
   "execution_count": 6,
138
   "metadata": {},
139
   "outputs": [],
140
   "source": [
141
    "def compute_recall(neighbors, true_neighbors):\n",
142
    "    total = 0\n",
143
    "    for gt_row, row in zip(true_neighbors, neighbors):\n",
144
    "        total += np.intersect1d(gt_row, row).shape[0]\n",
145
    "    return total / true_neighbors.size"
146
   ]
147
  },
148
  {
149
   "cell_type": "markdown",
150
   "metadata": {},
151
   "source": [
152
    "### ScaNN interface features"
153
   ]
154
  },
155
  {
156
   "cell_type": "code",
157
   "execution_count": 7,
158
   "metadata": {},
159
   "outputs": [
160
    {
161
     "name": "stdout",
162
     "output_type": "stream",
163
     "text": [
164
      "Recall: 0.8999\n",
165
      "Time: 1.3812487125396729\n"
166
     ]
167
    }
168
   ],
169
   "source": [
170
    "# this will search the top 100 of the 2000 leaves, and compute\n",
171
    "# the exact dot products of the top 100 candidates from asymmetric\n",
172
    "# hashing to get the final top 10 candidates.\n",
173
    "start = time.time()\n",
174
    "neighbors, distances = searcher.search_batched(queries)\n",
175
    "end = time.time()\n",
176
    "\n",
177
    "# we are given top 100 neighbors in the ground truth, so select top 10\n",
178
    "print(\"Recall:\", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))\n",
179
    "print(\"Time:\", end - start)"
180
   ]
181
  },
182
  {
183
   "cell_type": "code",
184
   "execution_count": 8,
185
   "metadata": {},
186
   "outputs": [
187
    {
188
     "name": "stdout",
189
     "output_type": "stream",
190
     "text": [
191
      "Recall: 0.92327\n",
192
      "Time: 1.8380558490753174\n"
193
     ]
194
    }
195
   ],
196
   "source": [
197
    "# increasing the leaves to search increases recall at the cost of speed\n",
198
    "start = time.time()\n",
199
    "neighbors, distances = searcher.search_batched(queries, leaves_to_search=150)\n",
200
    "end = time.time()\n",
201
    "\n",
202
    "print(\"Recall:\", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))\n",
203
    "print(\"Time:\", end - start)"
204
   ]
205
  },
206
  {
207
   "cell_type": "code",
208
   "execution_count": 9,
209
   "metadata": {},
210
   "outputs": [
211
    {
212
     "name": "stdout",
213
     "output_type": "stream",
214
     "text": [
215
      "Recall: 0.93098\n",
216
      "Time: 2.2772152423858643\n"
217
     ]
218
    }
219
   ],
220
   "source": [
221
    "# increasing reordering (the exact scoring of top AH candidates) has a similar effect.\n",
222
    "start = time.time()\n",
223
    "neighbors, distances = searcher.search_batched(queries, leaves_to_search=150, pre_reorder_num_neighbors=250)\n",
224
    "end = time.time()\n",
225
    "\n",
226
    "print(\"Recall:\", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))\n",
227
    "print(\"Time:\", end - start)"
228
   ]
229
  },
230
  {
231
   "cell_type": "code",
232
   "execution_count": 10,
233
   "metadata": {},
234
   "outputs": [
235
    {
236
     "name": "stdout",
237
     "output_type": "stream",
238
     "text": [
239
      "(10000, 10) (10000, 10)\n",
240
      "(10000, 20) (10000, 20)\n"
241
     ]
242
    }
243
   ],
244
   "source": [
245
    "# we can also dynamically configure the number of neighbors returned\n",
246
    "# currently returns 10 as configued in ScannBuilder()\n",
247
    "neighbors, distances = searcher.search_batched(queries)\n",
248
    "print(neighbors.shape, distances.shape)\n",
249
    "\n",
250
    "# now returns 20\n",
251
    "neighbors, distances = searcher.search_batched(queries, final_num_neighbors=20)\n",
252
    "print(neighbors.shape, distances.shape)"
253
   ]
254
  },
255
  {
256
   "cell_type": "code",
257
   "execution_count": 11,
258
   "metadata": {},
259
   "outputs": [
260
    {
261
     "name": "stdout",
262
     "output_type": "stream",
263
     "text": [
264
      "[ 97478 262700 846101 671078 232287]\n",
265
      "[2.5518737 2.542952  2.539792  2.5383418 2.519638 ]\n",
266
      "Latency (ms): 0.7724761962890625\n"
267
     ]
268
    }
269
   ],
270
   "source": [
271
    "# we have been exclusively calling batch search so far; the single-query call has the same API\n",
272
    "start = time.time()\n",
273
    "neighbors, distances = searcher.search(queries[0], final_num_neighbors=5)\n",
274
    "end = time.time()\n",
275
    "\n",
276
    "print(neighbors)\n",
277
    "print(distances)\n",
278
    "print(\"Latency (ms):\", 1000*(end - start))"
279
   ]
280
  }
281
 ],
282
 "metadata": {
283
  "kernelspec": {
284
   "display_name": "Python 3",
285
   "language": "python",
286
   "name": "python3"
287
  },
288
  "language_info": {
289
   "codemirror_mode": {
290
    "name": "ipython",
291
    "version": 3
292
   },
293
   "file_extension": ".py",
294
   "mimetype": "text/x-python",
295
   "name": "python",
296
   "nbconvert_exporter": "python",
297
   "pygments_lexer": "ipython3",
298
   "version": "3.6.10"
299
  }
300
 },
301
 "nbformat": 4,
302
 "nbformat_minor": 4
303
}
304

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.