examples

Форк
0
/
02-negative-mining.ipynb 
274 строки · 8.7 Кб
1
{
2
 "cells": [
3
  {
4
   "attachments": {},
5
   "cell_type": "markdown",
6
   "metadata": {},
7
   "source": [
8
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/02-negative-mining.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/02-negative-mining.ipynb)\n",
9
    "\n",
10
    "To perform the negative mining step we must create a vector database to store encoded passages, and allow us to search for similar passages that do not match the query we're searching with. This requires two things:\n",
11
    "\n",
12
    "* a pre-existing retriever model to build encodings - for this we will use a model from the *sentence-transformers* library\n",
13
    "* a vector DB to store encodings - for this we will use Pinecone as it is an free and easy vector DB to setup, which is fast at scale\n",
14
    "\n",
15
    "Let's load the model first."
16
   ]
17
  },
18
  {
19
   "cell_type": "code",
20
   "execution_count": 1,
21
   "metadata": {},
22
   "outputs": [
23
    {
24
     "data": {
25
      "text/plain": [
26
       "SentenceTransformer(\n",
27
       "  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: DistilBertModel \n",
28
       "  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
29
       ")"
30
      ]
31
     },
32
     "execution_count": 1,
33
     "metadata": {},
34
     "output_type": "execute_result"
35
    }
36
   ],
37
   "source": [
38
    "from sentence_transformers import SentenceTransformer\n",
39
    "\n",
40
    "model = SentenceTransformer('msmarco-distilbert-base-tas-b')\n",
41
    "model.max_seq_length = 256\n",
42
    "model"
43
   ]
44
  },
45
  {
46
   "cell_type": "markdown",
47
   "metadata": {},
48
   "source": [
49
    "And now initialize a Pinecone index for storing the encode passage vectors later."
50
   ]
51
  },
52
  {
53
   "cell_type": "code",
54
   "execution_count": 2,
55
   "metadata": {},
56
   "outputs": [],
57
   "source": [
58
    "from pinecone import Pinecone  # pip install pinecone-client\n",
59
    "\n",
60
    "index_name = \"negative-mining\"\n",
61
    "\n",
62
    "pinecone.init(\n",
63
    "    api_key=\"YOUR_API_KEY\",  # app.pinecone.io\n",
64
    "    environment=\"YOUR_ENV\"  # find next to API key in console\n",
65
    ")\n",
66
    "# create a new negative mining index if does not already exist\n",
67
    "if index_name not in pinecone.list_indexes().names():\n",
68
    "    pinecone.create_index(\n",
69
    "        index_name,\n",
70
    "        dimension=model.get_sentence_embedding_dimension(),\n",
71
    "        metric='dotproduct',\n",
72
    "        pods=1\n",
73
    "    )\n",
74
    "# connect\n",
75
    "index = pinecone.Index(index_name)"
76
   ]
77
  },
78
  {
79
   "attachments": {},
80
   "cell_type": "markdown",
81
   "metadata": {},
82
   "source": [
83
    "Now we encode the passages and store in the `negative-mining` index."
84
   ]
85
  },
86
  {
87
   "cell_type": "code",
88
   "execution_count": 3,
89
   "metadata": {},
90
   "outputs": [],
91
   "source": [
92
    "from tqdm.auto import tqdm\n",
93
    "\n",
94
    "def get_text():\n",
95
    "    with open('data/pairs.tsv', 'r', encoding='utf-8') as fp:\n",
96
    "        lines = fp.read().split('\\n')\n",
97
    "    for line in tqdm(lines):\n",
98
    "        try:\n",
99
    "            query, passage = line.split('\\t')\n",
100
    "            yield query, passage\n",
101
    "        except ValueError:\n",
102
    "            # in case of malformed data, pass onto next row\n",
103
    "            pass"
104
   ]
105
  },
106
  {
107
   "cell_type": "code",
108
   "execution_count": 4,
109
   "metadata": {},
110
   "outputs": [
111
    {
112
     "data": {
113
      "application/vnd.jupyter.widget-view+json": {
114
       "model_id": "f123d57309b042eca4ce279ec0aff06e",
115
       "version_major": 2,
116
       "version_minor": 0
117
      },
118
      "text/plain": [
119
       "  0%|          | 0/200 [00:00<?, ?it/s]"
120
      ]
121
     },
122
     "metadata": {},
123
     "output_type": "display_data"
124
    },
125
    {
126
     "data": {
127
      "text/plain": [
128
       "{'dimension': 768,\n",
129
       " 'index_fullness': 0.0,\n",
130
       " 'namespaces': {'': {'vector_count': 67840}}}"
131
      ]
132
     },
133
     "execution_count": 4,
134
     "metadata": {},
135
     "output_type": "execute_result"
136
    }
137
   ],
138
   "source": [
139
    "pair_gen = get_text()\n",
140
    "\n",
141
    "pairs = []\n",
142
    "to_upsert = []\n",
143
    "passage_batch = []\n",
144
    "id_batch = []\n",
145
    "batch_size = 64\n",
146
    "\n",
147
    "for i, (query, passage) in enumerate(pair_gen):\n",
148
    "    pairs.append((query, passage))\n",
149
    "    # we do this to avoid passage duplication in the vector DB\n",
150
    "    if passage not in passage_batch: \n",
151
    "        passage_batch.append(passage)\n",
152
    "        id_batch.append(str(i))\n",
153
    "    # on reaching batch_size, we encode and upsert\n",
154
    "    if len(passage_batch) == batch_size:\n",
155
    "        embeds = model.encode(passage_batch).tolist()\n",
156
    "        # upload to index\n",
157
    "        index.upsert(vectors=list(zip(id_batch, embeds)))\n",
158
    "        # refresh batches\n",
159
    "        passage_batch = []\n",
160
    "        id_batch = []\n",
161
    "        \n",
162
    "# check number of vectors in the index\n",
163
    "index.describe_index_stats()"
164
   ]
165
  },
166
  {
167
   "cell_type": "markdown",
168
   "metadata": {},
169
   "source": [
170
    "The database is setup for us to begin the *negative mining* step. We will loop through each query in `pairs`, returning *10* of the most similar passage."
171
   ]
172
  },
173
  {
174
   "cell_type": "code",
175
   "execution_count": 5,
176
   "metadata": {},
177
   "outputs": [
178
    {
179
     "data": {
180
      "application/vnd.jupyter.widget-view+json": {
181
       "model_id": "33e05700fb8a43daadf39f5c2f2166d5",
182
       "version_major": 2,
183
       "version_minor": 0
184
      },
185
      "text/plain": [
186
       "  0%|          | 0/2000 [00:00<?, ?it/s]"
187
      ]
188
     },
189
     "metadata": {},
190
     "output_type": "display_data"
191
    }
192
   ],
193
   "source": [
194
    "import random\n",
195
    "\n",
196
    "batch_size = 100\n",
197
    "triplets = []\n",
198
    "\n",
199
    "for i in tqdm(range(0, len(pairs), batch_size)):\n",
200
    "    # embed queries and query pinecone in batches to minimize network latency\n",
201
    "    i_end = min(i+batch_size, len(pairs))\n",
202
    "    queries = [pair[0] for pair in pairs[i:i_end]]\n",
203
    "    pos_passages = [pair[1] for pair in pairs[i:i_end]]\n",
204
    "    # create query embeddings\n",
205
    "    query_embs = model.encode(queries, convert_to_tensor=True, show_progress_bar=False)\n",
206
    "    # search for top_k most similar passages\n",
207
    "    res = index.query(vector=query_embs.tolist(), top_k=10)\n",
208
    "    # iterate through queries and find negatives\n",
209
    "    for query, pos_passage, query_res in zip(queries, pos_passages, res['results']):\n",
210
    "        top_results = query_res['matches']\n",
211
    "        # shuffle results so they are in random order\n",
212
    "        random.shuffle(top_results)\n",
213
    "        for hit in top_results:\n",
214
    "            neg_passage = pairs[int(hit['id'])][1]\n",
215
    "            # check that we're not just returning the positive passage\n",
216
    "            if neg_passage != pos_passage:\n",
217
    "                # if not we can add this to our (Q, P+, P-) triplets\n",
218
    "                triplets.append(query+'\\t'+pos_passage+'\\t'+neg_passage)\n",
219
    "                break\n",
220
    "\n",
221
    "with open('data/triplets.tsv', 'w', encoding='utf-8') as fp:\n",
222
    "    fp.write('\\n'.join(triplets))"
223
   ]
224
  },
225
  {
226
   "cell_type": "code",
227
   "execution_count": 6,
228
   "metadata": {},
229
   "outputs": [],
230
   "source": [
231
    "pinecone.delete_index(index_name)  # delete the index when done to avoid higher charges (if using multiple pods)"
232
   ]
233
  },
234
  {
235
   "cell_type": "markdown",
236
   "metadata": {},
237
   "source": [
238
    "With that we now have even more *(query, passage) pairs*, that are both positive and negative matches. The next step in GPL will see us scoring all of these pairs using a cross-encoder model."
239
   ]
240
  }
241
 ],
242
 "metadata": {
243
  "environment": {
244
   "kernel": "python3",
245
   "name": "common-cu110.m91",
246
   "type": "gcloud",
247
   "uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
248
  },
249
  "kernelspec": {
250
   "display_name": "Python 3",
251
   "language": "python",
252
   "name": "python3"
253
  },
254
  "language_info": {
255
   "codemirror_mode": {
256
    "name": "ipython",
257
    "version": 3
258
   },
259
   "file_extension": ".py",
260
   "mimetype": "text/x-python",
261
   "name": "python",
262
   "nbconvert_exporter": "python",
263
   "pygments_lexer": "ipython3",
264
   "version": "3.10.9"
265
  },
266
  "vscode": {
267
   "interpreter": {
268
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
269
   }
270
  }
271
 },
272
 "nbformat": 4,
273
 "nbformat_minor": 4
274
}
275

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.