colossalai

Форк
0
/
test_adam_optim.py 
89 строк · 3.1 Кб
1
from copy import deepcopy
2
from typing import Type, Union
3

4
import pytest
5
import torch
6
import torch.nn as nn
7
from torch.optim import Adam, AdamW
8

9
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
10
from tests.kit.model_zoo import model_zoo
11

12
_ALLOWED_OPTIM_DEVICES = [
13
    (FusedAdam, torch.device("cuda:0")),
14
    (CPUAdam, torch.device("cpu")),
15
    (CPUAdam, torch.device("cuda:0")),
16
    (HybridAdam, torch.device("cpu")),
17
    (HybridAdam, torch.device("cuda:0")),
18
]
19

20
_ALLOWED_P_G_TYPES = [
21
    (torch.float, torch.float),  # pure fp32
22
    (torch.float, torch.half),  # fp16 amp
23
    (torch.float, torch.bfloat16),  # bfloat16 amp
24
]
25

26
N_STEPS = 3
27

28

29
def setup_param_groups(bert_model: nn.Module) -> list:
30
    no_decay = ["bias", "LayerNorm.weight"]
31
    optimizer_grouped_parameters = [
32
        {
33
            "params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
34
            "weight_decay": 0.1,
35
        },
36
        {
37
            "params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
38
            "weight_decay": 0.0,
39
        },
40
    ]
41
    return optimizer_grouped_parameters
42

43

44
def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
45
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
46
        torch_p.grad = torch.rand_like(torch_p)
47
        # avoid inconsistent grad and param dtype error
48
        orig_p = p.data
49
        p.data = torch_p.grad.clone().to(g_dtype)
50
        p.grad = p.data
51
        p.data = orig_p
52

53

54
@pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES)
55
@pytest.mark.parametrize("adamw", [False, True])
56
@pytest.mark.parametrize("p_dtype, g_dtype", _ALLOWED_P_G_TYPES)
57
def test_adam_optim_on_bert(
58
    optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]],
59
    device: torch.device,
60
    adamw: bool,
61
    p_dtype: torch.dtype,
62
    g_dtype: torch.dtype,
63
) -> None:
64
    model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_bert_for_sequence_classification").values()))
65
    torch_model = model_fn().to(device)
66
    model = deepcopy(torch_model).to(p_dtype)
67
    lr = 1e-3
68
    beta1, beta2 = 0.9, 0.999
69
    eps = 1e-8
70
    torch_optim_cls = AdamW if adamw else Adam
71
    torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps)
72
    optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw)
73

74
    rtol, atol = 1e-5, 1e-5
75
    if p_dtype is torch.float16 or g_dtype is torch.float16:
76
        rtol, atol = 2e-3, 2e-3
77
    if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
78
        rtol, atol = 4e-3, 4e-3
79

80
    for _ in range(N_STEPS):
81
        set_grad(model, torch_model, g_dtype)
82
        torch_optim.step()
83
        optim.step()
84
        torch_optim.zero_grad()
85
        optim.zero_grad()
86
        for p, torch_p in zip(model.parameters(), torch_model.parameters()):
87
            # if overflow, the weight won't be updated. so there will be no nan in p
88
            assert not torch.isnan(p).any()
89
            assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol)
90

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.