google-research
401 строка · 13.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""NP vanilla."""
17from __future__ import print_function
18import functools
19import os
20import pickle
21import time
22
23from absl import app
24from absl import flags
25import numpy as np
26import tensorflow.compat.v1 as tf
27from tensorflow.compat.v1.keras.layers import Conv2D
28from tensorflow.compat.v1.keras.layers import MaxPooling2D
29
30
31tf.compat.v1.enable_v2_tensorshape()
32FLAGS = flags.FLAGS
33
34## Dataset/method options
35flags.DEFINE_float('beta', 0.001, 'the beta for weight decay')
36flags.DEFINE_bool('weight_decay', False, 'whether or not to use weight decay')
37
38flags.DEFINE_string('logdir', '/tmp/data',
39'directory for summaries and checkpoints.')
40flags.DEFINE_string('data_dir', None,
41'Directory of data files.')
42get_data_dir = lambda: FLAGS.data_dir
43flags.DEFINE_list('data', ['train_data_ins.pkl', 'val_data_ins.pkl'],
44'data name')
45flags.DEFINE_integer('update_batch_size', 15, 'number of context/target')
46flags.DEFINE_integer('meta_batch_size', 10, 'number of tasks')
47flags.DEFINE_integer('dim_im', 128, 'image size')
48flags.DEFINE_integer('dim_y', 1, 'dimension of y')
49
50## Training options
51flags.DEFINE_list('n_hidden_units_g', [100, 100],
52'number of tasks sampled per meta-update')
53flags.DEFINE_list('n_hidden_units_r', [100, 100],
54'number of inner gradient updates during test.')
55flags.DEFINE_integer('dim_z', 64, 'dimension of z')
56flags.DEFINE_integer('dim_r', 100, 'dimension of r for aggregating')
57flags.DEFINE_float('update_lr', 0.001, 'lr')
58flags.DEFINE_integer('num_updates', 140000, 'num_updates')
59flags.DEFINE_integer('trial', 1, 'trial number')
60flags.DEFINE_integer(
61'num_classes', 1,
62'number of classes used in classification (e.g. 5-way classification).')
63flags.DEFINE_bool('deterministic', True, 'deterministic encoder')
64
65## IB options
66flags.DEFINE_integer('dim_w', 64, 'dimension of w')
67flags.DEFINE_float('facto', 1.0, 'zero out z to memorize or not')
68
69
70def get_batch(x, y):
71"""Get data batch."""
72xs, ys, xq, yq = [], [], [], []
73for _ in range(FLAGS.meta_batch_size):
74# sample WAY classes
75classes = np.random.choice(
76range(np.shape(x)[0]), size=FLAGS.num_classes, replace=False)
77
78support_set = []
79query_set = []
80support_sety = []
81query_sety = []
82for k in list(classes):
83# sample SHOT and QUERY instances
84idx = np.random.choice(
85range(np.shape(x)[1]),
86size=FLAGS.update_batch_size + FLAGS.update_batch_size,
87replace=False)
88x_k = x[k][idx]
89y_k = y[k][idx]
90
91support_set.append(x_k[:FLAGS.update_batch_size])
92query_set.append(x_k[FLAGS.update_batch_size:])
93support_sety.append(y_k[:FLAGS.update_batch_size])
94query_sety.append(y_k[FLAGS.update_batch_size:])
95
96xs_k = np.concatenate(support_set, 0)
97xq_k = np.concatenate(query_set, 0)
98ys_k = np.concatenate(support_sety, 0)
99yq_k = np.concatenate(query_sety, 0)
100
101xs.append(xs_k)
102xq.append(xq_k)
103ys.append(ys_k)
104yq.append(yq_k)
105
106xs, ys = np.stack(xs, 0), np.stack(ys, 0)
107xq, yq = np.stack(xq, 0), np.stack(yq, 0)
108
109xs = np.reshape(
110xs,
111[FLAGS.meta_batch_size, FLAGS.update_batch_size * FLAGS.num_classes, -1])
112xq = np.reshape(
113xq,
114[FLAGS.meta_batch_size, FLAGS.update_batch_size * FLAGS.num_classes, -1])
115xs = xs.astype(np.float32) / 255.0
116xq = xq.astype(np.float32) / 255.0
117ys = ys.astype(np.float32) * 10.0
118yq = yq.astype(np.float32) * 10.0
119return xs, ys, xq, yq
120
121
122def gen(x, y):
123while True:
124yield get_batch(np.array(x), np.array(y))
125
126
127def sampling(output):
128mu, logstd = tf.split(output, num_or_size_splits=2, axis=-1)
129sigma = tf.nn.softplus(logstd)
130ws = mu + tf.random_normal(tf.shape(mu)) * sigma
131return ws, mu, sigma
132
133
134def mse(pred, label):
135pred = tf.reshape(pred, [-1])
136label = tf.reshape(label, [-1])
137return tf.reduce_mean(tf.square(pred - label))
138
139
140def encoder_r(xys):
141"""Define encoder."""
142with tf.variable_scope('encoder_r', reuse=tf.AUTO_REUSE):
143hidden_layer = xys
144# First layers are relu
145for i, n_hidden_units in enumerate(FLAGS.n_hidden_units_r):
146hidden_layer = tf.layers.dense(
147hidden_layer,
148n_hidden_units,
149activation=tf.nn.relu,
150name='encoder_r_{}'.format(i),
151reuse=tf.AUTO_REUSE,
152kernel_initializer='normal')
153
154# Last layer is simple linear
155i = len(FLAGS.n_hidden_units_r)
156r = tf.layers.dense(
157hidden_layer,
158FLAGS.dim_r,
159name='encoder_r_{}'.format(i),
160reuse=tf.AUTO_REUSE,
161kernel_initializer='normal')
162return r
163
164
165def encoder_w(xs, encoder_w0):
166"""xs is [n_task, n_im, dim_x]; return [n_task, n_im, dim_w]."""
167n_task = tf.shape(xs)[0]
168n_im = tf.shape(xs)[1]
169xs = tf.reshape(xs, [-1, 128, 128, 1])
170
171ws = encoder_w0(xs)
172ws = tf.reshape(ws, [n_task, n_im, FLAGS.dim_w])
173return ws
174
175
176def xy_to_z(xs, ys, encoder_w0):
177r"""ws = T0(xs), rs = T1(ws, ys), r = mean(rs), z \sim N(mu(r), sigma(r))."""
178with tf.variable_scope(''):
179ws = encoder_w(xs, encoder_w0) # (n_task * n_im_per_task) * dim_w
180
181transformed_ys = tf.layers.dense(
182ys,
183FLAGS.dim_w // 4,
184name='lift_y',
185reuse=tf.AUTO_REUSE,
186kernel_initializer='normal')
187wys = tf.concat([ws, transformed_ys],
188axis=-1) # n_task * n_im_per_task * (dim_w+dim_transy)
189
190rs = encoder_r(wys) # n_task * n_im_per_task * dim_r
191
192r = tf.reduce_mean(rs, axis=1, keepdims=True) # n_task * 1 * dim_r
193
194if FLAGS.deterministic:
195z_sample = tf.layers.dense(
196r,
197FLAGS.dim_z,
198name='r2z',
199reuse=tf.AUTO_REUSE,
200kernel_initializer='normal')
201else:
202z = tf.layers.dense(
203r,
204FLAGS.dim_z + FLAGS.dim_z,
205name='r2z',
206reuse=tf.AUTO_REUSE,
207kernel_initializer='normal')
208z_sample, _, _ = sampling(z)
209
210return tf.tile(z_sample, [1, FLAGS.update_batch_size, 1]) # tile n_targets
211
212
213def construct_model(input_tensors, encoder_w0, decoder0, prefix=None):
214"""Construct model."""
215facto = tf.placeholder_with_default(1.0, ())
216context_xs = input_tensors['inputa']
217context_ys = input_tensors['labela']
218target_xs = input_tensors['inputb']
219target_ys = input_tensors['labelb']
220
221# sample ws ~ w|(x_all,a), rs = T(ws, ys), r = mean(rs), z = T(r)
222# x_all = tf.concat([context_xs, target_xs], axis=1) #n_task * 20 * (128*128)
223# y_all = tf.concat([context_ys, target_ys], axis=1)
224
225x_all = context_xs
226y_all = context_ys
227
228# n_task * [n_im] * d_z
229if 'train' in prefix:
230z_samples = xy_to_z(x_all, y_all, encoder_w0) * facto
231else:
232z_samples = xy_to_z(context_xs, context_ys, encoder_w0) * facto
233
234target_ws = encoder_w(target_xs, encoder_w0)
235input_zxs = tf.concat([z_samples, target_ws], axis=-1)
236
237# sample y_hat ~ y|(w,z)
238with tf.variable_scope('decoder'):
239target_yhat_mu = decoder0(input_zxs) # n_task * n_im * dim_y
240
241# when var of p(y | x,z) is fixed, neg-loglik <=> MSE
242mse_loss = mse(target_yhat_mu, target_ys)
243
244tf.summary.scalar(prefix + 'mse', mse_loss)
245optimizer1 = tf.train.AdamOptimizer(FLAGS.update_lr)
246optimizer2 = tf.train.AdamOptimizer(FLAGS.update_lr)
247
248if 'train' in prefix:
249if FLAGS.weight_decay:
250loss = mse_loss
251optimizer = contrib_opt.AdamWOptimizer(
252weight_decay=FLAGS.beta, learning_rate=FLAGS.update_lr)
253gvs = optimizer.compute_gradients(loss)
254train_op = optimizer.apply_gradients(gvs)
255else:
256THETA = ( # pylint: disable=invalid-name
257tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='decoder')
258+ tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='encoder_w'))
259all_var = tf.trainable_variables()
260PHI = [v for v in all_var if v not in THETA] # pylint: disable=invalid-name
261loss = mse_loss
262gvs_theta = optimizer1.compute_gradients(loss, THETA)
263train_theta_op = optimizer1.apply_gradients(gvs_theta)
264gvs_phi = optimizer2.compute_gradients(loss, PHI)
265train_phi_op = optimizer2.apply_gradients(gvs_phi)
266with tf.control_dependencies([train_theta_op, train_phi_op]):
267train_op = tf.no_op()
268return mse_loss, train_op, facto
269else:
270return mse_loss
271
272
273def main(_):
274
275encoder_w0 = tf.keras.Sequential([
276Conv2D(
277filters=32,
278kernel_size=3,
279strides=(2, 2),
280activation='relu',
281padding='same'),
282Conv2D(
283filters=48,
284kernel_size=3,
285strides=(2, 2),
286activation='relu',
287padding='same'),
288MaxPooling2D(pool_size=(2, 2)),
289Conv2D(
290filters=64,
291kernel_size=3,
292strides=(2, 2),
293activation='relu',
294padding='same'),
295tf.keras.layers.Flatten(),
296tf.keras.layers.Dense(FLAGS.dim_w),
297])
298
299decoder0 = tf.keras.Sequential([
300tf.keras.layers.Dense(100, activation=tf.nn.relu),
301tf.keras.layers.Dense(100, activation=tf.nn.relu),
302tf.keras.layers.Dense(FLAGS.dim_y),
303])
304
305dim_output = FLAGS.dim_y
306dim_input = FLAGS.dim_im * FLAGS.dim_im * 1
307
308if FLAGS.weight_decay:
309exp_name = '%s.update_lr-%g.beta-%g.trial-%d' % (
310'np_vanilla', FLAGS.update_lr, FLAGS.beta, FLAGS.trial)
311else:
312exp_name = '%s.update_lr-%g.trial-%d' % ('np_vanilla', FLAGS.update_lr,
313FLAGS.trial)
314checkpoint_dir = os.path.join(FLAGS.logdir, exp_name)
315
316x_train, y_train = pickle.load(
317tf.io.gfile.GFile(os.path.join(get_data_dir(), FLAGS.data[0]), 'rb'))
318x_val, y_val = pickle.load(
319tf.io.gfile.GFile(os.path.join(get_data_dir(), FLAGS.data[1]), 'rb'))
320
321x_train, y_train = np.array(x_train), np.array(y_train)
322y_train = y_train[:, :, -1, None]
323x_val, y_val = np.array(x_val), np.array(y_val)
324y_val = y_val[:, :, -1, None]
325
326ds_train = tf.data.Dataset.from_generator(
327functools.partial(gen, x_train, y_train),
328(tf.float32, tf.float32, tf.float32, tf.float32),
329(tf.TensorShape(
330[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
331tf.TensorShape(
332[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output]),
333tf.TensorShape(
334[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
335tf.TensorShape(
336[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output])))
337
338ds_val = tf.data.Dataset.from_generator(
339functools.partial(gen, x_val, y_val),
340(tf.float32, tf.float32, tf.float32, tf.float32),
341(tf.TensorShape(
342[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
343tf.TensorShape(
344[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output]),
345tf.TensorShape(
346[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
347tf.TensorShape(
348[None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output])))
349
350inputa, labela, inputb, labelb = ds_train.make_one_shot_iterator().get_next()
351
352input_tensors = {'inputa': inputa,\
353'inputb': inputb,\
354'labela': labela, 'labelb': labelb}
355
356inputa_val, labela_val, inputb_val, labelb_val = ds_val.make_one_shot_iterator(
357).get_next()
358
359metaval_input_tensors = {'inputa': inputa_val,\
360'inputb': inputb_val,\
361'labela': labela_val, 'labelb': labelb_val}
362
363loss, train_op, facto = construct_model(
364input_tensors, encoder_w0, decoder0, prefix='metatrain_')
365loss_val = construct_model(
366metaval_input_tensors, encoder_w0, decoder0, prefix='metaval_')
367
368###########
369
370summ_op = tf.summary.merge_all()
371sess = tf.InteractiveSession()
372summary_writer = tf.summary.FileWriter(checkpoint_dir, sess.graph)
373tf.global_variables_initializer().run()
374
375PRINT_INTERVAL = 50 # pylint: disable=invalid-name
376SUMMARY_INTERVAL = 5 # pylint: disable=invalid-name
377prelosses, prelosses_val = [], []
378old_time = time.time()
379for itr in range(FLAGS.num_updates):
380
381feed_dict = {facto: FLAGS.facto}
382
383if itr % SUMMARY_INTERVAL == 0:
384summary, cost, cost_val = sess.run([summ_op, loss, loss_val], feed_dict)
385summary_writer.add_summary(summary, itr)
386prelosses.append(cost) # 0 step loss on training set
387prelosses_val.append(cost_val) # 0 step loss on meta_val training set
388
389sess.run(train_op, feed_dict)
390
391if (itr != 0) and itr % PRINT_INTERVAL == 0:
392print('Iteration ' + str(itr) + ': ' + str(np.mean(prelosses)), 'time =',
393time.time() - old_time)
394prelosses = []
395old_time = time.time()
396print('Validation results: ' + str(np.mean(prelosses_val)))
397prelosses_val = []
398
399
400if __name__ == '__main__':
401app.run(main)
402