google-research

Форк
0
401 строка · 13.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""NP vanilla."""
17
from __future__ import print_function
18
import functools
19
import os
20
import pickle
21
import time
22

23
from absl import app
24
from absl import flags
25
import numpy as np
26
import tensorflow.compat.v1 as tf
27
from tensorflow.compat.v1.keras.layers import Conv2D
28
from tensorflow.compat.v1.keras.layers import MaxPooling2D
29

30

31
tf.compat.v1.enable_v2_tensorshape()
32
FLAGS = flags.FLAGS
33

34
## Dataset/method options
35
flags.DEFINE_float('beta', 0.001, 'the beta for weight decay')
36
flags.DEFINE_bool('weight_decay', False, 'whether or not to use weight decay')
37

38
flags.DEFINE_string('logdir', '/tmp/data',
39
                    'directory for summaries and checkpoints.')
40
flags.DEFINE_string('data_dir', None,
41
                    'Directory of data files.')
42
get_data_dir = lambda: FLAGS.data_dir
43
flags.DEFINE_list('data', ['train_data_ins.pkl', 'val_data_ins.pkl'],
44
                  'data name')
45
flags.DEFINE_integer('update_batch_size', 15, 'number of context/target')
46
flags.DEFINE_integer('meta_batch_size', 10, 'number of tasks')
47
flags.DEFINE_integer('dim_im', 128, 'image size')
48
flags.DEFINE_integer('dim_y', 1, 'dimension of y')
49

50
## Training options
51
flags.DEFINE_list('n_hidden_units_g', [100, 100],
52
                  'number of tasks sampled per meta-update')
53
flags.DEFINE_list('n_hidden_units_r', [100, 100],
54
                  'number of inner gradient updates during test.')
55
flags.DEFINE_integer('dim_z', 64, 'dimension of z')
56
flags.DEFINE_integer('dim_r', 100, 'dimension of r for aggregating')
57
flags.DEFINE_float('update_lr', 0.001, 'lr')
58
flags.DEFINE_integer('num_updates', 140000, 'num_updates')
59
flags.DEFINE_integer('trial', 1, 'trial number')
60
flags.DEFINE_integer(
61
    'num_classes', 1,
62
    'number of classes used in classification (e.g. 5-way classification).')
63
flags.DEFINE_bool('deterministic', True, 'deterministic encoder')
64

65
## IB options
66
flags.DEFINE_integer('dim_w', 64, 'dimension of w')
67
flags.DEFINE_float('facto', 1.0, 'zero out z to memorize or not')
68

69

70
def get_batch(x, y):
71
  """Get data batch."""
72
  xs, ys, xq, yq = [], [], [], []
73
  for _ in range(FLAGS.meta_batch_size):
74
    # sample WAY classes
75
    classes = np.random.choice(
76
        range(np.shape(x)[0]), size=FLAGS.num_classes, replace=False)
77

78
    support_set = []
79
    query_set = []
80
    support_sety = []
81
    query_sety = []
82
    for k in list(classes):
83
      # sample SHOT and QUERY instances
84
      idx = np.random.choice(
85
          range(np.shape(x)[1]),
86
          size=FLAGS.update_batch_size + FLAGS.update_batch_size,
87
          replace=False)
88
      x_k = x[k][idx]
89
      y_k = y[k][idx]
90

91
      support_set.append(x_k[:FLAGS.update_batch_size])
92
      query_set.append(x_k[FLAGS.update_batch_size:])
93
      support_sety.append(y_k[:FLAGS.update_batch_size])
94
      query_sety.append(y_k[FLAGS.update_batch_size:])
95

96
    xs_k = np.concatenate(support_set, 0)
97
    xq_k = np.concatenate(query_set, 0)
98
    ys_k = np.concatenate(support_sety, 0)
99
    yq_k = np.concatenate(query_sety, 0)
100

101
    xs.append(xs_k)
102
    xq.append(xq_k)
103
    ys.append(ys_k)
104
    yq.append(yq_k)
105

106
  xs, ys = np.stack(xs, 0), np.stack(ys, 0)
107
  xq, yq = np.stack(xq, 0), np.stack(yq, 0)
108

109
  xs = np.reshape(
110
      xs,
111
      [FLAGS.meta_batch_size, FLAGS.update_batch_size * FLAGS.num_classes, -1])
112
  xq = np.reshape(
113
      xq,
114
      [FLAGS.meta_batch_size, FLAGS.update_batch_size * FLAGS.num_classes, -1])
115
  xs = xs.astype(np.float32) / 255.0
116
  xq = xq.astype(np.float32) / 255.0
117
  ys = ys.astype(np.float32) * 10.0
118
  yq = yq.astype(np.float32) * 10.0
119
  return xs, ys, xq, yq
120

121

122
def gen(x, y):
123
  while True:
124
    yield get_batch(np.array(x), np.array(y))
125

126

127
def sampling(output):
128
  mu, logstd = tf.split(output, num_or_size_splits=2, axis=-1)
129
  sigma = tf.nn.softplus(logstd)
130
  ws = mu + tf.random_normal(tf.shape(mu)) * sigma
131
  return ws, mu, sigma
132

133

134
def mse(pred, label):
135
  pred = tf.reshape(pred, [-1])
136
  label = tf.reshape(label, [-1])
137
  return tf.reduce_mean(tf.square(pred - label))
138

139

140
def encoder_r(xys):
141
  """Define encoder."""
142
  with tf.variable_scope('encoder_r', reuse=tf.AUTO_REUSE):
143
    hidden_layer = xys
144
    # First layers are relu
145
    for i, n_hidden_units in enumerate(FLAGS.n_hidden_units_r):
146
      hidden_layer = tf.layers.dense(
147
          hidden_layer,
148
          n_hidden_units,
149
          activation=tf.nn.relu,
150
          name='encoder_r_{}'.format(i),
151
          reuse=tf.AUTO_REUSE,
152
          kernel_initializer='normal')
153

154
    # Last layer is simple linear
155
    i = len(FLAGS.n_hidden_units_r)
156
    r = tf.layers.dense(
157
        hidden_layer,
158
        FLAGS.dim_r,
159
        name='encoder_r_{}'.format(i),
160
        reuse=tf.AUTO_REUSE,
161
        kernel_initializer='normal')
162
  return r
163

164

165
def encoder_w(xs, encoder_w0):
166
  """xs is [n_task, n_im, dim_x]; return [n_task, n_im, dim_w]."""
167
  n_task = tf.shape(xs)[0]
168
  n_im = tf.shape(xs)[1]
169
  xs = tf.reshape(xs, [-1, 128, 128, 1])
170

171
  ws = encoder_w0(xs)
172
  ws = tf.reshape(ws, [n_task, n_im, FLAGS.dim_w])
173
  return ws
174

175

176
def xy_to_z(xs, ys, encoder_w0):
177
  r"""ws = T0(xs), rs = T1(ws, ys), r = mean(rs), z \sim N(mu(r), sigma(r))."""
178
  with tf.variable_scope(''):
179
    ws = encoder_w(xs, encoder_w0)  # (n_task * n_im_per_task) * dim_w
180

181
  transformed_ys = tf.layers.dense(
182
      ys,
183
      FLAGS.dim_w // 4,
184
      name='lift_y',
185
      reuse=tf.AUTO_REUSE,
186
      kernel_initializer='normal')
187
  wys = tf.concat([ws, transformed_ys],
188
                  axis=-1)  # n_task *  n_im_per_task * (dim_w+dim_transy)
189

190
  rs = encoder_r(wys)  # n_task *  n_im_per_task * dim_r
191

192
  r = tf.reduce_mean(rs, axis=1, keepdims=True)  # n_task * 1 * dim_r
193

194
  if FLAGS.deterministic:
195
    z_sample = tf.layers.dense(
196
        r,
197
        FLAGS.dim_z,
198
        name='r2z',
199
        reuse=tf.AUTO_REUSE,
200
        kernel_initializer='normal')
201
  else:
202
    z = tf.layers.dense(
203
        r,
204
        FLAGS.dim_z + FLAGS.dim_z,
205
        name='r2z',
206
        reuse=tf.AUTO_REUSE,
207
        kernel_initializer='normal')
208
    z_sample, _, _ = sampling(z)
209

210
  return tf.tile(z_sample, [1, FLAGS.update_batch_size, 1])  # tile n_targets
211

212

213
def construct_model(input_tensors, encoder_w0, decoder0, prefix=None):
214
  """Construct model."""
215
  facto = tf.placeholder_with_default(1.0, ())
216
  context_xs = input_tensors['inputa']
217
  context_ys = input_tensors['labela']
218
  target_xs = input_tensors['inputb']
219
  target_ys = input_tensors['labelb']
220

221
  # sample ws ~ w|(x_all,a), rs = T(ws, ys), r = mean(rs), z = T(r)
222
  # x_all = tf.concat([context_xs, target_xs], axis=1) #n_task * 20 * (128*128)
223
  # y_all = tf.concat([context_ys, target_ys], axis=1)
224

225
  x_all = context_xs
226
  y_all = context_ys
227

228
  # n_task * [n_im] * d_z
229
  if 'train' in prefix:
230
    z_samples = xy_to_z(x_all, y_all, encoder_w0) * facto
231
  else:
232
    z_samples = xy_to_z(context_xs, context_ys, encoder_w0) * facto
233

234
  target_ws = encoder_w(target_xs, encoder_w0)
235
  input_zxs = tf.concat([z_samples, target_ws], axis=-1)
236

237
  # sample y_hat ~  y|(w,z)
238
  with tf.variable_scope('decoder'):
239
    target_yhat_mu = decoder0(input_zxs)  # n_task * n_im * dim_y
240

241
  # when var of  p(y | x,z) is fixed, neg-loglik <=> MSE
242
  mse_loss = mse(target_yhat_mu, target_ys)
243

244
  tf.summary.scalar(prefix + 'mse', mse_loss)
245
  optimizer1 = tf.train.AdamOptimizer(FLAGS.update_lr)
246
  optimizer2 = tf.train.AdamOptimizer(FLAGS.update_lr)
247

248
  if 'train' in prefix:
249
    if FLAGS.weight_decay:
250
      loss = mse_loss
251
      optimizer = contrib_opt.AdamWOptimizer(
252
          weight_decay=FLAGS.beta, learning_rate=FLAGS.update_lr)
253
      gvs = optimizer.compute_gradients(loss)
254
      train_op = optimizer.apply_gradients(gvs)
255
    else:
256
      THETA = (  # pylint: disable=invalid-name
257
          tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='decoder')
258
          + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='encoder_w'))
259
      all_var = tf.trainable_variables()
260
      PHI = [v for v in all_var if v not in THETA]  # pylint: disable=invalid-name
261
      loss = mse_loss
262
      gvs_theta = optimizer1.compute_gradients(loss, THETA)
263
      train_theta_op = optimizer1.apply_gradients(gvs_theta)
264
      gvs_phi = optimizer2.compute_gradients(loss, PHI)
265
      train_phi_op = optimizer2.apply_gradients(gvs_phi)
266
      with tf.control_dependencies([train_theta_op, train_phi_op]):
267
        train_op = tf.no_op()
268
    return mse_loss, train_op, facto
269
  else:
270
    return mse_loss
271

272

273
def main(_):
274

275
  encoder_w0 = tf.keras.Sequential([
276
      Conv2D(
277
          filters=32,
278
          kernel_size=3,
279
          strides=(2, 2),
280
          activation='relu',
281
          padding='same'),
282
      Conv2D(
283
          filters=48,
284
          kernel_size=3,
285
          strides=(2, 2),
286
          activation='relu',
287
          padding='same'),
288
      MaxPooling2D(pool_size=(2, 2)),
289
      Conv2D(
290
          filters=64,
291
          kernel_size=3,
292
          strides=(2, 2),
293
          activation='relu',
294
          padding='same'),
295
      tf.keras.layers.Flatten(),
296
      tf.keras.layers.Dense(FLAGS.dim_w),
297
  ])
298

299
  decoder0 = tf.keras.Sequential([
300
      tf.keras.layers.Dense(100, activation=tf.nn.relu),
301
      tf.keras.layers.Dense(100, activation=tf.nn.relu),
302
      tf.keras.layers.Dense(FLAGS.dim_y),
303
  ])
304

305
  dim_output = FLAGS.dim_y
306
  dim_input = FLAGS.dim_im * FLAGS.dim_im * 1
307

308
  if FLAGS.weight_decay:
309
    exp_name = '%s.update_lr-%g.beta-%g.trial-%d' % (
310
        'np_vanilla', FLAGS.update_lr, FLAGS.beta, FLAGS.trial)
311
  else:
312
    exp_name = '%s.update_lr-%g.trial-%d' % ('np_vanilla', FLAGS.update_lr,
313
                                             FLAGS.trial)
314
  checkpoint_dir = os.path.join(FLAGS.logdir, exp_name)
315

316
  x_train, y_train = pickle.load(
317
      tf.io.gfile.GFile(os.path.join(get_data_dir(), FLAGS.data[0]), 'rb'))
318
  x_val, y_val = pickle.load(
319
      tf.io.gfile.GFile(os.path.join(get_data_dir(), FLAGS.data[1]), 'rb'))
320

321
  x_train, y_train = np.array(x_train), np.array(y_train)
322
  y_train = y_train[:, :, -1, None]
323
  x_val, y_val = np.array(x_val), np.array(y_val)
324
  y_val = y_val[:, :, -1, None]
325

326
  ds_train = tf.data.Dataset.from_generator(
327
      functools.partial(gen, x_train, y_train),
328
      (tf.float32, tf.float32, tf.float32, tf.float32),
329
      (tf.TensorShape(
330
          [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
331
       tf.TensorShape(
332
           [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output]),
333
       tf.TensorShape(
334
           [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
335
       tf.TensorShape(
336
           [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output])))
337

338
  ds_val = tf.data.Dataset.from_generator(
339
      functools.partial(gen, x_val, y_val),
340
      (tf.float32, tf.float32, tf.float32, tf.float32),
341
      (tf.TensorShape(
342
          [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
343
       tf.TensorShape(
344
           [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output]),
345
       tf.TensorShape(
346
           [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
347
       tf.TensorShape(
348
           [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output])))
349

350
  inputa, labela, inputb, labelb = ds_train.make_one_shot_iterator().get_next()
351

352
  input_tensors = {'inputa': inputa,\
353
                   'inputb': inputb,\
354
                   'labela': labela, 'labelb': labelb}
355

356
  inputa_val, labela_val, inputb_val, labelb_val = ds_val.make_one_shot_iterator(
357
  ).get_next()
358

359
  metaval_input_tensors = {'inputa': inputa_val,\
360
                           'inputb': inputb_val,\
361
                           'labela': labela_val, 'labelb': labelb_val}
362

363
  loss, train_op, facto = construct_model(
364
      input_tensors, encoder_w0, decoder0, prefix='metatrain_')
365
  loss_val = construct_model(
366
      metaval_input_tensors, encoder_w0, decoder0, prefix='metaval_')
367

368
  ###########
369

370
  summ_op = tf.summary.merge_all()
371
  sess = tf.InteractiveSession()
372
  summary_writer = tf.summary.FileWriter(checkpoint_dir, sess.graph)
373
  tf.global_variables_initializer().run()
374

375
  PRINT_INTERVAL = 50  # pylint: disable=invalid-name
376
  SUMMARY_INTERVAL = 5  # pylint: disable=invalid-name
377
  prelosses, prelosses_val = [], []
378
  old_time = time.time()
379
  for itr in range(FLAGS.num_updates):
380

381
    feed_dict = {facto: FLAGS.facto}
382

383
    if itr % SUMMARY_INTERVAL == 0:
384
      summary, cost, cost_val = sess.run([summ_op, loss, loss_val], feed_dict)
385
      summary_writer.add_summary(summary, itr)
386
      prelosses.append(cost)  # 0 step loss on training set
387
      prelosses_val.append(cost_val)  # 0 step loss on meta_val training set
388

389
    sess.run(train_op, feed_dict)
390

391
    if (itr != 0) and itr % PRINT_INTERVAL == 0:
392
      print('Iteration ' + str(itr) + ': ' + str(np.mean(prelosses)), 'time =',
393
            time.time() - old_time)
394
      prelosses = []
395
      old_time = time.time()
396
      print('Validation results: ' + str(np.mean(prelosses_val)))
397
      prelosses_val = []
398

399

400
if __name__ == '__main__':
401
  app.run(main)
402

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.