txtai

Форк
0
/
16_Train_a_text_labeler.ipynb 
335 строк · 12.5 Кб
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "provenance": []
7
    },
8
    "kernelspec": {
9
      "name": "python3",
10
      "display_name": "Python 3"
11
    },
12
    "accelerator": "GPU"
13
  },
14
  "cells": [
15
    {
16
      "cell_type": "markdown",
17
      "metadata": {
18
        "id": "4Pjmz-RORV8E"
19
      },
20
      "source": [
21
        "# Train a text labeler\n",
22
        "\n",
23
        "The [Hugging Face Model Hub](https://huggingface.co/models) has a wide range of models that can handle many tasks. While these models perform well, the best performance often is found when fine-tuning a model with task-specific data. \n",
24
        "\n",
25
        "Hugging Face provides a [number of full-featured examples](https://github.com/huggingface/transformers/tree/master/examples) available to assist with training task-specific models. When building models from the command line, these scripts are a great way to get started.\n",
26
        "\n",
27
        "txtai provides a training pipeline that can be used to train new models programatically using the Transformers Trainer framework. The training pipeline supports the following:\n",
28
        "\n",
29
        "- Building transient models without requiring an output directory\n",
30
        "- Load training data from Hugging Face datasets, pandas DataFrames and list of dicts\n",
31
        "- Text sequence classification tasks (single/multi label classification and regression) including all GLUE tasks\n",
32
        "- All training arguments\n",
33
        "\n",
34
        "This notebook shows examples of how to use txtai to train/fine-tune new models."
35
      ]
36
    },
37
    {
38
      "cell_type": "markdown",
39
      "metadata": {
40
        "id": "Dk31rbYjSTYm"
41
      },
42
      "source": [
43
        "# Install dependencies\n",
44
        "\n",
45
        "Install `txtai` and all dependencies."
46
      ]
47
    },
48
    {
49
      "cell_type": "code",
50
      "metadata": {
51
        "id": "XMQuuun2R06J"
52
      },
53
      "source": [
54
        "%%capture\n",
55
        "!pip install git+https://github.com/neuml/txtai#egg=txtai[pipeline-train] datasets pandas"
56
      ],
57
      "execution_count": null,
58
      "outputs": []
59
    },
60
    {
61
      "cell_type": "markdown",
62
      "metadata": {
63
        "id": "PNPJ95cdTKSS"
64
      },
65
      "source": [
66
        "# Train a model\n",
67
        "\n",
68
        "Let's get right to it! The following example fine-tunes a tiny Bert model with the sst2 dataset.\n",
69
        "\n",
70
        "The trainer pipeline is basically a one-liner that fine-tunes any text classification/regression model available (locally and/or from the HF Hub). \n"
71
      ]
72
    },
73
    {
74
      "cell_type": "code",
75
      "metadata": {
76
        "id": "USb4JXZHxqTA"
77
      },
78
      "source": [
79
        "from datasets import load_dataset\n",
80
        "\n",
81
        "from txtai.pipeline import HFTrainer\n",
82
        "\n",
83
        "trainer = HFTrainer()\n",
84
        "\n",
85
        "# Hugging Face dataset\n",
86
        "ds = load_dataset(\"glue\", \"sst2\")\n",
87
        "model, tokenizer = trainer(\"google/bert_uncased_L-2_H-128_A-2\", ds[\"train\"], columns=(\"sentence\", \"label\"))"
88
      ],
89
      "execution_count": null,
90
      "outputs": []
91
    },
92
    {
93
      "cell_type": "markdown",
94
      "metadata": {
95
        "id": "CubsNAbpEWQg"
96
      },
97
      "source": [
98
        "The default trainer pipeline functionality will not store any logs, checkpoints or models to disk. The trainer can take any of the standard TrainingArguments to enable persistent models.\n",
99
        "\n",
100
        "The next section creates a Labels pipeline using the newly built model and runs the model against the sst2 validation set. "
101
      ]
102
    },
103
    {
104
      "cell_type": "code",
105
      "metadata": {
106
        "colab": {
107
          "base_uri": "https://localhost:8080/"
108
        },
109
        "id": "xw2y2C5Mg11_",
110
        "outputId": "78400e45-ea5c-4cd9-d205-b55ee7a9f005"
111
      },
112
      "source": [
113
        "from txtai.pipeline import Labels\n",
114
        "\n",
115
        "labels = Labels((model, tokenizer), dynamic=False)\n",
116
        "\n",
117
        "# Determine accuracy on validation set\n",
118
        "results = [row[\"label\"] == labels(row[\"sentence\"])[0][0] for row in ds[\"validation\"]]\n",
119
        "sum(results) / len(ds[\"validation\"])"
120
      ],
121
      "execution_count": null,
122
      "outputs": [
123
        {
124
          "output_type": "execute_result",
125
          "data": {
126
            "text/plain": [
127
              "0.8268348623853211"
128
            ]
129
          },
130
          "metadata": {},
131
          "execution_count": 10
132
        }
133
      ]
134
    },
135
    {
136
      "cell_type": "markdown",
137
      "metadata": {
138
        "id": "ZAHSwaB3Ex49"
139
      },
140
      "source": [
141
        "82.68% accuracy - not bad for a tiny Bert model. \n",
142
        "\n"
143
      ]
144
    },
145
    {
146
      "cell_type": "markdown",
147
      "metadata": {
148
        "id": "f3GkY4JNEhhE"
149
      },
150
      "source": [
151
        "# Train a model with Lists\n",
152
        "\n",
153
        "As mentioned earlier, the trainer pipeline supports Hugging Face datasets, pandas DataFrames and lists of dicts. The example below trains a model using lists."
154
      ]
155
    },
156
    {
157
      "cell_type": "code",
158
      "metadata": {
159
        "id": "QkApw1b2hfZq",
160
        "colab": {
161
          "base_uri": "https://localhost:8080/",
162
          "height": 182
163
        },
164
        "outputId": "8c3dceae-49fb-4b63-837d-5944e63c768e"
165
      },
166
      "source": [
167
        "data = [{\"text\": \"This is a test sentence\", \"label\": 0}, {\"text\": \"This is not a test\", \"label\": 1}]\n",
168
        "\n",
169
        "model, tokenizer = trainer(\"google/bert_uncased_L-2_H-128_A-2\", data)"
170
      ],
171
      "execution_count": null,
172
      "outputs": [
173
        {
174
          "output_type": "stream",
175
          "name": "stderr",
176
          "text": [
177
            "Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']\n",
178
            "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
179
            "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
180
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
181
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
182
          ]
183
        },
184
        {
185
          "output_type": "display_data",
186
          "data": {
187
            "text/html": [
188
              "\n",
189
              "    <div>\n",
190
              "      \n",
191
              "      <progress value='3' max='3' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
192
              "      [3/3 00:00, Epoch 3/3]\n",
193
              "    </div>\n",
194
              "    <table border=\"1\" class=\"dataframe\">\n",
195
              "  <thead>\n",
196
              "    <tr style=\"text-align: left;\">\n",
197
              "      <th>Step</th>\n",
198
              "      <th>Training Loss</th>\n",
199
              "    </tr>\n",
200
              "  </thead>\n",
201
              "  <tbody>\n",
202
              "  </tbody>\n",
203
              "</table><p>"
204
            ],
205
            "text/plain": [
206
              "<IPython.core.display.HTML object>"
207
            ]
208
          },
209
          "metadata": {}
210
        }
211
      ]
212
    },
213
    {
214
      "cell_type": "markdown",
215
      "metadata": {
216
        "id": "cjYTxm7sFKyZ"
217
      },
218
      "source": [
219
        "# Train a model with DataFrames\n",
220
        "\n",
221
        "The next section builds a new model using data stored in a pandas DataFrame."
222
      ]
223
    },
224
    {
225
      "cell_type": "code",
226
      "metadata": {
227
        "id": "0XaKKQ32wqbs",
228
        "colab": {
229
          "base_uri": "https://localhost:8080/",
230
          "height": 182
231
        },
232
        "outputId": "edb82a45-6c2a-4718-ce0b-56030f95ffbf"
233
      },
234
      "source": [
235
        "import pandas as pd\n",
236
        "\n",
237
        "df = pd.DataFrame(data)\n",
238
        "\n",
239
        "model, tokenizer = trainer(\"google/bert_uncased_L-2_H-128_A-2\", data)"
240
      ],
241
      "execution_count": null,
242
      "outputs": [
243
        {
244
          "output_type": "stream",
245
          "name": "stderr",
246
          "text": [
247
            "Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']\n",
248
            "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
249
            "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
250
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
251
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
252
          ]
253
        },
254
        {
255
          "output_type": "display_data",
256
          "data": {
257
            "text/html": [
258
              "\n",
259
              "    <div>\n",
260
              "      \n",
261
              "      <progress value='3' max='3' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
262
              "      [3/3 00:00, Epoch 3/3]\n",
263
              "    </div>\n",
264
              "    <table border=\"1\" class=\"dataframe\">\n",
265
              "  <thead>\n",
266
              "    <tr style=\"text-align: left;\">\n",
267
              "      <th>Step</th>\n",
268
              "      <th>Training Loss</th>\n",
269
              "    </tr>\n",
270
              "  </thead>\n",
271
              "  <tbody>\n",
272
              "  </tbody>\n",
273
              "</table><p>"
274
            ],
275
            "text/plain": [
276
              "<IPython.core.display.HTML object>"
277
            ]
278
          },
279
          "metadata": {}
280
        }
281
      ]
282
    },
283
    {
284
      "cell_type": "markdown",
285
      "metadata": {
286
        "id": "QH3D8PQSFvQO"
287
      },
288
      "source": [
289
        "# Train a regression model\n",
290
        "\n",
291
        "The previous models were classification tasks. The following model trains a sentence similarity model with a regression output per sentence pair between 0 (dissimilar) and 1 (similar)."
292
      ]
293
    },
294
    {
295
      "cell_type": "code",
296
      "metadata": {
297
        "id": "1rXuz4ncw9G-"
298
      },
299
      "source": [
300
        "ds = load_dataset(\"glue\", \"stsb\")\n",
301
        "model, tokenizer = trainer(\"google/bert_uncased_L-2_H-128_A-2\", ds[\"train\"], columns=(\"sentence1\", \"sentence2\", \"label\"))"
302
      ],
303
      "execution_count": null,
304
      "outputs": []
305
    },
306
    {
307
      "cell_type": "code",
308
      "metadata": {
309
        "colab": {
310
          "base_uri": "https://localhost:8080/"
311
        },
312
        "id": "fyvAslSP6j0F",
313
        "outputId": "ec46a6aa-25a7-4777-e226-d53aeb37899b"
314
      },
315
      "source": [
316
        "labels = Labels((model, tokenizer), dynamic=False)\n",
317
        "labels([[(\"Sailing to the arctic\", \"Dogs and cats don't get along\")], \n",
318
        "        [(\"Walking down the road\", \"Walking down the street\")]])"
319
      ],
320
      "execution_count": null,
321
      "outputs": [
322
        {
323
          "output_type": "execute_result",
324
          "data": {
325
            "text/plain": [
326
              "[[(0, 0.5648878216743469)], [(0, 0.97544926404953)]]"
327
            ]
328
          },
329
          "metadata": {},
330
          "execution_count": 14
331
        }
332
      ]
333
    }
334
  ]
335
}
336

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.