google-research

Форк
0
/
tf2_utils.py 
216 строк · 6.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utilities for Tensorflow 2.0.
17

18
Partially adapted from:
19
https://www.tensorflow.org/tutorials/text/image_captioning
20
"""
21
# pylint: disable=invalid-name
22

23
from __future__ import absolute_import
24
from __future__ import division
25

26
import tensorflow as tf
27

28

29
def film_params(sentence_embedding, n_layer_channel):
30
  """Generate FiLM parameters from a sentence embedding.
31

32
  Generate FiLM parameters from a sentence embedding. This method assumes a
33
  batch dimension exists.
34

35
  Args:
36
    sentence_embedding: a tensor containing batched sentenced embedding to be
37
      transformed
38
    n_layer_channel:    a list of integers specifying how many channels are at
39
      each hidden layer to be FiLM'ed
40

41
  Returns:
42
    a tuple of tensors the same length as n_layer_channel. Each element
43
    contains all gamma_i and beta_i for a single hidden layer.
44
  """
45
  n_total = sum(n_layer_channel) * 2
46
  all_params = tf.layers.dense(sentence_embedding, n_total)
47
  all_params = tf.keras.layers.Dense(
48
      2 * sum * (n_layer_channel), activation=tf.nn.relu)
49
  return tf.split(all_params, [c * 2 for c in n_layer_channel], 1)
50

51

52
def stack_conv_layer(layer_cfg, padding='same'):
53
  """Stack convolution layers per layer_cfg.
54

55
  Args:
56
    layer_cfg: list of integer tuples specifying the parameter each layer;
57
      each tuple should be (channel, kernel size, strides)
58
    padding: what kind of padding the conv layers use
59

60
  Returns:
61
    the keras model with stacked conv layers
62
  """
63
  layers = []
64
  for cfg in layer_cfg[:-1]:
65
    layers.append(
66
        tf.keras.layers.Conv2D(
67
            filters=cfg[0],
68
            kernel_size=cfg[1],
69
            strides=cfg[2],
70
            activation=tf.nn.relu,
71
            padding=padding))
72
  final_cfg = layer_cfg[-1]
73
  layers.append(
74
      tf.keras.layers.Conv2D(
75
          final_cfg[0], final_cfg[1], final_cfg[2], padding=padding))
76
  return tf.keras.Sequential(layers)
77

78

79
def stack_dense_layer(layer_cfg):
80
  """Stack Dense layers.
81

82
  Args:
83
    layer_cfg: list of integer specifying the number of units at each layer
84

85
  Returns:
86
    the keras model with stacked dense layers
87
  """
88
  layers = []
89
  for cfg in layer_cfg[:-1]:
90
    layers.append(tf.keras.layers.Dense(cfg, activation=tf.nn.relu))
91
  layers.append(tf.keras.layers.Dense(layer_cfg[-1]))
92
  return tf.keras.Sequential(layers)
93

94

95
def soft_variables_update(source_variables, target_variables, polyak_rate=1.0):
96
  """Update the target variables using exponential moving average.
97

98
  Specifically, v_s' = v_s * polyak_rate + (1-polyak_rate) * v_t
99

100
  Args:
101
    source_variables:  the moving average variables
102
    target_variables:  the new observations
103
    polyak_rate: rate of moving average
104

105
  Returns:
106
    Operation that does the update
107
  """
108
  updates = []
109
  for (v_s, v_t) in zip(source_variables, target_variables):
110
    v_t.shape.assert_is_compatible_with(v_s.shape)
111

112
    def update_fn(v1, v2):
113
      """Update variables."""
114
      # For not trainable variables do hard updates.
115
      return v1.assign(polyak_rate * v1 + (1 - polyak_rate) * v2)
116

117
    update = update_fn(v_t, v_s)
118
    updates.append(update)
119
  return updates
120

121

122
def vector_tensor_product(a, b):
123
  """"Returns keras layer that perfrom a outer product between a and b."""
124
  # a shape: [B, ?, d], b shape: [B, ?, d]
125
  shape_layer = tf.keras.layers.Lambda(tf.shape)
126
  shape = shape_layer(b)
127
  shape_numpy = b.get_shape()
128
  variable_length = shape[1]  # variable_len = ?
129
  expand_dims_layer_1 = tf.keras.layers.Reshape((-1, 1, shape_numpy[-1]))
130
  expand_dims_layer_2 = tf.keras.layers.Reshape((-1, 1, shape_numpy[-1]))
131
  a = expand_dims_layer_1(a)  # a shape: [B, ?, 1, d]
132
  b = expand_dims_layer_2(b)  # a shape: [B, ?, 1, d]
133
  tile_layer = tf.keras.layers.Lambda(
134
      lambda inputs: tf.tile(inputs[0], multiples=inputs[1]))
135
  a = tile_layer((a, [1, 1, variable_length, 1]))  # a shape: [B, ?, ?, d]
136
  b = tile_layer((b, [1, 1, variable_length, 1]))  # b shape: [B, ?, ?, d]
137
  b = tf.keras.layers.Permute((2, 1, 3))(b)  # b shape: [B, ?, ?, d]
138
  return tf.keras.layers.concatenate([a, b])  # shape: [B, ?, ?, 2*d]
139

140

141
class BahdanauAttention(tf.keras.Model):
142
  """Bahdanau Attention Layer.
143

144
  Attributes:
145
    w1: weights that process the feature
146
    w2: weights that process the memory state
147
    v: projection layer that project score vector to scalar
148
  """
149

150
  def __init__(self, units):
151
    """Initialize Bahdanau attention layer.
152

153
    Args:
154
      units: size of the dense layers
155
    """
156
    super(BahdanauAttention, self).__init__()
157
    self.W1 = tf.keras.layers.Dense(units)
158
    self.W2 = tf.keras.layers.Dense(units)
159
    self.V = tf.keras.layers.Dense(1)
160

161
  def call(self, features, hidden):
162
    # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)
163

164
    # hidden shape == (batch_size, hidden_size)
165
    # hidden_with_time_axis shape == (batch_size, 1, hidden_size)
166
    hidden_with_time_axis = tf.expand_dims(hidden, 1)
167

168
    # score shape == (batch_size, 64, hidden_size)
169
    score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
170

171
    # attention_weights shape == (batch_size, 64, 1)
172
    # you get 1 at the last axis because you are applying score to self.V
173
    attention_weights = tf.nn.softmax(self.V(score), axis=1)
174

175
    # context_vector shape after sum == (batch_size, hidden_size)
176
    context_vector = attention_weights * features
177
    context_vector = tf.reduce_sum(context_vector, axis=1)
178

179
    return context_vector, attention_weights
180

181

182
class GRUEnecoder(tf.keras.Model):
183
  """TF2.0 GRE encoder.
184

185
  Attributes:
186
    embedding: word embedding matrix
187
    gru: the GRU layer
188
  """
189

190
  def __init__(self, embedding_dim, units, vocab_size):
191
    """Initialize the GRU encoder.
192

193
    Args:
194
      embedding_dim: dimension of word emebdding
195
      units: number of units of the memory state
196
      vocab_size: total number of vocabulary
197
    """
198
    super(GRUEnecoder, self).__init__()
199
    self._units = units
200

201
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
202
    self.gru = tf.keras.layers.GRU(
203
        self.units,
204
        return_sequences=True,
205
        return_state=True,
206
        recurrent_initializer='glorot_uniform')
207

208
  def call(self, x, hidden):
209
    # x shape after passing through embedding == (batch_size, 1, embedding_dim)
210
    x = self.embedding(x)
211
    # passing the concatenated vector to the GRU
212
    output, state = self.gru(x)
213
    return output, state
214

215
  def reset_state(self, batch_size):
216
    return tf.zeros((batch_size, self._units))
217

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.