openai-cookbook

Форк
0
/
Multiclass_classification_for_transactions.ipynb 
2211 строк · 83.0 Кб
1
{
2
 "cells": [
3
  {
4
   "attachments": {},
5
   "cell_type": "markdown",
6
   "metadata": {},
7
   "source": [
8
    "# Multiclass Classification for Transactions\n",
9
    "\n",
10
    "For this notebook we will be looking to classify a public dataset of transactions into a number of categories that we have predefined. These approaches should be replicable to any multiclass classification use case where we are trying to fit transactional data into predefined categories, and by the end of running through this you should have a few approaches for dealing with both labelled and unlabelled datasets.\n",
11
    "\n",
12
    "The different approaches we'll be taking in this notebook are:\n",
13
    "- **Zero-shot Classification:** First we'll do zero shot classification to put transactions in one of five named buckets using only a prompt for guidance\n",
14
    "- **Classification with Embeddings:** Following this we'll create embeddings on a labelled dataset, and then use a traditional classification model to test their effectiveness at identifying our categories\n",
15
    "- **Fine-tuned Classification:** Lastly we'll produce a fine-tuned model trained on our labelled dataset to see how this compares to the zero-shot and few-shot classification approaches"
16
   ]
17
  },
18
  {
19
   "attachments": {},
20
   "cell_type": "markdown",
21
   "metadata": {},
22
   "source": [
23
    "## Setup"
24
   ]
25
  },
26
  {
27
   "cell_type": "code",
28
   "execution_count": 1,
29
   "metadata": {},
30
   "outputs": [],
31
   "source": [
32
    "%load_ext autoreload\n",
33
    "%autoreload\n",
34
    "%pip install openai 'openai[datalib]' 'openai[embeddings]' transformers\n"
35
   ]
36
  },
37
  {
38
   "cell_type": "code",
39
   "execution_count": 311,
40
   "metadata": {},
41
   "outputs": [],
42
   "source": [
43
    "import openai\n",
44
    "import pandas as pd\n",
45
    "import numpy as np\n",
46
    "import json\n",
47
    "import os\n",
48
    "\n",
49
    "COMPLETIONS_MODEL = \"gpt-4\"\n",
50
    "\n",
51
    "client = openai.OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if you didn't set as an env var>\"))"
52
   ]
53
  },
54
  {
55
   "attachments": {},
56
   "cell_type": "markdown",
57
   "metadata": {},
58
   "source": [
59
    "### Load dataset\n",
60
    "\n",
61
    "We're using a public transaction dataset of transactions over £25k for the Library of Scotland. The dataset has three features that we'll be using:\n",
62
    "- Supplier: The name of the supplier\n",
63
    "- Description: A text description of the transaction\n",
64
    "- Value: The value of the transaction in GBP\n",
65
    "\n",
66
    "**Source**:\n",
67
    "\n",
68
    "https://data.nls.uk/data/organisational-data/transactions-over-25k/"
69
   ]
70
  },
71
  {
72
   "cell_type": "code",
73
   "execution_count": 312,
74
   "metadata": {},
75
   "outputs": [
76
    {
77
     "data": {
78
      "text/plain": [
79
       "359"
80
      ]
81
     },
82
     "execution_count": 312,
83
     "metadata": {},
84
     "output_type": "execute_result"
85
    }
86
   ],
87
   "source": [
88
    "transactions = pd.read_csv('./data/25000_spend_dataset_current.csv', encoding= 'unicode_escape')\n",
89
    "len(transactions)\n"
90
   ]
91
  },
92
  {
93
   "cell_type": "code",
94
   "execution_count": 313,
95
   "metadata": {},
96
   "outputs": [
97
    {
98
     "data": {
99
      "text/html": [
100
       "<div>\n",
101
       "<style scoped>\n",
102
       "    .dataframe tbody tr th:only-of-type {\n",
103
       "        vertical-align: middle;\n",
104
       "    }\n",
105
       "\n",
106
       "    .dataframe tbody tr th {\n",
107
       "        vertical-align: top;\n",
108
       "    }\n",
109
       "\n",
110
       "    .dataframe thead th {\n",
111
       "        text-align: right;\n",
112
       "    }\n",
113
       "</style>\n",
114
       "<table border=\"1\" class=\"dataframe\">\n",
115
       "  <thead>\n",
116
       "    <tr style=\"text-align: right;\">\n",
117
       "      <th></th>\n",
118
       "      <th>Date</th>\n",
119
       "      <th>Supplier</th>\n",
120
       "      <th>Description</th>\n",
121
       "      <th>Transaction value (£)</th>\n",
122
       "    </tr>\n",
123
       "  </thead>\n",
124
       "  <tbody>\n",
125
       "    <tr>\n",
126
       "      <th>0</th>\n",
127
       "      <td>21/04/2016</td>\n",
128
       "      <td>M &amp; J Ballantyne Ltd</td>\n",
129
       "      <td>George IV Bridge Work</td>\n",
130
       "      <td>35098.0</td>\n",
131
       "    </tr>\n",
132
       "    <tr>\n",
133
       "      <th>1</th>\n",
134
       "      <td>26/04/2016</td>\n",
135
       "      <td>Private Sale</td>\n",
136
       "      <td>Literary &amp; Archival Items</td>\n",
137
       "      <td>30000.0</td>\n",
138
       "    </tr>\n",
139
       "    <tr>\n",
140
       "      <th>2</th>\n",
141
       "      <td>30/04/2016</td>\n",
142
       "      <td>City Of Edinburgh Council</td>\n",
143
       "      <td>Non Domestic Rates</td>\n",
144
       "      <td>40800.0</td>\n",
145
       "    </tr>\n",
146
       "    <tr>\n",
147
       "      <th>3</th>\n",
148
       "      <td>09/05/2016</td>\n",
149
       "      <td>Computacenter Uk</td>\n",
150
       "      <td>Kelvin Hall</td>\n",
151
       "      <td>72835.0</td>\n",
152
       "    </tr>\n",
153
       "    <tr>\n",
154
       "      <th>4</th>\n",
155
       "      <td>09/05/2016</td>\n",
156
       "      <td>John Graham Construction Ltd</td>\n",
157
       "      <td>Causewayside Refurbishment</td>\n",
158
       "      <td>64361.0</td>\n",
159
       "    </tr>\n",
160
       "  </tbody>\n",
161
       "</table>\n",
162
       "</div>"
163
      ],
164
      "text/plain": [
165
       "         Date                      Supplier                 Description  \\\n",
166
       "0  21/04/2016          M & J Ballantyne Ltd       George IV Bridge Work   \n",
167
       "1  26/04/2016                  Private Sale   Literary & Archival Items   \n",
168
       "2  30/04/2016     City Of Edinburgh Council         Non Domestic Rates    \n",
169
       "3  09/05/2016              Computacenter Uk                 Kelvin Hall   \n",
170
       "4  09/05/2016  John Graham Construction Ltd  Causewayside Refurbishment   \n",
171
       "\n",
172
       "   Transaction value (£)  \n",
173
       "0                35098.0  \n",
174
       "1                30000.0  \n",
175
       "2                40800.0  \n",
176
       "3                72835.0  \n",
177
       "4                64361.0  "
178
      ]
179
     },
180
     "execution_count": 313,
181
     "metadata": {},
182
     "output_type": "execute_result"
183
    }
184
   ],
185
   "source": [
186
    "transactions.head()\n"
187
   ]
188
  },
189
  {
190
   "cell_type": "code",
191
   "execution_count": 335,
192
   "metadata": {},
193
   "outputs": [],
194
   "source": [
195
    "def request_completion(prompt):\n",
196
    "\n",
197
    "    completion_response = openai.chat.completions.create(\n",
198
    "                            prompt=prompt,\n",
199
    "                            temperature=0,\n",
200
    "                            max_tokens=5,\n",
201
    "                            top_p=1,\n",
202
    "                            frequency_penalty=0,\n",
203
    "                            presence_penalty=0,\n",
204
    "                            model=COMPLETIONS_MODEL)\n",
205
    "\n",
206
    "    return completion_response\n",
207
    "\n",
208
    "def classify_transaction(transaction,prompt):\n",
209
    "\n",
210
    "    prompt = prompt.replace('SUPPLIER_NAME',transaction['Supplier'])\n",
211
    "    prompt = prompt.replace('DESCRIPTION_TEXT',transaction['Description'])\n",
212
    "    prompt = prompt.replace('TRANSACTION_VALUE',str(transaction['Transaction value (£)']))\n",
213
    "\n",
214
    "    classification = request_completion(prompt).choices[0].message.content.replace('\\n','')\n",
215
    "\n",
216
    "    return classification\n",
217
    "\n",
218
    "# This function takes your training and validation outputs from the prepare_data function of the Finetuning API, and\n",
219
    "# confirms that each have the same number of classes.\n",
220
    "# If they do not have the same number of classes the fine-tune will fail and return an error\n",
221
    "\n",
222
    "def check_finetune_classes(train_file,valid_file):\n",
223
    "\n",
224
    "    train_classes = set()\n",
225
    "    valid_classes = set()\n",
226
    "    with open(train_file, 'r') as json_file:\n",
227
    "        json_list = list(json_file)\n",
228
    "        print(len(json_list))\n",
229
    "\n",
230
    "    for json_str in json_list:\n",
231
    "        result = json.loads(json_str)\n",
232
    "        train_classes.add(result['completion'])\n",
233
    "        #print(f\"result: {result['completion']}\")\n",
234
    "        #print(isinstance(result, dict))\n",
235
    "\n",
236
    "    with open(valid_file, 'r') as json_file:\n",
237
    "        json_list = list(json_file)\n",
238
    "        print(len(json_list))\n",
239
    "\n",
240
    "    for json_str in json_list:\n",
241
    "        result = json.loads(json_str)\n",
242
    "        valid_classes.add(result['completion'])\n",
243
    "        #print(f\"result: {result['completion']}\")\n",
244
    "        #print(isinstance(result, dict))\n",
245
    "\n",
246
    "    if len(train_classes) == len(valid_classes):\n",
247
    "        print('All good')\n",
248
    "\n",
249
    "    else:\n",
250
    "        print('Classes do not match, please prepare data again')\n"
251
   ]
252
  },
253
  {
254
   "attachments": {},
255
   "cell_type": "markdown",
256
   "metadata": {},
257
   "source": [
258
    "## Zero-shot Classification\n",
259
    "\n",
260
    "We'll first assess the performance of the base models at classifying these transactions using a simple prompt. We'll provide the model with 5 categories and a catch-all of \"Could not classify\" for ones that it cannot place."
261
   ]
262
  },
263
  {
264
   "cell_type": "code",
265
   "execution_count": 277,
266
   "metadata": {},
267
   "outputs": [],
268
   "source": [
269
    "zero_shot_prompt = '''You are a data expert working for the National Library of Scotland.\n",
270
    "You are analysing all transactions over £25,000 in value and classifying them into one of five categories.\n",
271
    "The five categories are Building Improvement, Literature & Archive, Utility Bills, Professional Services and Software/IT.\n",
272
    "If you can't tell what it is, say Could not classify\n",
273
    "\n",
274
    "Transaction:\n",
275
    "\n",
276
    "Supplier: SUPPLIER_NAME\n",
277
    "Description: DESCRIPTION_TEXT\n",
278
    "Value: TRANSACTION_VALUE\n",
279
    "\n",
280
    "The classification is:'''\n"
281
   ]
282
  },
283
  {
284
   "cell_type": "code",
285
   "execution_count": 315,
286
   "metadata": {},
287
   "outputs": [
288
    {
289
     "name": "stdout",
290
     "output_type": "stream",
291
     "text": [
292
      " Building Improvement\n"
293
     ]
294
    }
295
   ],
296
   "source": [
297
    "# Get a test transaction\n",
298
    "transaction = transactions.iloc[0]\n",
299
    "\n",
300
    "# Interpolate the values into the prompt\n",
301
    "prompt = zero_shot_prompt.replace('SUPPLIER_NAME',transaction['Supplier'])\n",
302
    "prompt = prompt.replace('DESCRIPTION_TEXT',transaction['Description'])\n",
303
    "prompt = prompt.replace('TRANSACTION_VALUE',str(transaction['Transaction value (£)']))\n",
304
    "\n",
305
    "# Use our completion function to return a prediction\n",
306
    "completion_response = request_completion(prompt)\n",
307
    "print(completion_response.choices[0].text)\n"
308
   ]
309
  },
310
  {
311
   "attachments": {},
312
   "cell_type": "markdown",
313
   "metadata": {},
314
   "source": [
315
    "Our first attempt is correct, M & J Ballantyne Ltd are a house builder and the work they performed is indeed Building Improvement.\n",
316
    "\n",
317
    "Lets expand the sample size to 25 and see how it performs, again with just a simple prompt to guide it"
318
   ]
319
  },
320
  {
321
   "cell_type": "code",
322
   "execution_count": 291,
323
   "metadata": {},
324
   "outputs": [
325
    {
326
     "name": "stderr",
327
     "output_type": "stream",
328
     "text": [
329
      "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:2: SettingWithCopyWarning: \n",
330
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
331
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
332
      "\n",
333
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
334
      "  \n"
335
     ]
336
    }
337
   ],
338
   "source": [
339
    "test_transactions = transactions.iloc[:25]\n",
340
    "test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x,zero_shot_prompt),axis=1)\n"
341
   ]
342
  },
343
  {
344
   "cell_type": "code",
345
   "execution_count": 292,
346
   "metadata": {},
347
   "outputs": [
348
    {
349
     "data": {
350
      "text/plain": [
351
       " Building Improvement    14\n",
352
       " Could not classify       5\n",
353
       " Literature & Archive     3\n",
354
       " Software/IT              2\n",
355
       " Utility Bills            1\n",
356
       "Name: Classification, dtype: int64"
357
      ]
358
     },
359
     "execution_count": 292,
360
     "metadata": {},
361
     "output_type": "execute_result"
362
    }
363
   ],
364
   "source": [
365
    "test_transactions['Classification'].value_counts()\n"
366
   ]
367
  },
368
  {
369
   "cell_type": "code",
370
   "execution_count": 293,
371
   "metadata": {},
372
   "outputs": [
373
    {
374
     "data": {
375
      "text/html": [
376
       "<div>\n",
377
       "<style scoped>\n",
378
       "    .dataframe tbody tr th:only-of-type {\n",
379
       "        vertical-align: middle;\n",
380
       "    }\n",
381
       "\n",
382
       "    .dataframe tbody tr th {\n",
383
       "        vertical-align: top;\n",
384
       "    }\n",
385
       "\n",
386
       "    .dataframe thead th {\n",
387
       "        text-align: right;\n",
388
       "    }\n",
389
       "</style>\n",
390
       "<table border=\"1\" class=\"dataframe\">\n",
391
       "  <thead>\n",
392
       "    <tr style=\"text-align: right;\">\n",
393
       "      <th></th>\n",
394
       "      <th>Date</th>\n",
395
       "      <th>Supplier</th>\n",
396
       "      <th>Description</th>\n",
397
       "      <th>Transaction value (£)</th>\n",
398
       "      <th>Classification</th>\n",
399
       "    </tr>\n",
400
       "  </thead>\n",
401
       "  <tbody>\n",
402
       "    <tr>\n",
403
       "      <th>0</th>\n",
404
       "      <td>21/04/2016</td>\n",
405
       "      <td>M &amp; J Ballantyne Ltd</td>\n",
406
       "      <td>George IV Bridge Work</td>\n",
407
       "      <td>35098.0</td>\n",
408
       "      <td>Building Improvement</td>\n",
409
       "    </tr>\n",
410
       "    <tr>\n",
411
       "      <th>1</th>\n",
412
       "      <td>26/04/2016</td>\n",
413
       "      <td>Private Sale</td>\n",
414
       "      <td>Literary &amp; Archival Items</td>\n",
415
       "      <td>30000.0</td>\n",
416
       "      <td>Literature &amp; Archive</td>\n",
417
       "    </tr>\n",
418
       "    <tr>\n",
419
       "      <th>2</th>\n",
420
       "      <td>30/04/2016</td>\n",
421
       "      <td>City Of Edinburgh Council</td>\n",
422
       "      <td>Non Domestic Rates</td>\n",
423
       "      <td>40800.0</td>\n",
424
       "      <td>Utility Bills</td>\n",
425
       "    </tr>\n",
426
       "    <tr>\n",
427
       "      <th>3</th>\n",
428
       "      <td>09/05/2016</td>\n",
429
       "      <td>Computacenter Uk</td>\n",
430
       "      <td>Kelvin Hall</td>\n",
431
       "      <td>72835.0</td>\n",
432
       "      <td>Software/IT</td>\n",
433
       "    </tr>\n",
434
       "    <tr>\n",
435
       "      <th>4</th>\n",
436
       "      <td>09/05/2016</td>\n",
437
       "      <td>John Graham Construction Ltd</td>\n",
438
       "      <td>Causewayside Refurbishment</td>\n",
439
       "      <td>64361.0</td>\n",
440
       "      <td>Building Improvement</td>\n",
441
       "    </tr>\n",
442
       "    <tr>\n",
443
       "      <th>5</th>\n",
444
       "      <td>09/05/2016</td>\n",
445
       "      <td>A McGillivray</td>\n",
446
       "      <td>Causewayside Refurbishment</td>\n",
447
       "      <td>53690.0</td>\n",
448
       "      <td>Building Improvement</td>\n",
449
       "    </tr>\n",
450
       "    <tr>\n",
451
       "      <th>6</th>\n",
452
       "      <td>16/05/2016</td>\n",
453
       "      <td>John Graham Construction Ltd</td>\n",
454
       "      <td>Causewayside Refurbishment</td>\n",
455
       "      <td>365344.0</td>\n",
456
       "      <td>Building Improvement</td>\n",
457
       "    </tr>\n",
458
       "    <tr>\n",
459
       "      <th>7</th>\n",
460
       "      <td>23/05/2016</td>\n",
461
       "      <td>Computacenter Uk</td>\n",
462
       "      <td>Kelvin Hall</td>\n",
463
       "      <td>26506.0</td>\n",
464
       "      <td>Software/IT</td>\n",
465
       "    </tr>\n",
466
       "    <tr>\n",
467
       "      <th>8</th>\n",
468
       "      <td>23/05/2016</td>\n",
469
       "      <td>ECG Facilities Service</td>\n",
470
       "      <td>Facilities Management Charge</td>\n",
471
       "      <td>32777.0</td>\n",
472
       "      <td>Building Improvement</td>\n",
473
       "    </tr>\n",
474
       "    <tr>\n",
475
       "      <th>9</th>\n",
476
       "      <td>23/05/2016</td>\n",
477
       "      <td>ECG Facilities Service</td>\n",
478
       "      <td>Facilities Management Charge</td>\n",
479
       "      <td>32777.0</td>\n",
480
       "      <td>Building Improvement</td>\n",
481
       "    </tr>\n",
482
       "    <tr>\n",
483
       "      <th>10</th>\n",
484
       "      <td>30/05/2016</td>\n",
485
       "      <td>ALDL</td>\n",
486
       "      <td>ALDL Charges</td>\n",
487
       "      <td>32317.0</td>\n",
488
       "      <td>Could not classify</td>\n",
489
       "    </tr>\n",
490
       "    <tr>\n",
491
       "      <th>11</th>\n",
492
       "      <td>10/06/2016</td>\n",
493
       "      <td>Wavetek Ltd</td>\n",
494
       "      <td>Kelvin Hall</td>\n",
495
       "      <td>87589.0</td>\n",
496
       "      <td>Could not classify</td>\n",
497
       "    </tr>\n",
498
       "    <tr>\n",
499
       "      <th>12</th>\n",
500
       "      <td>10/06/2016</td>\n",
501
       "      <td>John Graham Construction Ltd</td>\n",
502
       "      <td>Causewayside Refurbishment</td>\n",
503
       "      <td>381803.0</td>\n",
504
       "      <td>Building Improvement</td>\n",
505
       "    </tr>\n",
506
       "    <tr>\n",
507
       "      <th>13</th>\n",
508
       "      <td>28/06/2016</td>\n",
509
       "      <td>ECG Facilities Service</td>\n",
510
       "      <td>Facilities Management Charge</td>\n",
511
       "      <td>32832.0</td>\n",
512
       "      <td>Building Improvement</td>\n",
513
       "    </tr>\n",
514
       "    <tr>\n",
515
       "      <th>14</th>\n",
516
       "      <td>30/06/2016</td>\n",
517
       "      <td>Glasgow City Council</td>\n",
518
       "      <td>Kelvin Hall</td>\n",
519
       "      <td>1700000.0</td>\n",
520
       "      <td>Building Improvement</td>\n",
521
       "    </tr>\n",
522
       "    <tr>\n",
523
       "      <th>15</th>\n",
524
       "      <td>11/07/2016</td>\n",
525
       "      <td>Wavetek Ltd</td>\n",
526
       "      <td>Kelvin Hall</td>\n",
527
       "      <td>65692.0</td>\n",
528
       "      <td>Could not classify</td>\n",
529
       "    </tr>\n",
530
       "    <tr>\n",
531
       "      <th>16</th>\n",
532
       "      <td>11/07/2016</td>\n",
533
       "      <td>John Graham Construction Ltd</td>\n",
534
       "      <td>Causewayside Refurbishment</td>\n",
535
       "      <td>139845.0</td>\n",
536
       "      <td>Building Improvement</td>\n",
537
       "    </tr>\n",
538
       "    <tr>\n",
539
       "      <th>17</th>\n",
540
       "      <td>15/07/2016</td>\n",
541
       "      <td>Sotheby'S</td>\n",
542
       "      <td>Literary &amp; Archival Items</td>\n",
543
       "      <td>28500.0</td>\n",
544
       "      <td>Literature &amp; Archive</td>\n",
545
       "    </tr>\n",
546
       "    <tr>\n",
547
       "      <th>18</th>\n",
548
       "      <td>18/07/2016</td>\n",
549
       "      <td>Christies</td>\n",
550
       "      <td>Literary &amp; Archival Items</td>\n",
551
       "      <td>33800.0</td>\n",
552
       "      <td>Literature &amp; Archive</td>\n",
553
       "    </tr>\n",
554
       "    <tr>\n",
555
       "      <th>19</th>\n",
556
       "      <td>25/07/2016</td>\n",
557
       "      <td>A McGillivray</td>\n",
558
       "      <td>Causewayside Refurbishment</td>\n",
559
       "      <td>30113.0</td>\n",
560
       "      <td>Building Improvement</td>\n",
561
       "    </tr>\n",
562
       "    <tr>\n",
563
       "      <th>20</th>\n",
564
       "      <td>31/07/2016</td>\n",
565
       "      <td>ALDL</td>\n",
566
       "      <td>ALDL Charges</td>\n",
567
       "      <td>32317.0</td>\n",
568
       "      <td>Could not classify</td>\n",
569
       "    </tr>\n",
570
       "    <tr>\n",
571
       "      <th>21</th>\n",
572
       "      <td>08/08/2016</td>\n",
573
       "      <td>ECG Facilities Service</td>\n",
574
       "      <td>Facilities Management Charge</td>\n",
575
       "      <td>32795.0</td>\n",
576
       "      <td>Building Improvement</td>\n",
577
       "    </tr>\n",
578
       "    <tr>\n",
579
       "      <th>22</th>\n",
580
       "      <td>15/08/2016</td>\n",
581
       "      <td>Creative Video Productions Ltd</td>\n",
582
       "      <td>Kelvin Hall</td>\n",
583
       "      <td>26866.0</td>\n",
584
       "      <td>Could not classify</td>\n",
585
       "    </tr>\n",
586
       "    <tr>\n",
587
       "      <th>23</th>\n",
588
       "      <td>15/08/2016</td>\n",
589
       "      <td>John Graham Construction Ltd</td>\n",
590
       "      <td>Causewayside Refurbishment</td>\n",
591
       "      <td>196807.0</td>\n",
592
       "      <td>Building Improvement</td>\n",
593
       "    </tr>\n",
594
       "    <tr>\n",
595
       "      <th>24</th>\n",
596
       "      <td>24/08/2016</td>\n",
597
       "      <td>ECG Facilities Service</td>\n",
598
       "      <td>Facilities Management Charge</td>\n",
599
       "      <td>32795.0</td>\n",
600
       "      <td>Building Improvement</td>\n",
601
       "    </tr>\n",
602
       "  </tbody>\n",
603
       "</table>\n",
604
       "</div>"
605
      ],
606
      "text/plain": [
607
       "          Date                        Supplier                   Description  \\\n",
608
       "0   21/04/2016            M & J Ballantyne Ltd         George IV Bridge Work   \n",
609
       "1   26/04/2016                    Private Sale     Literary & Archival Items   \n",
610
       "2   30/04/2016       City Of Edinburgh Council           Non Domestic Rates    \n",
611
       "3   09/05/2016                Computacenter Uk                   Kelvin Hall   \n",
612
       "4   09/05/2016    John Graham Construction Ltd    Causewayside Refurbishment   \n",
613
       "5   09/05/2016                   A McGillivray    Causewayside Refurbishment   \n",
614
       "6   16/05/2016    John Graham Construction Ltd    Causewayside Refurbishment   \n",
615
       "7   23/05/2016                Computacenter Uk                   Kelvin Hall   \n",
616
       "8   23/05/2016          ECG Facilities Service  Facilities Management Charge   \n",
617
       "9   23/05/2016          ECG Facilities Service  Facilities Management Charge   \n",
618
       "10  30/05/2016                            ALDL                  ALDL Charges   \n",
619
       "11  10/06/2016                     Wavetek Ltd                   Kelvin Hall   \n",
620
       "12  10/06/2016    John Graham Construction Ltd    Causewayside Refurbishment   \n",
621
       "13  28/06/2016          ECG Facilities Service  Facilities Management Charge   \n",
622
       "14  30/06/2016            Glasgow City Council                   Kelvin Hall   \n",
623
       "15  11/07/2016                     Wavetek Ltd                   Kelvin Hall   \n",
624
       "16  11/07/2016    John Graham Construction Ltd    Causewayside Refurbishment   \n",
625
       "17  15/07/2016                       Sotheby'S     Literary & Archival Items   \n",
626
       "18  18/07/2016                       Christies     Literary & Archival Items   \n",
627
       "19  25/07/2016                   A McGillivray    Causewayside Refurbishment   \n",
628
       "20  31/07/2016                            ALDL                  ALDL Charges   \n",
629
       "21  08/08/2016          ECG Facilities Service  Facilities Management Charge   \n",
630
       "22  15/08/2016  Creative Video Productions Ltd                   Kelvin Hall   \n",
631
       "23  15/08/2016    John Graham Construction Ltd    Causewayside Refurbishment   \n",
632
       "24  24/08/2016          ECG Facilities Service  Facilities Management Charge   \n",
633
       "\n",
634
       "    Transaction value (£)         Classification  \n",
635
       "0                 35098.0   Building Improvement  \n",
636
       "1                 30000.0   Literature & Archive  \n",
637
       "2                 40800.0          Utility Bills  \n",
638
       "3                 72835.0            Software/IT  \n",
639
       "4                 64361.0   Building Improvement  \n",
640
       "5                 53690.0   Building Improvement  \n",
641
       "6                365344.0   Building Improvement  \n",
642
       "7                 26506.0            Software/IT  \n",
643
       "8                 32777.0   Building Improvement  \n",
644
       "9                 32777.0   Building Improvement  \n",
645
       "10                32317.0     Could not classify  \n",
646
       "11                87589.0     Could not classify  \n",
647
       "12               381803.0   Building Improvement  \n",
648
       "13                32832.0   Building Improvement  \n",
649
       "14              1700000.0   Building Improvement  \n",
650
       "15                65692.0     Could not classify  \n",
651
       "16               139845.0   Building Improvement  \n",
652
       "17                28500.0   Literature & Archive  \n",
653
       "18                33800.0   Literature & Archive  \n",
654
       "19                30113.0   Building Improvement  \n",
655
       "20                32317.0     Could not classify  \n",
656
       "21                32795.0   Building Improvement  \n",
657
       "22                26866.0     Could not classify  \n",
658
       "23               196807.0   Building Improvement  \n",
659
       "24                32795.0   Building Improvement  "
660
      ]
661
     },
662
     "execution_count": 293,
663
     "metadata": {},
664
     "output_type": "execute_result"
665
    }
666
   ],
667
   "source": [
668
    "test_transactions.head(25)\n"
669
   ]
670
  },
671
  {
672
   "attachments": {},
673
   "cell_type": "markdown",
674
   "metadata": {},
675
   "source": [
676
    "Initial results are pretty good even with no labelled examples! The ones that it could not classify were tougher cases with few clues as to their topic, but maybe if we clean up the labelled dataset to give more examples we can get better performance."
677
   ]
678
  },
679
  {
680
   "attachments": {},
681
   "cell_type": "markdown",
682
   "metadata": {},
683
   "source": [
684
    "## Classification with Embeddings\n",
685
    "\n",
686
    "Lets create embeddings from the small set that we've classified so far - we've made a set of labelled examples by running the zero-shot classifier on 101 transactions from our dataset and manually correcting the 15 **Could not classify** results that we got\n",
687
    "\n",
688
    "### Create embeddings\n",
689
    "\n",
690
    "This initial section reuses the approach from the [Get_embeddings_from_dataset Notebook](Get_embeddings_from_dataset.ipynb) to create embeddings from a combined field concatenating all of our features"
691
   ]
692
  },
693
  {
694
   "cell_type": "code",
695
   "execution_count": 317,
696
   "metadata": {},
697
   "outputs": [
698
    {
699
     "data": {
700
      "text/html": [
701
       "<div>\n",
702
       "<style scoped>\n",
703
       "    .dataframe tbody tr th:only-of-type {\n",
704
       "        vertical-align: middle;\n",
705
       "    }\n",
706
       "\n",
707
       "    .dataframe tbody tr th {\n",
708
       "        vertical-align: top;\n",
709
       "    }\n",
710
       "\n",
711
       "    .dataframe thead th {\n",
712
       "        text-align: right;\n",
713
       "    }\n",
714
       "</style>\n",
715
       "<table border=\"1\" class=\"dataframe\">\n",
716
       "  <thead>\n",
717
       "    <tr style=\"text-align: right;\">\n",
718
       "      <th></th>\n",
719
       "      <th>Date</th>\n",
720
       "      <th>Supplier</th>\n",
721
       "      <th>Description</th>\n",
722
       "      <th>Transaction value (£)</th>\n",
723
       "      <th>Classification</th>\n",
724
       "    </tr>\n",
725
       "  </thead>\n",
726
       "  <tbody>\n",
727
       "    <tr>\n",
728
       "      <th>0</th>\n",
729
       "      <td>15/08/2016</td>\n",
730
       "      <td>Creative Video Productions Ltd</td>\n",
731
       "      <td>Kelvin Hall</td>\n",
732
       "      <td>26866</td>\n",
733
       "      <td>Other</td>\n",
734
       "    </tr>\n",
735
       "    <tr>\n",
736
       "      <th>1</th>\n",
737
       "      <td>29/05/2017</td>\n",
738
       "      <td>John Graham Construction Ltd</td>\n",
739
       "      <td>Causewayside Refurbishment</td>\n",
740
       "      <td>74806</td>\n",
741
       "      <td>Building Improvement</td>\n",
742
       "    </tr>\n",
743
       "    <tr>\n",
744
       "      <th>2</th>\n",
745
       "      <td>29/05/2017</td>\n",
746
       "      <td>Morris &amp; Spottiswood Ltd</td>\n",
747
       "      <td>George IV Bridge Work</td>\n",
748
       "      <td>56448</td>\n",
749
       "      <td>Building Improvement</td>\n",
750
       "    </tr>\n",
751
       "    <tr>\n",
752
       "      <th>3</th>\n",
753
       "      <td>31/05/2017</td>\n",
754
       "      <td>John Graham Construction Ltd</td>\n",
755
       "      <td>Causewayside Refurbishment</td>\n",
756
       "      <td>164691</td>\n",
757
       "      <td>Building Improvement</td>\n",
758
       "    </tr>\n",
759
       "    <tr>\n",
760
       "      <th>4</th>\n",
761
       "      <td>24/07/2017</td>\n",
762
       "      <td>John Graham Construction Ltd</td>\n",
763
       "      <td>Causewayside Refurbishment</td>\n",
764
       "      <td>27926</td>\n",
765
       "      <td>Building Improvement</td>\n",
766
       "    </tr>\n",
767
       "  </tbody>\n",
768
       "</table>\n",
769
       "</div>"
770
      ],
771
      "text/plain": [
772
       "         Date                        Supplier                 Description  \\\n",
773
       "0  15/08/2016  Creative Video Productions Ltd                 Kelvin Hall   \n",
774
       "1  29/05/2017    John Graham Construction Ltd  Causewayside Refurbishment   \n",
775
       "2  29/05/2017        Morris & Spottiswood Ltd       George IV Bridge Work   \n",
776
       "3  31/05/2017    John Graham Construction Ltd  Causewayside Refurbishment   \n",
777
       "4  24/07/2017    John Graham Construction Ltd  Causewayside Refurbishment   \n",
778
       "\n",
779
       "   Transaction value (£)        Classification  \n",
780
       "0                  26866                 Other  \n",
781
       "1                  74806  Building Improvement  \n",
782
       "2                  56448  Building Improvement  \n",
783
       "3                 164691  Building Improvement  \n",
784
       "4                  27926  Building Improvement  "
785
      ]
786
     },
787
     "execution_count": 317,
788
     "metadata": {},
789
     "output_type": "execute_result"
790
    }
791
   ],
792
   "source": [
793
    "df = pd.read_csv('./data/labelled_transactions.csv')\n",
794
    "df.head()\n"
795
   ]
796
  },
797
  {
798
   "cell_type": "code",
799
   "execution_count": 318,
800
   "metadata": {},
801
   "outputs": [
802
    {
803
     "data": {
804
      "text/html": [
805
       "<div>\n",
806
       "<style scoped>\n",
807
       "    .dataframe tbody tr th:only-of-type {\n",
808
       "        vertical-align: middle;\n",
809
       "    }\n",
810
       "\n",
811
       "    .dataframe tbody tr th {\n",
812
       "        vertical-align: top;\n",
813
       "    }\n",
814
       "\n",
815
       "    .dataframe thead th {\n",
816
       "        text-align: right;\n",
817
       "    }\n",
818
       "</style>\n",
819
       "<table border=\"1\" class=\"dataframe\">\n",
820
       "  <thead>\n",
821
       "    <tr style=\"text-align: right;\">\n",
822
       "      <th></th>\n",
823
       "      <th>Date</th>\n",
824
       "      <th>Supplier</th>\n",
825
       "      <th>Description</th>\n",
826
       "      <th>Transaction value (£)</th>\n",
827
       "      <th>Classification</th>\n",
828
       "      <th>combined</th>\n",
829
       "    </tr>\n",
830
       "  </thead>\n",
831
       "  <tbody>\n",
832
       "    <tr>\n",
833
       "      <th>0</th>\n",
834
       "      <td>15/08/2016</td>\n",
835
       "      <td>Creative Video Productions Ltd</td>\n",
836
       "      <td>Kelvin Hall</td>\n",
837
       "      <td>26866</td>\n",
838
       "      <td>Other</td>\n",
839
       "      <td>Supplier: Creative Video Productions Ltd; Desc...</td>\n",
840
       "    </tr>\n",
841
       "    <tr>\n",
842
       "      <th>1</th>\n",
843
       "      <td>29/05/2017</td>\n",
844
       "      <td>John Graham Construction Ltd</td>\n",
845
       "      <td>Causewayside Refurbishment</td>\n",
846
       "      <td>74806</td>\n",
847
       "      <td>Building Improvement</td>\n",
848
       "      <td>Supplier: John Graham Construction Ltd; Descri...</td>\n",
849
       "    </tr>\n",
850
       "  </tbody>\n",
851
       "</table>\n",
852
       "</div>"
853
      ],
854
      "text/plain": [
855
       "         Date                        Supplier                 Description  \\\n",
856
       "0  15/08/2016  Creative Video Productions Ltd                 Kelvin Hall   \n",
857
       "1  29/05/2017    John Graham Construction Ltd  Causewayside Refurbishment   \n",
858
       "\n",
859
       "   Transaction value (£)        Classification  \\\n",
860
       "0                  26866                 Other   \n",
861
       "1                  74806  Building Improvement   \n",
862
       "\n",
863
       "                                            combined  \n",
864
       "0  Supplier: Creative Video Productions Ltd; Desc...  \n",
865
       "1  Supplier: John Graham Construction Ltd; Descri...  "
866
      ]
867
     },
868
     "execution_count": 318,
869
     "metadata": {},
870
     "output_type": "execute_result"
871
    }
872
   ],
873
   "source": [
874
    "df['combined'] = \"Supplier: \" + df['Supplier'].str.strip() + \"; Description: \" + df['Description'].str.strip() + \"; Value: \" + str(df['Transaction value (£)']).strip()\n",
875
    "df.head(2)\n"
876
   ]
877
  },
878
  {
879
   "cell_type": "code",
880
   "execution_count": 319,
881
   "metadata": {},
882
   "outputs": [
883
    {
884
     "data": {
885
      "text/plain": [
886
       "101"
887
      ]
888
     },
889
     "execution_count": 319,
890
     "metadata": {},
891
     "output_type": "execute_result"
892
    }
893
   ],
894
   "source": [
895
    "from transformers import GPT2TokenizerFast\n",
896
    "tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n",
897
    "\n",
898
    "df['n_tokens'] = df.combined.apply(lambda x: len(tokenizer.encode(x)))\n",
899
    "len(df)\n"
900
   ]
901
  },
902
  {
903
   "cell_type": "code",
904
   "execution_count": 320,
905
   "metadata": {},
906
   "outputs": [],
907
   "source": [
908
    "embedding_path = './data/transactions_with_embeddings_100.csv'\n"
909
   ]
910
  },
911
  {
912
   "cell_type": "code",
913
   "execution_count": 321,
914
   "metadata": {},
915
   "outputs": [],
916
   "source": [
917
    "from utils.embeddings_utils import get_embedding\n",
918
    "\n",
919
    "df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x, model='gpt-4'))\n",
920
    "df['babbage_search'] = df.combined.apply(lambda x: get_embedding(x, model='gpt-4'))\n",
921
    "df.to_csv(embedding_path)\n"
922
   ]
923
  },
924
  {
925
   "attachments": {},
926
   "cell_type": "markdown",
927
   "metadata": {},
928
   "source": [
929
    "### Use embeddings for classification\n",
930
    "\n",
931
    "Now that we have our embeddings, let see if classifying these into the categories we've named gives us any more success.\n",
932
    "\n",
933
    "For this we'll use a template from the [Classification_using_embeddings](Classification_using_embeddings.ipynb) notebook"
934
   ]
935
  },
936
  {
937
   "cell_type": "code",
938
   "execution_count": 309,
939
   "metadata": {},
940
   "outputs": [
941
    {
942
     "data": {
943
      "text/html": [
944
       "<div>\n",
945
       "<style scoped>\n",
946
       "    .dataframe tbody tr th:only-of-type {\n",
947
       "        vertical-align: middle;\n",
948
       "    }\n",
949
       "\n",
950
       "    .dataframe tbody tr th {\n",
951
       "        vertical-align: top;\n",
952
       "    }\n",
953
       "\n",
954
       "    .dataframe thead th {\n",
955
       "        text-align: right;\n",
956
       "    }\n",
957
       "</style>\n",
958
       "<table border=\"1\" class=\"dataframe\">\n",
959
       "  <thead>\n",
960
       "    <tr style=\"text-align: right;\">\n",
961
       "      <th></th>\n",
962
       "      <th>Unnamed: 0</th>\n",
963
       "      <th>Date</th>\n",
964
       "      <th>Supplier</th>\n",
965
       "      <th>Description</th>\n",
966
       "      <th>Transaction value (£)</th>\n",
967
       "      <th>Classification</th>\n",
968
       "      <th>combined</th>\n",
969
       "      <th>n_tokens</th>\n",
970
       "      <th>babbage_similarity</th>\n",
971
       "      <th>babbage_search</th>\n",
972
       "    </tr>\n",
973
       "  </thead>\n",
974
       "  <tbody>\n",
975
       "    <tr>\n",
976
       "      <th>0</th>\n",
977
       "      <td>0</td>\n",
978
       "      <td>15/08/2016</td>\n",
979
       "      <td>Creative Video Productions Ltd</td>\n",
980
       "      <td>Kelvin Hall</td>\n",
981
       "      <td>26866</td>\n",
982
       "      <td>Other</td>\n",
983
       "      <td>Supplier: Creative Video Productions Ltd; Desc...</td>\n",
984
       "      <td>136</td>\n",
985
       "      <td>[-0.009802100248634815, 0.022551486268639565, ...</td>\n",
986
       "      <td>[-0.00232666521333158, 0.019198870286345482, 0...</td>\n",
987
       "    </tr>\n",
988
       "    <tr>\n",
989
       "      <th>1</th>\n",
990
       "      <td>1</td>\n",
991
       "      <td>29/05/2017</td>\n",
992
       "      <td>John Graham Construction Ltd</td>\n",
993
       "      <td>Causewayside Refurbishment</td>\n",
994
       "      <td>74806</td>\n",
995
       "      <td>Building Improvement</td>\n",
996
       "      <td>Supplier: John Graham Construction Ltd; Descri...</td>\n",
997
       "      <td>140</td>\n",
998
       "      <td>[-0.009065819904208183, 0.012094118632376194, ...</td>\n",
999
       "      <td>[0.005169447045773268, 0.00473341578617692, -0...</td>\n",
1000
       "    </tr>\n",
1001
       "    <tr>\n",
1002
       "      <th>2</th>\n",
1003
       "      <td>2</td>\n",
1004
       "      <td>29/05/2017</td>\n",
1005
       "      <td>Morris &amp; Spottiswood Ltd</td>\n",
1006
       "      <td>George IV Bridge Work</td>\n",
1007
       "      <td>56448</td>\n",
1008
       "      <td>Building Improvement</td>\n",
1009
       "      <td>Supplier: Morris &amp; Spottiswood Ltd; Descriptio...</td>\n",
1010
       "      <td>141</td>\n",
1011
       "      <td>[-0.009000026620924473, 0.02405017428100109, -...</td>\n",
1012
       "      <td>[0.0028343256562948227, 0.021166473627090454, ...</td>\n",
1013
       "    </tr>\n",
1014
       "    <tr>\n",
1015
       "      <th>3</th>\n",
1016
       "      <td>3</td>\n",
1017
       "      <td>31/05/2017</td>\n",
1018
       "      <td>John Graham Construction Ltd</td>\n",
1019
       "      <td>Causewayside Refurbishment</td>\n",
1020
       "      <td>164691</td>\n",
1021
       "      <td>Building Improvement</td>\n",
1022
       "      <td>Supplier: John Graham Construction Ltd; Descri...</td>\n",
1023
       "      <td>140</td>\n",
1024
       "      <td>[-0.009065819904208183, 0.012094118632376194, ...</td>\n",
1025
       "      <td>[0.005169447045773268, 0.00473341578617692, -0...</td>\n",
1026
       "    </tr>\n",
1027
       "    <tr>\n",
1028
       "      <th>4</th>\n",
1029
       "      <td>4</td>\n",
1030
       "      <td>24/07/2017</td>\n",
1031
       "      <td>John Graham Construction Ltd</td>\n",
1032
       "      <td>Causewayside Refurbishment</td>\n",
1033
       "      <td>27926</td>\n",
1034
       "      <td>Building Improvement</td>\n",
1035
       "      <td>Supplier: John Graham Construction Ltd; Descri...</td>\n",
1036
       "      <td>140</td>\n",
1037
       "      <td>[-0.009065819904208183, 0.012094118632376194, ...</td>\n",
1038
       "      <td>[0.005169447045773268, 0.00473341578617692, -0...</td>\n",
1039
       "    </tr>\n",
1040
       "  </tbody>\n",
1041
       "</table>\n",
1042
       "</div>"
1043
      ],
1044
      "text/plain": [
1045
       "   Unnamed: 0        Date                        Supplier  \\\n",
1046
       "0           0  15/08/2016  Creative Video Productions Ltd   \n",
1047
       "1           1  29/05/2017    John Graham Construction Ltd   \n",
1048
       "2           2  29/05/2017        Morris & Spottiswood Ltd   \n",
1049
       "3           3  31/05/2017    John Graham Construction Ltd   \n",
1050
       "4           4  24/07/2017    John Graham Construction Ltd   \n",
1051
       "\n",
1052
       "                  Description  Transaction value (£)        Classification  \\\n",
1053
       "0                 Kelvin Hall                  26866                 Other   \n",
1054
       "1  Causewayside Refurbishment                  74806  Building Improvement   \n",
1055
       "2       George IV Bridge Work                  56448  Building Improvement   \n",
1056
       "3  Causewayside Refurbishment                 164691  Building Improvement   \n",
1057
       "4  Causewayside Refurbishment                  27926  Building Improvement   \n",
1058
       "\n",
1059
       "                                            combined  n_tokens  \\\n",
1060
       "0  Supplier: Creative Video Productions Ltd; Desc...       136   \n",
1061
       "1  Supplier: John Graham Construction Ltd; Descri...       140   \n",
1062
       "2  Supplier: Morris & Spottiswood Ltd; Descriptio...       141   \n",
1063
       "3  Supplier: John Graham Construction Ltd; Descri...       140   \n",
1064
       "4  Supplier: John Graham Construction Ltd; Descri...       140   \n",
1065
       "\n",
1066
       "                                  babbage_similarity  \\\n",
1067
       "0  [-0.009802100248634815, 0.022551486268639565, ...   \n",
1068
       "1  [-0.009065819904208183, 0.012094118632376194, ...   \n",
1069
       "2  [-0.009000026620924473, 0.02405017428100109, -...   \n",
1070
       "3  [-0.009065819904208183, 0.012094118632376194, ...   \n",
1071
       "4  [-0.009065819904208183, 0.012094118632376194, ...   \n",
1072
       "\n",
1073
       "                                      babbage_search  \n",
1074
       "0  [-0.00232666521333158, 0.019198870286345482, 0...  \n",
1075
       "1  [0.005169447045773268, 0.00473341578617692, -0...  \n",
1076
       "2  [0.0028343256562948227, 0.021166473627090454, ...  \n",
1077
       "3  [0.005169447045773268, 0.00473341578617692, -0...  \n",
1078
       "4  [0.005169447045773268, 0.00473341578617692, -0...  "
1079
      ]
1080
     },
1081
     "execution_count": 309,
1082
     "metadata": {},
1083
     "output_type": "execute_result"
1084
    }
1085
   ],
1086
   "source": [
1087
    "from sklearn.ensemble import RandomForestClassifier\n",
1088
    "from sklearn.model_selection import train_test_split\n",
1089
    "from sklearn.metrics import classification_report, accuracy_score\n",
1090
    "from ast import literal_eval\n",
1091
    "\n",
1092
    "fs_df = pd.read_csv(embedding_path)\n",
1093
    "fs_df[\"babbage_similarity\"] = fs_df.babbage_similarity.apply(literal_eval).apply(np.array)\n",
1094
    "fs_df.head()\n"
1095
   ]
1096
  },
1097
  {
1098
   "cell_type": "code",
1099
   "execution_count": 310,
1100
   "metadata": {},
1101
   "outputs": [
1102
    {
1103
     "name": "stdout",
1104
     "output_type": "stream",
1105
     "text": [
1106
      "                      precision    recall  f1-score   support\n",
1107
      "\n",
1108
      "Building Improvement       0.92      1.00      0.96        11\n",
1109
      "Literature & Archive       1.00      1.00      1.00         3\n",
1110
      "               Other       0.00      0.00      0.00         1\n",
1111
      "         Software/IT       1.00      1.00      1.00         1\n",
1112
      "       Utility Bills       1.00      1.00      1.00         5\n",
1113
      "\n",
1114
      "            accuracy                           0.95        21\n",
1115
      "           macro avg       0.78      0.80      0.79        21\n",
1116
      "        weighted avg       0.91      0.95      0.93        21\n",
1117
      "\n"
1118
     ]
1119
    },
1120
    {
1121
     "name": "stderr",
1122
     "output_type": "stream",
1123
     "text": [
1124
      "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
1125
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
1126
      "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
1127
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
1128
      "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
1129
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
1130
     ]
1131
    }
1132
   ],
1133
   "source": [
1134
    "X_train, X_test, y_train, y_test = train_test_split(\n",
1135
    "    list(fs_df.babbage_similarity.values), fs_df.Classification, test_size=0.2, random_state=42\n",
1136
    ")\n",
1137
    "\n",
1138
    "clf = RandomForestClassifier(n_estimators=100)\n",
1139
    "clf.fit(X_train, y_train)\n",
1140
    "preds = clf.predict(X_test)\n",
1141
    "probas = clf.predict_proba(X_test)\n",
1142
    "\n",
1143
    "report = classification_report(y_test, preds)\n",
1144
    "print(report)\n"
1145
   ]
1146
  },
1147
  {
1148
   "attachments": {},
1149
   "cell_type": "markdown",
1150
   "metadata": {},
1151
   "source": [
1152
    "Performance for this model is pretty strong, so creating embeddings and using even a simpler classifier looks like an effective approach as well, with the zero-shot classifier helping us do the initial classification of the unlabelled dataset.\n",
1153
    "\n",
1154
    "Lets take it one step further and see if a fine-tuned model trained on this same labelled datasets gives us comparable results"
1155
   ]
1156
  },
1157
  {
1158
   "attachments": {},
1159
   "cell_type": "markdown",
1160
   "metadata": {},
1161
   "source": [
1162
    "## Fine-tuned Transaction Classification\n",
1163
    "\n",
1164
    "For this use case we're going to try to improve on the few-shot classification from above by training a fine-tuned model on the same labelled set of 101 transactions and applying this fine-tuned model on group of unseen transactions"
1165
   ]
1166
  },
1167
  {
1168
   "attachments": {},
1169
   "cell_type": "markdown",
1170
   "metadata": {},
1171
   "source": [
1172
    "### Building Fine-tuned Classifier\n",
1173
    "\n",
1174
    "We'll need to do some data prep first to get our data ready. This will take the following steps:\n",
1175
    "- First we'll list out our classes and replace them with numeric identifiers. Making the model predict a single token rather than multiple consecutive ones like 'Building Improvement' should give us better results\n",
1176
    "- We also need to add a common prefix and suffix to each example to aid the model in making predictions - in our case our text is already started with 'Supplier' and we'll add a suffix of '\\n\\n###\\n\\n'\n",
1177
    "- Lastly we'll aid a leading whitespace onto each of our target classes for classification, again to aid the model"
1178
   ]
1179
  },
1180
  {
1181
   "cell_type": "code",
1182
   "execution_count": 210,
1183
   "metadata": {},
1184
   "outputs": [
1185
    {
1186
     "data": {
1187
      "text/plain": [
1188
       "101"
1189
      ]
1190
     },
1191
     "execution_count": 210,
1192
     "metadata": {},
1193
     "output_type": "execute_result"
1194
    }
1195
   ],
1196
   "source": [
1197
    "ft_prep_df = fs_df.copy()\n",
1198
    "len(ft_prep_df)\n"
1199
   ]
1200
  },
1201
  {
1202
   "cell_type": "code",
1203
   "execution_count": 211,
1204
   "metadata": {},
1205
   "outputs": [
1206
    {
1207
     "data": {
1208
      "text/html": [
1209
       "<div>\n",
1210
       "<style scoped>\n",
1211
       "    .dataframe tbody tr th:only-of-type {\n",
1212
       "        vertical-align: middle;\n",
1213
       "    }\n",
1214
       "\n",
1215
       "    .dataframe tbody tr th {\n",
1216
       "        vertical-align: top;\n",
1217
       "    }\n",
1218
       "\n",
1219
       "    .dataframe thead th {\n",
1220
       "        text-align: right;\n",
1221
       "    }\n",
1222
       "</style>\n",
1223
       "<table border=\"1\" class=\"dataframe\">\n",
1224
       "  <thead>\n",
1225
       "    <tr style=\"text-align: right;\">\n",
1226
       "      <th></th>\n",
1227
       "      <th>Unnamed: 0</th>\n",
1228
       "      <th>Date</th>\n",
1229
       "      <th>Supplier</th>\n",
1230
       "      <th>Description</th>\n",
1231
       "      <th>Transaction value (£)</th>\n",
1232
       "      <th>Classification</th>\n",
1233
       "      <th>combined</th>\n",
1234
       "      <th>n_tokens</th>\n",
1235
       "      <th>babbage_similarity</th>\n",
1236
       "      <th>babbage_search</th>\n",
1237
       "    </tr>\n",
1238
       "  </thead>\n",
1239
       "  <tbody>\n",
1240
       "    <tr>\n",
1241
       "      <th>0</th>\n",
1242
       "      <td>0</td>\n",
1243
       "      <td>15/08/2016</td>\n",
1244
       "      <td>Creative Video Productions Ltd</td>\n",
1245
       "      <td>Kelvin Hall</td>\n",
1246
       "      <td>26866</td>\n",
1247
       "      <td>Other</td>\n",
1248
       "      <td>Supplier: Creative Video Productions Ltd; Desc...</td>\n",
1249
       "      <td>12</td>\n",
1250
       "      <td>[-0.009630300104618073, 0.009887108579277992, ...</td>\n",
1251
       "      <td>[-0.008217384107410908, 0.025170527398586273, ...</td>\n",
1252
       "    </tr>\n",
1253
       "    <tr>\n",
1254
       "      <th>1</th>\n",
1255
       "      <td>1</td>\n",
1256
       "      <td>29/05/2017</td>\n",
1257
       "      <td>John Graham Construction Ltd</td>\n",
1258
       "      <td>Causewayside Refurbishment</td>\n",
1259
       "      <td>74806</td>\n",
1260
       "      <td>Building Improvement</td>\n",
1261
       "      <td>Supplier: John Graham Construction Ltd; Descri...</td>\n",
1262
       "      <td>16</td>\n",
1263
       "      <td>[-0.006144719664007425, -0.0018709596479311585...</td>\n",
1264
       "      <td>[-0.007424891460686922, 0.008475713431835175, ...</td>\n",
1265
       "    </tr>\n",
1266
       "    <tr>\n",
1267
       "      <th>2</th>\n",
1268
       "      <td>2</td>\n",
1269
       "      <td>29/05/2017</td>\n",
1270
       "      <td>Morris &amp; Spottiswood Ltd</td>\n",
1271
       "      <td>George IV Bridge Work</td>\n",
1272
       "      <td>56448</td>\n",
1273
       "      <td>Building Improvement</td>\n",
1274
       "      <td>Supplier: Morris &amp; Spottiswood Ltd; Descriptio...</td>\n",
1275
       "      <td>17</td>\n",
1276
       "      <td>[-0.005225738976150751, 0.015156379900872707, ...</td>\n",
1277
       "      <td>[-0.007611643522977829, 0.030322374776005745, ...</td>\n",
1278
       "    </tr>\n",
1279
       "    <tr>\n",
1280
       "      <th>3</th>\n",
1281
       "      <td>3</td>\n",
1282
       "      <td>31/05/2017</td>\n",
1283
       "      <td>John Graham Construction Ltd</td>\n",
1284
       "      <td>Causewayside Refurbishment</td>\n",
1285
       "      <td>164691</td>\n",
1286
       "      <td>Building Improvement</td>\n",
1287
       "      <td>Supplier: John Graham Construction Ltd; Descri...</td>\n",
1288
       "      <td>16</td>\n",
1289
       "      <td>[-0.006144719664007425, -0.0018709596479311585...</td>\n",
1290
       "      <td>[-0.007424891460686922, 0.008475713431835175, ...</td>\n",
1291
       "    </tr>\n",
1292
       "    <tr>\n",
1293
       "      <th>4</th>\n",
1294
       "      <td>4</td>\n",
1295
       "      <td>24/07/2017</td>\n",
1296
       "      <td>John Graham Construction Ltd</td>\n",
1297
       "      <td>Causewayside Refurbishment</td>\n",
1298
       "      <td>27926</td>\n",
1299
       "      <td>Building Improvement</td>\n",
1300
       "      <td>Supplier: John Graham Construction Ltd; Descri...</td>\n",
1301
       "      <td>16</td>\n",
1302
       "      <td>[-0.006144719664007425, -0.0018709596479311585...</td>\n",
1303
       "      <td>[-0.007424891460686922, 0.008475713431835175, ...</td>\n",
1304
       "    </tr>\n",
1305
       "  </tbody>\n",
1306
       "</table>\n",
1307
       "</div>"
1308
      ],
1309
      "text/plain": [
1310
       "   Unnamed: 0        Date                        Supplier  \\\n",
1311
       "0           0  15/08/2016  Creative Video Productions Ltd   \n",
1312
       "1           1  29/05/2017    John Graham Construction Ltd   \n",
1313
       "2           2  29/05/2017        Morris & Spottiswood Ltd   \n",
1314
       "3           3  31/05/2017    John Graham Construction Ltd   \n",
1315
       "4           4  24/07/2017    John Graham Construction Ltd   \n",
1316
       "\n",
1317
       "                  Description  Transaction value (£)        Classification  \\\n",
1318
       "0                 Kelvin Hall                  26866                 Other   \n",
1319
       "1  Causewayside Refurbishment                  74806  Building Improvement   \n",
1320
       "2       George IV Bridge Work                  56448  Building Improvement   \n",
1321
       "3  Causewayside Refurbishment                 164691  Building Improvement   \n",
1322
       "4  Causewayside Refurbishment                  27926  Building Improvement   \n",
1323
       "\n",
1324
       "                                            combined  n_tokens  \\\n",
1325
       "0  Supplier: Creative Video Productions Ltd; Desc...        12   \n",
1326
       "1  Supplier: John Graham Construction Ltd; Descri...        16   \n",
1327
       "2  Supplier: Morris & Spottiswood Ltd; Descriptio...        17   \n",
1328
       "3  Supplier: John Graham Construction Ltd; Descri...        16   \n",
1329
       "4  Supplier: John Graham Construction Ltd; Descri...        16   \n",
1330
       "\n",
1331
       "                                  babbage_similarity  \\\n",
1332
       "0  [-0.009630300104618073, 0.009887108579277992, ...   \n",
1333
       "1  [-0.006144719664007425, -0.0018709596479311585...   \n",
1334
       "2  [-0.005225738976150751, 0.015156379900872707, ...   \n",
1335
       "3  [-0.006144719664007425, -0.0018709596479311585...   \n",
1336
       "4  [-0.006144719664007425, -0.0018709596479311585...   \n",
1337
       "\n",
1338
       "                                      babbage_search  \n",
1339
       "0  [-0.008217384107410908, 0.025170527398586273, ...  \n",
1340
       "1  [-0.007424891460686922, 0.008475713431835175, ...  \n",
1341
       "2  [-0.007611643522977829, 0.030322374776005745, ...  \n",
1342
       "3  [-0.007424891460686922, 0.008475713431835175, ...  \n",
1343
       "4  [-0.007424891460686922, 0.008475713431835175, ...  "
1344
      ]
1345
     },
1346
     "execution_count": 211,
1347
     "metadata": {},
1348
     "output_type": "execute_result"
1349
    }
1350
   ],
1351
   "source": [
1352
    "ft_prep_df.head()\n"
1353
   ]
1354
  },
1355
  {
1356
   "cell_type": "code",
1357
   "execution_count": 212,
1358
   "metadata": {},
1359
   "outputs": [
1360
    {
1361
     "data": {
1362
      "text/plain": [
1363
       "(   class_id                 class\n",
1364
       " 0         0  Literature & Archive\n",
1365
       " 1         1         Utility Bills\n",
1366
       " 2         2  Building Improvement\n",
1367
       " 3         3           Software/IT\n",
1368
       " 4         4                 Other,\n",
1369
       " 5)"
1370
      ]
1371
     },
1372
     "execution_count": 212,
1373
     "metadata": {},
1374
     "output_type": "execute_result"
1375
    }
1376
   ],
1377
   "source": [
1378
    "classes = list(set(ft_prep_df['Classification']))\n",
1379
    "class_df = pd.DataFrame(classes).reset_index()\n",
1380
    "class_df.columns = ['class_id','class']\n",
1381
    "class_df  , len(class_df)\n"
1382
   ]
1383
  },
1384
  {
1385
   "cell_type": "code",
1386
   "execution_count": 215,
1387
   "metadata": {},
1388
   "outputs": [
1389
    {
1390
     "data": {
1391
      "text/html": [
1392
       "<div>\n",
1393
       "<style scoped>\n",
1394
       "    .dataframe tbody tr th:only-of-type {\n",
1395
       "        vertical-align: middle;\n",
1396
       "    }\n",
1397
       "\n",
1398
       "    .dataframe tbody tr th {\n",
1399
       "        vertical-align: top;\n",
1400
       "    }\n",
1401
       "\n",
1402
       "    .dataframe thead th {\n",
1403
       "        text-align: right;\n",
1404
       "    }\n",
1405
       "</style>\n",
1406
       "<table border=\"1\" class=\"dataframe\">\n",
1407
       "  <thead>\n",
1408
       "    <tr style=\"text-align: right;\">\n",
1409
       "      <th></th>\n",
1410
       "      <th>Unnamed: 0</th>\n",
1411
       "      <th>Date</th>\n",
1412
       "      <th>Supplier</th>\n",
1413
       "      <th>Description</th>\n",
1414
       "      <th>Transaction value (£)</th>\n",
1415
       "      <th>Classification</th>\n",
1416
       "      <th>combined</th>\n",
1417
       "      <th>n_tokens</th>\n",
1418
       "      <th>babbage_similarity</th>\n",
1419
       "      <th>babbage_search</th>\n",
1420
       "      <th>class_id</th>\n",
1421
       "      <th>prompt</th>\n",
1422
       "    </tr>\n",
1423
       "  </thead>\n",
1424
       "  <tbody>\n",
1425
       "    <tr>\n",
1426
       "      <th>0</th>\n",
1427
       "      <td>0</td>\n",
1428
       "      <td>15/08/2016</td>\n",
1429
       "      <td>Creative Video Productions Ltd</td>\n",
1430
       "      <td>Kelvin Hall</td>\n",
1431
       "      <td>26866</td>\n",
1432
       "      <td>Other</td>\n",
1433
       "      <td>Supplier: Creative Video Productions Ltd; Desc...</td>\n",
1434
       "      <td>12</td>\n",
1435
       "      <td>[-0.009630300104618073, 0.009887108579277992, ...</td>\n",
1436
       "      <td>[-0.008217384107410908, 0.025170527398586273, ...</td>\n",
1437
       "      <td>4</td>\n",
1438
       "      <td>Supplier: Creative Video Productions Ltd; Desc...</td>\n",
1439
       "    </tr>\n",
1440
       "    <tr>\n",
1441
       "      <th>1</th>\n",
1442
       "      <td>51</td>\n",
1443
       "      <td>31/03/2017</td>\n",
1444
       "      <td>NLS Foundation</td>\n",
1445
       "      <td>Grant Payment</td>\n",
1446
       "      <td>177500</td>\n",
1447
       "      <td>Other</td>\n",
1448
       "      <td>Supplier: NLS Foundation; Description: Grant P...</td>\n",
1449
       "      <td>11</td>\n",
1450
       "      <td>[-0.022305507212877274, 0.008543581701815128, ...</td>\n",
1451
       "      <td>[-0.020519884303212166, 0.01993306167423725, -...</td>\n",
1452
       "      <td>4</td>\n",
1453
       "      <td>Supplier: NLS Foundation; Description: Grant P...</td>\n",
1454
       "    </tr>\n",
1455
       "    <tr>\n",
1456
       "      <th>2</th>\n",
1457
       "      <td>70</td>\n",
1458
       "      <td>26/06/2017</td>\n",
1459
       "      <td>British Library</td>\n",
1460
       "      <td>Legal Deposit Services</td>\n",
1461
       "      <td>50056</td>\n",
1462
       "      <td>Other</td>\n",
1463
       "      <td>Supplier: British Library; Description: Legal ...</td>\n",
1464
       "      <td>11</td>\n",
1465
       "      <td>[-0.01019938476383686, 0.015277703292667866, -...</td>\n",
1466
       "      <td>[-0.01843327097594738, 0.03343546763062477, -0...</td>\n",
1467
       "      <td>4</td>\n",
1468
       "      <td>Supplier: British Library; Description: Legal ...</td>\n",
1469
       "    </tr>\n",
1470
       "    <tr>\n",
1471
       "      <th>3</th>\n",
1472
       "      <td>71</td>\n",
1473
       "      <td>24/07/2017</td>\n",
1474
       "      <td>ALDL</td>\n",
1475
       "      <td>Legal Deposit Services</td>\n",
1476
       "      <td>27067</td>\n",
1477
       "      <td>Other</td>\n",
1478
       "      <td>Supplier: ALDL; Description: Legal Deposit Ser...</td>\n",
1479
       "      <td>11</td>\n",
1480
       "      <td>[-0.008471488021314144, 0.004098685923963785, ...</td>\n",
1481
       "      <td>[-0.012966590002179146, 0.01299362163990736, 0...</td>\n",
1482
       "      <td>4</td>\n",
1483
       "      <td>Supplier: ALDL; Description: Legal Deposit Ser...</td>\n",
1484
       "    </tr>\n",
1485
       "    <tr>\n",
1486
       "      <th>4</th>\n",
1487
       "      <td>100</td>\n",
1488
       "      <td>24/07/2017</td>\n",
1489
       "      <td>AM Phillip</td>\n",
1490
       "      <td>Vehicle Purchase</td>\n",
1491
       "      <td>26604</td>\n",
1492
       "      <td>Other</td>\n",
1493
       "      <td>Supplier: AM Phillip; Description: Vehicle Pur...</td>\n",
1494
       "      <td>10</td>\n",
1495
       "      <td>[-0.003459023078903556, 0.004626389592885971, ...</td>\n",
1496
       "      <td>[-0.0010945454705506563, 0.008626140654087067,...</td>\n",
1497
       "      <td>4</td>\n",
1498
       "      <td>Supplier: AM Phillip; Description: Vehicle Pur...</td>\n",
1499
       "    </tr>\n",
1500
       "  </tbody>\n",
1501
       "</table>\n",
1502
       "</div>"
1503
      ],
1504
      "text/plain": [
1505
       "   Unnamed: 0        Date                        Supplier  \\\n",
1506
       "0           0  15/08/2016  Creative Video Productions Ltd   \n",
1507
       "1          51  31/03/2017                  NLS Foundation   \n",
1508
       "2          70  26/06/2017                 British Library   \n",
1509
       "3          71  24/07/2017                            ALDL   \n",
1510
       "4         100  24/07/2017                      AM Phillip   \n",
1511
       "\n",
1512
       "              Description  Transaction value (£) Classification  \\\n",
1513
       "0             Kelvin Hall                  26866          Other   \n",
1514
       "1           Grant Payment                 177500          Other   \n",
1515
       "2  Legal Deposit Services                  50056          Other   \n",
1516
       "3  Legal Deposit Services                  27067          Other   \n",
1517
       "4        Vehicle Purchase                  26604          Other   \n",
1518
       "\n",
1519
       "                                            combined  n_tokens  \\\n",
1520
       "0  Supplier: Creative Video Productions Ltd; Desc...        12   \n",
1521
       "1  Supplier: NLS Foundation; Description: Grant P...        11   \n",
1522
       "2  Supplier: British Library; Description: Legal ...        11   \n",
1523
       "3  Supplier: ALDL; Description: Legal Deposit Ser...        11   \n",
1524
       "4  Supplier: AM Phillip; Description: Vehicle Pur...        10   \n",
1525
       "\n",
1526
       "                                  babbage_similarity  \\\n",
1527
       "0  [-0.009630300104618073, 0.009887108579277992, ...   \n",
1528
       "1  [-0.022305507212877274, 0.008543581701815128, ...   \n",
1529
       "2  [-0.01019938476383686, 0.015277703292667866, -...   \n",
1530
       "3  [-0.008471488021314144, 0.004098685923963785, ...   \n",
1531
       "4  [-0.003459023078903556, 0.004626389592885971, ...   \n",
1532
       "\n",
1533
       "                                      babbage_search class_id  \\\n",
1534
       "0  [-0.008217384107410908, 0.025170527398586273, ...        4   \n",
1535
       "1  [-0.020519884303212166, 0.01993306167423725, -...        4   \n",
1536
       "2  [-0.01843327097594738, 0.03343546763062477, -0...        4   \n",
1537
       "3  [-0.012966590002179146, 0.01299362163990736, 0...        4   \n",
1538
       "4  [-0.0010945454705506563, 0.008626140654087067,...        4   \n",
1539
       "\n",
1540
       "                                              prompt  \n",
1541
       "0  Supplier: Creative Video Productions Ltd; Desc...  \n",
1542
       "1  Supplier: NLS Foundation; Description: Grant P...  \n",
1543
       "2  Supplier: British Library; Description: Legal ...  \n",
1544
       "3  Supplier: ALDL; Description: Legal Deposit Ser...  \n",
1545
       "4  Supplier: AM Phillip; Description: Vehicle Pur...  "
1546
      ]
1547
     },
1548
     "execution_count": 215,
1549
     "metadata": {},
1550
     "output_type": "execute_result"
1551
    }
1552
   ],
1553
   "source": [
1554
    "ft_df_with_class = ft_prep_df.merge(class_df,left_on='Classification',right_on='class',how='inner')\n",
1555
    "\n",
1556
    "# Adding a leading whitespace onto each completion to help the model\n",
1557
    "ft_df_with_class['class_id'] = ft_df_with_class.apply(lambda x: ' ' + str(x['class_id']),axis=1)\n",
1558
    "ft_df_with_class = ft_df_with_class.drop('class', axis=1)\n",
1559
    "\n",
1560
    "# Adding a common separator onto the end of each prompt so the model knows when a prompt is terminating\n",
1561
    "ft_df_with_class['prompt'] = ft_df_with_class.apply(lambda x: x['combined'] + '\\n\\n###\\n\\n',axis=1)\n",
1562
    "ft_df_with_class.head()\n"
1563
   ]
1564
  },
1565
  {
1566
   "cell_type": "code",
1567
   "execution_count": 236,
1568
   "metadata": {},
1569
   "outputs": [
1570
    {
1571
     "data": {
1572
      "text/html": [
1573
       "<div>\n",
1574
       "<style scoped>\n",
1575
       "    .dataframe tbody tr th:only-of-type {\n",
1576
       "        vertical-align: middle;\n",
1577
       "    }\n",
1578
       "\n",
1579
       "    .dataframe tbody tr th {\n",
1580
       "        vertical-align: top;\n",
1581
       "    }\n",
1582
       "\n",
1583
       "    .dataframe thead th {\n",
1584
       "        text-align: right;\n",
1585
       "    }\n",
1586
       "</style>\n",
1587
       "<table border=\"1\" class=\"dataframe\">\n",
1588
       "  <thead>\n",
1589
       "    <tr style=\"text-align: right;\">\n",
1590
       "      <th></th>\n",
1591
       "      <th>prompt</th>\n",
1592
       "      <th>completion</th>\n",
1593
       "    </tr>\n",
1594
       "    <tr>\n",
1595
       "      <th>ordering</th>\n",
1596
       "      <th></th>\n",
1597
       "      <th></th>\n",
1598
       "    </tr>\n",
1599
       "  </thead>\n",
1600
       "  <tbody>\n",
1601
       "    <tr>\n",
1602
       "      <th>0</th>\n",
1603
       "      <td>Supplier: Sothebys; Description: Literary &amp; Ar...</td>\n",
1604
       "      <td>0</td>\n",
1605
       "    </tr>\n",
1606
       "    <tr>\n",
1607
       "      <th>1</th>\n",
1608
       "      <td>Supplier: Sotheby'S; Description: Literary &amp; A...</td>\n",
1609
       "      <td>0</td>\n",
1610
       "    </tr>\n",
1611
       "    <tr>\n",
1612
       "      <th>2</th>\n",
1613
       "      <td>Supplier: City Of Edinburgh Council; Descripti...</td>\n",
1614
       "      <td>1</td>\n",
1615
       "    </tr>\n",
1616
       "    <tr>\n",
1617
       "      <th>2</th>\n",
1618
       "      <td>Supplier: John Graham Construction Ltd; Descri...</td>\n",
1619
       "      <td>2</td>\n",
1620
       "    </tr>\n",
1621
       "    <tr>\n",
1622
       "      <th>3</th>\n",
1623
       "      <td>Supplier: John Graham Construction Ltd; Descri...</td>\n",
1624
       "      <td>2</td>\n",
1625
       "    </tr>\n",
1626
       "  </tbody>\n",
1627
       "</table>\n",
1628
       "</div>"
1629
      ],
1630
      "text/plain": [
1631
       "                                                     prompt completion\n",
1632
       "ordering                                                              \n",
1633
       "0         Supplier: Sothebys; Description: Literary & Ar...          0\n",
1634
       "1         Supplier: Sotheby'S; Description: Literary & A...          0\n",
1635
       "2         Supplier: City Of Edinburgh Council; Descripti...          1\n",
1636
       "2         Supplier: John Graham Construction Ltd; Descri...          2\n",
1637
       "3         Supplier: John Graham Construction Ltd; Descri...          2"
1638
      ]
1639
     },
1640
     "execution_count": 236,
1641
     "metadata": {},
1642
     "output_type": "execute_result"
1643
    }
1644
   ],
1645
   "source": [
1646
    "# This step is unnecessary if you have a number of observations in each class\n",
1647
    "# In our case we don't, so we shuffle the data to give us a better chance of getting equal classes in our train and validation sets\n",
1648
    "# Our fine-tuned model will error if we have less classes in the validation set, so this is a necessary step\n",
1649
    "\n",
1650
    "import random\n",
1651
    "\n",
1652
    "labels = [x for x in ft_df_with_class['class_id']]\n",
1653
    "text = [x for x in ft_df_with_class['prompt']]\n",
1654
    "ft_df = pd.DataFrame(zip(text, labels), columns = ['prompt','class_id']) #[:300]\n",
1655
    "ft_df.columns = ['prompt','completion']\n",
1656
    "ft_df['ordering'] = ft_df.apply(lambda x: random.randint(0,len(ft_df)), axis = 1)\n",
1657
    "ft_df.set_index('ordering',inplace=True)\n",
1658
    "ft_df_sorted = ft_df.sort_index(ascending=True)\n",
1659
    "ft_df_sorted.head()\n"
1660
   ]
1661
  },
1662
  {
1663
   "cell_type": "code",
1664
   "execution_count": null,
1665
   "metadata": {},
1666
   "outputs": [],
1667
   "source": [
1668
    "# This step is to remove any existing files if we've already produced training/validation sets for this classifier\n",
1669
    "#!rm transactions_grouped*\n",
1670
    "\n",
1671
    "# We output our shuffled dataframe to a .jsonl file and run the prepare_data function to get us our input files\n",
1672
    "ft_df_sorted.to_json(\"transactions_grouped.jsonl\", orient='records', lines=True)\n",
1673
    "!openai tools fine_tunes.prepare_data -f transactions_grouped.jsonl -q\n"
1674
   ]
1675
  },
1676
  {
1677
   "cell_type": "code",
1678
   "execution_count": 322,
1679
   "metadata": {},
1680
   "outputs": [
1681
    {
1682
     "name": "stdout",
1683
     "output_type": "stream",
1684
     "text": [
1685
      "31\n",
1686
      "8\n",
1687
      "All good\n"
1688
     ]
1689
    }
1690
   ],
1691
   "source": [
1692
    "# This functions checks that your classes all appear in both prepared files\n",
1693
    "# If they don't, the fine-tuned model creation will fail\n",
1694
    "check_finetune_classes('transactions_grouped_prepared_train.jsonl','transactions_grouped_prepared_valid.jsonl')\n"
1695
   ]
1696
  },
1697
  {
1698
   "cell_type": "code",
1699
   "execution_count": null,
1700
   "metadata": {},
1701
   "outputs": [],
1702
   "source": [
1703
    "# This step creates your model\n",
1704
    "!openai api fine_tunes.create -t \"transactions_grouped_prepared_train.jsonl\" -v \"transactions_grouped_prepared_valid.jsonl\" --compute_classification_metrics --classification_n_classes 5 -m curie\n",
1705
    "\n",
1706
    "# You can use following command to get fine tuning job status and model name, replace the job name with your job\n",
1707
    "#!openai api fine_tunes.get -i ft-YBIc01t4hxYBC7I5qhRF3Qdx\n"
1708
   ]
1709
  },
1710
  {
1711
   "cell_type": "code",
1712
   "execution_count": 323,
1713
   "metadata": {},
1714
   "outputs": [],
1715
   "source": [
1716
    "# Congrats, you've got a fine-tuned model!\n",
1717
    "# Copy/paste the name provided into the variable below and we'll take it for a spin\n",
1718
    "fine_tuned_model = 'curie:ft-personal-2022-10-20-10-42-56'\n"
1719
   ]
1720
  },
1721
  {
1722
   "attachments": {},
1723
   "cell_type": "markdown",
1724
   "metadata": {},
1725
   "source": [
1726
    "### Applying Fine-tuned Classifier\n",
1727
    "\n",
1728
    "Now we'll apply our classifier to see how it performs. We only had 31 unique observations in our training set and 8 in our validation set, so lets see how the performance is"
1729
   ]
1730
  },
1731
  {
1732
   "cell_type": "code",
1733
   "execution_count": 324,
1734
   "metadata": {},
1735
   "outputs": [
1736
    {
1737
     "data": {
1738
      "text/html": [
1739
       "<div>\n",
1740
       "<style scoped>\n",
1741
       "    .dataframe tbody tr th:only-of-type {\n",
1742
       "        vertical-align: middle;\n",
1743
       "    }\n",
1744
       "\n",
1745
       "    .dataframe tbody tr th {\n",
1746
       "        vertical-align: top;\n",
1747
       "    }\n",
1748
       "\n",
1749
       "    .dataframe thead th {\n",
1750
       "        text-align: right;\n",
1751
       "    }\n",
1752
       "</style>\n",
1753
       "<table border=\"1\" class=\"dataframe\">\n",
1754
       "  <thead>\n",
1755
       "    <tr style=\"text-align: right;\">\n",
1756
       "      <th></th>\n",
1757
       "      <th>prompt</th>\n",
1758
       "      <th>completion</th>\n",
1759
       "    </tr>\n",
1760
       "  </thead>\n",
1761
       "  <tbody>\n",
1762
       "    <tr>\n",
1763
       "      <th>0</th>\n",
1764
       "      <td>Supplier: Wavetek Ltd; Description: Kelvin Hal...</td>\n",
1765
       "      <td>2</td>\n",
1766
       "    </tr>\n",
1767
       "    <tr>\n",
1768
       "      <th>1</th>\n",
1769
       "      <td>Supplier: ECG Facilities Service; Description:...</td>\n",
1770
       "      <td>1</td>\n",
1771
       "    </tr>\n",
1772
       "    <tr>\n",
1773
       "      <th>2</th>\n",
1774
       "      <td>Supplier: M &amp; J Ballantyne Ltd; Description: G...</td>\n",
1775
       "      <td>2</td>\n",
1776
       "    </tr>\n",
1777
       "    <tr>\n",
1778
       "      <th>3</th>\n",
1779
       "      <td>Supplier: Private Sale; Description: Literary ...</td>\n",
1780
       "      <td>0</td>\n",
1781
       "    </tr>\n",
1782
       "    <tr>\n",
1783
       "      <th>4</th>\n",
1784
       "      <td>Supplier: Ex Libris; Description: IT equipment...</td>\n",
1785
       "      <td>3</td>\n",
1786
       "    </tr>\n",
1787
       "  </tbody>\n",
1788
       "</table>\n",
1789
       "</div>"
1790
      ],
1791
      "text/plain": [
1792
       "                                              prompt  completion\n",
1793
       "0  Supplier: Wavetek Ltd; Description: Kelvin Hal...           2\n",
1794
       "1  Supplier: ECG Facilities Service; Description:...           1\n",
1795
       "2  Supplier: M & J Ballantyne Ltd; Description: G...           2\n",
1796
       "3  Supplier: Private Sale; Description: Literary ...           0\n",
1797
       "4  Supplier: Ex Libris; Description: IT equipment...           3"
1798
      ]
1799
     },
1800
     "execution_count": 324,
1801
     "metadata": {},
1802
     "output_type": "execute_result"
1803
    }
1804
   ],
1805
   "source": [
1806
    "test_set = pd.read_json('transactions_grouped_prepared_valid.jsonl', lines=True)\n",
1807
    "test_set.head()\n"
1808
   ]
1809
  },
1810
  {
1811
   "cell_type": "code",
1812
   "execution_count": 325,
1813
   "metadata": {},
1814
   "outputs": [],
1815
   "source": [
1816
    "test_set['predicted_class'] = test_set.apply(lambda x: openai.chat.completions.create(model=fine_tuned_model, prompt=x['prompt'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n",
1817
    "test_set['pred'] = test_set.apply(lambda x : x['predicted_class']['choices'][0]['text'],axis=1)\n"
1818
   ]
1819
  },
1820
  {
1821
   "cell_type": "code",
1822
   "execution_count": 326,
1823
   "metadata": {},
1824
   "outputs": [],
1825
   "source": [
1826
    "test_set['result'] = test_set.apply(lambda x: str(x['pred']).strip() == str(x['completion']).strip(), axis = 1)\n"
1827
   ]
1828
  },
1829
  {
1830
   "cell_type": "code",
1831
   "execution_count": 327,
1832
   "metadata": {},
1833
   "outputs": [
1834
    {
1835
     "data": {
1836
      "text/plain": [
1837
       "True     4\n",
1838
       "False    4\n",
1839
       "Name: result, dtype: int64"
1840
      ]
1841
     },
1842
     "execution_count": 327,
1843
     "metadata": {},
1844
     "output_type": "execute_result"
1845
    }
1846
   ],
1847
   "source": [
1848
    "test_set['result'].value_counts()\n"
1849
   ]
1850
  },
1851
  {
1852
   "attachments": {},
1853
   "cell_type": "markdown",
1854
   "metadata": {},
1855
   "source": [
1856
    "Performance is not great - unfortunately this is expected. With only a few examples of each class, the above approach with embeddings and a traditional classifier worked better.\n",
1857
    "\n",
1858
    "A fine-tuned model works best with a great number of labelled observations. If we had a few hundred or thousand we may get better results, but lets do one last test on a holdout set to confirm that it doesn't generalise well to a new set of observations"
1859
   ]
1860
  },
1861
  {
1862
   "cell_type": "code",
1863
   "execution_count": 330,
1864
   "metadata": {},
1865
   "outputs": [
1866
    {
1867
     "data": {
1868
      "text/html": [
1869
       "<div>\n",
1870
       "<style scoped>\n",
1871
       "    .dataframe tbody tr th:only-of-type {\n",
1872
       "        vertical-align: middle;\n",
1873
       "    }\n",
1874
       "\n",
1875
       "    .dataframe tbody tr th {\n",
1876
       "        vertical-align: top;\n",
1877
       "    }\n",
1878
       "\n",
1879
       "    .dataframe thead th {\n",
1880
       "        text-align: right;\n",
1881
       "    }\n",
1882
       "</style>\n",
1883
       "<table border=\"1\" class=\"dataframe\">\n",
1884
       "  <thead>\n",
1885
       "    <tr style=\"text-align: right;\">\n",
1886
       "      <th></th>\n",
1887
       "      <th>Date</th>\n",
1888
       "      <th>Supplier</th>\n",
1889
       "      <th>Description</th>\n",
1890
       "      <th>Transaction value (£)</th>\n",
1891
       "    </tr>\n",
1892
       "  </thead>\n",
1893
       "  <tbody>\n",
1894
       "    <tr>\n",
1895
       "      <th>101</th>\n",
1896
       "      <td>23/10/2017</td>\n",
1897
       "      <td>City Building LLP</td>\n",
1898
       "      <td>Causewayside Refurbishment</td>\n",
1899
       "      <td>53147.0</td>\n",
1900
       "    </tr>\n",
1901
       "    <tr>\n",
1902
       "      <th>102</th>\n",
1903
       "      <td>30/10/2017</td>\n",
1904
       "      <td>ECG Facilities Service</td>\n",
1905
       "      <td>Facilities Management Charge</td>\n",
1906
       "      <td>35758.0</td>\n",
1907
       "    </tr>\n",
1908
       "    <tr>\n",
1909
       "      <th>103</th>\n",
1910
       "      <td>30/10/2017</td>\n",
1911
       "      <td>ECG Facilities Service</td>\n",
1912
       "      <td>Facilities Management Charge</td>\n",
1913
       "      <td>35758.0</td>\n",
1914
       "    </tr>\n",
1915
       "    <tr>\n",
1916
       "      <th>104</th>\n",
1917
       "      <td>06/11/2017</td>\n",
1918
       "      <td>John Graham Construction Ltd</td>\n",
1919
       "      <td>Causewayside Refurbishment</td>\n",
1920
       "      <td>134208.0</td>\n",
1921
       "    </tr>\n",
1922
       "    <tr>\n",
1923
       "      <th>105</th>\n",
1924
       "      <td>06/11/2017</td>\n",
1925
       "      <td>ALDL</td>\n",
1926
       "      <td>Legal Deposit Services</td>\n",
1927
       "      <td>27067.0</td>\n",
1928
       "    </tr>\n",
1929
       "  </tbody>\n",
1930
       "</table>\n",
1931
       "</div>"
1932
      ],
1933
      "text/plain": [
1934
       "           Date                      Supplier                   Description  \\\n",
1935
       "101  23/10/2017             City Building LLP    Causewayside Refurbishment   \n",
1936
       "102  30/10/2017        ECG Facilities Service  Facilities Management Charge   \n",
1937
       "103  30/10/2017        ECG Facilities Service  Facilities Management Charge   \n",
1938
       "104  06/11/2017  John Graham Construction Ltd    Causewayside Refurbishment   \n",
1939
       "105  06/11/2017                          ALDL        Legal Deposit Services   \n",
1940
       "\n",
1941
       "     Transaction value (£)  \n",
1942
       "101                53147.0  \n",
1943
       "102                35758.0  \n",
1944
       "103                35758.0  \n",
1945
       "104               134208.0  \n",
1946
       "105                27067.0  "
1947
      ]
1948
     },
1949
     "execution_count": 330,
1950
     "metadata": {},
1951
     "output_type": "execute_result"
1952
    }
1953
   ],
1954
   "source": [
1955
    "holdout_df = transactions.copy().iloc[101:]\n",
1956
    "holdout_df.head()\n"
1957
   ]
1958
  },
1959
  {
1960
   "cell_type": "code",
1961
   "execution_count": 332,
1962
   "metadata": {},
1963
   "outputs": [],
1964
   "source": [
1965
    "holdout_df['combined'] = \"Supplier: \" + holdout_df['Supplier'].str.strip() + \"; Description: \" + holdout_df['Description'].str.strip() + '\\n\\n###\\n\\n' # + \"; Value: \" + str(df['Transaction value (£)']).strip()\n",
1966
    "holdout_df['prediction_result'] = holdout_df.apply(lambda x: openai.chat.completions.create(model=fine_tuned_model, prompt=x['combined'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n",
1967
    "holdout_df['pred'] = holdout_df.apply(lambda x : x['prediction_result']['choices'][0]['text'],axis=1)\n"
1968
   ]
1969
  },
1970
  {
1971
   "cell_type": "code",
1972
   "execution_count": 333,
1973
   "metadata": {},
1974
   "outputs": [
1975
    {
1976
     "data": {
1977
      "text/html": [
1978
       "<div>\n",
1979
       "<style scoped>\n",
1980
       "    .dataframe tbody tr th:only-of-type {\n",
1981
       "        vertical-align: middle;\n",
1982
       "    }\n",
1983
       "\n",
1984
       "    .dataframe tbody tr th {\n",
1985
       "        vertical-align: top;\n",
1986
       "    }\n",
1987
       "\n",
1988
       "    .dataframe thead th {\n",
1989
       "        text-align: right;\n",
1990
       "    }\n",
1991
       "</style>\n",
1992
       "<table border=\"1\" class=\"dataframe\">\n",
1993
       "  <thead>\n",
1994
       "    <tr style=\"text-align: right;\">\n",
1995
       "      <th></th>\n",
1996
       "      <th>Date</th>\n",
1997
       "      <th>Supplier</th>\n",
1998
       "      <th>Description</th>\n",
1999
       "      <th>Transaction value (£)</th>\n",
2000
       "      <th>combined</th>\n",
2001
       "      <th>prediction_result</th>\n",
2002
       "      <th>pred</th>\n",
2003
       "    </tr>\n",
2004
       "  </thead>\n",
2005
       "  <tbody>\n",
2006
       "    <tr>\n",
2007
       "      <th>101</th>\n",
2008
       "      <td>23/10/2017</td>\n",
2009
       "      <td>City Building LLP</td>\n",
2010
       "      <td>Causewayside Refurbishment</td>\n",
2011
       "      <td>53147.0</td>\n",
2012
       "      <td>Supplier: City Building LLP; Description: Caus...</td>\n",
2013
       "      <td>{'id': 'cmpl-63YDadbYLo8xKsGY2vReOFCMgTOvG', '...</td>\n",
2014
       "      <td>2</td>\n",
2015
       "    </tr>\n",
2016
       "    <tr>\n",
2017
       "      <th>102</th>\n",
2018
       "      <td>30/10/2017</td>\n",
2019
       "      <td>ECG Facilities Service</td>\n",
2020
       "      <td>Facilities Management Charge</td>\n",
2021
       "      <td>35758.0</td>\n",
2022
       "      <td>Supplier: ECG Facilities Service; Description:...</td>\n",
2023
       "      <td>{'id': 'cmpl-63YDbNK1D7UikDc3xi5ATihg5kQEt', '...</td>\n",
2024
       "      <td>2</td>\n",
2025
       "    </tr>\n",
2026
       "    <tr>\n",
2027
       "      <th>103</th>\n",
2028
       "      <td>30/10/2017</td>\n",
2029
       "      <td>ECG Facilities Service</td>\n",
2030
       "      <td>Facilities Management Charge</td>\n",
2031
       "      <td>35758.0</td>\n",
2032
       "      <td>Supplier: ECG Facilities Service; Description:...</td>\n",
2033
       "      <td>{'id': 'cmpl-63YDbwfiHjkjMWsfTKNt6naeqPzOe', '...</td>\n",
2034
       "      <td>2</td>\n",
2035
       "    </tr>\n",
2036
       "    <tr>\n",
2037
       "      <th>104</th>\n",
2038
       "      <td>06/11/2017</td>\n",
2039
       "      <td>John Graham Construction Ltd</td>\n",
2040
       "      <td>Causewayside Refurbishment</td>\n",
2041
       "      <td>134208.0</td>\n",
2042
       "      <td>Supplier: John Graham Construction Ltd; Descri...</td>\n",
2043
       "      <td>{'id': 'cmpl-63YDbWAndtsRqPTi2ZHZtPodZvOwr', '...</td>\n",
2044
       "      <td>2</td>\n",
2045
       "    </tr>\n",
2046
       "    <tr>\n",
2047
       "      <th>105</th>\n",
2048
       "      <td>06/11/2017</td>\n",
2049
       "      <td>ALDL</td>\n",
2050
       "      <td>Legal Deposit Services</td>\n",
2051
       "      <td>27067.0</td>\n",
2052
       "      <td>Supplier: ALDL; Description: Legal Deposit Ser...</td>\n",
2053
       "      <td>{'id': 'cmpl-63YDbDu7WM3svYWsRAMdDUKtSFDBu', '...</td>\n",
2054
       "      <td>2</td>\n",
2055
       "    </tr>\n",
2056
       "    <tr>\n",
2057
       "      <th>106</th>\n",
2058
       "      <td>27/11/2017</td>\n",
2059
       "      <td>Maggs Bros Ltd</td>\n",
2060
       "      <td>Literary &amp; Archival Items</td>\n",
2061
       "      <td>26500.0</td>\n",
2062
       "      <td>Supplier: Maggs Bros Ltd; Description: Literar...</td>\n",
2063
       "      <td>{'id': 'cmpl-63YDbxNNI8ZH5CJJNxQ0IF9Zf925C', '...</td>\n",
2064
       "      <td>0</td>\n",
2065
       "    </tr>\n",
2066
       "    <tr>\n",
2067
       "      <th>107</th>\n",
2068
       "      <td>30/11/2017</td>\n",
2069
       "      <td>Glasgow City Council</td>\n",
2070
       "      <td>Kelvin Hall</td>\n",
2071
       "      <td>42345.0</td>\n",
2072
       "      <td>Supplier: Glasgow City Council; Description: K...</td>\n",
2073
       "      <td>{'id': 'cmpl-63YDb8R1FWu4bjwM2xE775rouwneV', '...</td>\n",
2074
       "      <td>2</td>\n",
2075
       "    </tr>\n",
2076
       "    <tr>\n",
2077
       "      <th>108</th>\n",
2078
       "      <td>11/12/2017</td>\n",
2079
       "      <td>ECG Facilities Service</td>\n",
2080
       "      <td>Facilities Management Charge</td>\n",
2081
       "      <td>35758.0</td>\n",
2082
       "      <td>Supplier: ECG Facilities Service; Description:...</td>\n",
2083
       "      <td>{'id': 'cmpl-63YDcAPsp37WhbPs9kwfUX0kBk7Hv', '...</td>\n",
2084
       "      <td>2</td>\n",
2085
       "    </tr>\n",
2086
       "    <tr>\n",
2087
       "      <th>109</th>\n",
2088
       "      <td>11/12/2017</td>\n",
2089
       "      <td>John Graham Construction Ltd</td>\n",
2090
       "      <td>Causewayside Refurbishment</td>\n",
2091
       "      <td>159275.0</td>\n",
2092
       "      <td>Supplier: John Graham Construction Ltd; Descri...</td>\n",
2093
       "      <td>{'id': 'cmpl-63YDcML2welrC3wF0nuKgcNmVu1oQ', '...</td>\n",
2094
       "      <td>2</td>\n",
2095
       "    </tr>\n",
2096
       "    <tr>\n",
2097
       "      <th>110</th>\n",
2098
       "      <td>08/01/2018</td>\n",
2099
       "      <td>ECG Facilities Service</td>\n",
2100
       "      <td>Facilities Management Charge</td>\n",
2101
       "      <td>35758.0</td>\n",
2102
       "      <td>Supplier: ECG Facilities Service; Description:...</td>\n",
2103
       "      <td>{'id': 'cmpl-63YDc95SSdOHnIliFB2cjMEEm7Z2u', '...</td>\n",
2104
       "      <td>2</td>\n",
2105
       "    </tr>\n",
2106
       "  </tbody>\n",
2107
       "</table>\n",
2108
       "</div>"
2109
      ],
2110
      "text/plain": [
2111
       "           Date                      Supplier                   Description  \\\n",
2112
       "101  23/10/2017             City Building LLP    Causewayside Refurbishment   \n",
2113
       "102  30/10/2017        ECG Facilities Service  Facilities Management Charge   \n",
2114
       "103  30/10/2017        ECG Facilities Service  Facilities Management Charge   \n",
2115
       "104  06/11/2017  John Graham Construction Ltd    Causewayside Refurbishment   \n",
2116
       "105  06/11/2017                          ALDL        Legal Deposit Services   \n",
2117
       "106  27/11/2017                Maggs Bros Ltd     Literary & Archival Items   \n",
2118
       "107  30/11/2017          Glasgow City Council                   Kelvin Hall   \n",
2119
       "108  11/12/2017        ECG Facilities Service  Facilities Management Charge   \n",
2120
       "109  11/12/2017  John Graham Construction Ltd    Causewayside Refurbishment   \n",
2121
       "110  08/01/2018        ECG Facilities Service  Facilities Management Charge   \n",
2122
       "\n",
2123
       "     Transaction value (£)                                           combined  \\\n",
2124
       "101                53147.0  Supplier: City Building LLP; Description: Caus...   \n",
2125
       "102                35758.0  Supplier: ECG Facilities Service; Description:...   \n",
2126
       "103                35758.0  Supplier: ECG Facilities Service; Description:...   \n",
2127
       "104               134208.0  Supplier: John Graham Construction Ltd; Descri...   \n",
2128
       "105                27067.0  Supplier: ALDL; Description: Legal Deposit Ser...   \n",
2129
       "106                26500.0  Supplier: Maggs Bros Ltd; Description: Literar...   \n",
2130
       "107                42345.0  Supplier: Glasgow City Council; Description: K...   \n",
2131
       "108                35758.0  Supplier: ECG Facilities Service; Description:...   \n",
2132
       "109               159275.0  Supplier: John Graham Construction Ltd; Descri...   \n",
2133
       "110                35758.0  Supplier: ECG Facilities Service; Description:...   \n",
2134
       "\n",
2135
       "                                     prediction_result pred  \n",
2136
       "101  {'id': 'cmpl-63YDadbYLo8xKsGY2vReOFCMgTOvG', '...    2  \n",
2137
       "102  {'id': 'cmpl-63YDbNK1D7UikDc3xi5ATihg5kQEt', '...    2  \n",
2138
       "103  {'id': 'cmpl-63YDbwfiHjkjMWsfTKNt6naeqPzOe', '...    2  \n",
2139
       "104  {'id': 'cmpl-63YDbWAndtsRqPTi2ZHZtPodZvOwr', '...    2  \n",
2140
       "105  {'id': 'cmpl-63YDbDu7WM3svYWsRAMdDUKtSFDBu', '...    2  \n",
2141
       "106  {'id': 'cmpl-63YDbxNNI8ZH5CJJNxQ0IF9Zf925C', '...    0  \n",
2142
       "107  {'id': 'cmpl-63YDb8R1FWu4bjwM2xE775rouwneV', '...    2  \n",
2143
       "108  {'id': 'cmpl-63YDcAPsp37WhbPs9kwfUX0kBk7Hv', '...    2  \n",
2144
       "109  {'id': 'cmpl-63YDcML2welrC3wF0nuKgcNmVu1oQ', '...    2  \n",
2145
       "110  {'id': 'cmpl-63YDc95SSdOHnIliFB2cjMEEm7Z2u', '...    2  "
2146
      ]
2147
     },
2148
     "execution_count": 333,
2149
     "metadata": {},
2150
     "output_type": "execute_result"
2151
    }
2152
   ],
2153
   "source": [
2154
    "holdout_df.head(10)\n"
2155
   ]
2156
  },
2157
  {
2158
   "cell_type": "code",
2159
   "execution_count": 334,
2160
   "metadata": {},
2161
   "outputs": [
2162
    {
2163
     "data": {
2164
      "text/plain": [
2165
       " 2    231\n",
2166
       " 0     27\n",
2167
       "Name: pred, dtype: int64"
2168
      ]
2169
     },
2170
     "execution_count": 334,
2171
     "metadata": {},
2172
     "output_type": "execute_result"
2173
    }
2174
   ],
2175
   "source": [
2176
    "holdout_df['pred'].value_counts()\n"
2177
   ]
2178
  },
2179
  {
2180
   "attachments": {},
2181
   "cell_type": "markdown",
2182
   "metadata": {},
2183
   "source": [
2184
    "Well those results were similarly underwhelming - so we've learned that with a dataset with a small number of labelled observations, either zero-shot classification or traditional classification with embeddings return better results than a fine-tuned model.\n",
2185
    "\n",
2186
    "A fine-tuned model is still a great tool, but is more effective when you have a larger number of labelled examples for each class that you're looking to classify"
2187
   ]
2188
  }
2189
 ],
2190
 "metadata": {
2191
  "kernelspec": {
2192
   "display_name": "Python 3",
2193
   "language": "python",
2194
   "name": "python3"
2195
  },
2196
  "language_info": {
2197
   "codemirror_mode": {
2198
    "name": "ipython",
2199
    "version": 3
2200
   },
2201
   "file_extension": ".py",
2202
   "mimetype": "text/x-python",
2203
   "name": "python",
2204
   "nbconvert_exporter": "python",
2205
   "pygments_lexer": "ipython3",
2206
   "version": "3.11.3"
2207
  }
2208
 },
2209
 "nbformat": 4,
2210
 "nbformat_minor": 4
2211
}
2212

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.