Amazing-Python-Scripts

Форк
0
/
Fashion_MNIST_MultiClass_Image_Classification_(CNN).ipynb 
645 строк · 74.5 Кб
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "provenance": []
7
    },
8
    "kernelspec": {
9
      "name": "python3",
10
      "display_name": "Python 3"
11
    },
12
    "language_info": {
13
      "name": "python"
14
    }
15
  },
16
  "cells": [
17
    {
18
      "cell_type": "markdown",
19
      "source": [
20
        "Libraries"
21
      ],
22
      "metadata": {
23
        "id": "X1g131LkMs2l"
24
      }
25
    },
26
    {
27
      "cell_type": "code",
28
      "execution_count": 47,
29
      "metadata": {
30
        "id": "H579zGsgDeBv"
31
      },
32
      "outputs": [],
33
      "source": [
34
        "import tensorflow as tf\n",
35
        "import pandas as pd\n",
36
        "import numpy as np\n",
37
        "from google.colab import files\n",
38
        "import matplotlib.pyplot as plt\n",
39
        "from keras.preprocessing import image"
40
      ]
41
    },
42
    {
43
      "cell_type": "code",
44
      "source": [
45
        "dataset=tf.keras.datasets.fashion_mnist\n",
46
        "(train,train_labels),(test,test_labels)=dataset.load_data()"
47
      ],
48
      "metadata": {
49
        "id": "r2pdmx2FM7MJ"
50
      },
51
      "execution_count": 8,
52
      "outputs": []
53
    },
54
    {
55
      "cell_type": "code",
56
      "source": [
57
        "train.shape"
58
      ],
59
      "metadata": {
60
        "colab": {
61
          "base_uri": "https://localhost:8080/"
62
        },
63
        "id": "rmMBgOUeODHu",
64
        "outputId": "71ae9891-afd6-4b53-aa39-e4327cf3b2e3"
65
      },
66
      "execution_count": 9,
67
      "outputs": [
68
        {
69
          "output_type": "execute_result",
70
          "data": {
71
            "text/plain": [
72
              "(60000, 28, 28)"
73
            ]
74
          },
75
          "metadata": {},
76
          "execution_count": 9
77
        }
78
      ]
79
    },
80
    {
81
      "cell_type": "code",
82
      "source": [
83
        "test.shape"
84
      ],
85
      "metadata": {
86
        "colab": {
87
          "base_uri": "https://localhost:8080/"
88
        },
89
        "id": "QmTbWMRxOK--",
90
        "outputId": "b21820ac-2a2f-4f62-8f2e-016f9db19fd5"
91
      },
92
      "execution_count": 10,
93
      "outputs": [
94
        {
95
          "output_type": "execute_result",
96
          "data": {
97
            "text/plain": [
98
              "(10000, 28, 28)"
99
            ]
100
          },
101
          "metadata": {},
102
          "execution_count": 10
103
        }
104
      ]
105
    },
106
    {
107
      "cell_type": "markdown",
108
      "source": [
109
        "Image Preprocessing"
110
      ],
111
      "metadata": {
112
        "id": "VbqyMa3dN9l1"
113
      }
114
    },
115
    {
116
      "cell_type": "code",
117
      "source": [
118
        "train=train.reshape(60000,28,28,1)\n",
119
        "train=train/255"
120
      ],
121
      "metadata": {
122
        "id": "k2AdnnteNzI8"
123
      },
124
      "execution_count": 12,
125
      "outputs": []
126
    },
127
    {
128
      "cell_type": "code",
129
      "source": [
130
        "test=test.reshape(10000,28,28,1)\n",
131
        "test=test/255"
132
      ],
133
      "metadata": {
134
        "id": "w89GRl4YOcti"
135
      },
136
      "execution_count": 14,
137
      "outputs": []
138
    },
139
    {
140
      "cell_type": "markdown",
141
      "source": [
142
        "CNN"
143
      ],
144
      "metadata": {
145
        "id": "m1XWcdo7QAXK"
146
      }
147
    },
148
    {
149
      "cell_type": "code",
150
      "source": [
151
        "cnn=tf.keras.models.Sequential([\n",
152
        "    tf.keras.layers.Conv2D(64,(3,3),activation=\"relu\",padding=\"same\",input_shape=(28,28,1)),\n",
153
        "    tf.keras.layers.MaxPooling2D(2,2),\n",
154
        "     tf.keras.layers.Conv2D(32,(3,3),activation=\"relu\",padding=\"same\"),\n",
155
        "     tf.keras.layers.MaxPooling2D(2,2),\n",
156
        "     tf.keras.layers.Flatten(),\n",
157
        "    tf.keras.layers.Dense(128,activation=\"relu\"),\n",
158
        "    tf.keras.layers.Dense(64,activation=\"relu\"),\n",
159
        "    tf.keras.layers.Dense(10,activation=\"softmax\")\n",
160
        "])\n",
161
        "\n",
162
        "cnn.summary()"
163
      ],
164
      "metadata": {
165
        "colab": {
166
          "base_uri": "https://localhost:8080/"
167
        },
168
        "id": "BWJJK4-qOktf",
169
        "outputId": "8c0d251c-0add-477e-959f-752997d6e76e"
170
      },
171
      "execution_count": 18,
172
      "outputs": [
173
        {
174
          "output_type": "stream",
175
          "name": "stdout",
176
          "text": [
177
            "Model: \"sequential_1\"\n",
178
            "_________________________________________________________________\n",
179
            " Layer (type)                Output Shape              Param #   \n",
180
            "=================================================================\n",
181
            " conv2d_2 (Conv2D)           (None, 28, 28, 64)        640       \n",
182
            "                                                                 \n",
183
            " max_pooling2d_2 (MaxPooling  (None, 14, 14, 64)       0         \n",
184
            " 2D)                                                             \n",
185
            "                                                                 \n",
186
            " conv2d_3 (Conv2D)           (None, 14, 14, 32)        18464     \n",
187
            "                                                                 \n",
188
            " max_pooling2d_3 (MaxPooling  (None, 7, 7, 32)         0         \n",
189
            " 2D)                                                             \n",
190
            "                                                                 \n",
191
            " flatten_1 (Flatten)         (None, 1568)              0         \n",
192
            "                                                                 \n",
193
            " dense_3 (Dense)             (None, 128)               200832    \n",
194
            "                                                                 \n",
195
            " dense_4 (Dense)             (None, 64)                8256      \n",
196
            "                                                                 \n",
197
            " dense_5 (Dense)             (None, 10)                650       \n",
198
            "                                                                 \n",
199
            "=================================================================\n",
200
            "Total params: 228,842\n",
201
            "Trainable params: 228,842\n",
202
            "Non-trainable params: 0\n",
203
            "_________________________________________________________________\n"
204
          ]
205
        }
206
      ]
207
    },
208
    {
209
      "cell_type": "code",
210
      "source": [
211
        "cnn.compile(optimizer=\"adam\",loss=\"sparse_categorical_crossentropy\",metrics=[\"accuracy\"])"
212
      ],
213
      "metadata": {
214
        "id": "hj6WRNLUQHp_"
215
      },
216
      "execution_count": 19,
217
      "outputs": []
218
    },
219
    {
220
      "cell_type": "code",
221
      "source": [
222
        "model=cnn.fit(train,train_labels,epochs=10)"
223
      ],
224
      "metadata": {
225
        "colab": {
226
          "base_uri": "https://localhost:8080/"
227
        },
228
        "id": "7LIcK4A5QWYa",
229
        "outputId": "7609a3e0-9e5c-4859-aa3a-422d70de872a"
230
      },
231
      "execution_count": 20,
232
      "outputs": [
233
        {
234
          "output_type": "stream",
235
          "name": "stdout",
236
          "text": [
237
            "Epoch 1/10\n",
238
            "1875/1875 [==============================] - 134s 71ms/step - loss: 0.4358 - accuracy: 0.8425\n",
239
            "Epoch 2/10\n",
240
            "1875/1875 [==============================] - 128s 68ms/step - loss: 0.2832 - accuracy: 0.8957\n",
241
            "Epoch 3/10\n",
242
            "1875/1875 [==============================] - 124s 66ms/step - loss: 0.2402 - accuracy: 0.9115\n",
243
            "Epoch 4/10\n",
244
            "1875/1875 [==============================] - 133s 71ms/step - loss: 0.2101 - accuracy: 0.9212\n",
245
            "Epoch 5/10\n",
246
            "1875/1875 [==============================] - 124s 66ms/step - loss: 0.1834 - accuracy: 0.9311\n",
247
            "Epoch 6/10\n",
248
            "1875/1875 [==============================] - 120s 64ms/step - loss: 0.1645 - accuracy: 0.9382\n",
249
            "Epoch 7/10\n",
250
            "1875/1875 [==============================] - 122s 65ms/step - loss: 0.1472 - accuracy: 0.9445\n",
251
            "Epoch 8/10\n",
252
            "1875/1875 [==============================] - 124s 66ms/step - loss: 0.1311 - accuracy: 0.9505\n",
253
            "Epoch 9/10\n",
254
            "1875/1875 [==============================] - 122s 65ms/step - loss: 0.1148 - accuracy: 0.9565\n",
255
            "Epoch 10/10\n",
256
            "1875/1875 [==============================] - 121s 65ms/step - loss: 0.1023 - accuracy: 0.9611\n"
257
          ]
258
        }
259
      ]
260
    },
261
    {
262
      "cell_type": "code",
263
      "source": [
264
        "train_labels"
265
      ],
266
      "metadata": {
267
        "colab": {
268
          "base_uri": "https://localhost:8080/"
269
        },
270
        "id": "wkar6_fjc_aB",
271
        "outputId": "336628f0-eacb-4bc4-d0dd-906b313e81ff"
272
      },
273
      "execution_count": 45,
274
      "outputs": [
275
        {
276
          "output_type": "execute_result",
277
          "data": {
278
            "text/plain": [
279
              "array([9, 0, 0, ..., 3, 0, 5], dtype=uint8)"
280
            ]
281
          },
282
          "metadata": {},
283
          "execution_count": 45
284
        }
285
      ]
286
    },
287
    {
288
      "cell_type": "code",
289
      "source": [
290
        "plt.plot(range(1, 11), model.history['accuracy'], marker='o')\n",
291
        "plt.xlabel('Epochs')\n",
292
        "plt.ylabel('Accuracy')\n",
293
        "plt.title('Training Accuracy over Epochs')"
294
      ],
295
      "metadata": {
296
        "colab": {
297
          "base_uri": "https://localhost:8080/",
298
          "height": 489
299
        },
300
        "id": "OKqilKpDXv3D",
301
        "outputId": "31ef7d1b-99b4-48b7-d3f6-a953b23d92d7"
302
      },
303
      "execution_count": 30,
304
      "outputs": [
305
        {
306
          "output_type": "execute_result",
307
          "data": {
308
            "text/plain": [
309
              "Text(0.5, 1.0, 'Training Accuracy over Epochs')"
310
            ]
311
          },
312
          "metadata": {},
313
          "execution_count": 30
314
        },
315
        {
316
          "output_type": "display_data",
317
          "data": {
318
            "text/plain": [
319
              "<Figure size 640x480 with 1 Axes>"
320
            ],
321
            "image/png": "\n"
322
          },
323
          "metadata": {}
324
        }
325
      ]
326
    },
327
    {
328
      "cell_type": "code",
329
      "source": [
330
        "testing_accuracy = []\n",
331
        "for epoch in range(10):\n",
332
        "    testing_results = cnn.evaluate(test, test_labels, verbose=0)\n",
333
        "    testing_accuracy.append(testing_results[1])"
334
      ],
335
      "metadata": {
336
        "id": "LXtuThIBYFIn"
337
      },
338
      "execution_count": 31,
339
      "outputs": []
340
    },
341
    {
342
      "cell_type": "code",
343
      "source": [
344
        "plt.plot(range(1, 11), testing_accuracy, marker='o', label='Testing Accuracy')"
345
      ],
346
      "metadata": {
347
        "colab": {
348
          "base_uri": "https://localhost:8080/",
349
          "height": 447
350
        },
351
        "id": "ubNSWfTgYTX7",
352
        "outputId": "397ee654-720e-40f5-9583-d38fdd15d8cd"
353
      },
354
      "execution_count": 32,
355
      "outputs": [
356
        {
357
          "output_type": "execute_result",
358
          "data": {
359
            "text/plain": [
360
              "[<matplotlib.lines.Line2D at 0x7ca5b90c8b80>]"
361
            ]
362
          },
363
          "metadata": {},
364
          "execution_count": 32
365
        },
366
        {
367
          "output_type": "display_data",
368
          "data": {
369
            "text/plain": [
370
              "<Figure size 640x480 with 1 Axes>"
371
            ],
372
            "image/png": "\n"
373
          },
374
          "metadata": {}
375
        }
376
      ]
377
    },
378
    {
379
      "cell_type": "code",
380
      "source": [
381
        "uploaded_file = files.upload()\n",
382
        "uploaded_image_path = list(uploaded_file.keys())[0]"
383
      ],
384
      "metadata": {
385
        "colab": {
386
          "base_uri": "https://localhost:8080/",
387
          "height": 73
388
        },
389
        "id": "uHpMx4y2RzxQ",
390
        "outputId": "81089118-b052-409f-c5f6-1168a474244a"
391
      },
392
      "execution_count": 23,
393
      "outputs": [
394
        {
395
          "output_type": "display_data",
396
          "data": {
397
            "text/plain": [
398
              "<IPython.core.display.HTML object>"
399
            ],
400
            "text/html": [
401
              "\n",
402
              "     <input type=\"file\" id=\"files-9fee1ebd-772a-4401-be69-08f68113fd98\" name=\"files[]\" multiple disabled\n",
403
              "        style=\"border:none\" />\n",
404
              "     <output id=\"result-9fee1ebd-772a-4401-be69-08f68113fd98\">\n",
405
              "      Upload widget is only available when the cell has been executed in the\n",
406
              "      current browser session. Please rerun this cell to enable.\n",
407
              "      </output>\n",
408
              "      <script>// Copyright 2017 Google LLC\n",
409
              "//\n",
410
              "// Licensed under the Apache License, Version 2.0 (the \"License\");\n",
411
              "// you may not use this file except in compliance with the License.\n",
412
              "// You may obtain a copy of the License at\n",
413
              "//\n",
414
              "//      http://www.apache.org/licenses/LICENSE-2.0\n",
415
              "//\n",
416
              "// Unless required by applicable law or agreed to in writing, software\n",
417
              "// distributed under the License is distributed on an \"AS IS\" BASIS,\n",
418
              "// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
419
              "// See the License for the specific language governing permissions and\n",
420
              "// limitations under the License.\n",
421
              "\n",
422
              "/**\n",
423
              " * @fileoverview Helpers for google.colab Python module.\n",
424
              " */\n",
425
              "(function(scope) {\n",
426
              "function span(text, styleAttributes = {}) {\n",
427
              "  const element = document.createElement('span');\n",
428
              "  element.textContent = text;\n",
429
              "  for (const key of Object.keys(styleAttributes)) {\n",
430
              "    element.style[key] = styleAttributes[key];\n",
431
              "  }\n",
432
              "  return element;\n",
433
              "}\n",
434
              "\n",
435
              "// Max number of bytes which will be uploaded at a time.\n",
436
              "const MAX_PAYLOAD_SIZE = 100 * 1024;\n",
437
              "\n",
438
              "function _uploadFiles(inputId, outputId) {\n",
439
              "  const steps = uploadFilesStep(inputId, outputId);\n",
440
              "  const outputElement = document.getElementById(outputId);\n",
441
              "  // Cache steps on the outputElement to make it available for the next call\n",
442
              "  // to uploadFilesContinue from Python.\n",
443
              "  outputElement.steps = steps;\n",
444
              "\n",
445
              "  return _uploadFilesContinue(outputId);\n",
446
              "}\n",
447
              "\n",
448
              "// This is roughly an async generator (not supported in the browser yet),\n",
449
              "// where there are multiple asynchronous steps and the Python side is going\n",
450
              "// to poll for completion of each step.\n",
451
              "// This uses a Promise to block the python side on completion of each step,\n",
452
              "// then passes the result of the previous step as the input to the next step.\n",
453
              "function _uploadFilesContinue(outputId) {\n",
454
              "  const outputElement = document.getElementById(outputId);\n",
455
              "  const steps = outputElement.steps;\n",
456
              "\n",
457
              "  const next = steps.next(outputElement.lastPromiseValue);\n",
458
              "  return Promise.resolve(next.value.promise).then((value) => {\n",
459
              "    // Cache the last promise value to make it available to the next\n",
460
              "    // step of the generator.\n",
461
              "    outputElement.lastPromiseValue = value;\n",
462
              "    return next.value.response;\n",
463
              "  });\n",
464
              "}\n",
465
              "\n",
466
              "/**\n",
467
              " * Generator function which is called between each async step of the upload\n",
468
              " * process.\n",
469
              " * @param {string} inputId Element ID of the input file picker element.\n",
470
              " * @param {string} outputId Element ID of the output display.\n",
471
              " * @return {!Iterable<!Object>} Iterable of next steps.\n",
472
              " */\n",
473
              "function* uploadFilesStep(inputId, outputId) {\n",
474
              "  const inputElement = document.getElementById(inputId);\n",
475
              "  inputElement.disabled = false;\n",
476
              "\n",
477
              "  const outputElement = document.getElementById(outputId);\n",
478
              "  outputElement.innerHTML = '';\n",
479
              "\n",
480
              "  const pickedPromise = new Promise((resolve) => {\n",
481
              "    inputElement.addEventListener('change', (e) => {\n",
482
              "      resolve(e.target.files);\n",
483
              "    });\n",
484
              "  });\n",
485
              "\n",
486
              "  const cancel = document.createElement('button');\n",
487
              "  inputElement.parentElement.appendChild(cancel);\n",
488
              "  cancel.textContent = 'Cancel upload';\n",
489
              "  const cancelPromise = new Promise((resolve) => {\n",
490
              "    cancel.onclick = () => {\n",
491
              "      resolve(null);\n",
492
              "    };\n",
493
              "  });\n",
494
              "\n",
495
              "  // Wait for the user to pick the files.\n",
496
              "  const files = yield {\n",
497
              "    promise: Promise.race([pickedPromise, cancelPromise]),\n",
498
              "    response: {\n",
499
              "      action: 'starting',\n",
500
              "    }\n",
501
              "  };\n",
502
              "\n",
503
              "  cancel.remove();\n",
504
              "\n",
505
              "  // Disable the input element since further picks are not allowed.\n",
506
              "  inputElement.disabled = true;\n",
507
              "\n",
508
              "  if (!files) {\n",
509
              "    return {\n",
510
              "      response: {\n",
511
              "        action: 'complete',\n",
512
              "      }\n",
513
              "    };\n",
514
              "  }\n",
515
              "\n",
516
              "  for (const file of files) {\n",
517
              "    const li = document.createElement('li');\n",
518
              "    li.append(span(file.name, {fontWeight: 'bold'}));\n",
519
              "    li.append(span(\n",
520
              "        `(${file.type || 'n/a'}) - ${file.size} bytes, ` +\n",
521
              "        `last modified: ${\n",
522
              "            file.lastModifiedDate ? file.lastModifiedDate.toLocaleDateString() :\n",
523
              "                                    'n/a'} - `));\n",
524
              "    const percent = span('0% done');\n",
525
              "    li.appendChild(percent);\n",
526
              "\n",
527
              "    outputElement.appendChild(li);\n",
528
              "\n",
529
              "    const fileDataPromise = new Promise((resolve) => {\n",
530
              "      const reader = new FileReader();\n",
531
              "      reader.onload = (e) => {\n",
532
              "        resolve(e.target.result);\n",
533
              "      };\n",
534
              "      reader.readAsArrayBuffer(file);\n",
535
              "    });\n",
536
              "    // Wait for the data to be ready.\n",
537
              "    let fileData = yield {\n",
538
              "      promise: fileDataPromise,\n",
539
              "      response: {\n",
540
              "        action: 'continue',\n",
541
              "      }\n",
542
              "    };\n",
543
              "\n",
544
              "    // Use a chunked sending to avoid message size limits. See b/62115660.\n",
545
              "    let position = 0;\n",
546
              "    do {\n",
547
              "      const length = Math.min(fileData.byteLength - position, MAX_PAYLOAD_SIZE);\n",
548
              "      const chunk = new Uint8Array(fileData, position, length);\n",
549
              "      position += length;\n",
550
              "\n",
551
              "      const base64 = btoa(String.fromCharCode.apply(null, chunk));\n",
552
              "      yield {\n",
553
              "        response: {\n",
554
              "          action: 'append',\n",
555
              "          file: file.name,\n",
556
              "          data: base64,\n",
557
              "        },\n",
558
              "      };\n",
559
              "\n",
560
              "      let percentDone = fileData.byteLength === 0 ?\n",
561
              "          100 :\n",
562
              "          Math.round((position / fileData.byteLength) * 100);\n",
563
              "      percent.textContent = `${percentDone}% done`;\n",
564
              "\n",
565
              "    } while (position < fileData.byteLength);\n",
566
              "  }\n",
567
              "\n",
568
              "  // All done.\n",
569
              "  yield {\n",
570
              "    response: {\n",
571
              "      action: 'complete',\n",
572
              "    }\n",
573
              "  };\n",
574
              "}\n",
575
              "\n",
576
              "scope.google = scope.google || {};\n",
577
              "scope.google.colab = scope.google.colab || {};\n",
578
              "scope.google.colab._files = {\n",
579
              "  _uploadFiles,\n",
580
              "  _uploadFilesContinue,\n",
581
              "};\n",
582
              "})(self);\n",
583
              "</script> "
584
            ]
585
          },
586
          "metadata": {}
587
        },
588
        {
589
          "output_type": "stream",
590
          "name": "stdout",
591
          "text": [
592
            "Saving pngwing.com (3).png to pngwing.com (3).png\n"
593
          ]
594
        }
595
      ]
596
    },
597
    {
598
      "cell_type": "code",
599
      "source": [
600
        "test_image = test[0]\n",
601
        "test_image = np.expand_dims(test_image, axis=0)\n",
602
        "test_image = test_image / 255.0\n",
603
        "\n",
604
        "predicted_label = cnn.predict(test_image)\n",
605
        "predicted_label = np.argmax(predicted_label)\n",
606
        "\n",
607
        "class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']\n",
608
        "predicted_class = class_names[predicted_label]\n",
609
        "\n",
610
        "plt.imshow(test[0].reshape(28, 28), cmap='gray')\n",
611
        "plt.title(f\"Predicted Class: {predicted_class}\")\n",
612
        "plt.axis('off')\n",
613
        "plt.show()\n"
614
      ],
615
      "metadata": {
616
        "colab": {
617
          "base_uri": "https://localhost:8080/",
618
          "height": 445
619
        },
620
        "id": "vWQp4SSDcT4Z",
621
        "outputId": "a6e4a2a2-affd-4b6c-a957-f3f601f073b0"
622
      },
623
      "execution_count": 46,
624
      "outputs": [
625
        {
626
          "output_type": "stream",
627
          "name": "stdout",
628
          "text": [
629
            "1/1 [==============================] - 0s 25ms/step\n"
630
          ]
631
        },
632
        {
633
          "output_type": "display_data",
634
          "data": {
635
            "text/plain": [
636
              "<Figure size 640x480 with 1 Axes>"
637
            ],
638
            "image/png": "\n"
639
          },
640
          "metadata": {}
641
        }
642
      ]
643
    }
644
  ]
645
}

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.