google-research
210 строк · 8.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import os17import numpy as np18import tensorflow as tf19# pylint: skip-file
20
21ATT_ID = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2,22'Bags_Under_Eyes': 3, 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6,23'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, 'Blurry': 10,24'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13,25'Double_Chin': 14, 'Eyeglasses': 15, 'Goatee': 16,26'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19,27'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22,28'Narrow_Eyes': 23, 'No_Beard': 24, 'Oval_Face': 25,29'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28,30'Rosy_Cheeks': 29, 'Sideburns': 30, 'Smiling': 31,31'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34,32'Wearing_Hat': 35, 'Wearing_Lipstick': 36,33'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39}34ID_ATT = {v: k for k, v in ATT_ID.items()}35NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 16277036CENTRAL_FRACTION = 0.8937LOAD_SIZE = 142 #28638CROP_SIZE = 128 #25639
40def cal_eo(a, y_label, y_pred):41a = np.array(a)42y_label = np.array(y_label)43y_pred = np.array(y_pred)44
45idx00 = np.logical_and(a==0,y_label==0)46idx01 = np.logical_and(a==0,y_label==1)47idx10 = np.logical_and(a==1,y_label==0)48idx11 = np.logical_and(a==1,y_label==1)49
50d00 = 1 - np.sum(y_pred[idx00])/y_pred[idx00].shape[0]51d01 = np.sum(y_pred[idx01])/y_pred[idx01].shape[0]52d10 = 1 - np.sum(y_pred[idx10])/y_pred[idx10].shape[0]53d11 = np.sum(y_pred[idx11])/y_pred[idx11].shape[0]54
55eo = np.abs(d00-d10)+np.abs(d01-d11)56return (d00,d01,d10,d11,eo)57
58def reorg(label_path,af,bf):59img_names = np.genfromtxt(label_path, dtype=str, usecols=0)60labels = np.genfromtxt(label_path, dtype=int, usecols=range(1, 41))61entry = np.concatenate((img_names[:, np.newaxis], labels), axis=1)62a = np.asarray((labels[:,ATT_ID[af]]+1)//2)63b = np.asarray((labels[:,ATT_ID[bf]]+1)//2)64d00 = []65d01 = []66d10 = []67d11 = []68for i in range(labels.shape[0]):69if a[i]==0:70if b[i]==0: d00.append(entry[i])71elif b[i]==1: d01.append(entry[i])72elif a[i]==1:73if b[i]==0: d10.append(entry[i])74elif b[i]==1: d11.append(entry[i])75min_leng = np.min([len(d00),len(d01),len(d10),len(d11)])76new_list = d00[:min_leng]+d01[:3*min_leng]+d10[:3*min_leng]+d11[:min_leng]77return np.array(new_list)78
79def reorg_fake(label_path,af,bf):80img_names = np.genfromtxt(label_path, dtype=str, usecols=0)81labels = np.genfromtxt(label_path, dtype=int, usecols=range(1, 41))82entry = np.concatenate((img_names[:, np.newaxis], labels), axis=1)83a = np.asarray((labels[:,ATT_ID[af]]+1)//2)84b = np.asarray((labels[:,ATT_ID[bf]]+1)//2)85d00 = []86d01 = []87d10 = []88d11 = []89for i in range(labels.shape[0]):90if a[i]==0:91if b[i]==0: d00.append(entry[i])92elif b[i]==1: d01.append(entry[i])93elif a[i]==1:94if b[i]==0: d10.append(entry[i])95elif b[i]==1: d11.append(entry[i])96min_leng = np.min([len(d00),len(d01),len(d10),len(d11)])97new_list = d00[:min_leng]+d01[:3*min_leng]+d10[:3*min_leng]+d11[:min_leng]98return np.array(new_list)99
100def load_train(image_path, label, att):101image = tf.io.read_file(image_path)102image = tf.image.decode_jpeg(image)103image = tf.image.resize(image, [LOAD_SIZE, LOAD_SIZE])104image = tf.image.random_flip_left_right(image)105image = tf.image.random_crop(image, [CROP_SIZE, CROP_SIZE, 3])106image = tf.clip_by_value(image, 0, 255) / 127.5 - 1107label = (label + 1) // 2108att = (att + 1) // 2109image = tf.cast(image, tf.float32)110label = tf.cast(label, tf.int32)111att = tf.cast(att, tf.int32)112return (image, label, att)113
114def load_test(image_path, label, att):115image = tf.io.read_file(image_path)116image = tf.image.decode_jpeg(image)117image = tf.image.resize(image, [LOAD_SIZE, LOAD_SIZE])118image = tf.image.central_crop(image, CENTRAL_FRACTION)119image = tf.clip_by_value(image, 0, 255) / 127.5 - 1120label = (label + 1) // 2121att = (att + 1) // 2122image = tf.cast(image, tf.float32)123label = tf.cast(label, tf.int32)124att = tf.cast(att, tf.int32)125return (image, label, att)126
127# load balanced training dataset
128def data_train(image_path, label_path, batch_size):129a = 'Male'130b = 'Arched_Eyebrows'131new_entry = reorg(label_path,a,b)132n_examples = new_entry.shape[0]133img_names = new_entry[:,0]134img_paths = np.array([os.path.join(image_path, img_name) for img_name in img_names])135img_labels = new_entry[:,1:]136labels = img_labels[:,ATT_ID['Arched_Eyebrows']].astype(int)137att = img_labels[:,ATT_ID['Male']].astype(int)138
139train_dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels, att))140train_dataset = train_dataset.map(load_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)141train_dataset = train_dataset.shuffle(n_examples)142train_dataset = train_dataset.batch(batch_size, drop_remainder=True)143train_dataset = train_dataset.repeat().prefetch(1)144
145train_iter = train_dataset.make_one_shot_iterator()146batch = train_iter.get_next()147
148return batch, int(np.ceil(n_examples/batch_size))149
150def data_fake(image_path, label_path, batch_size):151a = 'Male'152b = 'Arched_Eyebrows'153new_entry = reorg_fake(label_path,a,b)154n_examples = new_entry.shape[0]155img_names = new_entry[:,0]156img_paths = np.array([os.path.join(image_path, img_name) for img_name in img_names])157img_labels = new_entry[:,1:]158labels = img_labels[:,ATT_ID['Arched_Eyebrows']].astype(int)159att = img_labels[:,ATT_ID['Male']].astype(int)160
161train_dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels, att))162train_dataset = train_dataset.map(load_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)163train_dataset = train_dataset.shuffle(n_examples)164train_dataset = train_dataset.batch(batch_size, drop_remainder=True)165train_dataset = train_dataset.repeat().prefetch(1)166
167train_iter = train_dataset.make_one_shot_iterator()168batch = train_iter.get_next()169
170return batch, int(np.ceil(n_examples/batch_size))171
172# load entire training dataset
173# def data_train(image_path, label_path, batch_size):
174# img_names = np.genfromtxt(label_path, dtype=str, usecols=0)
175# img_paths = np.array([os.path.join(image_path, img_name) for img_name in img_names])
176# labels = np.genfromtxt(label_path, dtype=int, usecols=range(1, 41))
177# n_examples = img_names.shape[0]
178# # labels = labels[:,ATT_ID['Male']]
179# labels = labels[:,ATT_ID['Smiling']]
180# # labels = labels[:,ATT_ID['Arched_Eyebrows']]
181
182
183# train_dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels))
184# train_dataset = train_dataset.map(load_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
185# train_dataset = train_dataset.shuffle(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN, seed=0)
186# train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
187# train_dataset = train_dataset.repeat().prefetch(1)
188
189# train_iter = train_dataset.make_one_shot_iterator()
190# batch = train_iter.get_next()
191
192# return batch, int(np.ceil(n_examples/batch_size))
193
194def data_test(image_path, label_path, batch_size):195img_names = np.genfromtxt(label_path, dtype=str, usecols=0)196img_paths = np.array([os.path.join(image_path, img_name) for img_name in img_names])197img_labels = np.genfromtxt(label_path, dtype=int, usecols=range(1, 41))198n_examples = img_names.shape[0]199labels = img_labels[:,ATT_ID['Arched_Eyebrows']]200att = img_labels[:,ATT_ID['Male']]201
202test_dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels, att))203test_dataset = test_dataset.map(load_test, num_parallel_calls=tf.data.experimental.AUTOTUNE)204test_dataset = test_dataset.batch(batch_size, drop_remainder=False)205test_dataset = test_dataset.repeat().prefetch(1)206
207test_iter = test_dataset.make_one_shot_iterator()208batch = test_iter.get_next()209
210return batch, int(np.ceil(n_examples/batch_size))211
212
213
214
215