google-research
523 строки · 20.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""The runners."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20import os
21import time
22import numpy as np
23
24import tensorflow.compat.v1 as tf
25from capsule_em import model as f_model
26from capsule_em.mnist \
27import mnist_record
28from capsule_em.norb \
29import norb_record
30
31
32FLAGS = tf.app.flags.FLAGS
33tf.app.flags.DEFINE_integer('num_prime_capsules', 32,
34'Number of first layer capsules.')
35tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate')
36tf.app.flags.DEFINE_integer('routing_iteration', 3,
37'Number of iterations for softmax routing')
38tf.app.flags.DEFINE_float(
39'routing_rate', 1,
40'ratio for combining routing logits and routing feedback')
41tf.app.flags.DEFINE_float('decay_rate', 0.96, 'ratio for learning rate decay')
42tf.app.flags.DEFINE_integer('decay_steps', 20000,
43'number of steps for learning rate decay')
44tf.app.flags.DEFINE_bool('normalize_kernels', False,
45'Normalize the capsule weight kernels')
46tf.app.flags.DEFINE_integer('num_second_atoms', 16,
47'number of capsule atoms for the second layer')
48tf.app.flags.DEFINE_integer('num_primary_atoms', 16,
49'number of capsule atoms for the first layer')
50tf.app.flags.DEFINE_integer('num_start_conv', 32,
51'number of channels for the start layer')
52tf.app.flags.DEFINE_integer('kernel_size', 5,
53'kernel size for the start layer.')
54tf.app.flags.DEFINE_integer(
55'routing_iteration_prime', 1,
56'number of routing iterations for primary capsules.')
57tf.app.flags.DEFINE_integer('max_steps', 2000000,
58'Number of steps to run trainer.')
59tf.app.flags.DEFINE_string('data_dir', '/datasets/mnist/',
60'Directory for storing input data')
61tf.app.flags.DEFINE_string('summary_dir',
62'/tmp/tensorflow/mnist/logs/mnist_with_summaries',
63'Summaries log directory')
64tf.app.flags.DEFINE_bool('train', True, 'train or test.')
65tf.app.flags.DEFINE_integer(
66'checkpoint_steps', 1500,
67'number of steps before saving a training checkpoint.')
68tf.app.flags.DEFINE_bool('verbose_image', False, 'whether to show images.')
69tf.app.flags.DEFINE_bool('multi', True,
70'whether to use multiple digit dataset.')
71tf.app.flags.DEFINE_bool('eval_once', False,
72'whether to evaluate once on the ckpnt file.')
73tf.app.flags.DEFINE_integer('eval_size', 24300,
74'number of examples to evaluate.')
75tf.app.flags.DEFINE_string(
76'ckpnt',
77'/tmp/tensorflow/mnist/logs/mnist_with_summaries/train/model.ckpnt',
78'The checkpoint to load and evaluate once.')
79tf.app.flags.DEFINE_integer('keep_ckpt', 5, 'number of examples to evaluate.')
80tf.app.flags.DEFINE_bool(
81'clip_lr', False, 'whether to clip learning rate to not go bellow 1e-5.')
82tf.app.flags.DEFINE_integer('stride_1', 2,
83'stride for the first convolutinal layer.')
84tf.app.flags.DEFINE_integer('kernel_2', 9,
85'kernel size for the secon convolutinal layer.')
86tf.app.flags.DEFINE_integer('stride_2', 2,
87'stride for the second convolutinal layer.')
88tf.app.flags.DEFINE_string('padding', 'VALID',
89'the padding method for conv layers.')
90tf.app.flags.DEFINE_integer('extra_caps', 2, 'number of extra conv capsules.')
91tf.app.flags.DEFINE_string('caps_dims', '32,32',
92'output dim for extra conv capsules.')
93tf.app.flags.DEFINE_string('caps_strides', '2,1',
94'stride for extra conv capsules.')
95tf.app.flags.DEFINE_string('caps_kernels', '3,3',
96'kernel size for extra conv capsuls.')
97tf.app.flags.DEFINE_integer('extra_conv', 0, 'number of extra conv layers.')
98
99tf.app.flags.DEFINE_string('conv_dims', '', 'output dim for extra conv layers.')
100tf.app.flags.DEFINE_string('conv_strides', '', 'stride for extra conv layers.')
101tf.app.flags.DEFINE_string('conv_kernels', '',
102'kernel size for extra conv layers.')
103tf.app.flags.DEFINE_bool('leaky', False, 'Use leaky routing.')
104tf.app.flags.DEFINE_bool('fast', False, 'Use the new faster implementation.')
105tf.app.flags.DEFINE_bool('cpu_way', False,
106'If set, use NHWC ordering instead of NCHW.')
107tf.app.flags.DEFINE_bool('jit_scopes', False,
108'Use xla jit_scopes to compile. Not supported.')
109tf.app.flags.DEFINE_bool('staircase', False, 'Use staircase decay.')
110tf.app.flags.DEFINE_integer('num_gpus', 1, 'number of gpus to train.')
111tf.app.flags.DEFINE_bool('adam', True, 'Use Adam optimizer.')
112tf.app.flags.DEFINE_bool('pooling', False, 'Pooling after convolution.')
113tf.app.flags.DEFINE_bool('use_caps', True, 'Use capsule layers.')
114tf.app.flags.DEFINE_integer(
115'extra_fc', 512, 'number of units in the extra fc layer in no caps mode.')
116tf.app.flags.DEFINE_bool('dropout', False, 'Dropout before last layer.')
117tf.app.flags.DEFINE_bool('tweak', False, 'During eval recons from tweaked rep.')
118tf.app.flags.DEFINE_bool('softmax', False, 'softmax loss in no caps.')
119tf.app.flags.DEFINE_bool('c_dropout', False, 'dropout after conv capsules.')
120tf.app.flags.DEFINE_bool(
121'distort', True,
122'distort mnist images by cropping to 24 * 24 and rotating by 15 degrees.')
123tf.app.flags.DEFINE_bool('restart', False, 'Clean train checkpoints.')
124tf.app.flags.DEFINE_bool('use_em', True,
125'If set use em capsules with em routing.')
126tf.app.flags.DEFINE_float('final_beta', 0.01, 'Temperature at the sigmoid.')
127tf.app.flags.DEFINE_bool('eval_ensemble', False, 'eval over aggregated logits.')
128tf.app.flags.DEFINE_string('part1', 'ok', 'ok')
129tf.app.flags.DEFINE_string('part2', 'ok', 'ok')
130tf.app.flags.DEFINE_bool('reduce_mean', False,
131'If set normalize mean of each image.')
132tf.app.flags.DEFINE_float('loss_rate', 1.0,
133'classification to regularization rate.')
134tf.app.flags.DEFINE_integer('batch_size', 64, 'Batch size.')
135tf.app.flags.DEFINE_integer('norb_pixel', 48, 'Batch size.')
136tf.app.flags.DEFINE_bool('patching', True, 'If set use patching for eval.')
137
138tf.app.flags.DEFINE_string('data_set', 'norb', 'the data set to use.')
139tf.app.flags.DEFINE_string('cifar_data_dir', '/tmp/cifar10_data',
140"""Path to the CIFAR-10 data directory.""")
141tf.app.flags.DEFINE_string('norb_data_dir', '/root/datasets/smallNORB/',
142"""Path to the norb data directory.""")
143tf.app.flags.DEFINE_string('affnist_data_dir', '/tmp/affnist_data',
144"""Path to the affnist data directory.""")
145
146
147num_classes = {
148'mnist': 10,
149'cifar10': 10,
150'mnist_multi': 10,
151'svhn': 10,
152'affnist': 10,
153'expanded_mnist': 10,
154'norb': 5,
155}
156
157
158def get_features(train, total_batch):
159"""Return batched inputs."""
160print(FLAGS.data_set)
161batch_size = total_batch // max(1, FLAGS.num_gpus)
162split = 'train' if train else 'test'
163features = []
164for i in range(FLAGS.num_gpus):
165with tf.device('/cpu:0'):
166with tf.name_scope('input_tower_%d' % (i)):
167if FLAGS.data_set == 'norb':
168features += [
169norb_record.inputs(
170train_dir=FLAGS.norb_data_dir,
171batch_size=batch_size,
172split=split,
173multi=FLAGS.multi,
174image_pixel=FLAGS.norb_pixel,
175distort=FLAGS.distort,
176patching=FLAGS.patching,
177)
178]
179elif FLAGS.data_set == 'affnist':
180features += [
181mnist_record.inputs(
182train_dir=FLAGS.affnist_data_dir,
183batch_size=batch_size,
184split=split,
185multi=FLAGS.multi,
186shift=0,
187height=40,
188train_file='test.tfrecords')
189]
190elif FLAGS.data_set == 'expanded_mnist':
191features += [
192mnist_record.inputs(
193train_dir=FLAGS.data_dir,
194batch_size=batch_size,
195split=split,
196multi=FLAGS.multi,
197height=40,
198train_file='train_6shifted_6padded_mnist.tfrecords',
199shift=6)
200]
201else:
202if train and not FLAGS.distort:
203shift = 2
204else:
205shift = 0
206features += [
207mnist_record.inputs(
208train_dir=FLAGS.data_dir,
209batch_size=batch_size,
210split=split,
211multi=FLAGS.multi,
212shift=shift,
213distort=FLAGS.distort)
214]
215print(features)
216return features
217
218
219def run_training():
220"""Train."""
221with tf.Graph().as_default():
222# Input images and labels.
223features = get_features(True, FLAGS.batch_size)
224model = f_model.multi_gpu_model
225print('so far so good!')
226result = model(features)
227
228# TODO(sasabour): merge jit scopes after jit scopes where enabled.
229merged = result['summary']
230train_step = result['train']
231# test_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/test')
232
233sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
234
235init_op = tf.group(tf.global_variables_initializer(),
236tf.local_variables_initializer())
237sess.run(init_op)
238saver = tf.train.Saver(max_to_keep=FLAGS.keep_ckpt)
239if tf.gfile.Exists(FLAGS.summary_dir + '/train'):
240ckpt = tf.train.get_checkpoint_state(FLAGS.summary_dir + '/train/')
241print(ckpt)
242if (not FLAGS.restart) and ckpt and ckpt.model_checkpoint_path:
243print('hesllo')
244saver.restore(sess, ckpt.model_checkpoint_path)
245prev_step = int(
246ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
247else:
248print('what??')
249tf.gfile.DeleteRecursively(FLAGS.summary_dir + '/train')
250tf.gfile.MakeDirs(FLAGS.summary_dir + '/train')
251prev_step = 0
252else:
253tf.gfile.MakeDirs(FLAGS.summary_dir + '/train')
254prev_step = 0
255train_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/train',
256sess.graph)
257coord = tf.train.Coordinator()
258threads = tf.train.start_queue_runners(sess=sess, coord=coord)
259
260try:
261step = 0
262for i in range(prev_step, FLAGS.max_steps):
263step += 1
264summary, _ = sess.run([merged, train_step])
265train_writer.add_summary(summary, i)
266if (i + 1) % FLAGS.checkpoint_steps == 0:
267saver.save(
268sess,
269os.path.join(FLAGS.summary_dir + '/train', 'model.ckpt'),
270global_step=i + 1)
271except tf.errors.OutOfRangeError:
272print('Done training for %d steps.' % step)
273finally:
274# When done, ask the threads to stop.
275coord.request_stop()
276train_writer.close()
277# Wait for threads to finish.
278coord.join(threads)
279sess.close()
280
281
282def run_eval():
283"""Evaluate on test or validation."""
284with tf.Graph().as_default():
285# Input images and labels.
286features = get_features(False, 5)
287model = f_model.multi_gpu_model
288result = model(features)
289merged = result['summary']
290correct_prediction_sum = result['correct']
291almost_correct_sum = result['almost']
292saver = tf.train.Saver()
293test_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/test')
294seen_step = -1
295time.sleep(3 * 60)
296paused = 0
297while paused < 360:
298ckpt = tf.train.get_checkpoint_state(FLAGS.summary_dir + '/train/')
299if ckpt and ckpt.model_checkpoint_path:
300# Restores from checkpoin
301global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
302else:
303time.sleep(2 * 60)
304paused += 2
305continue
306while seen_step == int(global_step):
307time.sleep(2 * 60)
308ckpt = tf.train.get_checkpoint_state(FLAGS.summary_dir + '/train/')
309global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
310paused += 2
311if paused > 360:
312test_writer.close()
313return
314paused = 0
315
316seen_step = int(global_step)
317print(seen_step)
318sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
319saver.restore(sess, ckpt.model_checkpoint_path)
320coord = tf.train.Coordinator()
321threads = tf.train.start_queue_runners(sess=sess, coord=coord)
322try:
323total_tp = 0
324total_almost = 0
325for i in range(FLAGS.eval_size // 5):
326summary_j, tp, almost = sess.run(
327[merged, correct_prediction_sum, almost_correct_sum])
328total_tp += tp
329total_almost += almost
330
331total_false = FLAGS.eval_size - total_tp
332total_almost_false = FLAGS.eval_size - total_almost
333summary_tp = tf.Summary.FromString(summary_j)
334summary_tp.value.add(tag='correct_prediction', simple_value=total_tp)
335summary_tp.value.add(tag='wrong_prediction', simple_value=total_false)
336summary_tp.value.add(
337tag='almost_wrong_prediction', simple_value=total_almost_false)
338test_writer.add_summary(summary_tp, global_step)
339print('write done')
340except tf.errors.OutOfRangeError:
341print('Done eval for %d steps.' % i)
342finally:
343# When done, ask the threads to stop.
344coord.request_stop()
345# Wait for threads to finish.
346coord.join(threads)
347sess.close()
348test_writer.close()
349
350
351def softmax(x):
352"""Compute softmax values for each sets of scores in x."""
353e_x = np.exp(x - np.max(x))
354return e_x / e_x.sum()
355
356
357def eval_ensemble(ckpnts):
358"""Evaluate on an ensemble of checkpoints."""
359with tf.Graph().as_default():
360first_features = get_features(False, 100)[0]
361h = first_features['height']
362d = first_features['depth']
363features = {
364'images': tf.placeholder(tf.float32, shape=(100, d, h, h)),
365'labels': tf.placeholder(tf.float32, shape=(100, 10)),
366'recons_image': tf.placeholder(tf.float32, shape=(100, d, h, h)),
367'recons_label': tf.placeholder(tf.int32, shape=(100)),
368'height': first_features['height'],
369'depth': first_features['depth']
370}
371
372model = f_model.multi_gpu_model
373result = model([features])
374logits = result['logits']
375config = tf.ConfigProto(allow_soft_placement=True)
376# saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpnt))
377batch_logits = np.zeros((FLAGS.eval_size // 100, 100, 10), dtype=np.float32)
378batch_recons_label = np.zeros((FLAGS.eval_size // 100, 100),
379dtype=np.float32)
380batch_labels = np.zeros((FLAGS.eval_size // 100, 100, 10), dtype=np.float32)
381batch_images = np.zeros((FLAGS.eval_size // 100, 100, d, h, h),
382dtype=np.float32)
383batch_recons_image = np.zeros((FLAGS.eval_size // 100, 100, d, h, h),
384dtype=np.float32)
385saver = tf.train.Saver()
386sess = tf.Session(config=config)
387coord = tf.train.Coordinator()
388threads = tf.train.start_queue_runners(sess=sess, coord=coord)
389try:
390for i in range(FLAGS.eval_size // 100):
391(batch_recons_label[i, Ellipsis], batch_labels[i, Ellipsis], batch_images[i, Ellipsis],
392batch_recons_image[i, Ellipsis]) = sess.run([
393first_features['recons_label'], first_features['labels'],
394first_features['images'], first_features['recons_image']
395])
396for ckpnt in ckpnts:
397saver.restore(sess, ckpnt)
398for i in range(FLAGS.eval_size // 100):
399logits_i = sess.run(
400logits,
401feed_dict={
402features['recons_label']: batch_recons_label[i, Ellipsis],
403features['labels']: batch_labels[i, Ellipsis],
404features['images']: batch_images[i, Ellipsis],
405features['recons_image']: batch_recons_image[i, Ellipsis]
406})
407# batch_logits[i, ...] += softmax(logits_i)
408batch_logits[i, Ellipsis] += logits_i
409except tf.errors.OutOfRangeError:
410print('Done eval for %d steps.' % i)
411finally:
412# When done, ask the threads to stop.
413coord.request_stop()
414# Wait for threads to finish.
415coord.join(threads)
416sess.close()
417batch_pred = np.argmax(batch_logits, axis=2)
418total_wrong = np.sum(np.not_equal(batch_pred, batch_recons_label))
419print(total_wrong)
420
421
422def eval_once(ckpnt):
423"""Evaluate on one checkpoint once."""
424ptches = np.zeros((14, 14, 32, 32))
425for i in range(14):
426for j in range(14):
427ind_x = i * 2
428ind_y = j * 2
429for k in range(5):
430for h in range(5):
431ptches[i, j, ind_x + k, ind_y + h] = 1
432ptches = np.reshape(ptches, (14 * 14, 32, 32))
433
434with tf.Graph().as_default():
435features = get_features(False, 1)[0]
436if FLAGS.patching:
437features['images'] = features['cc_images']
438features['recons_label'] = features['cc_recons_label']
439features['labels'] = features['cc_labels']
440model = f_model.multi_gpu_model
441result = model([features])
442# merged = result['summary']
443correct_prediction_sum = result['correct']
444# almost_correct_sum = result['almost']
445# mid_act = result['mid_act']
446logits = result['logits']
447
448saver = tf.train.Saver()
449test_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/test_once')
450config = tf.ConfigProto(allow_soft_placement=True)
451config.gpu_options.per_process_gpu_memory_fraction = 0.3
452sess = tf.Session(config=config)
453# saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpnt))
454saver.restore(sess, ckpnt)
455coord = tf.train.Coordinator()
456threads = tf.train.start_queue_runners(sess=sess, coord=coord)
457i = 0
458try:
459total_tp = 0
460for i in range(FLAGS.eval_size):
461#, g_ac, ac
462lb, tp, lg = sess.run([
463features['recons_label'],
464correct_prediction_sum,
465logits,
466])
467if FLAGS.patching:
468batched_lg = np.sum(lg / np.sum(lg, axis=1, keepdims=True), axis=0)
469batch_pred = np.argmax(batched_lg)
470tp = np.equal(batch_pred, lb[0])
471
472total_tp += tp
473total_false = FLAGS.eval_size - total_tp
474print('false:{}, true:{}'.format(total_false, total_tp))
475# summary_tp = tf.Summary.FromString(summary_j)
476# summary_tp.value.add(tag='correct_prediction', simple_value=total_tp)
477# summary_tp.value.add(tag='wrong_prediction', simple_value=total_false)
478# summary_tp.value.add(
479# tag='almost_wrong_prediction', simple_value=total_almost_false)
480# test_writer.add_summary(summary_tp, i + 1)
481except tf.errors.OutOfRangeError:
482print('Done eval for %d steps.' % i)
483finally:
484# When done, ask the threads to stop.
485coord.request_stop()
486# Wait for threads to finish.
487coord.join(threads)
488sess.close()
489test_writer.close()
490
491
492def main(_):
493if FLAGS.eval_ensemble:
494if tf.gfile.Exists(FLAGS.summary_dir + '/test_ensemble'):
495tf.gfile.DeleteRecursively(FLAGS.summary_dir + '/test_ensemble')
496tf.gfile.MakeDirs(FLAGS.summary_dir + '/test_ensemble')
497ensem = []
498for i in range(1, 12):
499f_name = '/tmp/cifar10/{}{}{}-600000'.format(FLAGS.part1, i, FLAGS.part2)
500if tf.train.checkpoint_exists(f_name):
501ensem += [f_name]
502
503print(len(ensem))
504eval_ensemble(ensem)
505elif FLAGS.eval_once:
506if tf.gfile.Exists(FLAGS.summary_dir + '/test_once'):
507tf.gfile.DeleteRecursively(FLAGS.summary_dir + '/test_once')
508tf.gfile.MakeDirs(FLAGS.summary_dir + '/test_once')
509eval_once(FLAGS.ckpnt)
510elif FLAGS.train:
511run_training()
512else:
513if tf.gfile.Exists(FLAGS.summary_dir + '/test_once'):
514tf.gfile.DeleteRecursively(FLAGS.summary_dir + '/test_once')
515tf.gfile.MakeDirs(FLAGS.summary_dir + '/test_once')
516if tf.gfile.Exists(FLAGS.summary_dir + '/test'):
517tf.gfile.DeleteRecursively(FLAGS.summary_dir + '/test')
518tf.gfile.MakeDirs(FLAGS.summary_dir + '/test')
519run_eval()
520
521
522if __name__ == '__main__':
523tf.app.run()
524