google-research

Форк
0
223 строки · 8.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import os
17
import sys
18
import argparse
19
import numpy as np
20
import tensorflow as tf
21
import data
22
import model
23
# pylint: skip-file
24

25

26
parser = argparse.ArgumentParser()
27
parser.add_argument('--gpu', type=str, default='0', help='GPU to use [default: GPU 0]')
28
parser.add_argument('--real_path', default='../data/resize128')
29
parser.add_argument('--fake_path', default='../data/fakedata')
30
parser.add_argument('--train_label', default='../data/annotations/train_label.txt')
31
parser.add_argument('--test_label', default='../data/annotations/test_label.txt')
32
parser.add_argument('--valid_label', default='../data/annotations/val_label.txt')
33
parser.add_argument('--log_dir', default='test_log', help='Log dir [default: log]')
34
parser.add_argument('--n_episode', type=int, default=500, help='Epoch to run [default: 50]')
35
parser.add_argument('--batch_size', type=int, default=64, help='Batch size during training [default: 64]')
36
parser.add_argument('--n_class', type=int, default=2, help='Number of class [default: 2]')
37
parser.add_argument('--n_action', type=int, default=2, help='Number of action [default: 2]')
38
parser.add_argument('--lr', type=float, default=0.1, help='Initial learning rate [default: 0.1]')
39
parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]')
40
parser.add_argument('--optimizer', default='momentum', help='adam or momentum [default: momentum]')
41
FLAGS = parser.parse_args()
42

43
#####################  config  #####################
44

45
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
46
os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.gpu
47
tf.set_random_seed(100) # tf.set_random_seed(0)# 0 for 512
48
REAL_PATH = FLAGS.real_path
49
FAKE_PATH = FLAGS.fake_path
50
TRAIN_LABEL = FLAGS.train_label
51
TEST_LABEL = FLAGS.test_label
52
VALID_LABEL = FLAGS.valid_label
53
BATCH_SIZE = FLAGS.batch_size
54
N_EPISODE = FLAGS.n_episode
55
N_CLASS = FLAGS.n_class
56
N_ACTION = FLAGS.n_action
57
LR = FLAGS.lr
58
MOMENTUM = FLAGS.momentum
59
ROOT_PATH = os.path.dirname(os.path.realpath(__file__))
60
LOG_PATH = os.path.join(ROOT_PATH, FLAGS.log_dir)
61
if not os.path.exists(LOG_PATH): os.mkdir(LOG_PATH)
62
acc_count = 0
63
while True:
64
  if os.path.exists(os.path.join(LOG_PATH, 'log_%02d.txt' % acc_count)): acc_count += 1
65
  else: break
66
LOG_FNAME = 'log_%02d.txt' % acc_count
67
LOG_FOUT = open(os.path.join(LOG_PATH, LOG_FNAME), 'w')
68

69
(train_images, train_labels, train_att), train_iters = data.data_train(REAL_PATH, TRAIN_LABEL, BATCH_SIZE)
70
(fake_images, fake_labels, fake_att), fake_iters = data.data_train(FAKE_PATH, TRAIN_LABEL, BATCH_SIZE)
71
(valid_images, valid_labels, valid_att), valid_iters = data.data_test(REAL_PATH, VALID_LABEL, BATCH_SIZE)
72
(test_images, test_labels, test_att), test_iters = data.data_test(REAL_PATH, TEST_LABEL, BATCH_SIZE)
73

74
####################################################
75

76
def log_string(out_str):
77
  LOG_FOUT.write(out_str+'\n')
78
  LOG_FOUT.flush()
79
  print(out_str)
80

81
def choose_action(prob_actions):
82
  actions = []
83
  for i in range(prob_actions.shape[0]):
84
    action = np.random.choice(range(prob_actions.shape[1]), p=prob_actions[i])
85
    actions.append(action)
86
  return np.array(actions)
87

88
def vgg_graph(sess, phs):
89
  VGG = model.VGG()
90
  Y_score = VGG.build(phs['batch_images'], N_CLASS, phs['is_training_ph'])
91

92
  Y_hat = tf.nn.softmax(Y_score)
93
  Y_pred = tf.argmax(Y_hat, 1)
94
  Y_label = tf.to_float(tf.one_hot(phs['batch_labels'], N_CLASS))
95

96
  cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits = Y_score, labels = Y_label)
97
  loss_op = tf.reduce_mean(cross_entropy)
98
  correct_prediction = tf.equal(tf.argmax(Y_hat, 1), tf.argmax(Y_label, 1))
99
  acc_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
100

101
  update_op = tf.train.MomentumOptimizer(LR, MOMENTUM).minimize(loss_op, var_list=VGG.vars)
102

103
  return loss_op, acc_op, cross_entropy, Y_hat, update_op, Y_pred, VGG.vars
104

105
def rl_graph(sess, phrl):
106
  Actor = model.Actor()
107
  Y_score = Actor.build(phrl['states_rl'], N_ACTION, phrl['is_training_rl'])
108
  Y_prob =tf.nn.softmax(Y_score)
109

110

111
  neg_log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = Y_score, labels = phrl['actions_rl'])
112
  loss_op = tf.reduce_mean(neg_log_prob*phrl['values_rl'])
113

114
  # update_op = tf.train.MomentumOptimizer(LR, MOMENTUM).minimize(loss_op, var_list=Actor.vars)
115
  update_op = tf.train.AdamOptimizer(1e-3).minimize(loss_op, var_list=Actor.vars)
116

117
  return loss_op, Y_prob, update_op, Actor.vars
118

119
def train():
120
  batch_images = tf.placeholder(tf.float32,[None,128,128,3])
121
  batch_labels = tf.placeholder(tf.int32,[None,])
122
  is_training_ph = tf.placeholder(tf.bool)
123
  lr_ph = tf.placeholder(tf.float32)
124

125
  states_rl = tf.placeholder(tf.float32,[None,11])
126
  actions_rl = tf.placeholder(tf.int32,[None,])
127
  values_rl = tf.placeholder(tf.float32,[None,])
128
  is_training_rl = tf.placeholder(tf.bool)
129
  lr_rl = tf.placeholder(tf.float32)
130

131
  phs = {'batch_images': batch_images,
132
       'batch_labels': batch_labels,
133
       'is_training_ph': is_training_ph,
134
       'lr_ph': lr_ph}
135

136
  phrl = {'states_rl': states_rl,
137
       'actions_rl': actions_rl,
138
       'values_rl': values_rl,
139
       'is_training_rl': is_training_rl,
140
       'lr_rl': lr_rl}
141

142
  with tf.Session() as sess:
143
    vgg_loss, vgg_acc, vgg_ce, vgg_prob, vgg_update, vgg_pred, vgg_vars = vgg_graph(sess, phs)
144
    rl_loss, rl_prob, rl_update, rl_vars = rl_graph(sess, phrl)
145
    vgg_init = tf.variables_initializer(var_list=vgg_vars)
146
    saver = tf.train.Saver(vgg_vars)
147
    all_saver = tf.train.Saver()
148
    init = tf.global_variables_initializer()
149
    sess.run(init)
150

151

152

153

154
    for i in range(N_EPISODE):
155
      # sess.run(vgg_init)
156
      all_saver.restore(sess,LOG_PATH+'/all.ckpt')
157
      saver.restore(sess,LOG_PATH+'/vgg.ckpt')
158
      # state_list = []
159
      # action_list = []
160
      # reward_list = []
161
      for j in range(train_iters*20):
162
        tr_images, tr_labels, tr_att = sess.run([train_images,train_labels, train_att])
163
        fa_images, fa_labels, fa_att = sess.run([fake_images,fake_labels, fake_att])
164

165
        train_dict = {phs['batch_images']: tr_images,
166
                phs['batch_labels']: tr_labels,
167
                phs['is_training_ph']: False}
168
        ce, acc, prob, pred = sess.run([vgg_ce, vgg_acc, vgg_prob, vgg_pred], feed_dict=train_dict)
169
        ce = np.clip(ce, 0, 10)/10.0
170
        model_stat = list(data.cal_eo(tr_att, tr_labels, pred))
171
        model_stat.append(np.mean(ce))
172
        model_stat = np.tile(model_stat,(BATCH_SIZE,1))
173
        state = np.concatenate((tr_labels[:, np.newaxis], tr_att[:, np.newaxis], prob, ce[:, np.newaxis], model_stat), axis=1)
174

175

176

177
        rl_dict = {phrl['states_rl']: state,
178
               phrl['is_training_rl']: False}
179
        action = choose_action(sess.run(rl_prob, feed_dict=rl_dict))
180

181

182

183
        bool_train = list(map(bool,action))
184
        bool_fake = list(map(bool,1-action))
185
        co_images = np.concatenate((tr_images[bool_train],fa_images[bool_fake]),axis=0)
186
        co_labels = np.concatenate((tr_labels[bool_train],fa_labels[bool_fake]),axis=0)
187

188

189
        update_dict = {phs['batch_images']: co_images,
190
                phs['batch_labels']: co_labels,
191
                phs['is_training_ph']: True}
192
        _, ce, acc = sess.run([vgg_update, vgg_ce, vgg_acc], feed_dict=update_dict)
193

194

195
        if j % 100 == 0:
196
          print('====epoch_%d====iter_%d: loss=%.4f, train_acc=%.4f' % (i, j, np.mean(ce), acc))
197
          print(action, np.sum(action))
198

199

200
      valid_acc = 0.0
201
      y_pred =[]
202
      y_label = []
203
      y_att = []
204
      for k in range(valid_iters):
205
        va_images, va_labels, va_att = sess.run([valid_images, valid_labels, valid_att])
206
        valid_dict = {phs['batch_images']: va_images,
207
                phs['batch_labels']: va_labels,
208
                phs['is_training_ph']: False}
209
        batch_acc, batch_pred = sess.run([vgg_acc,vgg_pred], feed_dict=valid_dict)
210
        valid_acc += batch_acc
211
        y_pred += batch_pred.tolist()
212
        y_label += va_labels.tolist()
213
        y_att += va_att.tolist()
214
      valid_acc = valid_acc / float(valid_iters)
215
      valid_eo = data.cal_eo(y_att, y_label, y_pred)
216
      log_string('====epoch_%d: valid_acc=%.4f, valid_eo=%.4f' % (i, valid_acc, valid_eo[-1]))
217
      print('eo: ',valid_eo[0],valid_eo[1])
218
      print('eo: ',valid_eo[2],valid_eo[3])
219

220

221
if __name__ == "__main__":
222
  train()
223
  LOG_FOUT.close()
224

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.