paddlenlp

Форк
0
115 строк · 4.0 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import math
16

17
from paddle.optimizer import lr
18
from paddle.optimizer.lr import LRScheduler
19

20
__all__ = [
21
    "CosineAnnealingWithWarmupDecay",
22
    "LinearDecayWithWarmup",
23
    "CosineDecay",
24
]
25

26

27
class CosineAnnealingWithWarmupDecay(LRScheduler):
28
    def __init__(self, max_lr, min_lr, warmup_rate, decay_steps, last_epoch=0, verbose=False, **kwargs):
29

30
        self.decay_steps = decay_steps
31
        self.warmup_step = warmup_rate * decay_steps
32
        self.max_lr = max_lr
33
        self.min_lr = min_lr
34
        self.increment = 0
35
        super(CosineAnnealingWithWarmupDecay, self).__init__(max_lr, last_epoch, verbose)
36
        self.increment = int(kwargs.get("global_batch_size", 0))
37

38
    def get_lr(self):
39
        if self.warmup_step > 0 and self.last_epoch <= self.warmup_step:
40
            return float(self.max_lr) * (self.last_epoch) / self.warmup_step
41

42
        if self.last_epoch > self.decay_steps:
43
            return self.min_lr
44

45
        num_step_ = self.last_epoch - self.warmup_step
46
        decay_steps_ = self.decay_steps - self.warmup_step
47
        decay_ratio = float(num_step_) / float(decay_steps_)
48
        coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
49
        return self.min_lr + coeff * (self.max_lr - self.min_lr)
50

51
    def step(self, epoch=None):
52
        if epoch is None:
53
            self.last_epoch += self.increment
54
            self.last_lr = self.get_lr()
55
        else:
56
            self.last_epoch += epoch
57
            if hasattr(self, "_get_closed_form_lr"):
58
                self.last_lr = self._get_closed_form_lr()
59
            else:
60
                self.last_lr = self.get_lr()
61

62
        if self.verbose:
63
            print(
64
                "Epoch {}: {} set learning rate to {}.".format(self.last_epoch, self.__class__.__name__, self.last_lr)
65
            )
66

67

68
class LinearDecayWithWarmup(LRScheduler):
69
    def __init__(self, learning_rate, step_each_epoch, epochs, warmup=0, verbose=False, last_epoch=-1, **kwargs):
70
        if kwargs.get("total_steps", -1) > 0:
71
            self.T_max = kwargs.get("total_steps")
72
        else:
73
            self.T_max = epochs * step_each_epoch
74

75
        self.warmup_steps = warmup if isinstance(warmup, int) else int(math.floor(warmup * self.T_max))
76
        super(LinearDecayWithWarmup, self).__init__(learning_rate, last_epoch, verbose)
77

78
    def get_lr(self):
79
        if self.last_epoch < self.warmup_steps:
80
            return self.base_lr * (float(self.last_epoch) / float(max(1, self.warmup_steps)))
81
        return self.base_lr * max(0.0, 1.0 - self.last_epoch / self.T_max)
82

83

84
class CosineDecay(lr.LRScheduler):
85
    def __init__(
86
        self,
87
        learning_rate,
88
        step_each_epoch,
89
        epochs,
90
        update_unit="epoch",
91
        warmups=0,
92
        verbose=False,
93
        last_epoch=-1,
94
        **kwargs
95
    ):
96

97
        self.T_max = epochs if update_unit == "epoch" else step_each_epoch * epochs
98
        self.warmups = warmups if update_unit == "epoch" else step_each_epoch * warmups
99

100
        assert self.warmups < self.T_max
101

102
        self.last_epoch = last_epoch
103
        super(CosineDecay, self).__init__(learning_rate, last_epoch, verbose)
104

105
    def get_lr(self):
106

107
        progress = (self.last_epoch - self.warmups) / float(self.T_max - self.warmups)
108
        progress = min(1.0, max(0.0, progress))
109

110
        if self.warmups:
111
            lr = self.base_lr * min(1.0, self.last_epoch / self.warmups)
112
        else:
113
            lr = 0.5 * self.base_lr * (1.0 + math.cos(math.pi * progress))
114

115
        return lr
116

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.