google-research
1054 строки · 38.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Training and evaluation using Distance Based Learning from Errors (DBLE)."""
17
18from __future__ import print_function
19
20import argparse
21import logging
22import os
23import sys
24import time
25
26import numpy as np
27import pathlib
28import tensorflow.compat.v1 as tf
29from tqdm import trange
30
31sys.path.insert(0, '..')
32# pylint: disable=g-import-not-at-top
33from dble import data_loader
34from dble import mlp
35from dble import resnet
36from dble import utils
37from dble import vgg
38from tensorflow.contrib import slim as contrib_slim
39
40tf.logging.set_verbosity(tf.logging.INFO)
41logging.basicConfig(level=logging.INFO)
42
43feature_dim = 512
44
45
46class Namespace(object):
47
48def __init__(self, adict):
49self.__dict__.update(adict)
50
51
52def get_arguments():
53"""Processes all parameters."""
54parser = argparse.ArgumentParser()
55
56# Dataset parameters
57parser.add_argument(
58'--data_dir',
59type=str,
60default='',
61help='Path to the data, only for Tiny-ImageNet.')
62parser.add_argument(
63'--val_data_dir',
64type=str,
65default='',
66help='Path to the validation data, only for Tiny-ImageNet.')
67
68# Training parameters
69parser.add_argument(
70'--number_of_steps',
71type=int,
72default=int(120000),
73help='Number of training steps.')
74parser.add_argument(
75'--number_of_steps_to_early_stop',
76type=int,
77default=int(75000),
78help='Number of training steps after half way to early stop.')
79parser.add_argument(
80'--log_dir',
81type=str,
82default='/tmp/cifar_10/',
83help='experiment directory.')
84parser.add_argument(
85'--num_tasks_per_batch',
86type=int,
87default=2,
88help='Number of few shot episodes per batch.')
89parser.add_argument(
90'--init_learning_rate',
91type=float,
92default=0.1,
93help='Initial learning rate.')
94parser.add_argument(
95'--save_summaries_secs',
96type=int,
97default=300,
98help='Time between saving summaries')
99parser.add_argument(
100'--save_interval_secs',
101type=int,
102default=300,
103help='Time between saving models.')
104parser.add_argument(
105'--optimizer', type=str, default='sgd', choices=['sgd', 'adam'])
106
107# Optimization parameters
108parser.add_argument(
109'--lr_anneal',
110type=str,
111default='pwc',
112choices=['const', 'pwc', 'cos', 'exp'])
113parser.add_argument('--n_lr_decay', type=int, default=3)
114parser.add_argument('--lr_decay_rate', type=float, default=10.0)
115parser.add_argument(
116'--num_steps_decay_pwc',
117type=int,
118default=10000,
119help='Decay learning rate every num_steps_decay_pwc')
120parser.add_argument(
121'--clip_gradient_norm',
122type=float,
123default=1.0,
124help='gradient clip norm.')
125parser.add_argument(
126'--weights_initializer_factor',
127type=float,
128default=0.1,
129help='multiplier in the variance of the initialization noise.')
130
131# Evaluation parameters
132parser.add_argument(
133'--eval_interval_secs',
134type=int,
135default=0,
136help='Time between evaluating model.')
137parser.add_argument(
138'--eval_interval_steps',
139type=int,
140default=2000,
141help='Number of train steps between evaluating model in training loop.')
142parser.add_argument(
143'--eval_interval_fine_steps',
144type=int,
145default=1000,
146help='Number of train steps between evaluating model in the final phase.')
147
148# Architecture parameters
149parser.add_argument('--conv_keepdim', type=float, default=0.5)
150parser.add_argument('--neck', type=bool, default=False, help='')
151parser.add_argument('--num_forward', type=int, default=10, help='')
152parser.add_argument('--weight_decay', type=float, default=0.0005)
153parser.add_argument('--num_cases_train', type=int, default=50000)
154parser.add_argument('--num_cases_test', type=int, default=10000)
155parser.add_argument('--model_name', type=str, default='vgg')
156parser.add_argument('--dataset', type=str, default='cifar10')
157parser.add_argument(
158'--num_samples_per_class', type=int, default=5000, help='')
159parser.add_argument(
160'--num_classes_total',
161type=int,
162default=10,
163help='Number of classes in total of the data set.')
164parser.add_argument(
165'--num_classes_test',
166type=int,
167default=10,
168help='Number of classes in the test phase. ')
169parser.add_argument(
170'--num_classes_train',
171type=int,
172default=10,
173help='Number of classes in a protoypical episode.')
174parser.add_argument(
175'--num_shots_train',
176type=int,
177default=10,
178help='Number of shots (support samples) in a prototypical episode.')
179parser.add_argument(
180'--train_batch_size',
181type=int,
182default=100,
183help='The size of the query batch in a prototypical episode.')
184
185args, _ = parser.parse_known_args()
186print(args)
187return args
188
189
190def build_feature_extractor_graph(inputs,
191flags,
192is_variance,
193is_training=False,
194model=None):
195"""Calculates the representations and variances for inputs.
196
197Args:
198inputs: The input batch with shape (batch_size, height, width,
199num_channels). The batch_size of a support batch is num_classes_per_task
200*num_supports_per_class*num_tasks. The batch_size of a query batch is
201query_batch_size_per_task*num_tasks.
202flags: The hyperparameter dictionary.
203is_variance: The bool value of whether to calculate variances for every
204training sample. For support samples, calculating variaces is not
205required.
206is_training: The bool value of whether to use training mode.
207model: The representation model defined in function train(flags).
208
209Returns:
210h: The representations of the input batch with shape
211(batch_size, feature_dim).
212variance: The variances of the input batch with shape
213(batch_size, feature_dim).
214"""
215variance = None
216with tf.variable_scope('feature_extractor', reuse=tf.AUTO_REUSE):
217h = model.encoder(inputs, training=is_training)
218if is_variance:
219variance = model.confidence_model(h, training=is_training)
220embedding_shape = h.get_shape().as_list()
221if is_training:
222h = tf.reshape(
223h,
224shape=(flags.num_tasks_per_batch,
225embedding_shape[0] // flags.num_tasks_per_batch, -1),
226name='reshape_to_multi_task_format')
227if is_variance:
228variance = tf.reshape(
229variance,
230shape=(flags.num_tasks_per_batch,
231embedding_shape[0] // flags.num_tasks_per_batch, -1),
232name='reshape_to_multi_task_format')
233else:
234h = tf.reshape(
235h,
236shape=(1, embedding_shape[0], -1),
237name='reshape_to_multi_task_format')
238if is_variance:
239variance = tf.reshape(
240variance,
241shape=(1, embedding_shape[0], -1),
242name='reshape_to_multi_task_format')
243
244return h, variance
245
246
247def calculate_class_center(support_embeddings,
248flags,
249is_training,
250scope='class_center_calculator'):
251"""Calculates the class centers for every episode given support embeddings.
252
253Args:
254support_embeddings: The support embeddings with shape
255(num_classes_per_task*num_supports_per_class*num_tasks, height, width,
256num_channels).
257flags: The hyperparameter dictionary.
258is_training: The bool value of whether to use training mode.
259scope: The name of the variable scope.
260
261Returns:
262class_center: The representations of the class centers with shape
263(num_supports_per_class*num_tasks, feature_dim).
264"""
265with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
266class_center = support_embeddings
267if is_training:
268class_center = tf.reshape(
269class_center,
270shape=(flags.num_tasks_per_batch, flags.num_classes_train,
271flags.num_shots_train, -1),
272name='reshape_to_multi_task_format')
273else:
274class_center = tf.reshape(
275class_center,
276shape=(1, flags.num_classes_test, flags.num_shots_train, -1),
277name='reshape_to_multi_task_format')
278class_center = tf.reduce_mean(class_center, axis=2, keep_dims=False)
279
280return class_center
281
282
283def build_euclidean_calculator(query_representation,
284class_center,
285flags,
286scope='prototypical_head'):
287"""Calculates the negative Euclidean distance of queries to class centers.
288
289Args:
290query_representation: The query embeddings with shape (num_tasks,
291query_batch_size_per_task, feature_dim).
292class_center: The representations of class centers with shape (num_tasks,
293num_training_class, feature_dim).
294flags: The hyperparameter dictionary.
295scope: The name of the variable scope.
296
297Returns:
298negative_euclidean: The negative euclidean distance of queries to the class
299centers in their episodes. The shape of negative_euclidean is (num_tasks,
300query_batch_size_per_task, num_training_class).
301"""
302with tf.variable_scope(scope):
303if len(query_representation.get_shape().as_list()) == 2:
304query_representation = tf.expand_dims(query_representation, axis=0)
305if len(class_center.get_shape().as_list()) == 2:
306class_center = tf.expand_dims(class_center, axis=0)
307
308# j is the number of classes in each episode
309# i is the number of queries in each episode
310j = class_center.get_shape().as_list()[1]
311i = query_representation.get_shape().as_list()[1]
312print('task_encoding_shape:' + str(j))
313
314# tile to be able to produce weight matrix alpha in (i,j) space
315query_representation = tf.expand_dims(query_representation, axis=2)
316class_center = tf.expand_dims(class_center, axis=1)
317# features_generic changes over i and is constant over j
318# task_encoding changes over j and is constant over i
319class_center_tile = tf.tile(class_center, (1, i, 1, 1))
320query_representation_tile = tf.tile(query_representation, (1, 1, j, 1))
321negative_euclidean = -tf.norm(
322(class_center_tile - query_representation_tile),
323name='neg_euclidean_distance',
324axis=-1)
325negative_euclidean = tf.reshape(
326negative_euclidean, shape=(flags.num_tasks_per_batch * i, -1))
327
328return negative_euclidean
329
330
331def build_proto_train_graph(images_query, images_support, flags, is_training,
332model):
333"""Builds the tf graph of dble's prototypical training.
334
335Args:
336images_query: The processed query batch with shape
337(query_batch_size_per_task*num_tasks, height, width, num_channels).
338images_support: The processed support batch with shape
339(num_classes_per_task*num_supports_per_class*num_tasks, height, width,
340num_channels).
341flags: The hyperparameter dictionary.
342is_training: The bool value of whether to use training mode.
343model: The model defined in the main train function.
344
345Returns:
346logits: The logits before softmax (negative Euclidean) of the batch
347calculated with the original representations (mu in the paper) of queries.
348logits_z: The logits before softmax (negative Euclidean) of the batch
349calculated with the sampled representations (z in the paper) of queries.
350"""
351
352with tf.variable_scope('Proto_training'):
353support_representation, _ = build_feature_extractor_graph(
354inputs=images_support,
355flags=flags,
356is_variance=False,
357is_training=is_training,
358model=model)
359class_center = calculate_class_center(
360support_embeddings=support_representation,
361flags=flags,
362is_training=is_training)
363query_representation, query_variance = build_feature_extractor_graph(
364inputs=images_query,
365flags=flags,
366is_variance=True,
367is_training=is_training,
368model=model)
369
370logits = build_euclidean_calculator(query_representation, class_center,
371flags)
372eps = tf.random.normal(shape=query_representation.shape)
373z = eps * tf.exp((query_variance) * .5) + query_representation
374logits_z = build_euclidean_calculator(z, class_center, flags)
375
376return logits, logits_z
377
378
379def placeholder_inputs(batch_size, image_size, scope):
380"""Builds the placeholders for the training images and labels."""
381with tf.variable_scope(scope):
382if image_size != 28: # not mnist:
383images_placeholder = tf.placeholder(
384tf.float32,
385shape=(batch_size, image_size, image_size, 3),
386name='images')
387else:
388images_placeholder = tf.placeholder(
389tf.float32, shape=(batch_size, 784), name='images')
390labels_placeholder = tf.placeholder(
391tf.int32, shape=(batch_size), name='labels')
392return images_placeholder, labels_placeholder
393
394
395def build_episode_placeholder(flags):
396"""Builds the placeholders for the support and query input batches."""
397image_size = data_loader.get_image_size(flags.dataset)
398images_query_pl, labels_query_pl = placeholder_inputs(
399batch_size=flags.num_tasks_per_batch * flags.train_batch_size,
400image_size=image_size,
401scope='inputs/query')
402images_support_pl, labels_support_pl = placeholder_inputs(
403batch_size=flags.num_tasks_per_batch * flags.num_classes_train *
404flags.num_shots_train,
405image_size=image_size,
406scope='inputs/support')
407
408return images_query_pl, labels_query_pl, images_support_pl, labels_support_pl
409
410
411def build_model(flags):
412"""Builds model according to flags.
413
414For image data types, we considered ResNet and VGG models. One can use DBLE
415with other data types, by choosing a model with appropriate inductive bias
416for feature extraction, e.g. WaveNet for speech or BERT for text.
417Args:
418flags: The hyperparameter dictionary.
419
420Returns:
421mlp_model: the mlp model instance.
422vgg_model: the vgg model instance.
423resnet_model: the resnet model instance.
424"""
425if flags.model_name == 'vgg':
426# Primary task operations
427vgg_model = vgg.vgg11(
428keep_prob=flags.conv_keepdim,
429wd=flags.weight_decay,
430neck=flags.neck,
431feature_dim=feature_dim)
432return vgg_model
433elif flags.model_name == 'mlp':
434mlp_model = mlp(
435keep_prob=flags.conv_keepdim,
436feature_dim=feature_dim,
437wd=flags.weight_decay)
438return mlp_model
439elif flags.model_name == 'resnet':
440if flags.dataset == 'cifar10' or flags.dataset == 'cifar100':
441resnet_model = resnet.Model(
442wd=flags.weight_decay,
443resnet_size=50,
444bottleneck=True,
445num_classes=flags.num_classes_train,
446num_filters=16,
447kernel_size=3,
448conv_stride=1,
449first_pool_size=None,
450first_pool_stride=None,
451block_sizes=[8, 8, 8],
452block_strides=[1, 2, 2],
453data_format='channels_last',
454feature_dim=feature_dim)
455else:
456resnet_model = resnet.Model(
457wd=flags.weight_decay,
458resnet_size=50,
459bottleneck=True,
460num_classes=flags.num_classes_train,
461num_filters=16,
462kernel_size=3,
463conv_stride=1,
464first_pool_size=3,
465first_pool_stride=1,
466block_sizes=[3, 4, 6, 3],
467block_strides=[1, 2, 2, 2],
468data_format='channels_last',
469feature_dim=feature_dim)
470return resnet_model
471
472
473def train(flags):
474"""Training entry point."""
475log_dir = flags.log_dir
476flags.pretrained_model_dir = log_dir
477log_dir = os.path.join(log_dir, 'train')
478flags.eval_interval_secs = 0
479with tf.Graph().as_default():
480global_step = tf.Variable(
4810, trainable=False, name='global_step', dtype=tf.int64)
482global_step_confidence = tf.Variable(
4830, trainable=False, name='global_step_confidence', dtype=tf.int64)
484
485model = build_model(flags)
486images_query_pl, labels_query_pl, \
487images_support_pl, labels_support_pl = \
488build_episode_placeholder(flags)
489
490# Augments the input.
491if flags.dataset == 'cifar10' or flags.dataset == 'cifar100':
492images_query_pl_aug = data_loader.augment_cifar(
493images_query_pl, is_training=True)
494images_support_pl_aug = data_loader.augment_cifar(
495images_support_pl, is_training=True)
496elif flags.dataset == 'tinyimagenet':
497images_query_pl_aug = data_loader.augment_tinyimagenet(
498images_query_pl, is_training=True)
499images_support_pl_aug = data_loader.augment_tinyimagenet(
500images_support_pl, is_training=True)
501
502logits, logits_z = build_proto_train_graph(
503images_query=images_query_pl_aug,
504images_support=images_support_pl_aug,
505flags=flags,
506is_training=True,
507model=model)
508# Losses and optimizer
509## Classification loss
510loss_classification = tf.reduce_mean(
511tf.nn.softmax_cross_entropy_with_logits(
512logits=logits,
513labels=tf.one_hot(labels_query_pl, flags.num_classes_train)))
514
515# Confidence loss
516_, top_k_indices = tf.nn.top_k(logits, k=1)
517pred = tf.squeeze(top_k_indices)
518incorrect_mask = tf.math.logical_not(tf.math.equal(pred, labels_query_pl))
519incorrect_logits_z = tf.boolean_mask(logits_z, incorrect_mask)
520incorrect_labels_z = tf.boolean_mask(labels_query_pl, incorrect_mask)
521signal_variance = tf.math.reduce_sum(tf.cast(incorrect_mask, tf.int32))
522loss_variance_incorrect = tf.reduce_mean(
523tf.nn.softmax_cross_entropy_with_logits(
524logits=incorrect_logits_z,
525labels=tf.one_hot(incorrect_labels_z, flags.num_classes_train)))
526loss_variance_zero = 0.0
527loss_confidence = tf.cond(
528tf.greater(signal_variance, 0), lambda: loss_variance_incorrect,
529lambda: loss_variance_zero)
530
531regu_losses = tf.losses.get_regularization_losses()
532loss = tf.add_n([loss_classification] + regu_losses)
533
534# Learning rate
535if flags.lr_anneal == 'const':
536learning_rate = flags.init_learning_rate
537elif flags.lr_anneal == 'pwc':
538learning_rate = get_pwc_learning_rate(global_step, flags)
539elif flags.lr_anneal == 'exp':
540lr_decay_step = flags.number_of_steps // flags.n_lr_decay
541learning_rate = tf.train.exponential_decay(
542flags.init_learning_rate,
543global_step,
544lr_decay_step,
5451.0 / flags.lr_decay_rate,
546staircase=True)
547else:
548raise Exception('Not implemented')
549
550# Optimizer
551optimizer = tf.train.MomentumOptimizer(
552learning_rate=learning_rate, momentum=0.9)
553optimizer_confidence = tf.train.MomentumOptimizer(
554learning_rate=learning_rate, momentum=0.9)
555
556train_op = contrib_slim.learning.create_train_op(
557total_loss=loss,
558optimizer=optimizer,
559global_step=global_step,
560clip_gradient_norm=flags.clip_gradient_norm)
561variable_variance = []
562for v in tf.trainable_variables():
563if 'fc_variance' in v.name:
564variable_variance.append(v)
565train_op_confidence = contrib_slim.learning.create_train_op(
566total_loss=loss_confidence,
567optimizer=optimizer_confidence,
568global_step=global_step_confidence,
569clip_gradient_norm=flags.clip_gradient_norm,
570variables_to_train=variable_variance)
571
572tf.summary.scalar('loss', loss)
573tf.summary.scalar('loss_classification', loss_classification)
574tf.summary.scalar('loss_variance', loss_confidence)
575tf.summary.scalar('regu_loss', tf.add_n(regu_losses))
576tf.summary.scalar('learning_rate', learning_rate)
577# Merges all summaries except for pretrain
578summary = tf.summary.merge(
579tf.get_collection('summaries', scope='(?!pretrain).*'))
580
581# Gets datasets
582few_shot_data_train, test_dataset, train_dataset = get_train_datasets(flags)
583# Defines session and logging
584summary_writer_train = tf.summary.FileWriter(log_dir, flush_secs=1)
585saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True)
586print(saver.saver_def.filename_tensor_name)
587print(saver.saver_def.restore_op_name)
588# pylint: disable=unused-variable
589run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
590run_metadata = tf.RunMetadata()
591supervisor = tf.train.Supervisor(
592logdir=log_dir,
593init_feed_dict=None,
594summary_op=None,
595init_op=tf.global_variables_initializer(),
596summary_writer=summary_writer_train,
597saver=saver,
598global_step=global_step,
599save_summaries_secs=flags.save_summaries_secs,
600save_model_secs=0)
601
602with supervisor.managed_session() as sess:
603checkpoint_step = sess.run(global_step)
604if checkpoint_step > 0:
605checkpoint_step += 1
606eval_interval_steps = flags.eval_interval_steps
607for step in range(checkpoint_step, flags.number_of_steps):
608# Computes the classification loss using a batch of data.
609images_query, labels_query,\
610images_support, labels_support = \
611few_shot_data_train.next_few_shot_batch(
612query_batch_size_per_task=flags.train_batch_size,
613num_classes_per_task=flags.num_classes_train,
614num_supports_per_class=flags.num_shots_train,
615num_tasks=flags.num_tasks_per_batch)
616
617feed_dict = {
618images_query_pl: images_query.astype(dtype=np.float32),
619labels_query_pl: labels_query,
620images_support_pl: images_support.astype(dtype=np.float32),
621labels_support_pl: labels_support
622}
623
624t_batch = time.time()
625dt_batch = time.time() - t_batch
626
627t_train = time.time()
628loss, loss_confidence = sess.run([train_op, train_op_confidence],
629feed_dict=feed_dict)
630dt_train = time.time() - t_train
631
632if step % 100 == 0:
633summary_str = sess.run(summary, feed_dict=feed_dict)
634summary_writer_train.add_summary(summary_str, step)
635summary_writer_train.flush()
636logging.info('step %d, loss : %.4g, dt: %.3gs, dt_batch: %.3gs', step,
637loss, dt_train, dt_batch)
638
639if float(step) / flags.number_of_steps > 0.5:
640eval_interval_steps = flags.eval_interval_fine_steps
641
642if eval_interval_steps > 0 and step % eval_interval_steps == 0:
643saver.save(sess, os.path.join(log_dir, 'model'), global_step=step)
644eval(
645flags=flags,
646train_dataset=train_dataset,
647test_dataset=test_dataset)
648
649if float(
650step
651) > 0.5 * flags.number_of_steps + flags.number_of_steps_to_early_stop:
652break
653
654
655def get_class_center_for_evaluation(train_bs, test_bs, num_classes):
656"""The tf graph of calculating class centers at eval given training data."""
657x_train = tf.placeholder(shape=[train_bs, feature_dim], dtype=tf.float32)
658x_test = tf.placeholder(shape=[test_bs, feature_dim], dtype=tf.float32)
659y_train = tf.placeholder(
660shape=[
661train_bs,
662], dtype=tf.int32)
663y_test = tf.placeholder(
664shape=[
665test_bs,
666], dtype=tf.int32)
667
668# Finds the class centers for the training data. class label should be 0-N
669ind_c = tf.squeeze(tf.where(tf.equal(y_train, 0)))
670train_input_c = tf.gather(x_train, ind_c)
671train_input_c = tf.expand_dims(train_input_c, 0)
672centroid = tf.reduce_sum(train_input_c, 1)
673
674for i in range(1, num_classes):
675ind_c = tf.squeeze(tf.where(tf.equal(y_train, i)))
676tmp_input_c = tf.gather(x_train, ind_c)
677tmp_input_c = tf.expand_dims(tmp_input_c, 0)
678tmp_centroid = tf.reduce_sum(tmp_input_c, 1)
679centroid = tf.concat([centroid, tmp_centroid], 0)
680
681return x_train, x_test, y_train, y_test, centroid
682
683
684def make_predictions_for_evaluation(centroid, x_test, y_test, flags):
685"""The tf graph for making predictions given class centers and test data."""
686# Calculates the centroid pair-wise distance
687centroid_expand = tf.expand_dims(centroid, axis=0)
688# Calculates the test sample - centroid distance
689i = x_test.get_shape().as_list()[0]
690# Tiles to be able to produce weight matrix alpha in (i,j) space
691x_data_test = tf.expand_dims(x_test, axis=1)
692x_data_test = tf.tile(x_data_test, (1, flags.num_classes_total, 1))
693centroid_expand_test = tf.tile(centroid_expand, (i, 1, 1))
694euclidean = tf.norm(
695x_data_test - centroid_expand_test, name='euclidean_distance', axis=-1)
696# Prediction based on nearest neighbors
697_, top_k_indices = tf.nn.top_k(-euclidean, k=1)
698pred = tf.squeeze(top_k_indices)
699correct_mask = tf.cast(tf.math.equal(pred, y_test), tf.float32)
700correct = tf.reduce_sum(correct_mask, axis=0)
701
702return pred, correct
703
704
705def calculate_nll(y_test, softmaxed_logits, num_classes):
706y_test = tf.one_hot(y_test, depth=num_classes)
707nll = tf.reduce_sum(-tf.log(tf.reduce_sum(y_test * softmaxed_logits, axis=1)))
708
709return nll
710
711
712def confidence_estimation_and_evaluation(centroid, x_test, x_test_variance,
713y_test, flags):
714"""The tf graph for confidence estimation and NLL for a batch of test data."""
715centroid_expand = tf.expand_dims(centroid, axis=0)
716i = x_test.get_shape().as_list()[0]
717# Tiles to be able to produce weight matrix alpha in (i,j) space
718x_test_tile = tf.expand_dims(x_test, axis=1)
719x_test_tile = tf.tile(x_test_tile, (1, flags.num_classes_total, 1))
720centroid_expand_tile = tf.tile(centroid_expand, (i, 1, 1))
721distance = tf.norm(
722x_test_tile - centroid_expand_tile, name='euclidean_distance', axis=-1)
723
724softmaxed_distance = tf.math.softmax(-distance)
725_, top_k_indices = tf.nn.top_k(softmaxed_distance, k=1)
726pred = tf.squeeze(top_k_indices)
727
728# Calculates the distance with the correct label
729row = tf.constant(np.arange(pred.get_shape().as_list()[0]))
730row = tf.cast(row, tf.int32)
731
732eps = tf.random.normal(shape=x_test.shape)
733z = eps * tf.exp((x_test_variance) * .5) + x_test
734z_tile = tf.expand_dims(z, axis=1)
735z_tile = tf.tile(z_tile, (1, flags.num_classes_total, 1))
736distance = tf.norm(
737z_tile - centroid_expand_tile, name='euclidean_distance', axis=-1)
738softmaxed_distance = tf.math.softmax(-distance)
739
740for _ in range(1, flags.num_forward):
741eps = tf.random.normal(shape=x_test.shape)
742z = eps * tf.exp((x_test_variance) * .5) + x_test
743z_tile = tf.expand_dims(z, axis=1)
744z_tile = tf.tile(z_tile, (1, flags.num_classes_total, 1))
745distance = tf.norm(
746z_tile - centroid_expand_tile, name='euclidean_distance', axis=-1)
747softmaxed_distance += tf.math.softmax(-distance)
748softmaxed_distance = softmaxed_distance / flags.num_forward
749nll_sum = calculate_nll(y_test, softmaxed_distance, flags.num_classes_total)
750ind_tensor = tf.transpose(tf.stack([row, pred]))
751confidence = tf.gather_nd(softmaxed_distance, ind_tensor)
752
753return nll_sum, confidence
754
755
756def get_train_datasets(flags):
757if flags.dataset == 'cifar10':
758train_set, test_set = data_loader.load_cifar10()
759elif flags.dataset == 'cifar100':
760train_set, test_set = data_loader.load_cifar100()
761elif flags.dataset == 'tinyimagenet':
762train_set, test_set = data_loader.load_tiny_imagenet(
763flags.data_dir, flags.val_data_dir)
764episodic_train = data_loader.Dataset(train_set)
765return episodic_train, test_set, train_set
766
767
768def get_pwc_learning_rate(global_step, flags):
769learning_rate = tf.train.piecewise_constant(global_step, [
770np.int64(flags.number_of_steps / 2),
771np.int64(flags.number_of_steps / 2 + flags.num_steps_decay_pwc),
772np.int64(flags.number_of_steps / 2 + 2 * flags.num_steps_decay_pwc)
773], [
774flags.init_learning_rate, flags.init_learning_rate * 0.1,
775flags.init_learning_rate * 0.01, flags.init_learning_rate * 0.001
776])
777return learning_rate
778
779
780class ModelLoader:
781"""The class definition for the evaluation module."""
782
783def __init__(self, model_path, batch_size, train_dataset, test_dataset):
784self.train_batch_size = batch_size
785self.test_batch_size = batch_size
786self.test_dataset = test_dataset
787self.train_dataset = train_dataset
788
789latest_checkpoint = tf.train.latest_checkpoint(
790checkpoint_dir=os.path.join(model_path, 'train'))
791print(latest_checkpoint)
792step = int(os.path.basename(latest_checkpoint).split('-')[1])
793flags = Namespace(
794utils.load_and_save_params(default_params=dict(), exp_dir=model_path))
795image_size = data_loader.get_image_size(flags.dataset)
796self.flags = flags
797
798with tf.Graph().as_default():
799self.tensor_images, self.tensor_labels = placeholder_inputs(
800batch_size=self.train_batch_size,
801image_size=image_size,
802scope='inputs')
803if flags.dataset == 'cifar10' or flags.dataset == 'cifar100':
804tensor_images_aug = data_loader.augment_cifar(
805self.tensor_images, is_training=False)
806else:
807tensor_images_aug = data_loader.augment_tinyimagenet(
808self.tensor_images, is_training=False)
809model = build_model(flags)
810with tf.variable_scope('Proto_training'):
811self.representation, self.variance = build_feature_extractor_graph(
812inputs=tensor_images_aug,
813flags=flags,
814is_variance=True,
815is_training=False,
816model=model)
817self.tensor_train_rep, self.tensor_test_rep, \
818self.tensor_train_rep_label, self.tensor_test_rep_label,\
819self.center = get_class_center_for_evaluation(
820self.train_batch_size, self.test_batch_size, flags.num_classes_total)
821
822self.prediction, self.acc \
823= make_predictions_for_evaluation(self.center,
824self.tensor_test_rep,
825self.tensor_test_rep_label,
826self.flags)
827self.tensor_test_variance = tf.placeholder(
828shape=[self.test_batch_size, feature_dim], dtype=tf.float32)
829self.nll, self.confidence = confidence_estimation_and_evaluation(
830self.center, self.tensor_test_rep, self.tensor_test_variance,
831self.tensor_test_rep_label, flags)
832
833config = tf.ConfigProto(allow_soft_placement=True)
834config.gpu_options.allow_growth = True
835self.sess = tf.Session(config=config)
836# Runs init before loading the weights
837self.sess.run(tf.global_variables_initializer())
838# Loads weights
839saver = tf.train.Saver()
840saver.restore(self.sess, latest_checkpoint)
841self.flags = flags
842self.step = step
843log_dir = flags.log_dir
844graphpb_txt = str(tf.get_default_graph().as_graph_def())
845with open(os.path.join(log_dir, 'eval', 'graph.pbtxt'), 'w') as f:
846f.write(graphpb_txt)
847
848def eval_ece(self, pred_logits_np, pred_np, label_np, num_bins):
849"""Calculates ECE.
850
851Args:
852pred_logits_np: the softmax output at the dimension of the predicted
853labels of test samples.
854pred_np: the numpy array of the predicted labels of test samples.
855label_np: the numpy array of the ground-truth labels of test samples.
856num_bins: the number of bins to partition all samples. we set it as 15.
857
858Returns:
859ece: the calculated ECE value.
860"""
861acc_tab = np.zeros(num_bins) # Empirical (true) confidence
862mean_conf = np.zeros(num_bins) # Predicted confidence
863nb_items_bin = np.zeros(num_bins) # Number of items in the bins
864tau_tab = np.linspace(
865min(pred_logits_np), max(pred_logits_np),
866num_bins + 1) # Confidence bins
867tau_tab = np.linspace(0, 1, num_bins + 1) # Confidence bins
868for i in np.arange(num_bins): # Iterates over the bins
869# Selects the items where the predicted max probability falls in the bin
870# [tau_tab[i], tau_tab[i + 1)]
871sec = (tau_tab[i + 1] > pred_logits_np) & (pred_logits_np >= tau_tab[i])
872nb_items_bin[i] = np.sum(sec) # Number of items in the bin
873# Selects the predicted classes, and the true classes
874class_pred_sec, y_sec = pred_np[sec], label_np[sec]
875# Averages of the predicted max probabilities
876mean_conf[i] = np.mean(
877pred_logits_np[sec]) if nb_items_bin[i] > 0 else np.nan
878# Computes the empirical confidence
879acc_tab[i] = np.mean(
880class_pred_sec == y_sec) if nb_items_bin[i] > 0 else np.nan
881# Cleaning
882mean_conf = mean_conf[nb_items_bin > 0]
883acc_tab = acc_tab[nb_items_bin > 0]
884nb_items_bin = nb_items_bin[nb_items_bin > 0]
885if sum(nb_items_bin) != 0:
886ece = np.average(
887np.absolute(mean_conf - acc_tab),
888weights=nb_items_bin.astype(float) / np.sum(nb_items_bin))
889else:
890ece = 0.0
891return ece
892
893def eval_acc_nll_ece(self, num_cases_train, num_cases_test):
894"""Returns evaluation metrics.
895
896Args:
897num_cases_train: the total number of training samples.
898num_cases_test: the total number of test samples.
899
900Returns:
901num_correct / num_cases_test: the accuracy of the evaluation.
902nll: the calculated NLL value.
903ece: the calculated ECE value.
904"""
905num_batches_train = num_cases_train // self.train_batch_size
906num_batches_test = num_cases_test // self.test_batch_size
907num_correct = 0.0
908features_train_np = []
909features_test_np = []
910variance_test_np = []
911for i in trange(num_batches_train):
912images_train = self.train_dataset[0][(
913i * self.train_batch_size):((i + 1) * self.train_batch_size)]
914feed_dict = {self.tensor_images: images_train.astype(dtype=np.float32)}
915[features_train_batch] = self.sess.run([self.representation], feed_dict)
916features_train_np.extend(features_train_batch)
917features_train_np = np.concatenate(features_train_np, axis=0)
918for i in trange(num_batches_test):
919images_test = self.test_dataset[0][(
920i * self.test_batch_size):((i + 1) * self.test_batch_size)]
921feed_dict = {self.tensor_images: images_test.astype(dtype=np.float32)}
922[features_test_batch, variances_test_batch
923] = self.sess.run([self.representation, self.variance], feed_dict)
924features_test_np.extend(features_test_batch)
925variance_test_np.extend(variances_test_batch)
926features_test_np = np.concatenate(features_test_np, axis=0)
927variance_test_np = np.concatenate(variance_test_np, axis=0)
928
929# Computes class centers.
930features_train_batch = features_train_np[:self.train_batch_size]
931feed_dict = {
932self.tensor_train_rep:
933features_train_batch,
934self.tensor_train_rep_label:
935self.train_dataset[1][:self.train_batch_size]
936}
937[centroid] = self.sess.run([self.center], feed_dict)
938for i in trange(1, num_batches_train):
939features_train_batch = features_train_np[(
940i * self.train_batch_size):((i + 1) * self.train_batch_size)]
941feed_dict = {
942self.tensor_train_rep:
943features_train_batch,
944self.tensor_train_rep_label:
945self.train_dataset[1]
946[(i * self.train_batch_size):((i + 1) * self.train_batch_size)]
947}
948[centroid_batch] = self.sess.run([self.center], feed_dict)
949centroid = centroid + centroid_batch
950
951centroid = centroid / self.flags.num_samples_per_class
952pred_list = []
953confidence_list = []
954nll_sum = 0
955for i in trange(num_batches_test):
956features_test_batch = features_test_np[(
957i * self.test_batch_size):((i + 1) * self.test_batch_size)]
958variance_test_batch = variance_test_np[(
959i * self.test_batch_size):((i + 1) * self.test_batch_size)]
960feed_dict = {
961self.center:
962centroid,
963self.tensor_test_rep:
964features_test_batch,
965self.tensor_test_variance:
966variance_test_batch,
967self.tensor_test_rep_label:
968self.test_dataset[1]
969[(i * self.test_batch_size):((i + 1) * self.test_batch_size)]
970}
971[prediction,
972num_correct_per_batch] = self.sess.run([self.prediction, self.acc],
973feed_dict)
974num_correct += num_correct_per_batch
975pred_list.append(prediction)
976
977[nll, confidence] = self.sess.run([self.nll, self.confidence], feed_dict)
978confidence_list.append(confidence)
979nll_sum += nll
980pred_np = np.concatenate(pred_list, axis=0)
981confidence_np = np.concatenate(confidence_list, axis=0)
982
983# The definition of NLL can be found at "On calibration of modern neural
984# networks." Guo, Chuan, et al. Proceedings of the 34th International
985# Conference on Machine Learning-Volume 70. JMLR. org, 2017. NLL averages
986# the negative log-likelihood of all test samples.
987
988nll = nll_sum / num_cases_test
989
990# The definition of ECE can be found at "On calibration of modern neural
991# networks." Guo, Chuan, et al. Proceedings of the 34th International
992# Conference on Machine Learning-Volume 70. JMLR. org, 2017. ECE
993# approximates the expectation of the difference between accuracy and
994# confidence. It partitions the confidence estimations (the likelihood of
995# the predicted label) of all test samples into L equally-spaced bins and
996# calculates the average confidence and accuracy of test samples lying in
997# each bin.
998
999ece = self.eval_ece(confidence_np, pred_np, self.test_dataset[1], 15)
1000print('acc: ' + str(num_correct / num_cases_test))
1001print('nll: ')
1002print(nll)
1003print('ece: ')
1004print(ece)
1005
1006return num_correct / num_cases_test, nll, ece
1007
1008
1009def eval(flags, train_dataset, test_dataset):
1010"""Evaluation entry point."""
1011# pylint: disable=redefined-builtin
1012log_dir = flags.log_dir
1013eval_writer = utils.summary_writer(log_dir + '/eval')
1014results = {}
1015model = ModelLoader(
1016model_path=flags.pretrained_model_dir,
1017batch_size=10000,
1018train_dataset=train_dataset,
1019test_dataset=test_dataset)
1020acc_tst, nll, ece \
1021= model.eval_acc_nll_ece(flags.num_cases_train, flags.num_cases_test)
1022
1023results['accuracy_target_tst'] = acc_tst
1024results['nll'] = nll
1025results['ece'] = ece
1026eval_writer(model.step, **results)
1027logging.info('accuracy_%s: %.3g.', 'target_tst', acc_tst)
1028
1029
1030def main(argv=None):
1031# pylint: disable=unused-argument
1032# pylint: disable=unused-variable
1033config = tf.ConfigProto(allow_soft_placement=True)
1034config.gpu_options.per_process_gpu_memory_fraction = 1.0
1035config.gpu_options.allow_growth = True
1036sess = tf.Session(config=config)
1037
1038# Gets parameters.
1039default_params = get_arguments()
1040
1041# Creates the experiment directory.
1042log_dir = default_params.log_dir
1043ad = pathlib.Path(log_dir)
1044if not ad.exists():
1045ad.mkdir(parents=True)
1046
1047# Main function for training and evaluation.
1048flags = Namespace(utils.load_and_save_params(vars(default_params),
1049log_dir, ignore_existing=True))
1050train(flags=flags)
1051
1052
1053if __name__ == '__main__':
1054tf.app.run()
1055