google-research
360 строк · 12.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Main script for all experiments.
17"""
18
19# pylint: disable=g-bad-import-order, unused-import, g-multiple-import
20# pylint: disable=line-too-long, missing-docstring, g-importing-member
21# pylint: disable=no-value-for-parameter, unused-argument
22import gin
23import matplotlib.pyplot as plt
24import numpy as np
25import os
26import tensorflow.compat.v1 as tf
27import time
28from absl import app
29from absl import flags
30from tensorflow.compat.v1 import gfile
31from tqdm import tqdm
32
33from weak_disentangle import datasets, viz, networks, evaluate
34from weak_disentangle import utils as ut
35tf.enable_v2_behavior()
36tfk = tf.keras
37
38
39@gin.configurable
40def train(dset_name, s_dim, n_dim, factors,
41batch_size, dec_lr, enc_lr_mul, iterations,
42model_type="gen"):
43ut.log("In train")
44masks = datasets.make_masks(factors, s_dim)
45z_dim = s_dim + n_dim
46enc_lr = enc_lr_mul * dec_lr
47
48# Load data
49dset = datasets.get_dlib_data(dset_name)
50if dset is None:
51x_shape = [64, 64, 1]
52else:
53x_shape = dset.observation_shape
54targets_real = tf.ones((batch_size, 1))
55targets_fake = tf.zeros((batch_size, 1))
56targets = tf.concat((targets_real, targets_fake), axis=0)
57
58# Networks
59if model_type == "gen":
60assert factors.split("=")[0] in {"c", "s", "cs", "r"}
61y_dim = len(masks)
62dis = networks.Discriminator(x_shape, y_dim)
63gen = networks.Generator(x_shape, z_dim)
64enc = networks.Encoder(x_shape, s_dim) # Encoder ignores nuisance param
65ut.log(dis.read(dis.WITH_VARS))
66ut.log(gen.read(gen.WITH_VARS))
67ut.log(enc.read(enc.WITH_VARS))
68elif model_type == "enc":
69assert factors.split("=")[0] in {"r"}
70enc = networks.Encoder(x_shape, s_dim) # Encoder ignores nuisance param
71ut.log(enc.read(enc.WITH_VARS))
72elif model_type == "van":
73assert factors.split("=")[0] in {"l"}
74dis = networks.LabelDiscriminator(x_shape, s_dim) # Uses s_dim
75gen = networks.Generator(x_shape, z_dim)
76enc = networks.Encoder(x_shape, s_dim) # Encoder ignores nuisance param
77ut.log(dis.read(dis.WITH_VARS))
78ut.log(gen.read(gen.WITH_VARS))
79ut.log(enc.read(enc.WITH_VARS))
80
81# Create optimizers
82if model_type in {"gen", "van"}:
83gen_opt = tfk.optimizers.Adam(learning_rate=dec_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8)
84dis_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8)
85enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8)
86elif model_type == "enc":
87enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8)
88
89@tf.function
90def train_gen_step(x1_real, x2_real, y_real):
91gen.train()
92dis.train()
93enc.train()
94# Alternate discriminator step and generator step
95with tf.GradientTape(persistent=True) as tape:
96# Generate
97z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks)
98x1_fake = tf.stop_gradient(gen(z1))
99x2_fake = tf.stop_gradient(gen(z2))
100
101# Discriminate
102x1 = tf.concat((x1_real, x1_fake), 0)
103x2 = tf.concat((x2_real, x2_fake), 0)
104y = tf.concat((y_real, y_fake), 0)
105logits = dis(x1, x2, y)
106
107# Encode
108p_z = enc(x1_fake)
109
110dis_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
111logits=logits, labels=targets))
112# Encoder ignores nuisance parameters (if they exist)
113enc_loss = -tf.reduce_mean(p_z.log_prob(z1[:, :s_dim]))
114
115dis_grads = tape.gradient(dis_loss, dis.trainable_variables)
116enc_grads = tape.gradient(enc_loss, enc.trainable_variables)
117
118dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables))
119enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables))
120
121with tf.GradientTape(persistent=False) as tape:
122# Generate
123z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks)
124x1_fake = gen(z1)
125x2_fake = gen(z2)
126
127# Discriminate
128logits_fake = dis(x1_fake, x2_fake, y_fake)
129
130gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
131logits=logits_fake, labels=targets_real))
132
133gen_grads = tape.gradient(gen_loss, gen.trainable_variables)
134gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables))
135
136return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss)
137
138@tf.function
139def train_van_step(x_real, y_real):
140gen.train()
141dis.train()
142enc.train()
143
144if n_dim > 0:
145padding = tf.zeros((y_real.shape[0], n_dim))
146y_real_pad = tf.concat((y_real, padding), axis=-1)
147else:
148y_real_pad = y_real
149
150# Alternate discriminator step and generator step
151with tf.GradientTape(persistent=False) as tape:
152# Generate
153z_fake = datasets.paired_randn(batch_size, z_dim, masks)
154z_fake = z_fake + y_real_pad
155x_fake = gen(z_fake)
156
157# Discriminate
158logits_fake = dis(x_fake, y_real)
159
160gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
161logits=logits_fake, labels=targets_real))
162
163gen_grads = tape.gradient(gen_loss, gen.trainable_variables)
164gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables))
165
166with tf.GradientTape(persistent=True) as tape:
167# Generate
168z_fake = datasets.paired_randn(batch_size, z_dim, masks)
169z_fake = z_fake + y_real_pad
170x_fake = tf.stop_gradient(gen(z_fake))
171
172# Discriminate
173x = tf.concat((x_real, x_fake), 0)
174y = tf.concat((y_real, y_real), 0)
175logits = dis(x, y)
176
177# Encode
178p_z = enc(x_fake)
179
180dis_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
181logits=logits, labels=targets))
182# Encoder ignores nuisance parameters (if they exist)
183enc_loss = -tf.reduce_mean(p_z.log_prob(z_fake[:, :s_dim]))
184
185dis_grads = tape.gradient(dis_loss, dis.trainable_variables)
186enc_grads = tape.gradient(enc_loss, enc.trainable_variables)
187
188dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables))
189enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables))
190
191return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss)
192
193@tf.function
194def train_enc_step(x1_real, x2_real, y_real):
195with tf.GradientTape() as tape:
196z1 = enc(x1_real).mean()
197z2 = enc(x2_real).mean()
198logits = tf.gather(z1 - z2, masks, axis=-1)
199loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
200logits=logits, labels=y_real))
201
202enc_grads = tape.gradient(loss, enc.trainable_variables)
203enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables))
204return dict(gen_loss=0, dis_loss=0, enc_loss=loss)
205
206@tf.function
207def gen_eval(z):
208gen.eval()
209return gen(z)
210
211@tf.function
212def enc_eval(x):
213enc.eval()
214return enc(x).mean()
215enc_np = lambda x: enc_eval(x).numpy()
216
217# Initial preparation
218if FLAGS.debug:
219iter_log = 100
220iter_save = 2000
221train_range = range(iterations)
222basedir = FLAGS.basedir
223vizdir = FLAGS.basedir
224ckptdir = FLAGS.basedir
225new_run = True
226else:
227iter_log = 5000
228iter_save = 50000
229iter_metric = iter_save * 5 # Make sure this is a factor of 500k
230basedir = os.path.join(FLAGS.basedir, "exp")
231ckptdir = os.path.join(basedir, "ckptdir")
232vizdir = os.path.join(basedir, "vizdir")
233gfile.MakeDirs(basedir)
234gfile.MakeDirs(ckptdir)
235gfile.MakeDirs(vizdir) # train_range will be specified below
236
237ckpt_prefix = os.path.join(ckptdir, "model")
238if model_type in {"gen", "van"}:
239ckpt_root = tf.train.Checkpoint(dis=dis, dis_opt=dis_opt,
240gen=gen, gen_opt=gen_opt,
241enc=enc, enc_opt=enc_opt)
242elif model_type == "enc":
243ckpt_root = tf.train.Checkpoint(enc=enc, enc_opt=enc_opt)
244
245# Check if we're resuming training if not in debugging mode
246if not FLAGS.debug:
247latest_ckpt = tf.train.latest_checkpoint(ckptdir)
248if latest_ckpt is None:
249new_run = True
250ut.log("Starting a completely new model")
251train_range = range(iterations)
252
253else:
254new_run = False
255ut.log("Restarting from {}".format(latest_ckpt))
256ckpt_root.restore(latest_ckpt)
257resuming_iteration = iter_save * (int(ckpt_root.save_counter) - 1)
258train_range = range(resuming_iteration, iterations)
259
260# Training
261if dset is None:
262ut.log("Dataset {} is not available".format(dset_name))
263ut.log("Ending program having checked that the networks can be built.")
264return
265
266batches = datasets.paired_data_generator(dset, masks).repeat().batch(batch_size).prefetch(1000)
267batches = iter(batches)
268start_time = time.time()
269train_time = 0
270
271if FLAGS.debug:
272train_range = tqdm(train_range)
273
274for global_step in train_range:
275stopwatch = time.time()
276if model_type == "gen":
277x1, x2, y = next(batches)
278vals = train_gen_step(x1, x2, y)
279elif model_type == "enc":
280x1, x2, y = next(batches)
281vals = train_enc_step(x1, x2, y)
282elif model_type == "van":
283x, y = next(batches)
284vals = train_van_step(x, y)
285train_time += time.time() - stopwatch
286
287# Generic bookkeeping
288if (global_step + 1) % iter_log == 0 or global_step == 0:
289elapsed_time = time.time() - start_time
290string = ", ".join((
291"Iter: {:07d}, Elapsed: {:.3e}, (Elapsed) Iter/s: {:.3e}, (Train Step) Iter/s: {:.3e}".format(
292global_step, elapsed_time, global_step / elapsed_time, global_step / train_time),
293"Gen: {gen_loss:.4f}, Dis: {dis_loss:.4f}, Enc: {enc_loss:.4f}".format(**vals)
294)) + "."
295ut.log(string)
296
297# Log visualizations and evaluations
298if (global_step + 1) % iter_save == 0 or global_step == 0:
299if model_type == "gen":
300viz.ablation_visualization(x1, x2, gen_eval, z_dim, vizdir, global_step + 1)
301elif model_type == "van":
302viz.ablation_visualization(x, x, gen_eval, z_dim, vizdir, global_step + 1)
303
304if FLAGS.debug:
305evaluate.evaluate_enc(enc_np, dset, s_dim,
306FLAGS.gin_file,
307FLAGS.gin_bindings,
308pida_sample_size=1000,
309dlib_metrics=FLAGS.debug_dlib_metrics)
310
311else:
312dlib_metrics = (global_step + 1) % iter_metric == 0
313evaluate.evaluate_enc(enc_np, dset, s_dim,
314FLAGS.gin_file,
315FLAGS.gin_bindings,
316pida_sample_size=10000,
317dlib_metrics=dlib_metrics)
318
319# Save model
320if (global_step + 1) % iter_save == 0 or (global_step == 0 and new_run):
321# Save model only after ensuring all measurements are taken.
322# This ensures that restarts always computes the evals
323ut.log("Saved to", ckpt_root.save(ckpt_prefix))
324
325
326def main(_):
327if FLAGS.debug:
328FLAGS.gin_bindings += ["log.debug = True"]
329gin.parse_config_files_and_bindings(
330[FLAGS.gin_file],
331FLAGS.gin_bindings,
332finalize_config=False)
333ut.log("\n" + "*" * 80 + "\nBegin program\n" + "*" * 80)
334ut.log("In main")
335train()
336ut.log("\n" + "*" * 80 + "\nEnd program\n" + "*" * 80)
337
338
339if __name__ == "__main__":
340FLAGS = flags.FLAGS
341flags.DEFINE_string(
342"basedir",
343"/tmp",
344"Path to directory where to store results.")
345flags.DEFINE_boolean(
346"debug",
347False,
348"Flag debugging mode (shorter run-times, etc)")
349flags.DEFINE_boolean(
350"debug_dlib_metrics",
351False,
352"Flag evaluating dlib metrics when debugging")
353flags.DEFINE_string(
354"gin_file",
355"weak_disentangle/configs/gan.gin",
356"Gin bindings to override values in gin config.")
357flags.DEFINE_multi_string(
358"gin_bindings", [],
359"Gin bindings to override values in gin config.")
360app.run(main)
361