google-research
163 строки · 5.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Multi GPU model (sync gradient updates.)
17"""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22import tensorflow.compat.v1 as tf
23from capsule_em import em_model
24from capsule_em import layers
25from capsule_em import simple_model
26from capsule_em import utils
27
28FLAGS = tf.app.flags.FLAGS
29
30
31def _average_gradients(tower_grads):
32"""Calculate the average gradient for each shared variable across all towers.
33
34Note that this function provides a synchronization point across all towers.
35
36Args:
37tower_grads: List of lists of (gradient, variable) tuples. The outer list is
38over individual gradients. The inner list is over the gradient calculation
39for each tower.
40
41Returns:
42List of pairs of (gradient, variable) where the gradient has been averaged
43across all towers.
44"""
45average_grads = []
46for grad_and_vars in zip(*tower_grads):
47# Note that each grad_and_vars looks like the following:
48# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
49grads = []
50print(len(grad_and_vars))
51for g, v in grad_and_vars:
52if g is None:
53print(v)
54for grad_and_vars in zip(*tower_grads):
55# Note that each grad_and_vars looks like the following:
56# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
57grads = []
58print(len(grad_and_vars))
59for g, v in grad_and_vars:
60if g is not None:
61print(v)
62for g, v in grad_and_vars:
63# Add 0 dimension to the gradients to represent the tower.
64print(v)
65expanded_g = tf.expand_dims(g, 0)
66
67# Append on a 'tower' dimension which we will average over below.
68grads.append(expanded_g)
69
70# Average over the 'tower' dimension.
71grad = tf.concat(grads, 0)
72grad = tf.reduce_mean(grad, 0)
73capped_grad = tf.clip_by_value(grad, -200., 200.)
74
75# Keep in mind that the Variables are redundant because they are shared
76# across towers. So .. we will just return the first tower's pointer to
77# the Variable.
78v = grad_and_vars[0][1]
79grad_and_var = (capped_grad, v)
80average_grads.append(grad_and_var)
81return average_grads
82
83
84def multi_gpu_model(features):
85"""Build the Graph and train the model on multiple gpus."""
86if FLAGS.use_caps:
87if FLAGS.use_em:
88inference = em_model.inference
89else:
90print('not supported')
91else:
92inference = simple_model.conv_inference
93with tf.device('/cpu:0'):
94global_step = tf.get_variable(
95'global_step', [],
96initializer=tf.constant_initializer(0),
97trainable=False)
98
99lr = tf.train.exponential_decay(
100FLAGS.learning_rate,
101global_step,
102FLAGS.decay_steps,
103FLAGS.decay_rate,
104staircase=FLAGS.staircase)
105if FLAGS.clip_lr:
106lr = tf.maximum(lr, 1e-6)
107
108if FLAGS.adam:
109opt = tf.train.AdamOptimizer(lr)
110else:
111opt = tf.train.GradientDescentOptimizer(lr)
112
113tower_grads = []
114corrects = []
115almosts = []
116result = {}
117with tf.variable_scope(tf.get_variable_scope()):
118for i in range(FLAGS.num_gpus):
119with tf.device('/gpu:%d' % i):
120with tf.name_scope('tower_%d' % (i)) as scope:
121label_ = features[i]['labels']
122y, result['recons_1'], result['recons_2'], result[
123'mid_act'] = inference(features[i])
124result['logits'] = y
125
126losses, correct, almost = layers.optimizer(
127logits=y,
128labels=label_,
129multi=FLAGS.multi and FLAGS.data_set == 'mnist',
130scope=scope,
131softmax=FLAGS.softmax,
132rate=FLAGS.loss_rate,
133step=global_step,
134)
135tf.get_variable_scope().reuse_variables()
136corrects.append(correct)
137almosts.append(almost)
138# summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
139grads = opt.compute_gradients(
140losses,
141gate_gradients=tf.train.Optimizer.GATE_NONE,
142)
143tower_grads.append(grads)
144
145with utils.maybe_jit_scope(), tf.name_scope('average_gradients'):
146grads = _average_gradients(tower_grads)
147summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
148if FLAGS.verbose:
149for grad, var in grads:
150if grad is not None:
151summaries.append(
152tf.summary.histogram(var.op.name + '/gradients', grad))
153summaries.append(tf.summary.scalar('learning_rate', lr))
154result['summary'] = tf.summary.merge(summaries)
155result['train'] = opt.apply_gradients(grads, global_step=global_step)
156# result['train'] = y
157
158cors = tf.stack(corrects)
159alms = tf.stack(almosts)
160result['correct'] = tf.reduce_sum(cors, 0)
161result['almost'] = tf.reduce_sum(alms, 0)
162
163return result
164