google-research

Форк
0
407 строк · 13.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Add KL, beta-BBB, just on encoder_w and include a version of vanilla NP."""
17
from __future__ import print_function
18
import functools
19
import os
20
import pickle
21
import time
22

23
from absl import app
24
from absl import flags
25
import numpy as np
26
import tensorflow.compat.v1 as tf
27
from tensorflow.compat.v1.keras.layers import MaxPooling2D
28
import tensorflow_probability as tfp
29
from tensorflow_probability.python.layers import util as tfp_layers_util
30

31

32
tf.compat.v1.enable_v2_tensorshape()
33
FLAGS = flags.FLAGS
34

35
## Dataset/method options
36
flags.DEFINE_string('logdir', '/tmp/data',
37
                    'directory for summaries and checkpoints.')
38
flags.DEFINE_string('data_dir', None,
39
                    'Directory of data files.')
40
get_data_dir = lambda: FLAGS.data_dir
41
flags.DEFINE_list('data', ['train_data_ins.pkl', 'val_data_ins.pkl'],
42
                  'data name')
43
flags.DEFINE_integer('update_batch_size', 15, 'number of context/target')
44
flags.DEFINE_integer('meta_batch_size', 10, 'number of tasks')
45
flags.DEFINE_integer('dim_im', 128, 'image size')
46
flags.DEFINE_integer('dim_y', 1, 'dimension of y')
47

48
## Training options
49
flags.DEFINE_list('n_hidden_units_g', [100, 100],
50
                  'number of tasks sampled per meta-update')
51
flags.DEFINE_list('n_hidden_units_r', [100, 100],
52
                  'number of inner gradient updates during test.')
53
flags.DEFINE_integer('dim_z', 200, 'dimension of z')
54
flags.DEFINE_integer('dim_r', 200, 'dimension of r for aggregating')
55
flags.DEFINE_float('update_lr', 5e-4, 'lr')
56
flags.DEFINE_integer('num_updates', 100000, 'num_updates')
57
flags.DEFINE_integer('trial', 1, 'trial number')
58
flags.DEFINE_integer(
59
    'num_classes', 1,
60
    'number of classes used in classification (e.g. 5-way classification).')
61
flags.DEFINE_bool('deterministic', True, 'deterministic encoder')
62

63
flags.DEFINE_float('beta', 0.001, 'beta for IB')
64
flags.DEFINE_float('var', -3.0, 'var initial')
65
flags.DEFINE_integer('dim_w', 20, 'dimension of w')
66
flags.DEFINE_float('facto', 1.0, 'zero out z to memorize or not')
67

68

69
def get_batch(x, y):
70
  """Get data batch."""
71
  xs, ys, xq, yq = [], [], [], []
72
  for _ in range(FLAGS.meta_batch_size):
73
    # sample WAY classes
74
    classes = np.random.choice(
75
        range(np.shape(x)[0]), size=FLAGS.num_classes, replace=False)
76

77
    support_set = []
78
    query_set = []
79
    support_sety = []
80
    query_sety = []
81
    for k in list(classes):
82
      # sample SHOT and QUERY instances
83
      idx = np.random.choice(
84
          range(np.shape(x)[1]),
85
          size=FLAGS.update_batch_size + FLAGS.update_batch_size,
86
          replace=False)
87
      x_k = x[k][idx]
88
      y_k = y[k][idx]
89

90
      support_set.append(x_k[:FLAGS.update_batch_size])
91
      query_set.append(x_k[FLAGS.update_batch_size:])
92
      support_sety.append(y_k[:FLAGS.update_batch_size])
93
      query_sety.append(y_k[FLAGS.update_batch_size:])
94

95
    xs_k = np.concatenate(support_set, 0)
96
    xq_k = np.concatenate(query_set, 0)
97
    ys_k = np.concatenate(support_sety, 0)
98
    yq_k = np.concatenate(query_sety, 0)
99

100
    xs.append(xs_k)
101
    xq.append(xq_k)
102
    ys.append(ys_k)
103
    yq.append(yq_k)
104

105
  xs, ys = np.stack(xs, 0), np.stack(ys, 0)
106
  xq, yq = np.stack(xq, 0), np.stack(yq, 0)
107

108
  xs = np.reshape(
109
      xs,
110
      [FLAGS.meta_batch_size, FLAGS.update_batch_size * FLAGS.num_classes, -1])
111
  xq = np.reshape(
112
      xq,
113
      [FLAGS.meta_batch_size, FLAGS.update_batch_size * FLAGS.num_classes, -1])
114
  xs = xs.astype(np.float32) / 255.0
115
  xq = xq.astype(np.float32) / 255.0
116
  ys = ys.astype(np.float32) * 10.0
117
  yq = yq.astype(np.float32) * 10.0
118
  return xs, ys, xq, yq
119

120

121
def gen(x, y):
122
  while True:
123
    yield get_batch(np.array(x), np.array(y))
124

125

126
def sampling(output):
127
  mu, logstd = tf.split(output, num_or_size_splits=2, axis=-1)
128
  sigma = tf.nn.softplus(logstd)
129
  ws = mu + tf.random_normal(tf.shape(mu)) * sigma
130
  return ws, mu, sigma
131

132

133
def mse(pred, label):
134
  pred = tf.reshape(pred, [-1])
135
  label = tf.reshape(label, [-1])
136
  return tf.reduce_mean(tf.square(pred - label))
137

138

139
def encoder_r(xys):
140
  """Define encoder."""
141
  with tf.variable_scope('encoder_r', reuse=tf.AUTO_REUSE):
142
    hidden_layer = xys
143
    # First layers are relu
144
    for i, n_hidden_units in enumerate(FLAGS.n_hidden_units_r):
145
      hidden_layer = tf.layers.dense(
146
          hidden_layer,
147
          n_hidden_units,
148
          activation=tf.nn.relu,
149
          name='encoder_r_{}'.format(i),
150
          reuse=tf.AUTO_REUSE,
151
          kernel_initializer='normal')
152

153
    # Last layer is simple linear
154
    i = len(FLAGS.n_hidden_units_r)
155
    r = tf.layers.dense(
156
        hidden_layer,
157
        FLAGS.dim_r,
158
        name='encoder_r_{}'.format(i),
159
        reuse=tf.AUTO_REUSE,
160
        kernel_initializer='normal')
161
  return r
162

163

164
def encoder_w(xs, encoder_w0):
165
  """xs is [n_task, n_im, dim_x]; return [n_task, n_im, dim_w]."""
166
  n_task = tf.shape(xs)[0]
167
  n_im = tf.shape(xs)[1]
168
  xs = tf.reshape(xs, [-1, 128, 128, 1])
169

170
  ws = encoder_w0(xs)
171
  ws = tf.reshape(ws, [n_task, n_im, FLAGS.dim_w])
172
  return ws
173

174

175
def xy_to_z(xs, ys, encoder_w0):
176
  r"""ws = T0(xs), rs = T1(ws, ys), r = mean(rs), z \sim N(mu(r), sigma(r))."""
177
  with tf.variable_scope(''):
178
    ws = encoder_w(xs, encoder_w0)  # (n_task * n_im_per_task) * dim_w
179

180
  transformed_ys = tf.layers.dense(
181
      ys,
182
      FLAGS.dim_w // 4,
183
      name='lift_y',
184
      reuse=tf.AUTO_REUSE,
185
      kernel_initializer='normal')
186
  wys = tf.concat([ws, transformed_ys],
187
                  axis=-1)  # n_task *  n_im_per_task * (dim_w+dim_transy)
188

189
  rs = encoder_r(wys)  # n_task *  n_im_per_task * dim_r
190

191
  r = tf.reduce_mean(rs, axis=1, keepdims=True)  # n_task * 1 * dim_r
192

193
  if FLAGS.deterministic:
194
    z_sample = tf.layers.dense(
195
        r,
196
        FLAGS.dim_z,
197
        name='r2z',
198
        reuse=tf.AUTO_REUSE,
199
        kernel_initializer='normal')
200
  else:
201
    z = tf.layers.dense(
202
        r,
203
        FLAGS.dim_z + FLAGS.dim_z,
204
        name='r2z',
205
        reuse=tf.AUTO_REUSE,
206
        kernel_initializer='normal')
207
    z_sample, _, _ = sampling(z)
208

209
  return tf.tile(z_sample, [1, FLAGS.update_batch_size, 1])  # tile n_targets
210

211

212
def construct_model(input_tensors, encoder_w0, decoder0, prefix=None):
213
  """Construct model."""
214
  facto = tf.placeholder_with_default(1.0, ())
215
  context_xs = input_tensors['inputa']
216
  context_ys = input_tensors['labela']
217
  target_xs = input_tensors['inputb']
218
  target_ys = input_tensors['labelb']
219

220
  # sample ws ~ w|(x_all,a), rs = T(ws, ys), r = mean(rs), z = T(r)
221
  # x_all = tf.concat([context_xs, target_xs], axis=1) #n_task * 20 * (128*128)
222
  # y_all = tf.concat([context_ys, target_ys], axis=1)
223

224
  x_all = context_xs
225
  y_all = context_ys
226

227
  # n_task * [n_im] * d_z
228
  if 'train' in prefix:
229
    z_samples = xy_to_z(x_all, y_all, encoder_w0) * facto
230
  else:
231
    z_samples = xy_to_z(context_xs, context_ys, encoder_w0) * facto
232

233
  target_ws = encoder_w(target_xs, encoder_w0)
234
  input_zxs = tf.concat([z_samples, target_ws], axis=-1)
235

236
  # sample y_hat ~  y|(w,z)
237
  with tf.variable_scope('decoder'):
238
    target_yhat_mu = decoder0(input_zxs)  # n_task * n_im * dim_y
239

240
  # when var of  p(y | x,z) is fixed, neg-loglik <=> MSE
241
  mse_loss = mse(target_yhat_mu, target_ys)
242

243
  tf.summary.scalar(prefix + 'mse', mse_loss)
244
  optimizer1 = tf.train.AdamOptimizer(FLAGS.update_lr)
245
  optimizer2 = tf.train.AdamOptimizer(FLAGS.update_lr)
246

247
  if 'train' in prefix:
248
    THETA = (  # pylint: disable=invalid-name
249
        tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='decoder') +
250
        tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='encoder_w'))
251
    all_var = tf.trainable_variables()
252
    PHI = [v for v in all_var if v not in THETA]  # pylint: disable=invalid-name
253

254
    kl_loss = sum(encoder_w0.losses)  # +sum(decoder0.losses)
255

256
    scale_v = [v for v in encoder_w0.trainable_variables if 'scale' in v.name]
257
    scale_norm = [tf.reduce_mean(v) for v in scale_v]
258
    scale_norm = tf.reduce_mean(scale_norm)
259

260
    loss = mse_loss + FLAGS.beta * kl_loss
261

262
    gvs_theta = optimizer1.compute_gradients(loss, THETA)
263
    train_theta_op = optimizer1.apply_gradients(gvs_theta)
264

265
    gvs_phi = optimizer2.compute_gradients(loss, PHI)
266
    train_phi_op = optimizer2.apply_gradients(gvs_phi)
267
    with tf.control_dependencies([train_theta_op, train_phi_op]):
268
      train_op = tf.no_op()
269
    tf.summary.scalar(prefix + 'full_loss', loss)
270
    tf.summary.scalar(prefix + 'regularizer', FLAGS.beta * kl_loss)
271
    tf.summary.scalar(prefix + 'untransformed_scale', scale_norm)
272
    return mse_loss, train_op, facto
273
  else:
274
    return mse_loss
275

276

277
def main(_):
278
  kernel_posterior_fn = tfp_layers_util.default_mean_field_normal_fn(
279
      untransformed_scale_initializer=tf.compat.v1.initializers.random_normal(
280
          mean=FLAGS.var, stddev=0.1))
281
  encoder_w0 = tf.keras.Sequential([
282
      tfp.layers.Convolution2DReparameterization(
283
          filters=32,
284
          kernel_size=3,
285
          strides=(2, 2),
286
          activation='relu',
287
          padding='SAME',
288
          kernel_posterior_fn=kernel_posterior_fn),
289
      tfp.layers.Convolution2DReparameterization(
290
          filters=48,
291
          kernel_size=3,
292
          strides=(2, 2),
293
          activation='relu',
294
          padding='SAME',
295
          kernel_posterior_fn=kernel_posterior_fn),
296
      MaxPooling2D(pool_size=(2, 2)),
297
      tfp.layers.Convolution2DReparameterization(
298
          filters=64,
299
          kernel_size=3,
300
          strides=(2, 2),
301
          activation='relu',
302
          padding='SAME',
303
          kernel_posterior_fn=kernel_posterior_fn),
304
      tf.keras.layers.Flatten(),
305
      tfp.layers.DenseReparameterization(
306
          FLAGS.dim_w, kernel_posterior_fn=kernel_posterior_fn),
307
  ])
308

309
  decoder0 = tf.keras.Sequential([
310
      tf.keras.layers.Dense(100, activation=tf.nn.relu),
311
      tf.keras.layers.Dense(100, activation=tf.nn.relu),
312
      tf.keras.layers.Dense(FLAGS.dim_y),
313
  ])
314

315
  dim_output = FLAGS.dim_y
316
  dim_input = FLAGS.dim_im * FLAGS.dim_im * 1
317

318
  exp_name = '%s.beta-%g.update_lr-%g.trial-%d' % ('np_bbb', FLAGS.beta,
319
                                                   FLAGS.update_lr, FLAGS.trial)
320
  checkpoint_dir = os.path.join(FLAGS.logdir, exp_name)
321

322
  x_train, y_train = pickle.load(
323
      tf.io.gfile.GFile(os.path.join(get_data_dir(), FLAGS.data[0]), 'rb'))
324
  x_val, y_val = pickle.load(
325
      tf.io.gfile.GFile(os.path.join(get_data_dir(), FLAGS.data[1]), 'rb'))
326

327
  x_train, y_train = np.array(x_train), np.array(y_train)
328
  y_train = y_train[:, :, -1, None]
329
  x_val, y_val = np.array(x_val), np.array(y_val)
330
  y_val = y_val[:, :, -1, None]
331

332
  ds_train = tf.data.Dataset.from_generator(
333
      functools.partial(gen, x_train, y_train),
334
      (tf.float32, tf.float32, tf.float32, tf.float32),
335
      (tf.TensorShape(
336
          [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
337
       tf.TensorShape(
338
           [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output]),
339
       tf.TensorShape(
340
           [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
341
       tf.TensorShape(
342
           [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output])))
343

344
  ds_val = tf.data.Dataset.from_generator(
345
      functools.partial(gen, x_val, y_val),
346
      (tf.float32, tf.float32, tf.float32, tf.float32),
347
      (tf.TensorShape(
348
          [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
349
       tf.TensorShape(
350
           [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output]),
351
       tf.TensorShape(
352
           [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
353
       tf.TensorShape(
354
           [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output])))
355

356
  inputa, labela, inputb, labelb = ds_train.make_one_shot_iterator().get_next()
357

358
  input_tensors = {'inputa': inputa,\
359
                   'inputb': inputb,\
360
                   'labela': labela, 'labelb': labelb}
361

362
  inputa_val, labela_val, inputb_val, labelb_val = ds_val.make_one_shot_iterator(
363
  ).get_next()
364

365
  metaval_input_tensors = {'inputa': inputa_val,\
366
                           'inputb': inputb_val,\
367
                           'labela': labela_val, 'labelb': labelb_val}
368

369
  loss, train_op, facto = construct_model(
370
      input_tensors, encoder_w0, decoder0, prefix='metatrain_')
371
  loss_val = construct_model(
372
      metaval_input_tensors, encoder_w0, decoder0, prefix='metaval_')
373

374
  ###########
375

376
  summ_op = tf.summary.merge_all()
377
  sess = tf.InteractiveSession()
378
  summary_writer = tf.summary.FileWriter(checkpoint_dir, sess.graph)
379
  tf.global_variables_initializer().run()
380

381
  PRINT_INTERVAL = 50  # pylint: disable=invalid-name
382
  SUMMARY_INTERVAL = 5  # pylint: disable=invalid-name
383
  prelosses, prelosses_val = [], []
384
  old_time = time.time()
385
  for itr in range(FLAGS.num_updates):
386

387
    feed_dict = {facto: FLAGS.facto}
388

389
    if itr % SUMMARY_INTERVAL == 0:
390
      summary, cost, cost_val = sess.run([summ_op, loss, loss_val], feed_dict)
391
      summary_writer.add_summary(summary, itr)
392
      prelosses.append(cost)  # 0 step loss on training set
393
      prelosses_val.append(cost_val)  # 0 step loss on meta_val training set
394

395
    sess.run(train_op, feed_dict)
396

397
    if (itr != 0) and itr % PRINT_INTERVAL == 0:
398
      print('Iteration ' + str(itr) + ': ' + str(np.mean(prelosses)), 'time =',
399
            time.time() - old_time)
400
      prelosses = []
401
      old_time = time.time()
402
      print('Validation results: ' + str(np.mean(prelosses_val)))
403
      prelosses_val = []
404

405

406
if __name__ == '__main__':
407
  app.run(main)
408

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.