google-research
119 строк · 5.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import os
17import sys
18import argparse
19import numpy as np
20import tensorflow as tf
21import data
22import model
23# pylint: skip-file
24
25parser = argparse.ArgumentParser()
26parser.add_argument('--gpu', type=str, default='0', help='GPU to use [default: GPU 0]')
27parser.add_argument('--real_path', default='../data/resize128')
28parser.add_argument('--fake_path', default='../data/fake')
29parser.add_argument('--train_label', default='../data/annotations/train_label.txt')
30parser.add_argument('--test_label', default='../data/annotations/test_label.txt')
31parser.add_argument('--valid_label', default='../data/annotations/val_label.txt')
32parser.add_argument('--max_epoch', type=int, default=20, help='Epoch to run [default: 20]')
33parser.add_argument('--batch_size', type=int, default=64, help='Batch Size during training [default: 64]')
34parser.add_argument('--n_class', type=int, default=2, help='Number of class [default: 2]')
35parser.add_argument('--lr', type=float, default=0.1, help='Initial learning rate [default: 0.1]')
36parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]')
37parser.add_argument('--optimizer', default='momentum', help='adam or momentum [default: momentum]')
38FLAGS = parser.parse_args()
39
40
41
42ATT_ID = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2,
43'Bags_Under_Eyes': 3, 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6,
44'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, 'Blurry': 10,
45'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13,
46'Double_Chin': 14, 'Eyeglasses': 15, 'Goatee': 16,
47'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19,
48'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22,
49'Narrow_Eyes': 23, 'No_Beard': 24, 'Oval_Face': 25,
50'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28,
51'Rosy_Cheeks': 29, 'Sideburns': 30, 'Smiling': 31,
52'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34,
53'Wearing_Hat': 35, 'Wearing_Lipstick': 36,
54'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39}
55
56os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
57os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.gpu
58# tf.set_random_seed(0)# 0 for 512
59tf.set_random_seed(100)
60
61(train_images, train_labels, train_att), train_iters = data.data_train(FLAGS.real_path, FLAGS.train_label, 64)
62(fake_images, fake_labels, fake_att), fake_iters = data.data_fake(FLAGS.fake_path, FLAGS.train_label, 64)
63(valid_images, valid_labels, valid_att), valid_iters = data.data_test(FLAGS.real_path, FLAGS.valid_label, FLAGS.batch_size)
64(test_images, test_labels, test_att), test_iters = data.data_test(FLAGS.real_path, FLAGS.test_label, FLAGS.batch_size)
65
66batch_images = tf.placeholder(tf.float32,[None,128,128,3])
67batch_labels = tf.placeholder(tf.int32,[None,])
68is_training = tf.placeholder(tf.bool)
69lr_ph = tf.placeholder(tf.float32)
70lr = FLAGS.lr
71
72Y_score = model.vgg(batch_images, FLAGS.n_class, is_training)
73Y_hat = tf.nn.softmax(Y_score)
74Y_pred = tf.argmax(Y_hat, 1)
75Y_label = tf.to_float(tf.one_hot(batch_labels, FLAGS.n_class))
76
77cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits = Y_score, labels = Y_label)
78loss_op = tf.reduce_mean(cross_entropy)
79correct_prediction = tf.equal(tf.argmax(Y_hat, 1), tf.argmax(Y_label, 1))
80acc_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
81update_op = tf.train.MomentumOptimizer(lr_ph, FLAGS.momentum).minimize(loss_op)
82init = tf.global_variables_initializer()
83
84print("================\n\n",train_iters, fake_iters)
85
86with tf.Session() as sess:
87sess.run(init)
88for i in range(FLAGS.max_epoch):
89if i == 30:
90lr *= 0.1
91elif i == 40:
92lr *= 0.1
93
94for j in range(train_iters):
95# co_images, co_labels = sess.run([train_images,train_labels])
96tr_images, tr_labels = sess.run([train_images,train_labels])
97fa_images, fa_labels = sess.run([fake_images,fake_labels])
98co_images = np.concatenate((tr_images,fa_images),axis=0)
99co_labels = np.concatenate((tr_labels,fa_labels),axis=0)
100loss, acc, _ = sess.run([loss_op, acc_op, update_op], {batch_images:co_images, batch_labels:co_labels, lr_ph:lr, is_training:True})
101if j % 50 == 0:
102print('====epoch_%d====iter_%d: loss=%.4f, train_acc=%.4f' % (i, j, loss, acc))
103
104valid_acc = 0.0
105y_pred =[]
106y_label = []
107y_att = []
108for k in range(valid_iters):
109va_images, va_labels, va_att = sess.run([valid_images, valid_labels, valid_att])
110batch_acc, batch_pred = sess.run([acc_op,Y_pred], {batch_images:va_images, batch_labels:va_labels, is_training:False})
111valid_acc += batch_acc
112y_pred += batch_pred.tolist()
113y_label += va_labels.tolist()
114y_att += va_att.tolist()
115valid_acc = valid_acc / float(valid_iters)
116valid_eo = data.cal_eo(y_att, y_label, y_pred)
117print('====epoch_%d: valid_acc=%.4f, valid_eo=%.4f' % (i, valid_acc, valid_eo[-1]))
118print('eo: ',valid_eo[0],valid_eo[1])
119print('eo: ',valid_eo[2],valid_eo[3])
120
121
122