colossalai
299 строк · 10.7 Кб
1import argparse2import json3import os4
5import torch6import torch.distributed as dist7from huggingface_hub import snapshot_download8from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args9from model.openmoe_policy import OpenMoeForCausalLMPolicy10from torch.utils.data import Dataset11from tqdm import tqdm12from transformers import T5Tokenizer13from transformers.models.llama import LlamaConfig14from utils import PerformanceEvaluator, get_model_numel15
16import colossalai17from colossalai.accelerator import get_accelerator18from colossalai.booster import Booster19from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin20from colossalai.cluster import DistCoordinator21from colossalai.moe.layers import apply_load_balance22from colossalai.moe.manager import MOE_MANAGER23from colossalai.moe.utils import skip_init24from colossalai.nn.optimizer import HybridAdam25
26
27def move_to_cuda(batch, device):28return {k: v.to(device) for k, v in batch.items()}29
30
31def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):32ckpt_path = snapshot_download(repo_name)33# single ckpt34if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):35ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")36# shard ckpt37elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):38ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")39else:40raise ValueError(f"Invalid checkpoint path: {ckpt_path}")41booster.load_model(model, ckpt_path)42
43
44class RandomDataset(Dataset):45def __init__(46self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384, tokenizer: T5Tokenizer = None47):48self.num_samples = num_samples49self.max_length = max_length50if os.path.exists("./mock_data.json"):51self.input_ids = []52self.attention_mask = []53with open("./mock_data.json", "r") as f:54data = json.load(f)55for v in data.values():56d = v["text"]57encode = tokenizer(58"<pad>" + d,59return_tensors="pt",60add_special_tokens=False,61max_length=max_length,62truncation=True,63padding="max_length",64)65self.input_ids.append(encode["input_ids"])66self.attention_mask.append(encode["attention_mask"])67self.input_ids = torch.cat(self.input_ids, dim=0).to(get_accelerator().get_current_device())68self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_accelerator().get_current_device())69repeat_times = num_samples // self.input_ids.shape[0] + 170self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples]71self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples]72else:73self.input_ids = torch.randint(740, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()75)76self.attention_mask = torch.ones_like(self.input_ids)77
78def __len__(self):79return self.num_samples80
81def __getitem__(self, idx):82return {83"input_ids": self.input_ids[idx],84"attention_mask": self.attention_mask[idx],85"labels": self.input_ids[idx],86}87
88
89def parse_args():90# basic settings91parser = argparse.ArgumentParser()92parser.add_argument(93"--model_name",94type=str,95default="base",96choices=["base", "8b"],97help="Path to pretrained model or model identifier from huggingface.co/models.",98)99parser.add_argument(100"--batch_size",101type=int,102default=4,103help="Batch size (per dp group) for the training dataloader.",104)105parser.add_argument(106"--seq_length",107type=int,108default=2048,109help="sequence length for the training dataloader.",110)111parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")112parser.add_argument(113"--plugin",114type=str,115default="hybrid",116help="parallel plugin",117)118# hybrid plugin119parser.add_argument("--pp_size", type=int, default=2, help="pp size")120parser.add_argument("--dp_size", type=int, default=1, help="dp size")121parser.add_argument("--ep_size", type=int, default=2, help="ep size")122parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin")123parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size")124parser.add_argument("--extra_dp_size", type=int, default=1)125# kernel126parser.add_argument(127"--use_kernel",128action="store_true",129help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.",130)131# bench132parser.add_argument("--warmup", type=int, default=20)133parser.add_argument("--active", type=int, default=20)134# load balance135parser.add_argument("--load_balance", action="store_true")136
137# overlap communication138parser.add_argument("--overlap_comm", action="store_true")139# hierarchical all-to-all140parser.add_argument("--hierarchical_alltoall", action="store_true")141args = parser.parse_args()142return args143
144
145def main():146args = parse_args()147
148# Launch ColossalAI149colossalai.launch_from_torch(config={}, seed=args.seed)150coordinator = DistCoordinator()151
152# Set plugin153booster_kwargs = {}154hybrid_dict = {155"tp_size": 1,156"custom_policy": OpenMoeForCausalLMPolicy(),157"enable_fused_normalization": args.use_kernel,158"enable_jit_fused": args.use_kernel,159"precision": "bf16",160"zero_stage": args.zero_stage,161}162mgr_dict = {}163if args.plugin == "ep":164dp_size = dist.get_world_size()165plugin = MoeHybridParallelPlugin(166pp_size=1,167**hybrid_dict,168)169MOE_MANAGER.setup(170parallel="EP",171max_ep_size=dp_size,172**mgr_dict,173)174elif args.plugin == "ep_zero":175dp_size = dist.get_world_size()176use_ep_inside = False177plugin = MoeHybridParallelPlugin(178pp_size=1,179extra_dp_size=args.extra_dp_size,180use_ep_inside=use_ep_inside,181**hybrid_dict,182)183MOE_MANAGER.setup(184parallel="EP",185max_ep_size=dp_size // args.extra_dp_size,186use_ep_inside=use_ep_inside,187**mgr_dict,188)189elif args.plugin == "hybrid":190dp_size = dist.get_world_size() // args.pp_size191plugin = MoeHybridParallelPlugin(192pp_size=args.pp_size,193zero_stage=args.zero_stage,194microbatch_size=args.microbatch_size,195**hybrid_dict,196)197MOE_MANAGER.setup(198parallel="EP",199mode="fixed",200fixed_dp_size=args.dp_size,201fixed_ep_size=args.ep_size,202fixed_pp_size=args.pp_size,203**mgr_dict,204)205else:206raise ValueError(f"Invalid plugin {args.plugin}")207coordinator.print_on_master(f"Set plugin as {plugin}")208
209# Build OpenMoe model210repo_name = "hpcai-tech/openmoe-" + args.model_name211config = LlamaConfig.from_pretrained(repo_name)212set_openmoe_args(213config,214num_experts=config.num_experts,215moe_layer_interval=config.moe_layer_interval,216enable_load_balance=args.load_balance,217enable_kernel=args.use_kernel,218enable_comm_overlap=args.overlap_comm,219enable_hierarchical_alltoall=args.hierarchical_alltoall,220)221with skip_init():222model = OpenMoeForCausalLM(config)223coordinator.print_on_master(f"Finish init model with config:\n{config}")224
225# Enable gradient checkpointing226model.gradient_checkpointing_enable()227
228# Prepare tokenizer and dataloader229tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")230dataset = RandomDataset(231num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size,232max_length=args.seq_length,233tokenizer=tokenizer,234)235dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size)236
237# Set optimizer238optimizer = HybridAdam(model.parameters(), weight_decay=0.01, lr=1e-5)239
240model_numel = get_model_numel(model)241performance_evaluator = PerformanceEvaluator(242model_numel,243enable_grad_checkpoint=True,244ignore_steps=args.warmup,245dp_world_size=dp_size,246)247
248# Set booster249booster = Booster(plugin=plugin, **booster_kwargs)250load_ckpt(repo_name, model, booster)251model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)252use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1253is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()254coordinator.print_on_master(f"Finish init booster")255
256# Start finetuning257coordinator.print_on_master(f"Start training")258model.train()259train_dataloader_iter = iter(dataloader)260total_len = len(train_dataloader_iter) - 1261exmaple_data = next(train_dataloader_iter)262with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar:263for step in pbar:264performance_evaluator.on_step_start(step)265if use_pipeline:266# Forward pass267outputs = booster.execute_pipeline(268train_dataloader_iter,269model,270lambda x, y: x.loss,271optimizer,272return_loss=True,273return_outputs=True,274)275# Backward and optimize276if is_pp_last_stage:277loss = outputs["loss"]278pbar.set_postfix({"loss": loss.item()})279else:280# Forward pass281data = next(train_dataloader_iter)282data = move_to_cuda(data, torch.cuda.current_device())283outputs = model(**data)284loss = outputs["loss"]285# Backward286booster.backward(loss, optimizer)287pbar.set_postfix({"loss": loss.item()})288
289optimizer.step()290optimizer.zero_grad()291performance_evaluator.on_step_end(exmaple_data["input_ids"])292if (step == args.warmup // 2) and args.load_balance:293coordinator.print_on_master(f"Apply load balance")294apply_load_balance(model, optimizer)295performance_evaluator.on_fit_end()296
297
298if __name__ == "__main__":299main()300