Amazing-Python-Scripts

Форк
0
/
Predicting_Boston_Housing_Prices_using_CatBoost_Regression (1).ipynb 
335 строк · 65.4 Кб
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "provenance": []
7
    },
8
    "kernelspec": {
9
      "name": "python3",
10
      "display_name": "Python 3"
11
    },
12
    "language_info": {
13
      "name": "python"
14
    }
15
  },
16
  "cells": [
17
    {
18
      "cell_type": "code",
19
      "source": [
20
        "!pip install catboost"
21
      ],
22
      "metadata": {
23
        "colab": {
24
          "base_uri": "https://localhost:8080/"
25
        },
26
        "id": "fjtCd9-2YJtA",
27
        "outputId": "5f63d4c3-48c2-4cb4-9fd7-d026c7c7ddb3"
28
      },
29
      "execution_count": 44,
30
      "outputs": [
31
        {
32
          "output_type": "stream",
33
          "name": "stdout",
34
          "text": [
35
            "Requirement already satisfied: catboost in /usr/local/lib/python3.10/dist-packages (1.2)\n",
36
            "Requirement already satisfied: graphviz in /usr/local/lib/python3.10/dist-packages (from catboost) (0.20.1)\n",
37
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from catboost) (3.7.1)\n",
38
            "Requirement already satisfied: numpy>=1.16.0 in /usr/local/lib/python3.10/dist-packages (from catboost) (1.23.5)\n",
39
            "Requirement already satisfied: pandas>=0.24 in /usr/local/lib/python3.10/dist-packages (from catboost) (1.5.3)\n",
40
            "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from catboost) (1.10.1)\n",
41
            "Requirement already satisfied: plotly in /usr/local/lib/python3.10/dist-packages (from catboost) (5.13.1)\n",
42
            "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from catboost) (1.16.0)\n",
43
            "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24->catboost) (2.8.2)\n",
44
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24->catboost) (2022.7.1)\n",
45
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (1.1.0)\n",
46
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (0.11.0)\n",
47
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (4.42.0)\n",
48
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (1.4.4)\n",
49
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (23.1)\n",
50
            "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (9.4.0)\n",
51
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (3.1.1)\n",
52
            "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly->catboost) (8.2.2)\n"
53
          ]
54
        }
55
      ]
56
    },
57
    {
58
      "cell_type": "code",
59
      "source": [
60
        "import pandas as pd\n",
61
        "import numpy as np\n",
62
        "from catboost import CatBoostRegressor\n",
63
        "from sklearn.model_selection import train_test_split\n",
64
        "from sklearn.metrics import mean_squared_error\n",
65
        "import matplotlib.pyplot as plt\n",
66
        "import seaborn as sns"
67
      ],
68
      "metadata": {
69
        "id": "lCEz3mIoXx6Y"
70
      },
71
      "execution_count": 45,
72
      "outputs": []
73
    },
74
    {
75
      "cell_type": "code",
76
      "source": [
77
        "data_url = \"http://lib.stat.cmu.edu/datasets/boston\"\n",
78
        "raw_df = pd.read_csv(data_url, sep=\"\\s+\", skiprows=22, header=None)\n",
79
        "df = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])\n",
80
        "target = raw_df.values[1::2, 2]"
81
      ],
82
      "metadata": {
83
        "id": "PRd11JRTX1Qn"
84
      },
85
      "execution_count": 46,
86
      "outputs": []
87
    },
88
    {
89
      "cell_type": "code",
90
      "source": [
91
        "df.shape"
92
      ],
93
      "metadata": {
94
        "colab": {
95
          "base_uri": "https://localhost:8080/"
96
        },
97
        "id": "-EvoYWuIaCwk",
98
        "outputId": "5bef08c0-ae53-4598-9672-6b97cce059f7"
99
      },
100
      "execution_count": 47,
101
      "outputs": [
102
        {
103
          "output_type": "execute_result",
104
          "data": {
105
            "text/plain": [
106
              "(506, 13)"
107
            ]
108
          },
109
          "metadata": {},
110
          "execution_count": 47
111
        }
112
      ]
113
    },
114
    {
115
      "cell_type": "code",
116
      "source": [
117
        "df.dtype"
118
      ],
119
      "metadata": {
120
        "colab": {
121
          "base_uri": "https://localhost:8080/"
122
        },
123
        "id": "RenipdVGaaf7",
124
        "outputId": "fa6ed386-1042-4399-857b-825ee627c4ad"
125
      },
126
      "execution_count": 48,
127
      "outputs": [
128
        {
129
          "output_type": "execute_result",
130
          "data": {
131
            "text/plain": [
132
              "dtype('float64')"
133
            ]
134
          },
135
          "metadata": {},
136
          "execution_count": 48
137
        }
138
      ]
139
    },
140
    {
141
      "cell_type": "code",
142
      "source": [
143
        "X_train, X_test, y_train, y_test = train_test_split(df, target, test_size=0.2, random_state=42)\n"
144
      ],
145
      "metadata": {
146
        "id": "e-GxZmfEYm1j"
147
      },
148
      "execution_count": 49,
149
      "outputs": []
150
    },
151
    {
152
      "cell_type": "code",
153
      "source": [
154
        "model = CatBoostRegressor(iterations=1000, depth=8, learning_rate=0.05, loss_function='RMSE',\n",
155
        "                           random_seed=42, l2_leaf_reg=1)"
156
      ],
157
      "metadata": {
158
        "id": "G0x4Y6z6Xgun"
159
      },
160
      "execution_count": 50,
161
      "outputs": []
162
    },
163
    {
164
      "cell_type": "code",
165
      "source": [
166
        "model.fit(X_train, y_train, verbose=100, eval_set=(X_test, y_test), early_stopping_rounds=50)"
167
      ],
168
      "metadata": {
169
        "colab": {
170
          "base_uri": "https://localhost:8080/"
171
        },
172
        "id": "-jD7L2X8XkEf",
173
        "outputId": "da585195-5991-40a8-bb26-2304c756c1c8"
174
      },
175
      "execution_count": 51,
176
      "outputs": [
177
        {
178
          "output_type": "stream",
179
          "name": "stdout",
180
          "text": [
181
            "0:\tlearn: 9.0020476\ttest: 8.4095123\tbest: 8.4095123 (0)\ttotal: 13.6ms\tremaining: 13.6s\n",
182
            "100:\tlearn: 1.7496020\ttest: 3.2646698\tbest: 3.2646698 (100)\ttotal: 601ms\tremaining: 5.35s\n",
183
            "200:\tlearn: 0.9726505\ttest: 3.0362205\tbest: 3.0362205 (200)\ttotal: 1.44s\tremaining: 5.74s\n",
184
            "300:\tlearn: 0.5818381\ttest: 2.9857265\tbest: 2.9857265 (300)\ttotal: 2.07s\tremaining: 4.82s\n",
185
            "Stopped by overfitting detector  (50 iterations wait)\n",
186
            "\n",
187
            "bestTest = 2.980601834\n",
188
            "bestIteration = 314\n",
189
            "\n",
190
            "Shrink model to first 315 iterations.\n"
191
          ]
192
        },
193
        {
194
          "output_type": "execute_result",
195
          "data": {
196
            "text/plain": [
197
              "<catboost.core.CatBoostRegressor at 0x7bb8f2b08a30>"
198
            ]
199
          },
200
          "metadata": {},
201
          "execution_count": 51
202
        }
203
      ]
204
    },
205
    {
206
      "cell_type": "code",
207
      "source": [
208
        "y_pred = model.predict(X_test)"
209
      ],
210
      "metadata": {
211
        "id": "IB6QbK9xXnZa"
212
      },
213
      "execution_count": 52,
214
      "outputs": []
215
    },
216
    {
217
      "cell_type": "code",
218
      "source": [
219
        "y_pred"
220
      ],
221
      "metadata": {
222
        "colab": {
223
          "base_uri": "https://localhost:8080/"
224
        },
225
        "id": "N7JUDoILYt7y",
226
        "outputId": "adfc6754-2010-403e-ea8d-6f6a69eb0cb4"
227
      },
228
      "execution_count": 53,
229
      "outputs": [
230
        {
231
          "output_type": "execute_result",
232
          "data": {
233
            "text/plain": [
234
              "array([24.61951657, 29.79612399, 16.10995832, 23.90291017, 16.16479168,\n",
235
              "       21.96620609, 19.41533747, 14.85955243, 21.19257177, 20.46921239,\n",
236
              "       21.25112561, 19.42673204,  7.80799285, 21.73588536, 19.43968294,\n",
237
              "       23.57411435, 19.75321019,  9.37650806, 43.76299704, 14.28887094,\n",
238
              "       24.95484852, 25.1917329 , 13.33651447, 21.58680042, 14.40969043,\n",
239
              "       16.09083103, 22.30573291, 13.90180886, 20.20803133, 20.63283845,\n",
240
              "       20.95367068, 23.56590007, 21.37815241, 22.03890376, 15.11999286,\n",
241
              "       16.49008812, 35.52581167, 18.52302   , 22.9998298 , 23.24337579,\n",
242
              "       18.09930231, 29.42451049, 42.64947385, 19.48418839, 22.93099906,\n",
243
              "       13.81754368, 14.22441547, 24.09969962, 19.20374409, 26.05365903,\n",
244
              "       22.53293348, 36.45414434, 18.23660788, 24.90167829, 45.03091217,\n",
245
              "       22.50490127, 14.2964888 , 32.60026588, 22.04437289, 18.49242439,\n",
246
              "       23.55633875, 35.51997742, 29.5665242 , 18.17335086, 23.85466921,\n",
247
              "       18.37995176, 13.73776741, 23.64060698, 28.58094162, 14.67940568,\n",
248
              "       20.55768573, 24.54235025, 10.55305368, 20.77955652, 22.48437537,\n",
249
              "        7.04140151, 20.37419805, 43.59246225, 11.52873761, 10.56385587,\n",
250
              "       21.70285593, 13.76140818, 20.07970065, 10.45062855, 20.04264377,\n",
251
              "       27.82619285, 16.36425016, 23.43126727, 24.48350906, 18.47783108,\n",
252
              "       22.5117491 ,  9.62743311, 19.41891295, 18.22799834, 29.42397234,\n",
253
              "       19.59499851, 30.82681357, 11.09794744, 11.96605782, 15.49374648,\n",
254
              "       20.45985914, 24.01940107])"
255
            ]
256
          },
257
          "metadata": {},
258
          "execution_count": 53
259
        }
260
      ]
261
    },
262
    {
263
      "cell_type": "code",
264
      "source": [
265
        "plt.figure(figsize=(10, 6))\n",
266
        "sns.regplot(x=y_test, y=y_pred, scatter_kws={'s': 20})\n",
267
        "plt.title(\"Regression Plot: True Values vs Predicted Values\")"
268
      ],
269
      "metadata": {
270
        "colab": {
271
          "base_uri": "https://localhost:8080/",
272
          "height": 562
273
        },
274
        "id": "ekdHHCKOZXuu",
275
        "outputId": "ee2a9b75-a932-401b-92a9-858bea4eae7d"
276
      },
277
      "execution_count": 54,
278
      "outputs": [
279
        {
280
          "output_type": "execute_result",
281
          "data": {
282
            "text/plain": [
283
              "Text(0.5, 1.0, 'Regression Plot: True Values vs Predicted Values')"
284
            ]
285
          },
286
          "metadata": {},
287
          "execution_count": 54
288
        },
289
        {
290
          "output_type": "display_data",
291
          "data": {
292
            "text/plain": [
293
              "<Figure size 1000x600 with 1 Axes>"
294
            ],
295
            "image/png": "\n"
296
          },
297
          "metadata": {}
298
        }
299
      ]
300
    },
301
    {
302
      "cell_type": "code",
303
      "source": [
304
        "mse = mean_squared_error(y_test, y_pred)\n",
305
        "print(\"Mean Squared Error:\", mse)"
306
      ],
307
      "metadata": {
308
        "colab": {
309
          "base_uri": "https://localhost:8080/"
310
        },
311
        "id": "fEx9UCU8YxAn",
312
        "outputId": "eec37ce6-8733-4b10-effe-d0a312c5672d"
313
      },
314
      "execution_count": 55,
315
      "outputs": [
316
        {
317
          "output_type": "stream",
318
          "name": "stdout",
319
          "text": [
320
            "Mean Squared Error: 8.883987158812726\n"
321
          ]
322
        }
323
      ]
324
    },
325
    {
326
      "cell_type": "code",
327
      "source": [],
328
      "metadata": {
329
        "id": "WjzdGpQRZXRd"
330
      },
331
      "execution_count": 55,
332
      "outputs": []
333
    }
334
  ]
335
}

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.