colossalai
89 строк · 3.1 Кб
1from copy import deepcopy
2from typing import Type, Union
3
4import pytest
5import torch
6import torch.nn as nn
7from torch.optim import Adam, AdamW
8
9from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
10from tests.kit.model_zoo import model_zoo
11
12_ALLOWED_OPTIM_DEVICES = [
13(FusedAdam, torch.device("cuda:0")),
14(CPUAdam, torch.device("cpu")),
15(CPUAdam, torch.device("cuda:0")),
16(HybridAdam, torch.device("cpu")),
17(HybridAdam, torch.device("cuda:0")),
18]
19
20_ALLOWED_P_G_TYPES = [
21(torch.float, torch.float), # pure fp32
22(torch.float, torch.half), # fp16 amp
23(torch.float, torch.bfloat16), # bfloat16 amp
24]
25
26N_STEPS = 3
27
28
29def setup_param_groups(bert_model: nn.Module) -> list:
30no_decay = ["bias", "LayerNorm.weight"]
31optimizer_grouped_parameters = [
32{
33"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
34"weight_decay": 0.1,
35},
36{
37"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
38"weight_decay": 0.0,
39},
40]
41return optimizer_grouped_parameters
42
43
44def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
45for p, torch_p in zip(model.parameters(), torch_model.parameters()):
46torch_p.grad = torch.rand_like(torch_p)
47# avoid inconsistent grad and param dtype error
48orig_p = p.data
49p.data = torch_p.grad.clone().to(g_dtype)
50p.grad = p.data
51p.data = orig_p
52
53
54@pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES)
55@pytest.mark.parametrize("adamw", [False, True])
56@pytest.mark.parametrize("p_dtype, g_dtype", _ALLOWED_P_G_TYPES)
57def test_adam_optim_on_bert(
58optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]],
59device: torch.device,
60adamw: bool,
61p_dtype: torch.dtype,
62g_dtype: torch.dtype,
63) -> None:
64model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_bert_for_sequence_classification").values()))
65torch_model = model_fn().to(device)
66model = deepcopy(torch_model).to(p_dtype)
67lr = 1e-3
68beta1, beta2 = 0.9, 0.999
69eps = 1e-8
70torch_optim_cls = AdamW if adamw else Adam
71torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps)
72optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw)
73
74rtol, atol = 1e-5, 1e-5
75if p_dtype is torch.float16 or g_dtype is torch.float16:
76rtol, atol = 2e-3, 2e-3
77if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
78rtol, atol = 4e-3, 4e-3
79
80for _ in range(N_STEPS):
81set_grad(model, torch_model, g_dtype)
82torch_optim.step()
83optim.step()
84torch_optim.zero_grad()
85optim.zero_grad()
86for p, torch_p in zip(model.parameters(), torch_model.parameters()):
87# if overflow, the weight won't be updated. so there will be no nan in p
88assert not torch.isnan(p).any()
89assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol)
90