1
# Copyright (c) 2024, EleutherAI
2
# This file is based on code by the authors denoted below and has been modified from its original version.
4
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
6
# Licensed under the Apache License, Version 2.0 (the "License");
7
# you may not use this file except in compliance with the License.
8
# You may obtain a copy of the License at
10
# http://www.apache.org/licenses/LICENSE-2.0
12
# Unless required by applicable law or agreed to in writing, software
13
# distributed under the License is distributed on an "AS IS" BASIS,
14
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
# See the License for the specific language governing permissions and
16
# limitations under the License.
18
"""Input/output checkpointing."""
32
except ModuleNotFoundError:
34
"For s3 checkpointing, please install boto3 either using requirements/requirements-s3.txt or https://github.com/boto/boto3"
38
except ModuleNotFoundError:
40
"For s3 checkpointing, please install hf_transfer either using requirements/requirements-s3.txt or https://github.com/huggingface/hf_transfer"
45
from megatron import mpu
46
from megatron import print_rank_0
47
from megatron.utils import natural_sort
48
from megatron.text_generation_utils import get_batch, forward_model
49
from pathlib import Path
50
from pprint import pformat
53
def check_checkpoint_args(neox_args, checkpoint_args):
54
"""Ensure fixed arguments for a model are the same for the input
55
arguments and the one retrieved from checkpoint."""
57
assert isinstance(checkpoint_args, dict), "args stored in checkpoint is a dict"
58
for checkpoint_arg_name, checkpoint_arg_value in checkpoint_args.items():
59
args_value = getattr(neox_args, checkpoint_arg_name)
60
error_message = "{} value from checkpoint ({}) is not equal to the currently set argument value ({}).".format(
61
checkpoint_arg_name, checkpoint_arg_value, args_value
63
assert checkpoint_arg_value == args_value, error_message
66
def do_forward_pass(neox_args, model, inference=False):
69
model_was_in_train = model.training
73
# always forward full batch size
74
context_tokens_tensor = (
75
torch.arange(neox_args.seq_length + 1)
76
.repeat((neox_args.train_micro_batch_size_per_gpu, 1))
82
tokens, attention_mask, position_ids = get_batch(
83
neox_args, context_tokens_tensor[:, : neox_args.seq_length]
91
logits, _ = forward_model(neox_args, model, model_inputs)
92
elif neox_args.is_pipe_parallel:
93
data_iterator = iter([{"text": context_tokens_tensor}])
94
_, logits = model.eval_batch(data_iter=data_iterator, return_logits=True)
96
tokens, attention_mask, position_ids = get_batch(
97
neox_args, context_tokens_tensor[:, : neox_args.seq_length]
99
logits = model((tokens, position_ids, attention_mask))
101
# reset to train mode, if model was in training before
102
if model_was_in_train:
105
if logits is not None:
106
logits = logits.detach().cpu()[
108
] # just return first batch item (they are all equal)
113
def check_forward_pass(neox_args, model, checkpoint_logits, inference):
114
# do forward pass with loaded checkpoint
115
logits = do_forward_pass(neox_args=neox_args, model=model, inference=inference)
119
logits is not None and checkpoint_logits is not None
120
): # this could be the case for non-final pipeline stages
121
if not (logits == checkpoint_logits).all().item():
122
if mpu.get_data_parallel_rank() == 0:
124
" > WARNING: validate_checkpoint_forward() forward after load of checkpoint does not yield exactly same result"
127
torch.isclose(logits, checkpoint_logits).all().item()
128
), "validate_checkpoint_forward() forward after load of checkpoint does not yield a close result"
131
def ensure_directory_exists(filename):
132
"""Build filename's path if it does not already exists."""
133
dirname = os.path.dirname(filename)
134
if not os.path.exists(dirname):
138
def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None):
139
"""A unified checkpoint name."""
141
directory = "release"
143
directory = "iter_{:07d}".format(iteration)
147
"mp_rank_{:02d}".format(
148
mpu.get_model_parallel_rank() if mp_rank is None else mp_rank
150
"model_optim_rng.pt",
154
def get_checkpoint_tag(iteration: int) -> str:
155
return f"global_step{iteration}"
158
def delete_old_checkpoints(save_dir, n_to_keep):
159
if torch.distributed.get_rank() == 0:
160
ckpt_dir_regex = r"global_step[\d]*"
161
if save_dir.endswith("/"):
162
save_dir = save_dir.strip("/")
163
all_ckpts = natural_sort(
166
for i in glob(f"{save_dir}/*")
167
if os.path.isdir(i) and re.search(ckpt_dir_regex, i)
170
n_to_delete = len(all_ckpts) - n_to_keep
172
to_delete = all_ckpts[:n_to_delete]
173
print(f"WARNING: Deleting old checkpoints: \n\t{', '.join(to_delete)}")
174
for ckpt in to_delete:
177
except FileNotFoundError:
181
def save_ds_checkpoint(iteration, model, neox_args):
182
"""Save a model checkpoint."""
184
"iteration": iteration,
186
"num_layers": neox_args.num_layers,
187
"hidden_size": neox_args.hidden_size,
188
"num_attention_heads": neox_args.num_attention_heads,
189
"max_position_embeddings": neox_args.max_position_embeddings,
190
"make_vocab_size_divisible_by": neox_args.make_vocab_size_divisible_by,
191
"padded_vocab_size": neox_args.padded_vocab_size,
192
"tokenizer_type": neox_args.tokenizer_type,
193
"model_parallel_size": neox_args.model_parallel_size,
197
if not neox_args.no_save_rng:
198
sd["random_rng_state"] = random.getstate()
199
sd["np_rng_state"] = np.random.get_state()
200
sd["torch_rng_state"] = torch.get_rng_state()
201
sd["cuda_rng_state"] = torch.cuda.get_rng_state()
202
sd["rng_tracker_states"] = mpu.get_cuda_rng_tracker().get_states()
204
if neox_args.checkpoint_validation_with_forward_pass:
205
logits = do_forward_pass(neox_args=neox_args, model=model)
206
sd["checkpoint_validation_logits"] = logits
208
# checkpoint folder name
209
tag = get_checkpoint_tag(iteration)
212
model.save_checkpoint(neox_args.save, tag=tag, client_state=sd)
215
if torch.distributed.get_rank() == 0 and neox_args.config_files is not None:
216
configs_directory = os.path.join(neox_args.save, tag, "configs")
217
os.makedirs(configs_directory, exist_ok=True)
218
for config_filename, config_data in neox_args.config_files.items():
219
with open(os.path.join(configs_directory, config_filename), "w") as f:
220
if isinstance(config_data, str):
223
json.dump(config_data, f)
226
def multiprocessing_starmap(func, args, num_processes=None):
227
"""Wrapper to allow for re-usable multiprocessing pools with `spawn` context handling
229
func (Callable): Function to call
230
args (Iterable): Iterable of arguments to pass to `func`
231
num_processes (int, optional): Number of processes to spawn. Defaults to `multiprocessing.cpu_count() - 1`
233
import multiprocessing
235
num_processes = num_processes or (multiprocessing.cpu_count() - 1)
236
with multiprocessing.get_context("spawn").Pool(
237
processes=num_processes
239
process_pool.starmap(func, args)
240
process_pool.terminate()
248
chunk_size: int = 104_857_600,
250
parallel_failures: int = 63,
251
max_retries: int = 5,
253
"""Upload local file to S3 using `hf_transfer` library
255
file_path (str): Local filename to upload
256
s3_key (str): S3 key to upload to. E.g. `s3://bucket-name/path/to/file`
257
chunk_size (int, optional): Chunk size to use for multipart upload.
258
Defaults to 100MiB = 104_857_600
259
max_files (int, optional): Number of open file handles, which determines
260
the maximum number of parallel downloads. Defaults to 64
261
parallel_failures (int, optional): Number of maximum failures of different
262
chunks in parallel (cannot exceed max_files). Defaults to 63
263
max_retries (int, optional): Number of retries for each chunk. Defaults to 5
265
s3 = boto3.client("s3")
266
bucket = s3_key.split("s3://")[1].split("/")[0]
267
key = s3_key.split(bucket)[1].lstrip("/")
269
# 1. Init multipart upload and obtain unique upload identifier
270
upload = s3.create_multipart_upload(
271
ACL="bucket-owner-full-control",
275
upload_id = upload["UploadId"]
277
# 2. Generate presigned URLs for each part
278
file_size = os.stat(file_path).st_size
280
nb_parts = math.ceil(file_size / chunk_size)
281
for part_number in range(1, nb_parts + 1):
285
"PartNumber": part_number,
286
"UploadId": upload_id,
289
s3.generate_presigned_url(
290
ClientMethod="upload_part", Params=params, ExpiresIn=86400
294
# 3. Upload parts in parallel
295
responses = hf_transfer.multipart_upload(
298
chunk_size=chunk_size,
300
parallel_failures=parallel_failures,
301
max_retries=max_retries,
304
# 4. Complete multipart upload request with ETag values
306
for part_number, header in enumerate(responses):
307
etag = header.get("etag")
308
etag_with_parts.append({"ETag": etag, "PartNumber": part_number + 1})
309
parts = {"Parts": etag_with_parts}
310
s3.complete_multipart_upload(
311
Bucket=bucket, Key=key, MultipartUpload=parts, UploadId=upload_id
315
def upload_checkpoint(iteration, neox_args):
316
local_checkpoint_path = os.path.join(
317
os.path.abspath(neox_args.save), get_checkpoint_tag(iteration)
319
local_checkpoint_list = sorted(
321
lambda x: os.path.isfile(x),
322
[str(p) for p in Path(local_checkpoint_path).rglob("*")],
325
remote_checkpoint_path = os.path.join(
327
os.path.basename(neox_args.save),
328
get_checkpoint_tag(iteration),
330
remote_checkpoint_list = [
332
remote_checkpoint_path,
333
os.path.relpath(local_checkpoint, local_checkpoint_path),
335
for local_checkpoint in local_checkpoint_list
338
local_checkpoint_list,
339
remote_checkpoint_list,
340
[neox_args.s3_chunk_size] * len(local_checkpoint_list),
344
f"[RANK {torch.distributed.get_rank()}] Uploading checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}`..."
347
multiprocessing_starmap(_upload, inputs)
348
total_time = time.time() - start
350
f"[RANK {torch.distributed.get_rank()}] Uploaded checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}` in {total_time:.2f}s"
354
def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
355
"""Save a model checkpoint."""
357
if neox_args.deepspeed:
358
save_ds_checkpoint(iteration, model, neox_args)
360
raise ValueError("Must be using deepspeed to use neox")
362
torch.distributed.barrier()
363
upload_to_s3 = torch.distributed.get_rank() == 0 and neox_args.s3_path is not None
365
upload_checkpoint(iteration, neox_args)
367
# Wait so everyone is done (necessary)
368
torch.distributed.barrier()
369
if neox_args.keep_last_n_checkpoints is not None:
370
delete_old_checkpoints(neox_args.save, neox_args.keep_last_n_checkpoints)
372
# Wait so everyone is done (not necessary)
373
torch.distributed.barrier()
377
neox_args, model, optimizer, lr_scheduler, inference=False, iteration=None
379
"""Load a model checkpoint and return the iteration."""
380
if neox_args.deepspeed:
381
load_optim_and_scheduler = (
382
not neox_args.no_load_optim
383
) # TODO: These should be configured by separate args
384
if neox_args.finetune:
385
load_optim_and_scheduler = False
386
if iteration is not None:
387
tag = get_checkpoint_tag(iteration)
390
checkpoint_name, state_dict = model.load_checkpoint(
392
load_optimizer_states=load_optim_and_scheduler,
393
load_lr_scheduler_states=load_optim_and_scheduler,
394
load_module_only=not load_optim_and_scheduler,
398
if checkpoint_name is None:
399
# if an iteration is specified, we want to raise an error here rather than
400
# continuing silently, since we are trying to load a specific checkpoint
401
if iteration is not None:
402
available_checkpoints = sorted(
404
int(i.name.replace("global_step", ""))
405
for i in Path(neox_args.load).glob("global_step*")
409
f"Unable to load checkpoint for iteration {iteration}. \nAvailable iterations: {pformat(available_checkpoints)}"
411
if mpu.get_data_parallel_rank() == 0:
412
print("Unable to load checkpoint.")
414
return 0 # iteration 0, if not checkpoint loaded
416
raise ValueError("Must be using deepspeed to use neox")
419
if neox_args.finetune:
422
if "iteration" in state_dict:
423
iteration = state_dict["iteration"]
425
iteration = state_dict.get(
427
) # total_iters backward compatible with older checkpoints
428
if iteration is None:
430
f"Unable to load iteration from checkpoint {checkpoint_name} with keys {state_dict.keys()}, exiting"
434
if "args" in state_dict:
435
checkpoint_args = state_dict["args"]
436
check_checkpoint_args(neox_args=neox_args, checkpoint_args=checkpoint_args)
438
" > validated currently set args with arguments in the checkpoint ..."
441
print_rank_0(" > could not find arguments in the checkpoint for validation...")
443
# Check loaded checkpoint with forward pass
444
if neox_args.checkpoint_validation_with_forward_pass:
445
if "checkpoint_validation_logits" in state_dict:
449
checkpoint_logits=state_dict["checkpoint_validation_logits"],
452
print_rank_0(" > validated loaded checkpoint with forward pass ...")
454
if mpu.get_data_parallel_rank() == 0:
456
" > WARNING: checkpoint_validation_with_forward_pass is configured but no checkpoint validation data available in checkpoint {}".format(
462
if not neox_args.finetune and not neox_args.no_load_rng:
464
random.setstate(state_dict["random_rng_state"])
465
np.random.set_state(state_dict["np_rng_state"])
466
torch.set_rng_state(state_dict["torch_rng_state"])
467
torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
468
mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
471
"Unable to load optimizer from checkpoint {}. "
472
"Specify --no-load-rng or --finetune to prevent "
473
"attempting to load the optimizer state, "
474
"exiting ...".format(checkpoint_name)
478
torch.distributed.barrier()
479
if mpu.get_data_parallel_rank() == 0:
480
print(" successfully loaded {}".format(checkpoint_name))