google-research
223 строки · 8.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import os
17import sys
18import argparse
19import numpy as np
20import tensorflow as tf
21import data
22import model
23# pylint: skip-file
24
25
26parser = argparse.ArgumentParser()
27parser.add_argument('--gpu', type=str, default='0', help='GPU to use [default: GPU 0]')
28parser.add_argument('--real_path', default='../data/resize128')
29parser.add_argument('--fake_path', default='../data/fakedata')
30parser.add_argument('--train_label', default='../data/annotations/train_label.txt')
31parser.add_argument('--test_label', default='../data/annotations/test_label.txt')
32parser.add_argument('--valid_label', default='../data/annotations/val_label.txt')
33parser.add_argument('--log_dir', default='test_log', help='Log dir [default: log]')
34parser.add_argument('--n_episode', type=int, default=500, help='Epoch to run [default: 50]')
35parser.add_argument('--batch_size', type=int, default=64, help='Batch size during training [default: 64]')
36parser.add_argument('--n_class', type=int, default=2, help='Number of class [default: 2]')
37parser.add_argument('--n_action', type=int, default=2, help='Number of action [default: 2]')
38parser.add_argument('--lr', type=float, default=0.1, help='Initial learning rate [default: 0.1]')
39parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]')
40parser.add_argument('--optimizer', default='momentum', help='adam or momentum [default: momentum]')
41FLAGS = parser.parse_args()
42
43##################### config #####################
44
45os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
46os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.gpu
47tf.set_random_seed(100) # tf.set_random_seed(0)# 0 for 512
48REAL_PATH = FLAGS.real_path
49FAKE_PATH = FLAGS.fake_path
50TRAIN_LABEL = FLAGS.train_label
51TEST_LABEL = FLAGS.test_label
52VALID_LABEL = FLAGS.valid_label
53BATCH_SIZE = FLAGS.batch_size
54N_EPISODE = FLAGS.n_episode
55N_CLASS = FLAGS.n_class
56N_ACTION = FLAGS.n_action
57LR = FLAGS.lr
58MOMENTUM = FLAGS.momentum
59ROOT_PATH = os.path.dirname(os.path.realpath(__file__))
60LOG_PATH = os.path.join(ROOT_PATH, FLAGS.log_dir)
61if not os.path.exists(LOG_PATH): os.mkdir(LOG_PATH)
62acc_count = 0
63while True:
64if os.path.exists(os.path.join(LOG_PATH, 'log_%02d.txt' % acc_count)): acc_count += 1
65else: break
66LOG_FNAME = 'log_%02d.txt' % acc_count
67LOG_FOUT = open(os.path.join(LOG_PATH, LOG_FNAME), 'w')
68
69(train_images, train_labels, train_att), train_iters = data.data_train(REAL_PATH, TRAIN_LABEL, BATCH_SIZE)
70(fake_images, fake_labels, fake_att), fake_iters = data.data_train(FAKE_PATH, TRAIN_LABEL, BATCH_SIZE)
71(valid_images, valid_labels, valid_att), valid_iters = data.data_test(REAL_PATH, VALID_LABEL, BATCH_SIZE)
72(test_images, test_labels, test_att), test_iters = data.data_test(REAL_PATH, TEST_LABEL, BATCH_SIZE)
73
74####################################################
75
76def log_string(out_str):
77LOG_FOUT.write(out_str+'\n')
78LOG_FOUT.flush()
79print(out_str)
80
81def choose_action(prob_actions):
82actions = []
83for i in range(prob_actions.shape[0]):
84action = np.random.choice(range(prob_actions.shape[1]), p=prob_actions[i])
85actions.append(action)
86return np.array(actions)
87
88def vgg_graph(sess, phs):
89VGG = model.VGG()
90Y_score = VGG.build(phs['batch_images'], N_CLASS, phs['is_training_ph'])
91
92Y_hat = tf.nn.softmax(Y_score)
93Y_pred = tf.argmax(Y_hat, 1)
94Y_label = tf.to_float(tf.one_hot(phs['batch_labels'], N_CLASS))
95
96cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits = Y_score, labels = Y_label)
97loss_op = tf.reduce_mean(cross_entropy)
98correct_prediction = tf.equal(tf.argmax(Y_hat, 1), tf.argmax(Y_label, 1))
99acc_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
100
101update_op = tf.train.MomentumOptimizer(LR, MOMENTUM).minimize(loss_op, var_list=VGG.vars)
102
103return loss_op, acc_op, cross_entropy, Y_hat, update_op, Y_pred, VGG.vars
104
105def rl_graph(sess, phrl):
106Actor = model.Actor()
107Y_score = Actor.build(phrl['states_rl'], N_ACTION, phrl['is_training_rl'])
108Y_prob =tf.nn.softmax(Y_score)
109
110
111neg_log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = Y_score, labels = phrl['actions_rl'])
112loss_op = tf.reduce_mean(neg_log_prob*phrl['values_rl'])
113
114# update_op = tf.train.MomentumOptimizer(LR, MOMENTUM).minimize(loss_op, var_list=Actor.vars)
115update_op = tf.train.AdamOptimizer(1e-3).minimize(loss_op, var_list=Actor.vars)
116
117return loss_op, Y_prob, update_op, Actor.vars
118
119def train():
120batch_images = tf.placeholder(tf.float32,[None,128,128,3])
121batch_labels = tf.placeholder(tf.int32,[None,])
122is_training_ph = tf.placeholder(tf.bool)
123lr_ph = tf.placeholder(tf.float32)
124
125states_rl = tf.placeholder(tf.float32,[None,11])
126actions_rl = tf.placeholder(tf.int32,[None,])
127values_rl = tf.placeholder(tf.float32,[None,])
128is_training_rl = tf.placeholder(tf.bool)
129lr_rl = tf.placeholder(tf.float32)
130
131phs = {'batch_images': batch_images,
132'batch_labels': batch_labels,
133'is_training_ph': is_training_ph,
134'lr_ph': lr_ph}
135
136phrl = {'states_rl': states_rl,
137'actions_rl': actions_rl,
138'values_rl': values_rl,
139'is_training_rl': is_training_rl,
140'lr_rl': lr_rl}
141
142with tf.Session() as sess:
143vgg_loss, vgg_acc, vgg_ce, vgg_prob, vgg_update, vgg_pred, vgg_vars = vgg_graph(sess, phs)
144rl_loss, rl_prob, rl_update, rl_vars = rl_graph(sess, phrl)
145vgg_init = tf.variables_initializer(var_list=vgg_vars)
146saver = tf.train.Saver(vgg_vars)
147all_saver = tf.train.Saver()
148init = tf.global_variables_initializer()
149sess.run(init)
150
151
152
153
154for i in range(N_EPISODE):
155# sess.run(vgg_init)
156all_saver.restore(sess,LOG_PATH+'/all.ckpt')
157saver.restore(sess,LOG_PATH+'/vgg.ckpt')
158# state_list = []
159# action_list = []
160# reward_list = []
161for j in range(train_iters*20):
162tr_images, tr_labels, tr_att = sess.run([train_images,train_labels, train_att])
163fa_images, fa_labels, fa_att = sess.run([fake_images,fake_labels, fake_att])
164
165train_dict = {phs['batch_images']: tr_images,
166phs['batch_labels']: tr_labels,
167phs['is_training_ph']: False}
168ce, acc, prob, pred = sess.run([vgg_ce, vgg_acc, vgg_prob, vgg_pred], feed_dict=train_dict)
169ce = np.clip(ce, 0, 10)/10.0
170model_stat = list(data.cal_eo(tr_att, tr_labels, pred))
171model_stat.append(np.mean(ce))
172model_stat = np.tile(model_stat,(BATCH_SIZE,1))
173state = np.concatenate((tr_labels[:, np.newaxis], tr_att[:, np.newaxis], prob, ce[:, np.newaxis], model_stat), axis=1)
174
175
176
177rl_dict = {phrl['states_rl']: state,
178phrl['is_training_rl']: False}
179action = choose_action(sess.run(rl_prob, feed_dict=rl_dict))
180
181
182
183bool_train = list(map(bool,action))
184bool_fake = list(map(bool,1-action))
185co_images = np.concatenate((tr_images[bool_train],fa_images[bool_fake]),axis=0)
186co_labels = np.concatenate((tr_labels[bool_train],fa_labels[bool_fake]),axis=0)
187
188
189update_dict = {phs['batch_images']: co_images,
190phs['batch_labels']: co_labels,
191phs['is_training_ph']: True}
192_, ce, acc = sess.run([vgg_update, vgg_ce, vgg_acc], feed_dict=update_dict)
193
194
195if j % 100 == 0:
196print('====epoch_%d====iter_%d: loss=%.4f, train_acc=%.4f' % (i, j, np.mean(ce), acc))
197print(action, np.sum(action))
198
199
200valid_acc = 0.0
201y_pred =[]
202y_label = []
203y_att = []
204for k in range(valid_iters):
205va_images, va_labels, va_att = sess.run([valid_images, valid_labels, valid_att])
206valid_dict = {phs['batch_images']: va_images,
207phs['batch_labels']: va_labels,
208phs['is_training_ph']: False}
209batch_acc, batch_pred = sess.run([vgg_acc,vgg_pred], feed_dict=valid_dict)
210valid_acc += batch_acc
211y_pred += batch_pred.tolist()
212y_label += va_labels.tolist()
213y_att += va_att.tolist()
214valid_acc = valid_acc / float(valid_iters)
215valid_eo = data.cal_eo(y_att, y_label, y_pred)
216log_string('====epoch_%d: valid_acc=%.4f, valid_eo=%.4f' % (i, valid_acc, valid_eo[-1]))
217print('eo: ',valid_eo[0],valid_eo[1])
218print('eo: ',valid_eo[2],valid_eo[3])
219
220
221if __name__ == "__main__":
222train()
223LOG_FOUT.close()
224