google-research

Форк
0
360 строк · 12.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Main script for all experiments.
17
"""
18

19
# pylint: disable=g-bad-import-order, unused-import, g-multiple-import
20
# pylint: disable=line-too-long, missing-docstring, g-importing-member
21
# pylint: disable=no-value-for-parameter, unused-argument
22
import gin
23
import matplotlib.pyplot as plt
24
import numpy as np
25
import os
26
import tensorflow.compat.v1 as tf
27
import time
28
from absl import app
29
from absl import flags
30
from tensorflow.compat.v1 import gfile
31
from tqdm import tqdm
32

33
from weak_disentangle import datasets, viz, networks, evaluate
34
from weak_disentangle import utils as ut
35
tf.enable_v2_behavior()
36
tfk = tf.keras
37

38

39
@gin.configurable
40
def train(dset_name, s_dim, n_dim, factors,
41
          batch_size, dec_lr, enc_lr_mul, iterations,
42
          model_type="gen"):
43
  ut.log("In train")
44
  masks = datasets.make_masks(factors, s_dim)
45
  z_dim = s_dim + n_dim
46
  enc_lr = enc_lr_mul * dec_lr
47

48
  # Load data
49
  dset = datasets.get_dlib_data(dset_name)
50
  if dset is None:
51
    x_shape = [64, 64, 1]
52
  else:
53
    x_shape = dset.observation_shape
54
    targets_real = tf.ones((batch_size, 1))
55
    targets_fake = tf.zeros((batch_size, 1))
56
    targets = tf.concat((targets_real, targets_fake), axis=0)
57

58
  # Networks
59
  if model_type == "gen":
60
    assert factors.split("=")[0] in {"c", "s", "cs", "r"}
61
    y_dim = len(masks)
62
    dis = networks.Discriminator(x_shape, y_dim)
63
    gen = networks.Generator(x_shape, z_dim)
64
    enc = networks.Encoder(x_shape, s_dim)  # Encoder ignores nuisance param
65
    ut.log(dis.read(dis.WITH_VARS))
66
    ut.log(gen.read(gen.WITH_VARS))
67
    ut.log(enc.read(enc.WITH_VARS))
68
  elif model_type == "enc":
69
    assert factors.split("=")[0] in {"r"}
70
    enc = networks.Encoder(x_shape, s_dim)  # Encoder ignores nuisance param
71
    ut.log(enc.read(enc.WITH_VARS))
72
  elif model_type == "van":
73
    assert factors.split("=")[0] in {"l"}
74
    dis = networks.LabelDiscriminator(x_shape, s_dim)  # Uses s_dim
75
    gen = networks.Generator(x_shape, z_dim)
76
    enc = networks.Encoder(x_shape, s_dim)  # Encoder ignores nuisance param
77
    ut.log(dis.read(dis.WITH_VARS))
78
    ut.log(gen.read(gen.WITH_VARS))
79
    ut.log(enc.read(enc.WITH_VARS))
80

81
  # Create optimizers
82
  if model_type in {"gen", "van"}:
83
    gen_opt = tfk.optimizers.Adam(learning_rate=dec_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8)
84
    dis_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8)
85
    enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8)
86
  elif model_type == "enc":
87
    enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8)
88

89
  @tf.function
90
  def train_gen_step(x1_real, x2_real, y_real):
91
    gen.train()
92
    dis.train()
93
    enc.train()
94
    # Alternate discriminator step and generator step
95
    with tf.GradientTape(persistent=True) as tape:
96
      # Generate
97
      z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks)
98
      x1_fake = tf.stop_gradient(gen(z1))
99
      x2_fake = tf.stop_gradient(gen(z2))
100

101
      # Discriminate
102
      x1 = tf.concat((x1_real, x1_fake), 0)
103
      x2 = tf.concat((x2_real, x2_fake), 0)
104
      y = tf.concat((y_real, y_fake), 0)
105
      logits = dis(x1, x2, y)
106

107
      # Encode
108
      p_z = enc(x1_fake)
109

110
      dis_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
111
          logits=logits, labels=targets))
112
      # Encoder ignores nuisance parameters (if they exist)
113
      enc_loss = -tf.reduce_mean(p_z.log_prob(z1[:, :s_dim]))
114

115
    dis_grads = tape.gradient(dis_loss, dis.trainable_variables)
116
    enc_grads = tape.gradient(enc_loss, enc.trainable_variables)
117

118
    dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables))
119
    enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables))
120

121
    with tf.GradientTape(persistent=False) as tape:
122
      # Generate
123
      z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks)
124
      x1_fake = gen(z1)
125
      x2_fake = gen(z2)
126

127
      # Discriminate
128
      logits_fake = dis(x1_fake, x2_fake, y_fake)
129

130
      gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
131
          logits=logits_fake, labels=targets_real))
132

133
    gen_grads = tape.gradient(gen_loss, gen.trainable_variables)
134
    gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables))
135

136
    return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss)
137

138
  @tf.function
139
  def train_van_step(x_real, y_real):
140
    gen.train()
141
    dis.train()
142
    enc.train()
143

144
    if n_dim > 0:
145
      padding = tf.zeros((y_real.shape[0], n_dim))
146
      y_real_pad = tf.concat((y_real, padding), axis=-1)
147
    else:
148
      y_real_pad = y_real
149

150
    # Alternate discriminator step and generator step
151
    with tf.GradientTape(persistent=False) as tape:
152
      # Generate
153
      z_fake = datasets.paired_randn(batch_size, z_dim, masks)
154
      z_fake = z_fake + y_real_pad
155
      x_fake = gen(z_fake)
156

157
      # Discriminate
158
      logits_fake = dis(x_fake, y_real)
159

160
      gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
161
          logits=logits_fake, labels=targets_real))
162

163
    gen_grads = tape.gradient(gen_loss, gen.trainable_variables)
164
    gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables))
165

166
    with tf.GradientTape(persistent=True) as tape:
167
      # Generate
168
      z_fake = datasets.paired_randn(batch_size, z_dim, masks)
169
      z_fake = z_fake + y_real_pad
170
      x_fake = tf.stop_gradient(gen(z_fake))
171

172
      # Discriminate
173
      x = tf.concat((x_real, x_fake), 0)
174
      y = tf.concat((y_real, y_real), 0)
175
      logits = dis(x, y)
176

177
      # Encode
178
      p_z = enc(x_fake)
179

180
      dis_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
181
          logits=logits, labels=targets))
182
      # Encoder ignores nuisance parameters (if they exist)
183
      enc_loss = -tf.reduce_mean(p_z.log_prob(z_fake[:, :s_dim]))
184

185
    dis_grads = tape.gradient(dis_loss, dis.trainable_variables)
186
    enc_grads = tape.gradient(enc_loss, enc.trainable_variables)
187

188
    dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables))
189
    enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables))
190

191
    return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss)
192

193
  @tf.function
194
  def train_enc_step(x1_real, x2_real, y_real):
195
    with tf.GradientTape() as tape:
196
      z1 = enc(x1_real).mean()
197
      z2 = enc(x2_real).mean()
198
      logits = tf.gather(z1 - z2, masks, axis=-1)
199
      loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
200
          logits=logits, labels=y_real))
201

202
    enc_grads = tape.gradient(loss, enc.trainable_variables)
203
    enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables))
204
    return dict(gen_loss=0, dis_loss=0, enc_loss=loss)
205

206
  @tf.function
207
  def gen_eval(z):
208
    gen.eval()
209
    return gen(z)
210

211
  @tf.function
212
  def enc_eval(x):
213
    enc.eval()
214
    return enc(x).mean()
215
  enc_np = lambda x: enc_eval(x).numpy()
216

217
  # Initial preparation
218
  if FLAGS.debug:
219
    iter_log = 100
220
    iter_save = 2000
221
    train_range = range(iterations)
222
    basedir = FLAGS.basedir
223
    vizdir = FLAGS.basedir
224
    ckptdir = FLAGS.basedir
225
    new_run = True
226
  else:
227
    iter_log = 5000
228
    iter_save = 50000
229
    iter_metric = iter_save * 5  # Make sure this is a factor of 500k
230
    basedir = os.path.join(FLAGS.basedir, "exp")
231
    ckptdir = os.path.join(basedir, "ckptdir")
232
    vizdir = os.path.join(basedir, "vizdir")
233
    gfile.MakeDirs(basedir)
234
    gfile.MakeDirs(ckptdir)
235
    gfile.MakeDirs(vizdir)  # train_range will be specified below
236

237
  ckpt_prefix = os.path.join(ckptdir, "model")
238
  if model_type in {"gen", "van"}:
239
    ckpt_root = tf.train.Checkpoint(dis=dis, dis_opt=dis_opt,
240
                                    gen=gen, gen_opt=gen_opt,
241
                                    enc=enc, enc_opt=enc_opt)
242
  elif model_type == "enc":
243
    ckpt_root = tf.train.Checkpoint(enc=enc, enc_opt=enc_opt)
244

245
  # Check if we're resuming training if not in debugging mode
246
  if not FLAGS.debug:
247
    latest_ckpt = tf.train.latest_checkpoint(ckptdir)
248
    if latest_ckpt is None:
249
      new_run = True
250
      ut.log("Starting a completely new model")
251
      train_range = range(iterations)
252

253
    else:
254
      new_run = False
255
      ut.log("Restarting from {}".format(latest_ckpt))
256
      ckpt_root.restore(latest_ckpt)
257
      resuming_iteration = iter_save * (int(ckpt_root.save_counter) - 1)
258
      train_range = range(resuming_iteration, iterations)
259

260
  # Training
261
  if dset is None:
262
    ut.log("Dataset {} is not available".format(dset_name))
263
    ut.log("Ending program having checked that the networks can be built.")
264
    return
265

266
  batches = datasets.paired_data_generator(dset, masks).repeat().batch(batch_size).prefetch(1000)
267
  batches = iter(batches)
268
  start_time = time.time()
269
  train_time = 0
270

271
  if FLAGS.debug:
272
    train_range = tqdm(train_range)
273

274
  for global_step in train_range:
275
    stopwatch = time.time()
276
    if model_type == "gen":
277
      x1, x2, y = next(batches)
278
      vals = train_gen_step(x1, x2, y)
279
    elif model_type == "enc":
280
      x1, x2, y = next(batches)
281
      vals = train_enc_step(x1, x2, y)
282
    elif model_type == "van":
283
      x, y = next(batches)
284
      vals = train_van_step(x, y)
285
    train_time += time.time() - stopwatch
286

287
    # Generic bookkeeping
288
    if (global_step + 1) % iter_log == 0 or global_step == 0:
289
      elapsed_time = time.time() - start_time
290
      string = ", ".join((
291
          "Iter: {:07d}, Elapsed: {:.3e}, (Elapsed) Iter/s: {:.3e}, (Train Step) Iter/s: {:.3e}".format(
292
              global_step, elapsed_time, global_step / elapsed_time, global_step / train_time),
293
          "Gen: {gen_loss:.4f}, Dis: {dis_loss:.4f}, Enc: {enc_loss:.4f}".format(**vals)
294
      )) + "."
295
      ut.log(string)
296

297
    # Log visualizations and evaluations
298
    if (global_step + 1) % iter_save == 0 or global_step == 0:
299
      if model_type == "gen":
300
        viz.ablation_visualization(x1, x2, gen_eval, z_dim, vizdir, global_step + 1)
301
      elif model_type == "van":
302
        viz.ablation_visualization(x, x, gen_eval, z_dim, vizdir, global_step + 1)
303

304
      if FLAGS.debug:
305
        evaluate.evaluate_enc(enc_np, dset, s_dim,
306
                              FLAGS.gin_file,
307
                              FLAGS.gin_bindings,
308
                              pida_sample_size=1000,
309
                              dlib_metrics=FLAGS.debug_dlib_metrics)
310

311
      else:
312
        dlib_metrics = (global_step + 1) % iter_metric == 0
313
        evaluate.evaluate_enc(enc_np, dset, s_dim,
314
                              FLAGS.gin_file,
315
                              FLAGS.gin_bindings,
316
                              pida_sample_size=10000,
317
                              dlib_metrics=dlib_metrics)
318

319
    # Save model
320
    if (global_step + 1) % iter_save == 0 or (global_step == 0 and new_run):
321
      # Save model only after ensuring all measurements are taken.
322
      # This ensures that restarts always computes the evals
323
      ut.log("Saved to", ckpt_root.save(ckpt_prefix))
324

325

326
def main(_):
327
  if FLAGS.debug:
328
    FLAGS.gin_bindings += ["log.debug = True"]
329
  gin.parse_config_files_and_bindings(
330
      [FLAGS.gin_file],
331
      FLAGS.gin_bindings,
332
      finalize_config=False)
333
  ut.log("\n" + "*" * 80 + "\nBegin program\n" + "*" * 80)
334
  ut.log("In main")
335
  train()
336
  ut.log("\n" + "*" * 80 + "\nEnd program\n" + "*" * 80)
337

338

339
if __name__ == "__main__":
340
  FLAGS = flags.FLAGS
341
  flags.DEFINE_string(
342
      "basedir",
343
      "/tmp",
344
      "Path to directory where to store results.")
345
  flags.DEFINE_boolean(
346
      "debug",
347
      False,
348
      "Flag debugging mode (shorter run-times, etc)")
349
  flags.DEFINE_boolean(
350
      "debug_dlib_metrics",
351
      False,
352
      "Flag evaluating dlib metrics when debugging")
353
  flags.DEFINE_string(
354
      "gin_file",
355
      "weak_disentangle/configs/gan.gin",
356
      "Gin bindings to override values in gin config.")
357
  flags.DEFINE_multi_string(
358
      "gin_bindings", [],
359
      "Gin bindings to override values in gin config.")
360
  app.run(main)
361

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.