google-research

Форк
0
254 строки · 8.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Dataset-specific utilities."""
17

18
# pylint: disable=g-bad-import-order, unused-import, g-multiple-import
19
# pylint: disable=line-too-long, missing-docstring, g-importing-member
20
from disentanglement_lib.data.ground_truth import dsprites
21
from disentanglement_lib.data.ground_truth import shapes3d
22
from disentanglement_lib.data.ground_truth import mpi3d
23
from disentanglement_lib.data.ground_truth import cars3d
24
from disentanglement_lib.data.ground_truth import norb
25
import numpy as np
26
import tensorflow.compat.v1 as tf
27
import gin
28

29
from weak_disentangle import utils as ut
30

31

32
def get_dlib_data(task):
33
  ut.log("Loading {}".format(task))
34
  if task == "dsprites":
35
    # 5 factors
36
    return dsprites.DSprites(list(range(1, 6)))
37
  elif task == "shapes3d":
38
    # 6 factors
39
    return shapes3d.Shapes3D()
40
  elif task == "norb":
41
    # 4 factors + 1 nuisance (which we'll handle via n_dim=2)
42
    return norb.SmallNORB()
43
  elif task == "cars3d":
44
    # 3 factors
45
    return cars3d.Cars3D()
46
  elif task == "mpi3d":
47
    # 7 factors
48
    return mpi3d.MPI3D()
49
  elif task == "scream":
50
    # 5 factors + 2 nuisance (handled as n_dim=2)
51
    return dsprites.ScreamDSprites(list(range(1, 6)))
52

53

54
@gin.configurable
55
def make_masks(string, s_dim, mask_type):
56
  strategy, factors = string.split("=")
57
  assert strategy in {"s", "c", "r", "cs", "l"}, "Only allow label, share, change, rank-types"
58

59
  # mask_type is only here to help sanity-check that I didn't accidentally
60
  # use an invalid (strategy , mask_type) pair
61
  if strategy == "r":
62
    assert mask_type == "rank", "mask_type must match data collection strategy"
63
    # Use factor indices as mask. Assumes single factor per comma
64
    return list(map(int, factors.split(",")))
65
  elif strategy in {"s", "c", "cs"}:
66
    assert mask_type == "match", "mask_type must match data collection strategy"
67
  elif strategy in {"l"}:
68
    assert mask_type == "label", "mask_type must match data collection strategy"
69

70
  if strategy == "cs":
71
    # Pre-process factors to add complement set
72
    idx = int(factors)
73
    l = list(range(s_dim))
74
    del l[idx]
75
    factors = "{},{}".format(idx, "".join(map(str, l)))
76

77
  factors = [list(map(int, l)) for l in map(list, factors.split(","))]
78
  masks = np.zeros((len(factors), s_dim), dtype=np.float32)
79
  for (i, f) in enumerate(factors):
80
    masks[i, f] = 1
81

82
  if strategy == "s":
83
    masks = 1 - masks
84
  elif strategy == "l":
85
    assert len(masks) == 1, "Only one mask allowed for label-strategy"
86

87
  ut.log("make_masks output:")
88
  ut.log(masks)
89
  return masks
90

91

92
def sample_match_factors(dset, batch_size, masks, random_state):
93
  factor1 = dset.sample_factors(batch_size, random_state)
94
  factor2 = dset.sample_factors(batch_size, random_state)
95
  mask_idx = np.random.choice(len(masks), batch_size)
96
  mask = masks[mask_idx]
97
  factor2 = factor2 * mask + factor1 * (1 - mask)
98
  factors = np.concatenate((factor1, factor2), 0)
99
  return factors, mask_idx
100

101

102
def sample_rank_factors(dset, batch_size, masks, random_state):
103
  # We assume for ranking that masks is just a list of indices
104
  factors = dset.sample_factors(2 * batch_size, random_state)
105
  factor1, factor2 = np.split(factors, 2)
106
  y = (factor1 > factor2)[:, masks].astype(np.float32)
107
  return factors, y
108

109

110
def sample_match_images(dset, batch_size, masks, random_state):
111
  factors, mask_idx = sample_match_factors(dset, batch_size, masks, random_state)
112
  images = dset.sample_observations_from_factors(factors, random_state)
113
  x1, x2 = np.split(images, 2)
114
  return x1, x2, mask_idx
115

116

117
def sample_rank_images(dset, batch_size, masks, random_state):
118
  factors, y = sample_rank_factors(dset, batch_size, masks, random_state)
119
  images = dset.sample_observations_from_factors(factors, random_state)
120
  x1, x2 = np.split(images, 2)
121
  return x1, x2, y
122

123

124
def sample_images(dset, batch_size, random_state):
125
  factors = dset.sample_factors(batch_size, random_state)
126
  return dset.sample_observations_from_factors(factors, random_state)
127

128

129
@gin.configurable
130
def paired_data_generator(dset, masks, random_seed=None, mask_type="match"):
131
  if mask_type == "match":
132
    return match_data_generator(dset, masks, random_seed)
133
  elif mask_type == "rank":
134
    return rank_data_generator(dset, masks, random_seed)
135
  elif mask_type == "label":
136
    return label_data_generator(dset, masks, random_seed)
137

138

139
def match_data_generator(dset, masks, random_seed=None):
140
  def generator():
141
    random_state = np.random.RandomState(random_seed)
142

143
    while True:
144
      x1, x2, idx = sample_match_images(dset, 1, masks, random_state)
145
      # Returning x1[0] and x2[0] removes batch dimension
146
      yield x1[0], x2[0], idx.item(0)
147

148
  return tf.data.Dataset.from_generator(
149
      generator,
150
      (tf.float32, tf.float32, tf.int32),
151
      output_shapes=(dset.observation_shape, dset.observation_shape, ()))
152

153

154
def rank_data_generator(dset, masks, random_seed=None):
155
  def generator():
156
    random_state = np.random.RandomState(random_seed)
157

158
    while True:
159
      # Note: remove batch dimension by returning x1[0], x2[0], y[0]
160
      x1, x2, y = sample_rank_images(dset, 1, masks, random_state)
161
      yield x1[0], x2[0], y[0]
162

163
  y_dim = len(masks)  # Remember, masks is just a list
164
  return tf.data.Dataset.from_generator(
165
      generator,
166
      (tf.float32, tf.float32, tf.float32),
167
      output_shapes=(dset.observation_shape, dset.observation_shape, (y_dim,)))
168

169

170
def label_data_generator(dset, masks, random_seed=None):
171
  # Normalize the factors using mean and stddev
172
  m, s = [], []
173
  for factor_size in dset.factors_num_values:
174
    factor_values = list(range(factor_size))
175
    m.append(np.mean(factor_values))
176
    s.append(np.std(factor_values))
177
  m = np.array(m)
178
  s = np.array(s)
179

180
  def generator():
181
    random_state = np.random.RandomState(random_seed)
182

183
    while True:
184
      # Note: remove batch dimension by returning x1[0], x2[0], y[0]
185
      factors = dset.sample_factors(1, random_state)
186
      x = dset.sample_observations_from_factors(factors, random_state)
187
      factors = (factors - m) / s  # normalize the factors
188
      y = factors * masks
189
      yield x[0], y[0]
190

191
  y_dim = masks.shape[-1]  # mask is 1-hot and equal in length to s_dim
192
  return tf.data.Dataset.from_generator(
193
      generator,
194
      (tf.float32, tf.float32),
195
      output_shapes=(dset.observation_shape, (y_dim,)))
196

197

198
@gin.configurable
199
def paired_randn(batch_size, z_dim, masks, mask_type="match"):
200
  if mask_type == "match":
201
    return match_randn(batch_size, z_dim, masks)
202
  elif mask_type == "rank":
203
    return rank_randn(batch_size, z_dim, masks)
204
  elif mask_type == "label":
205
    return label_randn(batch_size, z_dim, masks)
206

207

208
def match_randn(batch_size, z_dim, masks):
209
  # Note that masks.shape[-1] = s_dim and we assume s_dim <= z-dim
210
  n_dim = z_dim - masks.shape[-1]
211

212
  if n_dim == 0:
213
    z1 = tf.random_normal((batch_size, z_dim))
214
    z2 = tf.random_normal((batch_size, z_dim))
215
  else:
216
    # First sample the controllable latents
217
    z1 = tf.random_normal((batch_size, masks.shape[-1]))
218
    z2 = tf.random_normal((batch_size, masks.shape[-1]))
219

220
  # Do variable fixing here (controllable latents)
221
  mask_idx = tf.random_uniform((batch_size,), maxval=len(masks), dtype=tf.int32)
222
  mask = tf.gather(masks, mask_idx)
223
  z2 = z2 * mask + z1 * (1 - mask)
224

225
  # Add nuisance dims (uncontrollable latents)
226
  if n_dim > 0:
227
    z1_append = tf.random_normal((batch_size, n_dim))
228
    z2_append = tf.random_normal((batch_size, n_dim))
229
    z1 = tf.concat((z1, z1_append), axis=-1)
230
    z2 = tf.concat((z2, z2_append), axis=-1)
231

232
  return z1, z2, mask_idx
233

234

235
def rank_randn(batch_size, z_dim, masks):
236
  z1 = tf.random.normal((batch_size, z_dim))
237
  z2 = tf.random.normal((batch_size, z_dim))
238
  y = tf.gather(z1 > z2, masks, axis=-1)
239
  y = tf.cast(y, tf.float32)
240
  return z1, z2, y
241

242

243
# pylint: disable=unused-argument
244
def label_randn(batch_size, z_dim, masks):
245
  # Note that masks.shape[-1] = s_dim and we assume s_dim <= z-dim
246
  n_dim = z_dim - masks.shape[-1]
247

248
  if n_dim == 0:
249
    return tf.random.normal((batch_size, z_dim)) * (1 - masks)
250
  else:
251
    z = tf.random.normal((batch_size, masks.shape[-1])) * (1 - masks)
252
    n = tf.random.normal((batch_size, n_dim))
253
    z = tf.concat((z, n), axis=-1)
254
    return z
255

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.