Amazing-Python-Scripts

Форк
0
295 строк · 8.8 Кб
1
# Read text captions
2
from keras.layers import Dense, Dropout, Embedding, LSTM
3
from keras.utils import load_img, img_to_array
4
from tf.keras.applications.resnet50 import ResNet50, preprocess_input
5
from keras.layers import Input
6
import collections
7
from keras.layers.merging import add
8
from keras.utils import to_categorical
9
from tf.keras.preprocessing.sequence import pad_sequences
10
from keras.models import Model
11
from keras.preprocessing import image
12
import tensorflow as tf
13
import re
14
import keras
15
import matplotlib.pyplot as plt
16
import numpy as np
17

18

19
def readTextFile(path):
20
    with open(path) as f:
21
        captions = f.read()
22
    return captions
23

24

25
# Location of captions
26
captions = readTextFile('files/captions.txt')
27
captions = captions.split("\n")[1:-1]
28
print(len(captions))  # Total captions
29

30
# Creating dictionary - {"image name": ["caption1","caption2"...]}
31
description = {}
32
for x in captions:
33
    parts = x.split(',')
34
    img_name = parts[0][:-4]
35
    coment = parts[1]
36
    if description.get(img_name) is None:
37
        description[img_name] = []
38
    description[img_name].append(coment)
39

40
# All libraries
41

42
# Data cleaning
43
# Don't remove stopwords because we need to make meaningful words. Also stemming will also not applicable because we require texts has correct vocabulary
44
# Remove numbers, lower() , punctuations remove
45

46

47
def clean_text(sentence):
48
    sentence = sentence.lower()
49
    sentence = re.sub("[^a-z]+", " ", sentence)
50
    sentence = sentence.split()
51

52
    sentence = [s for s in sentence if len(s) > 1]
53
    sentence = " ".join(sentence)
54
    return sentence
55

56

57
# clean all captions
58
for key, caption_list in description.items():
59
    for i in range(len(caption_list)):
60
        caption_list[i] = clean_text(caption_list[i])
61

62
# Total number of words across all the sentences
63
total_words = []
64
for key in description.keys():
65
    [total_words.append(i) for des in description[key] for i in des.split()]
66
print(len(total_words))
67

68
# Filter words from the vocab according to the certain threshold frequency
69
counter = collections.Counter(total_words)
70
freq_cnt = dict(counter)
71

72
# Sort this dictionary according to freq count
73
sorted_freq_cnt = sorted(freq_cnt.items(), reverse=True, key=lambda x: x[1])
74

75
# Filtering
76
threshold = 5
77
sorted_freq_cnt = [x for x in sorted_freq_cnt if x[1] > threshold]
78
total_words = [x[0] for x in sorted_freq_cnt]
79

80
# Prepare train/test data
81
train_filedata = readTextFile("files/Flickr_8k.trainImages.txt")
82
test_filedata = readTextFile("files/Flickr_8k.testImages.txt")
83

84
train = [row.split(".")[0] for row in train_filedata.split("\n")[:-1]]
85
test = [row.split(".")[0] for row in test_filedata.split("\n")[:-1]]
86

87
# Prepare description for the training data
88
# Tweak - add <S> and <e> token to our training data
89
train_description = {}
90
for img_id in train:
91
    train_description[img_id] = []
92
    for cap in description[img_id]:
93
        cap_to_append = "startseq " + cap + " endseq"
94
        train_description[img_id].append(cap_to_append)
95

96

97
# Transfer learning
98
# Step 1. Image feature extraction
99
# Using pretrained ResNet50 model for extracting preprocessing images
100
model = ResNet50(weights='imagenet', input_shape=(224, 224, 3))
101
model.summary()
102

103
# Removing last 2 layers of ResNet50 model
104
new_model = Model(model.input, model.layers[-2].output)
105
new_model.summary()
106

107

108
def preprocess_img(img):
109
    img = load_img(img, target_size=(224, 224))
110
    img = img_to_array(img)
111
    img = np.expand_dims(img, axis=0)
112
    # normalisation -> preprocess_input
113
    img = preprocess_input(img)
114
    return img
115

116

117
def encode_image(img):
118
    img = preprocess_img(img)
119
    feature_vector = new_model.predict(img, verbose=0)
120
    # print(feature_vector.shape)
121
    feature_vector = feature_vector.reshape((-1,))
122
    return feature_vector
123

124

125
# encode all train images
126
encoding_train = {}
127
# image_id --> feature vector extrcted from resnet
128
for ix, img_id in enumerate(train):
129
    img_path = "files/Images/"+img_id+".jpg"
130
    encoding_train[img_id] = encode_image(img_path)
131
    # if ix%100==0:
132
    #     print(ix)
133

134
# encode all test images
135
encoding_test = {}
136
# image_id --> feature vector extrcted from resnet
137
for ix, img_id in enumerate(test):
138
    img_path = "files/Images/"+img_id+".jpg"
139
    encoding_test[img_id] = encode_image(img_path)
140
    # if ix%100==0:
141
    #     print(ix)
142

143

144
word_to_idx = {}
145
idx_to_word = {}
146
for i, word in enumerate(total_words):
147
    word_to_idx[word] = i+1
148
    idx_to_word[i+1] = word
149
word_to_idx['startseq'] = 2573
150
word_to_idx['endseq'] = 2574
151
idx_to_word[2573] = 'startseq'
152
idx_to_word[2574] = 'endseq'
153

154
# Model training
155
# RNN model ->
156
# Find max length of any caption to decide RNN model size
157
max_len = 0
158
for key in train_description.keys():
159
    for cap in train_description[key]:
160
        max_len = max(max_len, len(cap.split()))  # Max length of any caption
161

162
# Data Loader(generator)
163

164

165
def data_generator(train_description, encoding_train, word_to_idx, max_len, batch_size, vocab_size=2574):
166
    x1, x2, y = [], [], []
167
    n = 0
168

169
    while True:
170
        for key, desc_list in train_description.items():
171
            n += 1
172
            photo = encoding_train[key]
173
            for desc in desc_list:
174
                seq = [word_to_idx[word]
175
                       for word in desc.split() if word in word_to_idx.keys()]
176
                for i in range(1, len(seq)):
177
                    xi = seq[0:i]
178
                    yi = seq[i]
179

180
                    xi = pad_sequences([xi], maxlen=max_len,
181
                                       value=0, padding='post')[0]
182
                    yi = to_categorical([yi-1], num_classes=vocab_size)[0]
183
                    x1.append(photo)  # 2048
184
                    x2.append(xi)     # 35 -> glove
185
                    y.append(yi)      # vocab_size->2574
186

187
            if n == batch_size:
188
                yield [[np.array(x1), np.array(x2)], np.array(y)]
189
                x1, x2, y = [], [], []
190
                n = 0
191

192

193
# WORD EMBEDDINGS
194
# The text data should be embedded before passing to RNN/LSTM layer
195
f = open("files/glove.6B.50d.txt", encoding='utf8')
196
embedding_index = {}
197

198
for line in f:
199
    values = line.split()
200
    word = values[0]
201
    word_embedding = np.array(values[1:], dtype='float')
202
    embedding_index[word] = word_embedding
203

204

205
def get_embedding_matrix(vocab_size=2574):
206
    emb_dim = 50
207
    matrix = np.zeros((vocab_size, emb_dim))
208
    for word, idx in word_to_idx.items():
209
        embedding_vector = embedding_index.get(word)
210
        if embedding_vector is not None:
211
            matrix[idx] = embedding_vector
212
    return matrix
213

214

215
embedding_matrix = get_embedding_matrix()
216

217
vocab_size = 2574
218
input_img_features = Input(shape=(2048,))
219
input_img1 = Dropout(0.3)(input_img_features)
220
input_img2 = Dense(256, activation="relu")(input_img1)
221

222
# Captions as input => batch_size*35 -> batch_size*35*50 -> 256
223
input_captions = Input(shape=(max_len,))
224
# Now here we use customize embedding and not the glove vector embedding yet
225
input_cap1 = Embedding(input_dim=vocab_size, output_dim=50,
226
                       mask_zero=True)(input_captions)
227
input_cap2 = Dropout(0.3)(input_cap1)
228
input_cap3 = LSTM(256)(input_cap2)
229

230
# Add inputs and decode them
231
decoder1 = add([input_img2, input_cap3])
232
decoder2 = Dense(256, activation='relu')(decoder1)
233
outputs = Dense(vocab_size, activation='softmax')(decoder2)
234

235
# COMBINED MODEL
236
model = Model(inputs=[input_img_features, input_captions], outputs=outputs)
237

238
# Important thing -- Embedding layer # Here we defined the matrix to be choose for the words with integers
239
model.layers[2].set_weights([embedding_matrix])
240
model.layers[2].trainable = False
241
model.compile(loss="categorical_crossentropy", optimizer="adam")
242

243
print(model.summary())
244

245
# Training of Model
246
epochs = 10
247
batch_size = 3  # no if images per batch
248
steps = len(train_description)//batch_size
249

250

251
def train():
252
    for i in range(epochs):
253
        generator = data_generator(
254
            train_description, encoding_train, word_to_idx, max_len, batch_size)
255
        model.fit(generator, epochs=1, steps_per_epoch=steps, verbose=1)
256

257

258
model.save("models/"+"9"+'.h5')
259

260
train()
261

262
# Prediction Function
263

264

265
def predict_caption(photo):
266
    in_text = "startseq"
267
    for i in range(max_len):
268
        sequence = [word_to_idx[w]
269
                    for w in in_text.split() if w in word_to_idx]
270
        sequence = pad_sequences([sequence], maxlen=max_len, padding='post')
271
        ypred = model.predict([photo, sequence])
272
        ypred = ypred.argmax()  # word with max probability -> greedy sampling
273
        word = idx_to_word[ypred+1]
274
        in_text += (' ' + word)
275
        if word == 'endseq':
276
            break
277
    final_caption = in_text.split()[1:-1]
278
    final_caption = ' '.join(final_caption)
279
    return final_caption
280

281

282
# Pick some random images
283
for i in range(15):
284
    no = np.random.randint(0, 1000)
285
    all_img_names = list(encoding_test.keys())
286
    img_name = all_img_names[no]
287
    photo_2048 = encoding_test[img_name].reshape((1, 2048))
288

289
    caption = predict_caption(photo_2048)
290

291
    i = plt.imread("files/Images/"+img_name+".jpg")
292
    print(caption)
293
    plt.imshow(i)
294
    plt.axis("off")
295
    plt.show()
296

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.