google-research
313 строк · 10.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Generate and load dsprites and 3dident datasets.
17
18Creates dataframes of dsprites and 3dident datasets so that we can sample from
19and search them, and handles the train/eval/test splitting and conversion to
20tf.data.Dataset format in a reproducible way.
21
22Also generates and loads datasets for contrastive learning experiments, where
23we need to sample large datasets of similar pairs of images based on their
24latents.
25"""
26
27import os28
29from absl import flags30from absl import logging31
32import numpy as np33import pandas as pd34
35import tensorflow.compat.v2 as tf36import tensorflow_datasets as tfds37
38from graph_compression.contrastive_learning.learning_latents import data_utils39
40FLAGS = flags.FLAGS41
42
43### functions for loading existing datasets
44
45
46def get_standard_dataset(name,47dataframe_path=None,48img_size=None,49num_channels=None,50eval_split=None,51seed=None,52reset_index=True):53"""Loads dsprites/3dident in pandas dataframe and tensorflow dataset formats.54
55Note that 3dident already has a train/test split by default, so specifying
56eval_split will split off a third dataframe/dataset out of the train set.
57Similarly if a non-default original dataframe is used, with a 'split'
58column containing values 'train' or 'test' for each entry, this will produce
59the same effect.
60
61Args:
62name: 'dsprites' or 'threedident'.
63dataframe_path: Str, location of saved csv containing latent values and
64paths to saved images for each example. For dsprites this can be created
65using build_dsprites_dataframe.
66img_size: 2-tuple, optional, size to reshape images to during preprocessing.
67num_channels: Int, optional, specify number of channels in processed images.
68eval_split: Float in [0,1], optional, splits out this fraction of train
69dataset into a separate dataframe+dataset.
70seed: Int, optional, specify random seed for reproducibility of eval set.
71reset_index: Bool, default True, whether to reindex the train and eval
72dataframes after splitting.
73
74Returns:
75Dict of tuples (pd dataframe, tf dataset, number of examples) with keys
76'train', 'test' (if train/test split exists in original dataframe) and
77'eval' (if eval_split is not None).
78
79"""
80if name == 'dsprites':81preprocess_fn = data_utils.get_preprocess_dsprites_fn(82img_size, num_channels)83path = dataframe_path84
85
86elif name == 'threedident':87preprocess_fn = data_utils.get_preprocess_threedident_fn(88img_size, num_channels)89path = dataframe_path90
91
92else:93raise ValueError(94f'Dataset name must be one of "dsprites" or "threedident", you provided {name}'95)96
97with tf.io.gfile.GFile(path, 'rb') as f:98df = pd.read_csv(f)99
100datasets = {}101if 'split' in df.columns:102test_df, df = df[df.split == 'test'].copy(), df[df.split == 'train'].copy()103num_test_examples = len(test_df)104test_ds = data_utils.df_to_ds(test_df, preprocess_fn)105datasets.update({'test': (test_df, test_ds, num_test_examples)})106
107num_examples = len(df)108if eval_split is None:109ds = data_utils.df_to_ds(df, preprocess_fn)110datasets.update({'train': (df, ds, num_examples)})111else:112train_df, num_train_examples, eval_df, num_eval_examples = data_utils.pd_train_eval_split(113df, eval_split, seed, reset_index)114train_ds = data_utils.df_to_ds(train_df, preprocess_fn)115eval_ds = data_utils.df_to_ds(eval_df, preprocess_fn)116datasets.update({117'train': (train_df, train_ds, num_train_examples),118'eval': (eval_df, eval_ds, num_eval_examples)119})120return datasets121
122
123def get_contrastive_dataset(name,124dataframe_path=None,125img_size=None,126num_channels=None):127"""Loads existing contrastive dataset.128
129Args:
130name: 'dsprites' or 'threedident'.
131dataframe_path: Str, optional; location of specific dataframe to load. If
132not specified, the default is used.
133img_size: 2-tuple, optional; specify image resizing during preprocessing
134step.
135num_channels: Int, optional; specify number of channels for processed
136images.
137
138Returns:
139Tuple of form (pandas dataframe, tf dataset, number of examples).
140
141"""
142if name == 'dsprites':143preprocess_fn = data_utils.get_preprocess_dsprites_fn(144img_size, num_channels)145path = dataframe_path146
147elif name == 'threedident':148preprocess_fn = data_utils.get_preprocess_threedident_fn(149img_size, num_channels)150path = dataframe_path151
152
153else:154raise ValueError(155f'Dataset name must be one of "dsprites" or "threedident", you provided {name}'156)157
158with tf.io.gfile.GFile(path, 'rb') as f:159# header=[0,1] handles the multi-index; remove this if build_dataset changes160contrastive_df = pd.read_csv(f, header=[0, 1])161
162z_ds = data_utils.df_to_ds(contrastive_df['z'], preprocess_fn)163zprime_ds = data_utils.df_to_ds(contrastive_df['zprime'], preprocess_fn)164
165contrastive_ds = tf.data.Dataset.zip({'z': z_ds, 'zprime': zprime_ds})166
167return contrastive_df, contrastive_ds, len(contrastive_df)168
169
170### functions to recreate a dataset from scratch
171
172
173def build_dsprites_dataframe(target_path):174"""Recreates the dsprites dataframe from base tfds version.175
176Each image is converted to png and written to the 'images' subfolder of the
177specified target_path.
178
179The dataframe contains the latent values and labels of each example, a one-hot
180encoding of its shape, and the path to the corresponding image.
181
182Args:
183target_path: Str, path to where the dataframe and images should be saved.
184
185Returns:
186Location where dataframe was saved.
187"""
188
189tfds_dataset, tfds_info = tfds.load(190'dsprites', split='train', with_info=True, shuffle_files=False)191num_examples = tfds_info.splits['train'].num_examples192
193# list the features we care about194feature_keys = list(tfds_info.features.keys())195feature_keys.remove('image')196feature_keys.remove('value_shape')197feature_keys.remove('label_shape')198shapes = ['square', 'ellipse', 'heart']199
200# helper function to modify how the data is stored in the tf dataset before201# we convert it to a pandas dataframe202def pandas_setup(x):203# encoding the image as a png byte string turns out to be a convenient way204# of temporarily storing the images until we can write them to disk.205img = tf.io.encode_png(x['image'])206latents = {k: x[k] for k in feature_keys}207latents.update(208{k: int(x['label_shape'] == i) for i, k in enumerate(shapes)})209latents['png'] = img210return latents211
212temp_ds = tfds_dataset.map(pandas_setup)213dsprites_df = tfds.as_dataframe(temp_ds)214dsprites_df = dsprites_df[shapes + feature_keys + ['png']] # reorder columns215
216# setup for saving the pngs to disk217if os.path.basename(target_path).endswith('.csv'):218dataset_dir = os.path.dirname(target_path)219dataframe_location = target_path220else:221dataset_dir = target_path222dataframe_location = os.path.join(target_path, 'dsprites_df.csv')223
224images_path = os.path.join(dataset_dir, 'images')225tf.io.gfile.makedirs(images_path) # creates any missing parent directories226
227padding = len(str(num_examples))228temp_index = pd.Series(range(num_examples))229
230def create_image_paths(x):231path_to_file = os.path.join(images_path, str(x).zfill(padding) + '.png')232return path_to_file233
234# create a col in the dataframe for the image file path235dsprites_df['img_path'] = temp_index.apply(create_image_paths)236
237# iterate through the dataframe and save each image to specified folder238for i, x in dsprites_df.iterrows():239img = tf.io.decode_image(x['png'])240with tf.io.gfile.GFile(x['img_path'], 'wb') as f:241tf.keras.preprocessing.image.save_img(f, img.numpy(), file_format='PNG')242if i % 100 == 0:243logging.info('%s of %s images processed', i + 1, num_examples)244
245dsprites_df.drop(columns=['png'], inplace=True)246logging.info('finished processing images')247
248logging.info('conversion complete, saving...')249with tf.io.gfile.GFile(dataframe_location, 'wb') as f:250dsprites_df.to_csv(f, index=False)251
252# also make a copy so if you screw up the original df you don't have to run253# the entire generation process again254_ = data_utils.make_backup(dataframe_location)255
256return dataframe_location257
258
259def build_contrastive_dataframe(df,260save_location,261num_samples,262sample_fn,263seed=None):264"""Builds a contrastive dataframe from scratch.265
266Given a universe of examples and a rule for how to choose z_prime conditioned
267on z, generates a dataframe consisting of num_sample positive pairs for use in
268contrastive training.
269
270Args:
271df: The dataframe of available examples.
272save_location: Str, where to save the new contrastive dataframe.
273num_samples: int, how many pairs to generate.
274sample_fn: Function that specifies how to choose z_prime given z.
275seed: Int, optional; use if the dataframe construction needs to be
276reproducible.
277
278Returns:
279Location of new contrastive dataframe.
280"""
281z_df = df.sample(n=num_samples, random_state=seed, replace=True)282
283if seed is not None:284np.random.seed(seed)285
286tf.io.gfile.makedirs(os.path.split(save_location)[0])287
288zprime_index = []289counter = 0290for _, z in z_df.iterrows():291z_prime = data_utils.get_contrastive_example_idx(z, df, sample_fn)292zprime_index.append(z_prime)293# need a separate counter because iterrows keys off the index which is not294# sequential here295counter += 1296if counter % 100 == 0:297logging.info('%s of %s examples generated', counter + 1, num_samples)298
299zprime_df = df.loc[zprime_index].reset_index(drop=True)300z_df.reset_index(drop=True, inplace=True)301
302contrastive_df = pd.concat([z_df, zprime_df], axis=1, keys=['z', 'zprime'])303
304with tf.io.gfile.GFile(save_location, 'wb') as f:305contrastive_df.to_csv(f, index=False)306
307# also make a copy so if you screw up the original df you don't have to run308# the entire generation process again309_ = data_utils.make_backup(save_location)310
311logging.info('dataframe saved at %s', save_location)312
313return save_location314