colossalai
171 строка · 5.9 Кб
1# This test checks adam kernels
2# Baseline is pure fp32 torch adam optimizer
3import math
4from abc import abstractmethod
5from typing import Type
6
7import pytest
8import torch
9from torch import Tensor
10
11from colossalai.accelerator import get_accelerator
12from colossalai.utils import multi_tensor_applier
13
14_FUSED_ALLOWED_P_G_TYPES = [
15(torch.float, torch.half),
16(torch.float, torch.float),
17(torch.half, torch.half),
18(torch.float, torch.bfloat16),
19(torch.bfloat16, torch.bfloat16),
20]
21
22_CPU_ALLOWED_P_G_TYPES = [
23(torch.float, torch.half),
24(torch.float, torch.float),
25(torch.half, torch.half),
26]
27
28
29class AdamKernel:
30def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
31self.lr = lr
32self.beta1 = beta1
33self.beta2 = beta2
34self.eps = eps
35self.weight_decay = weight_decay
36self.use_adamw = use_adamw
37
38@abstractmethod
39def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
40pass
41
42
43class TorchAdamKernel(AdamKernel):
44def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
45bias_correction1 = 1 - self.beta1**step
46bias_correction2 = 1 - self.beta2**step
47
48if self.weight_decay != 0:
49if self.use_adamw:
50# Perform stepweight decay
51param.mul_(1 - self.lr * self.weight_decay)
52else:
53grad = grad.add(param, alpha=self.weight_decay)
54
55# Decay the first and second moment running average coefficient
56exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1)
57exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)
58denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
59
60step_size = self.lr / bias_correction1
61
62param.addcdiv_(exp_avg, denom, value=-step_size)
63
64
65class FusedAdamKernel(AdamKernel):
66def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
67super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
68from colossalai.kernel.kernel_loader import FusedOptimizerLoader
69
70fused_optim = FusedOptimizerLoader().load()
71self.fused_adam = fused_optim.multi_tensor_adam
72self.dummy_overflow_buf = torch.cuda.IntTensor([0])
73
74def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
75multi_tensor_applier(
76self.fused_adam,
77self.dummy_overflow_buf,
78[[grad], [param], [exp_avg], [exp_avg_sq]],
79self.lr,
80self.beta1,
81self.beta2,
82self.eps,
83step,
84self.use_adamw,
85True,
86self.weight_decay,
87-1,
88)
89
90
91class CPUAdamKernel(AdamKernel):
92def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
93super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
94from colossalai.kernel.kernel_loader import CPUAdamLoader
95
96cpu_optim = CPUAdamLoader().load()
97
98self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw)
99
100def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
101self.cpu_adam_op.step(
102step,
103self.lr,
104self.beta1,
105self.beta2,
106self.eps,
107self.weight_decay,
108True,
109param.view(-1),
110grad.view(-1),
111exp_avg.view(-1),
112exp_avg_sq.view(-1),
113-1,
114)
115
116
117def check_adam_kernel(
118kernel: Type[AdamKernel],
119adamw: bool,
120weight_decay: float,
121p_dtype: torch.dtype,
122g_dtype: torch.dtype,
123device: torch.device,
124n_steps: int,
125rtol: float,
126atol: float,
127):
128lr = 1e-3
129beta1, beta2 = 0.9, 0.999
130eps = 1e-8
131torch_adam = TorchAdamKernel(lr, beta1, beta2, eps, weight_decay, adamw)
132adam_kernel = kernel(lr, beta1, beta2, eps, weight_decay, adamw)
133master_p = torch.rand(64, device=device)
134master_g = torch.rand_like(master_p)
135master_exp_avg = torch.zeros_like(master_p)
136master_exp_avg_sq = torch.zeros_like(master_p)
137p = master_p.clone().to(p_dtype)
138g = master_g.clone().to(g_dtype)
139exp_avg = master_exp_avg.clone().to(p_dtype)
140exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype)
141
142for step in range(1, 1 + n_steps):
143torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)
144adam_kernel.update(step, p, g, exp_avg, exp_avg_sq)
145# if overflow, the weight won't be updated. so there will be no nan in p
146assert not torch.isnan(p).any()
147assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol)
148
149
150@pytest.mark.parametrize("adamw", [False, True])
151@pytest.mark.parametrize("weight_decay", [0.0, 0.1])
152@pytest.mark.parametrize("p_dtype, g_dtype", _FUSED_ALLOWED_P_G_TYPES)
153def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
154rtol, atol = 1e-5, 1e-8
155if p_dtype is torch.float16 or g_dtype is torch.float16:
156rtol, atol = 1e-3, 1e-3
157if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
158rtol, atol = 4e-3, 4e-3
159check_adam_kernel(
160FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_accelerator().get_current_device(), 3, rtol, atol
161)
162
163
164@pytest.mark.parametrize("adamw", [False, True])
165@pytest.mark.parametrize("weight_decay", [0.0, 0.1])
166@pytest.mark.parametrize("p_dtype, g_dtype", _CPU_ALLOWED_P_G_TYPES)
167def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
168rtol, atol = 1e-5, 1e-8
169if p_dtype is torch.float16 or g_dtype is torch.float16:
170rtol, atol = 1e-3, 1e-3
171check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device("cpu"), 3, rtol, atol)
172