google-research
Experiments with contrastive learning
This codebase supports zeef@ and nishanthd@'s experiments with contrastive learning and images defined by sets of of latent values. Our goal is to explore whether we can recover the latent values from representations generated by self-supervised models such as SimCLR, inspired by the work of arXiv:2103.06875, arXiv:2102.08850.
The code is based on Tensorflow 2.x and uses both tf.data.Datasets and pandas DataFrames to manage the datasets.
Datasets
We focus on dsprites and 3dident, since both consist of images defined by a set of latent values. However, while dsprites is available through tfds, the tensorflow dataset format makes it difficult to look up specific examples or search for similar examples (needed for generating positive pairs for contrastive training) and 3dident is not currently available through tfds at all.
datasets.py
contains functions to convert dsprites to a pandas dataframe format, load either dsprites or 3dident with a standardized format ready for experiments, and generate and load contrastive training sets from either dataset. There is currently no function to convert 3dident to this format, but it can be downloaded from here, extracted to the folder of your choosing, and then loaded using datasets.get_standard_dataset
.
data_utils.py
contains helper functions for working with both datasets, including dataset-specific functions for preprocessing and searching for similar examples.
Other datasets could easily be added to this framework by mimicking the preprocessing functions in data_utils.py
and adding the appropriate logic to datasets.get_standard_dataset
.
Training
train_linear_layer.py
implements a custom training loop to train a linear layer on top of a pre-trained SimCLR model. It can be run on either a single GPU (although the limiting factor here is whether the pre-trained SimCLR can fit in the GPU's memory) or on multiple TPUs via tf.distribute
.
We provide helper functions for measuring training progress on the dsprites dataset in metrics_utils.py
. These include custom accuracy measurements (e.g. defining a latent value to be "accurately" predicted if it is closer to the correct latent value than adjacent ones) and a class to handle metrics logging of each latent during training.
Further metrics can easily be added by modifying the setup_metrics
and update_metrics
methods of the appropriate MetricsInterface class.
To launch an experiment, use python -m train_linear_layer.py --pretrained_model_path=...
, filling in the path to where your pretrained SimCLR (or equivalent) model weights are stored.