examples
235 строк · 6.4 Кб
1{
2"cells": [
3{
4"attachments": {},
5"cell_type": "markdown",
6"metadata": {},
7"source": [
8"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/04-finetune.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/04-finetune.ipynb)\n",
9"\n",
10"# Fine-tuning with MSEMargin Loss\n",
11"\n",
12"Now that we have our margin labeled *(Q, P<sup>+</sup>, P<sup>-</sup>)* pairs, we can begin fine-tuning a bi-encoder model with MSEMargin loss. We will start by defining a data loading function that uses the standard `InputExample` format of *sentence-transformers*."
13]
14},
15{
16"cell_type": "code",
17"execution_count": 1,
18"metadata": {},
19"outputs": [
20{
21"data": {
22"application/vnd.jupyter.widget-view+json": {
23"model_id": "a4549bb0e64f495b91db5f6eea7a17f6",
24"version_major": 2,
25"version_minor": 0
26},
27"text/plain": [
28" 0%| | 0/200000 [00:00<?, ?it/s]"
29]
30},
31"metadata": {},
32"output_type": "display_data"
33},
34{
35"data": {
36"text/plain": [
37"200000"
38]
39},
40"execution_count": 1,
41"metadata": {},
42"output_type": "execute_result"
43}
44],
45"source": [
46"from tqdm.auto import tqdm\n",
47"from sentence_transformers import InputExample\n",
48"\n",
49"training_data = []\n",
50"\n",
51"with open('data/triplets_margin.tsv', 'r', encoding='utf-8') as fp:\n",
52" lines = fp.read().split('\\n')\n",
53"# loop through each line and return InputExample\n",
54"for line in tqdm(lines):\n",
55" q, p, n, margin = line.split('\\t')\n",
56" training_data.append(InputExample(\n",
57" texts=[q, p, n],\n",
58" label=float(margin)\n",
59" ))\n",
60"\n",
61"len(training_data)"
62]
63},
64{
65"cell_type": "markdown",
66"metadata": {},
67"source": [
68"We load these pairs into a generator `DataLoader`. Margin MSE works best with a large `batch_size`, the `32` used here is reasonable."
69]
70},
71{
72"cell_type": "code",
73"execution_count": 3,
74"metadata": {},
75"outputs": [],
76"source": [
77"import torch\n",
78"\n",
79"torch.cuda.empty_cache()\n",
80"\n",
81"batch_size = 32\n",
82"\n",
83"loader = torch.utils.data.DataLoader(\n",
84" training_data, batch_size=batch_size, shuffle=True\n",
85")"
86]
87},
88{
89"cell_type": "markdown",
90"metadata": {},
91"source": [
92"Next we initialize a bi-encoder model that we will be fine-tuning using domain adaption."
93]
94},
95{
96"cell_type": "code",
97"execution_count": 4,
98"metadata": {},
99"outputs": [
100{
101"data": {
102"text/plain": [
103"SentenceTransformer(\n",
104" (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: DistilBertModel \n",
105" (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
106")"
107]
108},
109"execution_count": 4,
110"metadata": {},
111"output_type": "execute_result"
112}
113],
114"source": [
115"from sentence_transformers import SentenceTransformer\n",
116"\n",
117"model = SentenceTransformer('msmarco-distilbert-base-tas-b')\n",
118"model.max_seq_length = 256\n",
119"model"
120]
121},
122{
123"cell_type": "markdown",
124"metadata": {},
125"source": [
126"Then initialize the Margin MSE loss function."
127]
128},
129{
130"cell_type": "code",
131"execution_count": 5,
132"metadata": {},
133"outputs": [],
134"source": [
135"from sentence_transformers import losses\n",
136"\n",
137"loss = losses.MarginMSELoss(model)"
138]
139},
140{
141"cell_type": "code",
142"execution_count": 6,
143"metadata": {},
144"outputs": [
145{
146"name": "stderr",
147"output_type": "stream",
148"text": [
149"/opt/conda/lib/python3.7/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
150" FutureWarning,\n"
151]
152},
153{
154"data": {
155"application/vnd.jupyter.widget-view+json": {
156"model_id": "2055a68e819c4a2295d62029458f7c7e",
157"version_major": 2,
158"version_minor": 0
159},
160"text/plain": [
161"Epoch: 0%| | 0/1 [00:00<?, ?it/s]"
162]
163},
164"metadata": {},
165"output_type": "display_data"
166},
167{
168"data": {
169"application/vnd.jupyter.widget-view+json": {
170"model_id": "9e6a24e3909645e299b2e5ec192b2ef8",
171"version_major": 2,
172"version_minor": 0
173},
174"text/plain": [
175"Iteration: 0%| | 0/6250 [00:00<?, ?it/s]"
176]
177},
178"metadata": {},
179"output_type": "display_data"
180}
181],
182"source": [
183"epochs = 1\n",
184"warmup_steps = int(len(loader) * epochs * 0.1)\n",
185"\n",
186"model.fit(\n",
187" train_objectives=[(loader, loss)],\n",
188" epochs=epochs,\n",
189" warmup_steps=warmup_steps,\n",
190" output_path='msmarco-distilbert-base-tas-b-covid',\n",
191" show_progress_bar=True\n",
192")"
193]
194},
195{
196"cell_type": "markdown",
197"metadata": {},
198"source": [
199"The model is saved in the `msmarco-distilbert-base-tas-b-covid` directory."
200]
201}
202],
203"metadata": {
204"environment": {
205"kernel": "conda-root-py",
206"name": "common-cu110.m91",
207"type": "gcloud",
208"uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
209},
210"kernelspec": {
211"display_name": "Python 3",
212"language": "python",
213"name": "python3"
214},
215"language_info": {
216"codemirror_mode": {
217"name": "ipython",
218"version": 3
219},
220"file_extension": ".py",
221"mimetype": "text/x-python",
222"name": "python",
223"nbconvert_exporter": "python",
224"pygments_lexer": "ipython3",
225"version": "3.10.7 (main, Sep 14 2022, 22:38:23) [Clang 14.0.0 (clang-1400.0.29.102)]"
226},
227"vscode": {
228"interpreter": {
229"hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
230}
231}
232},
233"nbformat": 4,
234"nbformat_minor": 4
235}
236