colossalai

Форк
0
/
test_moe_checkpoint.py 
202 строки · 6.1 Кб
1
import importlib
2
import os
3
import shutil
4
import sys
5

6
import pytest
7
import torch
8
import torch.distributed as dist
9
from transformers.models.llama import LlamaConfig
10

11
import colossalai
12
from colossalai.accelerator import get_accelerator
13
from colossalai.booster import Booster
14
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
15
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
16

17
sys.path.append(
18
    os.path.join(
19
        os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
20
        "examples/language/openmoe",
21
    )
22
)
23

24
OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
25
set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
26
OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
27

28

29
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
30
    input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_accelerator().get_current_device())
31
    attention_mask = torch.ones_like(input_ids)
32
    return {
33
        "input_ids": input_ids,
34
        "attention_mask": attention_mask,
35
        "labels": input_ids,
36
    }
37

38

39
def run_fwd_bwd(
40
    model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
41
):
42
    model.train()
43
    if pipeline:
44
        train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
45
        is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
46
        y = booster.execute_pipeline(
47
            train_dataloader_iter,
48
            model,
49
            lambda x, y: x.loss,
50
            optimizer,
51
            return_loss=True,
52
            return_outputs=True,
53
        )
54
        # Backward and optimize
55
        if is_pp_last_stage:
56
            loss = y["loss"]
57
    else:
58
        if criterion:
59
            y = model(data).logits
60
            loss = criterion(y)
61
        else:
62
            loss = model(data, label)
63
        loss = loss.float()
64

65
        if optimizer is not None:
66
            optimizer.backward(loss)
67
        else:
68
            loss.backward()
69
    return y
70

71

72
def get_config():
73
    config = LlamaConfig(
74
        vocab_size=300,
75
        hidden_size=16,
76
        intermediate_size=32,
77
        num_hidden_layers=2,
78
        num_attention_heads=2,
79
        head_dim=4,
80
        dropout_rate=0.0,
81
        hidden_act="swiglu",
82
    )
83
    set_openmoe_args(config, num_experts=8, moe_layer_interval=1)
84
    return config
85

86

87
def get_model(parallel):
88
    config = get_config()
89
    model = OpenMoeForCausalLM(config)
90
    optim = torch.optim.Adam(model.parameters())
91

92
    if parallel == None:
93
        plugin = MoeHybridParallelPlugin(
94
            precision="bf16",
95
            tp_size=1,
96
            pp_size=1,
97
            ep_size=1,
98
            zero_stage=2,
99
            custom_policy=OpenMoeForCausalLMPolicy(),
100
        )
101
    elif parallel == "ep":
102
        plugin = MoeHybridParallelPlugin(
103
            precision="bf16",
104
            tp_size=1,
105
            pp_size=1,
106
            ep_size=dist.get_world_size(),
107
            zero_stage=2,
108
            custom_policy=OpenMoeForCausalLMPolicy(),
109
        )
110
    elif parallel == "ep_zero":
111
        plugin = MoeHybridParallelPlugin(
112
            precision="bf16",
113
            tp_size=1,
114
            pp_size=1,
115
            ep_size=2,
116
            zero_stage=2,
117
            extra_dp_size=2,
118
            custom_policy=OpenMoeForCausalLMPolicy(),
119
        )
120
    elif parallel == "hybrid":
121
        plugin = MoeHybridParallelPlugin(
122
            precision="bf16",
123
            tp_size=1,
124
            pp_size=2,
125
            ep_size=2,
126
            zero_stage=1,
127
            microbatch_size=1,
128
            custom_policy=OpenMoeForCausalLMPolicy(),
129
        )
130
    booster = Booster(plugin=plugin)
131
    model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
132
    return model, booster, optim
133

134

135
def _test_moe_checkpoint(rank, parallel):
136
    model1, booster1, optim1 = get_model(parallel)
137
    model2, booster2, optim2 = get_model(parallel)
138
    model3, booster3, optim3 = get_model(parallel)
139

140
    # param ckpt
141
    # shard
142
    booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
143
    booster2.load_model(model2, "./tmp_ckpt1")
144
    # unshard
145
    booster1.save_model(model1, "./tmp_ckpt1.pth")
146
    booster3.load_model(model3, "./tmp_ckpt1.pth")
147
    # check
148
    check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
149
    check_state_dict_equal(model1.state_dict(), model3.state_dict(), False)
150

151
    # optim ckpt
152
    criterion = lambda x: x.mean()
153
    data = torch.randint(0, 4, (2, 4)).cuda()
154
    label = torch.randint(0, 4, (2,)).cuda()
155
    if parallel == "hybrid":
156
        kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
157
    else:
158
        kwargs = {}
159
    run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
160
    optim1.step()
161
    optim1.zero_grad()
162
    # shard
163
    booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
164
    dist.barrier()
165
    booster2.load_optimizer(optim2, "./tmp_ckpt2")
166
    # unshard
167
    booster1.save_optimizer(optim1, "./tmp_ckpt2.pth")
168
    booster3.load_optimizer(optim3, "./tmp_ckpt2.pth")
169
    # check
170
    check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
171
    check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False)
172

173
    if dist.get_rank() == 0:
174
        shutil.rmtree("./tmp_ckpt1")
175
        shutil.rmtree("./tmp_ckpt2")
176
        os.remove("./tmp_ckpt1.pth")
177
        os.remove("./tmp_ckpt2.pth")
178

179

180
def _run_dist(rank, world_size, port, parallel):
181
    colossalai.launch(
182
        config=dict(),
183
        rank=rank,
184
        world_size=world_size,
185
        host="localhost",
186
        port=port,
187
        backend="nccl",
188
    )
189
    _test_moe_checkpoint(rank, parallel)
190

191

192
@pytest.mark.skip(reason="This is tested in ColossalMOE")
193
@pytest.mark.dist
194
@pytest.mark.parametrize("world_size", [4])
195
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
196
@rerun_if_address_is_in_use()
197
def test_moe_checkpoint(world_size, parallel):
198
    spawn(_run_dist, world_size, parallel=parallel)
199

200

201
if __name__ == "__main__":
202
    test_moe_checkpoint(world_size=4, parallel="hybrid")
203

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.