google-research
321 строка · 11.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Slot Attention model for object discovery and set prediction."""
17import numpy as np
18import tensorflow as tf
19import tensorflow.keras.layers as layers
20
21
22class SlotAttention(layers.Layer):
23"""Slot Attention module."""
24
25def __init__(self, num_iterations, num_slots, slot_size, mlp_hidden_size,
26epsilon=1e-8):
27"""Builds the Slot Attention module.
28
29Args:
30num_iterations: Number of iterations.
31num_slots: Number of slots.
32slot_size: Dimensionality of slot feature vectors.
33mlp_hidden_size: Hidden layer size of MLP.
34epsilon: Offset for attention coefficients before normalization.
35"""
36super().__init__()
37self.num_iterations = num_iterations
38self.num_slots = num_slots
39self.slot_size = slot_size
40self.mlp_hidden_size = mlp_hidden_size
41self.epsilon = epsilon
42
43self.norm_inputs = layers.LayerNormalization()
44self.norm_slots = layers.LayerNormalization()
45self.norm_mlp = layers.LayerNormalization()
46
47# Parameters for Gaussian init (shared by all slots).
48self.slots_mu = self.add_weight(
49initializer="glorot_uniform",
50shape=[1, 1, self.slot_size],
51dtype=tf.float32,
52name="slots_mu")
53self.slots_log_sigma = self.add_weight(
54initializer="glorot_uniform",
55shape=[1, 1, self.slot_size],
56dtype=tf.float32,
57name="slots_log_sigma")
58
59# Linear maps for the attention module.
60self.project_q = layers.Dense(self.slot_size, use_bias=False, name="q")
61self.project_k = layers.Dense(self.slot_size, use_bias=False, name="k")
62self.project_v = layers.Dense(self.slot_size, use_bias=False, name="v")
63
64# Slot update functions.
65self.gru = layers.GRUCell(self.slot_size)
66self.mlp = tf.keras.Sequential([
67layers.Dense(self.mlp_hidden_size, activation="relu"),
68layers.Dense(self.slot_size)
69], name="mlp")
70
71def call(self, inputs):
72# `inputs` has shape [batch_size, num_inputs, inputs_size].
73inputs = self.norm_inputs(inputs) # Apply layer norm to the input.
74k = self.project_k(inputs) # Shape: [batch_size, num_inputs, slot_size].
75v = self.project_v(inputs) # Shape: [batch_size, num_inputs, slot_size].
76
77# Initialize the slots. Shape: [batch_size, num_slots, slot_size].
78slots = self.slots_mu + tf.exp(self.slots_log_sigma) * tf.random.normal(
79[tf.shape(inputs)[0], self.num_slots, self.slot_size])
80
81# Multiple rounds of attention.
82for _ in range(self.num_iterations):
83slots_prev = slots
84slots = self.norm_slots(slots)
85
86# Attention.
87q = self.project_q(slots) # Shape: [batch_size, num_slots, slot_size].
88q *= self.slot_size ** -0.5 # Normalization.
89attn_logits = tf.keras.backend.batch_dot(k, q, axes=-1)
90attn = tf.nn.softmax(attn_logits, axis=-1)
91# `attn` has shape: [batch_size, num_inputs, num_slots].
92
93# Weigted mean.
94attn += self.epsilon
95attn /= tf.reduce_sum(attn, axis=-2, keepdims=True)
96updates = tf.keras.backend.batch_dot(attn, v, axes=-2)
97# `updates` has shape: [batch_size, num_slots, slot_size].
98
99# Slot update.
100slots, _ = self.gru(updates, [slots_prev])
101slots += self.mlp(self.norm_mlp(slots))
102
103return slots
104
105
106def spatial_broadcast(slots, resolution):
107"""Broadcast slot features to a 2D grid and collapse slot dimension."""
108# `slots` has shape: [batch_size, num_slots, slot_size].
109slots = tf.reshape(slots, [-1, slots.shape[-1]])[:, None, None, :]
110grid = tf.tile(slots, [1, resolution[0], resolution[1], 1])
111# `grid` has shape: [batch_size*num_slots, width, height, slot_size].
112return grid
113
114
115def spatial_flatten(x):
116return tf.reshape(x, [-1, x.shape[1] * x.shape[2], x.shape[-1]])
117
118
119def unstack_and_split(x, batch_size, num_channels=3):
120"""Unstack batch dimension and split into channels and alpha mask."""
121unstacked = tf.reshape(x, [batch_size, -1] + x.shape.as_list()[1:])
122channels, masks = tf.split(unstacked, [num_channels, 1], axis=-1)
123return channels, masks
124
125
126class SlotAttentionAutoEncoder(layers.Layer):
127"""Slot Attention-based auto-encoder for object discovery."""
128
129def __init__(self, resolution, num_slots, num_iterations):
130"""Builds the Slot Attention-based auto-encoder.
131
132Args:
133resolution: Tuple of integers specifying width and height of input image.
134num_slots: Number of slots in Slot Attention.
135num_iterations: Number of iterations in Slot Attention.
136"""
137super().__init__()
138self.resolution = resolution
139self.num_slots = num_slots
140self.num_iterations = num_iterations
141
142self.encoder_cnn = tf.keras.Sequential([
143layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
144layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
145layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
146layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu")
147], name="encoder_cnn")
148
149self.decoder_initial_size = (8, 8)
150self.decoder_cnn = tf.keras.Sequential([
151layers.Conv2DTranspose(
15264, 5, strides=(2, 2), padding="SAME", activation="relu"),
153layers.Conv2DTranspose(
15464, 5, strides=(2, 2), padding="SAME", activation="relu"),
155layers.Conv2DTranspose(
15664, 5, strides=(2, 2), padding="SAME", activation="relu"),
157layers.Conv2DTranspose(
15864, 5, strides=(2, 2), padding="SAME", activation="relu"),
159layers.Conv2DTranspose(
16064, 5, strides=(1, 1), padding="SAME", activation="relu"),
161layers.Conv2DTranspose(
1624, 3, strides=(1, 1), padding="SAME", activation=None)
163], name="decoder_cnn")
164
165self.encoder_pos = SoftPositionEmbed(64, self.resolution)
166self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size)
167
168self.layer_norm = layers.LayerNormalization()
169self.mlp = tf.keras.Sequential([
170layers.Dense(64, activation="relu"),
171layers.Dense(64)
172], name="feedforward")
173
174self.slot_attention = SlotAttention(
175num_iterations=self.num_iterations,
176num_slots=self.num_slots,
177slot_size=64,
178mlp_hidden_size=128)
179
180def call(self, image):
181# `image` has shape: [batch_size, width, height, num_channels].
182
183# Convolutional encoder with position embedding.
184x = self.encoder_cnn(image) # CNN Backbone.
185x = self.encoder_pos(x) # Position embedding.
186x = spatial_flatten(x) # Flatten spatial dimensions (treat image as set).
187x = self.mlp(self.layer_norm(x)) # Feedforward network on set.
188# `x` has shape: [batch_size, width*height, input_size].
189
190# Slot Attention module.
191slots = self.slot_attention(x)
192# `slots` has shape: [batch_size, num_slots, slot_size].
193
194# Spatial broadcast decoder.
195x = spatial_broadcast(slots, self.decoder_initial_size)
196# `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
197x = self.decoder_pos(x)
198x = self.decoder_cnn(x)
199# `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
200
201# Undo combination of slot and batch dimension; split alpha masks.
202recons, masks = unstack_and_split(x, batch_size=image.shape[0])
203# `recons` has shape: [batch_size, num_slots, width, height, num_channels].
204# `masks` has shape: [batch_size, num_slots, width, height, 1].
205
206# Normalize alpha masks over slots.
207masks = tf.nn.softmax(masks, axis=1)
208recon_combined = tf.reduce_sum(recons * masks, axis=1) # Recombine image.
209# `recon_combined` has shape: [batch_size, width, height, num_channels].
210
211return recon_combined, recons, masks, slots
212
213
214class SlotAttentionClassifier(layers.Layer):
215"""Slot Attention-based classifier for property prediction."""
216
217def __init__(self, resolution, num_slots, num_iterations):
218"""Builds the Slot Attention-based classifier.
219
220Args:
221resolution: Tuple of integers specifying width and height of input image.
222num_slots: Number of slots in Slot Attention.
223num_iterations: Number of iterations in Slot Attention.
224"""
225super().__init__()
226self.resolution = resolution
227self.num_slots = num_slots
228self.num_iterations = num_iterations
229
230self.encoder_cnn = tf.keras.Sequential([
231layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
232layers.Conv2D(64, kernel_size=5, strides=(2, 2),
233padding="SAME", activation="relu"),
234layers.Conv2D(64, kernel_size=5, strides=(2, 2),
235padding="SAME", activation="relu"),
236layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu")
237], name="encoder_cnn")
238
239self.encoder_pos = SoftPositionEmbed(64, (32, 32))
240
241self.layer_norm = layers.LayerNormalization()
242self.mlp = tf.keras.Sequential([
243layers.Dense(64, activation="relu"),
244layers.Dense(64)
245], name="feedforward")
246
247self.slot_attention = SlotAttention(
248num_iterations=self.num_iterations,
249num_slots=self.num_slots,
250slot_size=64,
251mlp_hidden_size=128)
252
253self.mlp_classifier = tf.keras.Sequential(
254[layers.Dense(64, activation="relu"),
255layers.Dense(19, activation="sigmoid")], # Number of targets in CLEVR.
256name="mlp_classifier")
257
258def call(self, image):
259# `image` has shape: [batch_size, width, height, num_channels].
260
261# Convolutional encoder with position embedding.
262x = self.encoder_cnn(image) # CNN Backbone.
263x = self.encoder_pos(x) # Position embedding.
264x = spatial_flatten(x) # Flatten spatial dimensions (treat image as set).
265x = self.mlp(self.layer_norm(x)) # Feedforward network on set.
266# `x` has shape: [batch_size, width*height, input_size].
267
268# Slot Attention module.
269slots = self.slot_attention(x)
270# `slots` has shape: [batch_size, num_slots, slot_size].
271
272# Apply classifier per slot. The predictions have shape
273# [batch_size, num_slots, set_dimension].
274
275predictions = self.mlp_classifier(slots)
276
277return predictions
278
279
280def build_grid(resolution):
281ranges = [np.linspace(0., 1., num=res) for res in resolution]
282grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
283grid = np.stack(grid, axis=-1)
284grid = np.reshape(grid, [resolution[0], resolution[1], -1])
285grid = np.expand_dims(grid, axis=0)
286grid = grid.astype(np.float32)
287return np.concatenate([grid, 1.0 - grid], axis=-1)
288
289
290class SoftPositionEmbed(layers.Layer):
291"""Adds soft positional embedding with learnable projection."""
292
293def __init__(self, hidden_size, resolution):
294"""Builds the soft position embedding layer.
295
296Args:
297hidden_size: Size of input feature dimension.
298resolution: Tuple of integers specifying width and height of grid.
299"""
300super().__init__()
301self.dense = layers.Dense(hidden_size, use_bias=True)
302self.grid = build_grid(resolution)
303
304def call(self, inputs):
305return inputs + self.dense(self.grid)
306
307
308def build_model(resolution, batch_size, num_slots, num_iterations,
309num_channels=3, model_type="object_discovery"):
310"""Build keras model."""
311if model_type == "object_discovery":
312model_def = SlotAttentionAutoEncoder
313elif model_type == "set_prediction":
314model_def = SlotAttentionClassifier
315else:
316raise ValueError("Invalid name for model type.")
317
318image = tf.keras.Input(list(resolution) + [num_channels], batch_size)
319outputs = model_def(resolution, num_slots, num_iterations)(image)
320model = tf.keras.Model(inputs=image, outputs=outputs)
321return model
322
323
324
325
326