google-research
137 строк · 4.8 Кб
1# coding=utf-8
2# Copyright 2021 DeepMind Technologies Limited and the Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implementation of a VDVAE encoder."""
17
18from typing import Mapping, Optional, Sequence, Tuple
19
20import chex
21from flax import linen as nn
22import jax
23import jax.numpy as jnp
24import numpy as np
25
26from vdvae_flax import blocks as blocks_lib
27
28
29class Encoder(nn.Module):
30"""The Encoder of a VDVAE, mapping from images to latents."""
31
32num_blocks: int
33num_channels: int
34bottlenecked_num_channels: int
35downsampling_rates: Sequence[Tuple[int, int]]
36precision: Optional[jax.lax.Precision] = None
37"""Builds a VDVAE encoder.
38
39Args:
40num_blocks: number of residual blocks in the encoder.
41num_channels: number of channels output by each of the residual blocks.
42bottlenecked_num_channels: number of channels used internally by each
43residual block.
44downsampling_rates: a sequence of tuples (block index, downsampling rate).
45Blocks whose indices are not in this sequence conserve the resolution.
46precision: Optional :class:`jax.lax.Precision` to pass to convolutions.
47name: name of the haiku module.
48"""
49
50def setup(self):
51self._in_conv = blocks_lib.get_vdvae_convolution(
52self.num_channels, (3, 3), name='in_conv', precision=self.precision)
53
54sampling_rates = sorted(self.downsampling_rates)
55num_blocks = self.num_blocks
56
57current_sequence_start = 0
58blocks = []
59for block_idx, rate in sampling_rates:
60if rate == 1:
61continue
62sequence_length = block_idx - current_sequence_start
63if sequence_length > 0:
64# Add sequence of non-downsampling blocks as a single layer stack.
65for i in range(current_sequence_start, block_idx):
66blocks.append(
67blocks_lib.ResBlock(
68self.bottlenecked_num_channels,
69self.num_channels,
70downsampling_rate=1,
71use_residual_connection=True,
72last_weights_scale=np.sqrt(1.0 / self.num_blocks),
73precision=self.precision,
74name=f'res_block_{i}'))
75
76# Add downsampling block
77blocks.append(
78blocks_lib.ResBlock(
79self.bottlenecked_num_channels,
80self.num_channels,
81downsampling_rate=rate,
82use_residual_connection=True,
83last_weights_scale=np.sqrt(1.0 / self.num_blocks),
84precision=self.precision,
85name=f'res_block_{block_idx}'))
86# Update running parameters
87current_sequence_start = block_idx + 1
88# Add remaining blocks after last downsampling block
89sequence_length = num_blocks - current_sequence_start
90if sequence_length > 0:
91# Add sequence of non-downsampling blocks as a single layer stack.
92for i in range(current_sequence_start, num_blocks):
93blocks.append(
94blocks_lib.ResBlock(
95self.bottlenecked_num_channels,
96self.num_channels,
97downsampling_rate=1,
98use_residual_connection=True,
99last_weights_scale=np.sqrt(1.0 / self.num_blocks),
100precision=self.precision,
101name=f'res_block_{i}'))
102
103self._blocks = blocks
104
105def __call__(
106self,
107inputs,
108context_vectors = None,
109):
110"""Encodes a batch of input images.
111
112Args:
113inputs: a batch of input images of shape [B, H, W, C]. They should be
114centered and of type float32.
115context_vectors: optional batch of shape [B, D]. These are typically used
116to condition the VDVAE.
117
118Returns:
119a mapping from resolution to encoded image.
120"""
121
122if inputs.dtype != jnp.float32:
123raise ValueError('Expected inputs to be of type float32 but got '
124f'{inputs.dtype}')
125if len(inputs.shape) != 4 or inputs.shape[1] != inputs.shape[2]:
126raise ValueError('inputs should be a batch of images of shape '
127f'[B, H, W, C] with H=W, but got {inputs.shape}')
128outputs = self._in_conv(inputs)
129resolution = outputs.shape[1]
130activations = {resolution: outputs}
131
132for block in self._blocks:
133outputs = block(outputs, context_vectors)
134resolution = outputs.shape[1]
135activations[resolution] = outputs
136
137return activations
138