google-research

Форк
0
/
lamb_optimizer.py 
159 строк · 6.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Functions and classes related to optimization (weight updates)."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import re
23
import tensorflow.compat.v1 as tf
24

25
from mobilebert import distill_util
26

27

28
class LAMBOptimizer(tf.train.Optimizer):
29
  """LAMB (Layer-wise Adaptive Moments optimizer for Batch training)."""
30
  # A new optimizer that includes correct L2 weight decay, adaptive
31
  # element-wise updating, and layer-wise justification. The LAMB optimizer
32
  # was proposed by Yang You, Jing Li, Jonathan Hseu, Xiaodan Song,
33
  # James Demmel, and Cho-Jui Hsieh in a paper titled as Reducing BERT
34
  # Pre-Training Time from 3 Days to 76 Minutes (arxiv.org/abs/1904.00962)
35

36
  def __init__(self,
37
               learning_rate,
38
               weight_decay_rate=0.0,
39
               beta_1=0.9,
40
               beta_2=0.999,
41
               epsilon=1e-6,
42
               exclude_from_weight_decay=None,
43
               exclude_from_layer_adaptation=None,
44
               name="LAMBOptimizer",
45
               use_layer_wise_warmup=False,
46
               total_warmup_phases=0,
47
               num_train_steps=0):
48
    """Constructs a LAMBOptimizer."""
49
    super(LAMBOptimizer, self).__init__(False, name)
50
    self.learning_rate = learning_rate
51
    self.weight_decay_rate = weight_decay_rate
52
    self.beta_1 = beta_1
53
    self.beta_2 = beta_2
54
    self.epsilon = epsilon
55
    self.exclude_from_weight_decay = exclude_from_weight_decay
56
    # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
57
    # arg is None.
58
    if exclude_from_layer_adaptation:
59
      self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
60
    else:
61
      self.exclude_from_layer_adaptation = exclude_from_weight_decay
62
    self.use_layer_wise_warmup = use_layer_wise_warmup
63
    if total_warmup_phases == 0:
64
      self.steps_per_phase = 1
65
    else:
66
      self.steps_per_phase = num_train_steps // total_warmup_phases
67

68
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
69
    """See base class."""
70
    assignments = []
71
    background_lr = distill_util.get_background_lr(
72
        global_step=global_step, steps_per_phase=self.steps_per_phase)
73
    for (grad, param) in grads_and_vars:
74
      if grad is None or param is None:
75
        continue
76
      param_name = self._get_variable_name(param.name)
77
      m = tf.get_variable(
78
          name=param_name + "/adam_m",
79
          shape=param.shape.as_list(),
80
          dtype=tf.float32,
81
          trainable=False,
82
          initializer=tf.zeros_initializer())
83
      v = tf.get_variable(
84
          name=param_name + "/adam_v",
85
          shape=param.shape.as_list(),
86
          dtype=tf.float32,
87
          trainable=False,
88
          initializer=tf.zeros_initializer())
89
      if self.use_layer_wise_warmup:
90
        # Use model-specific name spaces to get layer id.
91
        if param_name.startswith("bert/encoder/layer_"):
92
          layer_id = int(param_name[len("bert/encoder/layer_"):].split("/",
93
                                                                       1)[0])
94
          layer_wise_lr = distill_util.layer_wise_learning_rate(
95
              layer_id=layer_id,
96
              steps_per_phase=self.steps_per_phase,
97
              background_lr=background_lr)
98
          layer_wise_gate = tf.where(
99
              tf.math.greater(layer_wise_lr, 0.0), 1.0, 0.0)
100
        else:
101
          layer_wise_lr = 0.0
102
          layer_wise_gate = 0.0
103
      else:
104
        layer_wise_lr = 1.0
105
        layer_wise_gate = 1.0
106
      # Standard Adam update.
107
      next_m = layer_wise_gate * (
108
          tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
109
      next_v = layer_wise_gate * (
110
          tf.multiply(self.beta_2, v) +
111
          tf.multiply(1.0 - self.beta_2, tf.square(grad)))
112
      update = next_m / (tf.sqrt(next_v) + self.epsilon)
113
      # Just adding the square of the weights to the loss function is *not*
114
      # the correct way of using L2 regularization/weight decay with Adam,
115
      # since that will interact with the m and v parameters in strange ways.
116
      #
117
      # Instead we want ot decay the weights in a manner that doesn't interact
118
      # with the m/v parameters. This is equivalent to adding the square
119
      # of the weights to the loss with plain (non-momentum) SGD.
120
      if self._do_use_weight_decay(param_name):
121
        update += layer_wise_gate * self.weight_decay_rate * param
122
      ratio = 1.0
123
      if self._do_layer_adaptation(param_name):
124
        w_norm = tf.linalg.norm(param, ord=2)
125
        g_norm = tf.linalg.norm(update, ord=2)
126
        ratio = tf.where(tf.math.greater(w_norm, 0), tf.where(
127
            tf.math.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)
128
      update_with_lr = layer_wise_lr * ratio * self.learning_rate * update
129
      next_param = param - update_with_lr
130
      assignments.extend(
131
          [param.assign(next_param),
132
           m.assign(next_m),
133
           v.assign(next_v)])
134
    return tf.group(*assignments, name=name)
135

136
  def _do_use_weight_decay(self, param_name):
137
    """Whether to use L2 weight decay for `param_name`."""
138
    if not self.weight_decay_rate:
139
      return False
140
    if self.exclude_from_weight_decay:
141
      for r in self.exclude_from_weight_decay:
142
        if re.search(r, param_name) is not None:
143
          return False
144
    return True
145

146
  def _do_layer_adaptation(self, param_name):
147
    """Whether to do layer-wise learning rate adaptation for `param_name`."""
148
    if self.exclude_from_layer_adaptation:
149
      for r in self.exclude_from_layer_adaptation:
150
        if re.search(r, param_name) is not None:
151
          return False
152
    return True
153

154
  def _get_variable_name(self, param_name):
155
    """Get the variable name from the tensor name."""
156
    m = re.match("^(.*):\\d+$", param_name)
157
    if m is not None:
158
      param_name = m.group(1)
159
    return param_name
160

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.