google-research

Форк
0
/
nigt_optimizer.py 
282 строки · 9.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""NIGT Optimizer.
17

18
See the paper https://arxiv.org/abs/2002.03305
19
This optimizer uses uses Taylor expansions to approximate variance reduction
20
algorithms while using only a single gradient evaluation per iteration.
21
"""
22

23
from __future__ import absolute_import
24
from __future__ import division
25
from __future__ import print_function
26

27
import re
28
import six
29
import tensorflow.compat.v1 as tf
30

31

32
class NIGTOptimizer(tf.train.Optimizer):
33
  """NIGTOptimizer."""
34

35
  def __init__(self,
36
               learning_rate,
37
               weight_decay_rate=0.0,
38
               beta=0.9,
39
               gamma=1e-3,
40
               use_igt=True,
41
               use_adaptive=False,
42
               exclude_from_weight_decay=None,
43
               exclude_from_layer_adaptation=None,
44
               name="NIGTptimizer"):
45
    """Constructs an optimizer."""
46
    super(NIGTOptimizer, self).__init__(False, name)
47

48
    self.learning_rate = learning_rate
49
    self.weight_decay_rate = weight_decay_rate
50
    self.beta = beta
51
    self.gamma = gamma
52
    self.use_igt = use_igt
53
    self.use_adaptive = use_adaptive
54
    self.exclude_from_weight_decay = exclude_from_weight_decay
55
    self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
56

57
  def compute_x(self, param_name, param, m, prev_w_norm, prev_eta, prev_beta):
58
    """Compute prev x value on the fly.
59

60
    Alternatively, we can store this as a slot but that would double the
61
    memory usage of our parameters. We don't like that!
62

63
    Args:
64
      param_name: Name of the parameter. Used to check whether to normalize the
65
        gradients for this layer.
66
      param: The parameter `Tensor`.
67
      m: Accumulated momentum `Tensor` of shape same as param.
68
      prev_w_norm: Scalar tracking norm of the param tensor at previous
69
        iteration.
70
      prev_eta: Scalar tracking the learning rate applied at previous iteration.
71
      prev_beta: Scalar tracking momentum applied at previous iteration.
72

73
    Returns:
74
      x: An intermediate `Tensor` of shape same as param. Will be used for the
75
        final update.
76
    """
77
    prev_ratio = 1.0
78
    if self._do_layer_adaptation(param_name):
79
      prev_g_norm = tf.norm(m, ord=2)
80
      prev_ratio = self.gamma * tf.where(
81
          tf.math.greater(prev_w_norm, 0),
82
          tf.where(
83
              tf.math.greater(prev_g_norm, 0),
84
              (prev_w_norm / prev_g_norm), 1.0), 1.0)
85
    prev_normalized_m_with_lr = prev_ratio * prev_eta * m
86

87
    x = param - tf.divide(
88
        tf.multiply(prev_beta, prev_normalized_m_with_lr), prev_beta - 1.0)
89
    return x
90

91
  def swap_to_optimal_params(self, params, name=None):
92
    """Swaps weights to be more optimal after training.
93

94
    NIGT evaluates gradients at points that are *different* than the points
95
    that we expect to have lower loss values. During training, the network
96
    weights are set to the be points where we evaluate gradients. This function
97
    returns an operation that will change the network weights to be the points
98
    that the NIGT believes to be more optimal, and should be used before
99
    evaluation.
100

101
    Note that once this function is called, the parameter values need to be
102
    swapped *back* in order to continue training. This should not be a
103
    concern if the function is only used in eval jobs.
104

105
    Args:
106
      params: list of parameters to update.
107
      name: name for operation.
108

109
    Returns:
110
      an operation that changes the parameters in params to better values.
111
    """
112
    switch_ops = []
113
    for param in params:
114
      param_name = self._get_variable_name(param.name)
115

116
      m = tf.get_variable(
117
          name=six.ensure_str(param_name) + "/m",
118
          shape=param.shape.as_list(),
119
          dtype=tf.float32,
120
          trainable=False,
121
          initializer=tf.zeros_initializer())
122

123
      prev_w_norm = tf.get_variable(
124
          name=six.ensure_str(param_name) + "/prev_w_norm",
125
          dtype=tf.float32,
126
          trainable=False,
127
          initializer=lambda w=param: tf.norm(w.initialized_value(), ord=2))
128

129
      prev_eta = tf.get_variable(
130
          name=six.ensure_str(param_name) + "/prev_eta",
131
          shape=[],
132
          dtype=tf.float32,
133
          trainable=False,
134
          initializer=tf.zeros_initializer())
135
      prev_beta = tf.get_variable(
136
          name=six.ensure_str(param_name) + "/prev_beta",
137
          shape=[],
138
          dtype=tf.float32,
139
          trainable=False,
140
          initializer=tf.zeros_initializer())
141

142
      x = self.compute_x(param_name, param, m, prev_w_norm, prev_eta, prev_beta)
143

144
      switch_ops.append(param.assign(x))
145
    return tf.group(*switch_ops, name=name)
146

147
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
148
    """See base class."""
149
    assignments = []
150
    for (grad, param) in grads_and_vars:
151
      if grad is None or param is None:
152
        continue
153

154
      param_name = self._get_variable_name(param.name)
155

156
      m = tf.get_variable(
157
          name=six.ensure_str(param_name) + "/m",
158
          shape=param.shape.as_list(),
159
          dtype=tf.float32,
160
          trainable=False,
161
          initializer=tf.zeros_initializer())
162

163
      # Note: shape is not passed here explicitly since tf.get_variable
164
      # complains when you do that while passing a Tensor as an initializer.
165
      prev_w_norm = tf.get_variable(
166
          name=six.ensure_str(param_name) + "/prev_w_norm",
167
          dtype=tf.float32,
168
          trainable=False,
169
          initializer=lambda w=param: tf.norm(w.initialized_value(), ord=2))
170

171
      prev_eta = tf.get_variable(
172
          name=six.ensure_str(param_name) + "/prev_eta",
173
          shape=[],
174
          dtype=tf.float32,
175
          trainable=False,
176
          initializer=tf.zeros_initializer())
177
      prev_beta = tf.get_variable(
178
          name=six.ensure_str(param_name) + "/prev_beta",
179
          shape=[],
180
          dtype=tf.float32,
181
          trainable=False,
182
          initializer=tf.zeros_initializer())
183

184
      if self._do_use_weight_decay(param_name):
185
        grad += self.weight_decay_rate * param
186

187
      if self.use_adaptive:
188
        grad_squared_sum = tf.get_variable(
189
            name=six.ensure_str(param_name) + "/grad_squared_sum",
190
            shape=[],
191
            dtype=tf.float32,
192
            trainable=False,
193
            initializer=tf.zeros_initializer())
194

195
        max_grad = tf.get_variable(
196
            name=six.ensure_str(param_name) + "/max_grad",
197
            shape=[],
198
            dtype=tf.float32,
199
            trainable=False,
200
            initializer=tf.zeros_initializer())
201

202
        iteration = tf.get_variable(
203
            name=six.ensure_str(param_name) + "/iteration",
204
            shape=[],
205
            dtype=tf.float32,
206
            trainable=False,
207
            initializer=tf.zeros_initializer())
208

209
        next_grad_squared_sum = grad_squared_sum + tf.norm(grad, 2)
210
        next_iteration = iteration + 1
211
        next_max_grad = tf.maximum(max_grad, tf.norm(grad, 2))
212
        assignments.extend([
213
            grad_squared_sum.assign(next_grad_squared_sum),
214
            iteration.assign(next_iteration),
215
            max_grad.assign(next_max_grad)
216
        ])
217

218
        # Intuitively we should be able to leave g_sum=next_grad_squared_sum,
219
        # but current theory needs this extra t^1/4 max_grad term.
220
        g_sum = next_grad_squared_sum + tf.pow(next_iteration,
221
                                               0.25) * next_max_grad
222

223
        eta = self.learning_rate / tf.pow(
224
            tf.pow(next_iteration, 3.0) * tf.pow(g_sum, 2.0), 1.0 / 7.0)
225
        a = tf.minimum(1.0, 1.0 / (next_iteration * tf.pow(eta, 2.0) * g_sum))
226
        beta = 1.0 - a
227
      else:
228
        eta = self.learning_rate
229
        beta = self.beta
230

231
      next_m = (tf.multiply(beta, m) + tf.multiply(1.0 - beta, grad))
232

233
      ratio = 1.0
234
      w_norm = tf.norm(param, ord=2)
235
      if self._do_layer_adaptation(param_name):
236
        g_norm = tf.norm(next_m, ord=2)
237
        ratio = self.gamma * tf.where(
238
            tf.math.greater(w_norm, 0),
239
            tf.where(tf.math.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)
240
      normalized_m_with_lr = ratio * eta * next_m
241

242
      if self.use_igt:
243
        prev_x = self.compute_x(param_name, param, m, prev_w_norm, prev_eta,
244
                                prev_beta)
245
        next_x = prev_x - normalized_m_with_lr
246
        next_param = next_x + tf.divide(
247
            tf.multiply(beta, normalized_m_with_lr), beta - 1.0)
248
      else:
249
        next_param = param - normalized_m_with_lr
250
      assignments.extend([
251
          param.assign(next_param),
252
          m.assign(next_m),
253
          prev_w_norm.assign(w_norm),
254
          prev_eta.assign(eta),
255
          prev_beta.assign(beta)
256
      ])
257
    return tf.group(*assignments, name=name)
258

259
  def _do_use_weight_decay(self, param_name):
260
    """Whether to use L2 weight decay for `param_name`."""
261
    if not self.weight_decay_rate:
262
      return False
263
    if self.exclude_from_weight_decay:
264
      for r in self.exclude_from_weight_decay:
265
        if re.search(r, param_name) is not None:
266
          return False
267
    return True
268

269
  def _do_layer_adaptation(self, param_name):
270
    """Whether to do layer-wise learning rate adaptation for `param_name`."""
271
    if self.exclude_from_layer_adaptation:
272
      for r in self.exclude_from_layer_adaptation:
273
        if re.search(r, param_name) is not None:
274
          return False
275
    return True
276

277
  def _get_variable_name(self, param_name):
278
    """Get the variable name from the tensor name."""
279
    m = re.match("^(.*):\\d+$", six.ensure_str(param_name))
280
    if m is not None:
281
      param_name = m.group(1)
282
    return param_name
283

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.