CSS-LM

Форк
0
/
optimization.py 
267 строк · 11.3 Кб
1
# coding=utf-8
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""PyTorch optimization for BERT model."""
16

17
import logging
18
import math
19
from typing import Callable, Iterable, Tuple
20

21
import torch
22
from torch.optim import Optimizer
23
from torch.optim.lr_scheduler import LambdaLR
24

25

26
logger = logging.getLogger(__name__)
27

28

29
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
30
    """
31
    Create a schedule with a constant learning rate, using the learning rate set in optimizer.
32

33
    Args:
34
        optimizer (:class:`~torch.optim.Optimizer`):
35
            The optimizer for which to schedule the learning rate.
36
        last_epoch (:obj:`int`, `optional`, defaults to -1):
37
            The index of the last epoch when resuming training.
38

39
    Return:
40
        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
41
    """
42
    return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
43

44

45
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
46
    """
47
    Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
48
    increases linearly between 0 and the initial lr set in the optimizer.
49

50
    Args:
51
        optimizer (:class:`~torch.optim.Optimizer`):
52
            The optimizer for which to schedule the learning rate.
53
        num_warmup_steps (:obj:`int`):
54
            The number of steps for the warmup phase.
55
        last_epoch (:obj:`int`, `optional`, defaults to -1):
56
            The index of the last epoch when resuming training.
57

58
    Return:
59
        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
60
    """
61

62
    def lr_lambda(current_step: int):
63
        if current_step < num_warmup_steps:
64
            return float(current_step) / float(max(1.0, num_warmup_steps))
65
        return 1.0
66

67
    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
68

69

70
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
71
    """
72
    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0,
73
    after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
74

75
    Args:
76
        optimizer (:class:`~torch.optim.Optimizer`):
77
            The optimizer for which to schedule the learning rate.
78
        num_warmup_steps (:obj:`int`):
79
            The number of steps for the warmup phase.
80
        num_training_steps (:obj:`int`):
81
            The totale number of training steps.
82
        last_epoch (:obj:`int`, `optional`, defaults to -1):
83
            The index of the last epoch when resuming training.
84

85
    Return:
86
        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
87
    """
88

89
    def lr_lambda(current_step: int):
90
        if current_step < num_warmup_steps:
91
            return float(current_step) / float(max(1, num_warmup_steps))
92
        return max(
93
            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
94
        )
95

96
    return LambdaLR(optimizer, lr_lambda, last_epoch)
97

98

99
def get_cosine_schedule_with_warmup(
100
    optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
101
):
102
    """
103
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
104
    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
105
    initial lr set in the optimizer.
106

107
    Args:
108
        optimizer (:class:`~torch.optim.Optimizer`):
109
            The optimizer for which to schedule the learning rate.
110
        num_warmup_steps (:obj:`int`):
111
            The number of steps for the warmup phase.
112
        num_training_steps (:obj:`int`):
113
            The total number of training steps.
114
        num_cycles (:obj:`float`, `optional`, defaults to 0.5):
115
            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
116
            following a half-cosine).
117
        last_epoch (:obj:`int`, `optional`, defaults to -1):
118
            The index of the last epoch when resuming training.
119

120
    Return:
121
        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
122
    """
123

124
    def lr_lambda(current_step):
125
        if current_step < num_warmup_steps:
126
            return float(current_step) / float(max(1, num_warmup_steps))
127
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
128
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
129

130
    return LambdaLR(optimizer, lr_lambda, last_epoch)
131

132

133
def get_cosine_with_hard_restarts_schedule_with_warmup(
134
    optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
135
):
136
    """
137
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
138
    initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
139
    linearly between 0 and the initial lr set in the optimizer.
140

141
    Args:
142
        optimizer (:class:`~torch.optim.Optimizer`):
143
            The optimizer for which to schedule the learning rate.
144
        num_warmup_steps (:obj:`int`):
145
            The number of steps for the warmup phase.
146
        num_training_steps (:obj:`int`):
147
            The total number of training steps.
148
        num_cycles (:obj:`int`, `optional`, defaults to 1):
149
            The number of hard restarts to use.
150
        last_epoch (:obj:`int`, `optional`, defaults to -1):
151
            The index of the last epoch when resuming training.
152

153
    Return:
154
        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
155
    """
156

157
    def lr_lambda(current_step):
158
        if current_step < num_warmup_steps:
159
            return float(current_step) / float(max(1, num_warmup_steps))
160
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
161
        if progress >= 1.0:
162
            return 0.0
163
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
164

165
    return LambdaLR(optimizer, lr_lambda, last_epoch)
166

167

168
class AdamW(Optimizer):
169
    """
170
    Implements Adam algorithm with weight decay fix as introduced in
171
    `Decoupled Weight Decay Regularization <https://arxiv.org/abs/1711.05101>`__.
172

173
    Parameters:
174
        params (:obj:`Iterable[torch.nn.parameter.Parameter]`):
175
            Iterable of parameters to optimize or dictionaries defining parameter groups.
176
        lr (:obj:`float`, `optional`, defaults to 1e-3):
177
            The learning rate to use.
178
        betas (:obj:`Tuple[float,float]`, `optional`, defaults to (0.9, 0.999)):
179
            Adam's betas parameters (b1, b2).
180
        eps (:obj:`float`, `optional`, defaults to 1e-6):
181
            Adam's epsilon for numerical stability.
182
        weight_decay (:obj:`float`, `optional`, defaults to 0):
183
            Decoupled weight decay to apply.
184
        correct_bias (:obj:`bool`, `optional`, defaults to `True`):
185
            Whether ot not to correct bias in Adam (for instance, in Bert TF repository they use :obj:`False`).
186
    """
187

188
    def __init__(
189
        self,
190
        params: Iterable[torch.nn.parameter.Parameter],
191
        lr: float = 1e-3,
192
        betas: Tuple[float, float] = (0.9, 0.999),
193
        eps: float = 1e-6,
194
        weight_decay: float = 0.0,
195
        correct_bias: bool = True,
196
    ):
197
        if lr < 0.0:
198
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
199
        if not 0.0 <= betas[0] < 1.0:
200
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
201
        if not 0.0 <= betas[1] < 1.0:
202
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
203
        if not 0.0 <= eps:
204
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
205
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
206
        super().__init__(params, defaults)
207

208
    def step(self, closure: Callable = None):
209
        """
210
        Performs a single optimization step.
211

212
        Arguments:
213
            closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss.
214
        """
215
        loss = None
216
        if closure is not None:
217
            loss = closure()
218

219
        for group in self.param_groups:
220
            for p in group["params"]:
221
                if p.grad is None:
222
                    continue
223
                grad = p.grad.data
224
                if grad.is_sparse:
225
                    raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
226

227
                state = self.state[p]
228

229
                # State initialization
230
                if len(state) == 0:
231
                    state["step"] = 0
232
                    # Exponential moving average of gradient values
233
                    state["exp_avg"] = torch.zeros_like(p.data)
234
                    # Exponential moving average of squared gradient values
235
                    state["exp_avg_sq"] = torch.zeros_like(p.data)
236

237
                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
238
                beta1, beta2 = group["betas"]
239

240
                state["step"] += 1
241

242
                # Decay the first and second moment running average coefficient
243
                # In-place operations to update the averages at the same time
244
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
245
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
246
                denom = exp_avg_sq.sqrt().add_(group["eps"])
247

248
                step_size = group["lr"]
249
                if group["correct_bias"]:  # No bias correction for Bert
250
                    bias_correction1 = 1.0 - beta1 ** state["step"]
251
                    bias_correction2 = 1.0 - beta2 ** state["step"]
252
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
253

254
                p.data.addcdiv_(exp_avg, denom, value=-step_size)
255

256
                # Just adding the square of the weights to the loss function is *not*
257
                # the correct way of using L2 regularization/weight decay with Adam,
258
                # since that will interact with the m and v parameters in strange ways.
259
                #
260
                # Instead we want to decay the weights in a manner that doesn't interact
261
                # with the m/v parameters. This is equivalent to adding the square
262
                # of the weights to the loss with plain (non-momentum) SGD.
263
                # Add weight decay at the end (fixed version)
264
                if group["weight_decay"] > 0.0:
265
                    p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
266

267
        return loss
268

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.