colossalai
202 строки · 6.1 Кб
1import importlib2import os3import shutil4import sys5
6import pytest7import torch8import torch.distributed as dist9from transformers.models.llama import LlamaConfig10
11import colossalai12from colossalai.accelerator import get_accelerator13from colossalai.booster import Booster14from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin15from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn16
17sys.path.append(18os.path.join(19os.path.dirname(os.path.dirname(os.path.dirname(__file__))),20"examples/language/openmoe",21)22)
23
24OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM25set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args26OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy27
28
29def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):30input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_accelerator().get_current_device())31attention_mask = torch.ones_like(input_ids)32return {33"input_ids": input_ids,34"attention_mask": attention_mask,35"labels": input_ids,36}37
38
39def run_fwd_bwd(40model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None41):42model.train()43if pipeline:44train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)45is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()46y = booster.execute_pipeline(47train_dataloader_iter,48model,49lambda x, y: x.loss,50optimizer,51return_loss=True,52return_outputs=True,53)54# Backward and optimize55if is_pp_last_stage:56loss = y["loss"]57else:58if criterion:59y = model(data).logits60loss = criterion(y)61else:62loss = model(data, label)63loss = loss.float()64
65if optimizer is not None:66optimizer.backward(loss)67else:68loss.backward()69return y70
71
72def get_config():73config = LlamaConfig(74vocab_size=300,75hidden_size=16,76intermediate_size=32,77num_hidden_layers=2,78num_attention_heads=2,79head_dim=4,80dropout_rate=0.0,81hidden_act="swiglu",82)83set_openmoe_args(config, num_experts=8, moe_layer_interval=1)84return config85
86
87def get_model(parallel):88config = get_config()89model = OpenMoeForCausalLM(config)90optim = torch.optim.Adam(model.parameters())91
92if parallel == None:93plugin = MoeHybridParallelPlugin(94precision="bf16",95tp_size=1,96pp_size=1,97ep_size=1,98zero_stage=2,99custom_policy=OpenMoeForCausalLMPolicy(),100)101elif parallel == "ep":102plugin = MoeHybridParallelPlugin(103precision="bf16",104tp_size=1,105pp_size=1,106ep_size=dist.get_world_size(),107zero_stage=2,108custom_policy=OpenMoeForCausalLMPolicy(),109)110elif parallel == "ep_zero":111plugin = MoeHybridParallelPlugin(112precision="bf16",113tp_size=1,114pp_size=1,115ep_size=2,116zero_stage=2,117extra_dp_size=2,118custom_policy=OpenMoeForCausalLMPolicy(),119)120elif parallel == "hybrid":121plugin = MoeHybridParallelPlugin(122precision="bf16",123tp_size=1,124pp_size=2,125ep_size=2,126zero_stage=1,127microbatch_size=1,128custom_policy=OpenMoeForCausalLMPolicy(),129)130booster = Booster(plugin=plugin)131model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)132return model, booster, optim133
134
135def _test_moe_checkpoint(rank, parallel):136model1, booster1, optim1 = get_model(parallel)137model2, booster2, optim2 = get_model(parallel)138model3, booster3, optim3 = get_model(parallel)139
140# param ckpt141# shard142booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)143booster2.load_model(model2, "./tmp_ckpt1")144# unshard145booster1.save_model(model1, "./tmp_ckpt1.pth")146booster3.load_model(model3, "./tmp_ckpt1.pth")147# check148check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)149check_state_dict_equal(model1.state_dict(), model3.state_dict(), False)150
151# optim ckpt152criterion = lambda x: x.mean()153data = torch.randint(0, 4, (2, 4)).cuda()154label = torch.randint(0, 4, (2,)).cuda()155if parallel == "hybrid":156kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}157else:158kwargs = {}159run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)160optim1.step()161optim1.zero_grad()162# shard163booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)164dist.barrier()165booster2.load_optimizer(optim2, "./tmp_ckpt2")166# unshard167booster1.save_optimizer(optim1, "./tmp_ckpt2.pth")168booster3.load_optimizer(optim3, "./tmp_ckpt2.pth")169# check170check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)171check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False)172
173if dist.get_rank() == 0:174shutil.rmtree("./tmp_ckpt1")175shutil.rmtree("./tmp_ckpt2")176os.remove("./tmp_ckpt1.pth")177os.remove("./tmp_ckpt2.pth")178
179
180def _run_dist(rank, world_size, port, parallel):181colossalai.launch(182config=dict(),183rank=rank,184world_size=world_size,185host="localhost",186port=port,187backend="nccl",188)189_test_moe_checkpoint(rank, parallel)190
191
192@pytest.mark.skip(reason="This is tested in ColossalMOE")193@pytest.mark.dist194@pytest.mark.parametrize("world_size", [4])195@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])196@rerun_if_address_is_in_use()197def test_moe_checkpoint(world_size, parallel):198spawn(_run_dist, world_size, parallel=parallel)199
200
201if __name__ == "__main__":202test_moe_checkpoint(world_size=4, parallel="hybrid")203