google-research
216 строк · 6.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities for Tensorflow 2.0.
17
18Partially adapted from:
19https://www.tensorflow.org/tutorials/text/image_captioning
20"""
21# pylint: disable=invalid-name
22
23from __future__ import absolute_import24from __future__ import division25
26import tensorflow as tf27
28
29def film_params(sentence_embedding, n_layer_channel):30"""Generate FiLM parameters from a sentence embedding.31
32Generate FiLM parameters from a sentence embedding. This method assumes a
33batch dimension exists.
34
35Args:
36sentence_embedding: a tensor containing batched sentenced embedding to be
37transformed
38n_layer_channel: a list of integers specifying how many channels are at
39each hidden layer to be FiLM'ed
40
41Returns:
42a tuple of tensors the same length as n_layer_channel. Each element
43contains all gamma_i and beta_i for a single hidden layer.
44"""
45n_total = sum(n_layer_channel) * 246all_params = tf.layers.dense(sentence_embedding, n_total)47all_params = tf.keras.layers.Dense(482 * sum * (n_layer_channel), activation=tf.nn.relu)49return tf.split(all_params, [c * 2 for c in n_layer_channel], 1)50
51
52def stack_conv_layer(layer_cfg, padding='same'):53"""Stack convolution layers per layer_cfg.54
55Args:
56layer_cfg: list of integer tuples specifying the parameter each layer;
57each tuple should be (channel, kernel size, strides)
58padding: what kind of padding the conv layers use
59
60Returns:
61the keras model with stacked conv layers
62"""
63layers = []64for cfg in layer_cfg[:-1]:65layers.append(66tf.keras.layers.Conv2D(67filters=cfg[0],68kernel_size=cfg[1],69strides=cfg[2],70activation=tf.nn.relu,71padding=padding))72final_cfg = layer_cfg[-1]73layers.append(74tf.keras.layers.Conv2D(75final_cfg[0], final_cfg[1], final_cfg[2], padding=padding))76return tf.keras.Sequential(layers)77
78
79def stack_dense_layer(layer_cfg):80"""Stack Dense layers.81
82Args:
83layer_cfg: list of integer specifying the number of units at each layer
84
85Returns:
86the keras model with stacked dense layers
87"""
88layers = []89for cfg in layer_cfg[:-1]:90layers.append(tf.keras.layers.Dense(cfg, activation=tf.nn.relu))91layers.append(tf.keras.layers.Dense(layer_cfg[-1]))92return tf.keras.Sequential(layers)93
94
95def soft_variables_update(source_variables, target_variables, polyak_rate=1.0):96"""Update the target variables using exponential moving average.97
98Specifically, v_s' = v_s * polyak_rate + (1-polyak_rate) * v_t
99
100Args:
101source_variables: the moving average variables
102target_variables: the new observations
103polyak_rate: rate of moving average
104
105Returns:
106Operation that does the update
107"""
108updates = []109for (v_s, v_t) in zip(source_variables, target_variables):110v_t.shape.assert_is_compatible_with(v_s.shape)111
112def update_fn(v1, v2):113"""Update variables."""114# For not trainable variables do hard updates.115return v1.assign(polyak_rate * v1 + (1 - polyak_rate) * v2)116
117update = update_fn(v_t, v_s)118updates.append(update)119return updates120
121
122def vector_tensor_product(a, b):123""""Returns keras layer that perfrom a outer product between a and b."""124# a shape: [B, ?, d], b shape: [B, ?, d]125shape_layer = tf.keras.layers.Lambda(tf.shape)126shape = shape_layer(b)127shape_numpy = b.get_shape()128variable_length = shape[1] # variable_len = ?129expand_dims_layer_1 = tf.keras.layers.Reshape((-1, 1, shape_numpy[-1]))130expand_dims_layer_2 = tf.keras.layers.Reshape((-1, 1, shape_numpy[-1]))131a = expand_dims_layer_1(a) # a shape: [B, ?, 1, d]132b = expand_dims_layer_2(b) # a shape: [B, ?, 1, d]133tile_layer = tf.keras.layers.Lambda(134lambda inputs: tf.tile(inputs[0], multiples=inputs[1]))135a = tile_layer((a, [1, 1, variable_length, 1])) # a shape: [B, ?, ?, d]136b = tile_layer((b, [1, 1, variable_length, 1])) # b shape: [B, ?, ?, d]137b = tf.keras.layers.Permute((2, 1, 3))(b) # b shape: [B, ?, ?, d]138return tf.keras.layers.concatenate([a, b]) # shape: [B, ?, ?, 2*d]139
140
141class BahdanauAttention(tf.keras.Model):142"""Bahdanau Attention Layer.143
144Attributes:
145w1: weights that process the feature
146w2: weights that process the memory state
147v: projection layer that project score vector to scalar
148"""
149
150def __init__(self, units):151"""Initialize Bahdanau attention layer.152
153Args:
154units: size of the dense layers
155"""
156super(BahdanauAttention, self).__init__()157self.W1 = tf.keras.layers.Dense(units)158self.W2 = tf.keras.layers.Dense(units)159self.V = tf.keras.layers.Dense(1)160
161def call(self, features, hidden):162# features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)163
164# hidden shape == (batch_size, hidden_size)165# hidden_with_time_axis shape == (batch_size, 1, hidden_size)166hidden_with_time_axis = tf.expand_dims(hidden, 1)167
168# score shape == (batch_size, 64, hidden_size)169score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))170
171# attention_weights shape == (batch_size, 64, 1)172# you get 1 at the last axis because you are applying score to self.V173attention_weights = tf.nn.softmax(self.V(score), axis=1)174
175# context_vector shape after sum == (batch_size, hidden_size)176context_vector = attention_weights * features177context_vector = tf.reduce_sum(context_vector, axis=1)178
179return context_vector, attention_weights180
181
182class GRUEnecoder(tf.keras.Model):183"""TF2.0 GRE encoder.184
185Attributes:
186embedding: word embedding matrix
187gru: the GRU layer
188"""
189
190def __init__(self, embedding_dim, units, vocab_size):191"""Initialize the GRU encoder.192
193Args:
194embedding_dim: dimension of word emebdding
195units: number of units of the memory state
196vocab_size: total number of vocabulary
197"""
198super(GRUEnecoder, self).__init__()199self._units = units200
201self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)202self.gru = tf.keras.layers.GRU(203self.units,204return_sequences=True,205return_state=True,206recurrent_initializer='glorot_uniform')207
208def call(self, x, hidden):209# x shape after passing through embedding == (batch_size, 1, embedding_dim)210x = self.embedding(x)211# passing the concatenated vector to the GRU212output, state = self.gru(x)213return output, state214
215def reset_state(self, batch_size):216return tf.zeros((batch_size, self._units))217