CSS-LM
267 строк · 11.3 Кб
1# coding=utf-8
2# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""PyTorch optimization for BERT model."""
16
17import logging18import math19from typing import Callable, Iterable, Tuple20
21import torch22from torch.optim import Optimizer23from torch.optim.lr_scheduler import LambdaLR24
25
26logger = logging.getLogger(__name__)27
28
29def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):30"""31Create a schedule with a constant learning rate, using the learning rate set in optimizer.
32
33Args:
34optimizer (:class:`~torch.optim.Optimizer`):
35The optimizer for which to schedule the learning rate.
36last_epoch (:obj:`int`, `optional`, defaults to -1):
37The index of the last epoch when resuming training.
38
39Return:
40:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
41"""
42return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)43
44
45def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):46"""47Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
48increases linearly between 0 and the initial lr set in the optimizer.
49
50Args:
51optimizer (:class:`~torch.optim.Optimizer`):
52The optimizer for which to schedule the learning rate.
53num_warmup_steps (:obj:`int`):
54The number of steps for the warmup phase.
55last_epoch (:obj:`int`, `optional`, defaults to -1):
56The index of the last epoch when resuming training.
57
58Return:
59:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
60"""
61
62def lr_lambda(current_step: int):63if current_step < num_warmup_steps:64return float(current_step) / float(max(1.0, num_warmup_steps))65return 1.066
67return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)68
69
70def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):71"""72Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0,
73after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
74
75Args:
76optimizer (:class:`~torch.optim.Optimizer`):
77The optimizer for which to schedule the learning rate.
78num_warmup_steps (:obj:`int`):
79The number of steps for the warmup phase.
80num_training_steps (:obj:`int`):
81The totale number of training steps.
82last_epoch (:obj:`int`, `optional`, defaults to -1):
83The index of the last epoch when resuming training.
84
85Return:
86:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
87"""
88
89def lr_lambda(current_step: int):90if current_step < num_warmup_steps:91return float(current_step) / float(max(1, num_warmup_steps))92return max(930.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))94)95
96return LambdaLR(optimizer, lr_lambda, last_epoch)97
98
99def get_cosine_schedule_with_warmup(100optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1101):102"""103Create a schedule with a learning rate that decreases following the values of the cosine function between the
104initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
105initial lr set in the optimizer.
106
107Args:
108optimizer (:class:`~torch.optim.Optimizer`):
109The optimizer for which to schedule the learning rate.
110num_warmup_steps (:obj:`int`):
111The number of steps for the warmup phase.
112num_training_steps (:obj:`int`):
113The total number of training steps.
114num_cycles (:obj:`float`, `optional`, defaults to 0.5):
115The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
116following a half-cosine).
117last_epoch (:obj:`int`, `optional`, defaults to -1):
118The index of the last epoch when resuming training.
119
120Return:
121:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
122"""
123
124def lr_lambda(current_step):125if current_step < num_warmup_steps:126return float(current_step) / float(max(1, num_warmup_steps))127progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))128return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))129
130return LambdaLR(optimizer, lr_lambda, last_epoch)131
132
133def get_cosine_with_hard_restarts_schedule_with_warmup(134optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1135):136"""137Create a schedule with a learning rate that decreases following the values of the cosine function between the
138initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
139linearly between 0 and the initial lr set in the optimizer.
140
141Args:
142optimizer (:class:`~torch.optim.Optimizer`):
143The optimizer for which to schedule the learning rate.
144num_warmup_steps (:obj:`int`):
145The number of steps for the warmup phase.
146num_training_steps (:obj:`int`):
147The total number of training steps.
148num_cycles (:obj:`int`, `optional`, defaults to 1):
149The number of hard restarts to use.
150last_epoch (:obj:`int`, `optional`, defaults to -1):
151The index of the last epoch when resuming training.
152
153Return:
154:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
155"""
156
157def lr_lambda(current_step):158if current_step < num_warmup_steps:159return float(current_step) / float(max(1, num_warmup_steps))160progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))161if progress >= 1.0:162return 0.0163return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))164
165return LambdaLR(optimizer, lr_lambda, last_epoch)166
167
168class AdamW(Optimizer):169"""170Implements Adam algorithm with weight decay fix as introduced in
171`Decoupled Weight Decay Regularization <https://arxiv.org/abs/1711.05101>`__.
172
173Parameters:
174params (:obj:`Iterable[torch.nn.parameter.Parameter]`):
175Iterable of parameters to optimize or dictionaries defining parameter groups.
176lr (:obj:`float`, `optional`, defaults to 1e-3):
177The learning rate to use.
178betas (:obj:`Tuple[float,float]`, `optional`, defaults to (0.9, 0.999)):
179Adam's betas parameters (b1, b2).
180eps (:obj:`float`, `optional`, defaults to 1e-6):
181Adam's epsilon for numerical stability.
182weight_decay (:obj:`float`, `optional`, defaults to 0):
183Decoupled weight decay to apply.
184correct_bias (:obj:`bool`, `optional`, defaults to `True`):
185Whether ot not to correct bias in Adam (for instance, in Bert TF repository they use :obj:`False`).
186"""
187
188def __init__(189self,190params: Iterable[torch.nn.parameter.Parameter],191lr: float = 1e-3,192betas: Tuple[float, float] = (0.9, 0.999),193eps: float = 1e-6,194weight_decay: float = 0.0,195correct_bias: bool = True,196):197if lr < 0.0:198raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))199if not 0.0 <= betas[0] < 1.0:200raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))201if not 0.0 <= betas[1] < 1.0:202raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))203if not 0.0 <= eps:204raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))205defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)206super().__init__(params, defaults)207
208def step(self, closure: Callable = None):209"""210Performs a single optimization step.
211
212Arguments:
213closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss.
214"""
215loss = None216if closure is not None:217loss = closure()218
219for group in self.param_groups:220for p in group["params"]:221if p.grad is None:222continue223grad = p.grad.data224if grad.is_sparse:225raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")226
227state = self.state[p]228
229# State initialization230if len(state) == 0:231state["step"] = 0232# Exponential moving average of gradient values233state["exp_avg"] = torch.zeros_like(p.data)234# Exponential moving average of squared gradient values235state["exp_avg_sq"] = torch.zeros_like(p.data)236
237exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]238beta1, beta2 = group["betas"]239
240state["step"] += 1241
242# Decay the first and second moment running average coefficient243# In-place operations to update the averages at the same time244exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)245exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)246denom = exp_avg_sq.sqrt().add_(group["eps"])247
248step_size = group["lr"]249if group["correct_bias"]: # No bias correction for Bert250bias_correction1 = 1.0 - beta1 ** state["step"]251bias_correction2 = 1.0 - beta2 ** state["step"]252step_size = step_size * math.sqrt(bias_correction2) / bias_correction1253
254p.data.addcdiv_(exp_avg, denom, value=-step_size)255
256# Just adding the square of the weights to the loss function is *not*257# the correct way of using L2 regularization/weight decay with Adam,258# since that will interact with the m and v parameters in strange ways.259#260# Instead we want to decay the weights in a manner that doesn't interact261# with the m/v parameters. This is equivalent to adding the square262# of the weights to the loss with plain (non-momentum) SGD.263# Add weight decay at the end (fixed version)264if group["weight_decay"] > 0.0:265p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])266
267return loss268