colossalai

Форк
0
/
test_adam_kernel.py 
171 строка · 5.9 Кб
1
# This test checks adam kernels
2
# Baseline is pure fp32 torch adam optimizer
3
import math
4
from abc import abstractmethod
5
from typing import Type
6

7
import pytest
8
import torch
9
from torch import Tensor
10

11
from colossalai.accelerator import get_accelerator
12
from colossalai.utils import multi_tensor_applier
13

14
_FUSED_ALLOWED_P_G_TYPES = [
15
    (torch.float, torch.half),
16
    (torch.float, torch.float),
17
    (torch.half, torch.half),
18
    (torch.float, torch.bfloat16),
19
    (torch.bfloat16, torch.bfloat16),
20
]
21

22
_CPU_ALLOWED_P_G_TYPES = [
23
    (torch.float, torch.half),
24
    (torch.float, torch.float),
25
    (torch.half, torch.half),
26
]
27

28

29
class AdamKernel:
30
    def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
31
        self.lr = lr
32
        self.beta1 = beta1
33
        self.beta2 = beta2
34
        self.eps = eps
35
        self.weight_decay = weight_decay
36
        self.use_adamw = use_adamw
37

38
    @abstractmethod
39
    def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
40
        pass
41

42

43
class TorchAdamKernel(AdamKernel):
44
    def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
45
        bias_correction1 = 1 - self.beta1**step
46
        bias_correction2 = 1 - self.beta2**step
47

48
        if self.weight_decay != 0:
49
            if self.use_adamw:
50
                # Perform stepweight decay
51
                param.mul_(1 - self.lr * self.weight_decay)
52
            else:
53
                grad = grad.add(param, alpha=self.weight_decay)
54

55
        # Decay the first and second moment running average coefficient
56
        exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1)
57
        exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)
58
        denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
59

60
        step_size = self.lr / bias_correction1
61

62
        param.addcdiv_(exp_avg, denom, value=-step_size)
63

64

65
class FusedAdamKernel(AdamKernel):
66
    def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
67
        super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
68
        from colossalai.kernel.kernel_loader import FusedOptimizerLoader
69

70
        fused_optim = FusedOptimizerLoader().load()
71
        self.fused_adam = fused_optim.multi_tensor_adam
72
        self.dummy_overflow_buf = torch.cuda.IntTensor([0])
73

74
    def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
75
        multi_tensor_applier(
76
            self.fused_adam,
77
            self.dummy_overflow_buf,
78
            [[grad], [param], [exp_avg], [exp_avg_sq]],
79
            self.lr,
80
            self.beta1,
81
            self.beta2,
82
            self.eps,
83
            step,
84
            self.use_adamw,
85
            True,
86
            self.weight_decay,
87
            -1,
88
        )
89

90

91
class CPUAdamKernel(AdamKernel):
92
    def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
93
        super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
94
        from colossalai.kernel.kernel_loader import CPUAdamLoader
95

96
        cpu_optim = CPUAdamLoader().load()
97

98
        self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw)
99

100
    def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
101
        self.cpu_adam_op.step(
102
            step,
103
            self.lr,
104
            self.beta1,
105
            self.beta2,
106
            self.eps,
107
            self.weight_decay,
108
            True,
109
            param.view(-1),
110
            grad.view(-1),
111
            exp_avg.view(-1),
112
            exp_avg_sq.view(-1),
113
            -1,
114
        )
115

116

117
def check_adam_kernel(
118
    kernel: Type[AdamKernel],
119
    adamw: bool,
120
    weight_decay: float,
121
    p_dtype: torch.dtype,
122
    g_dtype: torch.dtype,
123
    device: torch.device,
124
    n_steps: int,
125
    rtol: float,
126
    atol: float,
127
):
128
    lr = 1e-3
129
    beta1, beta2 = 0.9, 0.999
130
    eps = 1e-8
131
    torch_adam = TorchAdamKernel(lr, beta1, beta2, eps, weight_decay, adamw)
132
    adam_kernel = kernel(lr, beta1, beta2, eps, weight_decay, adamw)
133
    master_p = torch.rand(64, device=device)
134
    master_g = torch.rand_like(master_p)
135
    master_exp_avg = torch.zeros_like(master_p)
136
    master_exp_avg_sq = torch.zeros_like(master_p)
137
    p = master_p.clone().to(p_dtype)
138
    g = master_g.clone().to(g_dtype)
139
    exp_avg = master_exp_avg.clone().to(p_dtype)
140
    exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype)
141

142
    for step in range(1, 1 + n_steps):
143
        torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)
144
        adam_kernel.update(step, p, g, exp_avg, exp_avg_sq)
145
        # if overflow, the weight won't be updated. so there will be no nan in p
146
        assert not torch.isnan(p).any()
147
        assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol)
148

149

150
@pytest.mark.parametrize("adamw", [False, True])
151
@pytest.mark.parametrize("weight_decay", [0.0, 0.1])
152
@pytest.mark.parametrize("p_dtype, g_dtype", _FUSED_ALLOWED_P_G_TYPES)
153
def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
154
    rtol, atol = 1e-5, 1e-8
155
    if p_dtype is torch.float16 or g_dtype is torch.float16:
156
        rtol, atol = 1e-3, 1e-3
157
    if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
158
        rtol, atol = 4e-3, 4e-3
159
    check_adam_kernel(
160
        FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_accelerator().get_current_device(), 3, rtol, atol
161
    )
162

163

164
@pytest.mark.parametrize("adamw", [False, True])
165
@pytest.mark.parametrize("weight_decay", [0.0, 0.1])
166
@pytest.mark.parametrize("p_dtype, g_dtype", _CPU_ALLOWED_P_G_TYPES)
167
def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
168
    rtol, atol = 1e-5, 1e-8
169
    if p_dtype is torch.float16 or g_dtype is torch.float16:
170
        rtol, atol = 1e-3, 1e-3
171
    check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device("cpu"), 3, rtol, atol)
172

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.