google-research

Форк
0
163 строки · 5.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Multi GPU model (sync gradient updates.)
17
"""
18

19
from __future__ import absolute_import
20
from __future__ import division
21
from __future__ import print_function
22
import tensorflow.compat.v1 as tf
23
from capsule_em import em_model
24
from capsule_em import layers
25
from capsule_em import simple_model
26
from capsule_em import utils
27

28
FLAGS = tf.app.flags.FLAGS
29

30

31
def _average_gradients(tower_grads):
32
  """Calculate the average gradient for each shared variable across all towers.
33

34
  Note that this function provides a synchronization point across all towers.
35

36
  Args:
37
    tower_grads: List of lists of (gradient, variable) tuples. The outer list is
38
      over individual gradients. The inner list is over the gradient calculation
39
      for each tower.
40

41
  Returns:
42
    List of pairs of (gradient, variable) where the gradient has been averaged
43
    across all towers.
44
  """
45
  average_grads = []
46
  for grad_and_vars in zip(*tower_grads):
47
    # Note that each grad_and_vars looks like the following:
48
    #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
49
    grads = []
50
    print(len(grad_and_vars))
51
    for g, v in grad_and_vars:
52
      if g is None:
53
        print(v)
54
  for grad_and_vars in zip(*tower_grads):
55
    # Note that each grad_and_vars looks like the following:
56
    #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
57
    grads = []
58
    print(len(grad_and_vars))
59
    for g, v in grad_and_vars:
60
      if g is not None:
61
        print(v)
62
    for g, v in grad_and_vars:
63
      # Add 0 dimension to the gradients to represent the tower.
64
      print(v)
65
      expanded_g = tf.expand_dims(g, 0)
66

67
      # Append on a 'tower' dimension which we will average over below.
68
      grads.append(expanded_g)
69

70
    # Average over the 'tower' dimension.
71
    grad = tf.concat(grads, 0)
72
    grad = tf.reduce_mean(grad, 0)
73
    capped_grad = tf.clip_by_value(grad, -200., 200.)
74

75
    # Keep in mind that the Variables are redundant because they are shared
76
    # across towers. So .. we will just return the first tower's pointer to
77
    # the Variable.
78
    v = grad_and_vars[0][1]
79
    grad_and_var = (capped_grad, v)
80
    average_grads.append(grad_and_var)
81
  return average_grads
82

83

84
def multi_gpu_model(features):
85
  """Build the Graph and train the model on multiple gpus."""
86
  if FLAGS.use_caps:
87
    if FLAGS.use_em:
88
      inference = em_model.inference
89
    else:
90
      print('not supported')
91
  else:
92
    inference = simple_model.conv_inference
93
  with tf.device('/cpu:0'):
94
    global_step = tf.get_variable(
95
        'global_step', [],
96
        initializer=tf.constant_initializer(0),
97
        trainable=False)
98

99
    lr = tf.train.exponential_decay(
100
        FLAGS.learning_rate,
101
        global_step,
102
        FLAGS.decay_steps,
103
        FLAGS.decay_rate,
104
        staircase=FLAGS.staircase)
105
    if FLAGS.clip_lr:
106
      lr = tf.maximum(lr, 1e-6)
107

108
    if FLAGS.adam:
109
      opt = tf.train.AdamOptimizer(lr)
110
    else:
111
      opt = tf.train.GradientDescentOptimizer(lr)
112

113
    tower_grads = []
114
    corrects = []
115
    almosts = []
116
    result = {}
117
    with tf.variable_scope(tf.get_variable_scope()):
118
      for i in range(FLAGS.num_gpus):
119
        with tf.device('/gpu:%d' % i):
120
          with tf.name_scope('tower_%d' % (i)) as scope:
121
            label_ = features[i]['labels']
122
            y, result['recons_1'], result['recons_2'], result[
123
                'mid_act'] = inference(features[i])
124
            result['logits'] = y
125

126
            losses, correct, almost = layers.optimizer(
127
                logits=y,
128
                labels=label_,
129
                multi=FLAGS.multi and FLAGS.data_set == 'mnist',
130
                scope=scope,
131
                softmax=FLAGS.softmax,
132
                rate=FLAGS.loss_rate,
133
                step=global_step,
134
            )
135
            tf.get_variable_scope().reuse_variables()
136
            corrects.append(correct)
137
            almosts.append(almost)
138
            #           summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
139
            grads = opt.compute_gradients(
140
                losses,
141
                gate_gradients=tf.train.Optimizer.GATE_NONE,
142
            )
143
            tower_grads.append(grads)
144

145
    with utils.maybe_jit_scope(), tf.name_scope('average_gradients'):
146
      grads = _average_gradients(tower_grads)
147
    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
148
    if FLAGS.verbose:
149
      for grad, var in grads:
150
        if grad is not None:
151
          summaries.append(
152
              tf.summary.histogram(var.op.name + '/gradients', grad))
153
    summaries.append(tf.summary.scalar('learning_rate', lr))
154
    result['summary'] = tf.summary.merge(summaries)
155
    result['train'] = opt.apply_gradients(grads, global_step=global_step)
156
    # result['train'] = y
157

158
    cors = tf.stack(corrects)
159
    alms = tf.stack(almosts)
160
    result['correct'] = tf.reduce_sum(cors, 0)
161
    result['almost'] = tf.reduce_sum(alms, 0)
162

163
    return result
164

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.