google-research

Форк
0

README.md

Experiments with contrastive learning

This codebase supports zeef@ and nishanthd@'s experiments with contrastive learning and images defined by sets of of latent values. Our goal is to explore whether we can recover the latent values from representations generated by self-supervised models such as SimCLR, inspired by the work of arXiv:2103.06875, arXiv:2102.08850.

The code is based on Tensorflow 2.x and uses both tf.data.Datasets and pandas DataFrames to manage the datasets.

Datasets

We focus on dsprites and 3dident, since both consist of images defined by a set of latent values. However, while dsprites is available through tfds, the tensorflow dataset format makes it difficult to look up specific examples or search for similar examples (needed for generating positive pairs for contrastive training) and 3dident is not currently available through tfds at all.

datasets.py contains functions to convert dsprites to a pandas dataframe format, load either dsprites or 3dident with a standardized format ready for experiments, and generate and load contrastive training sets from either dataset. There is currently no function to convert 3dident to this format, but it can be downloaded from here, extracted to the folder of your choosing, and then loaded using datasets.get_standard_dataset.

data_utils.py contains helper functions for working with both datasets, including dataset-specific functions for preprocessing and searching for similar examples.

Other datasets could easily be added to this framework by mimicking the preprocessing functions in data_utils.py and adding the appropriate logic to datasets.get_standard_dataset.

Training

train_linear_layer.py implements a custom training loop to train a linear layer on top of a pre-trained SimCLR model. It can be run on either a single GPU (although the limiting factor here is whether the pre-trained SimCLR can fit in the GPU's memory) or on multiple TPUs via tf.distribute.

We provide helper functions for measuring training progress on the dsprites dataset in metrics_utils.py. These include custom accuracy measurements (e.g. defining a latent value to be "accurately" predicted if it is closer to the correct latent value than adjacent ones) and a class to handle metrics logging of each latent during training.

Further metrics can easily be added by modifying the setup_metrics and update_metrics methods of the appropriate MetricsInterface class.

To launch an experiment, use python -m train_linear_layer.py --pretrained_model_path=..., filling in the path to where your pretrained SimCLR (or equivalent) model weights are stored.

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.