google-research

Форк
0

..
/
weak_disentangle 
README.md

Weakly Supervised Disentanglement with Guarantees

This codebase trains the models analyzed in the paper: Weakly Supervised Disentanglement with Guarantees.

In addition to installing the requirements, make sure to follow the instructions from disentanglement_lib to download the necessary datasets in the parent directory that contains the folder weak_disentangle.

To run this code with a choice of default config (for complete-change-pairing on dsprites) from the parent directory, simply run

python3 -m weak_disentangle.main

Changing the Configuration

The codebase uses gin-config to specify the model being run. A full default specification is provided in weak_disentangle/configs/gan.gin. Below, we explain the behavior of the important configuration parameters:

train.model_type: There are two model types:
  - "gen": specifies the use of a paired GAN (for match pairing and rank pairing).
  - "van": specifies use of a vanilla GAN (for restricted labeling and full labeling).

train.dset_name: We used five datasets in our experiments:
  - "shapes3d"
  - "dsprites"
  - "scream"
  - "norb"
  - "cars3d"

train.s_dim: The number of true underlying factors:
  - 6: for "shapes3d"
  - 5: for "dsprites" and "scream"
  - 4: for "norb"
  - 3: for "cars3d"

train.n_dim: The number of nuisance factors:
  - 0: for "shapes3d", "dsprites", "cars3d"
  - 2: for "scream", "norb"

train.factors: The supervision procedure. Here are some examples:
  - "s=0,1,2,3,4": share-labeling on factors 0,1,2,3,4 individually
  - "c=0,1,2,3,4": change-labeling on factors 0,1,2,3,4 individually
  - "c=012,234": change-labeling on the group of factors {0,1,2} and the group of factors
    {2,3,4}. The intersection rule will result in restrictiveness on factor 2.
  - "r=0,1": ranking on factors 0,1 individually
  - "l=": vanilla GAN with no labels
  - "l=1,2": vanilla GAN with restricted labeling on factors 1 and 2 individually.
  - "cs=0,1": both change and share pairing on factors 0,1 individually
  These examples should hopefully give you a sense of how to set train.factors for your desired setting.

mask_type: There are three mask types:
  - "match"
  - "label"
  - "rank"
  You need to make sure that the mask type match your train.model_type and train.factors choice

In our code, we also have five architectural hyperparameters that we modified:

initializer.method:
  - "keras"
  - "pytorch"

train.enc_lr_mul:
  - 1
  - 2

Discriminator.width:
  - 1
  - 2

Discriminator.share_dense:
  - True
  - False

Discriminator.uncond_bias:
  - True
  - False

When using vanilla GAN, the Discriminator is swapped out for the LabelDiscriminator:

LabelDiscriminator.width:
  - 1
  - 2

LabelDiscriminator.share_dense:
  - True
  - False

LabelDiscriminator.uncond_bias:
  - True
  - False

If you have any questions about the project, reach out to Rui on Twitter @_smileyball or contact via email.

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.