google-research
159 строк · 6.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Functions and classes related to optimization (weight updates)."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import re
23import tensorflow.compat.v1 as tf
24
25from mobilebert import distill_util
26
27
28class LAMBOptimizer(tf.train.Optimizer):
29"""LAMB (Layer-wise Adaptive Moments optimizer for Batch training)."""
30# A new optimizer that includes correct L2 weight decay, adaptive
31# element-wise updating, and layer-wise justification. The LAMB optimizer
32# was proposed by Yang You, Jing Li, Jonathan Hseu, Xiaodan Song,
33# James Demmel, and Cho-Jui Hsieh in a paper titled as Reducing BERT
34# Pre-Training Time from 3 Days to 76 Minutes (arxiv.org/abs/1904.00962)
35
36def __init__(self,
37learning_rate,
38weight_decay_rate=0.0,
39beta_1=0.9,
40beta_2=0.999,
41epsilon=1e-6,
42exclude_from_weight_decay=None,
43exclude_from_layer_adaptation=None,
44name="LAMBOptimizer",
45use_layer_wise_warmup=False,
46total_warmup_phases=0,
47num_train_steps=0):
48"""Constructs a LAMBOptimizer."""
49super(LAMBOptimizer, self).__init__(False, name)
50self.learning_rate = learning_rate
51self.weight_decay_rate = weight_decay_rate
52self.beta_1 = beta_1
53self.beta_2 = beta_2
54self.epsilon = epsilon
55self.exclude_from_weight_decay = exclude_from_weight_decay
56# exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
57# arg is None.
58if exclude_from_layer_adaptation:
59self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
60else:
61self.exclude_from_layer_adaptation = exclude_from_weight_decay
62self.use_layer_wise_warmup = use_layer_wise_warmup
63if total_warmup_phases == 0:
64self.steps_per_phase = 1
65else:
66self.steps_per_phase = num_train_steps // total_warmup_phases
67
68def apply_gradients(self, grads_and_vars, global_step=None, name=None):
69"""See base class."""
70assignments = []
71background_lr = distill_util.get_background_lr(
72global_step=global_step, steps_per_phase=self.steps_per_phase)
73for (grad, param) in grads_and_vars:
74if grad is None or param is None:
75continue
76param_name = self._get_variable_name(param.name)
77m = tf.get_variable(
78name=param_name + "/adam_m",
79shape=param.shape.as_list(),
80dtype=tf.float32,
81trainable=False,
82initializer=tf.zeros_initializer())
83v = tf.get_variable(
84name=param_name + "/adam_v",
85shape=param.shape.as_list(),
86dtype=tf.float32,
87trainable=False,
88initializer=tf.zeros_initializer())
89if self.use_layer_wise_warmup:
90# Use model-specific name spaces to get layer id.
91if param_name.startswith("bert/encoder/layer_"):
92layer_id = int(param_name[len("bert/encoder/layer_"):].split("/",
931)[0])
94layer_wise_lr = distill_util.layer_wise_learning_rate(
95layer_id=layer_id,
96steps_per_phase=self.steps_per_phase,
97background_lr=background_lr)
98layer_wise_gate = tf.where(
99tf.math.greater(layer_wise_lr, 0.0), 1.0, 0.0)
100else:
101layer_wise_lr = 0.0
102layer_wise_gate = 0.0
103else:
104layer_wise_lr = 1.0
105layer_wise_gate = 1.0
106# Standard Adam update.
107next_m = layer_wise_gate * (
108tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
109next_v = layer_wise_gate * (
110tf.multiply(self.beta_2, v) +
111tf.multiply(1.0 - self.beta_2, tf.square(grad)))
112update = next_m / (tf.sqrt(next_v) + self.epsilon)
113# Just adding the square of the weights to the loss function is *not*
114# the correct way of using L2 regularization/weight decay with Adam,
115# since that will interact with the m and v parameters in strange ways.
116#
117# Instead we want ot decay the weights in a manner that doesn't interact
118# with the m/v parameters. This is equivalent to adding the square
119# of the weights to the loss with plain (non-momentum) SGD.
120if self._do_use_weight_decay(param_name):
121update += layer_wise_gate * self.weight_decay_rate * param
122ratio = 1.0
123if self._do_layer_adaptation(param_name):
124w_norm = tf.linalg.norm(param, ord=2)
125g_norm = tf.linalg.norm(update, ord=2)
126ratio = tf.where(tf.math.greater(w_norm, 0), tf.where(
127tf.math.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)
128update_with_lr = layer_wise_lr * ratio * self.learning_rate * update
129next_param = param - update_with_lr
130assignments.extend(
131[param.assign(next_param),
132m.assign(next_m),
133v.assign(next_v)])
134return tf.group(*assignments, name=name)
135
136def _do_use_weight_decay(self, param_name):
137"""Whether to use L2 weight decay for `param_name`."""
138if not self.weight_decay_rate:
139return False
140if self.exclude_from_weight_decay:
141for r in self.exclude_from_weight_decay:
142if re.search(r, param_name) is not None:
143return False
144return True
145
146def _do_layer_adaptation(self, param_name):
147"""Whether to do layer-wise learning rate adaptation for `param_name`."""
148if self.exclude_from_layer_adaptation:
149for r in self.exclude_from_layer_adaptation:
150if re.search(r, param_name) is not None:
151return False
152return True
153
154def _get_variable_name(self, param_name):
155"""Get the variable name from the tensor name."""
156m = re.match("^(.*):\\d+$", param_name)
157if m is not None:
158param_name = m.group(1)
159return param_name
160