google-research
119 строк · 5.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import os17import sys18import argparse19import numpy as np20import tensorflow as tf21import data22import model23# pylint: skip-file
24
25parser = argparse.ArgumentParser()26parser.add_argument('--gpu', type=str, default='0', help='GPU to use [default: GPU 0]')27parser.add_argument('--real_path', default='../data/resize128')28parser.add_argument('--fake_path', default='../data/fake')29parser.add_argument('--train_label', default='../data/annotations/train_label.txt')30parser.add_argument('--test_label', default='../data/annotations/test_label.txt')31parser.add_argument('--valid_label', default='../data/annotations/val_label.txt')32parser.add_argument('--max_epoch', type=int, default=20, help='Epoch to run [default: 20]')33parser.add_argument('--batch_size', type=int, default=64, help='Batch Size during training [default: 64]')34parser.add_argument('--n_class', type=int, default=2, help='Number of class [default: 2]')35parser.add_argument('--lr', type=float, default=0.1, help='Initial learning rate [default: 0.1]')36parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]')37parser.add_argument('--optimizer', default='momentum', help='adam or momentum [default: momentum]')38FLAGS = parser.parse_args()39
40
41
42ATT_ID = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2,43'Bags_Under_Eyes': 3, 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6,44'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, 'Blurry': 10,45'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13,46'Double_Chin': 14, 'Eyeglasses': 15, 'Goatee': 16,47'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19,48'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22,49'Narrow_Eyes': 23, 'No_Beard': 24, 'Oval_Face': 25,50'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28,51'Rosy_Cheeks': 29, 'Sideburns': 30, 'Smiling': 31,52'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34,53'Wearing_Hat': 35, 'Wearing_Lipstick': 36,54'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39}55
56os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"57os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.gpu58# tf.set_random_seed(0)# 0 for 512
59tf.set_random_seed(100)60
61(train_images, train_labels, train_att), train_iters = data.data_train(FLAGS.real_path, FLAGS.train_label, 64)62(fake_images, fake_labels, fake_att), fake_iters = data.data_fake(FLAGS.fake_path, FLAGS.train_label, 64)63(valid_images, valid_labels, valid_att), valid_iters = data.data_test(FLAGS.real_path, FLAGS.valid_label, FLAGS.batch_size)64(test_images, test_labels, test_att), test_iters = data.data_test(FLAGS.real_path, FLAGS.test_label, FLAGS.batch_size)65
66batch_images = tf.placeholder(tf.float32,[None,128,128,3])67batch_labels = tf.placeholder(tf.int32,[None,])68is_training = tf.placeholder(tf.bool)69lr_ph = tf.placeholder(tf.float32)70lr = FLAGS.lr71
72Y_score = model.vgg(batch_images, FLAGS.n_class, is_training)73Y_hat = tf.nn.softmax(Y_score)74Y_pred = tf.argmax(Y_hat, 1)75Y_label = tf.to_float(tf.one_hot(batch_labels, FLAGS.n_class))76
77cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits = Y_score, labels = Y_label)78loss_op = tf.reduce_mean(cross_entropy)79correct_prediction = tf.equal(tf.argmax(Y_hat, 1), tf.argmax(Y_label, 1))80acc_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))81update_op = tf.train.MomentumOptimizer(lr_ph, FLAGS.momentum).minimize(loss_op)82init = tf.global_variables_initializer()83
84print("================\n\n",train_iters, fake_iters)85
86with tf.Session() as sess:87sess.run(init)88for i in range(FLAGS.max_epoch):89if i == 30:90lr *= 0.191elif i == 40:92lr *= 0.193
94for j in range(train_iters):95co_images, co_labels = sess.run([train_images,train_labels])96# tr_images, tr_labels = sess.run([train_images,train_labels])97# fa_images, fa_labels = sess.run([fake_images,fake_labels])98# co_images = np.concatenate((tr_images,fa_images),axis=0)99# co_labels = np.concatenate((tr_labels,fa_labels),axis=0)100loss, acc, _ = sess.run([loss_op, acc_op, update_op], {batch_images:co_images, batch_labels:co_labels, lr_ph:lr, is_training:True})101if j % 50 == 0:102print('====epoch_%d====iter_%d: loss=%.4f, train_acc=%.4f' % (i, j, loss, acc))103
104valid_acc = 0.0105y_pred =[]106y_label = []107y_att = []108for k in range(valid_iters):109va_images, va_labels, va_att = sess.run([valid_images, valid_labels, valid_att])110batch_acc, batch_pred = sess.run([acc_op,Y_pred], {batch_images:va_images, batch_labels:va_labels, is_training:False})111valid_acc += batch_acc112y_pred += batch_pred.tolist()113y_label += va_labels.tolist()114y_att += va_att.tolist()115valid_acc = valid_acc / float(valid_iters)116valid_eo = data.cal_eo(y_att, y_label, y_pred)117print('====epoch_%d: valid_acc=%.4f, valid_eo=%.4f' % (i, valid_acc, valid_eo[-1]))118print('eo: ',valid_eo[0],valid_eo[1])119print('eo: ',valid_eo[2],valid_eo[3])120
121
122