google-research
412 строк · 13.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utility functions for working with dsprites and 3dident datasets.
17"""
18
19import datetime
20import functools
21import os
22
23from absl import flags
24
25import numpy as np
26import tensorflow.compat.v2 as tf
27
28
29tf.compat.v1.enable_v2_behavior()
30
31
32FLAGS = flags.FLAGS
33
34DSPRITES_SHAPE_NAMES = ['square', 'ellipse', 'heart']
35
36DSPRITES_LABEL_NAMES = [
37'label_scale', 'label_orientation', 'label_x_position', 'label_y_position'
38]
39
40DSPRITES_VALUE_NAMES = [
41'value_scale', 'value_orientation', 'value_x_position', 'value_y_position'
42]
43
44### general useful functions for pandas and tensorflow
45
46
47def tf_train_eval_split(ds, num_examples, eval_split=0.0):
48"""Splits tensorflow dataset into train/eval datasets.
49
50Since we can't sample from tf datasets, the eval set is just the first n
51examples so ensure the dataset is shuffled beforehand!
52
53Args:
54ds: Tensorflow dataset to split.
55num_examples: Number of examples in dataset (required because there's no
56easy way of counting the size of the dataset without iterating through).
57eval_split: Float in range (0, 1), fraction of dataset to split out as eval
58set.
59
60Returns:
61Tuple (train dataset, num train examples, eval dataset, num eval examples).
62
63"""
64num_eval_examples = int(num_examples * eval_split)
65num_train_examples = num_examples - num_eval_examples
66ds_eval = ds.take(num_eval_examples)
67ds_train = ds.skip(num_eval_examples)
68return ds_train, num_train_examples, ds_eval, num_eval_examples
69
70
71def pd_train_eval_split(df, eval_split=0.0, seed=None, reset_index=False):
72"""Splits pandas dataframe into train/eval sets.
73
74Uses pandas DataFrame.sample to create the split.
75
76Args:
77df: Dataframe of examples to be split.
78eval_split: Float in range (0, 1), fraction of dataset to split out as eval
79set.
80seed: Optional, for reproducibility of eval split.
81reset_index: If False (default) then retains original df indexing in
82returned train/eval dataframes, if True then resets index of both.
83
84Returns:
85Tuple (train dataset, num train examples, eval dataset, num eval examples).
86
87Sizes of returned dataframes aren't necessary but included to match the
88corresponding tensorflow function.
89"""
90eval_df = df.sample(frac=eval_split, random_state=seed, replace=False)
91eval_idx = eval_df.index
92train_df = df.drop(index=eval_idx)
93num_eval_examples, num_train_examples = len(eval_df), len(train_df)
94if reset_index:
95train_df.reset_index(drop=True, inplace=True)
96eval_df.reset_index(drop=True, inplace=True)
97return train_df, num_train_examples, eval_df, num_eval_examples
98
99
100def make_backup(file_path, overwrite=False):
101"""Makes a backup copy of a file in a 'backups' subfolder.
102
103Use for e.g. pandas dataframes of datasets where accidentally modifying/losing
104the dataframe later would be really really annoying.
105
106Args:
107file_path: Path (str) to where file is located.
108overwrite: If True, overwrites any existing backup. If False (default)
109creates new file name with current date/time rather than overwrite an
110existing backup.
111
112Returns:
113Path of backup file (str).
114"""
115file_name = os.path.basename(file_path)
116backup_dir = os.path.join(os.path.dirname(file_path), 'backups')
117tf.io.gfile.makedirs(backup_dir)
118backup_path = os.path.join(backup_dir, file_name)
119try:
120tf.io.gfile.copy(file_path, backup_path, overwrite=overwrite)
121except tf.errors.OpError:
122current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
123# insert current_time right before the file extension
124new_backup_path = backup_path.split('.')
125new_backup_path = '.'.join(new_backup_path[:-1] +
126[current_time, new_backup_path[-1]])
127tf.io.gfile.copy(file_path, new_backup_path)
128backup_path = new_backup_path
129return backup_path
130
131
132def df_to_ds(df, preprocess_fn):
133dataset = tf.data.Dataset.from_tensor_slices(dict(df))
134ds = dataset.map(preprocess_fn)
135return ds
136
137
138def get_image(img_path, num_channels=None):
139# set expand_animations to false to ensure image has a shape attribute;
140# otherwise causes problems later with tf.image.resize.
141img = tf.image.decode_image(
142tf.io.read_file(img_path), expand_animations=False, channels=num_channels)
143return img
144
145
146def get_image_by_id(df, idx, num_channels=None):
147return get_image(df.loc[idx, 'img_path'], num_channels=num_channels)
148
149
150def latent_lookup_map(df, latents):
151"""Finds indices of examples in dataframe whose latents match those specified.
152
153Args:
154df: The dataframe in which to search for latents.
155latents: Dict of latent values to match on. Ignores any keys that aren't
156present in dataframe columns.
157
158Returns:
159List of indices of all examples in dataframe which match the given latents.
160
161"""
162keys = [i for i in latents.keys() if i in df.columns]
163
164# use np.isclose rather than equality test because we're often dealing with
165# floats.
166check_conditions = np.all([np.isclose(df[k], latents[k]) for k in keys],
167axis=0)
168return df[check_conditions].index.values
169
170
171### dsprites-specific functions
172
173
174def preprocess_dsprites_images(x, img_size=None, num_channels=None):
175"""Fetches dsprite image and preprocess.
176
177Args:
178x: Pandas Series, dict, etc containing 'img_path' key whose value points to
179the image in question.
180img_size: (optional) 2-tuple to resize image to. Use None (default) for no
181resizing.
182num_channels: (optional) How many channels the resulting image will have.
183
184Returns:
185Tensorflow tensor of image, with values scaled to range [0, 1].
186"""
187img = get_image(x['img_path'], num_channels=num_channels)
188if img_size is not None:
189img = tf.image.resize(img, img_size)
190img = tf.cast(img, tf.float32)
191img = tf.clip_by_value(img, 0., 1.)
192return img
193
194
195def preprocess_dsprites_latents(x):
196"""Convertes dsprites latents into a standard format for our dataset.
197
198Args:
199x: Row from dsprites dataframe containing all latents for example.
200
201Returns:
202Tuple of tensors (labels, values). Both contain one-hot encoding of the
203example's shape. Orientation value is scaled by 1/2pi.
204
205"""
206shapes = [float(x[i]) for i in DSPRITES_SHAPE_NAMES]
207labels = shapes + [float(x[i]) for i in DSPRITES_LABEL_NAMES]
208values = shapes + [float(x[i]) for i in DSPRITES_VALUE_NAMES]
209values = values * tf.constant([1, 1, 1, 1, 1 / (2 * np.pi), 1, 1])
210return tf.convert_to_tensor(labels), tf.convert_to_tensor(values)
211
212
213def preprocess_dsprites(x, img_size=None, num_channels=None):
214"""All dsprites preprocessing functions in one function for use in map fns.
215
216Args:
217x: Row from pandas dataframe.
218img_size: 2-tuple of desired image size, or None for no resizing.
219num_channels: Desired number of channels in image, or None for the default
220determined by tf.image.decode_image.
221
222Returns:
223Dict with 'image' and 'values' keys, containing the input image and latent
224values.
225"""
226image = preprocess_dsprites_images(x, img_size, num_channels)
227labels, values = preprocess_dsprites_latents(x)
228return {'image': image, 'labels': labels, 'values': values}
229
230
231def get_preprocess_dsprites_fn(img_size=None, num_channels=None):
232return functools.partial(
233preprocess_dsprites,
234img_size=img_size,
235num_channels=num_channels)
236
237
238##### 3dident-specific functions
239
240
241THREEDIDENT_VALUE_NAMES = [
242'pos_x', 'pos_y', 'pos_z', 'rot_phi', 'rot_theta', 'rot_psi', 'spotlight',
243'hue_object', 'hue_spotlight', 'hue_background'
244]
245
246
247def preprocess_threedident_images(x, img_size=None, num_channels=3):
248"""Fetches 3DIdent image and preprocess.
249
250Args:
251x: Pandas Series, dict, etc containing 'img_path' key whose value contains
252the path to the image in question.
253img_size: (optional) 2-tuple to resize image to. Use None (default) for no
254resizing.
255num_channels: How many channels the resulting image will have. (Default 3)
256
257Returns:
258Tensorflow tensor of image, values in range [0, 1].
259"""
260img = get_image(
261x['img_path'], num_channels=num_channels)
262img = tf.cast(img, tf.float32)
263img = img / 255.0
264if img_size is not None:
265img = tf.image.resize(img, img_size)
266img = tf.clip_by_value(img, 0., 1.)
267return img
268
269
270def preprocess_threedident_values(x):
271values = [float(x[i]) for i in THREEDIDENT_VALUE_NAMES]
272return tf.convert_to_tensor(values)
273
274
275def preprocess_threedident(x, img_size, num_channels):
276"""All 3DIdent preprocessing functions in one convenience function.
277
278Args:
279x: Row from pandas dataframe.
280img_size: 2-tuple of desired image size, or None for no resizing.
281num_channels: Desired number of channels in image.
282
283Returns:
284Dict with 'image' and 'values' keys, containing the input image and latent
285values.
286"""
287image = preprocess_threedident_images(x, img_size, num_channels)
288values = preprocess_threedident_values(x)
289return {'image': image, 'values': values}
290
291
292def get_preprocess_threedident_fn(img_size, num_channels):
293return functools.partial(
294preprocess_threedident,
295img_size=img_size,
296num_channels=num_channels)
297
298
299### functions for contrastive example generation
300
301
302def get_contrastive_example_idx(z, df, sample_fn, deterministic=False):
303"""Finds a suitable contrastive example z_prime conditioned on z.
304
305Used for generating contrastive pairs from a dataframe of examples, where we
306modify the example at the level of the latent values and are constrained to
307return another example from the given dataframe.
308
309Args:
310z: The example to condition on.
311df: The dataframe of examples to sample the conditional z_prime from.
312sample_fn: The function used to generate z_prime given z.
313deterministic: If True, always returns first result; if False, returns a
314random one (default False).
315
316Returns:
317Index of new example z_prime in dataframe df.
318"""
319zprime_latents = sample_fn(z, df)
320if not isinstance(zprime_latents, dict):
321raise TypeError(
322'Contrastive sample function should return a dict of latents.')
323if deterministic:
324idx = latent_lookup_map(df, zprime_latents)[0]
325else:
326idx = np.random.choice(latent_lookup_map(df, zprime_latents))
327return int(idx)
328
329
330def dsprites_simple_noise_fn(z, df=None):
331"""Applies random noise to dsprites example for generating contrastive pair.
332
333Given latents for a dsprites example z, generates the latent labels for
334z_prime | z by randomly adding +1/-1 to some subset of scale, orientation,
335x_pos, y_pos. Orientation is computed mod 40 (the max label value) to ensure
336it is treated as circular.
337
338Args:
339z: Pandas Series containing latents for a dsprite example. Must
340contain the label and shape latents, value latents are optional.
341df: Dataframe, not used here.
342
343Returns:
344zprime_latents: Dict containing shape and label latents for zprime.
345Guaranteed to be different to z, guaranteed to leave shape latent
346unchanged.
347
348"""
349del df # not used
350
351features = DSPRITES_LABEL_NAMES
352shapes = DSPRITES_SHAPE_NAMES
353max_values = np.array([6, 40, 32, 32])
354
355zprime = z[features].to_numpy()
356while np.array_equal(z[features], zprime):
357zprime += np.random.randint(-1, 2, size=4)
358zprime[1] = zprime[1] % max_values[1]
359zprime = np.minimum(np.maximum(zprime, np.zeros(4)), max_values - 1)
360
361zprime = np.concatenate((z[shapes], zprime)).astype('int32')
362all_features = shapes + features
363zprime_latents = {all_features[i]: zprime[i] for i in range(7)}
364return zprime_latents
365
366
367def threedident_simple_noise_fn(z, df, tol=1.0, mult=1.1, deterministic=False):
368"""Finds example that is a small perturbation away from given example z.
369
370Given an example z, finds a nearby example z_prime subject to the condition
371that z_prime exists in the dataset. This is achieved by considering a ball of
372radius tol around each latent and sampling a (non-identity) example from the
373intersection.
374
375Args:
376z: Dataframe row of the example to condition on.
377df: Dataframe of all available examples.
378tol: Starting radius of balls around each latent (Default 1.0).
379mult: Float > 1, multiplier to scale ball radius by, used when intersection
380of balls is empty.
381deterministic: If True, always returns first result; if False, returns a
382random one. (default False).
383
384Returns:
385Dict containing latent values of new example z_prime (guaranteed to be
386different from input example).
387
388"""
389latents = THREEDIDENT_VALUE_NAMES
390# first three latents have twice the range of the others
391scaling = np.array([1, 1, 1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
392tol_array = scaling * tol
393while True:
394# for each latent, get a ball of examples around that value
395balls = [
396df[np.abs(df[latents[i]] - z[latents[i]]) < tol_array[i]]
397for i in range(len(latents))
398]
399# drop the original z
400new_idx = set(balls[0].drop(index=[z.name]).index)
401for b in balls:
402new_idx = new_idx.intersection(b.index)
403if new_idx:
404if deterministic:
405result_id = list(new_idx)[0]
406else:
407result_id = np.random.choice(list(new_idx))
408result = df.loc[result_id]
409return {k: result[k] for k in latents}
410else:
411# slightly increase the size of the balls and try again
412tol_array *= mult
413