google-research

Форк
0
112 строк · 3.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Training script for VAE."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import os
22
import sys
23

24
from absl import app
25
from absl import flags
26
import h5py
27
from .models.vae import ImageTransformSC
28
import numpy as np
29
import tensorflow.compat.v1 as tf
30
from .utils import sample_batch_vae
31
from .utils import save_im
32

33
FLAGS = flags.FLAGS
34

35
flags.DEFINE_integer('batchsize', 256,
36
                     'Batch Size')
37
flags.DEFINE_integer('latentsize', 8,
38
                     'Latent Size')
39
flags.DEFINE_integer('trainsteps', 100000,
40
                     'Train Steps')
41
flags.DEFINE_string('datapath', '/tmp/test.hdf5',
42
                    'Path to the HDF5 dataset')
43
flags.DEFINE_string('savedir', '/tmp/mazevae/',
44
                    'Where to save the model')
45

46

47
def main(argv):
48
  if len(argv) > 1:
49
    raise app.UsageError('Too many command-line arguments.')
50

51
  batchsize = FLAGS.batchsize
52
  latentsize = FLAGS.latentsize
53

54
  savedir = FLAGS.savedir + str(batchsize) + '_' + str(latentsize) + '/'
55
  path = FLAGS.datapath
56

57
  if not os.path.exists(savedir):
58
    os.makedirs(savedir)
59

60
  f = h5py.File(path, 'r')
61
  ims = f['sim']['ims'][:, :, :, :, :]
62

63
  it = ImageTransformSC(latentsize)
64
  outall = it(batchsize)
65
  out, _, mu, var = outall
66

67
  likelihood2 = tf.reduce_sum(it.s2 * tf.log(out) + (1-it.s2)*tf.log(1-out),
68
                              axis=[1, 2, 3])
69
  likelihood = likelihood2
70
  kl = 0.5 * tf.reduce_sum(-1 - tf.log(1e-5 +var) + tf.math.square(mu) + var,
71
                           axis=[1])
72
  loss = -1 * (tf.reduce_mean(likelihood) - tf.reduce_mean(kl))
73

74
  optim = tf.train.AdamOptimizer(0.0001)
75
  optimizer_step = optim.minimize(loss)
76

77
  saver = tf.train.Saver()
78
  with tf.Session() as sess:
79
    sess.run(tf.local_variables_initializer())
80
    sess.run(tf.global_variables_initializer())
81

82
    for i in range(FLAGS.trainsteps):
83
      batch = sample_batch_vae(batchsize, ims, env='maze', epnum=ims.shape[0],
84
                               epsize=ims.shape[1])
85
      forward_feed = {
86
          it.s1: batch[:, 0],
87
          it.s2: batch[:, 1]
88
      }
89

90
      o, l, _ = sess.run([outall, loss, optimizer_step], forward_feed)
91
      delta, rc, _, _ = o
92
      if i % 10000 == 0:
93
        saver.save(sess, savedir + 'model', global_step=i)
94
        save_im(255*batch[0, 0], savedir+ 's1_'+str(i)+'.jpg')
95
        save_im(255*batch[0, 1], savedir+'s2_'+str(i)+'.jpg')
96
        save_im(255*(delta[0]), savedir+'s2pred_'+str(i)+'.jpg')
97
        save_im(255*(rc[0]), savedir+'s1pred_'+str(i)+'.jpg')
98

99
        sys.stdout.write(str(l) + ', ' +str(i) + '\n')
100

101
        forward_feed = {
102
            it.s1: np.repeat(np.expand_dims(batch[0, 0], 0), batchsize, 0),
103
            it.s2: np.repeat(np.expand_dims(batch[0, 1], 0), batchsize, 0),
104
            it.z: np.random.normal([0.]*latentsize, [1.]*latentsize,
105
                                   (batchsize, latentsize))
106
        }
107
        delta = sess.run(out, forward_feed)
108
        for j in range(batchsize)[:20]:
109
          save_im(255*(delta[j]), savedir + 'gen'+str(i)+'_'+str(j)+'.jpg')
110

111
if __name__ == '__main__':
112
  app.run(main)
113

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.