Amazing-Python-Scripts

Форк
0
63 строки · 2.2 Кб
1
import os
2
from keras.preprocessing import image
3
import matplotlib.pyplot as plt
4
import numpy as np
5
from keras.utils.np_utils import to_categorical
6
import random
7
import shutil
8
from keras.models import Sequential
9
from keras.layers import Dropout, Conv2D, Flatten, Dense, MaxPooling2D, BatchNormalization
10
from keras.models import load_model
11

12

13
def generator(dir, gen=image.ImageDataGenerator(rescale=1./255), shuffle=True, batch_size=1, target_size=(24, 24), class_mode='categorical'):
14

15
    return gen.flow_from_directory(dir, batch_size=batch_size, shuffle=shuffle, color_mode='grayscale', class_mode=class_mode, target_size=target_size)
16

17

18
BS = 32
19
TS = (24, 24)
20
train_batch = generator('data/train', shuffle=True,
21
                        batch_size=BS, target_size=TS)
22
valid_batch = generator('data/valid', shuffle=True,
23
                        batch_size=BS, target_size=TS)
24
SPE = len(train_batch.classes)//BS
25
VS = len(valid_batch.classes)//BS
26
print(SPE, VS)
27

28

29
# img,labels= next(train_batch)
30
# print(img.shape)
31

32
model = Sequential([
33
    Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(24, 24, 1)),
34
    MaxPooling2D(pool_size=(1, 1)),
35
    Conv2D(32, (3, 3), activation='relu'),
36
    MaxPooling2D(pool_size=(1, 1)),
37
    # 32 convolution filters used each of size 3x3
38
    # again
39
    Conv2D(64, (3, 3), activation='relu'),
40
    MaxPooling2D(pool_size=(1, 1)),
41

42
    # 64 convolution filters used each of size 3x3
43
    # choose the best features via pooling
44

45
    # randomly turn neurons on and off to improve convergence
46
    Dropout(0.25),
47
    # flatten since too many dimensions, we only want a classification output
48
    Flatten(),
49
    # fully connected to get all relevant data
50
    Dense(128, activation='relu'),
51
    # one more dropout for convergence' sake :)
52
    Dropout(0.5),
53
    # output a softmax to squash the matrix into output probabilities
54
    Dense(2, activation='softmax')
55
])
56

57
model.compile(optimizer='adam', loss='categorical_crossentropy',
58
              metrics=['accuracy'])
59

60
model.fit_generator(train_batch, validation_data=valid_batch,
61
                    epochs=15, steps_per_epoch=SPE, validation_steps=VS)
62

63
model.save('models/cnnCat2.h5', overwrite=True)
64

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.