colossalai

Форк
0
384 строки · 13.6 Кб
1
import argparse
2
import os
3
from functools import partial
4
from typing import Dict
5

6
import torch
7
import torch.distributed as dist
8
from datasets import load_dataset
9
from huggingface_hub import snapshot_download
10
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
11
from model.openmoe_policy import OpenMoeForCausalLMPolicy
12
from torch.utils.data import Dataset
13
from tqdm import tqdm
14
from transformers import T5Tokenizer
15
from transformers.models.llama import LlamaConfig
16

17
import colossalai
18
from colossalai.accelerator import get_accelerator
19
from colossalai.booster import Booster
20
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
21
from colossalai.cluster import DistCoordinator
22
from colossalai.moe.layers import apply_load_balance
23
from colossalai.moe.manager import MOE_MANAGER
24
from colossalai.moe.utils import skip_init
25
from colossalai.nn.optimizer import HybridAdam
26

27

28
def move_to_cuda(batch, device):
29
    return {k: v.to(device) for k, v in batch.items()}
30

31

32
def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
33
    ckpt_path = snapshot_download(repo_name)
34
    # single ckpt
35
    if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
36
        ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
37
    # shard ckpt
38
    elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
39
        ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
40
    else:
41
        raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
42
    booster.load_model(model, ckpt_path)
43

44

45
def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict:
46
    texts = ["<pad>" + sample["prompt"] + sample["completion"] for sample in batch]
47
    data = tokenizer(
48
        texts,
49
        return_tensors="pt",
50
        padding="max_length",
51
        truncation=True,
52
        max_length=max_length,
53
        add_special_tokens=False,
54
    )
55
    data = {k: v.cuda() for k, v in data.items()}
56
    data["labels"] = data["input_ids"].clone()
57
    return data
58

59

60
class RandomDataset(Dataset):
61
    def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
62
        self.num_samples = num_samples
63
        self.max_length = max_length
64
        self.input_ids = torch.randint(
65
            0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
66
        )
67
        self.attention_mask = torch.ones_like(self.input_ids)
68

69
    def __len__(self):
70
        return self.num_samples
71

72
    def __getitem__(self, idx):
73
        return {
74
            "input_ids": self.input_ids[idx],
75
            "attention_mask": self.attention_mask[idx],
76
            "labels": self.input_ids[idx],
77
        }
78

79

80
def parse_args():
81
    # basic settings
82
    parser = argparse.ArgumentParser()
83
    parser.add_argument(
84
        "--model_name",
85
        type=str,
86
        default="base",
87
        choices=["base", "8b", "test"],
88
        help="Path to pretrained model or model identifier from huggingface.co/models.",
89
    )
90
    parser.add_argument(
91
        "--plugin",
92
        type=str,
93
        default="hybrid",
94
        choices=["ep", "ep_zero", "hybrid"],
95
        help="Parallel methos. ep_zero is recommended for general cases. ep can provides least memory consumption and hybrid suits large scale training.",
96
    )
97
    parser.add_argument(
98
        "--output_path",
99
        type=str,
100
        default="./outputs",
101
        help="The path of your saved model after finetuning.",
102
    )
103
    parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
104
    parser.add_argument(
105
        "--batch_size",
106
        type=int,
107
        default=1,
108
        help="Batch size (per dp group) for the training dataloader.",
109
    )
110
    parser.add_argument(
111
        "--save_interval",
112
        type=int,
113
        default=1000,
114
        help=" The interval (steps) of saving checkpoints.",
115
    )
116
    parser.add_argument(
117
        "--precision",
118
        type=str,
119
        default="bf16",
120
        choices=["fp32", "bf16", "fp16"],
121
        help="The mixed precision training.",
122
    )
123
    parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
124
    parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
125
    parser.add_argument(
126
        "--dataset",
127
        type=str,
128
        default="yizhongw/self_instruct",
129
        help="dataset name from `datasets` repo.",
130
    )
131
    parser.add_argument(
132
        "--task_name",
133
        type=str,
134
        default="super_natural_instructions",
135
        help="task of corresponding dataset.",
136
    )
137

138
    # optim
139
    parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
140
    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
141

142
    # zero stage for all plugins
143
    parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
144
    # ep_zero plugin
145
    parser.add_argument(
146
        "--extra_dp_size", type=int, default=1, help="ep_zero plugin's moe dp size. Recommended to be 2 or 4."
147
    )
148
    # hybrid plugin
149
    parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
150
    parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
151
    parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
152
    parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
153

154
    # kernel
155
    parser.add_argument(
156
        "--use_kernel",
157
        action="store_true",
158
        help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
159
    )
160
    parser.add_argument(
161
        "--use_layernorm_kernel",
162
        action="store_true",
163
        help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
164
    )
165

166
    # loss
167
    parser.add_argument(
168
        "--router_aux_loss_factor",
169
        type=float,
170
        default=0.01,
171
        help="Moe router z loss. You can refer to STMoE for details.",
172
    )
173
    parser.add_argument(
174
        "--router_z_loss_factor",
175
        type=float,
176
        default=0.0001,
177
        help="Moe router aux loss. You can refer to STMoE for details.",
178
    )
179
    parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing.")
180
    parser.add_argument(
181
        "--z_loss_factor", type=float, default=0.0001, help="The final outputs' classification z loss factor."
182
    )
183

184
    # load balance
185
    parser.add_argument(
186
        "--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
187
    )
188
    parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
189
    # communicate overlap
190
    parser.add_argument(
191
        "--comm_overlap",
192
        action="store_true",
193
        help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
194
    )
195
    # hierarchical all-to-all
196
    parser.add_argument(
197
        "--hierarchical_alltoall",
198
        action="store_true",
199
        help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
200
    )
201

202
    args = parser.parse_args()
203
    return args
204

205

206
def main():
207
    args = parse_args()
208

209
    # Launch ColossalAI
210
    colossalai.launch_from_torch(config={}, seed=args.seed)
211
    coordinator = DistCoordinator()
212
    test_mode = args.model_name == "test"
213

214
    # Set plugin
215
    booster_kwargs = {}
216
    hybrid_dict = {
217
        "tp_size": 1,
218
        "custom_policy": OpenMoeForCausalLMPolicy(),
219
        "enable_fused_normalization": args.use_layernorm_kernel,
220
        "enable_jit_fused": args.use_kernel,
221
        "precision": args.precision,
222
        "zero_stage": args.zero_stage,
223
    }
224
    mgr_dict = {}
225
    if args.plugin == "ep":
226
        dp_size = dist.get_world_size()
227
        plugin = MoeHybridParallelPlugin(
228
            pp_size=1,
229
            **hybrid_dict,
230
        )
231
        MOE_MANAGER.setup(
232
            parallel="EP",
233
            max_ep_size=dp_size,
234
            **mgr_dict,
235
        )
236
    elif args.plugin == "ep_zero":
237
        dp_size = dist.get_world_size()
238
        use_ep_inside = False
239
        plugin = MoeHybridParallelPlugin(
240
            pp_size=1,
241
            extra_dp_size=args.extra_dp_size,
242
            use_ep_inside=use_ep_inside,
243
            **hybrid_dict,
244
        )
245
        MOE_MANAGER.setup(
246
            parallel="EP",
247
            max_ep_size=dp_size // args.extra_dp_size,
248
            use_ep_inside=use_ep_inside,
249
            **mgr_dict,
250
        )
251
    elif args.plugin == "hybrid":
252
        dp_size = dist.get_world_size() // args.pp_size
253
        plugin = MoeHybridParallelPlugin(
254
            pp_size=args.pp_size,
255
            microbatch_size=args.microbatch_size,
256
            **hybrid_dict,
257
        )
258
        MOE_MANAGER.setup(
259
            parallel="EP",
260
            mode="fixed",
261
            fixed_dp_size=args.dp_size,
262
            fixed_ep_size=args.ep_size,
263
            fixed_pp_size=args.pp_size,
264
            **mgr_dict,
265
        )
266
    else:
267
        raise ValueError(f"Invalid plugin {args.plugin}")
268
    coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
269

270
    # Build OpenMoe model
271
    if test_mode:
272
        config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base")
273
        config.hidden_size = 128
274
        config.intermediate_size = 256
275
        config.vocab_size = 32000
276
    else:
277
        repo_name = "hpcai-tech/openmoe-" + args.model_name
278
        config = LlamaConfig.from_pretrained(repo_name)
279
    set_openmoe_args(
280
        config,
281
        num_experts=config.num_experts,
282
        moe_layer_interval=config.moe_layer_interval,
283
        router_aux_loss_factor=args.router_aux_loss_factor,
284
        router_z_loss_factor=args.router_z_loss_factor,
285
        z_loss_factor=args.z_loss_factor,
286
        enable_load_balance=args.load_balance,
287
        enable_comm_overlap=args.comm_overlap,
288
        enable_hierarchical_alltoall=args.hierarchical_alltoall,
289
        enable_kernel=args.use_kernel,
290
    )
291
    with skip_init():
292
        model = OpenMoeForCausalLM(config)
293
    coordinator.print_on_master(f"Finish init model with config:\n{config}")
294

295
    # Enable gradient checkpointing
296
    model.gradient_checkpointing_enable()
297

298
    # Prepare tokenizer and dataloader
299
    tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
300
    if test_mode:
301
        dataset = RandomDataset(num_samples=20, tokenizer=tokenizer)
302
        collate_fn = None
303
    else:
304
        dataset = load_dataset(args.dataset, args.task_name)
305
        dataset = dataset["train"]
306
        collate_fn = partial(tokenize_data, tokenizer=tokenizer, max_length=args.max_length)
307
    dataloader = plugin.prepare_dataloader(
308
        dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
309
    )
310

311
    # Set optimizer
312
    optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
313

314
    # Set booster
315
    booster = Booster(plugin=plugin, **booster_kwargs)
316
    if not test_mode:
317
        load_ckpt(repo_name, model, booster)
318
    model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
319
    use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
320
    is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
321
    coordinator.print_on_master(f"Finish init booster")
322

323
    # Start finetuning
324
    coordinator.print_on_master(f"Start finetuning")
325
    for epoch in range(args.num_epoch):
326
        model.train()
327
        train_dataloader_iter = iter(dataloader)
328
        total_len = len(train_dataloader_iter)
329
        with tqdm(
330
            range(total_len),
331
            desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
332
            disable=not coordinator.is_master(),
333
        ) as pbar:
334
            for step in pbar:
335
                if use_pipeline:
336
                    # Forward pass
337
                    outputs = booster.execute_pipeline(
338
                        train_dataloader_iter,
339
                        model,
340
                        lambda x, y: x.loss,
341
                        optimizer,
342
                        return_loss=True,
343
                        return_outputs=True,
344
                    )
345
                    # Backward and optimize
346
                    if is_pp_last_stage:
347
                        loss = outputs["loss"]
348
                        pbar.set_postfix({"loss": loss.item()})
349
                else:
350
                    # Forward pass
351
                    data = next(train_dataloader_iter)
352
                    data = move_to_cuda(data, torch.cuda.current_device())
353
                    outputs = model(**data)
354
                    loss = outputs["loss"]
355
                    # Backward
356
                    booster.backward(loss, optimizer)
357
                    pbar.set_postfix({"loss": loss.item()})
358

359
                optimizer.step()
360
                optimizer.zero_grad()
361

362
                # Apply load balance
363
                if (
364
                    args.load_balance
365
                    and args.load_balance_interval > 0
366
                    and (step + 1) % args.load_balance_interval == 0
367
                ):
368
                    coordinator.print_on_master(f"Apply load balance")
369
                    apply_load_balance(model, optimizer)
370
                # save ckeckpoint
371
                if (step + 1) % args.save_interval == 0:
372
                    coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
373
                    booster.save_model(model, args.output_path, shard=True)
374

375
        # save checkpoint at the end of each epochs
376
        booster.save_model(model, args.output_path, shard=True)
377
        coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
378

379
    # Finish training
380
    coordinator.print_on_master(f"Finish training")
381

382

383
if __name__ == "__main__":
384
    main()
385

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.