colossalai
384 строки · 13.6 Кб
1import argparse2import os3from functools import partial4from typing import Dict5
6import torch7import torch.distributed as dist8from datasets import load_dataset9from huggingface_hub import snapshot_download10from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args11from model.openmoe_policy import OpenMoeForCausalLMPolicy12from torch.utils.data import Dataset13from tqdm import tqdm14from transformers import T5Tokenizer15from transformers.models.llama import LlamaConfig16
17import colossalai18from colossalai.accelerator import get_accelerator19from colossalai.booster import Booster20from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin21from colossalai.cluster import DistCoordinator22from colossalai.moe.layers import apply_load_balance23from colossalai.moe.manager import MOE_MANAGER24from colossalai.moe.utils import skip_init25from colossalai.nn.optimizer import HybridAdam26
27
28def move_to_cuda(batch, device):29return {k: v.to(device) for k, v in batch.items()}30
31
32def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):33ckpt_path = snapshot_download(repo_name)34# single ckpt35if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):36ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")37# shard ckpt38elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):39ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")40else:41raise ValueError(f"Invalid checkpoint path: {ckpt_path}")42booster.load_model(model, ckpt_path)43
44
45def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict:46texts = ["<pad>" + sample["prompt"] + sample["completion"] for sample in batch]47data = tokenizer(48texts,49return_tensors="pt",50padding="max_length",51truncation=True,52max_length=max_length,53add_special_tokens=False,54)55data = {k: v.cuda() for k, v in data.items()}56data["labels"] = data["input_ids"].clone()57return data58
59
60class RandomDataset(Dataset):61def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):62self.num_samples = num_samples63self.max_length = max_length64self.input_ids = torch.randint(650, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()66)67self.attention_mask = torch.ones_like(self.input_ids)68
69def __len__(self):70return self.num_samples71
72def __getitem__(self, idx):73return {74"input_ids": self.input_ids[idx],75"attention_mask": self.attention_mask[idx],76"labels": self.input_ids[idx],77}78
79
80def parse_args():81# basic settings82parser = argparse.ArgumentParser()83parser.add_argument(84"--model_name",85type=str,86default="base",87choices=["base", "8b", "test"],88help="Path to pretrained model or model identifier from huggingface.co/models.",89)90parser.add_argument(91"--plugin",92type=str,93default="hybrid",94choices=["ep", "ep_zero", "hybrid"],95help="Parallel methos. ep_zero is recommended for general cases. ep can provides least memory consumption and hybrid suits large scale training.",96)97parser.add_argument(98"--output_path",99type=str,100default="./outputs",101help="The path of your saved model after finetuning.",102)103parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")104parser.add_argument(105"--batch_size",106type=int,107default=1,108help="Batch size (per dp group) for the training dataloader.",109)110parser.add_argument(111"--save_interval",112type=int,113default=1000,114help=" The interval (steps) of saving checkpoints.",115)116parser.add_argument(117"--precision",118type=str,119default="bf16",120choices=["fp32", "bf16", "fp16"],121help="The mixed precision training.",122)123parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")124parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")125parser.add_argument(126"--dataset",127type=str,128default="yizhongw/self_instruct",129help="dataset name from `datasets` repo.",130)131parser.add_argument(132"--task_name",133type=str,134default="super_natural_instructions",135help="task of corresponding dataset.",136)137
138# optim139parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")140parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")141
142# zero stage for all plugins143parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")144# ep_zero plugin145parser.add_argument(146"--extra_dp_size", type=int, default=1, help="ep_zero plugin's moe dp size. Recommended to be 2 or 4."147)148# hybrid plugin149parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")150parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")151parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")152parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")153
154# kernel155parser.add_argument(156"--use_kernel",157action="store_true",158help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",159)160parser.add_argument(161"--use_layernorm_kernel",162action="store_true",163help="Use layernorm kernel. Need to install apex. Raise error if not installed.",164)165
166# loss167parser.add_argument(168"--router_aux_loss_factor",169type=float,170default=0.01,171help="Moe router z loss. You can refer to STMoE for details.",172)173parser.add_argument(174"--router_z_loss_factor",175type=float,176default=0.0001,177help="Moe router aux loss. You can refer to STMoE for details.",178)179parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing.")180parser.add_argument(181"--z_loss_factor", type=float, default=0.0001, help="The final outputs' classification z loss factor."182)183
184# load balance185parser.add_argument(186"--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."187)188parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")189# communicate overlap190parser.add_argument(191"--comm_overlap",192action="store_true",193help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",194)195# hierarchical all-to-all196parser.add_argument(197"--hierarchical_alltoall",198action="store_true",199help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",200)201
202args = parser.parse_args()203return args204
205
206def main():207args = parse_args()208
209# Launch ColossalAI210colossalai.launch_from_torch(config={}, seed=args.seed)211coordinator = DistCoordinator()212test_mode = args.model_name == "test"213
214# Set plugin215booster_kwargs = {}216hybrid_dict = {217"tp_size": 1,218"custom_policy": OpenMoeForCausalLMPolicy(),219"enable_fused_normalization": args.use_layernorm_kernel,220"enable_jit_fused": args.use_kernel,221"precision": args.precision,222"zero_stage": args.zero_stage,223}224mgr_dict = {}225if args.plugin == "ep":226dp_size = dist.get_world_size()227plugin = MoeHybridParallelPlugin(228pp_size=1,229**hybrid_dict,230)231MOE_MANAGER.setup(232parallel="EP",233max_ep_size=dp_size,234**mgr_dict,235)236elif args.plugin == "ep_zero":237dp_size = dist.get_world_size()238use_ep_inside = False239plugin = MoeHybridParallelPlugin(240pp_size=1,241extra_dp_size=args.extra_dp_size,242use_ep_inside=use_ep_inside,243**hybrid_dict,244)245MOE_MANAGER.setup(246parallel="EP",247max_ep_size=dp_size // args.extra_dp_size,248use_ep_inside=use_ep_inside,249**mgr_dict,250)251elif args.plugin == "hybrid":252dp_size = dist.get_world_size() // args.pp_size253plugin = MoeHybridParallelPlugin(254pp_size=args.pp_size,255microbatch_size=args.microbatch_size,256**hybrid_dict,257)258MOE_MANAGER.setup(259parallel="EP",260mode="fixed",261fixed_dp_size=args.dp_size,262fixed_ep_size=args.ep_size,263fixed_pp_size=args.pp_size,264**mgr_dict,265)266else:267raise ValueError(f"Invalid plugin {args.plugin}")268coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")269
270# Build OpenMoe model271if test_mode:272config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base")273config.hidden_size = 128274config.intermediate_size = 256275config.vocab_size = 32000276else:277repo_name = "hpcai-tech/openmoe-" + args.model_name278config = LlamaConfig.from_pretrained(repo_name)279set_openmoe_args(280config,281num_experts=config.num_experts,282moe_layer_interval=config.moe_layer_interval,283router_aux_loss_factor=args.router_aux_loss_factor,284router_z_loss_factor=args.router_z_loss_factor,285z_loss_factor=args.z_loss_factor,286enable_load_balance=args.load_balance,287enable_comm_overlap=args.comm_overlap,288enable_hierarchical_alltoall=args.hierarchical_alltoall,289enable_kernel=args.use_kernel,290)291with skip_init():292model = OpenMoeForCausalLM(config)293coordinator.print_on_master(f"Finish init model with config:\n{config}")294
295# Enable gradient checkpointing296model.gradient_checkpointing_enable()297
298# Prepare tokenizer and dataloader299tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")300if test_mode:301dataset = RandomDataset(num_samples=20, tokenizer=tokenizer)302collate_fn = None303else:304dataset = load_dataset(args.dataset, args.task_name)305dataset = dataset["train"]306collate_fn = partial(tokenize_data, tokenizer=tokenizer, max_length=args.max_length)307dataloader = plugin.prepare_dataloader(308dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn309)310
311# Set optimizer312optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)313
314# Set booster315booster = Booster(plugin=plugin, **booster_kwargs)316if not test_mode:317load_ckpt(repo_name, model, booster)318model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)319use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1320is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()321coordinator.print_on_master(f"Finish init booster")322
323# Start finetuning324coordinator.print_on_master(f"Start finetuning")325for epoch in range(args.num_epoch):326model.train()327train_dataloader_iter = iter(dataloader)328total_len = len(train_dataloader_iter)329with tqdm(330range(total_len),331desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",332disable=not coordinator.is_master(),333) as pbar:334for step in pbar:335if use_pipeline:336# Forward pass337outputs = booster.execute_pipeline(338train_dataloader_iter,339model,340lambda x, y: x.loss,341optimizer,342return_loss=True,343return_outputs=True,344)345# Backward and optimize346if is_pp_last_stage:347loss = outputs["loss"]348pbar.set_postfix({"loss": loss.item()})349else:350# Forward pass351data = next(train_dataloader_iter)352data = move_to_cuda(data, torch.cuda.current_device())353outputs = model(**data)354loss = outputs["loss"]355# Backward356booster.backward(loss, optimizer)357pbar.set_postfix({"loss": loss.item()})358
359optimizer.step()360optimizer.zero_grad()361
362# Apply load balance363if (364args.load_balance365and args.load_balance_interval > 0366and (step + 1) % args.load_balance_interval == 0367):368coordinator.print_on_master(f"Apply load balance")369apply_load_balance(model, optimizer)370# save ckeckpoint371if (step + 1) % args.save_interval == 0:372coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")373booster.save_model(model, args.output_path, shard=True)374
375# save checkpoint at the end of each epochs376booster.save_model(model, args.output_path, shard=True)377coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")378
379# Finish training380coordinator.print_on_master(f"Finish training")381
382
383if __name__ == "__main__":384main()385