google-research
112 строк · 3.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Training script for VAE."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import sys
23
24from absl import app
25from absl import flags
26import h5py
27from .models.vae import ImageTransformSC
28import numpy as np
29import tensorflow.compat.v1 as tf
30from .utils import sample_batch_vae
31from .utils import save_im
32
33FLAGS = flags.FLAGS
34
35flags.DEFINE_integer('batchsize', 256,
36'Batch Size')
37flags.DEFINE_integer('latentsize', 8,
38'Latent Size')
39flags.DEFINE_integer('trainsteps', 100000,
40'Train Steps')
41flags.DEFINE_string('datapath', '/tmp/test.hdf5',
42'Path to the HDF5 dataset')
43flags.DEFINE_string('savedir', '/tmp/mazevae/',
44'Where to save the model')
45
46
47def main(argv):
48if len(argv) > 1:
49raise app.UsageError('Too many command-line arguments.')
50
51batchsize = FLAGS.batchsize
52latentsize = FLAGS.latentsize
53
54savedir = FLAGS.savedir + str(batchsize) + '_' + str(latentsize) + '/'
55path = FLAGS.datapath
56
57if not os.path.exists(savedir):
58os.makedirs(savedir)
59
60f = h5py.File(path, 'r')
61ims = f['sim']['ims'][:, :, :, :, :]
62
63it = ImageTransformSC(latentsize)
64outall = it(batchsize)
65out, _, mu, var = outall
66
67likelihood2 = tf.reduce_sum(it.s2 * tf.log(out) + (1-it.s2)*tf.log(1-out),
68axis=[1, 2, 3])
69likelihood = likelihood2
70kl = 0.5 * tf.reduce_sum(-1 - tf.log(1e-5 +var) + tf.math.square(mu) + var,
71axis=[1])
72loss = -1 * (tf.reduce_mean(likelihood) - tf.reduce_mean(kl))
73
74optim = tf.train.AdamOptimizer(0.0001)
75optimizer_step = optim.minimize(loss)
76
77saver = tf.train.Saver()
78with tf.Session() as sess:
79sess.run(tf.local_variables_initializer())
80sess.run(tf.global_variables_initializer())
81
82for i in range(FLAGS.trainsteps):
83batch = sample_batch_vae(batchsize, ims, env='maze', epnum=ims.shape[0],
84epsize=ims.shape[1])
85forward_feed = {
86it.s1: batch[:, 0],
87it.s2: batch[:, 1]
88}
89
90o, l, _ = sess.run([outall, loss, optimizer_step], forward_feed)
91delta, rc, _, _ = o
92if i % 10000 == 0:
93saver.save(sess, savedir + 'model', global_step=i)
94save_im(255*batch[0, 0], savedir+ 's1_'+str(i)+'.jpg')
95save_im(255*batch[0, 1], savedir+'s2_'+str(i)+'.jpg')
96save_im(255*(delta[0]), savedir+'s2pred_'+str(i)+'.jpg')
97save_im(255*(rc[0]), savedir+'s1pred_'+str(i)+'.jpg')
98
99sys.stdout.write(str(l) + ', ' +str(i) + '\n')
100
101forward_feed = {
102it.s1: np.repeat(np.expand_dims(batch[0, 0], 0), batchsize, 0),
103it.s2: np.repeat(np.expand_dims(batch[0, 1], 0), batchsize, 0),
104it.z: np.random.normal([0.]*latentsize, [1.]*latentsize,
105(batchsize, latentsize))
106}
107delta = sess.run(out, forward_feed)
108for j in range(batchsize)[:20]:
109save_im(255*(delta[j]), savedir + 'gen'+str(i)+'_'+str(j)+'.jpg')
110
111if __name__ == '__main__':
112app.run(main)
113