gpt-neox

Форк
0
/
learning_rates.py 
148 строк · 5.1 Кб
1
# Copyright (c) 2024, EleutherAI
2
# This file is based on code by the authors denoted below and has been modified from its original version.
3
#
4
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
5
#
6
# Licensed under the Apache License, Version 2.0 (the "License");
7
# you may not use this file except in compliance with the License.
8
# You may obtain a copy of the License at
9
#
10
#     http://www.apache.org/licenses/LICENSE-2.0
11
#
12
# Unless required by applicable law or agreed to in writing, software
13
# distributed under the License is distributed on an "AS IS" BASIS,
14
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
# See the License for the specific language governing permissions and
16
# limitations under the License.
17

18
"""Learning rate decay functions."""
19

20
import math
21

22
from megatron import print_rank_0
23

24

25
class AnnealingLR(object):
26
    """Anneals the learning rate."""
27

28
    def __init__(
29
        self,
30
        optimizer,
31
        start_lr,
32
        warmup_iter,
33
        total_iters,
34
        decay_style,
35
        last_iter,
36
        min_lr=0.0,
37
        use_checkpoint_lr_scheduler=True,
38
        override_lr_scheduler=False,
39
        use_mup=False,
40
    ):
41

42
        # Class values.
43
        self.optimizer = optimizer
44
        self.start_lr = start_lr
45
        self.min_lr = min_lr
46
        self.warmup_iter = warmup_iter
47
        self.num_iters = last_iter
48
        self.end_iter = total_iters
49
        assert self.end_iter > 0
50
        self.decay_style = decay_style
51
        self.override_lr_scheduler = override_lr_scheduler
52
        self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
53
        self.use_mup = use_mup
54
        if self.override_lr_scheduler:
55
            assert not self.use_checkpoint_lr_scheduler, (
56
                "both override and " "use-checkpoint are set."
57
            )
58
        # Set the learning rate
59
        self.step(self.num_iters)
60

61
        print_rank_0("> learning rate decay style: {}".format(self.decay_style))
62

63
    def get_lr(self):
64
        """Learning rate decay functions from:
65
        https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
66

67
        num_iters_ = self.num_iters
68
        # Warmup.
69
        if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
70
            return float(self.start_lr) * num_iters_ / self.warmup_iter
71

72
        num_iters_ = num_iters_ - self.warmup_iter
73
        if self.decay_style == "linear":
74
            end_iter_ = self.end_iter - self.warmup_iter
75
            lr = self.start_lr * (end_iter_ - num_iters_) / end_iter_
76
        elif self.decay_style == "cosine":
77
            end_iter_ = self.end_iter - self.warmup_iter
78
            lr = self.min_lr + (
79
                (self.start_lr - self.min_lr)
80
                / 2.0
81
                * (math.cos(math.pi * num_iters_ / end_iter_) + 1)
82
            )
83
        elif self.decay_style == "exponential":
84
            # exp(-0.693) = 1/2
85
            end_iter = self.end_iter - self.warmup_iter
86
            lr = self.start_lr * math.exp(-0.693 * num_iters_ / end_iter)
87
        else:
88
            lr = self.start_lr
89
        return max(lr, self.min_lr)
90

91
    def step(self, step_num=None):
92
        """Set lr for all parameters groups."""
93
        if step_num is None:
94
            step_num = self.num_iters + 1
95
        self.num_iters = step_num
96
        new_lr = self.get_lr()
97
        for group in self.optimizer.param_groups:
98
            if self.use_mup and "width_mult" in group:
99
                group["lr"] = new_lr / group["width_mult"]
100
            else:
101
                group["lr"] = new_lr
102

103
    def state_dict(self):
104
        state_dict = {
105
            "start_lr": self.start_lr,
106
            "warmup_iter": self.warmup_iter,
107
            "num_iters": self.num_iters,
108
            "decay_style": self.decay_style,
109
            "end_iter": self.end_iter,
110
            "min_lr": self.min_lr,
111
        }
112
        return state_dict
113

114
    def _check_and_set(self, cls_value, sd_value, name):
115
        """Auxiliary function for checking the values in the checkpoint and
116
        setting them."""
117
        if self.override_lr_scheduler:
118
            print_rank_0(" > overriding {} value to {}".format(name, cls_value))
119
            return cls_value
120

121
        if not self.use_checkpoint_lr_scheduler:
122
            assert cls_value == sd_value, (
123
                "AnnealingLR: class input value"
124
                "and checkpoint values for {} do not match".format(name)
125
            )
126
        print_rank_0(" > using checkpoint value {} for {}".format(sd_value, name))
127
        return sd_value
128

129
    def load_state_dict(self, sd):
130

131
        self.start_lr = self._check_and_set(
132
            self.start_lr, sd["start_lr"], "learning rate"
133
        )
134
        self.min_lr = self._check_and_set(
135
            self.min_lr, sd["min_lr"], "minimum learning rate"
136
        )
137
        self.warmup_iter = self._check_and_set(
138
            self.warmup_iter, sd["warmup_iter"], "warmup iterations"
139
        )
140
        self.end_iter = self._check_and_set(
141
            self.end_iter, sd["end_iter"], "total number of iterations"
142
        )
143
        self.decay_style = self._check_and_set(
144
            self.decay_style, sd["decay_style"], "decay style"
145
        )
146

147
        self.num_iters = sd["num_iters"]
148
        self.step(self.num_iters)
149

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.