google-research

Форк
0
201 строка · 7.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import os
17
import numpy as np
18
import tensorflow as tf
19
# pylint: skip-file
20

21
def get_weight(shape, stddev, reg, name):
22
  wd = 0.0
23
  # init = tf.random_normal_initializer(stddev=stddev)
24
  init = tf.contrib.layers.xavier_initializer()
25
  if reg:
26
    regu = tf.contrib.layers.l2_regularizer(wd)
27
    filt = tf.get_variable(name, shape, initializer=init, regularizer=regu)
28
  else:
29
    filt = tf.get_variable(name, shape, initializer=init)
30
  return filt
31

32
def get_bias(shape, init_bias, reg, name):
33
  wd = 0.0
34
  init = tf.constant_initializer(init_bias)
35
  if reg:
36
    regu = tf.contrib.layers.l2_regularizer(wd)
37
    bias = tf.get_variable(name, shape, initializer=init, regularizer=regu)
38
  else:
39
    bias = tf.get_variable(name, shape, initializer=init)
40
  return bias
41

42
def batch_norm(x, phase_train, moments_dim):
43
  """
44
  Batch normalization on convolutional maps.
45
  Ref.: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
46
  Args:
47
    x:           Tensor, 4D BHWD input maps
48
    phase_train: boolean tf.Varialbe, true indicates training phase
49
    scope:       string, variable scope
50
  Return:
51
    normed:      batch-normalized maps
52
  """
53
  with tf.variable_scope('bn'):
54
    n_out = x.get_shape().as_list()[-1]
55
    gamma = get_bias(n_out, 1.0, True, 'gamma')
56
    beta = get_bias(n_out, 0.0, True, 'beta')
57
    batch_mean, batch_var = tf.nn.moments(x, moments_dim, name='moments')
58
    ema = tf.train.ExponentialMovingAverage(decay=0.999)
59

60
    def mean_var_with_update():
61
      ema_apply_op = ema.apply([batch_mean, batch_var])
62
      with tf.control_dependencies([ema_apply_op]):
63
        return tf.identity(batch_mean), tf.identity(batch_var)
64

65
    mean, var = tf.cond(phase_train,
66
              mean_var_with_update,
67
              lambda: (ema.average(batch_mean), ema.average(batch_var)))
68
    normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
69
  return normed
70

71
def max_pool(inputs, name, k_shape=[1, 2, 2, 1],s_shape=[1, 2, 2, 1]):
72
  with tf.variable_scope(name) as scope:
73
    outputs = tf.nn.max_pool(inputs, ksize=k_shape, strides=s_shape, padding='SAME', name=name)
74
  return outputs
75

76
def get_deconv_dim(dim_size, stride_size, kernel_size, padding):
77
  dim_size *= stride_size
78
  if padding == 'VALID' and dim_size is not None:
79
    dim_size += max(kernel_size - stride_size, 0)
80
  return dim_size
81

82
def fc(inputs, n_output, is_training, name, bias=0.0, relu=True, reg=True, bn=True):
83
  with tf.variable_scope(name) as scope:
84
    n_input = inputs.get_shape().as_list()[-1]
85
    shape = [n_input, n_output]
86
    # print("shape of filter %s: %s" % (name, str(shape)))
87
    filt = get_weight(shape, stddev=tf.sqrt(2.0/tf.to_float(n_input+n_output)), reg=True, name='weight')
88
    bias = get_bias([n_output],init_bias=bias, reg=True, name='bias')
89
    outputs = tf.matmul(inputs, filt)
90
    outputs = tf.nn.bias_add(outputs, bias)
91
    if bn:
92
      outputs = batch_norm(outputs, is_training, [0,])
93
    if relu:
94
      outputs = tf.nn.leaky_relu(outputs)
95
  return outputs
96

97
def conv_2d(inputs, ksize, n_output, is_training, name, stride=1, pad='SAME', relu=True, reg=True, bn=True):
98
  with tf.variable_scope(name) as scope:
99
    n_input = inputs.get_shape().as_list()[3]
100
    shape = [ksize, ksize, n_input, n_output]
101
    # print("shape of filter %s: %s\n" % (name, str(shape)))
102
    filt = get_weight(shape, stddev=tf.sqrt(2.0/tf.to_float(n_input+n_output)), reg=reg, name='weight')
103
    outputs = tf.nn.conv2d(inputs, filt, [1, stride, stride, 1], padding=pad)
104
    if bn:
105
      outputs = batch_norm(outputs, is_training, [0,1,2])
106
    if relu:
107
      outputs = tf.nn.leaky_relu(outputs)
108
  return outputs
109

110
def conv_2d_trans(inputs, ksize, n_output, is_training, name, stride=1, pad='SAME', relu=True, reg=True, bn=True):
111
  with tf.variable_scope(name) as scope:
112
    batch_size = tf.shape(inputs)[0]
113
    input_size = inputs.get_shape().as_list()[1]
114
    n_input = inputs.get_shape().as_list()[3]
115
    shape = [ksize, ksize, n_output, n_input]
116
    output_shape = tf.stack([batch_size, input_size*stride, input_size*stride, n_output])
117
    # print("shape of deconv_filter %s: %s\n" % (name, str(shape)))
118
    filt = get_weight(shape, stddev=tf.sqrt(2.0/tf.to_float(n_input+n_output)), reg=reg, name='weight')
119
    outputs = tf.nn.conv2d_transpose(inputs, filt, output_shape, [1, stride, stride, 1], padding=pad)
120
    if bn:
121
      outputs = batch_norm(outputs, is_training, [0,1,2])
122
    if relu:
123
      outputs = tf.nn.relu(outputs)
124
  return outputs
125

126
class Adv_cls():
127
  def build(self, inputs, n_class, is_training):
128
    with tf.variable_scope('Adv', reuse=tf.AUTO_REUSE):
129
      net = inputs[-1]
130

131
      for i in range(3): #4x4
132
        net = conv_2d(net, 3, 512, is_training, 'conv1_'+str(i))
133
      # net = max_pool(net, 'pool3')
134

135
      for i in range(3): #4x4
136
        net = conv_2d(net, 3, 256, is_training, 'conv2_'+str(i))
137
      # net = max_pool(net, 'pool3')
138

139
      net = conv_2d(net, 4, 256, is_training, 'fc1', pad='VALID')
140

141
      net = conv_2d(net, 1, 128, is_training, 'fc2', pad='VALID')
142

143
      net = tf.squeeze(conv_2d(net, 1, n_class, is_training, 'fc3', pad='VALID', relu=False, bn=False))
144

145
      self.vars = tf.trainable_variables('Adv')
146
      self.reg_loss = tf.losses.get_regularization_losses('Adv')
147

148
      return net
149

150
class Genc():
151
  def build(self, inputs, is_training):
152
    with tf.variable_scope('Genc', reuse=tf.AUTO_REUSE):
153
      net = inputs
154
      nets = []
155
      for i in range(5):
156
        net = conv_2d(net, 4, int(64 * 2**i), is_training, 'enc_'+str(i), stride=2)
157
        nets.append(net)
158

159
      self.vars = tf.trainable_variables('Genc')
160
      self.reg_loss = tf.losses.get_regularization_losses('Genc')
161
      return nets
162

163
class Gdec():
164
  def build(self, inputs, labels, is_training):
165
    with tf.variable_scope('Gdec', reuse=tf.AUTO_REUSE):
166
      labels = tf.reshape(tf.to_float(labels),[-1,1,1,1]) # B,1,1,N
167
      net = inputs[-1]
168
      tile_labels = tf.tile(labels,[1,net.shape[1],net.shape[2],1])
169
      net = tf.concat([net, tile_labels],axis=-1)
170

171
      for i in range(4):
172
        if i==1:
173
          net = tf.concat([net, inputs[-2]],axis=-1)
174
          tile_labels = tf.tile(labels,[1,net.shape[1],net.shape[2],1])
175
          net = tf.concat([net, tile_labels],axis=-1)
176
        net = conv_2d_trans(net, 4, int(1024 / 2**i), is_training, 'dec_'+str(i), stride=2)
177

178
      net = tf.nn.tanh(conv_2d_trans(net, 4, 3, is_training, 'dec_f', stride=2, relu=False, bn=False))
179

180
      self.vars = tf.trainable_variables('Gdec')
181
      self.reg_loss = tf.losses.get_regularization_losses('Gdec')
182
      return net
183

184
class D():
185
  def build(self, inputs, is_training):
186
    with tf.variable_scope('D', reuse=tf.AUTO_REUSE):
187
      net = inputs
188
      batch_size = net.get_shape().as_list()[0]
189
      for i in range(5):
190
        net = conv_2d(net, 4, int(64 * 2**i), is_training, 'D_'+str(i), stride=2)
191
      net = tf.reshape(net, [batch_size, -1])
192

193
      gan_net = fc(net, 1024, is_training, 'gan1')
194
      gan_net = fc(gan_net, 1, is_training, 'gan2', relu=False, bn=False)
195

196
      cls_net = fc(net, 1024, is_training, 'cls1')
197
      cls_net = fc(cls_net, 1, is_training, 'cls2', relu=False, bn=False)
198

199
      self.vars = tf.trainable_variables('D')
200
      self.reg_loss = tf.losses.get_regularization_losses('D')
201
      return gan_net, cls_net
202

203

204

205

206

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.