google-research

Форк
0
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Generate and load dsprites and 3dident datasets.
17

18
Creates dataframes of dsprites and 3dident datasets so that we can sample from
19
and search them, and handles the train/eval/test splitting and conversion to
20
tf.data.Dataset format in a reproducible way.
21

22
Also generates and loads datasets for contrastive learning experiments, where
23
we need to sample large datasets of similar pairs of images based on their
24
latents.
25
"""
26

27
import os
28

29
from absl import flags
30
from absl import logging
31

32
import numpy as np
33
import pandas as pd
34

35
import tensorflow.compat.v2 as tf
36
import tensorflow_datasets as tfds
37

38
from graph_compression.contrastive_learning.learning_latents import data_utils
39

40
FLAGS = flags.FLAGS
41

42

43
### functions for loading existing datasets
44

45

46
def get_standard_dataset(name,
47
                         dataframe_path=None,
48
                         img_size=None,
49
                         num_channels=None,
50
                         eval_split=None,
51
                         seed=None,
52
                         reset_index=True):
53
  """Loads dsprites/3dident in pandas dataframe and tensorflow dataset formats.
54

55
  Note that 3dident already has a train/test split by default, so specifying
56
  eval_split will split off a third dataframe/dataset out of the train set.
57
  Similarly if a non-default original dataframe is used, with a 'split'
58
  column containing values 'train' or 'test' for each entry, this will produce
59
  the same effect.
60

61
  Args:
62
    name: 'dsprites' or 'threedident'.
63
    dataframe_path: Str, location of saved csv containing latent values and
64
      paths to saved images for each example. For dsprites this can be created
65
      using build_dsprites_dataframe.
66
    img_size: 2-tuple, optional, size to reshape images to during preprocessing.
67
    num_channels: Int, optional, specify number of channels in processed images.
68
    eval_split: Float in [0,1], optional, splits out this fraction of train
69
      dataset into a separate dataframe+dataset.
70
    seed: Int, optional, specify random seed for reproducibility of eval set.
71
    reset_index: Bool, default True, whether to reindex the train and eval
72
      dataframes after splitting.
73

74
  Returns:
75
    Dict of tuples (pd dataframe, tf dataset, number of examples) with keys
76
    'train', 'test' (if train/test split exists in original dataframe) and
77
    'eval' (if eval_split is not None).
78

79
  """
80
  if name == 'dsprites':
81
    preprocess_fn = data_utils.get_preprocess_dsprites_fn(
82
        img_size, num_channels)
83
    path = dataframe_path
84

85

86
  elif name == 'threedident':
87
    preprocess_fn = data_utils.get_preprocess_threedident_fn(
88
        img_size, num_channels)
89
    path = dataframe_path
90

91

92
  else:
93
    raise ValueError(
94
        f'Dataset name must be one of "dsprites" or "threedident", you provided {name}'
95
    )
96

97
  with tf.io.gfile.GFile(path, 'rb') as f:
98
    df = pd.read_csv(f)
99

100
  datasets = {}
101
  if 'split' in df.columns:
102
    test_df, df = df[df.split == 'test'].copy(), df[df.split == 'train'].copy()
103
    num_test_examples = len(test_df)
104
    test_ds = data_utils.df_to_ds(test_df, preprocess_fn)
105
    datasets.update({'test': (test_df, test_ds, num_test_examples)})
106

107
  num_examples = len(df)
108
  if eval_split is None:
109
    ds = data_utils.df_to_ds(df, preprocess_fn)
110
    datasets.update({'train': (df, ds, num_examples)})
111
  else:
112
    train_df, num_train_examples, eval_df, num_eval_examples = data_utils.pd_train_eval_split(
113
        df, eval_split, seed, reset_index)
114
    train_ds = data_utils.df_to_ds(train_df, preprocess_fn)
115
    eval_ds = data_utils.df_to_ds(eval_df, preprocess_fn)
116
    datasets.update({
117
        'train': (train_df, train_ds, num_train_examples),
118
        'eval': (eval_df, eval_ds, num_eval_examples)
119
    })
120
  return datasets
121

122

123
def get_contrastive_dataset(name,
124
                            dataframe_path=None,
125
                            img_size=None,
126
                            num_channels=None):
127
  """Loads existing contrastive dataset.
128

129
  Args:
130
    name: 'dsprites' or 'threedident'.
131
    dataframe_path: Str, optional; location of specific dataframe to load. If
132
      not specified, the default is used.
133
    img_size: 2-tuple, optional; specify image resizing during preprocessing
134
      step.
135
    num_channels: Int, optional; specify number of channels for processed
136
      images.
137

138
  Returns:
139
    Tuple of form (pandas dataframe, tf dataset, number of examples).
140

141
  """
142
  if name == 'dsprites':
143
    preprocess_fn = data_utils.get_preprocess_dsprites_fn(
144
        img_size, num_channels)
145
    path = dataframe_path
146

147
  elif name == 'threedident':
148
    preprocess_fn = data_utils.get_preprocess_threedident_fn(
149
        img_size, num_channels)
150
    path = dataframe_path
151

152

153
  else:
154
    raise ValueError(
155
        f'Dataset name must be one of "dsprites" or "threedident", you provided {name}'
156
    )
157

158
  with tf.io.gfile.GFile(path, 'rb') as f:
159
    # header=[0,1] handles the multi-index; remove this if build_dataset changes
160
    contrastive_df = pd.read_csv(f, header=[0, 1])
161

162
  z_ds = data_utils.df_to_ds(contrastive_df['z'], preprocess_fn)
163
  zprime_ds = data_utils.df_to_ds(contrastive_df['zprime'], preprocess_fn)
164

165
  contrastive_ds = tf.data.Dataset.zip({'z': z_ds, 'zprime': zprime_ds})
166

167
  return contrastive_df, contrastive_ds, len(contrastive_df)
168

169

170
### functions to recreate a dataset from scratch
171

172

173
def build_dsprites_dataframe(target_path):
174
  """Recreates the dsprites dataframe from base tfds version.
175

176
  Each image is converted to png and written to the 'images' subfolder of the
177
  specified target_path.
178

179
  The dataframe contains the latent values and labels of each example, a one-hot
180
  encoding of its shape, and the path to the corresponding image.
181

182
  Args:
183
    target_path: Str, path to where the dataframe and images should be saved.
184

185
  Returns:
186
    Location where dataframe was saved.
187
  """
188

189
  tfds_dataset, tfds_info = tfds.load(
190
      'dsprites', split='train', with_info=True, shuffle_files=False)
191
  num_examples = tfds_info.splits['train'].num_examples
192

193
  # list the features we care about
194
  feature_keys = list(tfds_info.features.keys())
195
  feature_keys.remove('image')
196
  feature_keys.remove('value_shape')
197
  feature_keys.remove('label_shape')
198
  shapes = ['square', 'ellipse', 'heart']
199

200
  # helper function to modify how the data is stored in the tf dataset before
201
  # we convert it to a pandas dataframe
202
  def pandas_setup(x):
203
    # encoding the image as a png byte string turns out to be a convenient way
204
    # of temporarily storing the images until we can write them to disk.
205
    img = tf.io.encode_png(x['image'])
206
    latents = {k: x[k] for k in feature_keys}
207
    latents.update(
208
        {k: int(x['label_shape'] == i) for i, k in enumerate(shapes)})
209
    latents['png'] = img
210
    return latents
211

212
  temp_ds = tfds_dataset.map(pandas_setup)
213
  dsprites_df = tfds.as_dataframe(temp_ds)
214
  dsprites_df = dsprites_df[shapes + feature_keys + ['png']]  # reorder columns
215

216
  # setup for saving the pngs to disk
217
  if os.path.basename(target_path).endswith('.csv'):
218
    dataset_dir = os.path.dirname(target_path)
219
    dataframe_location = target_path
220
  else:
221
    dataset_dir = target_path
222
    dataframe_location = os.path.join(target_path, 'dsprites_df.csv')
223

224
  images_path = os.path.join(dataset_dir, 'images')
225
  tf.io.gfile.makedirs(images_path)  # creates any missing parent directories
226

227
  padding = len(str(num_examples))
228
  temp_index = pd.Series(range(num_examples))
229

230
  def create_image_paths(x):
231
    path_to_file = os.path.join(images_path, str(x).zfill(padding) + '.png')
232
    return path_to_file
233

234
  # create a col in the dataframe for the image file path
235
  dsprites_df['img_path'] = temp_index.apply(create_image_paths)
236

237
  # iterate through the dataframe and save each image to specified folder
238
  for i, x in dsprites_df.iterrows():
239
    img = tf.io.decode_image(x['png'])
240
    with tf.io.gfile.GFile(x['img_path'], 'wb') as f:
241
      tf.keras.preprocessing.image.save_img(f, img.numpy(), file_format='PNG')
242
    if i % 100 == 0:
243
      logging.info('%s of %s images processed', i + 1, num_examples)
244

245
  dsprites_df.drop(columns=['png'], inplace=True)
246
  logging.info('finished processing images')
247

248
  logging.info('conversion complete, saving...')
249
  with tf.io.gfile.GFile(dataframe_location, 'wb') as f:
250
    dsprites_df.to_csv(f, index=False)
251

252
  # also make a copy so if you screw up the original df you don't have to run
253
  # the entire generation process again
254
  _ = data_utils.make_backup(dataframe_location)
255

256
  return dataframe_location
257

258

259
def build_contrastive_dataframe(df,
260
                                save_location,
261
                                num_samples,
262
                                sample_fn,
263
                                seed=None):
264
  """Builds a contrastive dataframe from scratch.
265

266
  Given a universe of examples and a rule for how to choose z_prime conditioned
267
  on z, generates a dataframe consisting of num_sample positive pairs for use in
268
  contrastive training.
269

270
  Args:
271
    df: The dataframe of available examples.
272
    save_location: Str, where to save the new contrastive dataframe.
273
    num_samples: int, how many pairs to generate.
274
    sample_fn: Function that specifies how to choose z_prime given z.
275
    seed: Int, optional; use if the dataframe construction needs to be
276
      reproducible.
277

278
  Returns:
279
    Location of new contrastive dataframe.
280
  """
281
  z_df = df.sample(n=num_samples, random_state=seed, replace=True)
282

283
  if seed is not None:
284
    np.random.seed(seed)
285

286
  tf.io.gfile.makedirs(os.path.split(save_location)[0])
287

288
  zprime_index = []
289
  counter = 0
290
  for _, z in z_df.iterrows():
291
    z_prime = data_utils.get_contrastive_example_idx(z, df, sample_fn)
292
    zprime_index.append(z_prime)
293
    # need a separate counter because iterrows keys off the index which is not
294
    # sequential here
295
    counter += 1
296
    if counter % 100 == 0:
297
      logging.info('%s of %s examples generated', counter + 1, num_samples)
298

299
  zprime_df = df.loc[zprime_index].reset_index(drop=True)
300
  z_df.reset_index(drop=True, inplace=True)
301

302
  contrastive_df = pd.concat([z_df, zprime_df], axis=1, keys=['z', 'zprime'])
303

304
  with tf.io.gfile.GFile(save_location, 'wb') as f:
305
    contrastive_df.to_csv(f, index=False)
306

307
  # also make a copy so if you screw up the original df you don't have to run
308
  # the entire generation process again
309
  _ = data_utils.make_backup(save_location)
310

311
  logging.info('dataframe saved at %s', save_location)
312

313
  return save_location
314

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.