google-research
77 строк · 2.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implements custom learning rate schedules."""
17
18import gin
19import tensorflow as tf
20
21
22@gin.configurable
23class InverseSquareRootDecayWithWarmup(
24tf.keras.optimizers.schedules.LearningRateSchedule):
25"""Implements the learning rate schedule in Vaswani et al. 2017."""
26
27def __init__(
28self,
29lr_max = 1e-3,
30warmup_init_lr = 0.0,
31warmup_steps = 4000,
32**kwargs):
33super().__init__(**kwargs)
34self._lr_max = lr_max
35self._warmup_init_lr = warmup_init_lr
36self._warmup_steps = warmup_steps
37
38def __call__(self, step):
39norm_step = step / self._warmup_steps
40
41def true_fn():
42return (self._warmup_init_lr +
43(self._lr_max - self._warmup_init_lr) * norm_step)
44
45def false_fn():
46return self._lr_max * tf.math.rsqrt(norm_step)
47
48return tf.cond(norm_step <= 1.0, true_fn, false_fn)
49
50
51@gin.configurable
52class NoDecayWithWarmup(
53tf.keras.optimizers.schedules.LearningRateSchedule):
54"""Implements a constant learning rate with a warmup period."""
55
56def __init__(
57self,
58lr_max = 1e-3,
59warmup_init_lr = 0.0,
60warmup_steps = 4000,
61**kwargs):
62super().__init__(**kwargs)
63self._lr_max = lr_max
64self._warmup_init_lr = warmup_init_lr
65self._warmup_steps = warmup_steps
66
67def __call__(self, step):
68norm_step = step / self._warmup_steps
69
70def true_fn():
71return (self._warmup_init_lr +
72(self._lr_max - self._warmup_init_lr) * norm_step)
73
74def false_fn():
75return tf.cast(self._lr_max, tf.float32)
76
77return tf.cond(norm_step <= 1.0, true_fn, false_fn)
78