Amazing-Python-Scripts

Форк
0
536 строк · 20.4 Кб
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "name": "AutoEncoders.ipynb",
7
      "provenance": [],
8
      "collapsed_sections": [],
9
      "toc_visible": true,
10
      "include_colab_link": true
11
    },
12
    "kernelspec": {
13
      "name": "python3",
14
      "display_name": "Python 3"
15
    },
16
    "accelerator": "GPU"
17
  },
18
  "cells": [
19
    {
20
      "cell_type": "markdown",
21
      "metadata": {
22
        "id": "view-in-github",
23
        "colab_type": "text"
24
      },
25
      "source": [
26
        "<a href=\"https://colab.research.google.com/github/ayush-09/AutoEncoder-Deep-Learning/blob/main/AutoEncoders.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
27
      ]
28
    },
29
    {
30
      "cell_type": "markdown",
31
      "metadata": {
32
        "id": "K4f4JG1gdKqj"
33
      },
34
      "source": [
35
        "#AutoEncoders"
36
      ]
37
    },
38
    {
39
      "cell_type": "code",
40
      "metadata": {
41
        "id": "rjOPzue7FCXJ",
42
        "colab": {
43
          "base_uri": "https://localhost:8080/"
44
        },
45
        "outputId": "c488b69b-9cc1-4cd7-b954-0928c41b410d"
46
      },
47
      "source": [
48
        "!wget \"http://files.grouplens.org/datasets/movielens/ml-100k.zip\"\n",
49
        "!unzip ml-100k.zip\n",
50
        "!ls"
51
      ],
52
      "execution_count": 1,
53
      "outputs": [
54
        {
55
          "output_type": "stream",
56
          "text": [
57
            "--2021-01-09 14:08:13--  http://files.grouplens.org/datasets/movielens/ml-100k.zip\n",
58
            "Resolving files.grouplens.org (files.grouplens.org)... 128.101.65.152\n",
59
            "Connecting to files.grouplens.org (files.grouplens.org)|128.101.65.152|:80... connected.\n",
60
            "HTTP request sent, awaiting response... 200 OK\n",
61
            "Length: 4924029 (4.7M) [application/zip]\n",
62
            "Saving to: ‘ml-100k.zip’\n",
63
            "\n",
64
            "ml-100k.zip         100%[===================>]   4.70M  18.8MB/s    in 0.2s    \n",
65
            "\n",
66
            "2021-01-09 14:08:13 (18.8 MB/s) - ‘ml-100k.zip’ saved [4924029/4924029]\n",
67
            "\n",
68
            "Archive:  ml-100k.zip\n",
69
            "   creating: ml-100k/\n",
70
            "  inflating: ml-100k/allbut.pl       \n",
71
            "  inflating: ml-100k/mku.sh          \n",
72
            "  inflating: ml-100k/README          \n",
73
            "  inflating: ml-100k/u.data          \n",
74
            "  inflating: ml-100k/u.genre         \n",
75
            "  inflating: ml-100k/u.info          \n",
76
            "  inflating: ml-100k/u.item          \n",
77
            "  inflating: ml-100k/u.occupation    \n",
78
            "  inflating: ml-100k/u.user          \n",
79
            "  inflating: ml-100k/u1.base         \n",
80
            "  inflating: ml-100k/u1.test         \n",
81
            "  inflating: ml-100k/u2.base         \n",
82
            "  inflating: ml-100k/u2.test         \n",
83
            "  inflating: ml-100k/u3.base         \n",
84
            "  inflating: ml-100k/u3.test         \n",
85
            "  inflating: ml-100k/u4.base         \n",
86
            "  inflating: ml-100k/u4.test         \n",
87
            "  inflating: ml-100k/u5.base         \n",
88
            "  inflating: ml-100k/u5.test         \n",
89
            "  inflating: ml-100k/ua.base         \n",
90
            "  inflating: ml-100k/ua.test         \n",
91
            "  inflating: ml-100k/ub.base         \n",
92
            "  inflating: ml-100k/ub.test         \n",
93
            "ml-100k  ml-100k.zip  sample_data\n"
94
          ],
95
          "name": "stdout"
96
        }
97
      ]
98
    },
99
    {
100
      "cell_type": "code",
101
      "metadata": {
102
        "id": "LOly1yfAfTjd",
103
        "colab": {
104
          "base_uri": "https://localhost:8080/"
105
        },
106
        "outputId": "97173fc2-9434-4529-9841-2bb989507c05"
107
      },
108
      "source": [
109
        "!wget \"http://files.grouplens.org/datasets/movielens/ml-1m.zip\"\n",
110
        "!unzip ml-1m.zip\n",
111
        "!ls"
112
      ],
113
      "execution_count": 2,
114
      "outputs": [
115
        {
116
          "output_type": "stream",
117
          "text": [
118
            "--2021-01-09 14:08:24--  http://files.grouplens.org/datasets/movielens/ml-1m.zip\n",
119
            "Resolving files.grouplens.org (files.grouplens.org)... 128.101.65.152\n",
120
            "Connecting to files.grouplens.org (files.grouplens.org)|128.101.65.152|:80... connected.\n",
121
            "HTTP request sent, awaiting response... 200 OK\n",
122
            "Length: 5917549 (5.6M) [application/zip]\n",
123
            "Saving to: ‘ml-1m.zip’\n",
124
            "\n",
125
            "ml-1m.zip           100%[===================>]   5.64M  22.1MB/s    in 0.3s    \n",
126
            "\n",
127
            "2021-01-09 14:08:25 (22.1 MB/s) - ‘ml-1m.zip’ saved [5917549/5917549]\n",
128
            "\n",
129
            "Archive:  ml-1m.zip\n",
130
            "   creating: ml-1m/\n",
131
            "  inflating: ml-1m/movies.dat        \n",
132
            "  inflating: ml-1m/ratings.dat       \n",
133
            "  inflating: ml-1m/README            \n",
134
            "  inflating: ml-1m/users.dat         \n",
135
            "ml-100k  ml-100k.zip  ml-1m  ml-1m.zip\tsample_data\n"
136
          ],
137
          "name": "stdout"
138
        }
139
      ]
140
    },
141
    {
142
      "cell_type": "code",
143
      "metadata": {
144
        "id": "_LvGeU1CeCtg"
145
      },
146
      "source": [
147
        "import numpy as np\n",
148
        "import pandas as pd\n",
149
        "import torch\n",
150
        "import torch.nn as nn\n",
151
        "import torch.nn.parallel\n",
152
        "import torch.optim as optim\n",
153
        "import torch.utils.data\n",
154
        "from torch.autograd import Variable"
155
      ],
156
      "execution_count": 3,
157
      "outputs": []
158
    },
159
    {
160
      "cell_type": "code",
161
      "metadata": {
162
        "id": "UJw2p3-Cewo4"
163
      },
164
      "source": [
165
        "# We won't be using this dataset.\n",
166
        "movies = pd.read_csv('ml-1m/movies.dat', sep = '::', header = None, engine = 'python', encoding = 'latin-1')\n",
167
        "users = pd.read_csv('ml-1m/users.dat', sep = '::', header = None, engine = 'python', encoding = 'latin-1')\n",
168
        "ratings = pd.read_csv('ml-1m/ratings.dat', sep = '::', header = None, engine = 'python', encoding = 'latin-1')"
169
      ],
170
      "execution_count": 4,
171
      "outputs": []
172
    },
173
    {
174
      "cell_type": "code",
175
      "metadata": {
176
        "id": "2usLKJBEgPE2"
177
      },
178
      "source": [
179
        "training_set = pd.read_csv('ml-100k/u1.base', delimiter = '\\t')\n",
180
        "training_set = np.array(training_set, dtype = 'int')\n",
181
        "test_set = pd.read_csv('ml-100k/u1.test', delimiter = '\\t')\n",
182
        "test_set = np.array(test_set, dtype = 'int')"
183
      ],
184
      "execution_count": 6,
185
      "outputs": []
186
    },
187
    {
188
      "cell_type": "code",
189
      "metadata": {
190
        "id": "gPaGZqdniC5m"
191
      },
192
      "source": [
193
        "nb_users = int(max(max(training_set[:, 0], ), max(test_set[:, 0])))\n",
194
        "nb_movies = int(max(max(training_set[:, 1], ), max(test_set[:, 1])))"
195
      ],
196
      "execution_count": 7,
197
      "outputs": []
198
    },
199
    {
200
      "cell_type": "code",
201
      "metadata": {
202
        "id": "-wASs2YFiDaa"
203
      },
204
      "source": [
205
        "def convert(data):\n",
206
        "  new_data = []\n",
207
        "  for id_users in range(1, nb_users + 1):\n",
208
        "    id_movies = data[:, 1] [data[:, 0] == id_users]\n",
209
        "    id_ratings = data[:, 2] [data[:, 0] == id_users]\n",
210
        "    ratings = np.zeros(nb_movies)\n",
211
        "    ratings[id_movies - 1] = id_ratings\n",
212
        "    new_data.append(list(ratings))\n",
213
        "  return new_data\n",
214
        "training_set = convert(training_set)\n",
215
        "test_set = convert(test_set)"
216
      ],
217
      "execution_count": 8,
218
      "outputs": []
219
    },
220
    {
221
      "cell_type": "code",
222
      "metadata": {
223
        "id": "TwD-KD8yiEEw"
224
      },
225
      "source": [
226
        "training_set = torch.FloatTensor(training_set)\n",
227
        "test_set = torch.FloatTensor(test_set)"
228
      ],
229
      "execution_count": 9,
230
      "outputs": []
231
    },
232
    {
233
      "cell_type": "code",
234
      "metadata": {
235
        "id": "oU2nyh76iE6M"
236
      },
237
      "source": [
238
        "class SAE(nn.Module):\n",
239
        "    def __init__(self, ):\n",
240
        "        super(SAE, self).__init__()\n",
241
        "        self.fc1 = nn.Linear(nb_movies, 20)\n",
242
        "        self.fc2 = nn.Linear(20, 10)\n",
243
        "        self.fc3 = nn.Linear(10, 20)\n",
244
        "        self.fc4 = nn.Linear(20, nb_movies)\n",
245
        "        self.activation = nn.Sigmoid()\n",
246
        "    def forward(self, x):\n",
247
        "        x = self.activation(self.fc1(x))\n",
248
        "        x = self.activation(self.fc2(x))\n",
249
        "        x = self.activation(self.fc3(x))\n",
250
        "        x = self.fc4(x)\n",
251
        "        return x\n",
252
        "sae = SAE()\n",
253
        "criterion = nn.MSELoss()\n",
254
        "optimizer = optim.RMSprop(sae.parameters(), lr = 0.01, weight_decay = 0.5)"
255
      ],
256
      "execution_count": 10,
257
      "outputs": []
258
    },
259
    {
260
      "cell_type": "code",
261
      "metadata": {
262
        "id": "FEz9hRaciFTs",
263
        "colab": {
264
          "base_uri": "https://localhost:8080/"
265
        },
266
        "outputId": "8ffae3d0-f74c-47af-c046-fbe8a5fec0b6"
267
      },
268
      "source": [
269
        "nb_epoch = 200\n",
270
        "for epoch in range(1, nb_epoch + 1):\n",
271
        "  train_loss = 0\n",
272
        "  s = 0.\n",
273
        "  for id_user in range(nb_users):\n",
274
        "    input = Variable(training_set[id_user]).unsqueeze(0)\n",
275
        "    target = input.clone()\n",
276
        "    if torch.sum(target.data > 0) > 0:\n",
277
        "      output = sae(input)\n",
278
        "      target.require_grad = False\n",
279
        "      output[target == 0] = 0\n",
280
        "      loss = criterion(output, target)\n",
281
        "      mean_corrector = nb_movies/float(torch.sum(target.data > 0) + 1e-10)\n",
282
        "      loss.backward()\n",
283
        "      train_loss += np.sqrt(loss.data*mean_corrector)\n",
284
        "      s += 1.\n",
285
        "      optimizer.step()\n",
286
        "  print('epoch: '+str(epoch)+'loss: '+ str(train_loss/s))"
287
      ],
288
      "execution_count": 11,
289
      "outputs": [
290
        {
291
          "output_type": "stream",
292
          "text": [
293
            "epoch: 1loss: tensor(1.7710)\n",
294
            "epoch: 2loss: tensor(1.0968)\n",
295
            "epoch: 3loss: tensor(1.0533)\n",
296
            "epoch: 4loss: tensor(1.0386)\n",
297
            "epoch: 5loss: tensor(1.0307)\n",
298
            "epoch: 6loss: tensor(1.0267)\n",
299
            "epoch: 7loss: tensor(1.0237)\n",
300
            "epoch: 8loss: tensor(1.0220)\n",
301
            "epoch: 9loss: tensor(1.0206)\n",
302
            "epoch: 10loss: tensor(1.0196)\n",
303
            "epoch: 11loss: tensor(1.0188)\n",
304
            "epoch: 12loss: tensor(1.0185)\n",
305
            "epoch: 13loss: tensor(1.0178)\n",
306
            "epoch: 14loss: tensor(1.0176)\n",
307
            "epoch: 15loss: tensor(1.0174)\n",
308
            "epoch: 16loss: tensor(1.0170)\n",
309
            "epoch: 17loss: tensor(1.0166)\n",
310
            "epoch: 18loss: tensor(1.0166)\n",
311
            "epoch: 19loss: tensor(1.0163)\n",
312
            "epoch: 20loss: tensor(1.0160)\n",
313
            "epoch: 21loss: tensor(1.0161)\n",
314
            "epoch: 22loss: tensor(1.0160)\n",
315
            "epoch: 23loss: tensor(1.0157)\n",
316
            "epoch: 24loss: tensor(1.0156)\n",
317
            "epoch: 25loss: tensor(1.0157)\n",
318
            "epoch: 26loss: tensor(1.0155)\n",
319
            "epoch: 27loss: tensor(1.0155)\n",
320
            "epoch: 28loss: tensor(1.0150)\n",
321
            "epoch: 29loss: tensor(1.0134)\n",
322
            "epoch: 30loss: tensor(1.0120)\n",
323
            "epoch: 31loss: tensor(1.0096)\n",
324
            "epoch: 32loss: tensor(1.0085)\n",
325
            "epoch: 33loss: tensor(1.0048)\n",
326
            "epoch: 34loss: tensor(1.0045)\n",
327
            "epoch: 35loss: tensor(1.0007)\n",
328
            "epoch: 36loss: tensor(0.9990)\n",
329
            "epoch: 37loss: tensor(0.9970)\n",
330
            "epoch: 38loss: tensor(0.9964)\n",
331
            "epoch: 39loss: tensor(0.9932)\n",
332
            "epoch: 40loss: tensor(0.9912)\n",
333
            "epoch: 41loss: tensor(0.9876)\n",
334
            "epoch: 42loss: tensor(0.9896)\n",
335
            "epoch: 43loss: tensor(0.9856)\n",
336
            "epoch: 44loss: tensor(0.9904)\n",
337
            "epoch: 45loss: tensor(0.9861)\n",
338
            "epoch: 46loss: tensor(0.9854)\n",
339
            "epoch: 47loss: tensor(0.9880)\n",
340
            "epoch: 48loss: tensor(0.9873)\n",
341
            "epoch: 49loss: tensor(0.9877)\n",
342
            "epoch: 50loss: tensor(0.9880)\n",
343
            "epoch: 51loss: tensor(0.9817)\n",
344
            "epoch: 52loss: tensor(0.9830)\n",
345
            "epoch: 53loss: tensor(0.9797)\n",
346
            "epoch: 54loss: tensor(0.9768)\n",
347
            "epoch: 55loss: tensor(0.9728)\n",
348
            "epoch: 56loss: tensor(0.9808)\n",
349
            "epoch: 57loss: tensor(0.9754)\n",
350
            "epoch: 58loss: tensor(0.9756)\n",
351
            "epoch: 59loss: tensor(0.9712)\n",
352
            "epoch: 60loss: tensor(0.9708)\n",
353
            "epoch: 61loss: tensor(0.9720)\n",
354
            "epoch: 62loss: tensor(0.9683)\n",
355
            "epoch: 63loss: tensor(0.9652)\n",
356
            "epoch: 64loss: tensor(0.9625)\n",
357
            "epoch: 65loss: tensor(0.9638)\n",
358
            "epoch: 66loss: tensor(0.9625)\n",
359
            "epoch: 67loss: tensor(0.9599)\n",
360
            "epoch: 68loss: tensor(0.9605)\n",
361
            "epoch: 69loss: tensor(0.9611)\n",
362
            "epoch: 70loss: tensor(0.9588)\n",
363
            "epoch: 71loss: tensor(0.9564)\n",
364
            "epoch: 72loss: tensor(0.9566)\n",
365
            "epoch: 73loss: tensor(0.9554)\n",
366
            "epoch: 74loss: tensor(0.9574)\n",
367
            "epoch: 75loss: tensor(0.9542)\n",
368
            "epoch: 76loss: tensor(0.9544)\n",
369
            "epoch: 77loss: tensor(0.9526)\n",
370
            "epoch: 78loss: tensor(0.9506)\n",
371
            "epoch: 79loss: tensor(0.9498)\n",
372
            "epoch: 80loss: tensor(0.9484)\n",
373
            "epoch: 81loss: tensor(0.9472)\n",
374
            "epoch: 82loss: tensor(0.9473)\n",
375
            "epoch: 83loss: tensor(0.9461)\n",
376
            "epoch: 84loss: tensor(0.9461)\n",
377
            "epoch: 85loss: tensor(0.9450)\n",
378
            "epoch: 86loss: tensor(0.9442)\n",
379
            "epoch: 87loss: tensor(0.9432)\n",
380
            "epoch: 88loss: tensor(0.9428)\n",
381
            "epoch: 89loss: tensor(0.9423)\n",
382
            "epoch: 90loss: tensor(0.9425)\n",
383
            "epoch: 91loss: tensor(0.9407)\n",
384
            "epoch: 92loss: tensor(0.9414)\n",
385
            "epoch: 93loss: tensor(0.9408)\n",
386
            "epoch: 94loss: tensor(0.9402)\n",
387
            "epoch: 95loss: tensor(0.9396)\n",
388
            "epoch: 96loss: tensor(0.9393)\n",
389
            "epoch: 97loss: tensor(0.9386)\n",
390
            "epoch: 98loss: tensor(0.9384)\n",
391
            "epoch: 99loss: tensor(0.9381)\n",
392
            "epoch: 100loss: tensor(0.9390)\n",
393
            "epoch: 101loss: tensor(0.9381)\n",
394
            "epoch: 102loss: tensor(0.9378)\n",
395
            "epoch: 103loss: tensor(0.9370)\n",
396
            "epoch: 104loss: tensor(0.9372)\n",
397
            "epoch: 105loss: tensor(0.9360)\n",
398
            "epoch: 106loss: tensor(0.9363)\n",
399
            "epoch: 107loss: tensor(0.9354)\n",
400
            "epoch: 108loss: tensor(0.9353)\n",
401
            "epoch: 109loss: tensor(0.9346)\n",
402
            "epoch: 110loss: tensor(0.9355)\n",
403
            "epoch: 111loss: tensor(0.9347)\n",
404
            "epoch: 112loss: tensor(0.9349)\n",
405
            "epoch: 113loss: tensor(0.9337)\n",
406
            "epoch: 114loss: tensor(0.9335)\n",
407
            "epoch: 115loss: tensor(0.9332)\n",
408
            "epoch: 116loss: tensor(0.9334)\n",
409
            "epoch: 117loss: tensor(0.9331)\n",
410
            "epoch: 118loss: tensor(0.9333)\n",
411
            "epoch: 119loss: tensor(0.9327)\n",
412
            "epoch: 120loss: tensor(0.9328)\n",
413
            "epoch: 121loss: tensor(0.9322)\n",
414
            "epoch: 122loss: tensor(0.9319)\n",
415
            "epoch: 123loss: tensor(0.9317)\n",
416
            "epoch: 124loss: tensor(0.9317)\n",
417
            "epoch: 125loss: tensor(0.9318)\n",
418
            "epoch: 126loss: tensor(0.9311)\n",
419
            "epoch: 127loss: tensor(0.9313)\n",
420
            "epoch: 128loss: tensor(0.9310)\n",
421
            "epoch: 129loss: tensor(0.9313)\n",
422
            "epoch: 130loss: tensor(0.9312)\n",
423
            "epoch: 131loss: tensor(0.9307)\n",
424
            "epoch: 132loss: tensor(0.9305)\n",
425
            "epoch: 133loss: tensor(0.9298)\n",
426
            "epoch: 134loss: tensor(0.9300)\n",
427
            "epoch: 135loss: tensor(0.9299)\n",
428
            "epoch: 136loss: tensor(0.9299)\n",
429
            "epoch: 137loss: tensor(0.9293)\n",
430
            "epoch: 138loss: tensor(0.9294)\n",
431
            "epoch: 139loss: tensor(0.9287)\n",
432
            "epoch: 140loss: tensor(0.9285)\n",
433
            "epoch: 141loss: tensor(0.9281)\n",
434
            "epoch: 142loss: tensor(0.9284)\n",
435
            "epoch: 143loss: tensor(0.9279)\n",
436
            "epoch: 144loss: tensor(0.9278)\n",
437
            "epoch: 145loss: tensor(0.9274)\n",
438
            "epoch: 146loss: tensor(0.9278)\n",
439
            "epoch: 147loss: tensor(0.9275)\n",
440
            "epoch: 148loss: tensor(0.9273)\n",
441
            "epoch: 149loss: tensor(0.9270)\n",
442
            "epoch: 150loss: tensor(0.9269)\n",
443
            "epoch: 151loss: tensor(0.9260)\n",
444
            "epoch: 152loss: tensor(0.9264)\n",
445
            "epoch: 153loss: tensor(0.9259)\n",
446
            "epoch: 154loss: tensor(0.9257)\n",
447
            "epoch: 155loss: tensor(0.9252)\n",
448
            "epoch: 156loss: tensor(0.9256)\n",
449
            "epoch: 157loss: tensor(0.9249)\n",
450
            "epoch: 158loss: tensor(0.9250)\n",
451
            "epoch: 159loss: tensor(0.9245)\n",
452
            "epoch: 160loss: tensor(0.9249)\n",
453
            "epoch: 161loss: tensor(0.9238)\n",
454
            "epoch: 162loss: tensor(0.9249)\n",
455
            "epoch: 163loss: tensor(0.9244)\n",
456
            "epoch: 164loss: tensor(0.9242)\n",
457
            "epoch: 165loss: tensor(0.9234)\n",
458
            "epoch: 166loss: tensor(0.9253)\n",
459
            "epoch: 167loss: tensor(0.9234)\n",
460
            "epoch: 168loss: tensor(0.9237)\n",
461
            "epoch: 169loss: tensor(0.9227)\n",
462
            "epoch: 170loss: tensor(0.9235)\n",
463
            "epoch: 171loss: tensor(0.9227)\n",
464
            "epoch: 172loss: tensor(0.9225)\n",
465
            "epoch: 173loss: tensor(0.9217)\n",
466
            "epoch: 174loss: tensor(0.9224)\n",
467
            "epoch: 175loss: tensor(0.9225)\n",
468
            "epoch: 176loss: tensor(0.9247)\n",
469
            "epoch: 177loss: tensor(0.9219)\n",
470
            "epoch: 178loss: tensor(0.9210)\n",
471
            "epoch: 179loss: tensor(0.9210)\n",
472
            "epoch: 180loss: tensor(0.9211)\n",
473
            "epoch: 181loss: tensor(0.9209)\n",
474
            "epoch: 182loss: tensor(0.9210)\n",
475
            "epoch: 183loss: tensor(0.9204)\n",
476
            "epoch: 184loss: tensor(0.9208)\n",
477
            "epoch: 185loss: tensor(0.9205)\n",
478
            "epoch: 186loss: tensor(0.9206)\n",
479
            "epoch: 187loss: tensor(0.9202)\n",
480
            "epoch: 188loss: tensor(0.9201)\n",
481
            "epoch: 189loss: tensor(0.9195)\n",
482
            "epoch: 190loss: tensor(0.9200)\n",
483
            "epoch: 191loss: tensor(0.9200)\n",
484
            "epoch: 192loss: tensor(0.9199)\n",
485
            "epoch: 193loss: tensor(0.9195)\n",
486
            "epoch: 194loss: tensor(0.9193)\n",
487
            "epoch: 195loss: tensor(0.9198)\n",
488
            "epoch: 196loss: tensor(0.9195)\n",
489
            "epoch: 197loss: tensor(0.9181)\n",
490
            "epoch: 198loss: tensor(0.9188)\n",
491
            "epoch: 199loss: tensor(0.9183)\n",
492
            "epoch: 200loss: tensor(0.9221)\n"
493
          ],
494
          "name": "stdout"
495
        }
496
      ]
497
    },
498
    {
499
      "cell_type": "code",
500
      "metadata": {
501
        "id": "5ztvzYRtiGCz",
502
        "colab": {
503
          "base_uri": "https://localhost:8080/",
504
          "height": 34
505
        },
506
        "outputId": "d0e8ea8b-9ac4-40e5-a19a-7fcfc6934d61"
507
      },
508
      "source": [
509
        "test_loss = 0\n",
510
        "s = 0.\n",
511
        "for id_user in range(nb_users):\n",
512
        "  input = Variable(training_set[id_user]).unsqueeze(0)\n",
513
        "  target = Variable(test_set[id_user]).unsqueeze(0)\n",
514
        "  if torch.sum(target.data > 0) > 0:\n",
515
        "    output = sae(input)\n",
516
        "    target.require_grad = False\n",
517
        "    output[target == 0] = 0\n",
518
        "    loss = criterion(output, target)\n",
519
        "    mean_corrector = nb_movies/float(torch.sum(target.data > 0) + 1e-10)\n",
520
        "    test_loss += np.sqrt(loss.data*mean_corrector)\n",
521
        "    s += 1.\n",
522
        "print('test loss: '+str(test_loss/s))"
523
      ],
524
      "execution_count": null,
525
      "outputs": [
526
        {
527
          "output_type": "stream",
528
          "text": [
529
            "test loss: tensor(0.9681)\n"
530
          ],
531
          "name": "stdout"
532
        }
533
      ]
534
    }
535
  ]
536
}

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.