21
"""Pretrain utilities."""
22
from datetime import datetime
23
from functools import partial
30
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler
33
from megatron.utils import (
36
get_ltor_masks_and_position_ids,
40
from megatron import print_rank_0, mpu
41
from megatron.model import (
44
get_params_for_weight_decay_optimization,
46
from megatron.checkpointing import load_checkpoint, save_checkpoint
47
from megatron.data.data_utils import build_train_valid_test_data_iterators
48
from megatron.initialize import initialize_megatron
49
from megatron.learning_rates import AnnealingLR
50
from megatron.logging import tb_wandb_log, training_log
51
from megatron.utils import (
53
get_noise_scale_logger,
57
from megatron.model.gpt2_model import cross_entropy
59
from pickle import dump
63
def mup_weights_reinit(neox_args, model):
64
def has_method(o, name):
65
return callable(getattr(o, name, None))
67
for layer in model.modules():
69
if hasattr(layer, "mup_rescale_parameters") and layer.mup_rescale_parameters:
70
layer._rescale_parameters()
72
if has_method(layer, "mup_reinitialize_weights"):
73
layer.mup_reinitialize_weights(neox_args)
76
def save_base_shapes(neox_args, base_shapes, use_cache):
79
neox_args.use_mup = False
81
base_model = GPT2ModelPipe(
85
topology=mpu.get_topology(),
89
if not neox_args.is_pipe_parallel:
90
base_model = base_model.to_sequential()
94
except ModuleNotFoundError:
95
print("Please install mup https://github.com/microsoft/mup")
98
base_shapes = mup.get_shapes(base_model)
102
old_hidden_size = neox_args.hidden_size
103
neox_args.hidden_size = neox_args.hidden_size * neox_args.mup_width_scale
105
delta_model = GPT2ModelPipe(
108
parallel_output=True,
109
topology=mpu.get_topology(),
113
if not neox_args.is_pipe_parallel:
114
delta_model = delta_model.to_sequential()
116
delta_shapes = mup.get_shapes(delta_model)
119
neox_args.use_mup = True
120
neox_args.hidden_size = old_hidden_size
122
save_shapes = f"{neox_args.base_shapes_file}.{torch.distributed.get_rank()}"
123
print(f"saving base shapes at {save_shapes}")
124
mup.make_base_shapes(base_shapes, delta_shapes, savefile=save_shapes)
125
print(f"base shapes saved...exiting")
129
def mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator):
130
from megatron.mup_substitute import get_coord_data
131
from mup.coord_check import plot_coord_data
133
def lazy_model(hidden_size):
135
old_hidden_size = neox_args.hidden_size
136
neox_args.hidden_size = hidden_size
138
model, optimizer, _ = setup_model_and_optimizer(
139
neox_args=neox_args, use_cache=False
142
neox_args.hidden_size = old_hidden_size
151
for hidden_size in (neox_args.num_attention_heads * (2**p) for p in range(2, 9)):
152
models[hidden_size] = lazy_model(hidden_size)
154
neox_args.use_mup = True
155
df_up = get_coord_data(
156
neox_args, timers, lr_scheduler, models, train_data_iterator, mup=True
158
neox_args.use_mup = False
159
df_sp = get_coord_data(
160
neox_args, timers, lr_scheduler, models, train_data_iterator, mup=False
163
plot_coord_data(df_up, save_to=f"coord_check_up.{torch.distributed.get_rank()}.jpg")
164
plot_coord_data(df_sp, save_to=f"coord_check_sp.{torch.distributed.get_rank()}.jpg")
166
print_rank_0("Saved coord check plots... exiting")
170
def pretrain(neox_args):
171
"""Main training program.
173
This function will run the following in the order provided:
174
1) initialize Megatron.
175
2) setup model, optimizer and lr schedule
176
3) call train_val_test_data_provider to get train/val/test datasets.
180
neox_args: an instance of NeoXArgs containing the configuration for pretrain
184
init_wandb(neox_args=neox_args)
186
use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer
190
initialize_megatron(neox_args=neox_args)
193
timers("model and optimizer").start()
194
model, optimizer, lr_scheduler = setup_model_and_optimizer(
195
neox_args=neox_args, use_cache=False, iteration=neox_args.iteration
197
timers("model and optimizer").stop()
200
timers("train/valid/test data iterators").start()
205
) = build_train_valid_test_data_iterators(neox_args=neox_args)
206
timers("train/valid/test data iterators").stop()
208
if neox_args.use_mup and neox_args.coord_check:
209
mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator)
212
print_rank_0("done with setups ...")
213
timers.log(["model and optimizer", "train/valid/test data iterators"])
214
print_rank_0("training ...")
216
iteration = neox_args.iteration
218
if neox_args.save and 0 in neox_args.save_iters and iteration == 0:
224
lr_scheduler=lr_scheduler,
227
if neox_args.do_train and neox_args.train_iters > 0:
233
lr_scheduler=lr_scheduler,
234
train_data_iterator=train_data_iterator,
235
valid_data_iterator=valid_data_iterator,
238
if neox_args.do_valid:
239
prefix = "the end of training for val data"
240
evaluate_and_print_results(
243
forward_step_func=forward_step,
244
data_iterator=valid_data_iterator,
251
if neox_args.save and iteration != 0:
257
lr_scheduler=lr_scheduler,
260
if neox_args.do_test:
262
prefix = "the end of training for test data"
263
evaluate_and_print_results(
266
forward_step_func=forward_step,
267
data_iterator=test_data_iterator,
276
def _get_batch(neox_args, tokenizer, keys, data, datatype):
277
"""Support function for get_batch / get_batch pipe (to avoid code repetition)"""
278
data_b = mpu.broadcast_data(keys, data, datatype)
281
tokens_ = data_b["text"].long()
282
if "label" in data_b:
283
labels = torch.where(
284
data_b["label"].long() >= 0,
285
data_b["label"].long(),
286
torch.zeros_like(data_b["label"].long()),
287
)[:, 1:].contiguous()
289
labels = tokens_[:, 1:].contiguous()
290
tokens = tokens_[:, :-1].contiguous()
293
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
295
eod_token=neox_args.tokenizer.eod,
296
eod_mask_loss=neox_args.eod_mask_loss,
297
sliding_window_width=neox_args.sliding_window_width,
300
if "label" in data_b:
301
loss_mask = (data_b["label"][:, 1:] >= 0).to(loss_mask.dtype)
302
return tokens, labels, loss_mask, attention_mask, position_ids
305
def get_batch(neox_args, data_iterator):
306
"""Generate a batch"""
309
keys = ["text", "label"] if neox_args.label_data_paths else ["text"]
310
datatype = torch.int64
313
if data_iterator is not None:
314
data = next(data_iterator)
319
tokenizer=neox_args.tokenizer,
326
def get_batch_pipe(data, neox_args, curr_scheduler=None):
327
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
329
keys = ["text", "label"] if neox_args.label_data_paths else ["text"]
330
datatype = torch.int64
332
tokens, labels, loss_mask, attention_mask, position_ids = _get_batch(
333
neox_args, neox_args.tokenizer, keys, data, datatype
335
if curr_scheduler is not None:
337
curriculum_seqlen = curr_scheduler.update_difficulty(neox_args.iteration + 1)
338
if curriculum_seqlen < tokens.size()[1]:
342
tokens = tokens[:, :curriculum_seqlen].contiguous()
343
position_ids = position_ids[:, :curriculum_seqlen].contiguous()
344
if labels is not None:
345
labels = labels[:, :curriculum_seqlen].contiguous()
346
if loss_mask is not None:
347
loss_mask = loss_mask[:, :curriculum_seqlen].contiguous()
349
attention_mask = attention_mask[
350
:, :, :curriculum_seqlen, :curriculum_seqlen
354
return (tokens, position_ids, attention_mask), (labels, loss_mask)
357
def get_batch_sequential(forward_input, neox_args):
358
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
359
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
360
data=forward_input[0],
361
eod_token=neox_args.tokenizer.eod,
362
eod_mask_loss=neox_args.eod_mask_loss,
364
return (forward_input[0], forward_input[1], attention_mask)
368
data_iterator, model, neox_args, timers, return_logits=False, is_train=False
371
if neox_args.is_pipe_parallel:
372
return model.eval_batch(data_iterator, return_logits=return_logits)
375
if neox_args.memory_profiling and neox_args.it:
376
torch.cuda.nvtx.range_push(f"Get batch")
377
if timers is not None:
378
timers("batch generator").start()
379
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
380
neox_args=neox_args, data_iterator=data_iterator
383
if timers is not None:
384
timers("batch generator").stop()
385
if neox_args.memory_profiling:
386
torch.cuda.nvtx.range_pop()
388
if neox_args.memory_profiling:
389
torch.cuda.nvtx.range_push(f"Forward pass")
391
maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args)
392
if type(maybe_tuple) is tuple:
393
outputs, moe_losses = maybe_tuple
395
outputs = maybe_tuple
399
and neox_args.curriculum_learning
400
and neox_args.curriculum_seqlen < neox_args.seq_length
402
loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous()
403
labels = labels[:, : neox_args.curriculum_seqlen].contiguous()
404
main_loss = cross_entropy(
405
outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy
407
if neox_args.num_experts > 1:
408
moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses)
411
loss = main_loss + moe_loss
412
if neox_args.memory_profiling:
413
torch.cuda.nvtx.range_pop()
419
def get_model(neox_args, use_cache=False):
420
"""Build the model."""
423
print_rank_0("building GPT2 model ...")
427
old_use_mup = neox_args.use_mup
428
neox_args.use_mup = False
429
model = GPT2ModelPipe(
432
parallel_output=True,
433
topology=mpu.get_topology(),
438
if neox_args.soft_prompt_tuning is not None and neox_args.soft_prompt_tuning.get(
441
soft_prompt = SoftEmbedding(
443
wte=getattr(model, "0").word_embeddings,
444
n_tokens=neox_args.soft_prompt_tuning.get("n_tokens", 10),
445
init_string=neox_args.soft_prompt_tuning.get("init_string", ""),
446
init_range=neox_args.soft_prompt_tuning.get("init_range", 0.5),
449
layers=soft_prompt, idx=1
453
for name, param in model.named_parameters():
454
if not "soft_embedding" in name:
455
param.requires_grad = False
457
if not neox_args.is_pipe_parallel:
459
model = model.to_sequential()
461
neox_args.use_mup = old_use_mup
463
if neox_args.use_mup:
466
except ModuleNotFoundError:
467
print("Please install mup https://github.com/microsoft/mup")
470
base_shapes = f"{neox_args.base_shapes_file}.{torch.distributed.get_rank()}"
472
if neox_args.save_base_shapes:
473
save_base_shapes(neox_args, base_shapes, use_cache)
475
mup.set_base_shapes(model, base_shapes)
478
mup_weights_reinit(neox_args, model)
480
if neox_args.deepspeed:
484
raise ValueError("Must be using deepspeed to run neox")
487
def get_optimizer(model, neox_args):
488
"""Set up the optimizer."""
489
if neox_args.no_load_optim:
492
if neox_args.optimizer is None:
494
f"ERROR: Optimizer is None. Either set the optimizer dict in your config (if training) or set no_load_optim in your config (if inference)"
498
param_groups = get_params_for_weight_decay_optimization(model, neox_args)
500
f'Configuring Optimizer type: {neox_args.optimizer_type} with params: {neox_args.optimizer["params"]}'
503
if neox_args.create_moe_param_group:
504
from deepspeed.moe.utils import (
506
split_params_into_different_moe_groups_for_optimizer,
509
param_groups = split_params_into_different_moe_groups_for_optimizer(
514
for param_group in param_groups:
515
for param in param_group["params"]:
516
if not hasattr(param, "model_parallel"):
517
param.model_parallel = False
521
for param_group in param_groups:
522
trainable_params = [p for p in param_group["params"] if p.requires_grad]
523
param_group["params"] = trainable_params
524
_param_groups.append(param_group)
525
param_groups = _param_groups
528
assert not neox_args.use_mup or (
529
neox_args.optimizer_type.lower() == "adam"
530
or neox_args.optimizer_type.lower() == "sgd"
531
), f"If use_mup == True, you must specify either the adam or sgd optimizers. You passed: {neox_args.optimizer_type.lower()}"
533
if neox_args.optimizer_type.lower() in ["cpu_adam", "cpu_torch_adam"]:
534
if neox_args.optimizer == "cpu_torch_adam":
535
cpu_adam_optimizer = torch.optim.Adam
537
from deepspeed.ops.adam import DeepSpeedCPUAdam
539
cpu_adam_optimizer = DeepSpeedCPUAdam
540
optimizer = cpu_adam_optimizer(
542
weight_decay=neox_args.weight_decay,
543
**neox_args.optimizer["params"],
545
elif neox_args.optimizer_type.lower() == "onebitadam":
546
assert neox_args.deepspeed
549
elif neox_args.optimizer_type.lower() == "sm3":
550
from .optimizers import SM3
552
optimizer = SM3(param_groups, **neox_args.optimizer["params"])
553
elif neox_args.optimizer_type.lower() == "madgrad_wd":
554
from .optimizers import madgrad_wd
556
optimizer = madgrad_wd(
558
weight_decay=neox_args.weight_decay,
559
**neox_args.optimizer["params"],
561
elif neox_args.optimizer_type.lower() == "lion":
563
if neox_args.zero_optimization["stage"] != 0:
564
from deepspeed.ops.lion import FusedLion
566
lion_optimizer = FusedLion
569
from .optimizers import Lion
571
lion_optimizer = Lion
573
optimizer = lion_optimizer(
575
weight_decay=neox_args.weight_decay,
576
**neox_args.optimizer["params"],
578
elif neox_args.optimizer_type.lower() == "adam":
580
if neox_args.use_mup:
582
from mup import MuAdam
584
adam_optimizer = MuAdam
585
except ModuleNotFoundError:
586
print("Please install mup https://github.com/microsoft/mup")
589
if neox_args.use_bnb_optimizer:
591
import bitsandbytes as bnb
593
adam_optimizer = bnb.optim.Adam8bit
594
except ModuleNotFoundError:
596
"Please install bitsandbytes following https://github.com/facebookresearch/bitsandbytes."
602
from apex.optimizers import FusedAdam as Adam
606
"WARNING: APEX not installed - defaulting to deepspeed's fused adam"
608
from deepspeed.ops.adam import FusedAdam as Adam
609
adam_optimizer = Adam
610
optimizer = adam_optimizer(
612
weight_decay=neox_args.weight_decay,
613
**neox_args.optimizer["params"],
615
elif neox_args.optimizer_type.lower() == "sgd":
617
from mup import MuSGD
618
except ModuleNotFoundError:
619
print("Please install mup https://github.com/microsoft/mup")
623
weight_decay=neox_args.weight_decay,
624
**neox_args.optimizer["params"],
627
raise ValueError(f"Optimizer type {neox_args.optimizer_type} not recognized")
629
if neox_args.deepspeed:
631
return optimizer, param_groups
633
raise ValueError("Must be using deepspeed to run neox")
636
def get_learning_rate_scheduler(optimizer, neox_args):
637
"""Build the learning rate scheduler."""
638
if neox_args.no_load_optim:
641
if neox_args.deepspeed and neox_args.optimizer_type.lower() == "onebitadam":
643
"WARNING: onebitadam requires the lr scheduler be built by deepspeed - "
644
"Make sure one is added to your deepspeed config"
649
if neox_args.lr_decay_iters is not None:
650
num_iters = neox_args.lr_decay_iters
652
num_iters = neox_args.train_iters
653
num_iters = max(1, num_iters)
655
warmup_iter = neox_args.warmup * num_iters
656
lr_scheduler = AnnealingLR(
658
start_lr=neox_args.lr,
659
warmup_iter=warmup_iter,
660
total_iters=num_iters,
661
decay_style=neox_args.lr_decay_style,
663
min_lr=neox_args.min_lr,
664
use_checkpoint_lr_scheduler=neox_args.use_checkpoint_lr_scheduler,
665
override_lr_scheduler=neox_args.override_lr_scheduler,
666
use_mup=neox_args.use_mup,
672
def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
673
"""Setup memory profiler"""
674
if neox_args.memory_profiling:
675
torch.cuda.memory._record_memory_history(
678
trace_alloc_max_entries=100000,
679
trace_alloc_record_context=True,
682
"""Setup model and optimizer."""
683
model = get_model(neox_args=neox_args, use_cache=use_cache)
684
optimizer, param_groups = get_optimizer(model=model, neox_args=neox_args)
685
lr_scheduler = get_learning_rate_scheduler(optimizer=optimizer, neox_args=neox_args)
687
if neox_args.deepspeed:
688
print_rank_0("DeepSpeed is enabled.")
689
if neox_args.no_load_optim:
690
assert optimizer is None
694
_model_params = param_groups if optimizer is None else None
695
_lr_scheduler = lr_scheduler
697
model, optimizer, _, lr_scheduler = deepspeed.initialize(
701
lr_scheduler=_lr_scheduler,
702
dist_init_required=False,
703
model_parameters=_model_params,
706
mpu=mpu if not neox_args.is_pipe_parallel else None,
708
model.total_params = get_total_params(model.module)
709
print_rank_0(f' > total params: {"{:,}".format(model.total_params)}')
711
if neox_args.is_pipe_parallel:
712
model.set_has_attention_mask(True)
713
if neox_args.curriculum_learning:
714
curr_scheduler = CurriculumScheduler(neox_args.curriculum_learning)
715
if iteration is not None and iteration > 0:
716
curr_scheduler.update_difficulty(iteration)
718
curr_scheduler = None
721
get_batch_pipe, neox_args=neox_args, curr_scheduler=curr_scheduler
725
model.module.set_batch_fn(
726
partial(get_batch_sequential, neox_args=neox_args)
730
raise ValueError("Must be using deepspeed to run neox")
732
if neox_args.load is not None:
733
neox_args.iteration = load_checkpoint(
737
lr_scheduler=lr_scheduler,
741
f"Loading checkpoint and starting from iteration {neox_args.iteration}"
744
neox_args.iteration = 0
748
if lr_scheduler is not None:
749
lr_scheduler.optimizer = model.optimizer
751
return model, optimizer, lr_scheduler
754
def backward_step(neox_args, timers, optimizer, model, loss):
758
timers("backward-backward").start()
759
if neox_args.deepspeed:
762
raise ValueError("Must be using deepspeed to run neox")
763
timers("backward-backward").stop()
765
if neox_args.deepspeed:
768
timers("backward-allreduce").reset()
770
raise ValueError("Must be using deepspeed to run neox")
773
def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler):
774
"""Single training step."""
777
if neox_args.is_pipe_parallel:
778
reduced_loss = train_step_pipe(
779
neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator
782
neox_args.memory_profiling
783
and neox_args.iteration >= neox_args.profile_step_start
784
and neox_args.iteration <= neox_args.profile_step_stop
785
and torch.distributed.get_rank() == 0
787
save_snapshot(neox_args)
790
for _ in range(neox_args.gradient_accumulation_steps):
792
timers("forward").start()
796
data_iterator=data_iterator,
800
timers("forward").stop()
805
and neox_args.iteration >= neox_args.profile_step_start
806
and neox_args.iteration <= neox_args.profile_step_stop
808
torch.cuda.nvtx.range_push(f"Backward pass")
809
timers("backward").start()
817
timers("backward").stop()
820
and neox_args.iteration >= neox_args.profile_step_start
821
and neox_args.iteration <= neox_args.profile_step_stop
823
torch.cuda.nvtx.range_pop()
827
and neox_args.iteration >= neox_args.profile_step_start
828
and neox_args.iteration <= neox_args.profile_step_stop
830
torch.cuda.nvtx.range_push(f"Optimizer step")
831
timers("optimizer").start()
832
if neox_args.deepspeed:
835
raise ValueError("Must be using deepspeed to run neox")
836
timers("optimizer").stop()
839
and neox_args.iteration >= neox_args.profile_step_start
840
and neox_args.iteration <= neox_args.profile_step_stop
842
torch.cuda.nvtx.range_pop()
845
and neox_args.iteration >= neox_args.profile_step_start
846
and neox_args.iteration <= neox_args.profile_step_stop
847
and torch.distributed.get_rank() == 0
849
save_snapshot(neox_args)
851
"lm_loss": reduce_losses(losses).mean()
854
if neox_args.precision == "fp16" and model.optimizer.overflow:
859
collect_loss_for_unit_test(reduced_loss["lm_loss"])
860
return reduced_loss, skipped_iter
863
def train_step_pipe(neox_args, timers, model, data_iterator):
864
"""Single training step with DeepSpeed's pipeline parallel engine."""
866
assert neox_args.deepspeed
867
loss = model.train_batch(data_iter=data_iterator)
868
loss_dict = {"lm_loss": loss}
891
"""Train the model function."""
900
iteration = neox_args.iteration
902
timers("interval time").start()
903
report_memory_flag = True
906
noise_scale_logger = get_noise_scale_logger(neox_args)
909
overflow_monitor = OverflowMonitor(optimizer)
910
while iteration < neox_args.train_iters:
911
if neox_args.profile and iteration == neox_args.profile_step_start:
912
torch.cuda.cudart().cudaProfilerStart()
913
loss_dict, skipped_iter = train_step(
916
data_iterator=train_data_iterator,
919
lr_scheduler=lr_scheduler,
921
if neox_args.profile and iteration == neox_args.profile_step_stop:
922
torch.cuda.cudart().cudaProfilerStop()
924
neox_args.iteration = iteration
925
if neox_args.precision == "fp16":
926
overflow_monitor.check(skipped_iter)
927
if neox_args.log_gradient_noise_scale:
928
noise_scale_logger.update()
932
if optimizer.param_groups:
933
lr = optimizer.param_groups[0].get("lr", 0)
938
report_memory_flag = training_log(
942
total_loss_dict=total_loss_dict,
945
loss_scale=optimizer.cur_scale if neox_args.precision == "fp16" else None,
946
report_memory_flag=report_memory_flag,
947
skipped_iter=skipped_iter,
950
noise_scale_logger=noise_scale_logger,
954
if neox_args.save and iteration in neox_args.save_iters:
960
lr_scheduler=lr_scheduler,
964
neox_args.eval_interval
965
and iteration % neox_args.eval_interval == 0
966
and neox_args.do_valid
968
prefix = "iteration {}".format(iteration)
969
evaluate_and_print_results(
972
forward_step_func=forward_step,
973
data_iterator=valid_data_iterator,
980
if neox_args.exit_interval and iteration % neox_args.exit_interval == 0:
981
torch.distributed.barrier()
982
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
983
rank = torch.distributed.get_rank()
985
"rank: {} | time: {} | exiting the program at iteration {}".format(
986
rank, time_str, iteration
995
neox_args, forward_step_fn, data_iterator, model, verbose=False, timers=None
998
neox_args: NeoX Arguments
999
forward_step_fn: function with args `neox_args, timers,
1000
data_iterator & model that will run a forward pass on the model
1001
data_iterator: Iterator that iterates over batches of data. Should return data in the form:
1002
{'text': np.array([tokens], dtype=np.int64)}
1003
where the size of the array is the model's context size + 1
1004
(`get_batch` transforms it into inputs / labels)
1009
if neox_args.char_level_ppl:
1010
data_iterator = CharCounter(data_iterator, neox_args.tokenizer)
1012
with torch.no_grad():
1014
while iteration < neox_args.eval_iters:
1016
if verbose and iteration % neox_args.log_interval == 0:
1018
"Evaluating iter {}/{}".format(iteration, neox_args.eval_iters)
1026
if neox_args.is_pipe_parallel
1027
else neox_args.gradient_accumulation_steps
1030
loss = forward_step_fn(
1032
data_iterator=data_iterator,
1033
neox_args=neox_args,
1042
if neox_args.deepspeed and neox_args.deepspeed_activation_checkpointing:
1043
deepspeed.checkpointing.reset()
1046
eval_results = {"lm_loss": reduce_losses(losses).mean().item()}
1047
eval_results["lm_loss_ppl"] = math.exp(eval_results["lm_loss"])
1049
if neox_args.char_level_ppl:
1053
tokens_per_char = data_iterator.tokens_per_char()
1054
print_rank_0(f"Counting chars took {data_iterator.total_time} seconds")
1056
data_iterator = data_iterator.data_iterator
1057
eval_results["lm_loss_char_lvl_ppl"] = math.exp(
1058
eval_results["lm_loss"] * tokens_per_char
1061
if neox_args.eval_tasks:
1062
from eval_tasks import run_eval_harness
1064
eval_results.update(
1066
model, forward_step_fn, neox_args, eval_tasks=neox_args.eval_tasks
1074
def collect_loss_for_unit_test(lm_ss):
1079
def evaluate_and_print_results(
1088
chart_name="validation",
1090
"""Helper function to evaluate and dump results on screen."""
1091
total_loss_dict = evaluate(
1092
neox_args=neox_args,
1093
forward_step_fn=forward_step_func,
1094
data_iterator=data_iterator,
1099
string = f" {chart_name} results at {prefix} | "
1100
for k, v in total_loss_dict.items():
1101
if isinstance(v, dict):
1102
if neox_args.eval_tasks and "results" in v:
1105
for k2, v2 in v.items():
1106
k3 = "_".join([k, k2])
1107
string += f"{k3} value: {v2:.6E} | "
1109
f"{chart_name}/{k3}",
1112
use_wandb=neox_args.use_wandb,
1113
tensorboard_writer=neox_args.tensorboard_writer,
1116
string += f"{k} value: {v:.6E} | "
1118
f"{chart_name}/{k}",
1121
use_wandb=neox_args.use_wandb,
1122
tensorboard_writer=neox_args.tensorboard_writer,
1125
length = len(string) + 1
1126
print_rank_0("-" * length)
1127
print_rank_0(string)
1128
print_rank_0("-" * length)
1131
def save_snapshot(neox_args):
1133
neox_args.memory_profiling_path is not None
1134
), "Must pass memory_profiling_path config arg to use profiling"
1135
snapshot = torch.cuda.memory._snapshot()
1136
snapshot_path = os.path.join(neox_args.memory_profiling_path)
1137
if not os.path.exists(snapshot_path):
1138
os.makedirs(snapshot_path)
1139
with open(os.path.join(snapshot_path, "mem_snapshot.pickle"), "wb") as f: