google-research

Форк
0
/
learning_rate_schedules.py 
77 строк · 2.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implements custom learning rate schedules."""
17

18
import gin
19
import tensorflow as tf
20

21

22
@gin.configurable
23
class InverseSquareRootDecayWithWarmup(
24
    tf.keras.optimizers.schedules.LearningRateSchedule):
25
  """Implements the learning rate schedule in Vaswani et al. 2017."""
26

27
  def __init__(
28
      self,
29
      lr_max = 1e-3,
30
      warmup_init_lr = 0.0,
31
      warmup_steps = 4000,
32
      **kwargs):
33
    super().__init__(**kwargs)
34
    self._lr_max = lr_max
35
    self._warmup_init_lr = warmup_init_lr
36
    self._warmup_steps = warmup_steps
37

38
  def __call__(self, step):
39
    norm_step = step / self._warmup_steps
40

41
    def true_fn():
42
      return (self._warmup_init_lr +
43
              (self._lr_max - self._warmup_init_lr) * norm_step)
44

45
    def false_fn():
46
      return self._lr_max * tf.math.rsqrt(norm_step)
47

48
    return tf.cond(norm_step <= 1.0, true_fn, false_fn)
49

50

51
@gin.configurable
52
class NoDecayWithWarmup(
53
    tf.keras.optimizers.schedules.LearningRateSchedule):
54
  """Implements a constant learning rate with a warmup period."""
55

56
  def __init__(
57
      self,
58
      lr_max = 1e-3,
59
      warmup_init_lr = 0.0,
60
      warmup_steps = 4000,
61
      **kwargs):
62
    super().__init__(**kwargs)
63
    self._lr_max = lr_max
64
    self._warmup_init_lr = warmup_init_lr
65
    self._warmup_steps = warmup_steps
66

67
  def __call__(self, step):
68
    norm_step = step / self._warmup_steps
69

70
    def true_fn():
71
      return (self._warmup_init_lr +
72
              (self._lr_max - self._warmup_init_lr) * norm_step)
73

74
    def false_fn():
75
      return tf.cast(self._lr_max, tf.float32)
76

77
    return tf.cond(norm_step <= 1.0, true_fn, false_fn)
78

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.