google-research
207 строк · 6.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Models."""
17
18# pylint: disable=g-bad-import-order, unused-import, g-multiple-import
19# pylint: disable=line-too-long, missing-docstring, g-importing-member
20# pylint: disable=g-wrong-blank-lines, missing-super-argument
21import gin22import tensorflow.compat.v1 as tf23import tensorflow_probability as tfp24from functools import partial25from collections import OrderedDict26import numpy as np27
28from weak_disentangle import tensorsketch as ts29from weak_disentangle import utils as ut30
31tfd = tfp.distributions32dense = gin.external_configurable(ts.Dense)33conv = gin.external_configurable(ts.Conv2d)34deconv = gin.external_configurable(ts.ConvTranspose2d)35add_wn = gin.external_configurable(ts.WeightNorm.add)36add_bn = gin.external_configurable(ts.BatchNorm.add)37
38
39@gin.configurable40class Encoder(ts.Module):41def __init__(self, x_shape, z_dim, width=1, spectral_norm=True):42super().__init__()43self.net = ts.Sequential(44conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),45conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),46conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),47conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),48ts.Flatten(),49dense(128 * width), ts.LeakyReLU(),50dense(2 * z_dim)51)52
53if spectral_norm:54self.net.apply(ts.SpectralNorm.add, targets=ts.Affine)55
56ut.log("Building encoder...")57self.build([1] + x_shape)58self.apply(ut.reset_parameters)59
60def forward(self, x):61h = self.net(x)62a, b = tf.split(h, 2, axis=-1)63return tfd.MultivariateNormalDiag(64loc=a,65scale_diag=tf.nn.softplus(b) + 1e-8)66
67
68@gin.configurable69class LabelDiscriminator(ts.Module):70def __init__(self, x_shape, y_dim, width=1, share_dense=False,71uncond_bias=False):72super().__init__()73self.y_dim = y_dim74self.body = ts.Sequential(75conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),76conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),77conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),78conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),79ts.Flatten(),80)81
82self.aux = ts.Sequential(83dense(128 * width), ts.LeakyReLU(),84)85
86if share_dense:87self.body.append(dense(128 * width), ts.LeakyReLU())88self.aux.append(dense(128 * width), ts.LeakyReLU())89
90self.head = ts.Sequential(91dense(128 * width), ts.LeakyReLU(),92dense(128 * width), ts.LeakyReLU(),93dense(1, bias=uncond_bias)94)95
96for m in (self.body, self.aux, self.head):97m.apply(ts.SpectralNorm.add, targets=ts.Affine)98
99ut.log("Building label discriminator...")100x_shape, y_shape = [1] + x_shape, (1, y_dim)101self.build(x_shape, y_shape)102self.apply(ut.reset_parameters)103
104def forward(self, x, y):105hx = self.body(x)106hy = self.aux(y)107o = self.head(tf.concat((hx, hy), axis=-1))108return o109
110
111@gin.configurable112class Discriminator(ts.Module):113def __init__(self, x_shape, y_dim, width=1, share_dense=False,114uncond_bias=False, cond_bias=False, mask_type="match"):115super().__init__()116self.y_dim = y_dim117self.mask_type = mask_type118self.body = ts.Sequential(119conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),120conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),121conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),122conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),123ts.Flatten(),124)125
126if share_dense:127self.body.append(dense(128 * width), ts.LeakyReLU())128
129if mask_type == "match":130self.neck = ts.Sequential(131dense(128 * width), ts.LeakyReLU(),132dense(128 * width), ts.LeakyReLU(),133)134
135self.head_uncond = dense(1, bias=uncond_bias)136self.head_cond = dense(128 * width, bias=cond_bias)137
138for m in (self.body, self.neck, self.head_uncond):139m.apply(ts.SpectralNorm.add, targets=ts.Affine)140add_wn(self.head_cond)141x_shape, y_shape = [1] + x_shape, ((1,), tf.int32)142
143elif mask_type == "rank":144self.body.append(145dense(128 * width), ts.LeakyReLU(),146dense(128 * width), ts.LeakyReLU(),147dense(1 + y_dim, bias=uncond_bias)148)149
150self.body.apply(ts.SpectralNorm.add, targets=ts.Affine)151x_shape, y_shape = [1] + x_shape, (1, y_dim)152
153ut.log("Building {} discriminator...".format(mask_type))154self.build(x_shape, x_shape, y_shape)155self.apply(ut.reset_parameters)156
157def forward(self, x1, x2, y):158if self.mask_type == "match":159h = self.body(tf.concat((x1, x2), axis=0))160h1, h2 = tf.split(h, 2, axis=0)161h = self.neck(tf.concat((h1, h2), axis=-1))162o_uncond = self.head_uncond(h)163
164w = self.head_cond(tf.one_hot(y, self.y_dim))165o_cond = tf.reduce_sum(h * w, axis=-1, keepdims=True)166return o_uncond + o_cond167
168elif self.mask_type == "rank":169h = self.body(tf.concat((x1, x2), axis=0))170h1, h2 = tf.split(h, 2, axis=0)171o1, z1 = tf.split(h1, (1, self.y_dim), axis=-1)172o2, z2 = tf.split(h2, (1, self.y_dim), axis=-1)173y_pm = y * 2 - 1 # convert from {0, 1} to {-1, 1}174diff = (z1 - z2) * y_pm175o_diff = tf.reduce_sum(diff, axis=-1, keepdims=True)176return o1 + o2 + o_diff177
178def expose_encoder(self, x):179h = self.body(x)180_, z = tf.split(h, (1, self.y_dim), axis=-1)181return z182
183
184@gin.configurable185class Generator(ts.Module):186def __init__(self, x_shape, z_dim, batch_norm=True):187super().__init__()188ch = x_shape[-1]189self.net = ts.Sequential(190dense(128), ts.ReLU(),191dense(4 * 4 * 64), ts.ReLU(), ts.Reshape((-1, 4, 4, 64)),192deconv(64, 4, 2, "same"), ts.LeakyReLU(),193deconv(32, 4, 2, "same"), ts.LeakyReLU(),194deconv(32, 4, 2, "same"), ts.LeakyReLU(),195deconv(ch, 4, 2, "same"), ts.Sigmoid(),196)197
198# Add batchnorm post-activation (attach to activation out_hook)199if batch_norm:200self.net.apply(add_bn, targets=(ts.ReLU, ts.LeakyReLU))201
202ut.log("Building generator...")203self.build((1, z_dim))204self.apply(ut.reset_parameters)205
206def forward(self, z):207return self.net(z)208