colossalai

Форк
0
116 строк · 4.1 Кб
1
import pytest
2
import torch
3
import torch.distributed as dist
4
from torch.nn.parallel import DistributedDataParallel as DDP
5
from torch.testing import assert_close
6

7
import colossalai
8
from colossalai.accelerator import get_accelerator
9
from colossalai.legacy.amp import convert_to_apex_amp
10
from colossalai.nn.optimizer import HybridAdam
11
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
12
from colossalai.utils import set_seed
13
from colossalai.zero import GeminiDDP, GeminiOptimizer
14
from colossalai.zero.gemini.chunk import search_chunk_configuration
15
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
16

17
PLACEMENT_CONFIGS = [
18
    {"placement_policy": "static", "shard_param_frac": 0.0},  # zero2
19
    {"placement_policy": "static", "shard_param_frac": 1.0},  # zero3
20
    {"placement_policy": "static", "shard_param_frac": 0.5},  # zero3-half
21
    {"placement_policy": "auto"},
22
]
23

24

25
def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
26
    chunk_manager = model.chunk_manager
27
    param_list = [p for p in model.parameters()]
28
    chunk_list = chunk_manager.get_chunks(param_list)
29
    if not model.reuse_fp16_chunk:
30
        chunk_list = [chunk.grad_chunk for chunk in chunk_list]
31
    for chunk in chunk_list:
32
        chunk_manager.access_chunk(chunk)
33

34
    for p0, p1 in zip(model.parameters(), torch_model.parameters()):
35
        assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
36

37

38
@parameterize("placement_config", PLACEMENT_CONFIGS)
39
@parameterize("keep_gather", [False, True])
40
@parameterize("model_name", ["transformers_gpt_lm"])
41
@parameterize("use_grad_checkpoint", [False, True])
42
@parameterize("master_weights", [False, True])
43
def exam_gpt_fwd_bwd(
44
    placement_config,
45
    keep_gather,
46
    model_name: str,
47
    use_grad_checkpoint: bool = False,
48
    master_weights: bool = True,
49
):
50
    init_device = get_accelerator().get_current_device()
51
    model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
52
        iter(model_zoo.get_sub_registry(model_name).values())
53
    )
54

55
    set_seed(42)
56
    model = model_builder()
57

58
    set_seed(42)
59
    torch_model = model_builder().cuda()
60
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
61
        torch_p.data.copy_(p.data)
62

63
    if use_grad_checkpoint:
64
        model.gradient_checkpointing_enable()
65
        torch_model.gradient_checkpointing_enable()
66

67
    world_size = torch.distributed.get_world_size()
68
    config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
69
    config_dict[world_size]["chunk_size"] = 5000
70
    config_dict[world_size]["keep_gathered"] = keep_gather
71
    model = GeminiDDP(
72
        model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights
73
    )
74
    optimizer = HybridAdam(model.parameters(), lr=1e-3)
75
    zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
76

77
    rank = dist.get_rank()
78
    amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, master_weights=master_weights)
79
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
80
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
81
    torch_model = DDP(torch_model, device_ids=[rank])
82

83
    set_seed(rank)
84

85
    data = data_gen_fn()
86
    data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
87

88
    torch_optim.zero_grad()
89
    zero_optim.zero_grad()
90

91
    # set random seed is same as torch_model.eval()
92
    set_seed(42)
93
    torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
94
    set_seed(42)
95
    loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
96

97
    assert_close(torch_loss.float(), loss.float())
98

99
    check_grad(model, torch_model)
100

101

102
def run_dist(rank, world_size, port):
103
    config = {}
104
    colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
105
    exam_gpt_fwd_bwd()
106

107

108
@pytest.mark.dist
109
@pytest.mark.parametrize("world_size", [1, 4])
110
@rerun_if_address_is_in_use()
111
def test_gpt(world_size):
112
    spawn(run_dist, world_size)
113

114

115
if __name__ == "__main__":
116
    test_gpt(1)
117

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.