Amazing-Python-Scripts
63 строки · 2.2 Кб
1import os
2from keras.preprocessing import image
3import matplotlib.pyplot as plt
4import numpy as np
5from keras.utils.np_utils import to_categorical
6import random
7import shutil
8from keras.models import Sequential
9from keras.layers import Dropout, Conv2D, Flatten, Dense, MaxPooling2D, BatchNormalization
10from keras.models import load_model
11
12
13def generator(dir, gen=image.ImageDataGenerator(rescale=1./255), shuffle=True, batch_size=1, target_size=(24, 24), class_mode='categorical'):
14
15return gen.flow_from_directory(dir, batch_size=batch_size, shuffle=shuffle, color_mode='grayscale', class_mode=class_mode, target_size=target_size)
16
17
18BS = 32
19TS = (24, 24)
20train_batch = generator('data/train', shuffle=True,
21batch_size=BS, target_size=TS)
22valid_batch = generator('data/valid', shuffle=True,
23batch_size=BS, target_size=TS)
24SPE = len(train_batch.classes)//BS
25VS = len(valid_batch.classes)//BS
26print(SPE, VS)
27
28
29# img,labels= next(train_batch)
30# print(img.shape)
31
32model = Sequential([
33Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(24, 24, 1)),
34MaxPooling2D(pool_size=(1, 1)),
35Conv2D(32, (3, 3), activation='relu'),
36MaxPooling2D(pool_size=(1, 1)),
37# 32 convolution filters used each of size 3x3
38# again
39Conv2D(64, (3, 3), activation='relu'),
40MaxPooling2D(pool_size=(1, 1)),
41
42# 64 convolution filters used each of size 3x3
43# choose the best features via pooling
44
45# randomly turn neurons on and off to improve convergence
46Dropout(0.25),
47# flatten since too many dimensions, we only want a classification output
48Flatten(),
49# fully connected to get all relevant data
50Dense(128, activation='relu'),
51# one more dropout for convergence' sake :)
52Dropout(0.5),
53# output a softmax to squash the matrix into output probabilities
54Dense(2, activation='softmax')
55])
56
57model.compile(optimizer='adam', loss='categorical_crossentropy',
58metrics=['accuracy'])
59
60model.fit_generator(train_batch, validation_data=valid_batch,
61epochs=15, steps_per_epoch=SPE, validation_steps=VS)
62
63model.save('models/cnnCat2.h5', overwrite=True)
64