google-research
283 строки · 10.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utility functions for maml_vanilla.py."""
17from __future__ import print_function
18
19from absl import flags
20import numpy as np
21import tensorflow.compat.v1 as tf
22from tensorflow.contrib import layers as contrib_layers
23from tensorflow.contrib import opt as contrib_opt
24from tensorflow.contrib.layers.python import layers as tf_layers
25
26FLAGS = flags.FLAGS
27
28
29## Network helpers
30def conv_block(x, weight, bias, reuse, scope):
31x = tf.nn.conv2d(x, weight, [1, 1, 1, 1], 'SAME') + bias
32x = tf_layers.batch_norm(
33x, activation_fn=tf.nn.relu, reuse=reuse, scope=scope)
34# # pooling
35# x = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID')
36return x
37
38
39## Loss functions
40def mse(pred, label):
41pred = tf.reshape(pred, [-1])
42label = tf.reshape(label, [-1])
43return tf.reduce_mean(tf.square(pred - label))
44
45
46class MAML(object):
47"""MAML algo object."""
48
49def __init__(self, dim_input=1, dim_output=1):
50"""must call construct_model() after initializing MAML!"""
51self.dim_input = dim_input
52self.dim_output = dim_output
53self.update_lr = FLAGS.update_lr
54self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ())
55self.classification = False
56
57self.loss_func = mse
58
59self.classification = True
60self.dim_hidden = FLAGS.num_filters
61self.forward = self.forward_conv
62self.construct_weights = self.construct_conv_weights
63
64self.channels = 1
65self.img_size = int(np.sqrt(self.dim_input / self.channels))
66
67def construct_model(self,
68input_tensors=None,
69prefix='metatrain_',
70test_num_updates=0):
71"""a: training data for inner gradient, b: test data for meta gradient."""
72
73self.inputa = input_tensors['inputa']
74self.inputb = input_tensors['inputb']
75self.labela = input_tensors['labela']
76self.labelb = input_tensors['labelb']
77
78with tf.variable_scope('model', reuse=None) as training_scope:
79if 'weights' in dir(self):
80training_scope.reuse_variables()
81weights = self.weights
82else:
83# Define the weights
84self.weights = weights = self.construct_weights()
85
86# outputbs[i] and lossesb[i] is the output and loss after i+1 gradient
87# updates
88num_updates = max(test_num_updates, FLAGS.num_updates)
89
90def task_metalearn(inp, reuse=True):
91"""Run meta learning."""
92TRAIN = 'train' in prefix # pylint: disable=invalid-name
93# Perform gradient descent for one task in the meta-batch.
94inputa, inputb, labela, labelb = inp
95task_outputbs, task_lossesb = [], []
96task_msesb = []
97
98# support_pred and loss, (n_data_per_task, out_dim)
99task_outputa = self.forward(
100inputa, weights, reuse=reuse) # only not reuse on the first iter
101# labela is (n_data_per_task, out_dim)
102task_lossa = self.loss_func(task_outputa, labela)
103
104# INNER LOOP (no change with ib)
105grads = tf.gradients(task_lossa, list(weights.values()))
106if FLAGS.stop_grad:
107grads = [tf.stop_gradient(grad) for grad in grads]
108gradients = dict(zip(weights.keys(), grads))
109## theta_pi = theta - alpha * grads
110fast_weights = dict(
111zip(weights.keys(), [
112weights[key] - self.update_lr * gradients[key]
113for key in weights.keys()
114]))
115
116# use theta_pi to forward meta-test
117output = self.forward(inputb, fast_weights, reuse=True)
118task_outputbs.append(output)
119# meta-test loss
120task_msesb.append(self.loss_func(output, labelb))
121task_lossesb.append(self.loss_func(output, labelb))
122
123def while_body(fast_weights_values):
124"""Update params."""
125loss = self.loss_func(
126self.forward(
127inputa,
128dict(zip(fast_weights.keys(), fast_weights_values)),
129reuse=True), labela)
130grads = tf.gradients(loss, fast_weights_values)
131fast_weights_values = [
132v - self.update_lr * g for v, g in zip(fast_weights_values, grads)
133]
134return fast_weights_values
135
136fast_weights_values = tf.while_loop(
137lambda _: True,
138while_body,
139loop_vars=[fast_weights.values()],
140maximum_iterations=num_updates - 1,
141back_prop=TRAIN)
142fast_weights = dict(zip(fast_weights.keys(), fast_weights_values))
143
144output = self.forward(inputb, fast_weights, reuse=True)
145task_outputbs.append(output)
146task_msesb.append(self.loss_func(output, labelb))
147task_lossesb.append(self.loss_func(output, labelb))
148task_output = [
149task_outputa, task_outputbs, task_lossa, task_lossesb, task_msesb
150]
151
152return task_output
153
154if FLAGS.norm is not None:
155# to initialize the batch norm vars, might want to combine this, and
156# not run idx 0 twice.
157_ = task_metalearn(
158(self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]),
159False)
160
161out_dtype = [
162tf.float32, [tf.float32] * 2, tf.float32, [tf.float32] * 2,
163[tf.float32] * 2
164]
165result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, \
166self.labela, self.labelb), dtype=out_dtype, \
167parallel_iterations=FLAGS.meta_batch_size)
168outputas, outputbs, lossesa, _, msesb = result
169
170## Performance & Optimization
171if 'train' in prefix:
172# lossesa is length(meta_batch_size)
173self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(
174FLAGS.meta_batch_size)
175self.total_losses2 = total_losses2 = [
176tf.reduce_sum(msesb[j]) / tf.to_float(FLAGS.meta_batch_size)
177for j in range(len(msesb))
178]
179# after the map_fn
180self.outputas, self.outputbs = outputas, outputbs
181
182# OUTER LOOP
183if FLAGS.metatrain_iterations > 0:
184if FLAGS.weight_decay:
185optimizer = contrib_opt.AdamWOptimizer(
186weight_decay=FLAGS.beta, learning_rate=self.meta_lr)
187else:
188optimizer = tf.train.AdamOptimizer(self.meta_lr)
189self.gvs_theta = gvs_theta = optimizer.compute_gradients(
190self.total_losses2[-1])
191self.metatrain_op = optimizer.apply_gradients(gvs_theta)
192
193else:
194self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(
195lossesa) / tf.to_float(FLAGS.meta_batch_size)
196self.metaval_total_losses2 = total_losses2 = [
197tf.reduce_sum(msesb[j]) / tf.to_float(FLAGS.meta_batch_size)
198for j in range(len(msesb))
199]
200
201## Summaries
202tf.summary.scalar(prefix + 'Pre-mse', total_loss1)
203tf.summary.scalar(prefix + 'Post-mse_' + str(num_updates),
204total_losses2[-1])
205
206def construct_conv_weights(self):
207"""Construct conv weights."""
208weights = {}
209
210dtype = tf.float32
211conv_initializer = contrib_layers.xavier_initializer_conv2d(dtype=dtype)
212conv_initializer2 = tf.glorot_uniform_initializer(dtype=dtype)
213k = 3
214weights['en_conv1'] = tf.get_variable(
215'en_conv1', [3, 3, 1, 32], initializer=conv_initializer2, dtype=dtype)
216weights['en_bias1'] = tf.Variable(tf.zeros([32]))
217weights['en_conv2'] = tf.get_variable(
218'en_conv2', [3, 3, 32, 48], initializer=conv_initializer2, dtype=dtype)
219weights['en_bias2'] = tf.Variable(tf.zeros([48]))
220weights['en_conv3'] = tf.get_variable(
221'en_conv3', [3, 3, 48, 64], initializer=conv_initializer2, dtype=dtype)
222weights['en_bias3'] = tf.Variable(tf.zeros([64]))
223
224weights['en_full1'] = tf.get_variable(
225'en_full1', [4096, 196], initializer=conv_initializer2, dtype=dtype)
226weights['en_bias_full1'] = tf.Variable(tf.zeros([196]))
227
228weights['conv1'] = tf.get_variable(
229'conv1', [k, k, self.channels, self.dim_hidden],
230initializer=conv_initializer,
231dtype=dtype)
232weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden]))
233weights['conv2'] = tf.get_variable(
234'conv2', [k, k, self.dim_hidden, self.dim_hidden],
235initializer=conv_initializer,
236dtype=dtype)
237weights['b2'] = tf.Variable(tf.zeros([self.dim_hidden]))
238weights['conv3'] = tf.get_variable(
239'conv3', [k, k, self.dim_hidden, self.dim_hidden],
240initializer=conv_initializer,
241dtype=dtype)
242weights['b3'] = tf.Variable(tf.zeros([self.dim_hidden]))
243weights['conv4'] = tf.get_variable(
244'conv4', [k, k, self.dim_hidden, self.dim_hidden],
245initializer=conv_initializer,
246dtype=dtype)
247weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden]))
248
249weights['w5'] = tf.Variable(
250tf.random_normal([self.dim_hidden, self.dim_output]), name='w5')
251weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5')
252return weights
253
254def forward_conv(self, inp, weights, reuse=False, scope=''):
255"""Forward conv."""
256# reuse is for the normalization parameters.
257channels = self.channels
258inp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels])
259en1 = tf.nn.conv2d(inp, weights['en_conv1'], [1, 2, 2, 1],
260'SAME') + weights['en_bias1']
261en2 = tf.nn.conv2d(en1, weights['en_conv2'], [1, 2, 2, 1],
262'SAME') + weights['en_bias2']
263pool1 = tf.nn.max_pool(en2, 2, 2, 'VALID')
264en3 = tf.nn.conv2d(pool1, weights['en_conv3'], [1, 2, 2, 1],
265'SAME') + weights['en_bias3']
266out0 = tf.layers.flatten(en3)
267out1 = tf.nn.relu(
268tf.matmul(out0, weights['en_full1']) + weights['en_bias_full1'])
269
270out1 = tf.reshape(out1, [-1, 14, 14, 1])
271
272hidden1 = conv_block(out1, weights['conv1'], weights['b1'], reuse,
273scope + '0')
274hidden2 = conv_block(hidden1, weights['conv2'], weights['b2'], reuse,
275scope + '1')
276hidden3 = conv_block(hidden2, weights['conv3'], weights['b3'], reuse,
277scope + '2')
278hidden4 = conv_block(hidden3, weights['conv4'], weights['b4'], reuse,
279scope + '3')
280# last hidden layer is 6x6x64-ish, reshape to a vector
281hidden4 = tf.reduce_mean(hidden4, [1, 2])
282# ipdb.set_trace()
283return tf.matmul(hidden4, weights['w5']) + weights['b5']
284