Amazing-Python-Scripts
536 строк · 20.4 Кб
1{
2"nbformat": 4,
3"nbformat_minor": 0,
4"metadata": {
5"colab": {
6"name": "AutoEncoders.ipynb",
7"provenance": [],
8"collapsed_sections": [],
9"toc_visible": true,
10"include_colab_link": true
11},
12"kernelspec": {
13"name": "python3",
14"display_name": "Python 3"
15},
16"accelerator": "GPU"
17},
18"cells": [
19{
20"cell_type": "markdown",
21"metadata": {
22"id": "view-in-github",
23"colab_type": "text"
24},
25"source": [
26"<a href=\"https://colab.research.google.com/github/ayush-09/AutoEncoder-Deep-Learning/blob/main/AutoEncoders.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
27]
28},
29{
30"cell_type": "markdown",
31"metadata": {
32"id": "K4f4JG1gdKqj"
33},
34"source": [
35"#AutoEncoders"
36]
37},
38{
39"cell_type": "code",
40"metadata": {
41"id": "rjOPzue7FCXJ",
42"colab": {
43"base_uri": "https://localhost:8080/"
44},
45"outputId": "c488b69b-9cc1-4cd7-b954-0928c41b410d"
46},
47"source": [
48"!wget \"http://files.grouplens.org/datasets/movielens/ml-100k.zip\"\n",
49"!unzip ml-100k.zip\n",
50"!ls"
51],
52"execution_count": 1,
53"outputs": [
54{
55"output_type": "stream",
56"text": [
57"--2021-01-09 14:08:13-- http://files.grouplens.org/datasets/movielens/ml-100k.zip\n",
58"Resolving files.grouplens.org (files.grouplens.org)... 128.101.65.152\n",
59"Connecting to files.grouplens.org (files.grouplens.org)|128.101.65.152|:80... connected.\n",
60"HTTP request sent, awaiting response... 200 OK\n",
61"Length: 4924029 (4.7M) [application/zip]\n",
62"Saving to: ‘ml-100k.zip’\n",
63"\n",
64"ml-100k.zip 100%[===================>] 4.70M 18.8MB/s in 0.2s \n",
65"\n",
66"2021-01-09 14:08:13 (18.8 MB/s) - ‘ml-100k.zip’ saved [4924029/4924029]\n",
67"\n",
68"Archive: ml-100k.zip\n",
69" creating: ml-100k/\n",
70" inflating: ml-100k/allbut.pl \n",
71" inflating: ml-100k/mku.sh \n",
72" inflating: ml-100k/README \n",
73" inflating: ml-100k/u.data \n",
74" inflating: ml-100k/u.genre \n",
75" inflating: ml-100k/u.info \n",
76" inflating: ml-100k/u.item \n",
77" inflating: ml-100k/u.occupation \n",
78" inflating: ml-100k/u.user \n",
79" inflating: ml-100k/u1.base \n",
80" inflating: ml-100k/u1.test \n",
81" inflating: ml-100k/u2.base \n",
82" inflating: ml-100k/u2.test \n",
83" inflating: ml-100k/u3.base \n",
84" inflating: ml-100k/u3.test \n",
85" inflating: ml-100k/u4.base \n",
86" inflating: ml-100k/u4.test \n",
87" inflating: ml-100k/u5.base \n",
88" inflating: ml-100k/u5.test \n",
89" inflating: ml-100k/ua.base \n",
90" inflating: ml-100k/ua.test \n",
91" inflating: ml-100k/ub.base \n",
92" inflating: ml-100k/ub.test \n",
93"ml-100k ml-100k.zip sample_data\n"
94],
95"name": "stdout"
96}
97]
98},
99{
100"cell_type": "code",
101"metadata": {
102"id": "LOly1yfAfTjd",
103"colab": {
104"base_uri": "https://localhost:8080/"
105},
106"outputId": "97173fc2-9434-4529-9841-2bb989507c05"
107},
108"source": [
109"!wget \"http://files.grouplens.org/datasets/movielens/ml-1m.zip\"\n",
110"!unzip ml-1m.zip\n",
111"!ls"
112],
113"execution_count": 2,
114"outputs": [
115{
116"output_type": "stream",
117"text": [
118"--2021-01-09 14:08:24-- http://files.grouplens.org/datasets/movielens/ml-1m.zip\n",
119"Resolving files.grouplens.org (files.grouplens.org)... 128.101.65.152\n",
120"Connecting to files.grouplens.org (files.grouplens.org)|128.101.65.152|:80... connected.\n",
121"HTTP request sent, awaiting response... 200 OK\n",
122"Length: 5917549 (5.6M) [application/zip]\n",
123"Saving to: ‘ml-1m.zip’\n",
124"\n",
125"ml-1m.zip 100%[===================>] 5.64M 22.1MB/s in 0.3s \n",
126"\n",
127"2021-01-09 14:08:25 (22.1 MB/s) - ‘ml-1m.zip’ saved [5917549/5917549]\n",
128"\n",
129"Archive: ml-1m.zip\n",
130" creating: ml-1m/\n",
131" inflating: ml-1m/movies.dat \n",
132" inflating: ml-1m/ratings.dat \n",
133" inflating: ml-1m/README \n",
134" inflating: ml-1m/users.dat \n",
135"ml-100k ml-100k.zip ml-1m ml-1m.zip\tsample_data\n"
136],
137"name": "stdout"
138}
139]
140},
141{
142"cell_type": "code",
143"metadata": {
144"id": "_LvGeU1CeCtg"
145},
146"source": [
147"import numpy as np\n",
148"import pandas as pd\n",
149"import torch\n",
150"import torch.nn as nn\n",
151"import torch.nn.parallel\n",
152"import torch.optim as optim\n",
153"import torch.utils.data\n",
154"from torch.autograd import Variable"
155],
156"execution_count": 3,
157"outputs": []
158},
159{
160"cell_type": "code",
161"metadata": {
162"id": "UJw2p3-Cewo4"
163},
164"source": [
165"# We won't be using this dataset.\n",
166"movies = pd.read_csv('ml-1m/movies.dat', sep = '::', header = None, engine = 'python', encoding = 'latin-1')\n",
167"users = pd.read_csv('ml-1m/users.dat', sep = '::', header = None, engine = 'python', encoding = 'latin-1')\n",
168"ratings = pd.read_csv('ml-1m/ratings.dat', sep = '::', header = None, engine = 'python', encoding = 'latin-1')"
169],
170"execution_count": 4,
171"outputs": []
172},
173{
174"cell_type": "code",
175"metadata": {
176"id": "2usLKJBEgPE2"
177},
178"source": [
179"training_set = pd.read_csv('ml-100k/u1.base', delimiter = '\\t')\n",
180"training_set = np.array(training_set, dtype = 'int')\n",
181"test_set = pd.read_csv('ml-100k/u1.test', delimiter = '\\t')\n",
182"test_set = np.array(test_set, dtype = 'int')"
183],
184"execution_count": 6,
185"outputs": []
186},
187{
188"cell_type": "code",
189"metadata": {
190"id": "gPaGZqdniC5m"
191},
192"source": [
193"nb_users = int(max(max(training_set[:, 0], ), max(test_set[:, 0])))\n",
194"nb_movies = int(max(max(training_set[:, 1], ), max(test_set[:, 1])))"
195],
196"execution_count": 7,
197"outputs": []
198},
199{
200"cell_type": "code",
201"metadata": {
202"id": "-wASs2YFiDaa"
203},
204"source": [
205"def convert(data):\n",
206" new_data = []\n",
207" for id_users in range(1, nb_users + 1):\n",
208" id_movies = data[:, 1] [data[:, 0] == id_users]\n",
209" id_ratings = data[:, 2] [data[:, 0] == id_users]\n",
210" ratings = np.zeros(nb_movies)\n",
211" ratings[id_movies - 1] = id_ratings\n",
212" new_data.append(list(ratings))\n",
213" return new_data\n",
214"training_set = convert(training_set)\n",
215"test_set = convert(test_set)"
216],
217"execution_count": 8,
218"outputs": []
219},
220{
221"cell_type": "code",
222"metadata": {
223"id": "TwD-KD8yiEEw"
224},
225"source": [
226"training_set = torch.FloatTensor(training_set)\n",
227"test_set = torch.FloatTensor(test_set)"
228],
229"execution_count": 9,
230"outputs": []
231},
232{
233"cell_type": "code",
234"metadata": {
235"id": "oU2nyh76iE6M"
236},
237"source": [
238"class SAE(nn.Module):\n",
239" def __init__(self, ):\n",
240" super(SAE, self).__init__()\n",
241" self.fc1 = nn.Linear(nb_movies, 20)\n",
242" self.fc2 = nn.Linear(20, 10)\n",
243" self.fc3 = nn.Linear(10, 20)\n",
244" self.fc4 = nn.Linear(20, nb_movies)\n",
245" self.activation = nn.Sigmoid()\n",
246" def forward(self, x):\n",
247" x = self.activation(self.fc1(x))\n",
248" x = self.activation(self.fc2(x))\n",
249" x = self.activation(self.fc3(x))\n",
250" x = self.fc4(x)\n",
251" return x\n",
252"sae = SAE()\n",
253"criterion = nn.MSELoss()\n",
254"optimizer = optim.RMSprop(sae.parameters(), lr = 0.01, weight_decay = 0.5)"
255],
256"execution_count": 10,
257"outputs": []
258},
259{
260"cell_type": "code",
261"metadata": {
262"id": "FEz9hRaciFTs",
263"colab": {
264"base_uri": "https://localhost:8080/"
265},
266"outputId": "8ffae3d0-f74c-47af-c046-fbe8a5fec0b6"
267},
268"source": [
269"nb_epoch = 200\n",
270"for epoch in range(1, nb_epoch + 1):\n",
271" train_loss = 0\n",
272" s = 0.\n",
273" for id_user in range(nb_users):\n",
274" input = Variable(training_set[id_user]).unsqueeze(0)\n",
275" target = input.clone()\n",
276" if torch.sum(target.data > 0) > 0:\n",
277" output = sae(input)\n",
278" target.require_grad = False\n",
279" output[target == 0] = 0\n",
280" loss = criterion(output, target)\n",
281" mean_corrector = nb_movies/float(torch.sum(target.data > 0) + 1e-10)\n",
282" loss.backward()\n",
283" train_loss += np.sqrt(loss.data*mean_corrector)\n",
284" s += 1.\n",
285" optimizer.step()\n",
286" print('epoch: '+str(epoch)+'loss: '+ str(train_loss/s))"
287],
288"execution_count": 11,
289"outputs": [
290{
291"output_type": "stream",
292"text": [
293"epoch: 1loss: tensor(1.7710)\n",
294"epoch: 2loss: tensor(1.0968)\n",
295"epoch: 3loss: tensor(1.0533)\n",
296"epoch: 4loss: tensor(1.0386)\n",
297"epoch: 5loss: tensor(1.0307)\n",
298"epoch: 6loss: tensor(1.0267)\n",
299"epoch: 7loss: tensor(1.0237)\n",
300"epoch: 8loss: tensor(1.0220)\n",
301"epoch: 9loss: tensor(1.0206)\n",
302"epoch: 10loss: tensor(1.0196)\n",
303"epoch: 11loss: tensor(1.0188)\n",
304"epoch: 12loss: tensor(1.0185)\n",
305"epoch: 13loss: tensor(1.0178)\n",
306"epoch: 14loss: tensor(1.0176)\n",
307"epoch: 15loss: tensor(1.0174)\n",
308"epoch: 16loss: tensor(1.0170)\n",
309"epoch: 17loss: tensor(1.0166)\n",
310"epoch: 18loss: tensor(1.0166)\n",
311"epoch: 19loss: tensor(1.0163)\n",
312"epoch: 20loss: tensor(1.0160)\n",
313"epoch: 21loss: tensor(1.0161)\n",
314"epoch: 22loss: tensor(1.0160)\n",
315"epoch: 23loss: tensor(1.0157)\n",
316"epoch: 24loss: tensor(1.0156)\n",
317"epoch: 25loss: tensor(1.0157)\n",
318"epoch: 26loss: tensor(1.0155)\n",
319"epoch: 27loss: tensor(1.0155)\n",
320"epoch: 28loss: tensor(1.0150)\n",
321"epoch: 29loss: tensor(1.0134)\n",
322"epoch: 30loss: tensor(1.0120)\n",
323"epoch: 31loss: tensor(1.0096)\n",
324"epoch: 32loss: tensor(1.0085)\n",
325"epoch: 33loss: tensor(1.0048)\n",
326"epoch: 34loss: tensor(1.0045)\n",
327"epoch: 35loss: tensor(1.0007)\n",
328"epoch: 36loss: tensor(0.9990)\n",
329"epoch: 37loss: tensor(0.9970)\n",
330"epoch: 38loss: tensor(0.9964)\n",
331"epoch: 39loss: tensor(0.9932)\n",
332"epoch: 40loss: tensor(0.9912)\n",
333"epoch: 41loss: tensor(0.9876)\n",
334"epoch: 42loss: tensor(0.9896)\n",
335"epoch: 43loss: tensor(0.9856)\n",
336"epoch: 44loss: tensor(0.9904)\n",
337"epoch: 45loss: tensor(0.9861)\n",
338"epoch: 46loss: tensor(0.9854)\n",
339"epoch: 47loss: tensor(0.9880)\n",
340"epoch: 48loss: tensor(0.9873)\n",
341"epoch: 49loss: tensor(0.9877)\n",
342"epoch: 50loss: tensor(0.9880)\n",
343"epoch: 51loss: tensor(0.9817)\n",
344"epoch: 52loss: tensor(0.9830)\n",
345"epoch: 53loss: tensor(0.9797)\n",
346"epoch: 54loss: tensor(0.9768)\n",
347"epoch: 55loss: tensor(0.9728)\n",
348"epoch: 56loss: tensor(0.9808)\n",
349"epoch: 57loss: tensor(0.9754)\n",
350"epoch: 58loss: tensor(0.9756)\n",
351"epoch: 59loss: tensor(0.9712)\n",
352"epoch: 60loss: tensor(0.9708)\n",
353"epoch: 61loss: tensor(0.9720)\n",
354"epoch: 62loss: tensor(0.9683)\n",
355"epoch: 63loss: tensor(0.9652)\n",
356"epoch: 64loss: tensor(0.9625)\n",
357"epoch: 65loss: tensor(0.9638)\n",
358"epoch: 66loss: tensor(0.9625)\n",
359"epoch: 67loss: tensor(0.9599)\n",
360"epoch: 68loss: tensor(0.9605)\n",
361"epoch: 69loss: tensor(0.9611)\n",
362"epoch: 70loss: tensor(0.9588)\n",
363"epoch: 71loss: tensor(0.9564)\n",
364"epoch: 72loss: tensor(0.9566)\n",
365"epoch: 73loss: tensor(0.9554)\n",
366"epoch: 74loss: tensor(0.9574)\n",
367"epoch: 75loss: tensor(0.9542)\n",
368"epoch: 76loss: tensor(0.9544)\n",
369"epoch: 77loss: tensor(0.9526)\n",
370"epoch: 78loss: tensor(0.9506)\n",
371"epoch: 79loss: tensor(0.9498)\n",
372"epoch: 80loss: tensor(0.9484)\n",
373"epoch: 81loss: tensor(0.9472)\n",
374"epoch: 82loss: tensor(0.9473)\n",
375"epoch: 83loss: tensor(0.9461)\n",
376"epoch: 84loss: tensor(0.9461)\n",
377"epoch: 85loss: tensor(0.9450)\n",
378"epoch: 86loss: tensor(0.9442)\n",
379"epoch: 87loss: tensor(0.9432)\n",
380"epoch: 88loss: tensor(0.9428)\n",
381"epoch: 89loss: tensor(0.9423)\n",
382"epoch: 90loss: tensor(0.9425)\n",
383"epoch: 91loss: tensor(0.9407)\n",
384"epoch: 92loss: tensor(0.9414)\n",
385"epoch: 93loss: tensor(0.9408)\n",
386"epoch: 94loss: tensor(0.9402)\n",
387"epoch: 95loss: tensor(0.9396)\n",
388"epoch: 96loss: tensor(0.9393)\n",
389"epoch: 97loss: tensor(0.9386)\n",
390"epoch: 98loss: tensor(0.9384)\n",
391"epoch: 99loss: tensor(0.9381)\n",
392"epoch: 100loss: tensor(0.9390)\n",
393"epoch: 101loss: tensor(0.9381)\n",
394"epoch: 102loss: tensor(0.9378)\n",
395"epoch: 103loss: tensor(0.9370)\n",
396"epoch: 104loss: tensor(0.9372)\n",
397"epoch: 105loss: tensor(0.9360)\n",
398"epoch: 106loss: tensor(0.9363)\n",
399"epoch: 107loss: tensor(0.9354)\n",
400"epoch: 108loss: tensor(0.9353)\n",
401"epoch: 109loss: tensor(0.9346)\n",
402"epoch: 110loss: tensor(0.9355)\n",
403"epoch: 111loss: tensor(0.9347)\n",
404"epoch: 112loss: tensor(0.9349)\n",
405"epoch: 113loss: tensor(0.9337)\n",
406"epoch: 114loss: tensor(0.9335)\n",
407"epoch: 115loss: tensor(0.9332)\n",
408"epoch: 116loss: tensor(0.9334)\n",
409"epoch: 117loss: tensor(0.9331)\n",
410"epoch: 118loss: tensor(0.9333)\n",
411"epoch: 119loss: tensor(0.9327)\n",
412"epoch: 120loss: tensor(0.9328)\n",
413"epoch: 121loss: tensor(0.9322)\n",
414"epoch: 122loss: tensor(0.9319)\n",
415"epoch: 123loss: tensor(0.9317)\n",
416"epoch: 124loss: tensor(0.9317)\n",
417"epoch: 125loss: tensor(0.9318)\n",
418"epoch: 126loss: tensor(0.9311)\n",
419"epoch: 127loss: tensor(0.9313)\n",
420"epoch: 128loss: tensor(0.9310)\n",
421"epoch: 129loss: tensor(0.9313)\n",
422"epoch: 130loss: tensor(0.9312)\n",
423"epoch: 131loss: tensor(0.9307)\n",
424"epoch: 132loss: tensor(0.9305)\n",
425"epoch: 133loss: tensor(0.9298)\n",
426"epoch: 134loss: tensor(0.9300)\n",
427"epoch: 135loss: tensor(0.9299)\n",
428"epoch: 136loss: tensor(0.9299)\n",
429"epoch: 137loss: tensor(0.9293)\n",
430"epoch: 138loss: tensor(0.9294)\n",
431"epoch: 139loss: tensor(0.9287)\n",
432"epoch: 140loss: tensor(0.9285)\n",
433"epoch: 141loss: tensor(0.9281)\n",
434"epoch: 142loss: tensor(0.9284)\n",
435"epoch: 143loss: tensor(0.9279)\n",
436"epoch: 144loss: tensor(0.9278)\n",
437"epoch: 145loss: tensor(0.9274)\n",
438"epoch: 146loss: tensor(0.9278)\n",
439"epoch: 147loss: tensor(0.9275)\n",
440"epoch: 148loss: tensor(0.9273)\n",
441"epoch: 149loss: tensor(0.9270)\n",
442"epoch: 150loss: tensor(0.9269)\n",
443"epoch: 151loss: tensor(0.9260)\n",
444"epoch: 152loss: tensor(0.9264)\n",
445"epoch: 153loss: tensor(0.9259)\n",
446"epoch: 154loss: tensor(0.9257)\n",
447"epoch: 155loss: tensor(0.9252)\n",
448"epoch: 156loss: tensor(0.9256)\n",
449"epoch: 157loss: tensor(0.9249)\n",
450"epoch: 158loss: tensor(0.9250)\n",
451"epoch: 159loss: tensor(0.9245)\n",
452"epoch: 160loss: tensor(0.9249)\n",
453"epoch: 161loss: tensor(0.9238)\n",
454"epoch: 162loss: tensor(0.9249)\n",
455"epoch: 163loss: tensor(0.9244)\n",
456"epoch: 164loss: tensor(0.9242)\n",
457"epoch: 165loss: tensor(0.9234)\n",
458"epoch: 166loss: tensor(0.9253)\n",
459"epoch: 167loss: tensor(0.9234)\n",
460"epoch: 168loss: tensor(0.9237)\n",
461"epoch: 169loss: tensor(0.9227)\n",
462"epoch: 170loss: tensor(0.9235)\n",
463"epoch: 171loss: tensor(0.9227)\n",
464"epoch: 172loss: tensor(0.9225)\n",
465"epoch: 173loss: tensor(0.9217)\n",
466"epoch: 174loss: tensor(0.9224)\n",
467"epoch: 175loss: tensor(0.9225)\n",
468"epoch: 176loss: tensor(0.9247)\n",
469"epoch: 177loss: tensor(0.9219)\n",
470"epoch: 178loss: tensor(0.9210)\n",
471"epoch: 179loss: tensor(0.9210)\n",
472"epoch: 180loss: tensor(0.9211)\n",
473"epoch: 181loss: tensor(0.9209)\n",
474"epoch: 182loss: tensor(0.9210)\n",
475"epoch: 183loss: tensor(0.9204)\n",
476"epoch: 184loss: tensor(0.9208)\n",
477"epoch: 185loss: tensor(0.9205)\n",
478"epoch: 186loss: tensor(0.9206)\n",
479"epoch: 187loss: tensor(0.9202)\n",
480"epoch: 188loss: tensor(0.9201)\n",
481"epoch: 189loss: tensor(0.9195)\n",
482"epoch: 190loss: tensor(0.9200)\n",
483"epoch: 191loss: tensor(0.9200)\n",
484"epoch: 192loss: tensor(0.9199)\n",
485"epoch: 193loss: tensor(0.9195)\n",
486"epoch: 194loss: tensor(0.9193)\n",
487"epoch: 195loss: tensor(0.9198)\n",
488"epoch: 196loss: tensor(0.9195)\n",
489"epoch: 197loss: tensor(0.9181)\n",
490"epoch: 198loss: tensor(0.9188)\n",
491"epoch: 199loss: tensor(0.9183)\n",
492"epoch: 200loss: tensor(0.9221)\n"
493],
494"name": "stdout"
495}
496]
497},
498{
499"cell_type": "code",
500"metadata": {
501"id": "5ztvzYRtiGCz",
502"colab": {
503"base_uri": "https://localhost:8080/",
504"height": 34
505},
506"outputId": "d0e8ea8b-9ac4-40e5-a19a-7fcfc6934d61"
507},
508"source": [
509"test_loss = 0\n",
510"s = 0.\n",
511"for id_user in range(nb_users):\n",
512" input = Variable(training_set[id_user]).unsqueeze(0)\n",
513" target = Variable(test_set[id_user]).unsqueeze(0)\n",
514" if torch.sum(target.data > 0) > 0:\n",
515" output = sae(input)\n",
516" target.require_grad = False\n",
517" output[target == 0] = 0\n",
518" loss = criterion(output, target)\n",
519" mean_corrector = nb_movies/float(torch.sum(target.data > 0) + 1e-10)\n",
520" test_loss += np.sqrt(loss.data*mean_corrector)\n",
521" s += 1.\n",
522"print('test loss: '+str(test_loss/s))"
523],
524"execution_count": null,
525"outputs": [
526{
527"output_type": "stream",
528"text": [
529"test loss: tensor(0.9681)\n"
530],
531"name": "stdout"
532}
533]
534}
535]
536}