haystack-tutorials

Форк
0
/
09_DPR_training.ipynb 
491 строка · 14.6 Кб
1
{
2
 "cells": [
3
  {
4
   "attachments": {},
5
   "cell_type": "markdown",
6
   "metadata": {
7
    "collapsed": false,
8
    "pycharm": {
9
     "name": "#%% md\n"
10
    }
11
   },
12
   "source": [
13
    "# Training Your Own \"Dense Passage Retrieval\" Model\n",
14
    "\n",
15
    "Haystack contains all the tools needed to train your own Dense Passage Retrieval model.\n",
16
    "This tutorial will guide you through the steps required to create a retriever that is specifically tailored to your domain."
17
   ]
18
  },
19
  {
20
   "attachments": {},
21
   "cell_type": "markdown",
22
   "metadata": {},
23
   "source": [
24
    "## Installing Haystack\n",
25
    "\n",
26
    "To start, let's install the latest release of Haystack with `pip`:"
27
   ]
28
  },
29
  {
30
   "cell_type": "code",
31
   "execution_count": null,
32
   "metadata": {
33
    "collapsed": false,
34
    "pycharm": {
35
     "name": "#%%\n"
36
    }
37
   },
38
   "outputs": [],
39
   "source": [
40
    "%%bash\n",
41
    "\n",
42
    "pip install --upgrade pip\n",
43
    "pip install farm-haystack[colab,inference,metrics]"
44
   ]
45
  },
46
  {
47
   "attachments": {},
48
   "cell_type": "markdown",
49
   "metadata": {},
50
   "source": [
51
    "### Enabling Telemetry \n",
52
    "Knowing you're using this tutorial helps us decide where to invest our efforts to build a better product but you can always opt out by commenting the following line. See [Telemetry](https://docs.haystack.deepset.ai/docs/telemetry) for more details."
53
   ]
54
  },
55
  {
56
   "cell_type": "code",
57
   "execution_count": null,
58
   "metadata": {},
59
   "outputs": [],
60
   "source": [
61
    "from haystack.telemetry import tutorial_running\n",
62
    "\n",
63
    "tutorial_running(9)"
64
   ]
65
  },
66
  {
67
   "attachments": {},
68
   "cell_type": "markdown",
69
   "metadata": {
70
    "collapsed": false,
71
    "pycharm": {
72
     "name": "#%% md\n"
73
    }
74
   },
75
   "source": [
76
    "## Logging\n",
77
    "\n",
78
    "We configure how logging messages should be displayed and which log level should be used before importing Haystack.\n",
79
    "Example log message:\n",
80
    "INFO - haystack.utils.preprocessing -  Converting data/tutorial1/218_Olenna_Tyrell.txt\n",
81
    "Default log level in basicConfig is WARNING so the explicit parameter is not necessary but can be changed easily:"
82
   ]
83
  },
84
  {
85
   "cell_type": "code",
86
   "execution_count": null,
87
   "metadata": {
88
    "collapsed": false,
89
    "pycharm": {
90
     "name": "#%%\n"
91
    }
92
   },
93
   "outputs": [],
94
   "source": [
95
    "import logging\n",
96
    "\n",
97
    "logging.basicConfig(format=\"%(levelname)s - %(name)s -  %(message)s\", level=logging.WARNING)\n",
98
    "logging.getLogger(\"haystack\").setLevel(logging.INFO)"
99
   ]
100
  },
101
  {
102
   "cell_type": "code",
103
   "execution_count": null,
104
   "metadata": {
105
    "collapsed": false,
106
    "pycharm": {
107
     "name": "#%%\n"
108
    }
109
   },
110
   "outputs": [],
111
   "source": [
112
    "# Here are some imports that we'll need\n",
113
    "\n",
114
    "from haystack.nodes import DensePassageRetriever\n",
115
    "from haystack.utils import fetch_archive_from_http\n",
116
    "from haystack.document_stores import InMemoryDocumentStore"
117
   ]
118
  },
119
  {
120
   "attachments": {},
121
   "cell_type": "markdown",
122
   "metadata": {
123
    "collapsed": false,
124
    "pycharm": {
125
     "name": "#%% md\n"
126
    }
127
   },
128
   "source": [
129
    "## Training Data\n",
130
    "\n",
131
    "DPR training performed using Information Retrieval data.\n",
132
    "More specifically, you want to feed in pairs of queries and relevant documents.\n",
133
    "\n",
134
    "To train a model, we will need a dataset that has the same format as the original DPR training data.\n",
135
    "Each data point in the dataset should have the following dictionary structure.\n",
136
    "\n",
137
    "``` python\n",
138
    "    {\n",
139
    "        \"dataset\": str,\n",
140
    "        \"question\": str,\n",
141
    "        \"answers\": list of str\n",
142
    "        \"positive_ctxs\": list of dictionaries of format {'title': str, 'text': str, 'score': int, 'title_score': int, 'passage_id': str}\n",
143
    "        \"negative_ctxs\": list of dictionaries of format {'title': str, 'text': str, 'score': int, 'title_score': int, 'passage_id': str}\n",
144
    "        \"hard_negative_ctxs\": list of dictionaries of format {'title': str, 'text': str, 'score': int, 'title_score': int, 'passage_id': str}\n",
145
    "    }\n",
146
    "```\n",
147
    "\n",
148
    "`positive_ctxs` are context passages which are relevant to the query.\n",
149
    "In some datasets, queries might have more than one positive context\n",
150
    "in which case you can set the `num_positives` parameter to be higher than the default 1.\n",
151
    "Note that `num_positives` needs to be lower or equal to the minimum number of `positive_ctxs` for queries in your data.\n",
152
    "If you have an unequal number of positive contexts per example,\n",
153
    "you might want to generate some soft labels by retrieving similar contexts which contain the answer.\n",
154
    "\n",
155
    "DPR is standardly trained using a method known as in-batch negatives.\n",
156
    "This means that positive contexts for a given query are treated as negative contexts for the other queries in the batch.\n",
157
    "Doing so allows for a high degree of computational efficiency, thus allowing the model to be trained on large amounts of data.\n",
158
    "\n",
159
    "`negative_ctxs` is not actually used in Haystack's DPR training so we recommend you set it to an empty list.\n",
160
    "They were used by the original DPR authors in an experiment to compare it against the in-batch negatives method.\n",
161
    "\n",
162
    "`hard_negative_ctxs` are passages that are not relevant to the query.\n",
163
    "In the original DPR paper, these are fetched using a retriever to find the most relevant passages to the query.\n",
164
    "Passages which contain the answer text are filtered out.\n",
165
    "\n",
166
    "If you'd like to convert your SQuAD format data into something that can train a DPR model,\n",
167
    "check out the utility script at [`haystack/utils/squad_to_dpr.py`](https://github.com/deepset-ai/haystack/blob/main/haystack/utils/squad_to_dpr.py)"
168
   ]
169
  },
170
  {
171
   "attachments": {},
172
   "cell_type": "markdown",
173
   "metadata": {
174
    "collapsed": false,
175
    "pycharm": {
176
     "name": "#%% md\n"
177
    }
178
   },
179
   "source": [
180
    "## Using Question Answering Data\n",
181
    "\n",
182
    "Question Answering datasets can sometimes be used as training data.\n",
183
    "Google's Natural Questions dataset, is sufficiently large\n",
184
    "and contains enough unique passages, that it can be converted into a DPR training set.\n",
185
    "This is done simply by considering answer containing passages as relevant documents to the query.\n",
186
    "\n",
187
    "The SQuAD dataset, however, is not as suited to this use case since its question and answer pairs\n",
188
    "are created on only a very small slice of wikipedia documents."
189
   ]
190
  },
191
  {
192
   "attachments": {},
193
   "cell_type": "markdown",
194
   "metadata": {
195
    "collapsed": false,
196
    "pycharm": {
197
     "name": "#%% md\n"
198
    }
199
   },
200
   "source": [
201
    "## Download Original DPR Training Data\n",
202
    "\n",
203
    "WARNING: These files are large! The train set is 7.4GB and the dev set is 800MB\n",
204
    "\n",
205
    "We can download the original DPR training data with the following cell.\n",
206
    "Note that this data is probably only useful if you are trying to train from scratch."
207
   ]
208
  },
209
  {
210
   "cell_type": "code",
211
   "execution_count": null,
212
   "metadata": {
213
    "collapsed": false,
214
    "pycharm": {
215
     "name": "#%%\n"
216
    }
217
   },
218
   "outputs": [],
219
   "source": [
220
    "# Download original DPR data\n",
221
    "# WARNING: the train set is 7.4GB and the dev set is 800MB\n",
222
    "\n",
223
    "doc_dir = \"data/tutorial9\"\n",
224
    "\n",
225
    "s3_url_train = \"https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz\"\n",
226
    "s3_url_dev = \"https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz\"\n",
227
    "\n",
228
    "fetch_archive_from_http(s3_url_train, output_dir=doc_dir + \"/train\")\n",
229
    "fetch_archive_from_http(s3_url_dev, output_dir=doc_dir + \"/dev\")"
230
   ]
231
  },
232
  {
233
   "attachments": {},
234
   "cell_type": "markdown",
235
   "metadata": {
236
    "collapsed": false,
237
    "pycharm": {
238
     "name": "#%% md\n"
239
    }
240
   },
241
   "source": [
242
    "## Option 1: Training DPR from Scratch\n",
243
    "\n",
244
    "The default variables that we provide below are chosen to train a DPR model from scratch.\n",
245
    "Here, both passage and query embedding models are initialized using BERT base\n",
246
    "and the model is trained using Google's Natural Questions dataset (in a format specialised for DPR).\n",
247
    "\n",
248
    "If you are working in a language other than English,\n",
249
    "you will want to initialize the passage and query embedding models with a language model that supports your language\n",
250
    "and also provide a dataset in your language."
251
   ]
252
  },
253
  {
254
   "cell_type": "code",
255
   "execution_count": null,
256
   "metadata": {
257
    "collapsed": false,
258
    "pycharm": {
259
     "name": "#%%\n"
260
    }
261
   },
262
   "outputs": [],
263
   "source": [
264
    "# Here are the variables to specify our training data, the models that we use to initialize DPR\n",
265
    "# and the directory where we'll be saving the model\n",
266
    "\n",
267
    "train_filename = \"train/biencoder-nq-train.json\"\n",
268
    "dev_filename = \"dev/biencoder-nq-dev.json\"\n",
269
    "\n",
270
    "query_model = \"bert-base-uncased\"\n",
271
    "passage_model = \"bert-base-uncased\"\n",
272
    "\n",
273
    "save_dir = \"../saved_models/dpr\""
274
   ]
275
  },
276
  {
277
   "attachments": {},
278
   "cell_type": "markdown",
279
   "metadata": {
280
    "collapsed": false,
281
    "pycharm": {
282
     "name": "#%% md\n"
283
    }
284
   },
285
   "source": [
286
    "## Option 2: Finetuning DPR\n",
287
    "\n",
288
    "If you have your own domain specific question answering or information retrieval dataset,\n",
289
    "you might instead be interested in finetuning a pretrained DPR model.\n",
290
    "In this case, you would initialize both query and passage models using the original pretrained model.\n",
291
    "You will want to load something like this set of variables instead of the ones above"
292
   ]
293
  },
294
  {
295
   "cell_type": "code",
296
   "execution_count": null,
297
   "metadata": {
298
    "collapsed": false,
299
    "pycharm": {
300
     "name": "#%%\n"
301
    }
302
   },
303
   "outputs": [],
304
   "source": [
305
    "# Here are the variables you might want to use instead of the set above\n",
306
    "# in order to perform pretraining\n",
307
    "\n",
308
    "doc_dir = \"PATH_TO_YOUR_DATA_DIR\"\n",
309
    "train_filename = \"TRAIN_FILENAME\"\n",
310
    "dev_filename = \"DEV_FILENAME\"\n",
311
    "\n",
312
    "query_model = \"facebook/dpr-question_encoder-single-nq-base\"\n",
313
    "passage_model = \"facebook/dpr-ctx_encoder-single-nq-base\"\n",
314
    "\n",
315
    "save_dir = \"../saved_models/dpr\""
316
   ]
317
  },
318
  {
319
   "attachments": {},
320
   "cell_type": "markdown",
321
   "metadata": {
322
    "collapsed": false,
323
    "pycharm": {
324
     "name": "#%% md\n"
325
    }
326
   },
327
   "source": [
328
    "## Initialization\n",
329
    "\n",
330
    "Here we want to initialize our model either with plain language model weights for training from scratch\n",
331
    "or else with pretrained DPR weights for finetuning.\n",
332
    "We follow the [original DPR parameters](https://github.com/facebookresearch/DPR#best-hyperparameter-settings)\n",
333
    "for their max passage length but set max query length to 64 since queries are very rarely longer."
334
   ]
335
  },
336
  {
337
   "cell_type": "code",
338
   "execution_count": null,
339
   "metadata": {
340
    "collapsed": false,
341
    "pycharm": {
342
     "name": "#%%\n"
343
    }
344
   },
345
   "outputs": [],
346
   "source": [
347
    "## Initialize DPR model\n",
348
    "\n",
349
    "retriever = DensePassageRetriever(\n",
350
    "    document_store=InMemoryDocumentStore(),\n",
351
    "    query_embedding_model=query_model,\n",
352
    "    passage_embedding_model=passage_model,\n",
353
    "    max_seq_len_query=64,\n",
354
    "    max_seq_len_passage=256,\n",
355
    ")"
356
   ]
357
  },
358
  {
359
   "attachments": {},
360
   "cell_type": "markdown",
361
   "metadata": {
362
    "collapsed": false,
363
    "pycharm": {
364
     "name": "#%% md\n"
365
    }
366
   },
367
   "source": [
368
    "## Training\n",
369
    "\n",
370
    "Let's start training and save our trained model!\n",
371
    "\n",
372
    "On a V100 GPU, you can fit up to batch size 16 so we set gradient accumulation steps to 8 in order\n",
373
    "to simulate the batch size 128 of the original DPR experiment.\n",
374
    "\n",
375
    "When `embed_title=True`, the document title is prepended to the input text sequence with a `[SEP]` token\n",
376
    "between it and document text."
377
   ]
378
  },
379
  {
380
   "attachments": {},
381
   "cell_type": "markdown",
382
   "metadata": {
383
    "collapsed": false,
384
    "pycharm": {
385
     "name": "#%% md\n"
386
    }
387
   },
388
   "source": [
389
    "When training from scratch with the above variables, 1 epoch takes around an hour and we reached the following performance:\n",
390
    "\n",
391
    "```\n",
392
    "loss: 0.046580662854042276\n",
393
    "task_name: text_similarity\n",
394
    "acc: 0.992524064068483\n",
395
    "f1: 0.8804297774366846\n",
396
    "acc_and_f1: 0.9364769207525838\n",
397
    "average_rank: 0.19631619339984652\n",
398
    "report:\n",
399
    "                precision    recall  f1-score   support\n",
400
    "\n",
401
    "hard_negative     0.9961    0.9961    0.9961    201887\n",
402
    "     positive     0.8804    0.8804    0.8804      6515\n",
403
    "\n",
404
    "     accuracy                         0.9925    208402\n",
405
    "    macro avg     0.9383    0.9383    0.9383    208402\n",
406
    " weighted avg     0.9925    0.9925    0.9925    208402\n",
407
    "\n",
408
    "```"
409
   ]
410
  },
411
  {
412
   "cell_type": "code",
413
   "execution_count": null,
414
   "metadata": {
415
    "collapsed": false,
416
    "pycharm": {
417
     "name": "#%%\n"
418
    }
419
   },
420
   "outputs": [],
421
   "source": [
422
    "# Start training our model and save it when it is finished\n",
423
    "\n",
424
    "retriever.train(\n",
425
    "    data_dir=doc_dir,\n",
426
    "    train_filename=train_filename,\n",
427
    "    dev_filename=dev_filename,\n",
428
    "    test_filename=dev_filename,\n",
429
    "    n_epochs=1,\n",
430
    "    batch_size=16,\n",
431
    "    grad_acc_steps=8,\n",
432
    "    save_dir=save_dir,\n",
433
    "    evaluate_every=3000,\n",
434
    "    embed_title=True,\n",
435
    "    num_positives=1,\n",
436
    "    num_hard_negatives=1,\n",
437
    ")"
438
   ]
439
  },
440
  {
441
   "attachments": {},
442
   "cell_type": "markdown",
443
   "metadata": {
444
    "collapsed": false,
445
    "pycharm": {
446
     "name": "#%% md\n"
447
    }
448
   },
449
   "source": [
450
    "## Loading\n",
451
    "\n",
452
    "Loading our newly trained model is simple!"
453
   ]
454
  },
455
  {
456
   "cell_type": "code",
457
   "execution_count": null,
458
   "metadata": {
459
    "collapsed": false,
460
    "pycharm": {
461
     "name": "#%%\n"
462
    }
463
   },
464
   "outputs": [],
465
   "source": [
466
    "reloaded_retriever = DensePassageRetriever.load(load_dir=save_dir, document_store=None)"
467
   ]
468
  }
469
 ],
470
 "metadata": {
471
  "kernelspec": {
472
   "display_name": "Python 3",
473
   "language": "python",
474
   "name": "python3"
475
  },
476
  "language_info": {
477
   "codemirror_mode": {
478
    "name": "ipython",
479
    "version": 2
480
   },
481
   "file_extension": ".py",
482
   "mimetype": "text/x-python",
483
   "name": "python",
484
   "nbconvert_exporter": "python",
485
   "pygments_lexer": "ipython2",
486
   "version": "2.7.6"
487
  }
488
 },
489
 "nbformat": 4,
490
 "nbformat_minor": 2
491
}
492

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.