google-research

Форк
0
207 строк · 6.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Models."""
17

18
# pylint: disable=g-bad-import-order, unused-import, g-multiple-import
19
# pylint: disable=line-too-long, missing-docstring, g-importing-member
20
# pylint: disable=g-wrong-blank-lines, missing-super-argument
21
import gin
22
import tensorflow.compat.v1 as tf
23
import tensorflow_probability as tfp
24
from functools import partial
25
from collections import OrderedDict
26
import numpy as np
27

28
from weak_disentangle import tensorsketch as ts
29
from weak_disentangle import utils as ut
30

31
tfd = tfp.distributions
32
dense = gin.external_configurable(ts.Dense)
33
conv = gin.external_configurable(ts.Conv2d)
34
deconv = gin.external_configurable(ts.ConvTranspose2d)
35
add_wn = gin.external_configurable(ts.WeightNorm.add)
36
add_bn = gin.external_configurable(ts.BatchNorm.add)
37

38

39
@gin.configurable
40
class Encoder(ts.Module):
41
  def __init__(self, x_shape, z_dim, width=1, spectral_norm=True):
42
    super().__init__()
43
    self.net = ts.Sequential(
44
        conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),
45
        conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),
46
        conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),
47
        conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),
48
        ts.Flatten(),
49
        dense(128 * width), ts.LeakyReLU(),
50
        dense(2 * z_dim)
51
        )
52

53
    if spectral_norm:
54
      self.net.apply(ts.SpectralNorm.add, targets=ts.Affine)
55

56
    ut.log("Building encoder...")
57
    self.build([1] + x_shape)
58
    self.apply(ut.reset_parameters)
59

60
  def forward(self, x):
61
    h = self.net(x)
62
    a, b = tf.split(h, 2, axis=-1)
63
    return tfd.MultivariateNormalDiag(
64
        loc=a,
65
        scale_diag=tf.nn.softplus(b) + 1e-8)
66

67

68
@gin.configurable
69
class LabelDiscriminator(ts.Module):
70
  def __init__(self, x_shape, y_dim, width=1, share_dense=False,
71
               uncond_bias=False):
72
    super().__init__()
73
    self.y_dim = y_dim
74
    self.body = ts.Sequential(
75
        conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),
76
        conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),
77
        conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),
78
        conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),
79
        ts.Flatten(),
80
        )
81

82
    self.aux = ts.Sequential(
83
        dense(128 * width), ts.LeakyReLU(),
84
        )
85

86
    if share_dense:
87
      self.body.append(dense(128 * width), ts.LeakyReLU())
88
      self.aux.append(dense(128 * width), ts.LeakyReLU())
89

90
    self.head = ts.Sequential(
91
        dense(128 * width), ts.LeakyReLU(),
92
        dense(128 * width), ts.LeakyReLU(),
93
        dense(1, bias=uncond_bias)
94
        )
95

96
    for m in (self.body, self.aux, self.head):
97
      m.apply(ts.SpectralNorm.add, targets=ts.Affine)
98

99
    ut.log("Building label discriminator...")
100
    x_shape, y_shape = [1] + x_shape, (1, y_dim)
101
    self.build(x_shape, y_shape)
102
    self.apply(ut.reset_parameters)
103

104
  def forward(self, x, y):
105
    hx = self.body(x)
106
    hy = self.aux(y)
107
    o = self.head(tf.concat((hx, hy), axis=-1))
108
    return o
109

110

111
@gin.configurable
112
class Discriminator(ts.Module):
113
  def __init__(self, x_shape, y_dim, width=1, share_dense=False,
114
               uncond_bias=False, cond_bias=False, mask_type="match"):
115
    super().__init__()
116
    self.y_dim = y_dim
117
    self.mask_type = mask_type
118
    self.body = ts.Sequential(
119
        conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),
120
        conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),
121
        conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),
122
        conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),
123
        ts.Flatten(),
124
        )
125

126
    if share_dense:
127
      self.body.append(dense(128 * width), ts.LeakyReLU())
128

129
    if mask_type == "match":
130
      self.neck = ts.Sequential(
131
          dense(128 * width), ts.LeakyReLU(),
132
          dense(128 * width), ts.LeakyReLU(),
133
          )
134

135
      self.head_uncond = dense(1, bias=uncond_bias)
136
      self.head_cond = dense(128 * width, bias=cond_bias)
137

138
      for m in (self.body, self.neck, self.head_uncond):
139
        m.apply(ts.SpectralNorm.add, targets=ts.Affine)
140
      add_wn(self.head_cond)
141
      x_shape, y_shape = [1] + x_shape, ((1,), tf.int32)
142

143
    elif mask_type == "rank":
144
      self.body.append(
145
          dense(128 * width), ts.LeakyReLU(),
146
          dense(128 * width), ts.LeakyReLU(),
147
          dense(1 + y_dim, bias=uncond_bias)
148
          )
149

150
      self.body.apply(ts.SpectralNorm.add, targets=ts.Affine)
151
      x_shape, y_shape = [1] + x_shape, (1, y_dim)
152

153
    ut.log("Building {} discriminator...".format(mask_type))
154
    self.build(x_shape, x_shape, y_shape)
155
    self.apply(ut.reset_parameters)
156

157
  def forward(self, x1, x2, y):
158
    if self.mask_type == "match":
159
      h = self.body(tf.concat((x1, x2), axis=0))
160
      h1, h2 = tf.split(h, 2, axis=0)
161
      h = self.neck(tf.concat((h1, h2), axis=-1))
162
      o_uncond = self.head_uncond(h)
163

164
      w = self.head_cond(tf.one_hot(y, self.y_dim))
165
      o_cond = tf.reduce_sum(h * w, axis=-1, keepdims=True)
166
      return o_uncond + o_cond
167

168
    elif self.mask_type == "rank":
169
      h = self.body(tf.concat((x1, x2), axis=0))
170
      h1, h2 = tf.split(h, 2, axis=0)
171
      o1, z1 = tf.split(h1, (1, self.y_dim), axis=-1)
172
      o2, z2 = tf.split(h2, (1, self.y_dim), axis=-1)
173
      y_pm = y * 2 - 1  # convert from {0, 1} to {-1, 1}
174
      diff = (z1 - z2) * y_pm
175
      o_diff = tf.reduce_sum(diff, axis=-1, keepdims=True)
176
      return o1 + o2 + o_diff
177

178
  def expose_encoder(self, x):
179
    h = self.body(x)
180
    _, z = tf.split(h, (1, self.y_dim), axis=-1)
181
    return z
182

183

184
@gin.configurable
185
class Generator(ts.Module):
186
  def __init__(self, x_shape, z_dim, batch_norm=True):
187
    super().__init__()
188
    ch = x_shape[-1]
189
    self.net = ts.Sequential(
190
        dense(128), ts.ReLU(),
191
        dense(4 * 4 * 64), ts.ReLU(), ts.Reshape((-1, 4, 4, 64)),
192
        deconv(64, 4, 2, "same"), ts.LeakyReLU(),
193
        deconv(32, 4, 2, "same"), ts.LeakyReLU(),
194
        deconv(32, 4, 2, "same"), ts.LeakyReLU(),
195
        deconv(ch, 4, 2, "same"), ts.Sigmoid(),
196
        )
197

198
    # Add batchnorm post-activation (attach to activation out_hook)
199
    if batch_norm:
200
      self.net.apply(add_bn, targets=(ts.ReLU, ts.LeakyReLU))
201

202
    ut.log("Building generator...")
203
    self.build((1, z_dim))
204
    self.apply(ut.reset_parameters)
205

206
  def forward(self, z):
207
    return self.net(z)
208

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.