google-research
120 строк · 5.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import os17import sys18import argparse19import numpy as np20import tensorflow as tf21import data22import model23
24# pylint: skip-file
25
26parser = argparse.ArgumentParser()27parser.add_argument('--gpu', type=str, default='0', help='GPU to use [default: GPU 0]')28parser.add_argument('--real_path', default='../data/resize128')29parser.add_argument('--fake_path', default='../data/fake')30parser.add_argument('--train_label', default='../data/annotations/train_label.txt')31parser.add_argument('--test_label', default='../data/annotations/test_label.txt')32parser.add_argument('--valid_label', default='../data/annotations/val_label.txt')33parser.add_argument('--max_epoch', type=int, default=20, help='Epoch to run [default: 20]')34parser.add_argument('--batch_size', type=int, default=64, help='Batch Size during training [default: 64]')35parser.add_argument('--n_class', type=int, default=2, help='Number of class [default: 2]')36parser.add_argument('--lr', type=float, default=0.1, help='Initial learning rate [default: 0.1]')37parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]')38parser.add_argument('--optimizer', default='momentum', help='adam or momentum [default: momentum]')39FLAGS = parser.parse_args()40
41
42
43ATT_ID = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2,44'Bags_Under_Eyes': 3, 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6,45'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, 'Blurry': 10,46'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13,47'Double_Chin': 14, 'Eyeglasses': 15, 'Goatee': 16,48'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19,49'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22,50'Narrow_Eyes': 23, 'No_Beard': 24, 'Oval_Face': 25,51'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28,52'Rosy_Cheeks': 29, 'Sideburns': 30, 'Smiling': 31,53'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34,54'Wearing_Hat': 35, 'Wearing_Lipstick': 36,55'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39}56
57os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"58os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.gpu59# tf.set_random_seed(0)# 0 for 512
60tf.set_random_seed(100)61
62(train_images, train_labels, train_att), train_iters = data.data_train(FLAGS.real_path, FLAGS.train_label, 64)63(fake_images, fake_labels, fake_att), fake_iters = data.data_fake(FLAGS.fake_path, FLAGS.train_label, 64)64(valid_images, valid_labels, valid_att), valid_iters = data.data_test(FLAGS.real_path, FLAGS.valid_label, FLAGS.batch_size)65(test_images, test_labels, test_att), test_iters = data.data_test(FLAGS.real_path, FLAGS.test_label, FLAGS.batch_size)66
67batch_images = tf.placeholder(tf.float32,[None,128,128,3])68batch_labels = tf.placeholder(tf.int32,[None,])69is_training = tf.placeholder(tf.bool)70lr_ph = tf.placeholder(tf.float32)71lr = FLAGS.lr72
73Y_score = model.vgg(batch_images, FLAGS.n_class, is_training)74Y_hat = tf.nn.softmax(Y_score)75Y_pred = tf.argmax(Y_hat, 1)76Y_label = tf.to_float(tf.one_hot(batch_labels, FLAGS.n_class))77
78cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits = Y_score, labels = Y_label)79loss_op = tf.reduce_mean(cross_entropy)80correct_prediction = tf.equal(tf.argmax(Y_hat, 1), tf.argmax(Y_label, 1))81acc_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))82update_op = tf.train.MomentumOptimizer(lr_ph, FLAGS.momentum).minimize(loss_op)83init = tf.global_variables_initializer()84
85print("================\n\n",train_iters, fake_iters)86
87with tf.Session() as sess:88sess.run(init)89for i in range(FLAGS.max_epoch):90if i == 30:91lr *= 0.192elif i == 40:93lr *= 0.194
95for j in range(train_iters):96co_images, co_labels = sess.run([fake_images,fake_labels])97# tr_images, tr_labels = sess.run([train_images,train_labels])98# fa_images, fa_labels = sess.run([fake_images,fake_labels])99# co_images = np.concatenate((tr_images,fa_images),axis=0)100# co_labels = np.concatenate((tr_labels,fa_labels),axis=0)101loss, acc, _ = sess.run([loss_op, acc_op, update_op], {batch_images:co_images, batch_labels:co_labels, lr_ph:lr, is_training:True})102if j % 50 == 0:103print('====epoch_%d====iter_%d: loss=%.4f, train_acc=%.4f' % (i, j, loss, acc))104
105valid_acc = 0.0106y_pred =[]107y_label = []108y_att = []109for k in range(valid_iters):110va_images, va_labels, va_att = sess.run([valid_images, valid_labels, valid_att])111batch_acc, batch_pred = sess.run([acc_op,Y_pred], {batch_images:va_images, batch_labels:va_labels, is_training:False})112valid_acc += batch_acc113y_pred += batch_pred.tolist()114y_label += va_labels.tolist()115y_att += va_att.tolist()116valid_acc = valid_acc / float(valid_iters)117valid_eo = data.cal_eo(y_att, y_label, y_pred)118print('====epoch_%d: valid_acc=%.4f, valid_eo=%.4f' % (i, valid_acc, valid_eo[-1]))119print('eo: ',valid_eo[0],valid_eo[1])120print('eo: ',valid_eo[2],valid_eo[3])121
122
123