colossalai

Форк
0
79 строк · 2.5 Кб
1
#!/usr/bin/env python
2
# -*- encoding: utf-8 -*-
3

4
import pprint
5

6
import pytest
7
import torch
8
import torch.nn as nn
9

10
import colossalai.legacy.nn as col_nn
11
from colossalai.legacy.context.parallel_mode import ParallelMode
12
from colossalai.legacy.core import global_context as gpc
13
from colossalai.legacy.initialize import launch
14
from colossalai.legacy.utils import is_using_pp
15
from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
16
from colossalai.logging import disable_existing_loggers
17
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
18

19

20
def build_pipeline(model):
21
    from colossalai.legacy.pipeline.utils import partition_uniform
22

23
    pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
24
    pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
25
    depth = len(model)
26
    start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
27
    layers = []
28
    for i in range(depth):
29
        if start <= i < end:
30
            layers.append(model[i])
31
        else:
32
            layers.append(nn.Identity())
33
    return nn.Sequential(*tuple(layers))
34

35

36
def check_equal(A, B):
37
    assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
38

39

40
def check_checkpoint_2p5d(rank, world_size, port):
41
    config = dict(
42
        parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),
43
    )
44

45
    disable_existing_loggers()
46
    launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
47

48
    m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
49
    sd1 = m1.state_dict()
50
    if gpc.get_global_rank() == 0:
51
        print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
52
    save_checkpoint("test.pt", 0, m1)
53

54
    m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
55
    if is_using_pp():
56
        m2 = build_pipeline(m2)
57

58
    load_checkpoint("test.pt", m2)
59
    sd2 = m2.state_dict()
60
    if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
61
        sd2 = gather_pipeline_parallel_state_dict(sd2)
62
    print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
63

64
    if gpc.get_global_rank() == 0:
65
        for k, v in sd1.items():
66
            assert k in sd2
67
            check_equal(v, sd2[k].to(torch.device("cpu")))
68

69

70
@pytest.mark.dist
71
@pytest.mark.skip("takes too long")
72
@skip_if_not_enough_gpus(min_gpus=8)
73
@rerun_if_address_is_in_use()
74
def test_checkpoint_2p5d():
75
    spawn(check_checkpoint_2p5d, 8)
76

77

78
if __name__ == "__main__":
79
    test_checkpoint_2p5d()
80

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.