google-research

Форк
0
283 строки · 10.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utility functions for maml_vanilla.py."""
17
from __future__ import print_function
18

19
from absl import flags
20
import numpy as np
21
import tensorflow.compat.v1 as tf
22
from tensorflow.contrib import layers as contrib_layers
23
from tensorflow.contrib import opt as contrib_opt
24
from tensorflow.contrib.layers.python import layers as tf_layers
25

26
FLAGS = flags.FLAGS
27

28

29
## Network helpers
30
def conv_block(x, weight, bias, reuse, scope):
31
  x = tf.nn.conv2d(x, weight, [1, 1, 1, 1], 'SAME') + bias
32
  x = tf_layers.batch_norm(
33
      x, activation_fn=tf.nn.relu, reuse=reuse, scope=scope)
34
  # # pooling
35
  # x = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID')
36
  return x
37

38

39
## Loss functions
40
def mse(pred, label):
41
  pred = tf.reshape(pred, [-1])
42
  label = tf.reshape(label, [-1])
43
  return tf.reduce_mean(tf.square(pred - label))
44

45

46
class MAML(object):
47
  """MAML algo object."""
48

49
  def __init__(self, dim_input=1, dim_output=1):
50
    """must call construct_model() after initializing MAML!"""
51
    self.dim_input = dim_input
52
    self.dim_output = dim_output
53
    self.update_lr = FLAGS.update_lr
54
    self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ())
55
    self.classification = False
56

57
    self.loss_func = mse
58

59
    self.classification = True
60
    self.dim_hidden = FLAGS.num_filters
61
    self.forward = self.forward_conv
62
    self.construct_weights = self.construct_conv_weights
63

64
    self.channels = 1
65
    self.img_size = int(np.sqrt(self.dim_input / self.channels))
66

67
  def construct_model(self,
68
                      input_tensors=None,
69
                      prefix='metatrain_',
70
                      test_num_updates=0):
71
    """a: training data for inner gradient, b: test data for meta gradient."""
72

73
    self.inputa = input_tensors['inputa']
74
    self.inputb = input_tensors['inputb']
75
    self.labela = input_tensors['labela']
76
    self.labelb = input_tensors['labelb']
77

78
    with tf.variable_scope('model', reuse=None) as training_scope:
79
      if 'weights' in dir(self):
80
        training_scope.reuse_variables()
81
        weights = self.weights
82
      else:
83
        # Define the weights
84
        self.weights = weights = self.construct_weights()
85

86
      # outputbs[i] and lossesb[i] is the output and loss after i+1 gradient
87
      # updates
88
      num_updates = max(test_num_updates, FLAGS.num_updates)
89

90
      def task_metalearn(inp, reuse=True):
91
        """Run meta learning."""
92
        TRAIN = 'train' in prefix  # pylint: disable=invalid-name
93
        # Perform gradient descent for one task in the meta-batch.
94
        inputa, inputb, labela, labelb = inp
95
        task_outputbs, task_lossesb = [], []
96
        task_msesb = []
97

98
        # support_pred and loss, (n_data_per_task, out_dim)
99
        task_outputa = self.forward(
100
            inputa, weights, reuse=reuse)  # only not reuse on the first iter
101
        # labela is (n_data_per_task, out_dim)
102
        task_lossa = self.loss_func(task_outputa, labela)
103

104
        # INNER LOOP (no change with ib)
105
        grads = tf.gradients(task_lossa, list(weights.values()))
106
        if FLAGS.stop_grad:
107
          grads = [tf.stop_gradient(grad) for grad in grads]
108
        gradients = dict(zip(weights.keys(), grads))
109
        ## theta_pi = theta - alpha * grads
110
        fast_weights = dict(
111
            zip(weights.keys(), [
112
                weights[key] - self.update_lr * gradients[key]
113
                for key in weights.keys()
114
            ]))
115

116
        # use theta_pi to forward meta-test
117
        output = self.forward(inputb, fast_weights, reuse=True)
118
        task_outputbs.append(output)
119
        # meta-test loss
120
        task_msesb.append(self.loss_func(output, labelb))
121
        task_lossesb.append(self.loss_func(output, labelb))
122

123
        def while_body(fast_weights_values):
124
          """Update params."""
125
          loss = self.loss_func(
126
              self.forward(
127
                  inputa,
128
                  dict(zip(fast_weights.keys(), fast_weights_values)),
129
                  reuse=True), labela)
130
          grads = tf.gradients(loss, fast_weights_values)
131
          fast_weights_values = [
132
              v - self.update_lr * g for v, g in zip(fast_weights_values, grads)
133
          ]
134
          return fast_weights_values
135

136
        fast_weights_values = tf.while_loop(
137
            lambda _: True,
138
            while_body,
139
            loop_vars=[fast_weights.values()],
140
            maximum_iterations=num_updates - 1,
141
            back_prop=TRAIN)
142
        fast_weights = dict(zip(fast_weights.keys(), fast_weights_values))
143

144
        output = self.forward(inputb, fast_weights, reuse=True)
145
        task_outputbs.append(output)
146
        task_msesb.append(self.loss_func(output, labelb))
147
        task_lossesb.append(self.loss_func(output, labelb))
148
        task_output = [
149
            task_outputa, task_outputbs, task_lossa, task_lossesb, task_msesb
150
        ]
151

152
        return task_output
153

154
      if FLAGS.norm is not None:
155
        # to initialize the batch norm vars, might want to combine this, and
156
        # not run idx 0 twice.
157
        _ = task_metalearn(
158
            (self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]),
159
            False)
160

161
      out_dtype = [
162
          tf.float32, [tf.float32] * 2, tf.float32, [tf.float32] * 2,
163
          [tf.float32] * 2
164
      ]
165
      result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, \
166
                                                self.labela, self.labelb), dtype=out_dtype, \
167
                                                parallel_iterations=FLAGS.meta_batch_size)
168
      outputas, outputbs, lossesa, _, msesb = result
169

170
    ## Performance & Optimization
171
    if 'train' in prefix:
172
      # lossesa is length(meta_batch_size)
173
      self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(
174
          FLAGS.meta_batch_size)
175
      self.total_losses2 = total_losses2 = [
176
          tf.reduce_sum(msesb[j]) / tf.to_float(FLAGS.meta_batch_size)
177
          for j in range(len(msesb))
178
      ]
179
      # after the map_fn
180
      self.outputas, self.outputbs = outputas, outputbs
181

182
      # OUTER LOOP
183
      if FLAGS.metatrain_iterations > 0:
184
        if FLAGS.weight_decay:
185
          optimizer = contrib_opt.AdamWOptimizer(
186
              weight_decay=FLAGS.beta, learning_rate=self.meta_lr)
187
        else:
188
          optimizer = tf.train.AdamOptimizer(self.meta_lr)
189
        self.gvs_theta = gvs_theta = optimizer.compute_gradients(
190
            self.total_losses2[-1])
191
        self.metatrain_op = optimizer.apply_gradients(gvs_theta)
192

193
    else:
194
      self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(
195
          lossesa) / tf.to_float(FLAGS.meta_batch_size)
196
      self.metaval_total_losses2 = total_losses2 = [
197
          tf.reduce_sum(msesb[j]) / tf.to_float(FLAGS.meta_batch_size)
198
          for j in range(len(msesb))
199
      ]
200

201
    ## Summaries
202
    tf.summary.scalar(prefix + 'Pre-mse', total_loss1)
203
    tf.summary.scalar(prefix + 'Post-mse_' + str(num_updates),
204
                      total_losses2[-1])
205

206
  def construct_conv_weights(self):
207
    """Construct conv weights."""
208
    weights = {}
209

210
    dtype = tf.float32
211
    conv_initializer = contrib_layers.xavier_initializer_conv2d(dtype=dtype)
212
    conv_initializer2 = tf.glorot_uniform_initializer(dtype=dtype)
213
    k = 3
214
    weights['en_conv1'] = tf.get_variable(
215
        'en_conv1', [3, 3, 1, 32], initializer=conv_initializer2, dtype=dtype)
216
    weights['en_bias1'] = tf.Variable(tf.zeros([32]))
217
    weights['en_conv2'] = tf.get_variable(
218
        'en_conv2', [3, 3, 32, 48], initializer=conv_initializer2, dtype=dtype)
219
    weights['en_bias2'] = tf.Variable(tf.zeros([48]))
220
    weights['en_conv3'] = tf.get_variable(
221
        'en_conv3', [3, 3, 48, 64], initializer=conv_initializer2, dtype=dtype)
222
    weights['en_bias3'] = tf.Variable(tf.zeros([64]))
223

224
    weights['en_full1'] = tf.get_variable(
225
        'en_full1', [4096, 196], initializer=conv_initializer2, dtype=dtype)
226
    weights['en_bias_full1'] = tf.Variable(tf.zeros([196]))
227

228
    weights['conv1'] = tf.get_variable(
229
        'conv1', [k, k, self.channels, self.dim_hidden],
230
        initializer=conv_initializer,
231
        dtype=dtype)
232
    weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden]))
233
    weights['conv2'] = tf.get_variable(
234
        'conv2', [k, k, self.dim_hidden, self.dim_hidden],
235
        initializer=conv_initializer,
236
        dtype=dtype)
237
    weights['b2'] = tf.Variable(tf.zeros([self.dim_hidden]))
238
    weights['conv3'] = tf.get_variable(
239
        'conv3', [k, k, self.dim_hidden, self.dim_hidden],
240
        initializer=conv_initializer,
241
        dtype=dtype)
242
    weights['b3'] = tf.Variable(tf.zeros([self.dim_hidden]))
243
    weights['conv4'] = tf.get_variable(
244
        'conv4', [k, k, self.dim_hidden, self.dim_hidden],
245
        initializer=conv_initializer,
246
        dtype=dtype)
247
    weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden]))
248

249
    weights['w5'] = tf.Variable(
250
        tf.random_normal([self.dim_hidden, self.dim_output]), name='w5')
251
    weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5')
252
    return weights
253

254
  def forward_conv(self, inp, weights, reuse=False, scope=''):
255
    """Forward conv."""
256
    # reuse is for the normalization parameters.
257
    channels = self.channels
258
    inp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels])
259
    en1 = tf.nn.conv2d(inp, weights['en_conv1'], [1, 2, 2, 1],
260
                       'SAME') + weights['en_bias1']
261
    en2 = tf.nn.conv2d(en1, weights['en_conv2'], [1, 2, 2, 1],
262
                       'SAME') + weights['en_bias2']
263
    pool1 = tf.nn.max_pool(en2, 2, 2, 'VALID')
264
    en3 = tf.nn.conv2d(pool1, weights['en_conv3'], [1, 2, 2, 1],
265
                       'SAME') + weights['en_bias3']
266
    out0 = tf.layers.flatten(en3)
267
    out1 = tf.nn.relu(
268
        tf.matmul(out0, weights['en_full1']) + weights['en_bias_full1'])
269

270
    out1 = tf.reshape(out1, [-1, 14, 14, 1])
271

272
    hidden1 = conv_block(out1, weights['conv1'], weights['b1'], reuse,
273
                         scope + '0')
274
    hidden2 = conv_block(hidden1, weights['conv2'], weights['b2'], reuse,
275
                         scope + '1')
276
    hidden3 = conv_block(hidden2, weights['conv3'], weights['b3'], reuse,
277
                         scope + '2')
278
    hidden4 = conv_block(hidden3, weights['conv4'], weights['b4'], reuse,
279
                         scope + '3')
280
    # last hidden layer is 6x6x64-ish, reshape to a vector
281
    hidden4 = tf.reduce_mean(hidden4, [1, 2])
282
    # ipdb.set_trace()
283
    return tf.matmul(hidden4, weights['w5']) + weights['b5']
284

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.