google-research
407 строк · 13.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Add KL, beta-BBB, just on encoder_w and include a version of vanilla NP."""
17from __future__ import print_function
18import functools
19import os
20import pickle
21import time
22
23from absl import app
24from absl import flags
25import numpy as np
26import tensorflow.compat.v1 as tf
27from tensorflow.compat.v1.keras.layers import MaxPooling2D
28import tensorflow_probability as tfp
29from tensorflow_probability.python.layers import util as tfp_layers_util
30
31
32tf.compat.v1.enable_v2_tensorshape()
33FLAGS = flags.FLAGS
34
35## Dataset/method options
36flags.DEFINE_string('logdir', '/tmp/data',
37'directory for summaries and checkpoints.')
38flags.DEFINE_string('data_dir', None,
39'Directory of data files.')
40get_data_dir = lambda: FLAGS.data_dir
41flags.DEFINE_list('data', ['train_data_ins.pkl', 'val_data_ins.pkl'],
42'data name')
43flags.DEFINE_integer('update_batch_size', 15, 'number of context/target')
44flags.DEFINE_integer('meta_batch_size', 10, 'number of tasks')
45flags.DEFINE_integer('dim_im', 128, 'image size')
46flags.DEFINE_integer('dim_y', 1, 'dimension of y')
47
48## Training options
49flags.DEFINE_list('n_hidden_units_g', [100, 100],
50'number of tasks sampled per meta-update')
51flags.DEFINE_list('n_hidden_units_r', [100, 100],
52'number of inner gradient updates during test.')
53flags.DEFINE_integer('dim_z', 200, 'dimension of z')
54flags.DEFINE_integer('dim_r', 200, 'dimension of r for aggregating')
55flags.DEFINE_float('update_lr', 5e-4, 'lr')
56flags.DEFINE_integer('num_updates', 100000, 'num_updates')
57flags.DEFINE_integer('trial', 1, 'trial number')
58flags.DEFINE_integer(
59'num_classes', 1,
60'number of classes used in classification (e.g. 5-way classification).')
61flags.DEFINE_bool('deterministic', True, 'deterministic encoder')
62
63flags.DEFINE_float('beta', 0.001, 'beta for IB')
64flags.DEFINE_float('var', -3.0, 'var initial')
65flags.DEFINE_integer('dim_w', 20, 'dimension of w')
66flags.DEFINE_float('facto', 1.0, 'zero out z to memorize or not')
67
68
69def get_batch(x, y):
70"""Get data batch."""
71xs, ys, xq, yq = [], [], [], []
72for _ in range(FLAGS.meta_batch_size):
73# sample WAY classes
74classes = np.random.choice(
75range(np.shape(x)[0]), size=FLAGS.num_classes, replace=False)
76
77support_set = []
78query_set = []
79support_sety = []
80query_sety = []
81for k in list(classes):
82# sample SHOT and QUERY instances
83idx = np.random.choice(
84range(np.shape(x)[1]),
85size=FLAGS.update_batch_size + FLAGS.update_batch_size,
86replace=False)
87x_k = x[k][idx]
88y_k = y[k][idx]
89
90support_set.append(x_k[:FLAGS.update_batch_size])
91query_set.append(x_k[FLAGS.update_batch_size:])
92support_sety.append(y_k[:FLAGS.update_batch_size])
93query_sety.append(y_k[FLAGS.update_batch_size:])
94
95xs_k = np.concatenate(support_set, 0)
96xq_k = np.concatenate(query_set, 0)
97ys_k = np.concatenate(support_sety, 0)
98yq_k = np.concatenate(query_sety, 0)
99
100xs.append(xs_k)
101xq.append(xq_k)
102ys.append(ys_k)
103yq.append(yq_k)
104
105xs, ys = np.stack(xs, 0), np.stack(ys, 0)
106xq, yq = np.stack(xq, 0), np.stack(yq, 0)
107
108xs = np.reshape(
109xs,
110[FLAGS.meta_batch_size, FLAGS.update_batch_size * FLAGS.num_classes, -1])
111xq = np.reshape(
112xq,
113[FLAGS.meta_batch_size, FLAGS.update_batch_size * FLAGS.num_classes, -1])
114xs = xs.astype(np.float32) / 255.0
115xq = xq.astype(np.float32) / 255.0
116ys = ys.astype(np.float32) * 10.0
117yq = yq.astype(np.float32) * 10.0
118return xs, ys, xq, yq
119
120
121def gen(x, y):
122while True:
123yield get_batch(np.array(x), np.array(y))
124
125
126def sampling(output):
127mu, logstd = tf.split(output, num_or_size_splits=2, axis=-1)
128sigma = tf.nn.softplus(logstd)
129ws = mu + tf.random_normal(tf.shape(mu)) * sigma
130return ws, mu, sigma
131
132
133def mse(pred, label):
134pred = tf.reshape(pred, [-1])
135label = tf.reshape(label, [-1])
136return tf.reduce_mean(tf.square(pred - label))
137
138
139def encoder_r(xys):
140"""Define encoder."""
141with tf.variable_scope('encoder_r', reuse=tf.AUTO_REUSE):
142hidden_layer = xys
143# First layers are relu
144for i, n_hidden_units in enumerate(FLAGS.n_hidden_units_r):
145hidden_layer = tf.layers.dense(
146hidden_layer,
147n_hidden_units,
148activation=tf.nn.relu,
149name='encoder_r_{}'.format(i),
150reuse=tf.AUTO_REUSE,
151kernel_initializer='normal')
152
153# Last layer is simple linear
154i = len(FLAGS.n_hidden_units_r)
155r = tf.layers.dense(
156hidden_layer,
157FLAGS.dim_r,
158name='encoder_r_{}'.format(i),
159reuse=tf.AUTO_REUSE,
160kernel_initializer='normal')
161return r
162
163
164def encoder_w(xs, encoder_w0):
165"""xs is [n_task, n_im, dim_x]; return [n_task, n_im, dim_w]."""
166n_task = tf.shape(xs)[0]
167n_im = tf.shape(xs)[1]
168xs = tf.reshape(xs, [-1, 128, 128, 1])
169
170ws = encoder_w0(xs)
171ws = tf.reshape(ws, [n_task, n_im, FLAGS.dim_w])
172return ws
173
174
175def xy_to_z(xs, ys, encoder_w0):
176r"""ws = T0(xs), rs = T1(ws, ys), r = mean(rs), z \sim N(mu(r), sigma(r))."""
177with tf.variable_scope(''):
178ws = encoder_w(xs, encoder_w0) # (n_task * n_im_per_task) * dim_w
179
180transformed_ys = tf.layers.dense(
181ys,
182FLAGS.dim_w // 4,
183name='lift_y',
184reuse=tf.AUTO_REUSE,
185kernel_initializer='normal')
186wys = tf.concat([ws, transformed_ys],
187axis=-1) # n_task * n_im_per_task * (dim_w+dim_transy)
188
189rs = encoder_r(wys) # n_task * n_im_per_task * dim_r
190
191r = tf.reduce_mean(rs, axis=1, keepdims=True) # n_task * 1 * dim_r
192
193if FLAGS.deterministic:
194z_sample = tf.layers.dense(
195r,
196FLAGS.dim_z,
197name='r2z',
198reuse=tf.AUTO_REUSE,
199kernel_initializer='normal')
200else:
201z = tf.layers.dense(
202r,
203FLAGS.dim_z + FLAGS.dim_z,
204name='r2z',
205reuse=tf.AUTO_REUSE,
206kernel_initializer='normal')
207z_sample, _, _ = sampling(z)
208
209return tf.tile(z_sample, [1, FLAGS.update_batch_size, 1]) # tile n_targets
210
211
212def construct_model(input_tensors, encoder_w0, decoder0, prefix=None):
213"""Construct model."""
214facto = tf.placeholder_with_default(1.0, ())
215context_xs = input_tensors['inputa']
216context_ys = input_tensors['labela']
217target_xs = input_tensors['inputb']
218target_ys = input_tensors['labelb']
219
220# sample ws ~ w|(x_all,a), rs = T(ws, ys), r = mean(rs), z = T(r)
221# x_all = tf.concat([context_xs, target_xs], axis=1) #n_task * 20 * (128*128)
222# y_all = tf.concat([context_ys, target_ys], axis=1)
223
224x_all = context_xs
225y_all = context_ys
226
227# n_task * [n_im] * d_z
228if 'train' in prefix:
229z_samples = xy_to_z(x_all, y_all, encoder_w0) * facto
230else:
231z_samples = xy_to_z(context_xs, context_ys, encoder_w0) * facto
232
233target_ws = encoder_w(target_xs, encoder_w0)
234input_zxs = tf.concat([z_samples, target_ws], axis=-1)
235
236# sample y_hat ~ y|(w,z)
237with tf.variable_scope('decoder'):
238target_yhat_mu = decoder0(input_zxs) # n_task * n_im * dim_y
239
240# when var of p(y | x,z) is fixed, neg-loglik <=> MSE
241mse_loss = mse(target_yhat_mu, target_ys)
242
243tf.summary.scalar(prefix + 'mse', mse_loss)
244optimizer1 = tf.train.AdamOptimizer(FLAGS.update_lr)
245optimizer2 = tf.train.AdamOptimizer(FLAGS.update_lr)
246
247if 'train' in prefix:
248THETA = ( # pylint: disable=invalid-name
249tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='decoder') +
250tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='encoder_w'))
251all_var = tf.trainable_variables()
252PHI = [v for v in all_var if v not in THETA] # pylint: disable=invalid-name
253
254kl_loss = sum(encoder_w0.losses) # +sum(decoder0.losses)
255
256scale_v = [v for v in encoder_w0.trainable_variables if 'scale' in v.name]
257scale_norm = [tf.reduce_mean(v) for v in scale_v]
258scale_norm = tf.reduce_mean(scale_norm)
259
260loss = mse_loss + FLAGS.beta * kl_loss
261
262gvs_theta = optimizer1.compute_gradients(loss, THETA)
263train_theta_op = optimizer1.apply_gradients(gvs_theta)
264
265gvs_phi = optimizer2.compute_gradients(loss, PHI)
266train_phi_op = optimizer2.apply_gradients(gvs_phi)
267with tf.control_dependencies([train_theta_op, train_phi_op]):
268train_op = tf.no_op()
269tf.summary.scalar(prefix + 'full_loss', loss)
270tf.summary.scalar(prefix + 'regularizer', FLAGS.beta * kl_loss)
271tf.summary.scalar(prefix + 'untransformed_scale', scale_norm)
272return mse_loss, train_op, facto
273else:
274return mse_loss
275
276
277def main(_):
278kernel_posterior_fn = tfp_layers_util.default_mean_field_normal_fn(
279untransformed_scale_initializer=tf.compat.v1.initializers.random_normal(
280mean=FLAGS.var, stddev=0.1))
281encoder_w0 = tf.keras.Sequential([
282tfp.layers.Convolution2DReparameterization(
283filters=32,
284kernel_size=3,
285strides=(2, 2),
286activation='relu',
287padding='SAME',
288kernel_posterior_fn=kernel_posterior_fn),
289tfp.layers.Convolution2DReparameterization(
290filters=48,
291kernel_size=3,
292strides=(2, 2),
293activation='relu',
294padding='SAME',
295kernel_posterior_fn=kernel_posterior_fn),
296MaxPooling2D(pool_size=(2, 2)),
297tfp.layers.Convolution2DReparameterization(
298filters=64,
299kernel_size=3,
300strides=(2, 2),
301activation='relu',
302padding='SAME',
303kernel_posterior_fn=kernel_posterior_fn),
304tf.keras.layers.Flatten(),
305tfp.layers.DenseReparameterization(
306FLAGS.dim_w, kernel_posterior_fn=kernel_posterior_fn),
307])
308
309decoder0 = tf.keras.Sequential([
310tf.keras.layers.Dense(100, activation=tf.nn.relu),
311tf.keras.layers.Dense(100, activation=tf.nn.relu),
312tf.keras.layers.Dense(FLAGS.dim_y),
313])
314
315dim_output = FLAGS.dim_y
316dim_input = FLAGS.dim_im * FLAGS.dim_im * 1
317
318exp_name = '%s.beta-%g.update_lr-%g.trial-%d' % ('np_bbb', FLAGS.beta,
319FLAGS.update_lr, FLAGS.trial)
320checkpoint_dir = os.path.join(FLAGS.logdir, exp_name)
321
322x_train, y_train = pickle.load(
323tf.io.gfile.GFile(os.path.join(get_data_dir(), FLAGS.data[0]), 'rb'))
324x_val, y_val = pickle.load(
325tf.io.gfile.GFile(os.path.join(get_data_dir(), FLAGS.data[1]), 'rb'))
326
327x_train, y_train = np.array(x_train), np.array(y_train)
328y_train = y_train[:, :, -1, None]
329x_val, y_val = np.array(x_val), np.array(y_val)
330y_val = y_val[:, :, -1, None]
331
332ds_train = tf.data.Dataset.from_generator(
333functools.partial(gen, x_train, y_train),
334(tf.float32, tf.float32, tf.float32, tf.float32),
335(tf.TensorShape(
336[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
337tf.TensorShape(
338[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output]),
339tf.TensorShape(
340[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
341tf.TensorShape(
342[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output])))
343
344ds_val = tf.data.Dataset.from_generator(
345functools.partial(gen, x_val, y_val),
346(tf.float32, tf.float32, tf.float32, tf.float32),
347(tf.TensorShape(
348[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
349tf.TensorShape(
350[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output]),
351tf.TensorShape(
352[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
353tf.TensorShape(
354[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output])))
355
356inputa, labela, inputb, labelb = ds_train.make_one_shot_iterator().get_next()
357
358input_tensors = {'inputa': inputa,\
359'inputb': inputb,\
360'labela': labela, 'labelb': labelb}
361
362inputa_val, labela_val, inputb_val, labelb_val = ds_val.make_one_shot_iterator(
363).get_next()
364
365metaval_input_tensors = {'inputa': inputa_val,\
366'inputb': inputb_val,\
367'labela': labela_val, 'labelb': labelb_val}
368
369loss, train_op, facto = construct_model(
370input_tensors, encoder_w0, decoder0, prefix='metatrain_')
371loss_val = construct_model(
372metaval_input_tensors, encoder_w0, decoder0, prefix='metaval_')
373
374###########
375
376summ_op = tf.summary.merge_all()
377sess = tf.InteractiveSession()
378summary_writer = tf.summary.FileWriter(checkpoint_dir, sess.graph)
379tf.global_variables_initializer().run()
380
381PRINT_INTERVAL = 50 # pylint: disable=invalid-name
382SUMMARY_INTERVAL = 5 # pylint: disable=invalid-name
383prelosses, prelosses_val = [], []
384old_time = time.time()
385for itr in range(FLAGS.num_updates):
386
387feed_dict = {facto: FLAGS.facto}
388
389if itr % SUMMARY_INTERVAL == 0:
390summary, cost, cost_val = sess.run([summ_op, loss, loss_val], feed_dict)
391summary_writer.add_summary(summary, itr)
392prelosses.append(cost) # 0 step loss on training set
393prelosses_val.append(cost_val) # 0 step loss on meta_val training set
394
395sess.run(train_op, feed_dict)
396
397if (itr != 0) and itr % PRINT_INTERVAL == 0:
398print('Iteration ' + str(itr) + ': ' + str(np.mean(prelosses)), 'time =',
399time.time() - old_time)
400prelosses = []
401old_time = time.time()
402print('Validation results: ' + str(np.mean(prelosses_val)))
403prelosses_val = []
404
405
406if __name__ == '__main__':
407app.run(main)
408