google-research

Форк
0
120 строк · 4.4 Кб
1
# coding=utf-8
2
# Copyright 2021 DeepMind Technologies Limited and the Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implementation of Very Deep VAEs (https://arxiv.org/abs/2011.10650)."""
17

18
from typing import Optional, Any
19

20
import chex
21
import flax
22
from flax import linen as nn
23
import jax
24
import jax.numpy as jnp
25
import ml_collections
26

27
from vdvae_flax import decoder
28
from vdvae_flax import encoder
29
from vdvae_flax import vdvae_utils
30

31

32
@flax.struct.dataclass
33
class VdvaeOutput:
34
  samples: chex.Array  # [B, H, W, C]
35
  elbo: Optional[chex.Array]  # [B]
36
  reconstruction_loss: Optional[chex.Array]  # [B]
37
  kl_per_decoder_block: Optional[chex.Array]  # [num_decoder_blocks, B]
38

39

40
class Vdvae(nn.Module):
41
  """Very Deep VAE."""
42

43
  config: ml_collections.ConfigDict
44

45
  @nn.compact
46
  def __call__(
47
      self,
48
      sample_rng,
49
      num_samples_to_generate,
50
      inputs = None,
51
      context_vectors = None,
52
      temperature = 1.,
53
  ):
54
    """Evaluates a VDVAE.
55

56
    Args:
57
      sample_rng: random key for sampling.
58
      num_samples_to_generate: number of images to generate from the prior
59
        distribution, conditioned only on optional context vectors. This
60
        argument should be provided only when inputs is not provided. If
61
        provided, it should be positive.
62
      inputs: an optional batch of input RGB images of shape [B, H, W, C], where
63
        H=W. These should be provided when training the VDVAE. These inputs
64
        should be of type uint8.
65
      context_vectors: an optional batch of input context vectors of shape [B,
66
        D] that the VDVAE is conditioned on. These can be omitted, in which case
67
        the samples will be conditioned on the inputs only if they are provided,
68
        or not conditioned on anything otherwise.
69
      temperature: when inputs are not provided, each decoder block samples a
70
        latent unconditionally using the mean of the prior distribution, and its
71
        log_std + log(temperature).
72

73
    Returns:
74
      A VdvaeOutput object containing sampled images. If inputs were provided,
75
      the sampled images are sampled using the posterior distribution,
76
      conditioned on the inputs (and optional context vectors). In this case,
77
      the output also contains the elbo, reconstruction loss and KL divergences
78
      between prior and posterior for each block of the decoder. If inputs are
79
      not provided, samples are produced using the prior distribution,
80
      conditioned only on the optional context vectors.
81
    """
82
    encoder_model = encoder.Encoder(**self.config.encoder)
83
    decoder_model = decoder.Decoder(**self.config.decoder)
84
    sampler_model = vdvae_utils.QuantizedLogisticMixtureNetwork(
85
        **self.config.sampler)
86

87
    if inputs is None:
88
      encoder_outputs = None
89
    else:
90
      if inputs.dtype != jnp.uint8:
91
        raise ValueError("Expected inputs to be of type uint8 but got "
92
                         f"{inputs.dtype}")
93
      preprocessed_inputs = vdvae_utils.cast_to_float_center_and_normalize(
94
          inputs)
95
      encoder_outputs = encoder_model(
96
          preprocessed_inputs, context_vectors=context_vectors)
97

98
    sample_rng, sample_rng_now = jax.random.split(sample_rng)
99
    decoder_outputs, kl = decoder_model(
100
        sample_rng=sample_rng_now,
101
        num_samples_to_generate=num_samples_to_generate,
102
        context_vectors=context_vectors,
103
        encoder_outputs=encoder_outputs,
104
        temperature=temperature)
105

106
    sample_rng, sample_rng_now = jax.random.split(sample_rng)
107
    sampler_output = sampler_model(sample_rng_now, decoder_outputs, inputs)
108
    if inputs is not None:
109
      reconstruction_loss = sampler_output.negative_log_likelihood
110
      total_kl = jnp.sum(kl, axis=0)
111
      elbo = reconstruction_loss + total_kl
112
    else:
113
      reconstruction_loss = None
114
      elbo = None
115

116
    return VdvaeOutput(
117
        samples=sampler_output.samples,
118
        elbo=elbo,
119
        kl_per_decoder_block=kl,
120
        reconstruction_loss=reconstruction_loss)
121

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.