google-research
201 строка · 7.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import os
17import numpy as np
18import tensorflow as tf
19# pylint: skip-file
20
21def get_weight(shape, stddev, reg, name):
22wd = 0.0
23# init = tf.random_normal_initializer(stddev=stddev)
24init = tf.contrib.layers.xavier_initializer()
25if reg:
26regu = tf.contrib.layers.l2_regularizer(wd)
27filt = tf.get_variable(name, shape, initializer=init, regularizer=regu)
28else:
29filt = tf.get_variable(name, shape, initializer=init)
30return filt
31
32def get_bias(shape, init_bias, reg, name):
33wd = 0.0
34init = tf.constant_initializer(init_bias)
35if reg:
36regu = tf.contrib.layers.l2_regularizer(wd)
37bias = tf.get_variable(name, shape, initializer=init, regularizer=regu)
38else:
39bias = tf.get_variable(name, shape, initializer=init)
40return bias
41
42def batch_norm(x, phase_train, moments_dim):
43"""
44Batch normalization on convolutional maps.
45Ref.: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
46Args:
47x: Tensor, 4D BHWD input maps
48phase_train: boolean tf.Varialbe, true indicates training phase
49scope: string, variable scope
50Return:
51normed: batch-normalized maps
52"""
53with tf.variable_scope('bn'):
54n_out = x.get_shape().as_list()[-1]
55gamma = get_bias(n_out, 1.0, True, 'gamma')
56beta = get_bias(n_out, 0.0, True, 'beta')
57batch_mean, batch_var = tf.nn.moments(x, moments_dim, name='moments')
58ema = tf.train.ExponentialMovingAverage(decay=0.999)
59
60def mean_var_with_update():
61ema_apply_op = ema.apply([batch_mean, batch_var])
62with tf.control_dependencies([ema_apply_op]):
63return tf.identity(batch_mean), tf.identity(batch_var)
64
65mean, var = tf.cond(phase_train,
66mean_var_with_update,
67lambda: (ema.average(batch_mean), ema.average(batch_var)))
68normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
69return normed
70
71def max_pool(inputs, name, k_shape=[1, 2, 2, 1],s_shape=[1, 2, 2, 1]):
72with tf.variable_scope(name) as scope:
73outputs = tf.nn.max_pool(inputs, ksize=k_shape, strides=s_shape, padding='SAME', name=name)
74return outputs
75
76def get_deconv_dim(dim_size, stride_size, kernel_size, padding):
77dim_size *= stride_size
78if padding == 'VALID' and dim_size is not None:
79dim_size += max(kernel_size - stride_size, 0)
80return dim_size
81
82def fc(inputs, n_output, is_training, name, bias=0.0, relu=True, reg=True, bn=True):
83with tf.variable_scope(name) as scope:
84n_input = inputs.get_shape().as_list()[-1]
85shape = [n_input, n_output]
86# print("shape of filter %s: %s" % (name, str(shape)))
87filt = get_weight(shape, stddev=tf.sqrt(2.0/tf.to_float(n_input+n_output)), reg=True, name='weight')
88bias = get_bias([n_output],init_bias=bias, reg=True, name='bias')
89outputs = tf.matmul(inputs, filt)
90outputs = tf.nn.bias_add(outputs, bias)
91if bn:
92outputs = batch_norm(outputs, is_training, [0,])
93if relu:
94outputs = tf.nn.leaky_relu(outputs)
95return outputs
96
97def conv_2d(inputs, ksize, n_output, is_training, name, stride=1, pad='SAME', relu=True, reg=True, bn=True):
98with tf.variable_scope(name) as scope:
99n_input = inputs.get_shape().as_list()[3]
100shape = [ksize, ksize, n_input, n_output]
101# print("shape of filter %s: %s\n" % (name, str(shape)))
102filt = get_weight(shape, stddev=tf.sqrt(2.0/tf.to_float(n_input+n_output)), reg=reg, name='weight')
103outputs = tf.nn.conv2d(inputs, filt, [1, stride, stride, 1], padding=pad)
104if bn:
105outputs = batch_norm(outputs, is_training, [0,1,2])
106if relu:
107outputs = tf.nn.leaky_relu(outputs)
108return outputs
109
110def conv_2d_trans(inputs, ksize, n_output, is_training, name, stride=1, pad='SAME', relu=True, reg=True, bn=True):
111with tf.variable_scope(name) as scope:
112batch_size = tf.shape(inputs)[0]
113input_size = inputs.get_shape().as_list()[1]
114n_input = inputs.get_shape().as_list()[3]
115shape = [ksize, ksize, n_output, n_input]
116output_shape = tf.stack([batch_size, input_size*stride, input_size*stride, n_output])
117# print("shape of deconv_filter %s: %s\n" % (name, str(shape)))
118filt = get_weight(shape, stddev=tf.sqrt(2.0/tf.to_float(n_input+n_output)), reg=reg, name='weight')
119outputs = tf.nn.conv2d_transpose(inputs, filt, output_shape, [1, stride, stride, 1], padding=pad)
120if bn:
121outputs = batch_norm(outputs, is_training, [0,1,2])
122if relu:
123outputs = tf.nn.relu(outputs)
124return outputs
125
126class Adv_cls():
127def build(self, inputs, n_class, is_training):
128with tf.variable_scope('Adv', reuse=tf.AUTO_REUSE):
129net = inputs[-1]
130
131for i in range(3): #4x4
132net = conv_2d(net, 3, 512, is_training, 'conv1_'+str(i))
133# net = max_pool(net, 'pool3')
134
135for i in range(3): #4x4
136net = conv_2d(net, 3, 256, is_training, 'conv2_'+str(i))
137# net = max_pool(net, 'pool3')
138
139net = conv_2d(net, 4, 256, is_training, 'fc1', pad='VALID')
140
141net = conv_2d(net, 1, 128, is_training, 'fc2', pad='VALID')
142
143net = tf.squeeze(conv_2d(net, 1, n_class, is_training, 'fc3', pad='VALID', relu=False, bn=False))
144
145self.vars = tf.trainable_variables('Adv')
146self.reg_loss = tf.losses.get_regularization_losses('Adv')
147
148return net
149
150class Genc():
151def build(self, inputs, is_training):
152with tf.variable_scope('Genc', reuse=tf.AUTO_REUSE):
153net = inputs
154nets = []
155for i in range(5):
156net = conv_2d(net, 4, int(64 * 2**i), is_training, 'enc_'+str(i), stride=2)
157nets.append(net)
158
159self.vars = tf.trainable_variables('Genc')
160self.reg_loss = tf.losses.get_regularization_losses('Genc')
161return nets
162
163class Gdec():
164def build(self, inputs, labels, is_training):
165with tf.variable_scope('Gdec', reuse=tf.AUTO_REUSE):
166labels = tf.reshape(tf.to_float(labels),[-1,1,1,1]) # B,1,1,N
167net = inputs[-1]
168tile_labels = tf.tile(labels,[1,net.shape[1],net.shape[2],1])
169net = tf.concat([net, tile_labels],axis=-1)
170
171for i in range(4):
172if i==1:
173net = tf.concat([net, inputs[-2]],axis=-1)
174tile_labels = tf.tile(labels,[1,net.shape[1],net.shape[2],1])
175net = tf.concat([net, tile_labels],axis=-1)
176net = conv_2d_trans(net, 4, int(1024 / 2**i), is_training, 'dec_'+str(i), stride=2)
177
178net = tf.nn.tanh(conv_2d_trans(net, 4, 3, is_training, 'dec_f', stride=2, relu=False, bn=False))
179
180self.vars = tf.trainable_variables('Gdec')
181self.reg_loss = tf.losses.get_regularization_losses('Gdec')
182return net
183
184class D():
185def build(self, inputs, is_training):
186with tf.variable_scope('D', reuse=tf.AUTO_REUSE):
187net = inputs
188batch_size = net.get_shape().as_list()[0]
189for i in range(5):
190net = conv_2d(net, 4, int(64 * 2**i), is_training, 'D_'+str(i), stride=2)
191net = tf.reshape(net, [batch_size, -1])
192
193gan_net = fc(net, 1024, is_training, 'gan1')
194gan_net = fc(gan_net, 1, is_training, 'gan2', relu=False, bn=False)
195
196cls_net = fc(net, 1024, is_training, 'cls1')
197cls_net = fc(cls_net, 1, is_training, 'cls2', relu=False, bn=False)
198
199self.vars = tf.trainable_variables('D')
200self.reg_loss = tf.losses.get_regularization_losses('D')
201return gan_net, cls_net
202
203
204
205
206