Amazing-Python-Scripts
295 строк · 8.8 Кб
1# Read text captions
2from keras.layers import Dense, Dropout, Embedding, LSTM3from keras.utils import load_img, img_to_array4from tf.keras.applications.resnet50 import ResNet50, preprocess_input5from keras.layers import Input6import collections7from keras.layers.merging import add8from keras.utils import to_categorical9from tf.keras.preprocessing.sequence import pad_sequences10from keras.models import Model11from keras.preprocessing import image12import tensorflow as tf13import re14import keras15import matplotlib.pyplot as plt16import numpy as np17
18
19def readTextFile(path):20with open(path) as f:21captions = f.read()22return captions23
24
25# Location of captions
26captions = readTextFile('files/captions.txt')27captions = captions.split("\n")[1:-1]28print(len(captions)) # Total captions29
30# Creating dictionary - {"image name": ["caption1","caption2"...]}
31description = {}32for x in captions:33parts = x.split(',')34img_name = parts[0][:-4]35coment = parts[1]36if description.get(img_name) is None:37description[img_name] = []38description[img_name].append(coment)39
40# All libraries
41
42# Data cleaning
43# Don't remove stopwords because we need to make meaningful words. Also stemming will also not applicable because we require texts has correct vocabulary
44# Remove numbers, lower() , punctuations remove
45
46
47def clean_text(sentence):48sentence = sentence.lower()49sentence = re.sub("[^a-z]+", " ", sentence)50sentence = sentence.split()51
52sentence = [s for s in sentence if len(s) > 1]53sentence = " ".join(sentence)54return sentence55
56
57# clean all captions
58for key, caption_list in description.items():59for i in range(len(caption_list)):60caption_list[i] = clean_text(caption_list[i])61
62# Total number of words across all the sentences
63total_words = []64for key in description.keys():65[total_words.append(i) for des in description[key] for i in des.split()]66print(len(total_words))67
68# Filter words from the vocab according to the certain threshold frequency
69counter = collections.Counter(total_words)70freq_cnt = dict(counter)71
72# Sort this dictionary according to freq count
73sorted_freq_cnt = sorted(freq_cnt.items(), reverse=True, key=lambda x: x[1])74
75# Filtering
76threshold = 577sorted_freq_cnt = [x for x in sorted_freq_cnt if x[1] > threshold]78total_words = [x[0] for x in sorted_freq_cnt]79
80# Prepare train/test data
81train_filedata = readTextFile("files/Flickr_8k.trainImages.txt")82test_filedata = readTextFile("files/Flickr_8k.testImages.txt")83
84train = [row.split(".")[0] for row in train_filedata.split("\n")[:-1]]85test = [row.split(".")[0] for row in test_filedata.split("\n")[:-1]]86
87# Prepare description for the training data
88# Tweak - add <S> and <e> token to our training data
89train_description = {}90for img_id in train:91train_description[img_id] = []92for cap in description[img_id]:93cap_to_append = "startseq " + cap + " endseq"94train_description[img_id].append(cap_to_append)95
96
97# Transfer learning
98# Step 1. Image feature extraction
99# Using pretrained ResNet50 model for extracting preprocessing images
100model = ResNet50(weights='imagenet', input_shape=(224, 224, 3))101model.summary()102
103# Removing last 2 layers of ResNet50 model
104new_model = Model(model.input, model.layers[-2].output)105new_model.summary()106
107
108def preprocess_img(img):109img = load_img(img, target_size=(224, 224))110img = img_to_array(img)111img = np.expand_dims(img, axis=0)112# normalisation -> preprocess_input113img = preprocess_input(img)114return img115
116
117def encode_image(img):118img = preprocess_img(img)119feature_vector = new_model.predict(img, verbose=0)120# print(feature_vector.shape)121feature_vector = feature_vector.reshape((-1,))122return feature_vector123
124
125# encode all train images
126encoding_train = {}127# image_id --> feature vector extrcted from resnet
128for ix, img_id in enumerate(train):129img_path = "files/Images/"+img_id+".jpg"130encoding_train[img_id] = encode_image(img_path)131# if ix%100==0:132# print(ix)133
134# encode all test images
135encoding_test = {}136# image_id --> feature vector extrcted from resnet
137for ix, img_id in enumerate(test):138img_path = "files/Images/"+img_id+".jpg"139encoding_test[img_id] = encode_image(img_path)140# if ix%100==0:141# print(ix)142
143
144word_to_idx = {}145idx_to_word = {}146for i, word in enumerate(total_words):147word_to_idx[word] = i+1148idx_to_word[i+1] = word149word_to_idx['startseq'] = 2573150word_to_idx['endseq'] = 2574151idx_to_word[2573] = 'startseq'152idx_to_word[2574] = 'endseq'153
154# Model training
155# RNN model ->
156# Find max length of any caption to decide RNN model size
157max_len = 0158for key in train_description.keys():159for cap in train_description[key]:160max_len = max(max_len, len(cap.split())) # Max length of any caption161
162# Data Loader(generator)
163
164
165def data_generator(train_description, encoding_train, word_to_idx, max_len, batch_size, vocab_size=2574):166x1, x2, y = [], [], []167n = 0168
169while True:170for key, desc_list in train_description.items():171n += 1172photo = encoding_train[key]173for desc in desc_list:174seq = [word_to_idx[word]175for word in desc.split() if word in word_to_idx.keys()]176for i in range(1, len(seq)):177xi = seq[0:i]178yi = seq[i]179
180xi = pad_sequences([xi], maxlen=max_len,181value=0, padding='post')[0]182yi = to_categorical([yi-1], num_classes=vocab_size)[0]183x1.append(photo) # 2048184x2.append(xi) # 35 -> glove185y.append(yi) # vocab_size->2574186
187if n == batch_size:188yield [[np.array(x1), np.array(x2)], np.array(y)]189x1, x2, y = [], [], []190n = 0191
192
193# WORD EMBEDDINGS
194# The text data should be embedded before passing to RNN/LSTM layer
195f = open("files/glove.6B.50d.txt", encoding='utf8')196embedding_index = {}197
198for line in f:199values = line.split()200word = values[0]201word_embedding = np.array(values[1:], dtype='float')202embedding_index[word] = word_embedding203
204
205def get_embedding_matrix(vocab_size=2574):206emb_dim = 50207matrix = np.zeros((vocab_size, emb_dim))208for word, idx in word_to_idx.items():209embedding_vector = embedding_index.get(word)210if embedding_vector is not None:211matrix[idx] = embedding_vector212return matrix213
214
215embedding_matrix = get_embedding_matrix()216
217vocab_size = 2574218input_img_features = Input(shape=(2048,))219input_img1 = Dropout(0.3)(input_img_features)220input_img2 = Dense(256, activation="relu")(input_img1)221
222# Captions as input => batch_size*35 -> batch_size*35*50 -> 256
223input_captions = Input(shape=(max_len,))224# Now here we use customize embedding and not the glove vector embedding yet
225input_cap1 = Embedding(input_dim=vocab_size, output_dim=50,226mask_zero=True)(input_captions)227input_cap2 = Dropout(0.3)(input_cap1)228input_cap3 = LSTM(256)(input_cap2)229
230# Add inputs and decode them
231decoder1 = add([input_img2, input_cap3])232decoder2 = Dense(256, activation='relu')(decoder1)233outputs = Dense(vocab_size, activation='softmax')(decoder2)234
235# COMBINED MODEL
236model = Model(inputs=[input_img_features, input_captions], outputs=outputs)237
238# Important thing -- Embedding layer # Here we defined the matrix to be choose for the words with integers
239model.layers[2].set_weights([embedding_matrix])240model.layers[2].trainable = False241model.compile(loss="categorical_crossentropy", optimizer="adam")242
243print(model.summary())244
245# Training of Model
246epochs = 10247batch_size = 3 # no if images per batch248steps = len(train_description)//batch_size249
250
251def train():252for i in range(epochs):253generator = data_generator(254train_description, encoding_train, word_to_idx, max_len, batch_size)255model.fit(generator, epochs=1, steps_per_epoch=steps, verbose=1)256
257
258model.save("models/"+"9"+'.h5')259
260train()261
262# Prediction Function
263
264
265def predict_caption(photo):266in_text = "startseq"267for i in range(max_len):268sequence = [word_to_idx[w]269for w in in_text.split() if w in word_to_idx]270sequence = pad_sequences([sequence], maxlen=max_len, padding='post')271ypred = model.predict([photo, sequence])272ypred = ypred.argmax() # word with max probability -> greedy sampling273word = idx_to_word[ypred+1]274in_text += (' ' + word)275if word == 'endseq':276break277final_caption = in_text.split()[1:-1]278final_caption = ' '.join(final_caption)279return final_caption280
281
282# Pick some random images
283for i in range(15):284no = np.random.randint(0, 1000)285all_img_names = list(encoding_test.keys())286img_name = all_img_names[no]287photo_2048 = encoding_test[img_name].reshape((1, 2048))288
289caption = predict_caption(photo_2048)290
291i = plt.imread("files/Images/"+img_name+".jpg")292print(caption)293plt.imshow(i)294plt.axis("off")295plt.show()296