txtai

Форк
0
/
17_Train_without_labels.ipynb 
297 строк · 11.1 Кб
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "provenance": []
7
    },
8
    "kernelspec": {
9
      "name": "python3",
10
      "display_name": "Python 3"
11
    },
12
    "accelerator": "GPU"
13
  },
14
  "cells": [
15
    {
16
      "cell_type": "markdown",
17
      "metadata": {
18
        "id": "4Pjmz-RORV8E"
19
      },
20
      "source": [
21
        "# Train without labels\n",
22
        "\n",
23
        "Almost all data available is unlabeled. Labeled data takes effort to manually review and/or takes time to collect. Zero-shot classification takes existing large language models and runs a similarity comparison between candidate text and a list of labels. This has been shown to perform surprisingly well.\n",
24
        "\n",
25
        "The problem with zero-shot classifiers is that they need to have a large number of parameters (400M+) to perform well against general tasks, which comes with sizable hardware requirements.\n",
26
        "\n",
27
        "This notebook explores using zero-shot classifiers to build training data for smaller models. A simple form of [knowledge distillation](https://en.wikipedia.org/wiki/Knowledge_distillation). "
28
      ]
29
    },
30
    {
31
      "cell_type": "markdown",
32
      "metadata": {
33
        "id": "Dk31rbYjSTYm"
34
      },
35
      "source": [
36
        "# Install dependencies\n",
37
        "\n",
38
        "Install `txtai` and all dependencies."
39
      ]
40
    },
41
    {
42
      "cell_type": "code",
43
      "metadata": {
44
        "id": "XMQuuun2R06J"
45
      },
46
      "source": [
47
        "%%capture\n",
48
        "!pip install git+https://github.com/neuml/txtai#egg=txtai[pipeline-train] datasets pandas"
49
      ],
50
      "execution_count": null,
51
      "outputs": []
52
    },
53
    {
54
      "cell_type": "markdown",
55
      "metadata": {
56
        "id": "3PUe1OW8IZR5"
57
      },
58
      "source": [
59
        "# Apply zero-shot classifier to unlabeled text\n",
60
        "\n",
61
        "The following section takes a small 1000 record random sample of the sst2 dataset and applies a zero-shot classifer to the text. The labels are ignored. This dataset was chosen only to be able to evaluate the accuracy at then end. "
62
      ]
63
    },
64
    {
65
      "cell_type": "code",
66
      "metadata": {
67
        "id": "GlrOnS4cmkih"
68
      },
69
      "source": [
70
        "import random\n",
71
        "\n",
72
        "from datasets import load_dataset\n",
73
        "\n",
74
        "from txtai.pipeline import Labels\n",
75
        "\n",
76
        "def batch(texts, size):\n",
77
        "    return [texts[x : x + size] for x in range(0, len(texts), size)]\n",
78
        "\n",
79
        "# Set random seed for repeatable sampling\n",
80
        "random.seed(42)\n",
81
        "\n",
82
        "ds = load_dataset(\"glue\", \"sst2\")\n",
83
        "\n",
84
        "sentences = random.sample(ds[\"train\"][\"sentence\"], 1000)\n",
85
        "\n",
86
        "# Load a zero shot classifier - txtai provides this through the Labels pipeline\n",
87
        "labels = Labels(\"microsoft/deberta-large-mnli\")\n",
88
        "\n",
89
        "train = []\n",
90
        "\n",
91
        "# Zero-shot prediction using [\"negative\", \"positive\"] labels\n",
92
        "for chunk in batch(sentences, 32):\n",
93
        "    train.extend([{\"text\": chunk[x], \"label\": label[0][0]} for x, label in enumerate(labels(chunk, [\"negative\", \"positive\"]))])"
94
      ],
95
      "execution_count": null,
96
      "outputs": []
97
    },
98
    {
99
      "cell_type": "markdown",
100
      "metadata": {
101
        "id": "TLsZmRpHJGav"
102
      },
103
      "source": [
104
        "Next, we'll use the training set we just built to train a smaller Electra model."
105
      ]
106
    },
107
    {
108
      "cell_type": "code",
109
      "metadata": {
110
        "colab": {
111
          "base_uri": "https://localhost:8080/",
112
          "height": 214
113
        },
114
        "id": "nAt42TIHnfTN",
115
        "outputId": "7080b21d-ecf4-459a-c818-11c748e28bb7"
116
      },
117
      "source": [
118
        "from txtai.pipeline import HFTrainer\n",
119
        "\n",
120
        "trainer = HFTrainer()\n",
121
        "model, tokenizer = trainer(\"google/electra-base-discriminator\", train, num_train_epochs=5)"
122
      ],
123
      "execution_count": null,
124
      "outputs": [
125
        {
126
          "output_type": "stream",
127
          "name": "stderr",
128
          "text": [
129
            "Some weights of the model checkpoint at google/electra-base-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias']\n",
130
            "- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
131
            "- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
132
            "Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-base-discriminator and are newly initialized: ['classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.weight']\n",
133
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
134
          ]
135
        },
136
        {
137
          "output_type": "display_data",
138
          "data": {
139
            "text/html": [
140
              "\n",
141
              "    <div>\n",
142
              "      \n",
143
              "      <progress value='625' max='625' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
144
              "      [625/625 02:51, Epoch 5/5]\n",
145
              "    </div>\n",
146
              "    <table border=\"1\" class=\"dataframe\">\n",
147
              "  <thead>\n",
148
              "    <tr style=\"text-align: left;\">\n",
149
              "      <th>Step</th>\n",
150
              "      <th>Training Loss</th>\n",
151
              "    </tr>\n",
152
              "  </thead>\n",
153
              "  <tbody>\n",
154
              "    <tr>\n",
155
              "      <td>500</td>\n",
156
              "      <td>0.282800</td>\n",
157
              "    </tr>\n",
158
              "  </tbody>\n",
159
              "</table><p>"
160
            ],
161
            "text/plain": [
162
              "<IPython.core.display.HTML object>"
163
            ]
164
          },
165
          "metadata": {}
166
        }
167
      ]
168
    },
169
    {
170
      "cell_type": "markdown",
171
      "metadata": {
172
        "id": "J9pugqJSJRn6"
173
      },
174
      "source": [
175
        "# Evaluating accuracy\n",
176
        "\n",
177
        "Recall the training set is only 1000 records. To be clear, training an Electra model against the full sst2 dataset would perform better than below. But for this exercise, we're are not using the training labels and simulating labeled data not being available.\n",
178
        "\n",
179
        "First, lets see what the baseline accuracy for the zero-shot model would be against the sst2 evaluation set. Reminder that this has not seen any of the sst2 training data. \n"
180
      ]
181
    },
182
    {
183
      "cell_type": "code",
184
      "metadata": {
185
        "colab": {
186
          "base_uri": "https://localhost:8080/"
187
        },
188
        "id": "RbgIrkgMvJS4",
189
        "outputId": "69287790-e01c-4c17-dfd5-0dc6afd73c98"
190
      },
191
      "source": [
192
        "labels = Labels(\"microsoft/deberta-large-mnli\")"
193
      ],
194
      "execution_count": null,
195
      "outputs": [
196
        {
197
          "output_type": "stream",
198
          "name": "stderr",
199
          "text": [
200
            "Some weights of the model checkpoint at microsoft/deberta-large-mnli were not used when initializing DebertaForSequenceClassification: ['config']\n",
201
            "- This IS expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
202
            "- This IS NOT expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
203
          ]
204
        }
205
      ]
206
    },
207
    {
208
      "cell_type": "code",
209
      "metadata": {
210
        "colab": {
211
          "base_uri": "https://localhost:8080/"
212
        },
213
        "id": "-36UBMILpKYh",
214
        "outputId": "3a340b9f-57c5-4c4c-d975-0fcc47df4930"
215
      },
216
      "source": [
217
        "results = [row[\"label\"] == labels(row[\"sentence\"], [\"negative\", \"positive\"])[0][0] for row in ds[\"validation\"]]\n",
218
        "sum(results) / len(ds[\"validation\"])"
219
      ],
220
      "execution_count": null,
221
      "outputs": [
222
        {
223
          "output_type": "execute_result",
224
          "data": {
225
            "text/plain": [
226
              "0.8818807339449541"
227
            ]
228
          },
229
          "metadata": {},
230
          "execution_count": 21
231
        }
232
      ]
233
    },
234
    {
235
      "cell_type": "markdown",
236
      "metadata": {
237
        "id": "uJVnWHZZKFIN"
238
      },
239
      "source": [
240
        "88.19% accuracy, not bad for a model that has not been trained on the dataset at all! Shows the power of zero-shot classification.\n",
241
        "\n",
242
        "Next, let's test our model trained on the 1000 zero-shot labeled records."
243
      ]
244
    },
245
    {
246
      "cell_type": "code",
247
      "metadata": {
248
        "colab": {
249
          "base_uri": "https://localhost:8080/"
250
        },
251
        "id": "Kr5IZqZtvXlP",
252
        "outputId": "1faeb0d6-349b-4982-e9e8-cdbbde9e9a09"
253
      },
254
      "source": [
255
        "labels = Labels((model, tokenizer), dynamic=False)\n",
256
        "\n",
257
        "results = [row[\"label\"] == labels(row[\"sentence\"])[0][0] for row in ds[\"validation\"]]\n",
258
        "sum(results) / len(ds[\"validation\"])"
259
      ],
260
      "execution_count": null,
261
      "outputs": [
262
        {
263
          "output_type": "execute_result",
264
          "data": {
265
            "text/plain": [
266
              "0.8738532110091743"
267
            ]
268
          },
269
          "metadata": {},
270
          "execution_count": 22
271
        }
272
      ]
273
    },
274
    {
275
      "cell_type": "markdown",
276
      "metadata": {
277
        "id": "sDw-Zh43KVdX"
278
      },
279
      "source": [
280
        "87.39% accuracy! Wouldn't get too carried away with the percentages but this at least nearly meets the accuracy of the zero-shot classifier.\n",
281
        "\n",
282
        "Now this model will be highly tuned for a specific task but it had the opportunity to learn from the combined 1000 records whereas the zero-shot classifier views each record independently. It's also much more performant. "
283
      ]
284
    },
285
    {
286
      "cell_type": "markdown",
287
      "metadata": {
288
        "id": "QEAwki2lLM2A"
289
      },
290
      "source": [
291
        "# Conclusion\n",
292
        "\n",
293
        "This notebook explored a method of building trained text classifiers without training data being available. Given the amount of resources needed to run large-scale zero-shot classifiers, this method is a simple way to build smaller models tuned for specific tasks. In this example, the zero-shot classifier has 400M parameters and the trained text classifier has 110M. "
294
      ]
295
    }
296
  ]
297
}
298

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.