examples

Форк
0
235 строк · 6.4 Кб
1
{
2
 "cells": [
3
  {
4
   "attachments": {},
5
   "cell_type": "markdown",
6
   "metadata": {},
7
   "source": [
8
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/04-finetune.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/04-finetune.ipynb)\n",
9
    "\n",
10
    "# Fine-tuning with MSEMargin Loss\n",
11
    "\n",
12
    "Now that we have our margin labeled *(Q, P<sup>+</sup>, P<sup>-</sup>)* pairs, we can begin fine-tuning a bi-encoder model with MSEMargin loss. We will start by defining a data loading function that uses the standard `InputExample` format of *sentence-transformers*."
13
   ]
14
  },
15
  {
16
   "cell_type": "code",
17
   "execution_count": 1,
18
   "metadata": {},
19
   "outputs": [
20
    {
21
     "data": {
22
      "application/vnd.jupyter.widget-view+json": {
23
       "model_id": "a4549bb0e64f495b91db5f6eea7a17f6",
24
       "version_major": 2,
25
       "version_minor": 0
26
      },
27
      "text/plain": [
28
       "  0%|          | 0/200000 [00:00<?, ?it/s]"
29
      ]
30
     },
31
     "metadata": {},
32
     "output_type": "display_data"
33
    },
34
    {
35
     "data": {
36
      "text/plain": [
37
       "200000"
38
      ]
39
     },
40
     "execution_count": 1,
41
     "metadata": {},
42
     "output_type": "execute_result"
43
    }
44
   ],
45
   "source": [
46
    "from tqdm.auto import tqdm\n",
47
    "from sentence_transformers import InputExample\n",
48
    "\n",
49
    "training_data = []\n",
50
    "\n",
51
    "with open('data/triplets_margin.tsv', 'r', encoding='utf-8') as fp:\n",
52
    "    lines = fp.read().split('\\n')\n",
53
    "# loop through each line and return InputExample\n",
54
    "for line in tqdm(lines):\n",
55
    "    q, p, n, margin = line.split('\\t')\n",
56
    "    training_data.append(InputExample(\n",
57
    "        texts=[q, p, n],\n",
58
    "        label=float(margin)\n",
59
    "    ))\n",
60
    "\n",
61
    "len(training_data)"
62
   ]
63
  },
64
  {
65
   "cell_type": "markdown",
66
   "metadata": {},
67
   "source": [
68
    "We load these pairs into a generator `DataLoader`. Margin MSE works best with a large `batch_size`, the `32` used here is reasonable."
69
   ]
70
  },
71
  {
72
   "cell_type": "code",
73
   "execution_count": 3,
74
   "metadata": {},
75
   "outputs": [],
76
   "source": [
77
    "import torch\n",
78
    "\n",
79
    "torch.cuda.empty_cache()\n",
80
    "\n",
81
    "batch_size = 32\n",
82
    "\n",
83
    "loader = torch.utils.data.DataLoader(\n",
84
    "    training_data, batch_size=batch_size, shuffle=True\n",
85
    ")"
86
   ]
87
  },
88
  {
89
   "cell_type": "markdown",
90
   "metadata": {},
91
   "source": [
92
    "Next we initialize a bi-encoder model that we will be fine-tuning using domain adaption."
93
   ]
94
  },
95
  {
96
   "cell_type": "code",
97
   "execution_count": 4,
98
   "metadata": {},
99
   "outputs": [
100
    {
101
     "data": {
102
      "text/plain": [
103
       "SentenceTransformer(\n",
104
       "  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: DistilBertModel \n",
105
       "  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
106
       ")"
107
      ]
108
     },
109
     "execution_count": 4,
110
     "metadata": {},
111
     "output_type": "execute_result"
112
    }
113
   ],
114
   "source": [
115
    "from sentence_transformers import SentenceTransformer\n",
116
    "\n",
117
    "model = SentenceTransformer('msmarco-distilbert-base-tas-b')\n",
118
    "model.max_seq_length = 256\n",
119
    "model"
120
   ]
121
  },
122
  {
123
   "cell_type": "markdown",
124
   "metadata": {},
125
   "source": [
126
    "Then initialize the Margin MSE loss function."
127
   ]
128
  },
129
  {
130
   "cell_type": "code",
131
   "execution_count": 5,
132
   "metadata": {},
133
   "outputs": [],
134
   "source": [
135
    "from sentence_transformers import losses\n",
136
    "\n",
137
    "loss = losses.MarginMSELoss(model)"
138
   ]
139
  },
140
  {
141
   "cell_type": "code",
142
   "execution_count": 6,
143
   "metadata": {},
144
   "outputs": [
145
    {
146
     "name": "stderr",
147
     "output_type": "stream",
148
     "text": [
149
      "/opt/conda/lib/python3.7/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
150
      "  FutureWarning,\n"
151
     ]
152
    },
153
    {
154
     "data": {
155
      "application/vnd.jupyter.widget-view+json": {
156
       "model_id": "2055a68e819c4a2295d62029458f7c7e",
157
       "version_major": 2,
158
       "version_minor": 0
159
      },
160
      "text/plain": [
161
       "Epoch:   0%|          | 0/1 [00:00<?, ?it/s]"
162
      ]
163
     },
164
     "metadata": {},
165
     "output_type": "display_data"
166
    },
167
    {
168
     "data": {
169
      "application/vnd.jupyter.widget-view+json": {
170
       "model_id": "9e6a24e3909645e299b2e5ec192b2ef8",
171
       "version_major": 2,
172
       "version_minor": 0
173
      },
174
      "text/plain": [
175
       "Iteration:   0%|          | 0/6250 [00:00<?, ?it/s]"
176
      ]
177
     },
178
     "metadata": {},
179
     "output_type": "display_data"
180
    }
181
   ],
182
   "source": [
183
    "epochs = 1\n",
184
    "warmup_steps = int(len(loader) * epochs * 0.1)\n",
185
    "\n",
186
    "model.fit(\n",
187
    "    train_objectives=[(loader, loss)],\n",
188
    "    epochs=epochs,\n",
189
    "    warmup_steps=warmup_steps,\n",
190
    "    output_path='msmarco-distilbert-base-tas-b-covid',\n",
191
    "    show_progress_bar=True\n",
192
    ")"
193
   ]
194
  },
195
  {
196
   "cell_type": "markdown",
197
   "metadata": {},
198
   "source": [
199
    "The model is saved in the `msmarco-distilbert-base-tas-b-covid` directory."
200
   ]
201
  }
202
 ],
203
 "metadata": {
204
  "environment": {
205
   "kernel": "conda-root-py",
206
   "name": "common-cu110.m91",
207
   "type": "gcloud",
208
   "uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
209
  },
210
  "kernelspec": {
211
   "display_name": "Python 3",
212
   "language": "python",
213
   "name": "python3"
214
  },
215
  "language_info": {
216
   "codemirror_mode": {
217
    "name": "ipython",
218
    "version": 3
219
   },
220
   "file_extension": ".py",
221
   "mimetype": "text/x-python",
222
   "name": "python",
223
   "nbconvert_exporter": "python",
224
   "pygments_lexer": "ipython3",
225
   "version": "3.10.7 (main, Sep 14 2022, 22:38:23) [Clang 14.0.0 (clang-1400.0.29.102)]"
226
  },
227
  "vscode": {
228
   "interpreter": {
229
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
230
   }
231
  }
232
 },
233
 "nbformat": 4,
234
 "nbformat_minor": 4
235
}
236

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.