google-research
282 строки · 9.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""NIGT Optimizer.
17
18See the paper https://arxiv.org/abs/2002.03305
19This optimizer uses uses Taylor expansions to approximate variance reduction
20algorithms while using only a single gradient evaluation per iteration.
21"""
22
23from __future__ import absolute_import24from __future__ import division25from __future__ import print_function26
27import re28import six29import tensorflow.compat.v1 as tf30
31
32class NIGTOptimizer(tf.train.Optimizer):33"""NIGTOptimizer."""34
35def __init__(self,36learning_rate,37weight_decay_rate=0.0,38beta=0.9,39gamma=1e-3,40use_igt=True,41use_adaptive=False,42exclude_from_weight_decay=None,43exclude_from_layer_adaptation=None,44name="NIGTptimizer"):45"""Constructs an optimizer."""46super(NIGTOptimizer, self).__init__(False, name)47
48self.learning_rate = learning_rate49self.weight_decay_rate = weight_decay_rate50self.beta = beta51self.gamma = gamma52self.use_igt = use_igt53self.use_adaptive = use_adaptive54self.exclude_from_weight_decay = exclude_from_weight_decay55self.exclude_from_layer_adaptation = exclude_from_layer_adaptation56
57def compute_x(self, param_name, param, m, prev_w_norm, prev_eta, prev_beta):58"""Compute prev x value on the fly.59
60Alternatively, we can store this as a slot but that would double the
61memory usage of our parameters. We don't like that!
62
63Args:
64param_name: Name of the parameter. Used to check whether to normalize the
65gradients for this layer.
66param: The parameter `Tensor`.
67m: Accumulated momentum `Tensor` of shape same as param.
68prev_w_norm: Scalar tracking norm of the param tensor at previous
69iteration.
70prev_eta: Scalar tracking the learning rate applied at previous iteration.
71prev_beta: Scalar tracking momentum applied at previous iteration.
72
73Returns:
74x: An intermediate `Tensor` of shape same as param. Will be used for the
75final update.
76"""
77prev_ratio = 1.078if self._do_layer_adaptation(param_name):79prev_g_norm = tf.norm(m, ord=2)80prev_ratio = self.gamma * tf.where(81tf.math.greater(prev_w_norm, 0),82tf.where(83tf.math.greater(prev_g_norm, 0),84(prev_w_norm / prev_g_norm), 1.0), 1.0)85prev_normalized_m_with_lr = prev_ratio * prev_eta * m86
87x = param - tf.divide(88tf.multiply(prev_beta, prev_normalized_m_with_lr), prev_beta - 1.0)89return x90
91def swap_to_optimal_params(self, params, name=None):92"""Swaps weights to be more optimal after training.93
94NIGT evaluates gradients at points that are *different* than the points
95that we expect to have lower loss values. During training, the network
96weights are set to the be points where we evaluate gradients. This function
97returns an operation that will change the network weights to be the points
98that the NIGT believes to be more optimal, and should be used before
99evaluation.
100
101Note that once this function is called, the parameter values need to be
102swapped *back* in order to continue training. This should not be a
103concern if the function is only used in eval jobs.
104
105Args:
106params: list of parameters to update.
107name: name for operation.
108
109Returns:
110an operation that changes the parameters in params to better values.
111"""
112switch_ops = []113for param in params:114param_name = self._get_variable_name(param.name)115
116m = tf.get_variable(117name=six.ensure_str(param_name) + "/m",118shape=param.shape.as_list(),119dtype=tf.float32,120trainable=False,121initializer=tf.zeros_initializer())122
123prev_w_norm = tf.get_variable(124name=six.ensure_str(param_name) + "/prev_w_norm",125dtype=tf.float32,126trainable=False,127initializer=lambda w=param: tf.norm(w.initialized_value(), ord=2))128
129prev_eta = tf.get_variable(130name=six.ensure_str(param_name) + "/prev_eta",131shape=[],132dtype=tf.float32,133trainable=False,134initializer=tf.zeros_initializer())135prev_beta = tf.get_variable(136name=six.ensure_str(param_name) + "/prev_beta",137shape=[],138dtype=tf.float32,139trainable=False,140initializer=tf.zeros_initializer())141
142x = self.compute_x(param_name, param, m, prev_w_norm, prev_eta, prev_beta)143
144switch_ops.append(param.assign(x))145return tf.group(*switch_ops, name=name)146
147def apply_gradients(self, grads_and_vars, global_step=None, name=None):148"""See base class."""149assignments = []150for (grad, param) in grads_and_vars:151if grad is None or param is None:152continue153
154param_name = self._get_variable_name(param.name)155
156m = tf.get_variable(157name=six.ensure_str(param_name) + "/m",158shape=param.shape.as_list(),159dtype=tf.float32,160trainable=False,161initializer=tf.zeros_initializer())162
163# Note: shape is not passed here explicitly since tf.get_variable164# complains when you do that while passing a Tensor as an initializer.165prev_w_norm = tf.get_variable(166name=six.ensure_str(param_name) + "/prev_w_norm",167dtype=tf.float32,168trainable=False,169initializer=lambda w=param: tf.norm(w.initialized_value(), ord=2))170
171prev_eta = tf.get_variable(172name=six.ensure_str(param_name) + "/prev_eta",173shape=[],174dtype=tf.float32,175trainable=False,176initializer=tf.zeros_initializer())177prev_beta = tf.get_variable(178name=six.ensure_str(param_name) + "/prev_beta",179shape=[],180dtype=tf.float32,181trainable=False,182initializer=tf.zeros_initializer())183
184if self._do_use_weight_decay(param_name):185grad += self.weight_decay_rate * param186
187if self.use_adaptive:188grad_squared_sum = tf.get_variable(189name=six.ensure_str(param_name) + "/grad_squared_sum",190shape=[],191dtype=tf.float32,192trainable=False,193initializer=tf.zeros_initializer())194
195max_grad = tf.get_variable(196name=six.ensure_str(param_name) + "/max_grad",197shape=[],198dtype=tf.float32,199trainable=False,200initializer=tf.zeros_initializer())201
202iteration = tf.get_variable(203name=six.ensure_str(param_name) + "/iteration",204shape=[],205dtype=tf.float32,206trainable=False,207initializer=tf.zeros_initializer())208
209next_grad_squared_sum = grad_squared_sum + tf.norm(grad, 2)210next_iteration = iteration + 1211next_max_grad = tf.maximum(max_grad, tf.norm(grad, 2))212assignments.extend([213grad_squared_sum.assign(next_grad_squared_sum),214iteration.assign(next_iteration),215max_grad.assign(next_max_grad)216])217
218# Intuitively we should be able to leave g_sum=next_grad_squared_sum,219# but current theory needs this extra t^1/4 max_grad term.220g_sum = next_grad_squared_sum + tf.pow(next_iteration,2210.25) * next_max_grad222
223eta = self.learning_rate / tf.pow(224tf.pow(next_iteration, 3.0) * tf.pow(g_sum, 2.0), 1.0 / 7.0)225a = tf.minimum(1.0, 1.0 / (next_iteration * tf.pow(eta, 2.0) * g_sum))226beta = 1.0 - a227else:228eta = self.learning_rate229beta = self.beta230
231next_m = (tf.multiply(beta, m) + tf.multiply(1.0 - beta, grad))232
233ratio = 1.0234w_norm = tf.norm(param, ord=2)235if self._do_layer_adaptation(param_name):236g_norm = tf.norm(next_m, ord=2)237ratio = self.gamma * tf.where(238tf.math.greater(w_norm, 0),239tf.where(tf.math.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)240normalized_m_with_lr = ratio * eta * next_m241
242if self.use_igt:243prev_x = self.compute_x(param_name, param, m, prev_w_norm, prev_eta,244prev_beta)245next_x = prev_x - normalized_m_with_lr246next_param = next_x + tf.divide(247tf.multiply(beta, normalized_m_with_lr), beta - 1.0)248else:249next_param = param - normalized_m_with_lr250assignments.extend([251param.assign(next_param),252m.assign(next_m),253prev_w_norm.assign(w_norm),254prev_eta.assign(eta),255prev_beta.assign(beta)256])257return tf.group(*assignments, name=name)258
259def _do_use_weight_decay(self, param_name):260"""Whether to use L2 weight decay for `param_name`."""261if not self.weight_decay_rate:262return False263if self.exclude_from_weight_decay:264for r in self.exclude_from_weight_decay:265if re.search(r, param_name) is not None:266return False267return True268
269def _do_layer_adaptation(self, param_name):270"""Whether to do layer-wise learning rate adaptation for `param_name`."""271if self.exclude_from_layer_adaptation:272for r in self.exclude_from_layer_adaptation:273if re.search(r, param_name) is not None:274return False275return True276
277def _get_variable_name(self, param_name):278"""Get the variable name from the tensor name."""279m = re.match("^(.*):\\d+$", six.ensure_str(param_name))280if m is not None:281param_name = m.group(1)282return param_name283