colossalai

Форк
0
/
test_grad_accum.py 
153 строки · 5.6 Кб
1
import pytest
2
import torch
3
import torch.distributed as dist
4
from apex import amp
5
from torch.nn.parallel import DistributedDataParallel as DDP
6
from torch.testing import assert_close
7

8
import colossalai
9
from colossalai.accelerator import get_accelerator
10
from colossalai.nn.optimizer import HybridAdam
11
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
12
from colossalai.utils import set_seed
13
from colossalai.zero import GeminiDDP, GeminiOptimizer
14
from colossalai.zero.gemini.chunk import search_chunk_configuration
15
from tests.kit.model_zoo import model_zoo, run_fwd
16

17
PLACEMENT_CONFIGS = [
18
    {"placement_policy": "static", "shard_param_frac": 0.0},  # zero2
19
    {"placement_policy": "static", "shard_param_frac": 1.0},  # zero3
20
    {"placement_policy": "static", "shard_param_frac": 0.5},  # zero3-half
21
    {"placement_policy": "auto"},
22
]
23

24

25
def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
26
    chunk_manager = model.chunk_manager
27
    grad_chunk_list = []
28
    device_list = []
29

30
    # Access gradient chunks.
31
    for p in model.parameters():
32
        grad_chunk = chunk_manager.get_chunk(p).grad_chunk
33
        if grad_chunk not in grad_chunk_list:
34
            chunk_manager.access_chunk(grad_chunk)
35
            grad_chunk_list.append(grad_chunk)
36
            device_list.append(model.grads_device[p])
37

38
    # Compare gradients.
39
    for p0, p1 in zip(model.parameters(), torch_model.parameters()):
40
        assert_close(p0, p1.grad, rtol=2e-3, atol=2e-2)
41

42
    # Release gradient chunks and move them to gradient device.
43
    for grad_chunk, device in zip(grad_chunk_list, device_list):
44
        chunk_manager.release_chunk(grad_chunk)
45
        chunk_manager.move_chunk(grad_chunk, device, force_copy=True)
46

47

48
@parameterize("placement_config", PLACEMENT_CONFIGS)
49
@parameterize("keep_gathered", [False, True])
50
@parameterize("model_name", ["transformers_gpt_lm"])
51
@parameterize("master_weights", [False, True])
52
@parameterize("use_grad_checkpoint", [False, True])
53
def exam_gemini_grad_acc(
54
    placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
55
):
56
    init_device = get_accelerator().get_current_device()
57
    model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
58
        iter(model_zoo.get_sub_registry(model_name).values())
59
    )
60

61
    set_seed(42)
62
    gemini_model = model_builder()
63

64
    set_seed(42)
65
    torch_model = model_builder().cuda()
66
    for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
67
        torch_p.data.copy_(p.data)
68

69
    if use_grad_checkpoint:
70
        gemini_model.gradient_checkpointing_enable()
71
        torch_model.gradient_checkpointing_enable()
72

73
    world_size = torch.distributed.get_world_size()
74
    config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100)
75
    config_dict[world_size]["chunk_size"] = 5000
76
    config_dict[world_size]["keep_gathered"] = keep_gathered
77
    gemini_model = GeminiDDP(
78
        gemini_model,
79
        config_dict,
80
        init_device,
81
        pin_memory=True,
82
        enable_gradient_accumulation=True,
83
        master_weights=master_weights,
84
        **placement_config,
85
    )
86
    optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
87
    gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0)
88

89
    rank = dist.get_rank()
90

91
    # setting master_weights to False will cause overflow after optimizer.step()
92
    amp_config = dict(
93
        opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, min_loss_scale=1, max_loss_scale=1, master_weights=True
94
    )
95
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
96
    torch_model, torch_optim = amp.initialize(torch_model, torch_optim, **amp_config)
97
    torch_model = DDP(torch_model, device_ids=[rank])
98

99
    set_seed(rank)
100
    accum_iter = 4
101
    train_dataloader = DummyDataloader(data_gen_fn)
102
    for i, data in enumerate(train_dataloader):
103
        delay_unscale = False if (i + 1) % accum_iter == 0 else True
104
        data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
105

106
        set_seed(42 + rank)
107
        torch_loss = run_fwd(torch_model, data, output_transform_fn, loss_fn)
108
        torch_loss = torch_loss / accum_iter
109
        with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss:
110
            scaled_loss.backward()
111

112
        set_seed(42 + rank)
113
        gemini_loss = run_fwd(gemini_model, data, output_transform_fn, loss_fn)
114
        gemini_loss = gemini_loss / accum_iter
115
        gemini_optim.backward(gemini_loss)
116

117
        assert torch.allclose(torch_loss.float(), gemini_loss.float(), rtol=1e-3, atol=1e-5)
118

119
        check_grad(gemini_model, torch_model)
120

121
        if (i + 1) % accum_iter == 0:
122
            torch.nn.utils.clip_grad_norm_(amp.master_params(torch_optim), 1.0)
123
            torch_optim.step()
124
            gemini_optim.step()
125
            torch_optim.zero_grad()
126

127
            # check updated param
128
            torch_dict = torch_model.state_dict()
129
            gemini_dict = gemini_model.state_dict(only_rank_0=False)
130

131
            for key, value in gemini_dict.items():
132
                torch_key = "module." + key
133
                torch_value = torch_dict[torch_key].to(value.device).to(value.dtype)
134
                assert_close(value, torch_value, rtol=1e-3, atol=2e-3)
135

136
        if i == accum_iter:
137
            break
138

139

140
def run_dist(rank, world_size, port):
141
    config = {}
142
    colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
143
    exam_gemini_grad_acc()
144

145

146
@pytest.mark.dist
147
@rerun_if_address_is_in_use()
148
def test_grad_accumulation():
149
    spawn(run_dist, 2)
150

151

152
if __name__ == "__main__":
153
    test_grad_accumulation()
154

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.