gpt-neox

Форк
0
/
training.py 
1140 строк · 39.3 Кб
1
# Copyright (c) 2024, EleutherAI
2
# This file is based on code by the authors denoted below and has been modified from its original version.
3
#
4
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
5
#
6
# Licensed under the Apache License, Version 2.0 (the "License");
7
# you may not use this file except in compliance with the License.
8
# You may obtain a copy of the License at
9
#
10
#     http://www.apache.org/licenses/LICENSE-2.0
11
#
12
# Unless required by applicable law or agreed to in writing, software
13
# distributed under the License is distributed on an "AS IS" BASIS,
14
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
# See the License for the specific language governing permissions and
16
# limitations under the License.
17
#
18
# This file has been modified from its original version
19
#
20

21
"""Pretrain utilities."""
22
from datetime import datetime
23
from functools import partial
24

25
import math
26
import sys
27

28
import torch
29
import deepspeed
30
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler
31
import numpy as np
32

33
from megatron.utils import (
34
    Timers,
35
    init_wandb,
36
    get_ltor_masks_and_position_ids,
37
    reduce_losses,
38
)
39

40
from megatron import print_rank_0, mpu
41
from megatron.model import (
42
    GPT2ModelPipe,
43
    SoftEmbedding,
44
    get_params_for_weight_decay_optimization,
45
)
46
from megatron.checkpointing import load_checkpoint, save_checkpoint
47
from megatron.data.data_utils import build_train_valid_test_data_iterators
48
from megatron.initialize import initialize_megatron
49
from megatron.learning_rates import AnnealingLR
50
from megatron.logging import tb_wandb_log, training_log
51
from megatron.utils import (
52
    OverflowMonitor,
53
    get_noise_scale_logger,
54
    get_total_params,
55
    CharCounter,
56
)
57
from megatron.model.gpt2_model import cross_entropy
58

59
from pickle import dump
60
import os
61

62

63
def mup_weights_reinit(neox_args, model):
64
    def has_method(o, name):
65
        return callable(getattr(o, name, None))
66

67
    for layer in model.modules():
68
        # This normally would happen in set_base_shapes if we actually were able to use the MuReadout class
69
        if hasattr(layer, "mup_rescale_parameters") and layer.mup_rescale_parameters:
70
            layer._rescale_parameters()
71

72
        if has_method(layer, "mup_reinitialize_weights"):
73
            layer.mup_reinitialize_weights(neox_args)
74

75

76
def save_base_shapes(neox_args, base_shapes, use_cache):
77

78
    # Instantiation of the base model fails in the init function (init_functions.py) because we haven't called set_base_shapes on it at this point, so disable it temporarily here
79
    neox_args.use_mup = False
80

81
    base_model = GPT2ModelPipe(
82
        neox_args=neox_args,
83
        num_tokentypes=0,
84
        parallel_output=True,
85
        topology=mpu.get_topology(),
86
        use_cache=use_cache,
87
    )
88

89
    if not neox_args.is_pipe_parallel:
90
        base_model = base_model.to_sequential()
91

92
    try:
93
        import mup
94
    except ModuleNotFoundError:
95
        print("Please install mup https://github.com/microsoft/mup")
96
        raise Exception
97

98
    base_shapes = mup.get_shapes(base_model)
99

100
    del base_model
101

102
    old_hidden_size = neox_args.hidden_size
103
    neox_args.hidden_size = neox_args.hidden_size * neox_args.mup_width_scale
104

105
    delta_model = GPT2ModelPipe(
106
        neox_args=neox_args,
107
        num_tokentypes=0,
108
        parallel_output=True,
109
        topology=mpu.get_topology(),
110
        use_cache=use_cache,
111
    )
112

113
    if not neox_args.is_pipe_parallel:
114
        delta_model = delta_model.to_sequential()
115

116
    delta_shapes = mup.get_shapes(delta_model)
117

118
    # change back
119
    neox_args.use_mup = True
120
    neox_args.hidden_size = old_hidden_size
121

122
    save_shapes = f"{neox_args.base_shapes_file}.{torch.distributed.get_rank()}"
123
    print(f"saving base shapes at {save_shapes}")
124
    mup.make_base_shapes(base_shapes, delta_shapes, savefile=save_shapes)
125
    print(f"base shapes saved...exiting")
126
    sys.exit(1)
127

128

129
def mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator):
130
    from megatron.mup_substitute import get_coord_data
131
    from mup.coord_check import plot_coord_data
132

133
    def lazy_model(hidden_size):
134
        def gen():
135
            old_hidden_size = neox_args.hidden_size
136
            neox_args.hidden_size = hidden_size
137

138
            model, optimizer, _ = setup_model_and_optimizer(
139
                neox_args=neox_args, use_cache=False
140
            )
141

142
            neox_args.hidden_size = old_hidden_size
143

144
            return model
145

146
        return gen
147

148
    models = {}
149

150
    # Hidden size needs to be divisible by num attention heads
151
    for hidden_size in (neox_args.num_attention_heads * (2**p) for p in range(2, 9)):
152
        models[hidden_size] = lazy_model(hidden_size)
153

154
    neox_args.use_mup = True
155
    df_up = get_coord_data(
156
        neox_args, timers, lr_scheduler, models, train_data_iterator, mup=True
157
    )
158
    neox_args.use_mup = False
159
    df_sp = get_coord_data(
160
        neox_args, timers, lr_scheduler, models, train_data_iterator, mup=False
161
    )
162

163
    plot_coord_data(df_up, save_to=f"coord_check_up.{torch.distributed.get_rank()}.jpg")
164
    plot_coord_data(df_sp, save_to=f"coord_check_sp.{torch.distributed.get_rank()}.jpg")
165

166
    print_rank_0("Saved coord check plots... exiting")
167
    sys.exit(1)
168

169

170
def pretrain(neox_args):
171
    """Main training program.
172

173
    This function will run the following in the order provided:
174
        1) initialize Megatron.
175
        2) setup model, optimizer and lr schedule
176
        3) call train_val_test_data_provider to get train/val/test datasets.
177
        4) train the model.
178

179
    Arguments:
180
        neox_args: an instance of NeoXArgs containing the configuration for pretrain
181

182
    """
183
    # setup logging and timers
184
    init_wandb(neox_args=neox_args)
185
    timers = Timers(
186
        use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer
187
    )
188

189
    # Initialize and get arguments, timers, and Tensorboard writer.
190
    initialize_megatron(neox_args=neox_args)
191

192
    # Model, optimizer, and learning rate.
193
    timers("model and optimizer").start()
194
    model, optimizer, lr_scheduler = setup_model_and_optimizer(
195
        neox_args=neox_args, use_cache=False, iteration=neox_args.iteration
196
    )
197
    timers("model and optimizer").stop()
198

199
    # Data stuff.
200
    timers("train/valid/test data iterators").start()
201
    (
202
        train_data_iterator,
203
        valid_data_iterator,
204
        test_data_iterator,
205
    ) = build_train_valid_test_data_iterators(neox_args=neox_args)
206
    timers("train/valid/test data iterators").stop()
207

208
    if neox_args.use_mup and neox_args.coord_check:
209
        mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator)
210

211
    # Print setup timing.
212
    print_rank_0("done with setups ...")
213
    timers.log(["model and optimizer", "train/valid/test data iterators"])
214
    print_rank_0("training ...")
215

216
    iteration = neox_args.iteration
217
    # edge case: save step 0 checkpoint if requested and we're starting from step 0
218
    if neox_args.save and 0 in neox_args.save_iters and iteration == 0:
219
        save_checkpoint(
220
            neox_args=neox_args,
221
            iteration=iteration,
222
            model=model,
223
            optimizer=optimizer,
224
            lr_scheduler=lr_scheduler,
225
        )
226

227
    if neox_args.do_train and neox_args.train_iters > 0:
228
        iteration = train(
229
            neox_args=neox_args,
230
            timers=timers,
231
            model=model,
232
            optimizer=optimizer,
233
            lr_scheduler=lr_scheduler,
234
            train_data_iterator=train_data_iterator,
235
            valid_data_iterator=valid_data_iterator,
236
        )
237

238
    if neox_args.do_valid:
239
        prefix = "the end of training for val data"
240
        evaluate_and_print_results(
241
            neox_args=neox_args,
242
            prefix=prefix,
243
            forward_step_func=forward_step,
244
            data_iterator=valid_data_iterator,
245
            model=model,
246
            iteration=iteration,
247
            verbose=False,
248
            timers=timers,
249
        )
250

251
    if neox_args.save and iteration != 0:
252
        save_checkpoint(
253
            neox_args=neox_args,
254
            iteration=iteration,
255
            model=model,
256
            optimizer=optimizer,
257
            lr_scheduler=lr_scheduler,
258
        )
259

260
    if neox_args.do_test:
261
        # Run on test data.
262
        prefix = "the end of training for test data"
263
        evaluate_and_print_results(
264
            neox_args=neox_args,
265
            prefix=prefix,
266
            forward_step_func=forward_step,
267
            data_iterator=test_data_iterator,
268
            model=model,
269
            iteration=iteration,
270
            verbose=True,
271
            timers=timers,
272
            chart_name="test",
273
        )
274

275

276
def _get_batch(neox_args, tokenizer, keys, data, datatype):
277
    """Support function for get_batch / get_batch pipe (to avoid code repetition)"""
278
    data_b = mpu.broadcast_data(keys, data, datatype)
279

280
    # Unpack.
281
    tokens_ = data_b["text"].long()
282
    if "label" in data_b:
283
        labels = torch.where(
284
            data_b["label"].long() >= 0,
285
            data_b["label"].long(),
286
            torch.zeros_like(data_b["label"].long()),
287
        )[:, 1:].contiguous()
288
    else:
289
        labels = tokens_[:, 1:].contiguous()
290
    tokens = tokens_[:, :-1].contiguous()
291

292
    # Get the masks and position ids.
293
    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
294
        data=tokens,
295
        eod_token=neox_args.tokenizer.eod,
296
        eod_mask_loss=neox_args.eod_mask_loss,
297
        sliding_window_width=neox_args.sliding_window_width,
298
    )
299
    # If `label` is present, any token < 0 (e.g., -100, the default for torch) skips the loss computation
300
    if "label" in data_b:
301
        loss_mask = (data_b["label"][:, 1:] >= 0).to(loss_mask.dtype)
302
    return tokens, labels, loss_mask, attention_mask, position_ids
303

304

305
def get_batch(neox_args, data_iterator):
306
    """Generate a batch"""
307

308
    # Items and their type.
309
    keys = ["text", "label"] if neox_args.label_data_paths else ["text"]
310
    datatype = torch.int64
311

312
    # Broadcast data.
313
    if data_iterator is not None:
314
        data = next(data_iterator)
315
    else:
316
        data = None
317
    return _get_batch(
318
        neox_args=neox_args,
319
        tokenizer=neox_args.tokenizer,
320
        keys=keys,
321
        data=data,
322
        datatype=datatype,
323
    )
324

325

326
def get_batch_pipe(data, neox_args, curr_scheduler=None):
327
    """A modification of get_batch() to work with the latest batch instead of an iterator."""
328
    # Items and their type.
329
    keys = ["text", "label"] if neox_args.label_data_paths else ["text"]
330
    datatype = torch.int64
331

332
    tokens, labels, loss_mask, attention_mask, position_ids = _get_batch(
333
        neox_args, neox_args.tokenizer, keys, data, datatype
334
    )
335
    if curr_scheduler is not None:
336
        # iteration + 1 to align with how/when DeepSpeed updates the buffers
337
        curriculum_seqlen = curr_scheduler.update_difficulty(neox_args.iteration + 1)
338
        if curriculum_seqlen < tokens.size()[1]:
339
            # seqlen-based curriculum learning
340
            # input_ids, position_ids, labels have size [batch size, seqlen]
341
            # input_ids = input_ids[:, :curriculum_seqlen].contiguous()
342
            tokens = tokens[:, :curriculum_seqlen].contiguous()
343
            position_ids = position_ids[:, :curriculum_seqlen].contiguous()
344
            if labels is not None:
345
                labels = labels[:, :curriculum_seqlen].contiguous()
346
            if loss_mask is not None:
347
                loss_mask = loss_mask[:, :curriculum_seqlen].contiguous()
348
            # attention_mask has size [1, 1, seqlen, seqlen]
349
            attention_mask = attention_mask[
350
                :, :, :curriculum_seqlen, :curriculum_seqlen
351
            ].contiguous()
352

353
    # unpack data
354
    return (tokens, position_ids, attention_mask), (labels, loss_mask)
355

356

357
def get_batch_sequential(forward_input, neox_args):
358
    """A modification of get_batch() to work with the latest batch instead of an iterator."""
359
    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
360
        data=forward_input[0],
361
        eod_token=neox_args.tokenizer.eod,
362
        eod_mask_loss=neox_args.eod_mask_loss,
363
    )
364
    return (forward_input[0], forward_input[1], attention_mask)
365

366

367
def forward_step(
368
    data_iterator, model, neox_args, timers, return_logits=False, is_train=False
369
):
370
    """Forward step."""
371
    if neox_args.is_pipe_parallel:
372
        return model.eval_batch(data_iterator, return_logits=return_logits)
373

374
    # Get the batch.
375
    if neox_args.memory_profiling and neox_args.it:
376
        torch.cuda.nvtx.range_push(f"Get batch")
377
    if timers is not None:
378
        timers("batch generator").start()
379
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
380
        neox_args=neox_args, data_iterator=data_iterator
381
    )
382

383
    if timers is not None:
384
        timers("batch generator").stop()
385
    if neox_args.memory_profiling:
386
        torch.cuda.nvtx.range_pop()
387

388
    if neox_args.memory_profiling:
389
        torch.cuda.nvtx.range_push(f"Forward pass")
390
    # Sequential returns moe_losses, but this is not yet supported by pipe parallel
391
    maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args)
392
    if type(maybe_tuple) is tuple:
393
        outputs, moe_losses = maybe_tuple
394
    else:
395
        outputs = maybe_tuple
396
        moe_losses = []
397
    if (
398
        is_train
399
        and neox_args.curriculum_learning
400
        and neox_args.curriculum_seqlen < neox_args.seq_length
401
    ):
402
        loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous()
403
        labels = labels[:, : neox_args.curriculum_seqlen].contiguous()
404
    main_loss = cross_entropy(
405
        outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy
406
    )
407
    if neox_args.num_experts > 1:
408
        moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses)
409
    else:
410
        moe_loss = 0.0
411
    loss = main_loss + moe_loss
412
    if neox_args.memory_profiling:
413
        torch.cuda.nvtx.range_pop()
414
    if return_logits:
415
        return loss, outputs
416
    return loss
417

418

419
def get_model(neox_args, use_cache=False):
420
    """Build the model."""
421

422
    # Build model on cpu.
423
    print_rank_0("building GPT2 model ...")
424

425
    # Temporarily disable mup so that the base model does not use the mup init functions before set_base_shapes is called below.
426
    # If mup isn't being used anyways, this has no effect.
427
    old_use_mup = neox_args.use_mup
428
    neox_args.use_mup = False
429
    model = GPT2ModelPipe(
430
        neox_args=neox_args,
431
        num_tokentypes=0,
432
        parallel_output=True,
433
        topology=mpu.get_topology(),
434
        use_cache=use_cache,
435
    )
436

437
    ### soft prompt tuning stuff ###
438
    if neox_args.soft_prompt_tuning is not None and neox_args.soft_prompt_tuning.get(
439
        "enabled", False
440
    ):
441
        soft_prompt = SoftEmbedding(
442
            neox_args,
443
            wte=getattr(model, "0").word_embeddings,
444
            n_tokens=neox_args.soft_prompt_tuning.get("n_tokens", 10),
445
            init_string=neox_args.soft_prompt_tuning.get("init_string", ""),
446
            init_range=neox_args.soft_prompt_tuning.get("init_range", 0.5),
447
        )
448
        model.insert_layers(
449
            layers=soft_prompt, idx=1
450
        )  # insert the soft prompt layer directly after the word embeddings
451

452
        # freeze everything but the soft prompt
453
        for name, param in model.named_parameters():
454
            if not "soft_embedding" in name:
455
                param.requires_grad = False
456

457
    if not neox_args.is_pipe_parallel:
458
        # Export PipeParallel model to nn.Sequential model to avoid the overhead of deepspeed's pipe parallel training
459
        model = model.to_sequential()
460

461
    neox_args.use_mup = old_use_mup
462

463
    if neox_args.use_mup:
464
        try:
465
            import mup
466
        except ModuleNotFoundError:
467
            print("Please install mup https://github.com/microsoft/mup")
468
            raise Exception
469

470
        base_shapes = f"{neox_args.base_shapes_file}.{torch.distributed.get_rank()}"
471

472
        if neox_args.save_base_shapes:
473
            save_base_shapes(neox_args, base_shapes, use_cache)
474

475
        mup.set_base_shapes(model, base_shapes)
476

477
        # Call the mup replacement init functions on the model now that set_base_shapes has given each weight a .infshape attribute
478
        mup_weights_reinit(neox_args, model)
479

480
    if neox_args.deepspeed:
481
        # DeepSpeed handles CUDA, FP16, and DDP components.
482
        return model
483
    else:
484
        raise ValueError("Must be using deepspeed to run neox")
485

486

487
def get_optimizer(model, neox_args):
488
    """Set up the optimizer."""
489
    if neox_args.no_load_optim:
490
        return None, None
491

492
    if neox_args.optimizer is None:
493
        print_rank_0(
494
            f"ERROR: Optimizer is None. Either set the optimizer dict in your config (if training) or set no_load_optim in your config (if inference)"
495
        )
496
        exit()
497
    # Build parameter groups (weight decay and non-decay).
498
    param_groups = get_params_for_weight_decay_optimization(model, neox_args)
499
    print_rank_0(
500
        f'Configuring Optimizer type: {neox_args.optimizer_type} with params: {neox_args.optimizer["params"]}'
501
    )
502

503
    if neox_args.create_moe_param_group:
504
        from deepspeed.moe.utils import (
505
            is_moe_param,
506
            split_params_into_different_moe_groups_for_optimizer,
507
        )
508

509
        param_groups = split_params_into_different_moe_groups_for_optimizer(
510
            param_groups
511
        )
512

513
    # Add model parallel attribute if it is not set.
514
    for param_group in param_groups:
515
        for param in param_group["params"]:
516
            if not hasattr(param, "model_parallel"):
517
                param.model_parallel = False
518

519
    # Filter out params that don't require a grad (for soft prompt tuning, etc.)
520
    _param_groups = []
521
    for param_group in param_groups:
522
        trainable_params = [p for p in param_group["params"] if p.requires_grad]
523
        param_group["params"] = trainable_params
524
        _param_groups.append(param_group)
525
    param_groups = _param_groups
526

527
    # If we're using mup, then the optimizer must be adam or sgd
528
    assert not neox_args.use_mup or (
529
        neox_args.optimizer_type.lower() == "adam"
530
        or neox_args.optimizer_type.lower() == "sgd"
531
    ), f"If use_mup == True, you must specify either the adam or sgd optimizers. You passed: {neox_args.optimizer_type.lower()}"
532

533
    if neox_args.optimizer_type.lower() in ["cpu_adam", "cpu_torch_adam"]:
534
        if neox_args.optimizer == "cpu_torch_adam":
535
            cpu_adam_optimizer = torch.optim.Adam
536
        else:
537
            from deepspeed.ops.adam import DeepSpeedCPUAdam
538

539
            cpu_adam_optimizer = DeepSpeedCPUAdam
540
        optimizer = cpu_adam_optimizer(
541
            param_groups,
542
            weight_decay=neox_args.weight_decay,
543
            **neox_args.optimizer["params"],
544
        )
545
    elif neox_args.optimizer_type.lower() == "onebitadam":
546
        assert neox_args.deepspeed
547
        optimizer = None
548
        # onebitadam needs to be instantiated within the deepspeed engine to work :|
549
    elif neox_args.optimizer_type.lower() == "sm3":
550
        from .optimizers import SM3
551

552
        optimizer = SM3(param_groups, **neox_args.optimizer["params"])
553
    elif neox_args.optimizer_type.lower() == "madgrad_wd":
554
        from .optimizers import madgrad_wd
555

556
        optimizer = madgrad_wd(
557
            param_groups,
558
            weight_decay=neox_args.weight_decay,
559
            **neox_args.optimizer["params"],
560
        )
561
    elif neox_args.optimizer_type.lower() == "lion":
562
        # if we want the deepspeed zero lion...megatron lion will throw DeepSpeed Error
563
        if neox_args.zero_optimization["stage"] != 0:
564
            from deepspeed.ops.lion import FusedLion
565

566
            lion_optimizer = FusedLion
567
        # if not zero
568
        else:
569
            from .optimizers import Lion
570

571
            lion_optimizer = Lion
572

573
        optimizer = lion_optimizer(
574
            param_groups,
575
            weight_decay=neox_args.weight_decay,
576
            **neox_args.optimizer["params"],
577
        )
578
    elif neox_args.optimizer_type.lower() == "adam":
579
        # Use Adam
580
        if neox_args.use_mup:
581
            try:
582
                from mup import MuAdam
583

584
                adam_optimizer = MuAdam
585
            except ModuleNotFoundError:
586
                print("Please install mup https://github.com/microsoft/mup")
587
                raise Exception
588
        else:
589
            if neox_args.use_bnb_optimizer:
590
                try:
591
                    import bitsandbytes as bnb
592

593
                    adam_optimizer = bnb.optim.Adam8bit
594
                except ModuleNotFoundError:
595
                    print(
596
                        "Please install bitsandbytes following https://github.com/facebookresearch/bitsandbytes."
597
                    )
598
                    raise Exception
599
            else:
600
                try:
601
                    # default to apex as it's slightly faster
602
                    from apex.optimizers import FusedAdam as Adam
603
                except ImportError:
604
                    # if apex isn't installed, use deepspeed's FusedAdam
605
                    print(
606
                        "WARNING: APEX not installed - defaulting to deepspeed's fused adam"
607
                    )
608
                    from deepspeed.ops.adam import FusedAdam as Adam
609
                adam_optimizer = Adam
610
        optimizer = adam_optimizer(
611
            param_groups,
612
            weight_decay=neox_args.weight_decay,
613
            **neox_args.optimizer["params"],
614
        )
615
    elif neox_args.optimizer_type.lower() == "sgd":
616
        try:
617
            from mup import MuSGD
618
        except ModuleNotFoundError:
619
            print("Please install mup https://github.com/microsoft/mup")
620
            raise Exception
621
        optimizer = MuSGD(
622
            param_groups,
623
            weight_decay=neox_args.weight_decay,
624
            **neox_args.optimizer["params"],
625
        )
626
    else:
627
        raise ValueError(f"Optimizer type {neox_args.optimizer_type} not recognized")
628

629
    if neox_args.deepspeed:
630
        # fp16 wrapper is not required for DeepSpeed.
631
        return optimizer, param_groups
632
    else:
633
        raise ValueError("Must be using deepspeed to run neox")
634

635

636
def get_learning_rate_scheduler(optimizer, neox_args):
637
    """Build the learning rate scheduler."""
638
    if neox_args.no_load_optim:
639
        # TODO: this should be configured as a separate arg
640
        return None
641
    if neox_args.deepspeed and neox_args.optimizer_type.lower() == "onebitadam":
642
        print_rank_0(
643
            "WARNING: onebitadam requires the lr scheduler be built by deepspeed - "
644
            "Make sure one is added to your deepspeed config"
645
        )
646
        return None
647

648
    # Add linear learning rate scheduler.
649
    if neox_args.lr_decay_iters is not None:
650
        num_iters = neox_args.lr_decay_iters
651
    else:
652
        num_iters = neox_args.train_iters
653
    num_iters = max(1, num_iters)
654
    init_step = 0
655
    warmup_iter = neox_args.warmup * num_iters
656
    lr_scheduler = AnnealingLR(
657
        optimizer,
658
        start_lr=neox_args.lr,
659
        warmup_iter=warmup_iter,
660
        total_iters=num_iters,
661
        decay_style=neox_args.lr_decay_style,
662
        last_iter=init_step,
663
        min_lr=neox_args.min_lr,
664
        use_checkpoint_lr_scheduler=neox_args.use_checkpoint_lr_scheduler,
665
        override_lr_scheduler=neox_args.override_lr_scheduler,
666
        use_mup=neox_args.use_mup,
667
    )
668

669
    return lr_scheduler
670

671

672
def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
673
    """Setup memory profiler"""
674
    if neox_args.memory_profiling:
675
        torch.cuda.memory._record_memory_history(
676
            True,
677
            # keep a maximum 100,000 alloc/free events from before the snapshot
678
            trace_alloc_max_entries=100000,
679
            trace_alloc_record_context=True,
680
        )
681

682
    """Setup model and optimizer."""
683
    model = get_model(neox_args=neox_args, use_cache=use_cache)
684
    optimizer, param_groups = get_optimizer(model=model, neox_args=neox_args)
685
    lr_scheduler = get_learning_rate_scheduler(optimizer=optimizer, neox_args=neox_args)
686

687
    if neox_args.deepspeed:
688
        print_rank_0("DeepSpeed is enabled.")
689
        if neox_args.no_load_optim:
690
            assert optimizer is None
691
            _model_params = None
692
            _lr_scheduler = None
693
        else:
694
            _model_params = param_groups if optimizer is None else None
695
            _lr_scheduler = lr_scheduler
696

697
        model, optimizer, _, lr_scheduler = deepspeed.initialize(
698
            model=model,
699
            optimizer=optimizer,
700
            args=neox_args,
701
            lr_scheduler=_lr_scheduler,
702
            dist_init_required=False,
703
            model_parameters=_model_params,
704
            # Need to remove the below so that it doesn't conflict with --deepspeed_config required by autotuning
705
            # config_params=neox_args.deepspeed_config,
706
            mpu=mpu if not neox_args.is_pipe_parallel else None,
707
        )
708
        model.total_params = get_total_params(model.module)
709
        print_rank_0(f' > total params: {"{:,}".format(model.total_params)}')
710

711
        if neox_args.is_pipe_parallel:
712
            model.set_has_attention_mask(True)
713
            if neox_args.curriculum_learning:
714
                curr_scheduler = CurriculumScheduler(neox_args.curriculum_learning)
715
                if iteration is not None and iteration > 0:
716
                    curr_scheduler.update_difficulty(iteration)
717
            else:
718
                curr_scheduler = None
719
            model.set_batch_fn(
720
                partial(
721
                    get_batch_pipe, neox_args=neox_args, curr_scheduler=curr_scheduler
722
                )
723
            )
724
        else:
725
            model.module.set_batch_fn(
726
                partial(get_batch_sequential, neox_args=neox_args)
727
            )
728

729
    else:
730
        raise ValueError("Must be using deepspeed to run neox")
731

732
    if neox_args.load is not None:
733
        neox_args.iteration = load_checkpoint(
734
            neox_args=neox_args,
735
            model=model,
736
            optimizer=optimizer,
737
            lr_scheduler=lr_scheduler,
738
            iteration=iteration,
739
        )
740
        print_rank_0(
741
            f"Loading checkpoint and starting from iteration {neox_args.iteration}"
742
        )
743
    else:
744
        neox_args.iteration = 0
745

746
    # need this for correct lr scheduling resume from ckpt
747
    # but it will not exist if this is being called for inference
748
    if lr_scheduler is not None:
749
        lr_scheduler.optimizer = model.optimizer
750

751
    return model, optimizer, lr_scheduler
752

753

754
def backward_step(neox_args, timers, optimizer, model, loss):
755
    """Backward step."""
756

757
    # Backward pass.
758
    timers("backward-backward").start()
759
    if neox_args.deepspeed:
760
        model.backward(loss)
761
    else:
762
        raise ValueError("Must be using deepspeed to run neox")
763
    timers("backward-backward").stop()
764

765
    if neox_args.deepspeed:
766
        # DeepSpeed backward propagation already addressed all reduce communication.
767
        # Reset the timer to avoid breaking timer logs below.
768
        timers("backward-allreduce").reset()
769
    else:
770
        raise ValueError("Must be using deepspeed to run neox")
771

772

773
def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler):
774
    """Single training step."""
775

776
    # Pipeline parallelism schedules forward/backward/step
777
    if neox_args.is_pipe_parallel:
778
        reduced_loss = train_step_pipe(
779
            neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator
780
        )
781
        if (
782
            neox_args.memory_profiling
783
            and neox_args.iteration >= neox_args.profile_step_start
784
            and neox_args.iteration <= neox_args.profile_step_stop
785
            and torch.distributed.get_rank() == 0
786
        ):
787
            save_snapshot(neox_args)
788
    else:
789
        losses = []
790
        for _ in range(neox_args.gradient_accumulation_steps):
791
            # Forward model for one step.
792
            timers("forward").start()
793
            loss = forward_step(
794
                neox_args=neox_args,
795
                timers=timers,
796
                data_iterator=data_iterator,
797
                model=model,
798
                is_train=True,
799
            )
800
            timers("forward").stop()
801
            losses.append(loss)
802
            # Calculate gradients, reduce across processes, and clip.
803
            if (
804
                neox_args.profile
805
                and neox_args.iteration >= neox_args.profile_step_start
806
                and neox_args.iteration <= neox_args.profile_step_stop
807
            ):
808
                torch.cuda.nvtx.range_push(f"Backward pass")
809
            timers("backward").start()
810
            backward_step(
811
                neox_args=neox_args,
812
                timers=timers,
813
                optimizer=optimizer,
814
                model=model,
815
                loss=loss,
816
            )
817
            timers("backward").stop()
818
            if (
819
                neox_args.profile
820
                and neox_args.iteration >= neox_args.profile_step_start
821
                and neox_args.iteration <= neox_args.profile_step_stop
822
            ):
823
                torch.cuda.nvtx.range_pop()
824
            # Update parameters.
825
            if (
826
                neox_args.profile
827
                and neox_args.iteration >= neox_args.profile_step_start
828
                and neox_args.iteration <= neox_args.profile_step_stop
829
            ):
830
                torch.cuda.nvtx.range_push(f"Optimizer step")
831
            timers("optimizer").start()
832
            if neox_args.deepspeed:
833
                model.step()
834
            else:
835
                raise ValueError("Must be using deepspeed to run neox")
836
            timers("optimizer").stop()
837
            if (
838
                neox_args.profile
839
                and neox_args.iteration >= neox_args.profile_step_start
840
                and neox_args.iteration <= neox_args.profile_step_stop
841
            ):
842
                torch.cuda.nvtx.range_pop()
843
            if (
844
                neox_args.profile
845
                and neox_args.iteration >= neox_args.profile_step_start
846
                and neox_args.iteration <= neox_args.profile_step_stop
847
                and torch.distributed.get_rank() == 0
848
            ):
849
                save_snapshot(neox_args)
850
        reduced_loss = {
851
            "lm_loss": reduce_losses(losses).mean()
852
        }  # reduces losses across machines for logging
853

854
    if neox_args.precision == "fp16" and model.optimizer.overflow:
855
        skipped_iter = 1
856
    else:
857
        skipped_iter = 0
858

859
    collect_loss_for_unit_test(reduced_loss["lm_loss"])
860
    return reduced_loss, skipped_iter
861

862

863
def train_step_pipe(neox_args, timers, model, data_iterator):
864
    """Single training step with DeepSpeed's pipeline parallel engine."""
865

866
    assert neox_args.deepspeed
867
    loss = model.train_batch(data_iter=data_iterator)
868
    loss_dict = {"lm_loss": loss}
869
    # Don't break Megatron's timers because we changed code paths.
870
    for t in [
871
        "forward",
872
        "backward",
873
        "allreduce",
874
        "optimizer",
875
        "batch generator",
876
        "data loader",
877
    ]:
878
        timers(t).reset()
879
    return loss_dict
880

881

882
def train(
883
    neox_args,
884
    timers,
885
    model,
886
    optimizer,
887
    lr_scheduler,
888
    train_data_iterator,
889
    valid_data_iterator,
890
):
891
    """Train the model function."""
892

893
    # Turn on training mode which enables dropout.
894
    model.train()
895

896
    # Tracking loss.
897
    total_loss_dict = {}
898

899
    # Iterations.
900
    iteration = neox_args.iteration
901

902
    timers("interval time").start()
903
    report_memory_flag = True
904

905
    # get noise scale logger (if neox_args.log_gradient_noise_scale is True)
906
    noise_scale_logger = get_noise_scale_logger(neox_args)
907

908
    # to monitor if we've skipped many iterations in a row and trigger an early exit
909
    overflow_monitor = OverflowMonitor(optimizer)
910
    while iteration < neox_args.train_iters:
911
        if neox_args.profile and iteration == neox_args.profile_step_start:
912
            torch.cuda.cudart().cudaProfilerStart()
913
        loss_dict, skipped_iter = train_step(
914
            neox_args=neox_args,
915
            timers=timers,
916
            data_iterator=train_data_iterator,
917
            model=model,
918
            optimizer=optimizer,
919
            lr_scheduler=lr_scheduler,
920
        )
921
        if neox_args.profile and iteration == neox_args.profile_step_stop:
922
            torch.cuda.cudart().cudaProfilerStop()
923
        iteration += 1
924
        neox_args.iteration = iteration
925
        if neox_args.precision == "fp16":
926
            overflow_monitor.check(skipped_iter)  # check for repeated overflow
927
        if neox_args.log_gradient_noise_scale:  # log noise scale if applicable
928
            noise_scale_logger.update()
929

930
        # get learning rate (if present) - if doing soft prompt tuning + pipe parallel, you
931
        # may have no tunable parameters on a specific rank
932
        if optimizer.param_groups:
933
            lr = optimizer.param_groups[0].get("lr", 0)
934
        else:
935
            lr = 0
936

937
        # Logging.
938
        report_memory_flag = training_log(
939
            neox_args=neox_args,
940
            timers=timers,
941
            loss_dict=loss_dict,
942
            total_loss_dict=total_loss_dict,
943
            learning_rate=lr,
944
            iteration=iteration,
945
            loss_scale=optimizer.cur_scale if neox_args.precision == "fp16" else None,
946
            report_memory_flag=report_memory_flag,
947
            skipped_iter=skipped_iter,
948
            model=model,
949
            optimizer=optimizer,
950
            noise_scale_logger=noise_scale_logger,
951
        )
952

953
        # Checkpointing
954
        if neox_args.save and iteration in neox_args.save_iters:
955
            save_checkpoint(
956
                neox_args=neox_args,
957
                iteration=iteration,
958
                model=model,
959
                optimizer=optimizer,
960
                lr_scheduler=lr_scheduler,
961
            )
962
        # Evaluation
963
        if (
964
            neox_args.eval_interval
965
            and iteration % neox_args.eval_interval == 0
966
            and neox_args.do_valid
967
        ):
968
            prefix = "iteration {}".format(iteration)
969
            evaluate_and_print_results(
970
                neox_args=neox_args,
971
                prefix=prefix,
972
                forward_step_func=forward_step,
973
                data_iterator=valid_data_iterator,
974
                model=model,
975
                iteration=iteration,
976
                verbose=False,
977
                timers=timers,
978
            )
979

980
        if neox_args.exit_interval and iteration % neox_args.exit_interval == 0:
981
            torch.distributed.barrier()
982
            time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
983
            rank = torch.distributed.get_rank()
984
            print_rank_0(
985
                "rank: {} | time: {} | exiting the program at iteration {}".format(
986
                    rank, time_str, iteration
987
                )
988
            )
989
            sys.exit()
990

991
    return iteration
992

993

994
def evaluate(
995
    neox_args, forward_step_fn, data_iterator, model, verbose=False, timers=None
996
):
997
    """Evaluation.
998
    neox_args: NeoX Arguments
999
    forward_step_fn: function with args `neox_args, timers,
1000
                    data_iterator & model that will run a forward pass on the model
1001
    data_iterator: Iterator that iterates over batches of data. Should return data in the form:
1002
                    {'text': np.array([tokens], dtype=np.int64)}
1003
                    where the size of the array is the model's context size + 1
1004
                    (`get_batch` transforms it into inputs / labels)
1005
    """
1006
    # Turn on evaluation mode which disables dropout.
1007
    model.eval()
1008
    losses = []
1009
    if neox_args.char_level_ppl:
1010
        data_iterator = CharCounter(data_iterator, neox_args.tokenizer)
1011

1012
    with torch.no_grad():
1013
        iteration = 0
1014
        while iteration < neox_args.eval_iters:
1015
            iteration += 1
1016
            if verbose and iteration % neox_args.log_interval == 0:
1017
                print_rank_0(
1018
                    "Evaluating iter {}/{}".format(iteration, neox_args.eval_iters)
1019
                )
1020

1021
            # although we're not accumulating gradients here, we count one iter as train_batch_size_per_gpu * g.a.s
1022
            # to be consistent with deepspeed's pipe parallel engine
1023
            # since pipe parallel already takes gradient_accumulation_steps into account - default to 1 here if pipe parallel is true
1024
            for _ in range(
1025
                1
1026
                if neox_args.is_pipe_parallel
1027
                else neox_args.gradient_accumulation_steps
1028
            ):
1029
                # Forward evaluation
1030
                loss = forward_step_fn(
1031
                    model=model,
1032
                    data_iterator=data_iterator,
1033
                    neox_args=neox_args,
1034
                    timers=timers,
1035
                )
1036
                losses.append(loss)
1037

1038
            # When contiguous memory optimizations are enabled, the buffers
1039
            # allocated by the optimizations are deallocated during backward pass
1040
            # in the absence of backward pass the buffers should be reset after each
1041
            # forward pass
1042
            if neox_args.deepspeed and neox_args.deepspeed_activation_checkpointing:
1043
                deepspeed.checkpointing.reset()
1044

1045
    # reduces losses across processes for logging & run eval harness tasks
1046
    eval_results = {"lm_loss": reduce_losses(losses).mean().item()}
1047
    eval_results["lm_loss_ppl"] = math.exp(eval_results["lm_loss"])
1048

1049
    if neox_args.char_level_ppl:
1050
        # calculate character level perplexity, if specified
1051
        # if neox_args.char_level_ppl:
1052
        # unwrap the data_iterator
1053
        tokens_per_char = data_iterator.tokens_per_char()
1054
        print_rank_0(f"Counting chars took {data_iterator.total_time} seconds")
1055

1056
        data_iterator = data_iterator.data_iterator
1057
        eval_results["lm_loss_char_lvl_ppl"] = math.exp(
1058
            eval_results["lm_loss"] * tokens_per_char
1059
        )
1060

1061
    if neox_args.eval_tasks:
1062
        from eval_tasks import run_eval_harness
1063

1064
        eval_results.update(
1065
            run_eval_harness(
1066
                model, forward_step_fn, neox_args, eval_tasks=neox_args.eval_tasks
1067
            ).get("results")
1068
        )
1069
    # Move model back to the train mode.
1070
    model.train()
1071
    return eval_results
1072

1073

1074
def collect_loss_for_unit_test(lm_ss):
1075
    # Logic moved to separate function to allow tracking in unit tests with unittest.mock.patch
1076
    pass
1077

1078

1079
def evaluate_and_print_results(
1080
    neox_args,
1081
    prefix,
1082
    forward_step_func,
1083
    data_iterator,
1084
    model,
1085
    iteration,
1086
    verbose=False,
1087
    timers=None,
1088
    chart_name="validation",
1089
):
1090
    """Helper function to evaluate and dump results on screen."""
1091
    total_loss_dict = evaluate(
1092
        neox_args=neox_args,
1093
        forward_step_fn=forward_step_func,
1094
        data_iterator=data_iterator,
1095
        model=model,
1096
        verbose=verbose,
1097
        timers=timers,
1098
    )
1099
    string = f" {chart_name} results at {prefix} | "
1100
    for k, v in total_loss_dict.items():
1101
        if isinstance(v, dict):
1102
            if neox_args.eval_tasks and "results" in v:
1103
                v = v["results"]
1104
                print(v)
1105
            for k2, v2 in v.items():
1106
                k3 = "_".join([k, k2])
1107
                string += f"{k3} value: {v2:.6E} | "
1108
                tb_wandb_log(
1109
                    f"{chart_name}/{k3}",
1110
                    v2,
1111
                    iteration,
1112
                    use_wandb=neox_args.use_wandb,
1113
                    tensorboard_writer=neox_args.tensorboard_writer,
1114
                )
1115
        else:
1116
            string += f"{k} value: {v:.6E} | "
1117
            tb_wandb_log(
1118
                f"{chart_name}/{k}",
1119
                v,
1120
                iteration,
1121
                use_wandb=neox_args.use_wandb,
1122
                tensorboard_writer=neox_args.tensorboard_writer,
1123
            )
1124

1125
    length = len(string) + 1
1126
    print_rank_0("-" * length)
1127
    print_rank_0(string)
1128
    print_rank_0("-" * length)
1129

1130

1131
def save_snapshot(neox_args):
1132
    assert (
1133
        neox_args.memory_profiling_path is not None
1134
    ), "Must pass memory_profiling_path config arg to use profiling"
1135
    snapshot = torch.cuda.memory._snapshot()
1136
    snapshot_path = os.path.join(neox_args.memory_profiling_path)
1137
    if not os.path.exists(snapshot_path):
1138
        os.makedirs(snapshot_path)
1139
    with open(os.path.join(snapshot_path, "mem_snapshot.pickle"), "wb") as f:
1140
        dump(snapshot, f)
1141

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.