examples

Форк
0
556 строк · 22.7 Кб
1
{
2
 "cells": [
3
  {
4
   "attachments": {},
5
   "cell_type": "markdown",
6
   "id": "11dae564",
7
   "metadata": {},
8
   "source": [
9
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/training-with-wandb/02-encode.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/training-with-wandb/02-encode.ipynb)"
10
   ]
11
  },
12
  {
13
   "cell_type": "code",
14
   "execution_count": null,
15
   "id": "857f8968",
16
   "metadata": {},
17
   "outputs": [],
18
   "source": [
19
    "!pip install -qq wandb datasets pinecone-client sentence-transformers transformers"
20
   ]
21
  },
22
  {
23
   "attachments": {},
24
   "cell_type": "markdown",
25
   "id": "29398b9e-d57b-4917-b59a-467e059f6bfa",
26
   "metadata": {},
27
   "source": [
28
    "## Encoding arXiv Abstracts\n",
29
    "\n",
30
    "This is part *three* of a four-part notebook series on fine-tuning encoder models with Weights & Biases for use with Pinecone. Find the [full set of notebooks on Github here](https://github.com/pinecone-io/examples/blob/master/analytics-and-ml/model-training/training-with-wandb).\n",
31
    "\n",
32
    "We start by loading two datasets from WandB created in the very first [W&B notebook](https://github.com/pinecone-io/examples/blob/master/analytics-and-ml/model-training/training-with-wandb/00-intro-and-summarizer-train.ipynb)."
33
   ]
34
  },
35
  {
36
   "cell_type": "code",
37
   "execution_count": 1,
38
   "id": "c1c06cb6-2e3d-423f-9323-3061494f87d6",
39
   "metadata": {},
40
   "outputs": [
41
    {
42
     "name": "stderr",
43
     "output_type": "stream",
44
     "text": [
45
      "/opt/conda/lib/python3.7/site-packages/IPython/html.py:14: ShimWarning: The `IPython.html` package has been deprecated since IPython 4.0. You should import from `notebook` instead. `IPython.html.widgets` has moved to `ipywidgets`.\n",
46
      "  \"`IPython.html.widgets` has moved to `ipywidgets`.\", ShimWarning)\n",
47
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mjamesbriggs\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
48
     ]
49
    },
50
    {
51
     "data": {
52
      "text/html": [
53
       "Tracking run with wandb version 0.13.5"
54
      ],
55
      "text/plain": [
56
       "<IPython.core.display.HTML object>"
57
      ]
58
     },
59
     "metadata": {},
60
     "output_type": "display_data"
61
    },
62
    {
63
     "data": {
64
      "text/html": [
65
       "Run data is saved locally in <code>/home/jupyter/wandb/wandb/run-20221110_061456-26ugx53n</code>"
66
      ],
67
      "text/plain": [
68
       "<IPython.core.display.HTML object>"
69
      ]
70
     },
71
     "metadata": {},
72
     "output_type": "display_data"
73
    },
74
    {
75
     "data": {
76
      "text/html": [
77
       "Syncing run <strong><a href=\"https://wandb.ai/jamesbriggs/arxiv-searching/runs/26ugx53n\" target=\"_blank\">glowing-surf-35</a></strong> to <a href=\"https://wandb.ai/jamesbriggs/arxiv-searching\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://wandb.me/run\" target=\"_blank\">docs</a>)<br/>"
78
      ],
79
      "text/plain": [
80
       "<IPython.core.display.HTML object>"
81
      ]
82
     },
83
     "metadata": {},
84
     "output_type": "display_data"
85
    },
86
    {
87
     "name": "stderr",
88
     "output_type": "stream",
89
     "text": [
90
      "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact arxiv-papers:latest, 3388.22MB. 1 files... \n",
91
      "\u001b[34m\u001b[1mwandb\u001b[0m:   1 of 1 files downloaded.  \n",
92
      "Done. 0:0:0.1\n"
93
     ]
94
    },
95
    {
96
     "name": "stdout",
97
     "output_type": "stream",
98
     "text": [
99
      "2151137\n",
100
      "{'id': 'supr-con/9609004', 'submitter': 'Masanori Ichioka', 'authors': 'Naoki Enomoto, Masanori Ichioka and Kazushige Machida (Okayama Univ.)', 'title': 'Ginzburg Landau theory for d-wave pairing and fourfold symmetric vortex\\n  core structure', 'comments': '12 pages including 8 eps figs, LaTeX with jpsj.sty & epsfig', 'journal-ref': 'J. Phys. Soc. Jpn. 66, 204 (1997).', 'doi': '10.1143/JPSJ.66.204', 'report-no': None, 'categories': 'supr-con cond-mat.supr-con', 'license': None, 'abstract': \"  The Ginzburg Landau theory for d_{x^2-y^2}-wave superconductors is\\nconstructed, by starting from the Gor'kov equation with including correction\\nterms up to the next order of ln(T_c/T). Some of the non-local correction terms\\nare found to break the cylindrical symmetry and lead to the fourfold symmetric\\ncore structure, reflecting the internal degree of freedom in the pair\\npotential. Using this extended Ginzburg Landau theory, we investigate the\\nfourfold symmetric structure of the pair potential, current and magnetic field\\naround an isolated single vortex, and clarify concretely how the vortex core\\nstructure deviates from the cylindrical symmetry in the d_{x^2-y^2}-wave\\nsuperconductors.\\n\", 'versions': [{'version': 'v1', 'created': 'Wed, 25 Sep 1996 14:17:09 GMT'}], 'update_date': '2009-10-30', 'authors_parsed': [['Enomoto', 'Naoki', '', 'Okayama Univ.'], ['Ichioka', 'Masanori', '', 'Okayama Univ.'], ['Machida', 'Kazushige', '', 'Okayama Univ.']]}\n"
101
     ]
102
    }
103
   ],
104
   "source": [
105
    "import wandb\n",
106
    "import json\n",
107
    "\n",
108
    "run = wandb.init(project=\"arxiv-searching\")\n",
109
    "# download\n",
110
    "artifact = run.use_artifact('events/arxiv-searching/arxiv-papers:latest', type='dataset')\n",
111
    "artifact_dir = artifact.download()\n",
112
    "\n",
113
    "# open file generator\n",
114
    "path = artifact_dir+'/arxiv-snapshot'\n",
115
    "def arxiv_metadata():\n",
116
    "    with open(path, 'r') as f:\n",
117
    "        for line in f:\n",
118
    "            doc_dict = json.loads(line)\n",
119
    "            yield doc_dict\n",
120
    "metadata = arxiv_metadata()\n",
121
    "# get count of items\n",
122
    "count = 0\n",
123
    "for row in metadata:\n",
124
    "    count += 1\n",
125
    "# refresh generator\n",
126
    "metadata = arxiv_metadata()\n",
127
    "print(count)\n",
128
    "print(row)"
129
   ]
130
  },
131
  {
132
   "cell_type": "markdown",
133
   "id": "5875dc95-cf5f-4110-b4ae-79e385472012",
134
   "metadata": {},
135
   "source": [
136
    "We will encode all of the `'abstract'` values with the `minilm-arxiv-encoder` model we previously trained and stored as an artifact on W&B.\n",
137
    "\n",
138
    "First download the artifact files:"
139
   ]
140
  },
141
  {
142
   "cell_type": "code",
143
   "execution_count": 2,
144
   "id": "93f4eed3-808a-452d-b6c9-183f0440206e",
145
   "metadata": {},
146
   "outputs": [
147
    {
148
     "name": "stderr",
149
     "output_type": "stream",
150
     "text": [
151
      "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact minilm-arxiv:latest, 128.23MB. 6 files... \n",
152
      "\u001b[34m\u001b[1mwandb\u001b[0m:   6 of 6 files downloaded.  \n",
153
      "Done. 0:0:0.1\n"
154
     ]
155
    },
156
    {
157
     "data": {
158
      "text/plain": [
159
       "'./artifacts/minilm-arxiv:v1'"
160
      ]
161
     },
162
     "execution_count": 2,
163
     "metadata": {},
164
     "output_type": "execute_result"
165
    }
166
   ],
167
   "source": [
168
    "artifact = run.use_artifact(\n",
169
    "    'jamesbriggs/arxiv-searching/minilm-arxiv:latest', type='model'\n",
170
    ")\n",
171
    "artifact_dir = artifact.download()\n",
172
    "artifact_dir"
173
   ]
174
  },
175
  {
176
   "cell_type": "markdown",
177
   "id": "3d88fd7a-9650-4d83-924b-1b0d1f0ae67b",
178
   "metadata": {},
179
   "source": [
180
    "In here we will find all of the model files needed to initialize our fine-tuned sentence transformer:"
181
   ]
182
  },
183
  {
184
   "cell_type": "code",
185
   "execution_count": 3,
186
   "id": "1d2a7d0c-bc07-474e-85b5-f3baa2aee1f1",
187
   "metadata": {},
188
   "outputs": [
189
    {
190
     "data": {
191
      "text/plain": [
192
       "['vocab.txt',\n",
193
       " 'tokenizer.json',\n",
194
       " 'pytorch_model.bin',\n",
195
       " 'special_tokens_map.json',\n",
196
       " 'tokenizer_config.json',\n",
197
       " 'config.json']"
198
      ]
199
     },
200
     "execution_count": 3,
201
     "metadata": {},
202
     "output_type": "execute_result"
203
    }
204
   ],
205
   "source": [
206
    "import os\n",
207
    "\n",
208
    "os.listdir(artifact_dir)"
209
   ]
210
  },
211
  {
212
   "cell_type": "markdown",
213
   "id": "76901f93-21dc-420e-8af7-616b7a61346a",
214
   "metadata": {},
215
   "source": [
216
    "We can do that like so:"
217
   ]
218
  },
219
  {
220
   "cell_type": "code",
221
   "execution_count": 4,
222
   "id": "4a739794-dd13-4232-944b-9a0a88c2ade1",
223
   "metadata": {},
224
   "outputs": [
225
    {
226
     "name": "stderr",
227
     "output_type": "stream",
228
     "text": [
229
      "/opt/conda/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
230
      "  from .autonotebook import tqdm as notebook_tqdm\n"
231
     ]
232
    },
233
    {
234
     "data": {
235
      "text/plain": [
236
       "SentenceTransformer(\n",
237
       "  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel \n",
238
       "  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
239
       ")"
240
      ]
241
     },
242
     "execution_count": 4,
243
     "metadata": {},
244
     "output_type": "execute_result"
245
    }
246
   ],
247
   "source": [
248
    "from sentence_transformers import models, SentenceTransformer\n",
249
    "import torch\n",
250
    "\n",
251
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
252
    "\n",
253
    "minilm = models.Transformer(artifact_dir)\n",
254
    "pooling = models.Pooling(\n",
255
    "    minilm.get_word_embedding_dimension(),\n",
256
    "    pooling_mode_mean_tokens=True\n",
257
    ")\n",
258
    "\n",
259
    "model = SentenceTransformer(\n",
260
    "    modules=[minilm, pooling],\n",
261
    "    device=device\n",
262
    ")\n",
263
    "model"
264
   ]
265
  },
266
  {
267
   "cell_type": "code",
268
   "execution_count": 5,
269
   "id": "586248d1-0049-4d5a-96f7-f8017fb2036d",
270
   "metadata": {
271
    "tags": []
272
   },
273
   "outputs": [
274
    {
275
     "data": {
276
      "text/plain": [
277
       "array([[-1.78127125e-01,  2.79968888e-01, -1.10255487e-01,\n",
278
       "         6.38929894e-03, -2.01225877e-01,  1.51545078e-01,\n",
279
       "         1.00171259e-02,  3.63548957e-02, -1.11633286e-01,\n",
280
       "        -1.86091021e-01, -1.80149123e-01,  1.60711005e-01,\n",
281
       "         1.74418956e-01,  4.35423665e-02,  2.56152838e-01,\n",
282
       "         1.36271194e-01, -2.20926031e-01,  1.72951341e-01,\n",
283
       "         1.11775815e-01,  3.24884406e-03, -3.29344273e-02,\n",
284
       "         7.01274052e-02,  5.62849566e-02,  5.84686697e-02,\n",
285
       "         6.33804947e-02,  1.59427132e-02,  1.90770373e-01,\n",
286
       "         3.12118349e-03,  1.37064531e-01, -5.43970279e-02,\n",
287
       "         4.25677076e-02,  1.51187047e-01, -4.71253663e-01,\n",
288
       "        -1.16020828e-01, -1.14065185e-01, -1.56056330e-01,\n",
289
       "         1.60679163e-03,  2.31568962e-02, -3.18167359e-02,\n",
290
       "        -1.16474792e-01, -5.54910935e-02,  2.14383662e-01,\n",
291
       "         3.19849402e-02, -5.36291003e-02, -5.07795922e-02,\n",
292
       "         1.01412281e-01,  1.14509769e-01,  1.15304410e-01,\n",
293
       "        -2.12687105e-02,  1.00387290e-01,  7.25929290e-02,\n",
294
       "         5.13379760e-02, -7.07295612e-02,  1.10100530e-01,\n",
295
       "         2.16299221e-02, -6.17225952e-02, -2.12949961e-01,\n",
296
       "         1.01176966e-02,  1.85544211e-02,  3.22577078e-03,\n",
297
       "         8.24356824e-02,  1.62155926e-01,  6.39806986e-02,\n",
298
       "         1.28711343e-01,  1.26675731e-02, -1.37379421e-02,\n",
299
       "        -8.26888978e-02,  4.14049849e-02, -6.12147013e-03,\n",
300
       "         7.80780986e-02,  1.22293532e-01,  5.99697642e-02,\n",
301
       "         1.15361720e-01, -6.55396283e-02,  6.20190501e-02,\n",
302
       "        -1.74107254e-01, -1.44642219e-01, -1.12532802e-01,\n",
303
       "        -9.24786776e-02, -8.02832171e-02,  8.34115818e-02,\n",
304
       "        -7.14966878e-02,  8.17610249e-02,  1.01000164e-02,\n",
305
       "         2.07653691e-04,  3.71094160e-02,  7.52397999e-02,\n",
306
       "         3.94362360e-02, -2.90415064e-02,  8.29361975e-02,\n",
307
       "         1.77226022e-01, -1.46473601e-01, -4.14880253e-02,\n",
308
       "         3.12972143e-02, -1.35279566e-01,  5.92804886e-02,\n",
309
       "        -5.87741770e-02,  1.08800747e-01, -1.22568272e-01,\n",
310
       "        -1.91151381e-01,  1.12020858e-01,  5.13652079e-02,\n",
311
       "         2.37159252e-01, -5.57411276e-02, -6.28360957e-02,\n",
312
       "        -2.37326287e-02,  1.69684872e-01,  1.59426272e-01,\n",
313
       "         4.77304980e-02, -1.67335540e-01, -1.01463072e-01,\n",
314
       "        -4.48939763e-02,  2.38010045e-02, -1.23074487e-01,\n",
315
       "        -2.02281028e-01, -1.75339788e-01,  8.34432840e-02,\n",
316
       "        -6.81631416e-02, -1.74563471e-02, -2.26193089e-02,\n",
317
       "        -1.68089364e-02,  1.36754466e-02,  3.46720554e-02,\n",
318
       "         1.10969320e-01, -6.79774359e-02,  1.85720205e-01,\n",
319
       "        -3.38043511e-01,  2.84836106e-02,  2.28992343e-01,\n",
320
       "         1.82887554e-01, -6.06962480e-02,  3.13410372e-01,\n",
321
       "         5.42417215e-03, -3.14366929e-02,  1.97462086e-02,\n",
322
       "        -3.07794008e-02, -1.77915275e-01,  1.17499366e-01,\n",
323
       "        -2.44052969e-02, -4.43571098e-02,  1.16025403e-01,\n",
324
       "        -1.77904785e-01, -1.05790654e-02, -1.69786140e-01,\n",
325
       "        -1.37123525e-01, -2.37312749e-01,  1.39733568e-01,\n",
326
       "        -1.90371811e-01,  6.40053228e-02,  2.61818111e-01,\n",
327
       "        -5.71417473e-02,  7.86288381e-02,  1.05672821e-01,\n",
328
       "        -4.95498218e-02,  1.14241913e-01, -4.57595550e-02,\n",
329
       "        -3.39903414e-01, -7.32076392e-02,  1.20728187e-01,\n",
330
       "         3.61286551e-02, -1.27044454e-01,  9.13235918e-02,\n",
331
       "         5.36567494e-02, -1.23653367e-01,  1.09173447e-01,\n",
332
       "        -4.37955260e-02, -1.50184159e-03, -7.60437176e-02,\n",
333
       "         4.86165099e-02, -2.50281155e-01, -1.74011782e-01,\n",
334
       "         2.47733947e-02,  1.83616176e-01, -9.84238014e-02,\n",
335
       "        -1.11607119e-01, -5.85185103e-02, -5.21076992e-02,\n",
336
       "        -2.13656023e-01,  6.92998394e-02,  1.08075552e-01,\n",
337
       "         5.98740950e-02, -7.83068463e-02, -2.56594848e-02,\n",
338
       "         1.29869208e-01, -1.71856564e-02,  1.66503429e-01,\n",
339
       "        -3.05819064e-02,  5.02898544e-02, -1.36713356e-01,\n",
340
       "         1.16758183e-01, -8.95673409e-02,  1.26468435e-01,\n",
341
       "         4.38416302e-02, -1.55453250e-01,  7.19079226e-02,\n",
342
       "        -1.06323056e-01, -1.07946269e-01, -6.80187643e-02,\n",
343
       "        -8.38701148e-03, -4.11822088e-02,  2.76909620e-01,\n",
344
       "         4.68547083e-02,  4.93463390e-02,  2.90763099e-02,\n",
345
       "        -1.30036846e-01, -7.74710029e-02,  1.63585737e-01,\n",
346
       "        -1.45333767e-01, -6.11055903e-02,  3.95771749e-02,\n",
347
       "         1.12327918e-01,  1.52646273e-01, -1.09469732e-02,\n",
348
       "        -1.57306686e-01,  1.21623144e-01, -7.70391002e-02,\n",
349
       "        -2.28713498e-01, -6.51240721e-03, -4.17144895e-02,\n",
350
       "         4.87575904e-02,  1.52817070e-01,  1.07581370e-01,\n",
351
       "        -5.30700460e-02,  5.00147082e-02, -1.04325451e-01,\n",
352
       "         7.74948820e-02,  2.57168468e-02, -8.06147531e-02,\n",
353
       "         3.87148522e-02,  2.13774636e-01, -1.26998127e-01,\n",
354
       "        -2.90410995e-01, -1.58638414e-02,  2.02428520e-01,\n",
355
       "         1.62639588e-01, -4.74067964e-03,  8.18896592e-02,\n",
356
       "         1.25175849e-01,  2.63762742e-01, -2.06704602e-01,\n",
357
       "        -9.14238766e-02, -1.96620092e-01,  1.09160148e-01,\n",
358
       "        -2.17940882e-02,  2.03905568e-01, -2.21943393e-01,\n",
359
       "        -2.11968988e-01, -7.20419586e-02, -5.21163940e-02,\n",
360
       "        -1.60035223e-01,  1.82005912e-01,  6.49777651e-02,\n",
361
       "         2.06461266e-01,  2.14683488e-01, -6.27292842e-02,\n",
362
       "         2.51692664e-02,  1.93498597e-01, -9.34920982e-02,\n",
363
       "        -1.54694840e-02, -1.30474761e-01, -1.77416548e-01,\n",
364
       "         1.15094319e-01,  1.29716650e-01, -2.87248701e-01,\n",
365
       "        -1.57648146e-01,  4.44391221e-02,  1.40908122e-01,\n",
366
       "         7.32596265e-03,  1.27783626e-01, -4.62960973e-02,\n",
367
       "         1.25900418e-01, -1.39598131e-01, -7.91362002e-02,\n",
368
       "        -4.22106944e-02, -4.32713069e-02, -6.43170327e-02,\n",
369
       "        -7.27615356e-02, -1.17521748e-01,  1.69333458e-01,\n",
370
       "        -1.27414629e-01, -2.29594290e-01, -1.46220028e-01,\n",
371
       "         9.28544328e-02, -2.80166622e-02,  1.14858069e-01,\n",
372
       "         1.38829008e-01,  1.74213678e-01,  9.00694653e-02,\n",
373
       "         6.96216011e-03, -3.68438698e-02, -3.28931748e-03,\n",
374
       "         2.23214030e-01,  6.37881011e-02, -1.14255890e-01,\n",
375
       "        -2.64940672e-02,  1.70708954e-01,  9.25933644e-02,\n",
376
       "        -1.36656925e-01,  7.83628076e-02,  1.18782088e-01,\n",
377
       "         1.60760820e-01,  1.30176514e-01,  8.49534050e-02,\n",
378
       "        -1.18924476e-01,  8.29576775e-02,  1.34217842e-02,\n",
379
       "        -1.86802924e-01, -1.33491144e-01, -7.09138587e-02,\n",
380
       "        -2.18602978e-02,  5.48370890e-02,  1.99575439e-01,\n",
381
       "        -6.68747425e-02, -1.79825410e-01,  4.49441820e-02,\n",
382
       "        -1.64867163e-01,  1.02575898e-01, -1.06179774e-01,\n",
383
       "         1.69204120e-02,  2.20158786e-01,  3.67492549e-02,\n",
384
       "        -7.78725743e-02,  1.15782209e-01,  1.73470229e-02,\n",
385
       "        -7.98001960e-02,  4.79919948e-02, -7.50755593e-02,\n",
386
       "        -1.93090454e-01, -2.67650094e-02, -2.04100497e-02,\n",
387
       "         3.35861258e-02,  7.13374540e-02,  1.55757129e-01,\n",
388
       "         2.43938509e-02,  2.40554623e-02,  7.94551671e-02,\n",
389
       "        -5.06589413e-02, -2.45769247e-01,  2.63830006e-01,\n",
390
       "         1.06279410e-01, -1.43656954e-01,  1.64090663e-01,\n",
391
       "         8.19550902e-02,  7.53268078e-02,  2.59129722e-02,\n",
392
       "         8.72157589e-02, -1.44945517e-01, -3.51325721e-01,\n",
393
       "         1.17137462e-01,  1.26396015e-01, -7.43836239e-02,\n",
394
       "         4.06302027e-02, -1.51913702e-01, -9.97206792e-02,\n",
395
       "        -2.61555821e-01, -1.17414176e-01, -7.79822320e-02,\n",
396
       "        -2.47816928e-02, -3.99118029e-02, -3.76398070e-03,\n",
397
       "         1.90613508e-01, -7.92257637e-02, -1.22871017e-02,\n",
398
       "         6.29720688e-02,  6.36950284e-02, -5.14467955e-02,\n",
399
       "         8.70334730e-02, -1.31731883e-01, -7.81667978e-02,\n",
400
       "         1.78837497e-02, -1.56852648e-01, -4.04477492e-02,\n",
401
       "         1.57394841e-01, -1.04244456e-01,  1.66083761e-02,\n",
402
       "         5.76242879e-02, -9.76261199e-02,  2.56204046e-02,\n",
403
       "        -3.02246939e-02, -1.48694322e-01, -1.53342336e-01,\n",
404
       "         9.41478685e-02,  2.22977489e-01,  9.73301306e-02]], dtype=float32)"
405
      ]
406
     },
407
     "execution_count": 5,
408
     "metadata": {},
409
     "output_type": "execute_result"
410
    }
411
   ],
412
   "source": [
413
    "model.encode([row['abstract']])"
414
   ]
415
  },
416
  {
417
   "cell_type": "markdown",
418
   "id": "63071b29-bba0-4b47-ab44-6a8afc13c69d",
419
   "metadata": {},
420
   "source": [
421
    "We must encode and then `upsert` our encoded vectors to Pinecone. For this we need to initialize a Pinecone index. First we connect to Pinecone using a [free API key](https://app.pinecone.io)."
422
   ]
423
  },
424
  {
425
   "cell_type": "code",
426
   "execution_count": 6,
427
   "id": "c8c0fe8d-265b-40f9-ae43-d6177b13d54f",
428
   "metadata": {},
429
   "outputs": [],
430
   "source": [
431
    "from pinecone import Pinecone\n",
432
    "\n",
433
    "pinecone.init(\n",
434
    "    api_key='YOUR_API_KEY',  # app.pinecone.io\n",
435
    "    environment='YOUR_ENV'  # find next to API key in console\n",
436
    ")"
437
   ]
438
  },
439
  {
440
   "cell_type": "code",
441
   "execution_count": 7,
442
   "id": "ca659d74-08f0-48f1-948d-c457589ff73d",
443
   "metadata": {},
444
   "outputs": [],
445
   "source": [
446
    "index_id = 'arxiv-search'\n",
447
    "\n",
448
    "# create index if doesn't exist\n",
449
    "if not index_id in pinecone.list_indexes().names():\n",
450
    "    pinecone.create_index(\n",
451
    "        index_id,\n",
452
    "        dimension=model.get_sentence_embedding_dimension(),\n",
453
    "        metric='cosine',\n",
454
    "        pod_type='s1'\n",
455
    "    )\n",
456
    "\n",
457
    "# connect to index\n",
458
    "index = pinecone.Index(index_id)"
459
   ]
460
  },
461
  {
462
   "cell_type": "markdown",
463
   "id": "91a6ff61-24f8-4242-adb5-414abb48239e",
464
   "metadata": {},
465
   "source": [
466
    "Now index everything in Pinecone..."
467
   ]
468
  },
469
  {
470
   "cell_type": "code",
471
   "execution_count": null,
472
   "id": "9b2ada8e-a871-4ab9-8382-8e9dd3e608b9",
473
   "metadata": {},
474
   "outputs": [
475
    {
476
     "name": "stderr",
477
     "output_type": "stream",
478
     "text": [
479
      "  1%|          | 13140/2151137 [00:47<2:08:27, 277.38it/s]"
480
     ]
481
    }
482
   ],
483
   "source": [
484
    "from tqdm.auto import tqdm\n",
485
    "\n",
486
    "batch_size = 90\n",
487
    "\n",
488
    "batch_i = 0\n",
489
    "batch = []\n",
490
    "\n",
491
    "for row in tqdm(metadata, total=count):\n",
492
    "    batch_i += 1\n",
493
    "    batch.append({'id': row['id'], 'abstract': row['abstract']})\n",
494
    "    if batch_i == batch_size:\n",
495
    "        embeds = model.encode([x['abstract'] for x in batch]).tolist()\n",
496
    "        meta = [{'abstract': x['abstract']} for x in batch]\n",
497
    "        ids = [x['id'] for x in batch]\n",
498
    "        # add to pinecone\n",
499
    "        to_upsert = list(zip(ids, embeds, meta))\n",
500
    "        index.upsert(vectors=to_upsert)\n",
501
    "        # reset batch\n",
502
    "        batch = []\n",
503
    "        batch_i = 0\n",
504
    "        \n",
505
    "# add final items if any left\n",
506
    "if len(batch) > 0:\n",
507
    "    embeds = model.encode([x['abstract'] for x in batch]).tolist()\n",
508
    "    meta = [{'abstract': x['abstract']} for x in batch]\n",
509
    "    ids = [x['id'] for x in batch]\n",
510
    "    # add to pinecone\n",
511
    "    to_upsert = list(zip(ids, embeds, meta))\n",
512
    "    index.upsert(vectors=to_upsert)"
513
   ]
514
  },
515
  {
516
   "cell_type": "markdown",
517
   "id": "4a3d7c0a-0b1f-482e-b434-fc815a341869",
518
   "metadata": {},
519
   "source": [
520
    "All that is left after this is to begin making queries, we'll do this in the final notebook [`03-query.ipynb`](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/projects/training-with-wandb/03-query.ipynb)"
521
   ]
522
  }
523
 ],
524
 "metadata": {
525
  "environment": {
526
   "kernel": "python3",
527
   "name": "common-cu110.m95",
528
   "type": "gcloud",
529
   "uri": "gcr.io/deeplearning-platform-release/base-cu110:m95"
530
  },
531
  "kernelspec": {
532
   "display_name": "Python 3.10.7 64-bit",
533
   "language": "python",
534
   "name": "python3"
535
  },
536
  "language_info": {
537
   "codemirror_mode": {
538
    "name": "ipython",
539
    "version": 3
540
   },
541
   "file_extension": ".py",
542
   "mimetype": "text/x-python",
543
   "name": "python",
544
   "nbconvert_exporter": "python",
545
   "pygments_lexer": "ipython3",
546
   "version": "3.10.9"
547
  },
548
  "vscode": {
549
   "interpreter": {
550
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
551
   }
552
  }
553
 },
554
 "nbformat": 4,
555
 "nbformat_minor": 5
556
}
557

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.