colossalai
79 строк · 2.5 Кб
1#!/usr/bin/env python
2# -*- encoding: utf-8 -*-
3
4import pprint
5
6import pytest
7import torch
8import torch.nn as nn
9
10import colossalai.legacy.nn as col_nn
11from colossalai.legacy.context.parallel_mode import ParallelMode
12from colossalai.legacy.core import global_context as gpc
13from colossalai.legacy.initialize import launch
14from colossalai.legacy.utils import is_using_pp
15from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
16from colossalai.logging import disable_existing_loggers
17from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
18
19
20def build_pipeline(model):
21from colossalai.legacy.pipeline.utils import partition_uniform
22
23pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
24pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
25depth = len(model)
26start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
27layers = []
28for i in range(depth):
29if start <= i < end:
30layers.append(model[i])
31else:
32layers.append(nn.Identity())
33return nn.Sequential(*tuple(layers))
34
35
36def check_equal(A, B):
37assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
38
39
40def check_checkpoint_1d(rank, world_size, port):
41config = dict(
42parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),
43)
44
45disable_existing_loggers()
46launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
47
48m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
49sd1 = m1.state_dict()
50if gpc.get_global_rank() == 0:
51print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
52save_checkpoint("test.pt", 0, m1)
53
54m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
55if is_using_pp():
56m2 = build_pipeline(m2)
57
58load_checkpoint("test.pt", m2)
59sd2 = m2.state_dict()
60if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
61sd2 = gather_pipeline_parallel_state_dict(sd2)
62print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
63
64if gpc.get_global_rank() == 0:
65for k, v in sd1.items():
66assert k in sd2
67check_equal(v, sd2[k].to(torch.device("cpu")))
68
69
70@pytest.mark.dist
71@pytest.mark.skip("takes too long")
72@skip_if_not_enough_gpus(min_gpus=8)
73@rerun_if_address_is_in_use()
74def test_checkpoint_1d():
75spawn(check_checkpoint_1d, 8)
76
77
78if __name__ == "__main__":
79test_checkpoint_1d()
80