google-research
254 строки · 8.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Dataset-specific utilities."""
17
18# pylint: disable=g-bad-import-order, unused-import, g-multiple-import
19# pylint: disable=line-too-long, missing-docstring, g-importing-member
20from disentanglement_lib.data.ground_truth import dsprites21from disentanglement_lib.data.ground_truth import shapes3d22from disentanglement_lib.data.ground_truth import mpi3d23from disentanglement_lib.data.ground_truth import cars3d24from disentanglement_lib.data.ground_truth import norb25import numpy as np26import tensorflow.compat.v1 as tf27import gin28
29from weak_disentangle import utils as ut30
31
32def get_dlib_data(task):33ut.log("Loading {}".format(task))34if task == "dsprites":35# 5 factors36return dsprites.DSprites(list(range(1, 6)))37elif task == "shapes3d":38# 6 factors39return shapes3d.Shapes3D()40elif task == "norb":41# 4 factors + 1 nuisance (which we'll handle via n_dim=2)42return norb.SmallNORB()43elif task == "cars3d":44# 3 factors45return cars3d.Cars3D()46elif task == "mpi3d":47# 7 factors48return mpi3d.MPI3D()49elif task == "scream":50# 5 factors + 2 nuisance (handled as n_dim=2)51return dsprites.ScreamDSprites(list(range(1, 6)))52
53
54@gin.configurable55def make_masks(string, s_dim, mask_type):56strategy, factors = string.split("=")57assert strategy in {"s", "c", "r", "cs", "l"}, "Only allow label, share, change, rank-types"58
59# mask_type is only here to help sanity-check that I didn't accidentally60# use an invalid (strategy , mask_type) pair61if strategy == "r":62assert mask_type == "rank", "mask_type must match data collection strategy"63# Use factor indices as mask. Assumes single factor per comma64return list(map(int, factors.split(",")))65elif strategy in {"s", "c", "cs"}:66assert mask_type == "match", "mask_type must match data collection strategy"67elif strategy in {"l"}:68assert mask_type == "label", "mask_type must match data collection strategy"69
70if strategy == "cs":71# Pre-process factors to add complement set72idx = int(factors)73l = list(range(s_dim))74del l[idx]75factors = "{},{}".format(idx, "".join(map(str, l)))76
77factors = [list(map(int, l)) for l in map(list, factors.split(","))]78masks = np.zeros((len(factors), s_dim), dtype=np.float32)79for (i, f) in enumerate(factors):80masks[i, f] = 181
82if strategy == "s":83masks = 1 - masks84elif strategy == "l":85assert len(masks) == 1, "Only one mask allowed for label-strategy"86
87ut.log("make_masks output:")88ut.log(masks)89return masks90
91
92def sample_match_factors(dset, batch_size, masks, random_state):93factor1 = dset.sample_factors(batch_size, random_state)94factor2 = dset.sample_factors(batch_size, random_state)95mask_idx = np.random.choice(len(masks), batch_size)96mask = masks[mask_idx]97factor2 = factor2 * mask + factor1 * (1 - mask)98factors = np.concatenate((factor1, factor2), 0)99return factors, mask_idx100
101
102def sample_rank_factors(dset, batch_size, masks, random_state):103# We assume for ranking that masks is just a list of indices104factors = dset.sample_factors(2 * batch_size, random_state)105factor1, factor2 = np.split(factors, 2)106y = (factor1 > factor2)[:, masks].astype(np.float32)107return factors, y108
109
110def sample_match_images(dset, batch_size, masks, random_state):111factors, mask_idx = sample_match_factors(dset, batch_size, masks, random_state)112images = dset.sample_observations_from_factors(factors, random_state)113x1, x2 = np.split(images, 2)114return x1, x2, mask_idx115
116
117def sample_rank_images(dset, batch_size, masks, random_state):118factors, y = sample_rank_factors(dset, batch_size, masks, random_state)119images = dset.sample_observations_from_factors(factors, random_state)120x1, x2 = np.split(images, 2)121return x1, x2, y122
123
124def sample_images(dset, batch_size, random_state):125factors = dset.sample_factors(batch_size, random_state)126return dset.sample_observations_from_factors(factors, random_state)127
128
129@gin.configurable130def paired_data_generator(dset, masks, random_seed=None, mask_type="match"):131if mask_type == "match":132return match_data_generator(dset, masks, random_seed)133elif mask_type == "rank":134return rank_data_generator(dset, masks, random_seed)135elif mask_type == "label":136return label_data_generator(dset, masks, random_seed)137
138
139def match_data_generator(dset, masks, random_seed=None):140def generator():141random_state = np.random.RandomState(random_seed)142
143while True:144x1, x2, idx = sample_match_images(dset, 1, masks, random_state)145# Returning x1[0] and x2[0] removes batch dimension146yield x1[0], x2[0], idx.item(0)147
148return tf.data.Dataset.from_generator(149generator,150(tf.float32, tf.float32, tf.int32),151output_shapes=(dset.observation_shape, dset.observation_shape, ()))152
153
154def rank_data_generator(dset, masks, random_seed=None):155def generator():156random_state = np.random.RandomState(random_seed)157
158while True:159# Note: remove batch dimension by returning x1[0], x2[0], y[0]160x1, x2, y = sample_rank_images(dset, 1, masks, random_state)161yield x1[0], x2[0], y[0]162
163y_dim = len(masks) # Remember, masks is just a list164return tf.data.Dataset.from_generator(165generator,166(tf.float32, tf.float32, tf.float32),167output_shapes=(dset.observation_shape, dset.observation_shape, (y_dim,)))168
169
170def label_data_generator(dset, masks, random_seed=None):171# Normalize the factors using mean and stddev172m, s = [], []173for factor_size in dset.factors_num_values:174factor_values = list(range(factor_size))175m.append(np.mean(factor_values))176s.append(np.std(factor_values))177m = np.array(m)178s = np.array(s)179
180def generator():181random_state = np.random.RandomState(random_seed)182
183while True:184# Note: remove batch dimension by returning x1[0], x2[0], y[0]185factors = dset.sample_factors(1, random_state)186x = dset.sample_observations_from_factors(factors, random_state)187factors = (factors - m) / s # normalize the factors188y = factors * masks189yield x[0], y[0]190
191y_dim = masks.shape[-1] # mask is 1-hot and equal in length to s_dim192return tf.data.Dataset.from_generator(193generator,194(tf.float32, tf.float32),195output_shapes=(dset.observation_shape, (y_dim,)))196
197
198@gin.configurable199def paired_randn(batch_size, z_dim, masks, mask_type="match"):200if mask_type == "match":201return match_randn(batch_size, z_dim, masks)202elif mask_type == "rank":203return rank_randn(batch_size, z_dim, masks)204elif mask_type == "label":205return label_randn(batch_size, z_dim, masks)206
207
208def match_randn(batch_size, z_dim, masks):209# Note that masks.shape[-1] = s_dim and we assume s_dim <= z-dim210n_dim = z_dim - masks.shape[-1]211
212if n_dim == 0:213z1 = tf.random_normal((batch_size, z_dim))214z2 = tf.random_normal((batch_size, z_dim))215else:216# First sample the controllable latents217z1 = tf.random_normal((batch_size, masks.shape[-1]))218z2 = tf.random_normal((batch_size, masks.shape[-1]))219
220# Do variable fixing here (controllable latents)221mask_idx = tf.random_uniform((batch_size,), maxval=len(masks), dtype=tf.int32)222mask = tf.gather(masks, mask_idx)223z2 = z2 * mask + z1 * (1 - mask)224
225# Add nuisance dims (uncontrollable latents)226if n_dim > 0:227z1_append = tf.random_normal((batch_size, n_dim))228z2_append = tf.random_normal((batch_size, n_dim))229z1 = tf.concat((z1, z1_append), axis=-1)230z2 = tf.concat((z2, z2_append), axis=-1)231
232return z1, z2, mask_idx233
234
235def rank_randn(batch_size, z_dim, masks):236z1 = tf.random.normal((batch_size, z_dim))237z2 = tf.random.normal((batch_size, z_dim))238y = tf.gather(z1 > z2, masks, axis=-1)239y = tf.cast(y, tf.float32)240return z1, z2, y241
242
243# pylint: disable=unused-argument
244def label_randn(batch_size, z_dim, masks):245# Note that masks.shape[-1] = s_dim and we assume s_dim <= z-dim246n_dim = z_dim - masks.shape[-1]247
248if n_dim == 0:249return tf.random.normal((batch_size, z_dim)) * (1 - masks)250else:251z = tf.random.normal((batch_size, masks.shape[-1])) * (1 - masks)252n = tf.random.normal((batch_size, n_dim))253z = tf.concat((z, n), axis=-1)254return z255