colossalai

Форк
0
118 строк · 3.2 Кб
1
import copy
2

3
import pytest
4
import torch
5
import torch.nn as nn
6
from torch.nn.parallel import DistributedDataParallel as DDP
7
from torch.testing import assert_close
8

9
import colossalai
10
from colossalai.testing import rerun_if_address_is_in_use, spawn
11
from colossalai.testing.random import seed_all
12
from colossalai.zero import LowLevelZeroOptimizer
13

14

15
class MlpModel(nn.Module):
16
    def __init__(self):
17
        super(MlpModel, self).__init__()
18
        self.linear1 = nn.Linear(12, 24)
19
        self.linear2 = nn.Linear(24, 12)
20

21
    def forward(self, x):
22
        x = self.linear1(x)
23
        x = self.linear2(x)
24
        return x
25

26

27
def loose_close(a, b, dtype: torch.dtype = torch.float32):
28
    rtol = None
29
    atol = None
30
    if dtype is torch.float16:
31
        rtol = 5e-2
32
        atol = 5e-4
33
    elif dtype is torch.bfloat16:
34
        rtol = 4e-3
35
        atol = 4e-3
36

37
    a = a.detach().to(dtype)
38
    b = b.detach().to(dtype).to(a.device)
39

40
    assert_close(a, b, rtol=rtol, atol=atol)
41

42

43
def exam_zero_1_torch_ddp_ckpt():
44
    """
45
    We examine the state_dict of zero and DDP.
46
    Moreover, we examine the zero's loading checkpoint of a torch ckpt.
47
    """
48
    local_rank = torch.distributed.get_rank()
49
    seed_all(1453)
50

51
    # create models
52
    torch_model = MlpModel().cuda()
53
    zero_model = copy.deepcopy(torch_model)
54

55
    torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
56

57
    # create optimizer
58
    zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
59

60
    # we only test stage 1 here
61
    # the state dicts of stage 1 and stage 2 are the same
62
    zero_optimizer = LowLevelZeroOptimizer(
63
        zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=262144
64
    )
65

66
    torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
67

68
    seed_all(1453 + local_rank)
69
    # create
70
    input_data = torch.rand(4, 12).cuda()
71

72
    # forward
73
    zero_output = zero_model(input_data)
74
    torch_output = torch_model(input_data)
75

76
    # backward
77
    zero_optimizer.backward(zero_output.mean().float())
78
    torch_output.mean().backward()
79

80
    # step
81
    zero_optimizer.step()
82
    torch_optimizer.step()
83

84
    torch_state_dict = torch_optimizer.state_dict()
85
    zero_state_dict = zero_optimizer.state_dict()
86

87
    # examine the original state dict
88
    for torch_state, zero_state in zip(torch_state_dict["state"].values(), zero_state_dict["state"].values()):
89
        for t_v, z_v in zip(torch_state.values(), zero_state.values()):
90
            loose_close(t_v, z_v)
91

92
    # empty the optimzer state
93
    zero_optimizer.optim.state = []
94

95
    # zero load a torch checkpoint
96
    zero_optimizer.load_state_dict(copy.deepcopy(torch_state_dict))
97
    zero_state_dict = zero_optimizer.state_dict()
98

99
    # examine the loaded state dict
100
    for torch_state, zero_state in zip(torch_state_dict["state"].values(), zero_state_dict["state"].values()):
101
        for t_v, z_v in zip(torch_state.values(), zero_state.values()):
102
            loose_close(t_v, z_v)
103

104

105
def run_dist(rank, world_size, port):
106
    colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
107

108
    exam_zero_1_torch_ddp_ckpt()
109

110

111
@pytest.mark.dist
112
@rerun_if_address_is_in_use()
113
def test_zero_ckpt():
114
    spawn(run_dist, 2)
115

116

117
if __name__ == "__main__":
118
    test_zero_ckpt()
119

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.