google-research

Форк
0
137 строк · 4.8 Кб
1
# coding=utf-8
2
# Copyright 2021 DeepMind Technologies Limited and the Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implementation of a VDVAE encoder."""
17

18
from typing import Mapping, Optional, Sequence, Tuple
19

20
import chex
21
from flax import linen as nn
22
import jax
23
import jax.numpy as jnp
24
import numpy as np
25

26
from vdvae_flax import blocks as blocks_lib
27

28

29
class Encoder(nn.Module):
30
  """The Encoder of a VDVAE, mapping from images to latents."""
31

32
  num_blocks: int
33
  num_channels: int
34
  bottlenecked_num_channels: int
35
  downsampling_rates: Sequence[Tuple[int, int]]
36
  precision: Optional[jax.lax.Precision] = None
37
  """Builds a VDVAE encoder.
38

39
  Args:
40
    num_blocks: number of residual blocks in the encoder.
41
    num_channels: number of channels output by each of the residual blocks.
42
    bottlenecked_num_channels: number of channels used internally by each
43
      residual block.
44
    downsampling_rates: a sequence of tuples (block index, downsampling rate).
45
      Blocks whose indices are not in this sequence conserve the resolution.
46
    precision: Optional :class:`jax.lax.Precision` to pass to convolutions.
47
    name: name of the haiku module.
48
  """
49

50
  def setup(self):
51
    self._in_conv = blocks_lib.get_vdvae_convolution(
52
        self.num_channels, (3, 3), name='in_conv', precision=self.precision)
53

54
    sampling_rates = sorted(self.downsampling_rates)
55
    num_blocks = self.num_blocks
56

57
    current_sequence_start = 0
58
    blocks = []
59
    for block_idx, rate in sampling_rates:
60
      if rate == 1:
61
        continue
62
      sequence_length = block_idx - current_sequence_start
63
      if sequence_length > 0:
64
        # Add sequence of non-downsampling blocks as a single layer stack.
65
        for i in range(current_sequence_start, block_idx):
66
          blocks.append(
67
              blocks_lib.ResBlock(
68
                  self.bottlenecked_num_channels,
69
                  self.num_channels,
70
                  downsampling_rate=1,
71
                  use_residual_connection=True,
72
                  last_weights_scale=np.sqrt(1.0 / self.num_blocks),
73
                  precision=self.precision,
74
                  name=f'res_block_{i}'))
75

76
      # Add downsampling block
77
      blocks.append(
78
          blocks_lib.ResBlock(
79
              self.bottlenecked_num_channels,
80
              self.num_channels,
81
              downsampling_rate=rate,
82
              use_residual_connection=True,
83
              last_weights_scale=np.sqrt(1.0 / self.num_blocks),
84
              precision=self.precision,
85
              name=f'res_block_{block_idx}'))
86
      # Update running parameters
87
      current_sequence_start = block_idx + 1
88
    # Add remaining blocks after last downsampling block
89
    sequence_length = num_blocks - current_sequence_start
90
    if sequence_length > 0:
91
      # Add sequence of non-downsampling blocks as a single layer stack.
92
      for i in range(current_sequence_start, num_blocks):
93
        blocks.append(
94
            blocks_lib.ResBlock(
95
                self.bottlenecked_num_channels,
96
                self.num_channels,
97
                downsampling_rate=1,
98
                use_residual_connection=True,
99
                last_weights_scale=np.sqrt(1.0 / self.num_blocks),
100
                precision=self.precision,
101
                name=f'res_block_{i}'))
102

103
    self._blocks = blocks
104

105
  def __call__(
106
      self,
107
      inputs,
108
      context_vectors = None,
109
  ):
110
    """Encodes a batch of input images.
111

112
    Args:
113
      inputs: a batch of input images of shape [B, H, W, C]. They should be
114
        centered and of type float32.
115
      context_vectors: optional batch of shape [B, D]. These are typically used
116
        to condition the VDVAE.
117

118
    Returns:
119
      a mapping from resolution to encoded image.
120
    """
121

122
    if inputs.dtype != jnp.float32:
123
      raise ValueError('Expected inputs to be of type float32 but got '
124
                       f'{inputs.dtype}')
125
    if len(inputs.shape) != 4 or inputs.shape[1] != inputs.shape[2]:
126
      raise ValueError('inputs should be a batch of images of shape '
127
                       f'[B, H, W, C] with H=W, but got {inputs.shape}')
128
    outputs = self._in_conv(inputs)
129
    resolution = outputs.shape[1]
130
    activations = {resolution: outputs}
131

132
    for block in self._blocks:
133
      outputs = block(outputs, context_vectors)
134
      resolution = outputs.shape[1]
135
      activations[resolution] = outputs
136

137
    return activations
138

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.