google-research

Форк
0
/
main_dble.py 
1054 строки · 38.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Training and evaluation using Distance Based Learning from Errors (DBLE)."""
17

18
from __future__ import print_function
19

20
import argparse
21
import logging
22
import os
23
import sys
24
import time
25

26
import numpy as np
27
import pathlib
28
import tensorflow.compat.v1 as tf
29
from tqdm import trange
30

31
sys.path.insert(0, '..')
32
# pylint: disable=g-import-not-at-top
33
from dble import data_loader
34
from dble import mlp
35
from dble import resnet
36
from dble import utils
37
from dble import vgg
38
from tensorflow.contrib import slim as contrib_slim
39

40
tf.logging.set_verbosity(tf.logging.INFO)
41
logging.basicConfig(level=logging.INFO)
42

43
feature_dim = 512
44

45

46
class Namespace(object):
47

48
  def __init__(self, adict):
49
    self.__dict__.update(adict)
50

51

52
def get_arguments():
53
  """Processes all parameters."""
54
  parser = argparse.ArgumentParser()
55

56
  # Dataset parameters
57
  parser.add_argument(
58
      '--data_dir',
59
      type=str,
60
      default='',
61
      help='Path to the data, only for Tiny-ImageNet.')
62
  parser.add_argument(
63
      '--val_data_dir',
64
      type=str,
65
      default='',
66
      help='Path to the validation data, only for Tiny-ImageNet.')
67

68
  # Training parameters
69
  parser.add_argument(
70
      '--number_of_steps',
71
      type=int,
72
      default=int(120000),
73
      help='Number of training steps.')
74
  parser.add_argument(
75
      '--number_of_steps_to_early_stop',
76
      type=int,
77
      default=int(75000),
78
      help='Number of training steps after half way to early stop.')
79
  parser.add_argument(
80
      '--log_dir',
81
      type=str,
82
      default='/tmp/cifar_10/',
83
      help='experiment directory.')
84
  parser.add_argument(
85
      '--num_tasks_per_batch',
86
      type=int,
87
      default=2,
88
      help='Number of few shot episodes per batch.')
89
  parser.add_argument(
90
      '--init_learning_rate',
91
      type=float,
92
      default=0.1,
93
      help='Initial learning rate.')
94
  parser.add_argument(
95
      '--save_summaries_secs',
96
      type=int,
97
      default=300,
98
      help='Time between saving summaries')
99
  parser.add_argument(
100
      '--save_interval_secs',
101
      type=int,
102
      default=300,
103
      help='Time between saving models.')
104
  parser.add_argument(
105
      '--optimizer', type=str, default='sgd', choices=['sgd', 'adam'])
106

107
  # Optimization parameters
108
  parser.add_argument(
109
      '--lr_anneal',
110
      type=str,
111
      default='pwc',
112
      choices=['const', 'pwc', 'cos', 'exp'])
113
  parser.add_argument('--n_lr_decay', type=int, default=3)
114
  parser.add_argument('--lr_decay_rate', type=float, default=10.0)
115
  parser.add_argument(
116
      '--num_steps_decay_pwc',
117
      type=int,
118
      default=10000,
119
      help='Decay learning rate every num_steps_decay_pwc')
120
  parser.add_argument(
121
      '--clip_gradient_norm',
122
      type=float,
123
      default=1.0,
124
      help='gradient clip norm.')
125
  parser.add_argument(
126
      '--weights_initializer_factor',
127
      type=float,
128
      default=0.1,
129
      help='multiplier in the variance of the initialization noise.')
130

131
  # Evaluation parameters
132
  parser.add_argument(
133
      '--eval_interval_secs',
134
      type=int,
135
      default=0,
136
      help='Time between evaluating model.')
137
  parser.add_argument(
138
      '--eval_interval_steps',
139
      type=int,
140
      default=2000,
141
      help='Number of train steps between evaluating model in training loop.')
142
  parser.add_argument(
143
      '--eval_interval_fine_steps',
144
      type=int,
145
      default=1000,
146
      help='Number of train steps between evaluating model in the final phase.')
147

148
  # Architecture parameters
149
  parser.add_argument('--conv_keepdim', type=float, default=0.5)
150
  parser.add_argument('--neck', type=bool, default=False, help='')
151
  parser.add_argument('--num_forward', type=int, default=10, help='')
152
  parser.add_argument('--weight_decay', type=float, default=0.0005)
153
  parser.add_argument('--num_cases_train', type=int, default=50000)
154
  parser.add_argument('--num_cases_test', type=int, default=10000)
155
  parser.add_argument('--model_name', type=str, default='vgg')
156
  parser.add_argument('--dataset', type=str, default='cifar10')
157
  parser.add_argument(
158
      '--num_samples_per_class', type=int, default=5000, help='')
159
  parser.add_argument(
160
      '--num_classes_total',
161
      type=int,
162
      default=10,
163
      help='Number of classes in total of the data set.')
164
  parser.add_argument(
165
      '--num_classes_test',
166
      type=int,
167
      default=10,
168
      help='Number of classes in the test phase. ')
169
  parser.add_argument(
170
      '--num_classes_train',
171
      type=int,
172
      default=10,
173
      help='Number of classes in a protoypical episode.')
174
  parser.add_argument(
175
      '--num_shots_train',
176
      type=int,
177
      default=10,
178
      help='Number of shots (support samples) in a prototypical episode.')
179
  parser.add_argument(
180
      '--train_batch_size',
181
      type=int,
182
      default=100,
183
      help='The size of the query batch in a prototypical episode.')
184

185
  args, _ = parser.parse_known_args()
186
  print(args)
187
  return args
188

189

190
def build_feature_extractor_graph(inputs,
191
                                  flags,
192
                                  is_variance,
193
                                  is_training=False,
194
                                  model=None):
195
  """Calculates the representations and variances for inputs.
196

197
  Args:
198
    inputs: The input batch with shape (batch_size, height, width,
199
      num_channels). The batch_size of a support batch is num_classes_per_task
200
      *num_supports_per_class*num_tasks. The batch_size of a query batch is
201
      query_batch_size_per_task*num_tasks.
202
    flags: The hyperparameter dictionary.
203
    is_variance: The bool value of whether to calculate variances for every
204
      training sample. For support samples, calculating variaces is not
205
      required.
206
    is_training: The bool value of whether to use training mode.
207
    model: The representation model defined in function train(flags).
208

209
  Returns:
210
    h: The representations of the input batch with shape
211
    (batch_size, feature_dim).
212
    variance: The variances of the input batch with shape
213
    (batch_size, feature_dim).
214
  """
215
  variance = None
216
  with tf.variable_scope('feature_extractor', reuse=tf.AUTO_REUSE):
217
    h = model.encoder(inputs, training=is_training)
218
    if is_variance:
219
      variance = model.confidence_model(h, training=is_training)
220
    embedding_shape = h.get_shape().as_list()
221
    if is_training:
222
      h = tf.reshape(
223
          h,
224
          shape=(flags.num_tasks_per_batch,
225
                 embedding_shape[0] // flags.num_tasks_per_batch, -1),
226
          name='reshape_to_multi_task_format')
227
      if is_variance:
228
        variance = tf.reshape(
229
            variance,
230
            shape=(flags.num_tasks_per_batch,
231
                   embedding_shape[0] // flags.num_tasks_per_batch, -1),
232
            name='reshape_to_multi_task_format')
233
    else:
234
      h = tf.reshape(
235
          h,
236
          shape=(1, embedding_shape[0], -1),
237
          name='reshape_to_multi_task_format')
238
      if is_variance:
239
        variance = tf.reshape(
240
            variance,
241
            shape=(1, embedding_shape[0], -1),
242
            name='reshape_to_multi_task_format')
243

244
    return h, variance
245

246

247
def calculate_class_center(support_embeddings,
248
                           flags,
249
                           is_training,
250
                           scope='class_center_calculator'):
251
  """Calculates the class centers for every episode given support embeddings.
252

253
  Args:
254
    support_embeddings: The support embeddings with shape
255
      (num_classes_per_task*num_supports_per_class*num_tasks, height, width,
256
      num_channels).
257
    flags: The hyperparameter dictionary.
258
    is_training: The bool value of whether to use training mode.
259
    scope: The name of the variable scope.
260

261
  Returns:
262
    class_center: The representations of the class centers with shape
263
    (num_supports_per_class*num_tasks, feature_dim).
264
  """
265
  with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
266
    class_center = support_embeddings
267
    if is_training:
268
      class_center = tf.reshape(
269
          class_center,
270
          shape=(flags.num_tasks_per_batch, flags.num_classes_train,
271
                 flags.num_shots_train, -1),
272
          name='reshape_to_multi_task_format')
273
    else:
274
      class_center = tf.reshape(
275
          class_center,
276
          shape=(1, flags.num_classes_test, flags.num_shots_train, -1),
277
          name='reshape_to_multi_task_format')
278
    class_center = tf.reduce_mean(class_center, axis=2, keep_dims=False)
279

280
    return class_center
281

282

283
def build_euclidean_calculator(query_representation,
284
                               class_center,
285
                               flags,
286
                               scope='prototypical_head'):
287
  """Calculates the negative Euclidean distance of queries to class centers.
288

289
  Args:
290
    query_representation: The query embeddings with shape (num_tasks,
291
      query_batch_size_per_task, feature_dim).
292
    class_center: The representations of class centers with shape (num_tasks,
293
      num_training_class, feature_dim).
294
    flags: The hyperparameter dictionary.
295
    scope: The name of the variable scope.
296

297
  Returns:
298
    negative_euclidean: The negative euclidean distance of queries to the class
299
     centers in their episodes. The shape of negative_euclidean is (num_tasks,
300
     query_batch_size_per_task, num_training_class).
301
  """
302
  with tf.variable_scope(scope):
303
    if len(query_representation.get_shape().as_list()) == 2:
304
      query_representation = tf.expand_dims(query_representation, axis=0)
305
    if len(class_center.get_shape().as_list()) == 2:
306
      class_center = tf.expand_dims(class_center, axis=0)
307

308
    # j is the number of classes in each episode
309
    # i is the number of queries in each episode
310
    j = class_center.get_shape().as_list()[1]
311
    i = query_representation.get_shape().as_list()[1]
312
    print('task_encoding_shape:' + str(j))
313

314
    # tile to be able to produce weight matrix alpha in (i,j) space
315
    query_representation = tf.expand_dims(query_representation, axis=2)
316
    class_center = tf.expand_dims(class_center, axis=1)
317
    # features_generic changes over i and is constant over j
318
    # task_encoding changes over j and is constant over i
319
    class_center_tile = tf.tile(class_center, (1, i, 1, 1))
320
    query_representation_tile = tf.tile(query_representation, (1, 1, j, 1))
321
    negative_euclidean = -tf.norm(
322
        (class_center_tile - query_representation_tile),
323
        name='neg_euclidean_distance',
324
        axis=-1)
325
    negative_euclidean = tf.reshape(
326
        negative_euclidean, shape=(flags.num_tasks_per_batch * i, -1))
327

328
    return negative_euclidean
329

330

331
def build_proto_train_graph(images_query, images_support, flags, is_training,
332
                            model):
333
  """Builds the tf graph of dble's prototypical training.
334

335
  Args:
336
    images_query: The processed query batch with shape
337
      (query_batch_size_per_task*num_tasks, height, width, num_channels).
338
    images_support: The processed support batch with shape
339
      (num_classes_per_task*num_supports_per_class*num_tasks, height, width,
340
      num_channels).
341
    flags: The hyperparameter dictionary.
342
    is_training: The bool value of whether to use training mode.
343
    model: The model defined in the main train function.
344

345
  Returns:
346
    logits: The logits before softmax (negative Euclidean) of the batch
347
    calculated with the original representations (mu in the paper) of queries.
348
    logits_z: The logits before softmax (negative Euclidean) of the batch
349
    calculated with the sampled representations (z in the paper) of queries.
350
  """
351

352
  with tf.variable_scope('Proto_training'):
353
    support_representation, _ = build_feature_extractor_graph(
354
        inputs=images_support,
355
        flags=flags,
356
        is_variance=False,
357
        is_training=is_training,
358
        model=model)
359
    class_center = calculate_class_center(
360
        support_embeddings=support_representation,
361
        flags=flags,
362
        is_training=is_training)
363
    query_representation, query_variance = build_feature_extractor_graph(
364
        inputs=images_query,
365
        flags=flags,
366
        is_variance=True,
367
        is_training=is_training,
368
        model=model)
369

370
    logits = build_euclidean_calculator(query_representation, class_center,
371
                                        flags)
372
    eps = tf.random.normal(shape=query_representation.shape)
373
    z = eps * tf.exp((query_variance) * .5) + query_representation
374
    logits_z = build_euclidean_calculator(z, class_center, flags)
375

376
  return logits, logits_z
377

378

379
def placeholder_inputs(batch_size, image_size, scope):
380
  """Builds the placeholders for the training images and labels."""
381
  with tf.variable_scope(scope):
382
    if image_size != 28:  # not mnist:
383
      images_placeholder = tf.placeholder(
384
          tf.float32,
385
          shape=(batch_size, image_size, image_size, 3),
386
          name='images')
387
    else:
388
      images_placeholder = tf.placeholder(
389
          tf.float32, shape=(batch_size, 784), name='images')
390
    labels_placeholder = tf.placeholder(
391
        tf.int32, shape=(batch_size), name='labels')
392
    return images_placeholder, labels_placeholder
393

394

395
def build_episode_placeholder(flags):
396
  """Builds the placeholders for the support and query input batches."""
397
  image_size = data_loader.get_image_size(flags.dataset)
398
  images_query_pl, labels_query_pl = placeholder_inputs(
399
      batch_size=flags.num_tasks_per_batch * flags.train_batch_size,
400
      image_size=image_size,
401
      scope='inputs/query')
402
  images_support_pl, labels_support_pl = placeholder_inputs(
403
      batch_size=flags.num_tasks_per_batch * flags.num_classes_train *
404
      flags.num_shots_train,
405
      image_size=image_size,
406
      scope='inputs/support')
407

408
  return images_query_pl, labels_query_pl, images_support_pl, labels_support_pl
409

410

411
def build_model(flags):
412
  """Builds model according to flags.
413

414
  For image data types, we considered ResNet and VGG models. One can use DBLE
415
  with other data types, by choosing a model with appropriate inductive bias
416
  for feature extraction, e.g. WaveNet for speech or BERT for text.
417
  Args:
418
    flags: The hyperparameter dictionary.
419

420
  Returns:
421
    mlp_model: the mlp model instance.
422
    vgg_model: the vgg model instance.
423
    resnet_model: the resnet model instance.
424
  """
425
  if flags.model_name == 'vgg':
426
    # Primary task operations
427
    vgg_model = vgg.vgg11(
428
        keep_prob=flags.conv_keepdim,
429
        wd=flags.weight_decay,
430
        neck=flags.neck,
431
        feature_dim=feature_dim)
432
    return vgg_model
433
  elif flags.model_name == 'mlp':
434
    mlp_model = mlp(
435
        keep_prob=flags.conv_keepdim,
436
        feature_dim=feature_dim,
437
        wd=flags.weight_decay)
438
    return mlp_model
439
  elif flags.model_name == 'resnet':
440
    if flags.dataset == 'cifar10' or flags.dataset == 'cifar100':
441
      resnet_model = resnet.Model(
442
          wd=flags.weight_decay,
443
          resnet_size=50,
444
          bottleneck=True,
445
          num_classes=flags.num_classes_train,
446
          num_filters=16,
447
          kernel_size=3,
448
          conv_stride=1,
449
          first_pool_size=None,
450
          first_pool_stride=None,
451
          block_sizes=[8, 8, 8],
452
          block_strides=[1, 2, 2],
453
          data_format='channels_last',
454
          feature_dim=feature_dim)
455
    else:
456
      resnet_model = resnet.Model(
457
          wd=flags.weight_decay,
458
          resnet_size=50,
459
          bottleneck=True,
460
          num_classes=flags.num_classes_train,
461
          num_filters=16,
462
          kernel_size=3,
463
          conv_stride=1,
464
          first_pool_size=3,
465
          first_pool_stride=1,
466
          block_sizes=[3, 4, 6, 3],
467
          block_strides=[1, 2, 2, 2],
468
          data_format='channels_last',
469
          feature_dim=feature_dim)
470
    return resnet_model
471

472

473
def train(flags):
474
  """Training entry point."""
475
  log_dir = flags.log_dir
476
  flags.pretrained_model_dir = log_dir
477
  log_dir = os.path.join(log_dir, 'train')
478
  flags.eval_interval_secs = 0
479
  with tf.Graph().as_default():
480
    global_step = tf.Variable(
481
        0, trainable=False, name='global_step', dtype=tf.int64)
482
    global_step_confidence = tf.Variable(
483
        0, trainable=False, name='global_step_confidence', dtype=tf.int64)
484

485
    model = build_model(flags)
486
    images_query_pl, labels_query_pl, \
487
    images_support_pl, labels_support_pl = \
488
      build_episode_placeholder(flags)
489

490
    # Augments the input.
491
    if flags.dataset == 'cifar10' or flags.dataset == 'cifar100':
492
      images_query_pl_aug = data_loader.augment_cifar(
493
          images_query_pl, is_training=True)
494
      images_support_pl_aug = data_loader.augment_cifar(
495
          images_support_pl, is_training=True)
496
    elif flags.dataset == 'tinyimagenet':
497
      images_query_pl_aug = data_loader.augment_tinyimagenet(
498
          images_query_pl, is_training=True)
499
      images_support_pl_aug = data_loader.augment_tinyimagenet(
500
          images_support_pl, is_training=True)
501

502
    logits, logits_z = build_proto_train_graph(
503
        images_query=images_query_pl_aug,
504
        images_support=images_support_pl_aug,
505
        flags=flags,
506
        is_training=True,
507
        model=model)
508
    # Losses and optimizer
509
    ## Classification loss
510
    loss_classification = tf.reduce_mean(
511
        tf.nn.softmax_cross_entropy_with_logits(
512
            logits=logits,
513
            labels=tf.one_hot(labels_query_pl, flags.num_classes_train)))
514

515
    # Confidence loss
516
    _, top_k_indices = tf.nn.top_k(logits, k=1)
517
    pred = tf.squeeze(top_k_indices)
518
    incorrect_mask = tf.math.logical_not(tf.math.equal(pred, labels_query_pl))
519
    incorrect_logits_z = tf.boolean_mask(logits_z, incorrect_mask)
520
    incorrect_labels_z = tf.boolean_mask(labels_query_pl, incorrect_mask)
521
    signal_variance = tf.math.reduce_sum(tf.cast(incorrect_mask, tf.int32))
522
    loss_variance_incorrect = tf.reduce_mean(
523
        tf.nn.softmax_cross_entropy_with_logits(
524
            logits=incorrect_logits_z,
525
            labels=tf.one_hot(incorrect_labels_z, flags.num_classes_train)))
526
    loss_variance_zero = 0.0
527
    loss_confidence = tf.cond(
528
        tf.greater(signal_variance, 0), lambda: loss_variance_incorrect,
529
        lambda: loss_variance_zero)
530

531
    regu_losses = tf.losses.get_regularization_losses()
532
    loss = tf.add_n([loss_classification] + regu_losses)
533

534
    # Learning rate
535
    if flags.lr_anneal == 'const':
536
      learning_rate = flags.init_learning_rate
537
    elif flags.lr_anneal == 'pwc':
538
      learning_rate = get_pwc_learning_rate(global_step, flags)
539
    elif flags.lr_anneal == 'exp':
540
      lr_decay_step = flags.number_of_steps // flags.n_lr_decay
541
      learning_rate = tf.train.exponential_decay(
542
          flags.init_learning_rate,
543
          global_step,
544
          lr_decay_step,
545
          1.0 / flags.lr_decay_rate,
546
          staircase=True)
547
    else:
548
      raise Exception('Not implemented')
549

550
    # Optimizer
551
    optimizer = tf.train.MomentumOptimizer(
552
        learning_rate=learning_rate, momentum=0.9)
553
    optimizer_confidence = tf.train.MomentumOptimizer(
554
        learning_rate=learning_rate, momentum=0.9)
555

556
    train_op = contrib_slim.learning.create_train_op(
557
        total_loss=loss,
558
        optimizer=optimizer,
559
        global_step=global_step,
560
        clip_gradient_norm=flags.clip_gradient_norm)
561
    variable_variance = []
562
    for v in tf.trainable_variables():
563
      if 'fc_variance' in v.name:
564
        variable_variance.append(v)
565
    train_op_confidence = contrib_slim.learning.create_train_op(
566
        total_loss=loss_confidence,
567
        optimizer=optimizer_confidence,
568
        global_step=global_step_confidence,
569
        clip_gradient_norm=flags.clip_gradient_norm,
570
        variables_to_train=variable_variance)
571

572
    tf.summary.scalar('loss', loss)
573
    tf.summary.scalar('loss_classification', loss_classification)
574
    tf.summary.scalar('loss_variance', loss_confidence)
575
    tf.summary.scalar('regu_loss', tf.add_n(regu_losses))
576
    tf.summary.scalar('learning_rate', learning_rate)
577
    # Merges all summaries except for pretrain
578
    summary = tf.summary.merge(
579
        tf.get_collection('summaries', scope='(?!pretrain).*'))
580

581
    # Gets datasets
582
    few_shot_data_train, test_dataset, train_dataset = get_train_datasets(flags)
583
    # Defines session and logging
584
    summary_writer_train = tf.summary.FileWriter(log_dir, flush_secs=1)
585
    saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True)
586
    print(saver.saver_def.filename_tensor_name)
587
    print(saver.saver_def.restore_op_name)
588
    # pylint: disable=unused-variable
589
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
590
    run_metadata = tf.RunMetadata()
591
    supervisor = tf.train.Supervisor(
592
        logdir=log_dir,
593
        init_feed_dict=None,
594
        summary_op=None,
595
        init_op=tf.global_variables_initializer(),
596
        summary_writer=summary_writer_train,
597
        saver=saver,
598
        global_step=global_step,
599
        save_summaries_secs=flags.save_summaries_secs,
600
        save_model_secs=0)
601

602
    with supervisor.managed_session() as sess:
603
      checkpoint_step = sess.run(global_step)
604
      if checkpoint_step > 0:
605
        checkpoint_step += 1
606
      eval_interval_steps = flags.eval_interval_steps
607
      for step in range(checkpoint_step, flags.number_of_steps):
608
        # Computes the classification loss using a batch of data.
609
        images_query, labels_query,\
610
        images_support, labels_support = \
611
          few_shot_data_train.next_few_shot_batch(
612
              query_batch_size_per_task=flags.train_batch_size,
613
              num_classes_per_task=flags.num_classes_train,
614
              num_supports_per_class=flags.num_shots_train,
615
              num_tasks=flags.num_tasks_per_batch)
616

617
        feed_dict = {
618
            images_query_pl: images_query.astype(dtype=np.float32),
619
            labels_query_pl: labels_query,
620
            images_support_pl: images_support.astype(dtype=np.float32),
621
            labels_support_pl: labels_support
622
        }
623

624
        t_batch = time.time()
625
        dt_batch = time.time() - t_batch
626

627
        t_train = time.time()
628
        loss, loss_confidence = sess.run([train_op, train_op_confidence],
629
                                         feed_dict=feed_dict)
630
        dt_train = time.time() - t_train
631

632
        if step % 100 == 0:
633
          summary_str = sess.run(summary, feed_dict=feed_dict)
634
          summary_writer_train.add_summary(summary_str, step)
635
          summary_writer_train.flush()
636
          logging.info('step %d, loss : %.4g, dt: %.3gs, dt_batch: %.3gs', step,
637
                       loss, dt_train, dt_batch)
638

639
        if float(step) / flags.number_of_steps > 0.5:
640
          eval_interval_steps = flags.eval_interval_fine_steps
641

642
        if eval_interval_steps > 0 and step % eval_interval_steps == 0:
643
          saver.save(sess, os.path.join(log_dir, 'model'), global_step=step)
644
          eval(
645
              flags=flags,
646
              train_dataset=train_dataset,
647
              test_dataset=test_dataset)
648

649
        if float(
650
            step
651
        ) > 0.5 * flags.number_of_steps + flags.number_of_steps_to_early_stop:
652
          break
653

654

655
def get_class_center_for_evaluation(train_bs, test_bs, num_classes):
656
  """The tf graph of calculating class centers at eval given training data."""
657
  x_train = tf.placeholder(shape=[train_bs, feature_dim], dtype=tf.float32)
658
  x_test = tf.placeholder(shape=[test_bs, feature_dim], dtype=tf.float32)
659
  y_train = tf.placeholder(
660
      shape=[
661
          train_bs,
662
      ], dtype=tf.int32)
663
  y_test = tf.placeholder(
664
      shape=[
665
          test_bs,
666
      ], dtype=tf.int32)
667

668
  # Finds the class centers for the training data. class label should be 0-N
669
  ind_c = tf.squeeze(tf.where(tf.equal(y_train, 0)))
670
  train_input_c = tf.gather(x_train, ind_c)
671
  train_input_c = tf.expand_dims(train_input_c, 0)
672
  centroid = tf.reduce_sum(train_input_c, 1)
673

674
  for i in range(1, num_classes):
675
    ind_c = tf.squeeze(tf.where(tf.equal(y_train, i)))
676
    tmp_input_c = tf.gather(x_train, ind_c)
677
    tmp_input_c = tf.expand_dims(tmp_input_c, 0)
678
    tmp_centroid = tf.reduce_sum(tmp_input_c, 1)
679
    centroid = tf.concat([centroid, tmp_centroid], 0)
680

681
  return x_train, x_test, y_train, y_test, centroid
682

683

684
def make_predictions_for_evaluation(centroid, x_test, y_test, flags):
685
  """The tf graph for making predictions given class centers and test data."""
686
  # Calculates the centroid pair-wise distance
687
  centroid_expand = tf.expand_dims(centroid, axis=0)
688
  # Calculates the test sample - centroid distance
689
  i = x_test.get_shape().as_list()[0]
690
  # Tiles to be able to produce weight matrix alpha in (i,j) space
691
  x_data_test = tf.expand_dims(x_test, axis=1)
692
  x_data_test = tf.tile(x_data_test, (1, flags.num_classes_total, 1))
693
  centroid_expand_test = tf.tile(centroid_expand, (i, 1, 1))
694
  euclidean = tf.norm(
695
      x_data_test - centroid_expand_test, name='euclidean_distance', axis=-1)
696
  # Prediction based on nearest neighbors
697
  _, top_k_indices = tf.nn.top_k(-euclidean, k=1)
698
  pred = tf.squeeze(top_k_indices)
699
  correct_mask = tf.cast(tf.math.equal(pred, y_test), tf.float32)
700
  correct = tf.reduce_sum(correct_mask, axis=0)
701

702
  return pred, correct
703

704

705
def calculate_nll(y_test, softmaxed_logits, num_classes):
706
  y_test = tf.one_hot(y_test, depth=num_classes)
707
  nll = tf.reduce_sum(-tf.log(tf.reduce_sum(y_test * softmaxed_logits, axis=1)))
708

709
  return nll
710

711

712
def confidence_estimation_and_evaluation(centroid, x_test, x_test_variance,
713
                                         y_test, flags):
714
  """The tf graph for confidence estimation and NLL for a batch of test data."""
715
  centroid_expand = tf.expand_dims(centroid, axis=0)
716
  i = x_test.get_shape().as_list()[0]
717
  # Tiles to be able to produce weight matrix alpha in (i,j) space
718
  x_test_tile = tf.expand_dims(x_test, axis=1)
719
  x_test_tile = tf.tile(x_test_tile, (1, flags.num_classes_total, 1))
720
  centroid_expand_tile = tf.tile(centroid_expand, (i, 1, 1))
721
  distance = tf.norm(
722
      x_test_tile - centroid_expand_tile, name='euclidean_distance', axis=-1)
723

724
  softmaxed_distance = tf.math.softmax(-distance)
725
  _, top_k_indices = tf.nn.top_k(softmaxed_distance, k=1)
726
  pred = tf.squeeze(top_k_indices)
727

728
  # Calculates the distance with the correct label
729
  row = tf.constant(np.arange(pred.get_shape().as_list()[0]))
730
  row = tf.cast(row, tf.int32)
731

732
  eps = tf.random.normal(shape=x_test.shape)
733
  z = eps * tf.exp((x_test_variance) * .5) + x_test
734
  z_tile = tf.expand_dims(z, axis=1)
735
  z_tile = tf.tile(z_tile, (1, flags.num_classes_total, 1))
736
  distance = tf.norm(
737
      z_tile - centroid_expand_tile, name='euclidean_distance', axis=-1)
738
  softmaxed_distance = tf.math.softmax(-distance)
739

740
  for _ in range(1, flags.num_forward):
741
    eps = tf.random.normal(shape=x_test.shape)
742
    z = eps * tf.exp((x_test_variance) * .5) + x_test
743
    z_tile = tf.expand_dims(z, axis=1)
744
    z_tile = tf.tile(z_tile, (1, flags.num_classes_total, 1))
745
    distance = tf.norm(
746
        z_tile - centroid_expand_tile, name='euclidean_distance', axis=-1)
747
    softmaxed_distance += tf.math.softmax(-distance)
748
  softmaxed_distance = softmaxed_distance / flags.num_forward
749
  nll_sum = calculate_nll(y_test, softmaxed_distance, flags.num_classes_total)
750
  ind_tensor = tf.transpose(tf.stack([row, pred]))
751
  confidence = tf.gather_nd(softmaxed_distance, ind_tensor)
752

753
  return nll_sum, confidence
754

755

756
def get_train_datasets(flags):
757
  if flags.dataset == 'cifar10':
758
    train_set, test_set = data_loader.load_cifar10()
759
  elif flags.dataset == 'cifar100':
760
    train_set, test_set = data_loader.load_cifar100()
761
  elif flags.dataset == 'tinyimagenet':
762
    train_set, test_set = data_loader.load_tiny_imagenet(
763
        flags.data_dir, flags.val_data_dir)
764
  episodic_train = data_loader.Dataset(train_set)
765
  return episodic_train, test_set, train_set
766

767

768
def get_pwc_learning_rate(global_step, flags):
769
  learning_rate = tf.train.piecewise_constant(global_step, [
770
      np.int64(flags.number_of_steps / 2),
771
      np.int64(flags.number_of_steps / 2 + flags.num_steps_decay_pwc),
772
      np.int64(flags.number_of_steps / 2 + 2 * flags.num_steps_decay_pwc)
773
  ], [
774
      flags.init_learning_rate, flags.init_learning_rate * 0.1,
775
      flags.init_learning_rate * 0.01, flags.init_learning_rate * 0.001
776
  ])
777
  return learning_rate
778

779

780
class ModelLoader:
781
  """The class definition for the evaluation module."""
782

783
  def __init__(self, model_path, batch_size, train_dataset, test_dataset):
784
    self.train_batch_size = batch_size
785
    self.test_batch_size = batch_size
786
    self.test_dataset = test_dataset
787
    self.train_dataset = train_dataset
788

789
    latest_checkpoint = tf.train.latest_checkpoint(
790
        checkpoint_dir=os.path.join(model_path, 'train'))
791
    print(latest_checkpoint)
792
    step = int(os.path.basename(latest_checkpoint).split('-')[1])
793
    flags = Namespace(
794
        utils.load_and_save_params(default_params=dict(), exp_dir=model_path))
795
    image_size = data_loader.get_image_size(flags.dataset)
796
    self.flags = flags
797

798
    with tf.Graph().as_default():
799
      self.tensor_images, self.tensor_labels = placeholder_inputs(
800
          batch_size=self.train_batch_size,
801
          image_size=image_size,
802
          scope='inputs')
803
      if flags.dataset == 'cifar10' or flags.dataset == 'cifar100':
804
        tensor_images_aug = data_loader.augment_cifar(
805
            self.tensor_images, is_training=False)
806
      else:
807
        tensor_images_aug = data_loader.augment_tinyimagenet(
808
            self.tensor_images, is_training=False)
809
      model = build_model(flags)
810
      with tf.variable_scope('Proto_training'):
811
        self.representation, self.variance = build_feature_extractor_graph(
812
            inputs=tensor_images_aug,
813
            flags=flags,
814
            is_variance=True,
815
            is_training=False,
816
            model=model)
817
      self.tensor_train_rep, self.tensor_test_rep, \
818
      self.tensor_train_rep_label, self.tensor_test_rep_label,\
819
      self.center = get_class_center_for_evaluation(
820
          self.train_batch_size, self.test_batch_size, flags.num_classes_total)
821

822
      self.prediction, self.acc \
823
        = make_predictions_for_evaluation(self.center,
824
                                          self.tensor_test_rep,
825
                                          self.tensor_test_rep_label,
826
                                          self.flags)
827
      self.tensor_test_variance = tf.placeholder(
828
          shape=[self.test_batch_size, feature_dim], dtype=tf.float32)
829
      self.nll, self.confidence = confidence_estimation_and_evaluation(
830
          self.center, self.tensor_test_rep, self.tensor_test_variance,
831
          self.tensor_test_rep_label, flags)
832

833
      config = tf.ConfigProto(allow_soft_placement=True)
834
      config.gpu_options.allow_growth = True
835
      self.sess = tf.Session(config=config)
836
      # Runs init before loading the weights
837
      self.sess.run(tf.global_variables_initializer())
838
      # Loads weights
839
      saver = tf.train.Saver()
840
      saver.restore(self.sess, latest_checkpoint)
841
      self.flags = flags
842
      self.step = step
843
      log_dir = flags.log_dir
844
      graphpb_txt = str(tf.get_default_graph().as_graph_def())
845
      with open(os.path.join(log_dir, 'eval', 'graph.pbtxt'), 'w') as f:
846
        f.write(graphpb_txt)
847

848
  def eval_ece(self, pred_logits_np, pred_np, label_np, num_bins):
849
    """Calculates ECE.
850

851
    Args:
852
      pred_logits_np: the softmax output at the dimension of the predicted
853
        labels of test samples.
854
      pred_np:  the numpy array of the predicted labels of test samples.
855
      label_np:  the numpy array of the ground-truth labels of test samples.
856
      num_bins: the number of bins to partition all samples. we set it as 15.
857

858
    Returns:
859
      ece: the calculated ECE value.
860
    """
861
    acc_tab = np.zeros(num_bins)  # Empirical (true) confidence
862
    mean_conf = np.zeros(num_bins)  # Predicted confidence
863
    nb_items_bin = np.zeros(num_bins)  # Number of items in the bins
864
    tau_tab = np.linspace(
865
        min(pred_logits_np), max(pred_logits_np),
866
        num_bins + 1)  # Confidence bins
867
    tau_tab = np.linspace(0, 1, num_bins + 1)  # Confidence bins
868
    for i in np.arange(num_bins):  # Iterates over the bins
869
      # Selects the items where the predicted max probability falls in the bin
870
      # [tau_tab[i], tau_tab[i + 1)]
871
      sec = (tau_tab[i + 1] > pred_logits_np) & (pred_logits_np >= tau_tab[i])
872
      nb_items_bin[i] = np.sum(sec)  # Number of items in the bin
873
      # Selects the predicted classes, and the true classes
874
      class_pred_sec, y_sec = pred_np[sec], label_np[sec]
875
      # Averages of the predicted max probabilities
876
      mean_conf[i] = np.mean(
877
          pred_logits_np[sec]) if nb_items_bin[i] > 0 else np.nan
878
      # Computes the empirical confidence
879
      acc_tab[i] = np.mean(
880
          class_pred_sec == y_sec) if nb_items_bin[i] > 0 else np.nan
881
    # Cleaning
882
    mean_conf = mean_conf[nb_items_bin > 0]
883
    acc_tab = acc_tab[nb_items_bin > 0]
884
    nb_items_bin = nb_items_bin[nb_items_bin > 0]
885
    if sum(nb_items_bin) != 0:
886
      ece = np.average(
887
          np.absolute(mean_conf - acc_tab),
888
          weights=nb_items_bin.astype(float) / np.sum(nb_items_bin))
889
    else:
890
      ece = 0.0
891
    return ece
892

893
  def eval_acc_nll_ece(self, num_cases_train, num_cases_test):
894
    """Returns evaluation metrics.
895

896
    Args:
897
      num_cases_train: the total number of training samples.
898
      num_cases_test:  the total number of test samples.
899

900
    Returns:
901
      num_correct / num_cases_test: the accuracy of the evaluation.
902
      nll: the calculated NLL value.
903
      ece: the calculated ECE value.
904
    """
905
    num_batches_train = num_cases_train // self.train_batch_size
906
    num_batches_test = num_cases_test // self.test_batch_size
907
    num_correct = 0.0
908
    features_train_np = []
909
    features_test_np = []
910
    variance_test_np = []
911
    for i in trange(num_batches_train):
912
      images_train = self.train_dataset[0][(
913
          i * self.train_batch_size):((i + 1) * self.train_batch_size)]
914
      feed_dict = {self.tensor_images: images_train.astype(dtype=np.float32)}
915
      [features_train_batch] = self.sess.run([self.representation], feed_dict)
916
      features_train_np.extend(features_train_batch)
917
    features_train_np = np.concatenate(features_train_np, axis=0)
918
    for i in trange(num_batches_test):
919
      images_test = self.test_dataset[0][(
920
          i * self.test_batch_size):((i + 1) * self.test_batch_size)]
921
      feed_dict = {self.tensor_images: images_test.astype(dtype=np.float32)}
922
      [features_test_batch, variances_test_batch
923
      ] = self.sess.run([self.representation, self.variance], feed_dict)
924
      features_test_np.extend(features_test_batch)
925
      variance_test_np.extend(variances_test_batch)
926
    features_test_np = np.concatenate(features_test_np, axis=0)
927
    variance_test_np = np.concatenate(variance_test_np, axis=0)
928

929
    # Computes class centers.
930
    features_train_batch = features_train_np[:self.train_batch_size]
931
    feed_dict = {
932
        self.tensor_train_rep:
933
            features_train_batch,
934
        self.tensor_train_rep_label:
935
            self.train_dataset[1][:self.train_batch_size]
936
    }
937
    [centroid] = self.sess.run([self.center], feed_dict)
938
    for i in trange(1, num_batches_train):
939
      features_train_batch = features_train_np[(
940
          i * self.train_batch_size):((i + 1) * self.train_batch_size)]
941
      feed_dict = {
942
          self.tensor_train_rep:
943
              features_train_batch,
944
          self.tensor_train_rep_label:
945
              self.train_dataset[1]
946
              [(i * self.train_batch_size):((i + 1) * self.train_batch_size)]
947
      }
948
      [centroid_batch] = self.sess.run([self.center], feed_dict)
949
      centroid = centroid + centroid_batch
950

951
    centroid = centroid / self.flags.num_samples_per_class
952
    pred_list = []
953
    confidence_list = []
954
    nll_sum = 0
955
    for i in trange(num_batches_test):
956
      features_test_batch = features_test_np[(
957
          i * self.test_batch_size):((i + 1) * self.test_batch_size)]
958
      variance_test_batch = variance_test_np[(
959
          i * self.test_batch_size):((i + 1) * self.test_batch_size)]
960
      feed_dict = {
961
          self.center:
962
              centroid,
963
          self.tensor_test_rep:
964
              features_test_batch,
965
          self.tensor_test_variance:
966
              variance_test_batch,
967
          self.tensor_test_rep_label:
968
              self.test_dataset[1]
969
              [(i * self.test_batch_size):((i + 1) * self.test_batch_size)]
970
      }
971
      [prediction,
972
       num_correct_per_batch] = self.sess.run([self.prediction, self.acc],
973
                                              feed_dict)
974
      num_correct += num_correct_per_batch
975
      pred_list.append(prediction)
976

977
      [nll, confidence] = self.sess.run([self.nll, self.confidence], feed_dict)
978
      confidence_list.append(confidence)
979
      nll_sum += nll
980
    pred_np = np.concatenate(pred_list, axis=0)
981
    confidence_np = np.concatenate(confidence_list, axis=0)
982

983
    # The definition of NLL can be found at "On calibration of modern neural
984
    # networks." Guo, Chuan, et al. Proceedings of the 34th International
985
    # Conference on Machine Learning-Volume 70. JMLR. org, 2017. NLL averages
986
    # the negative log-likelihood of all test samples.
987

988
    nll = nll_sum / num_cases_test
989

990
    # The definition of ECE can be found at "On calibration of modern neural
991
    # networks." Guo, Chuan, et al. Proceedings of the 34th International
992
    # Conference on Machine Learning-Volume 70. JMLR. org, 2017. ECE
993
    # approximates the expectation of the difference between accuracy and
994
    # confidence. It partitions the confidence estimations (the likelihood of
995
    # the predicted label) of all test samples into L equally-spaced bins and
996
    # calculates the average confidence and accuracy of test samples lying in
997
    # each bin.
998

999
    ece = self.eval_ece(confidence_np, pred_np, self.test_dataset[1], 15)
1000
    print('acc: ' + str(num_correct / num_cases_test))
1001
    print('nll: ')
1002
    print(nll)
1003
    print('ece: ')
1004
    print(ece)
1005

1006
    return num_correct / num_cases_test, nll, ece
1007

1008

1009
def eval(flags, train_dataset, test_dataset):
1010
  """Evaluation entry point."""
1011
  # pylint: disable=redefined-builtin
1012
  log_dir = flags.log_dir
1013
  eval_writer = utils.summary_writer(log_dir + '/eval')
1014
  results = {}
1015
  model = ModelLoader(
1016
      model_path=flags.pretrained_model_dir,
1017
      batch_size=10000,
1018
      train_dataset=train_dataset,
1019
      test_dataset=test_dataset)
1020
  acc_tst, nll, ece \
1021
      = model.eval_acc_nll_ece(flags.num_cases_train, flags.num_cases_test)
1022

1023
  results['accuracy_target_tst'] = acc_tst
1024
  results['nll'] = nll
1025
  results['ece'] = ece
1026
  eval_writer(model.step, **results)
1027
  logging.info('accuracy_%s: %.3g.', 'target_tst', acc_tst)
1028

1029

1030
def main(argv=None):
1031
  # pylint: disable=unused-argument
1032
  # pylint: disable=unused-variable
1033
  config = tf.ConfigProto(allow_soft_placement=True)
1034
  config.gpu_options.per_process_gpu_memory_fraction = 1.0
1035
  config.gpu_options.allow_growth = True
1036
  sess = tf.Session(config=config)
1037

1038
  # Gets parameters.
1039
  default_params = get_arguments()
1040

1041
  # Creates the experiment directory.
1042
  log_dir = default_params.log_dir
1043
  ad = pathlib.Path(log_dir)
1044
  if not ad.exists():
1045
    ad.mkdir(parents=True)
1046

1047
  # Main function for training and evaluation.
1048
  flags = Namespace(utils.load_and_save_params(vars(default_params),
1049
                                               log_dir, ignore_existing=True))
1050
  train(flags=flags)
1051

1052

1053
if __name__ == '__main__':
1054
  tf.app.run()
1055

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.