google-research

Форк
0

..
README.md

SupCon - Supervised Contrastive Learning

Authors: Prannay Khosla, Piotr Teterwak, Chen Wang, Aaron Sarna, Yonglong Tian, Phillip Isola, Aaron Maschinot, Ce Liu, Dilip Krishnan

Corresponding author: Aaron Sarna (sarna@google.com)

This repo contains the TensorFlow code used to train the models used in the paper Supervised Contrastive Learning, presented at NeurIPS 2020.

It is implemented in TensorFlow v1 using the TPUEstimator framework, although much of the code, including the loss function is TensorFlow v2 compatible. It is intended to be run on Google Cloud TPU V3. The code will run fine on GPU or CPU as well, but they may not support the batch sizes we trained with, particularly since the batch norm and loss function implementations can only aggregate metrics across multiple chips on TPUs. We have found that adding memory, similar to MoCo can compensate for using smaller batch sizes in this setting, but that is not implemented in this repo. There is a PyTorch implementation of our paper available at https://github.com/HobbitLong/SupContrast. Note that the CIFAR10 numbers from the paper come from the PyTorch implementation, and the TensorFlow implementation currently slightly underperforms that. The ImageNet numbers in the paper come from this TensorFlow implementation.

Self-supervised contrastive learning in the style of SimCLR is essentially a special case of SupCon where the label for each sample is unique within the global batch. Therefore, this implementation also reproduces SimCLR, simply by setting --use_labels=False.

Running

Environment setup

First review the Google Cloud TPU tutorial for basic information on how to use Google Cloud TPUs.

Make sure that all dependencies are installed by running

pip install -r requirements.txt

Finally, once you have setup your virtual machine with Cloud TPUs, if you would like to train with ImageNet, you must follow the instructions for downloading the dataset to be compatible with TensorFlow Datasets.

You can then set the following environment variables:

TPU_NAME=<tpu-name>
STORAGE_BUCKET=gs://<storage-bucket>
DATA_DIR=$STORAGE_BUCKET/<path-to-tensorflow-dataset>
MODEL_DIR=$STORAGE_BUCKET/<path-to-store-checkpoints>

Training/Evaluation

The scripts directory contains a number of configurations for training SupCon, SimCLR and cross-entropy models on ImageNet and CIFAR10. Each script contains a comment at the top indicating the number of TPU cores we used to train that configuration. An example of how to run one of the scripts is:

scripts/supcon_imagenet_resnet50.sh --mode=train_then_eval \
--tpu_name=$TPU_NAME --data_dir=$DATA_DIR --model_dir=$MODEL_DIR

For training on GPU or CPU, you should additionally pass --use_tpu=False.

The command above specified --mode=train_then_eval, which will train the full model and then run a single evaluation pass at the very end in the same job. You can also pass --mode=train and --mode=eval to separate invocations of the script if you would like to launch a separate continuous evaluation job.

There are many hyperparameters that can be tuned. The scripts provide the values we used for the papers, but if you would like to try others, take a look at hparams_flags.py and you can just pass those flags to the script to override the defaults it sets.

The code writes metrics summaries that can be visualized using Tensorboard, setting the Tensorboard --logdir flag to the MODEL_DIR directory.

Pretrained Models

Pretrained models trained on ImageNet are available via TF-Hub for ResNet 50, ResNet 101, and ResNet 200. See links for usage instructions.

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.