colossalai
153 строки · 5.6 Кб
1import pytest2import torch3import torch.distributed as dist4from apex import amp5from torch.nn.parallel import DistributedDataParallel as DDP6from torch.testing import assert_close7
8import colossalai9from colossalai.accelerator import get_accelerator10from colossalai.nn.optimizer import HybridAdam11from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn12from colossalai.utils import set_seed13from colossalai.zero import GeminiDDP, GeminiOptimizer14from colossalai.zero.gemini.chunk import search_chunk_configuration15from tests.kit.model_zoo import model_zoo, run_fwd16
17PLACEMENT_CONFIGS = [18{"placement_policy": "static", "shard_param_frac": 0.0}, # zero219{"placement_policy": "static", "shard_param_frac": 1.0}, # zero320{"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half21{"placement_policy": "auto"},22]
23
24
25def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):26chunk_manager = model.chunk_manager27grad_chunk_list = []28device_list = []29
30# Access gradient chunks.31for p in model.parameters():32grad_chunk = chunk_manager.get_chunk(p).grad_chunk33if grad_chunk not in grad_chunk_list:34chunk_manager.access_chunk(grad_chunk)35grad_chunk_list.append(grad_chunk)36device_list.append(model.grads_device[p])37
38# Compare gradients.39for p0, p1 in zip(model.parameters(), torch_model.parameters()):40assert_close(p0, p1.grad, rtol=2e-3, atol=2e-2)41
42# Release gradient chunks and move them to gradient device.43for grad_chunk, device in zip(grad_chunk_list, device_list):44chunk_manager.release_chunk(grad_chunk)45chunk_manager.move_chunk(grad_chunk, device, force_copy=True)46
47
48@parameterize("placement_config", PLACEMENT_CONFIGS)49@parameterize("keep_gathered", [False, True])50@parameterize("model_name", ["transformers_gpt_lm"])51@parameterize("master_weights", [False, True])52@parameterize("use_grad_checkpoint", [False, True])53def exam_gemini_grad_acc(54placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool55):56init_device = get_accelerator().get_current_device()57model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(58iter(model_zoo.get_sub_registry(model_name).values())59)60
61set_seed(42)62gemini_model = model_builder()63
64set_seed(42)65torch_model = model_builder().cuda()66for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):67torch_p.data.copy_(p.data)68
69if use_grad_checkpoint:70gemini_model.gradient_checkpointing_enable()71torch_model.gradient_checkpointing_enable()72
73world_size = torch.distributed.get_world_size()74config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100)75config_dict[world_size]["chunk_size"] = 500076config_dict[world_size]["keep_gathered"] = keep_gathered77gemini_model = GeminiDDP(78gemini_model,79config_dict,80init_device,81pin_memory=True,82enable_gradient_accumulation=True,83master_weights=master_weights,84**placement_config,85)86optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)87gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0)88
89rank = dist.get_rank()90
91# setting master_weights to False will cause overflow after optimizer.step()92amp_config = dict(93opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, min_loss_scale=1, max_loss_scale=1, master_weights=True94)95torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)96torch_model, torch_optim = amp.initialize(torch_model, torch_optim, **amp_config)97torch_model = DDP(torch_model, device_ids=[rank])98
99set_seed(rank)100accum_iter = 4101train_dataloader = DummyDataloader(data_gen_fn)102for i, data in enumerate(train_dataloader):103delay_unscale = False if (i + 1) % accum_iter == 0 else True104data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}105
106set_seed(42 + rank)107torch_loss = run_fwd(torch_model, data, output_transform_fn, loss_fn)108torch_loss = torch_loss / accum_iter109with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss:110scaled_loss.backward()111
112set_seed(42 + rank)113gemini_loss = run_fwd(gemini_model, data, output_transform_fn, loss_fn)114gemini_loss = gemini_loss / accum_iter115gemini_optim.backward(gemini_loss)116
117assert torch.allclose(torch_loss.float(), gemini_loss.float(), rtol=1e-3, atol=1e-5)118
119check_grad(gemini_model, torch_model)120
121if (i + 1) % accum_iter == 0:122torch.nn.utils.clip_grad_norm_(amp.master_params(torch_optim), 1.0)123torch_optim.step()124gemini_optim.step()125torch_optim.zero_grad()126
127# check updated param128torch_dict = torch_model.state_dict()129gemini_dict = gemini_model.state_dict(only_rank_0=False)130
131for key, value in gemini_dict.items():132torch_key = "module." + key133torch_value = torch_dict[torch_key].to(value.device).to(value.dtype)134assert_close(value, torch_value, rtol=1e-3, atol=2e-3)135
136if i == accum_iter:137break138
139
140def run_dist(rank, world_size, port):141config = {}142colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")143exam_gemini_grad_acc()144
145
146@pytest.mark.dist147@rerun_if_address_is_in_use()148def test_grad_accumulation():149spawn(run_dist, 2)150
151
152if __name__ == "__main__":153test_grad_accumulation()154