google-research
120 строк · 4.4 Кб
1# coding=utf-8
2# Copyright 2021 DeepMind Technologies Limited and the Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implementation of Very Deep VAEs (https://arxiv.org/abs/2011.10650)."""
17
18from typing import Optional, Any19
20import chex21import flax22from flax import linen as nn23import jax24import jax.numpy as jnp25import ml_collections26
27from vdvae_flax import decoder28from vdvae_flax import encoder29from vdvae_flax import vdvae_utils30
31
32@flax.struct.dataclass33class VdvaeOutput:34samples: chex.Array # [B, H, W, C]35elbo: Optional[chex.Array] # [B]36reconstruction_loss: Optional[chex.Array] # [B]37kl_per_decoder_block: Optional[chex.Array] # [num_decoder_blocks, B]38
39
40class Vdvae(nn.Module):41"""Very Deep VAE."""42
43config: ml_collections.ConfigDict44
45@nn.compact46def __call__(47self,48sample_rng,49num_samples_to_generate,50inputs = None,51context_vectors = None,52temperature = 1.,53):54"""Evaluates a VDVAE.55
56Args:
57sample_rng: random key for sampling.
58num_samples_to_generate: number of images to generate from the prior
59distribution, conditioned only on optional context vectors. This
60argument should be provided only when inputs is not provided. If
61provided, it should be positive.
62inputs: an optional batch of input RGB images of shape [B, H, W, C], where
63H=W. These should be provided when training the VDVAE. These inputs
64should be of type uint8.
65context_vectors: an optional batch of input context vectors of shape [B,
66D] that the VDVAE is conditioned on. These can be omitted, in which case
67the samples will be conditioned on the inputs only if they are provided,
68or not conditioned on anything otherwise.
69temperature: when inputs are not provided, each decoder block samples a
70latent unconditionally using the mean of the prior distribution, and its
71log_std + log(temperature).
72
73Returns:
74A VdvaeOutput object containing sampled images. If inputs were provided,
75the sampled images are sampled using the posterior distribution,
76conditioned on the inputs (and optional context vectors). In this case,
77the output also contains the elbo, reconstruction loss and KL divergences
78between prior and posterior for each block of the decoder. If inputs are
79not provided, samples are produced using the prior distribution,
80conditioned only on the optional context vectors.
81"""
82encoder_model = encoder.Encoder(**self.config.encoder)83decoder_model = decoder.Decoder(**self.config.decoder)84sampler_model = vdvae_utils.QuantizedLogisticMixtureNetwork(85**self.config.sampler)86
87if inputs is None:88encoder_outputs = None89else:90if inputs.dtype != jnp.uint8:91raise ValueError("Expected inputs to be of type uint8 but got "92f"{inputs.dtype}")93preprocessed_inputs = vdvae_utils.cast_to_float_center_and_normalize(94inputs)95encoder_outputs = encoder_model(96preprocessed_inputs, context_vectors=context_vectors)97
98sample_rng, sample_rng_now = jax.random.split(sample_rng)99decoder_outputs, kl = decoder_model(100sample_rng=sample_rng_now,101num_samples_to_generate=num_samples_to_generate,102context_vectors=context_vectors,103encoder_outputs=encoder_outputs,104temperature=temperature)105
106sample_rng, sample_rng_now = jax.random.split(sample_rng)107sampler_output = sampler_model(sample_rng_now, decoder_outputs, inputs)108if inputs is not None:109reconstruction_loss = sampler_output.negative_log_likelihood110total_kl = jnp.sum(kl, axis=0)111elbo = reconstruction_loss + total_kl112else:113reconstruction_loss = None114elbo = None115
116return VdvaeOutput(117samples=sampler_output.samples,118elbo=elbo,119kl_per_decoder_block=kl,120reconstruction_loss=reconstruction_loss)121