google-research

Форк
0
412 строк · 13.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utility functions for working with dsprites and 3dident datasets.
17
"""
18

19
import datetime
20
import functools
21
import os
22

23
from absl import flags
24

25
import numpy as np
26
import tensorflow.compat.v2 as tf
27

28

29
tf.compat.v1.enable_v2_behavior()
30

31

32
FLAGS = flags.FLAGS
33

34
DSPRITES_SHAPE_NAMES = ['square', 'ellipse', 'heart']
35

36
DSPRITES_LABEL_NAMES = [
37
    'label_scale', 'label_orientation', 'label_x_position', 'label_y_position'
38
]
39

40
DSPRITES_VALUE_NAMES = [
41
    'value_scale', 'value_orientation', 'value_x_position', 'value_y_position'
42
]
43

44
### general useful functions for pandas and tensorflow
45

46

47
def tf_train_eval_split(ds, num_examples, eval_split=0.0):
48
  """Splits tensorflow dataset into train/eval datasets.
49

50
  Since we can't sample from tf datasets, the eval set is just the first n
51
  examples so ensure the dataset is shuffled beforehand!
52

53
  Args:
54
    ds: Tensorflow dataset to split.
55
    num_examples: Number of examples in dataset (required because there's no
56
      easy way of counting the size of the dataset without iterating through).
57
    eval_split: Float in range (0, 1), fraction of dataset to split out as eval
58
      set.
59

60
  Returns:
61
    Tuple (train dataset, num train examples, eval dataset, num eval examples).
62

63
  """
64
  num_eval_examples = int(num_examples * eval_split)
65
  num_train_examples = num_examples - num_eval_examples
66
  ds_eval = ds.take(num_eval_examples)
67
  ds_train = ds.skip(num_eval_examples)
68
  return ds_train, num_train_examples, ds_eval, num_eval_examples
69

70

71
def pd_train_eval_split(df, eval_split=0.0, seed=None, reset_index=False):
72
  """Splits pandas dataframe into train/eval sets.
73

74
  Uses pandas DataFrame.sample to create the split.
75

76
  Args:
77
    df: Dataframe of examples to be split.
78
    eval_split: Float in range (0, 1), fraction of dataset to split out as eval
79
      set.
80
    seed: Optional, for reproducibility of eval split.
81
    reset_index: If False (default) then retains original df indexing in
82
      returned train/eval dataframes, if True then resets index of both.
83

84
  Returns:
85
    Tuple (train dataset, num train examples, eval dataset, num eval examples).
86

87
    Sizes of returned dataframes aren't necessary but included to match the
88
    corresponding tensorflow function.
89
  """
90
  eval_df = df.sample(frac=eval_split, random_state=seed, replace=False)
91
  eval_idx = eval_df.index
92
  train_df = df.drop(index=eval_idx)
93
  num_eval_examples, num_train_examples = len(eval_df), len(train_df)
94
  if reset_index:
95
    train_df.reset_index(drop=True, inplace=True)
96
    eval_df.reset_index(drop=True, inplace=True)
97
  return train_df, num_train_examples, eval_df, num_eval_examples
98

99

100
def make_backup(file_path, overwrite=False):
101
  """Makes a backup copy of a file in a 'backups' subfolder.
102

103
  Use for e.g. pandas dataframes of datasets where accidentally modifying/losing
104
  the dataframe later would be really really annoying.
105

106
  Args:
107
    file_path: Path (str) to where file is located.
108
    overwrite: If True, overwrites any existing backup. If False (default)
109
      creates new file name with current date/time rather than overwrite an
110
      existing backup.
111

112
  Returns:
113
    Path of backup file (str).
114
  """
115
  file_name = os.path.basename(file_path)
116
  backup_dir = os.path.join(os.path.dirname(file_path), 'backups')
117
  tf.io.gfile.makedirs(backup_dir)
118
  backup_path = os.path.join(backup_dir, file_name)
119
  try:
120
    tf.io.gfile.copy(file_path, backup_path, overwrite=overwrite)
121
  except tf.errors.OpError:
122
    current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
123
    # insert current_time right before the file extension
124
    new_backup_path = backup_path.split('.')
125
    new_backup_path = '.'.join(new_backup_path[:-1] +
126
                               [current_time, new_backup_path[-1]])
127
    tf.io.gfile.copy(file_path, new_backup_path)
128
    backup_path = new_backup_path
129
  return backup_path
130

131

132
def df_to_ds(df, preprocess_fn):
133
  dataset = tf.data.Dataset.from_tensor_slices(dict(df))
134
  ds = dataset.map(preprocess_fn)
135
  return ds
136

137

138
def get_image(img_path, num_channels=None):
139
  # set expand_animations to false to ensure image has a shape attribute;
140
  # otherwise causes problems later with tf.image.resize.
141
  img = tf.image.decode_image(
142
      tf.io.read_file(img_path), expand_animations=False, channels=num_channels)
143
  return img
144

145

146
def get_image_by_id(df, idx, num_channels=None):
147
  return get_image(df.loc[idx, 'img_path'], num_channels=num_channels)
148

149

150
def latent_lookup_map(df, latents):
151
  """Finds indices of examples in dataframe whose latents match those specified.
152

153
  Args:
154
    df: The dataframe in which to search for latents.
155
    latents: Dict of latent values to match on. Ignores any keys that aren't
156
      present in dataframe columns.
157

158
  Returns:
159
    List of indices of all examples in dataframe which match the given latents.
160

161
  """
162
  keys = [i for i in latents.keys() if i in df.columns]
163

164
  # use np.isclose rather than equality test because we're often dealing with
165
  # floats.
166
  check_conditions = np.all([np.isclose(df[k], latents[k]) for k in keys],
167
                            axis=0)
168
  return df[check_conditions].index.values
169

170

171
### dsprites-specific functions
172

173

174
def preprocess_dsprites_images(x, img_size=None, num_channels=None):
175
  """Fetches dsprite image and preprocess.
176

177
  Args:
178
    x: Pandas Series, dict, etc containing 'img_path' key whose value points to
179
      the image in question.
180
    img_size: (optional) 2-tuple to resize image to. Use None (default) for no
181
      resizing.
182
    num_channels: (optional) How many channels the resulting image will have.
183

184
  Returns:
185
    Tensorflow tensor of image, with values scaled to range [0, 1].
186
  """
187
  img = get_image(x['img_path'], num_channels=num_channels)
188
  if img_size is not None:
189
    img = tf.image.resize(img, img_size)
190
  img = tf.cast(img, tf.float32)
191
  img = tf.clip_by_value(img, 0., 1.)
192
  return img
193

194

195
def preprocess_dsprites_latents(x):
196
  """Convertes dsprites latents into a standard format for our dataset.
197

198
  Args:
199
    x: Row from dsprites dataframe containing all latents for example.
200

201
  Returns:
202
    Tuple of tensors (labels, values). Both contain one-hot encoding of the
203
    example's shape.  Orientation value is scaled by 1/2pi.
204

205
  """
206
  shapes = [float(x[i]) for i in DSPRITES_SHAPE_NAMES]
207
  labels = shapes + [float(x[i]) for i in DSPRITES_LABEL_NAMES]
208
  values = shapes + [float(x[i]) for i in DSPRITES_VALUE_NAMES]
209
  values = values * tf.constant([1, 1, 1, 1, 1 / (2 * np.pi), 1, 1])
210
  return tf.convert_to_tensor(labels), tf.convert_to_tensor(values)
211

212

213
def preprocess_dsprites(x, img_size=None, num_channels=None):
214
  """All dsprites preprocessing functions in one function for use in map fns.
215

216
  Args:
217
    x: Row from pandas dataframe.
218
    img_size: 2-tuple of desired image size, or None for no resizing.
219
    num_channels: Desired number of channels in image, or None for the default
220
      determined by tf.image.decode_image.
221

222
  Returns:
223
    Dict with 'image' and 'values' keys, containing the input image and latent
224
    values.
225
  """
226
  image = preprocess_dsprites_images(x, img_size, num_channels)
227
  labels, values = preprocess_dsprites_latents(x)
228
  return {'image': image, 'labels': labels, 'values': values}
229

230

231
def get_preprocess_dsprites_fn(img_size=None, num_channels=None):
232
  return functools.partial(
233
      preprocess_dsprites,
234
      img_size=img_size,
235
      num_channels=num_channels)
236

237

238
##### 3dident-specific functions
239

240

241
THREEDIDENT_VALUE_NAMES = [
242
    'pos_x', 'pos_y', 'pos_z', 'rot_phi', 'rot_theta', 'rot_psi', 'spotlight',
243
    'hue_object', 'hue_spotlight', 'hue_background'
244
]
245

246

247
def preprocess_threedident_images(x, img_size=None, num_channels=3):
248
  """Fetches 3DIdent image and preprocess.
249

250
  Args:
251
    x: Pandas Series, dict, etc containing 'img_path' key whose value contains
252
      the path to the image in question.
253
    img_size: (optional) 2-tuple to resize image to. Use None (default) for no
254
      resizing.
255
    num_channels: How many channels the resulting image will have. (Default 3)
256

257
  Returns:
258
    Tensorflow tensor of image, values in range [0, 1].
259
  """
260
  img = get_image(
261
      x['img_path'], num_channels=num_channels)
262
  img = tf.cast(img, tf.float32)
263
  img = img / 255.0
264
  if img_size is not None:
265
    img = tf.image.resize(img, img_size)
266
  img = tf.clip_by_value(img, 0., 1.)
267
  return img
268

269

270
def preprocess_threedident_values(x):
271
  values = [float(x[i]) for i in THREEDIDENT_VALUE_NAMES]
272
  return tf.convert_to_tensor(values)
273

274

275
def preprocess_threedident(x, img_size, num_channels):
276
  """All 3DIdent preprocessing functions in one convenience function.
277

278
  Args:
279
    x: Row from pandas dataframe.
280
    img_size: 2-tuple of desired image size, or None for no resizing.
281
    num_channels: Desired number of channels in image.
282

283
  Returns:
284
    Dict with 'image' and 'values' keys, containing the input image and latent
285
    values.
286
  """
287
  image = preprocess_threedident_images(x, img_size, num_channels)
288
  values = preprocess_threedident_values(x)
289
  return {'image': image, 'values': values}
290

291

292
def get_preprocess_threedident_fn(img_size, num_channels):
293
  return functools.partial(
294
      preprocess_threedident,
295
      img_size=img_size,
296
      num_channels=num_channels)
297

298

299
### functions for contrastive example generation
300

301

302
def get_contrastive_example_idx(z, df, sample_fn, deterministic=False):
303
  """Finds a suitable contrastive example z_prime conditioned on z.
304

305
  Used for generating contrastive pairs from a dataframe of examples, where we
306
  modify the example at the level of the latent values and are constrained to
307
  return another example from the given dataframe.
308

309
  Args:
310
    z: The example to condition on.
311
    df: The dataframe of examples to sample the conditional z_prime from.
312
    sample_fn: The function used to generate z_prime given z.
313
    deterministic: If True, always returns first result; if False, returns a
314
      random one (default False).
315

316
  Returns:
317
    Index of new example z_prime in dataframe df.
318
  """
319
  zprime_latents = sample_fn(z, df)
320
  if not isinstance(zprime_latents, dict):
321
    raise TypeError(
322
        'Contrastive sample function should return a dict of latents.')
323
  if deterministic:
324
    idx = latent_lookup_map(df, zprime_latents)[0]
325
  else:
326
    idx = np.random.choice(latent_lookup_map(df, zprime_latents))
327
  return int(idx)
328

329

330
def dsprites_simple_noise_fn(z, df=None):
331
  """Applies random noise to dsprites example for generating contrastive pair.
332

333
  Given latents for a dsprites example z, generates the latent labels for
334
  z_prime | z by randomly adding +1/-1 to some subset of scale, orientation,
335
  x_pos, y_pos. Orientation is computed mod 40 (the max label value) to ensure
336
  it is treated as circular.
337

338
  Args:
339
    z: Pandas Series containing latents for a dsprite example. Must
340
      contain the label and shape latents, value latents are optional.
341
    df: Dataframe, not used here.
342

343
  Returns:
344
    zprime_latents: Dict containing shape and label latents for zprime.
345
      Guaranteed to be different to z, guaranteed to leave shape latent
346
      unchanged.
347

348
  """
349
  del df  # not used
350

351
  features = DSPRITES_LABEL_NAMES
352
  shapes = DSPRITES_SHAPE_NAMES
353
  max_values = np.array([6, 40, 32, 32])
354

355
  zprime = z[features].to_numpy()
356
  while np.array_equal(z[features], zprime):
357
    zprime += np.random.randint(-1, 2, size=4)
358
    zprime[1] = zprime[1] % max_values[1]
359
    zprime = np.minimum(np.maximum(zprime, np.zeros(4)), max_values - 1)
360

361
  zprime = np.concatenate((z[shapes], zprime)).astype('int32')
362
  all_features = shapes + features
363
  zprime_latents = {all_features[i]: zprime[i] for i in range(7)}
364
  return zprime_latents
365

366

367
def threedident_simple_noise_fn(z, df, tol=1.0, mult=1.1, deterministic=False):
368
  """Finds example that is a small perturbation away from given example z.
369

370
  Given an example z, finds a nearby example z_prime subject to the condition
371
  that z_prime exists in the dataset. This is achieved by considering a ball of
372
  radius tol around each latent and sampling a (non-identity) example from the
373
  intersection.
374

375
  Args:
376
    z: Dataframe row of the example to condition on.
377
    df: Dataframe of all available examples.
378
    tol: Starting radius of balls around each latent (Default 1.0).
379
    mult: Float > 1, multiplier to scale ball radius by, used when intersection
380
      of balls is empty.
381
    deterministic: If True, always returns first result; if False, returns a
382
      random one. (default False).
383

384
  Returns:
385
    Dict containing latent values of new example z_prime (guaranteed to be
386
    different from input example).
387

388
  """
389
  latents = THREEDIDENT_VALUE_NAMES
390
  # first three latents have twice the range of the others
391
  scaling = np.array([1, 1, 1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
392
  tol_array = scaling * tol
393
  while True:
394
    # for each latent, get a ball of examples around that value
395
    balls = [
396
        df[np.abs(df[latents[i]] - z[latents[i]]) < tol_array[i]]
397
        for i in range(len(latents))
398
    ]
399
    # drop the original z
400
    new_idx = set(balls[0].drop(index=[z.name]).index)
401
    for b in balls:
402
      new_idx = new_idx.intersection(b.index)
403
    if new_idx:
404
      if deterministic:
405
        result_id = list(new_idx)[0]
406
      else:
407
        result_id = np.random.choice(list(new_idx))
408
      result = df.loc[result_id]
409
      return {k: result[k] for k in latents}
410
    else:
411
      # slightly increase the size of the balls and try again
412
      tol_array *= mult
413

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.