gpt-neox

Форк
0
/
checkpointing.py 
482 строки · 17.1 Кб
1
# Copyright (c) 2024, EleutherAI
2
# This file is based on code by the authors denoted below and has been modified from its original version.
3
#
4
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
5
#
6
# Licensed under the Apache License, Version 2.0 (the "License");
7
# you may not use this file except in compliance with the License.
8
# You may obtain a copy of the License at
9
#
10
#     http://www.apache.org/licenses/LICENSE-2.0
11
#
12
# Unless required by applicable law or agreed to in writing, software
13
# distributed under the License is distributed on an "AS IS" BASIS,
14
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
# See the License for the specific language governing permissions and
16
# limitations under the License.
17

18
"""Input/output checkpointing."""
19

20
import json
21
import math
22
import os
23
import re
24
import shutil
25
import time
26
import random
27
import sys
28
import numpy as np
29

30
try:
31
    import boto3
32
except ModuleNotFoundError:
33
    print(
34
        "For s3 checkpointing, please install boto3 either using requirements/requirements-s3.txt or https://github.com/boto/boto3"
35
    )
36
try:
37
    import hf_transfer
38
except ModuleNotFoundError:
39
    print(
40
        "For s3 checkpointing, please install hf_transfer either using requirements/requirements-s3.txt or https://github.com/huggingface/hf_transfer"
41
    )
42
import torch
43
from glob import glob
44

45
from megatron import mpu
46
from megatron import print_rank_0
47
from megatron.utils import natural_sort
48
from megatron.text_generation_utils import get_batch, forward_model
49
from pathlib import Path
50
from pprint import pformat
51

52

53
def check_checkpoint_args(neox_args, checkpoint_args):
54
    """Ensure fixed arguments for a model are the same for the input
55
    arguments and the one retrieved from checkpoint."""
56

57
    assert isinstance(checkpoint_args, dict), "args stored in checkpoint is a dict"
58
    for checkpoint_arg_name, checkpoint_arg_value in checkpoint_args.items():
59
        args_value = getattr(neox_args, checkpoint_arg_name)
60
        error_message = "{} value from checkpoint ({}) is not equal to the currently set argument value ({}).".format(
61
            checkpoint_arg_name, checkpoint_arg_value, args_value
62
        )
63
        assert checkpoint_arg_value == args_value, error_message
64

65

66
def do_forward_pass(neox_args, model, inference=False):
67

68
    # set to eval mode
69
    model_was_in_train = model.training
70
    model.eval()
71

72
    # get context tokens
73
    # always forward full batch size
74
    context_tokens_tensor = (
75
        torch.arange(neox_args.seq_length + 1)
76
        .repeat((neox_args.train_micro_batch_size_per_gpu, 1))
77
        .cuda()
78
    )
79

80
    # forward
81
    if inference:
82
        tokens, attention_mask, position_ids = get_batch(
83
            neox_args, context_tokens_tensor[:, : neox_args.seq_length]
84
        )
85
        model_inputs = (
86
            tokens,
87
            position_ids,
88
            attention_mask,
89
            torch.Tensor(),
90
        )
91
        logits, _ = forward_model(neox_args, model, model_inputs)
92
    elif neox_args.is_pipe_parallel:
93
        data_iterator = iter([{"text": context_tokens_tensor}])
94
        _, logits = model.eval_batch(data_iter=data_iterator, return_logits=True)
95
    else:
96
        tokens, attention_mask, position_ids = get_batch(
97
            neox_args, context_tokens_tensor[:, : neox_args.seq_length]
98
        )
99
        logits = model((tokens, position_ids, attention_mask))
100

101
    # reset to train mode, if model was in training before
102
    if model_was_in_train:
103
        model.train()
104

105
    if logits is not None:
106
        logits = logits.detach().cpu()[
107
            0
108
        ]  # just return first batch item (they are all equal)
109

110
    return logits
111

112

113
def check_forward_pass(neox_args, model, checkpoint_logits, inference):
114
    # do forward pass with loaded checkpoint
115
    logits = do_forward_pass(neox_args=neox_args, model=model, inference=inference)
116

117
    # check
118
    if (
119
        logits is not None and checkpoint_logits is not None
120
    ):  # this could be the case for non-final pipeline stages
121
        if not (logits == checkpoint_logits).all().item():
122
            if mpu.get_data_parallel_rank() == 0:
123
                print(
124
                    " > WARNING: validate_checkpoint_forward() forward after load of checkpoint does not yield exactly same result"
125
                )
126
            assert (
127
                torch.isclose(logits, checkpoint_logits).all().item()
128
            ), "validate_checkpoint_forward() forward after load of checkpoint does not yield a close result"
129

130

131
def ensure_directory_exists(filename):
132
    """Build filename's path if it does not already exists."""
133
    dirname = os.path.dirname(filename)
134
    if not os.path.exists(dirname):
135
        os.makedirs(dirname)
136

137

138
def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None):
139
    """A unified checkpoint name."""
140
    if release:
141
        directory = "release"
142
    else:
143
        directory = "iter_{:07d}".format(iteration)
144
    return os.path.join(
145
        checkpoints_path,
146
        directory,
147
        "mp_rank_{:02d}".format(
148
            mpu.get_model_parallel_rank() if mp_rank is None else mp_rank
149
        ),
150
        "model_optim_rng.pt",
151
    )
152

153

154
def get_checkpoint_tag(iteration: int) -> str:
155
    return f"global_step{iteration}"
156

157

158
def delete_old_checkpoints(save_dir, n_to_keep):
159
    if torch.distributed.get_rank() == 0:
160
        ckpt_dir_regex = r"global_step[\d]*"
161
        if save_dir.endswith("/"):
162
            save_dir = save_dir.strip("/")
163
        all_ckpts = natural_sort(
164
            [
165
                i
166
                for i in glob(f"{save_dir}/*")
167
                if os.path.isdir(i) and re.search(ckpt_dir_regex, i)
168
            ]
169
        )
170
        n_to_delete = len(all_ckpts) - n_to_keep
171
        if n_to_delete > 0:
172
            to_delete = all_ckpts[:n_to_delete]
173
            print(f"WARNING: Deleting old checkpoints: \n\t{', '.join(to_delete)}")
174
            for ckpt in to_delete:
175
                try:
176
                    shutil.rmtree(ckpt)
177
                except FileNotFoundError:
178
                    pass
179

180

181
def save_ds_checkpoint(iteration, model, neox_args):
182
    """Save a model checkpoint."""
183
    sd = {
184
        "iteration": iteration,
185
        "args": {
186
            "num_layers": neox_args.num_layers,
187
            "hidden_size": neox_args.hidden_size,
188
            "num_attention_heads": neox_args.num_attention_heads,
189
            "max_position_embeddings": neox_args.max_position_embeddings,
190
            "make_vocab_size_divisible_by": neox_args.make_vocab_size_divisible_by,
191
            "padded_vocab_size": neox_args.padded_vocab_size,
192
            "tokenizer_type": neox_args.tokenizer_type,
193
            "model_parallel_size": neox_args.model_parallel_size,
194
        },
195
    }
196
    # rng states.
197
    if not neox_args.no_save_rng:
198
        sd["random_rng_state"] = random.getstate()
199
        sd["np_rng_state"] = np.random.get_state()
200
        sd["torch_rng_state"] = torch.get_rng_state()
201
        sd["cuda_rng_state"] = torch.cuda.get_rng_state()
202
        sd["rng_tracker_states"] = mpu.get_cuda_rng_tracker().get_states()
203

204
    if neox_args.checkpoint_validation_with_forward_pass:
205
        logits = do_forward_pass(neox_args=neox_args, model=model)
206
        sd["checkpoint_validation_logits"] = logits
207

208
    # checkpoint folder name
209
    tag = get_checkpoint_tag(iteration)
210

211
    # save checkpoint
212
    model.save_checkpoint(neox_args.save, tag=tag, client_state=sd)
213

214
    # save config files
215
    if torch.distributed.get_rank() == 0 and neox_args.config_files is not None:
216
        configs_directory = os.path.join(neox_args.save, tag, "configs")
217
        os.makedirs(configs_directory, exist_ok=True)
218
        for config_filename, config_data in neox_args.config_files.items():
219
            with open(os.path.join(configs_directory, config_filename), "w") as f:
220
                if isinstance(config_data, str):
221
                    f.write(config_data)
222
                else:
223
                    json.dump(config_data, f)
224

225

226
def multiprocessing_starmap(func, args, num_processes=None):
227
    """Wrapper to allow for re-usable multiprocessing pools with `spawn` context handling
228
    Args:
229
        func (Callable): Function to call
230
        args (Iterable): Iterable of arguments to pass to `func`
231
        num_processes (int, optional): Number of processes to spawn. Defaults to `multiprocessing.cpu_count() - 1`
232
    """
233
    import multiprocessing
234

235
    num_processes = num_processes or (multiprocessing.cpu_count() - 1)
236
    with multiprocessing.get_context("spawn").Pool(
237
        processes=num_processes
238
    ) as process_pool:
239
        process_pool.starmap(func, args)
240
        process_pool.terminate()
241
        process_pool.join()
242
        del process_pool
243

244

245
def _upload(
246
    file_path: str,
247
    s3_key: str,
248
    chunk_size: int = 104_857_600,
249
    max_files: int = 64,
250
    parallel_failures: int = 63,
251
    max_retries: int = 5,
252
):
253
    """Upload local file to S3 using `hf_transfer` library
254
    Args:
255
        file_path (str): Local filename to upload
256
        s3_key (str): S3 key to upload to. E.g. `s3://bucket-name/path/to/file`
257
        chunk_size (int, optional): Chunk size to use for multipart upload.
258
            Defaults to 100MiB = 104_857_600
259
        max_files (int, optional):  Number of open file handles, which determines
260
            the maximum number of parallel downloads. Defaults to 64
261
        parallel_failures (int, optional): Number of maximum failures of different
262
            chunks in parallel (cannot exceed max_files). Defaults to 63
263
        max_retries (int, optional): Number of retries for each chunk. Defaults to 5
264
    """
265
    s3 = boto3.client("s3")
266
    bucket = s3_key.split("s3://")[1].split("/")[0]
267
    key = s3_key.split(bucket)[1].lstrip("/")
268

269
    # 1. Init multipart upload and obtain unique upload identifier
270
    upload = s3.create_multipart_upload(
271
        ACL="bucket-owner-full-control",
272
        Bucket=bucket,
273
        Key=key,
274
    )
275
    upload_id = upload["UploadId"]
276

277
    # 2. Generate presigned URLs for each part
278
    file_size = os.stat(file_path).st_size
279
    urls = []
280
    nb_parts = math.ceil(file_size / chunk_size)
281
    for part_number in range(1, nb_parts + 1):
282
        params = {
283
            "Bucket": bucket,
284
            "Key": key,
285
            "PartNumber": part_number,
286
            "UploadId": upload_id,
287
        }
288
        urls.append(
289
            s3.generate_presigned_url(
290
                ClientMethod="upload_part", Params=params, ExpiresIn=86400
291
            )
292
        )
293

294
    # 3. Upload parts in parallel
295
    responses = hf_transfer.multipart_upload(
296
        file_path=file_path,
297
        parts_urls=urls,
298
        chunk_size=chunk_size,
299
        max_files=max_files,
300
        parallel_failures=parallel_failures,
301
        max_retries=max_retries,
302
    )
303

304
    # 4. Complete multipart upload request with ETag values
305
    etag_with_parts = []
306
    for part_number, header in enumerate(responses):
307
        etag = header.get("etag")
308
        etag_with_parts.append({"ETag": etag, "PartNumber": part_number + 1})
309
    parts = {"Parts": etag_with_parts}
310
    s3.complete_multipart_upload(
311
        Bucket=bucket, Key=key, MultipartUpload=parts, UploadId=upload_id
312
    )
313

314

315
def upload_checkpoint(iteration, neox_args):
316
    local_checkpoint_path = os.path.join(
317
        os.path.abspath(neox_args.save), get_checkpoint_tag(iteration)
318
    )
319
    local_checkpoint_list = sorted(
320
        filter(
321
            lambda x: os.path.isfile(x),
322
            [str(p) for p in Path(local_checkpoint_path).rglob("*")],
323
        )
324
    )
325
    remote_checkpoint_path = os.path.join(
326
        neox_args.s3_path,
327
        os.path.basename(neox_args.save),
328
        get_checkpoint_tag(iteration),
329
    )
330
    remote_checkpoint_list = [
331
        os.path.join(
332
            remote_checkpoint_path,
333
            os.path.relpath(local_checkpoint, local_checkpoint_path),
334
        )
335
        for local_checkpoint in local_checkpoint_list
336
    ]
337
    inputs = zip(
338
        local_checkpoint_list,
339
        remote_checkpoint_list,
340
        [neox_args.s3_chunk_size] * len(local_checkpoint_list),
341
    )
342

343
    print_rank_0(
344
        f"[RANK {torch.distributed.get_rank()}] Uploading checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}`..."
345
    )
346
    start = time.time()
347
    multiprocessing_starmap(_upload, inputs)
348
    total_time = time.time() - start
349
    print_rank_0(
350
        f"[RANK {torch.distributed.get_rank()}] Uploaded checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}` in {total_time:.2f}s"
351
    )
352

353

354
def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
355
    """Save a model checkpoint."""
356

357
    if neox_args.deepspeed:
358
        save_ds_checkpoint(iteration, model, neox_args)
359
    else:
360
        raise ValueError("Must be using deepspeed to use neox")
361

362
    torch.distributed.barrier()
363
    upload_to_s3 = torch.distributed.get_rank() == 0 and neox_args.s3_path is not None
364
    if upload_to_s3:
365
        upload_checkpoint(iteration, neox_args)
366

367
    # Wait so everyone is done (necessary)
368
    torch.distributed.barrier()
369
    if neox_args.keep_last_n_checkpoints is not None:
370
        delete_old_checkpoints(neox_args.save, neox_args.keep_last_n_checkpoints)
371

372
    # Wait so everyone is done (not necessary)
373
    torch.distributed.barrier()
374

375

376
def load_checkpoint(
377
    neox_args, model, optimizer, lr_scheduler, inference=False, iteration=None
378
):
379
    """Load a model checkpoint and return the iteration."""
380
    if neox_args.deepspeed:
381
        load_optim_and_scheduler = (
382
            not neox_args.no_load_optim
383
        )  # TODO: These should be configured by separate args
384
        if neox_args.finetune:
385
            load_optim_and_scheduler = False
386
        if iteration is not None:
387
            tag = get_checkpoint_tag(iteration)
388
        else:
389
            tag = None
390
        checkpoint_name, state_dict = model.load_checkpoint(
391
            neox_args.load,
392
            load_optimizer_states=load_optim_and_scheduler,
393
            load_lr_scheduler_states=load_optim_and_scheduler,
394
            load_module_only=not load_optim_and_scheduler,
395
            tag=tag,
396
        )
397

398
        if checkpoint_name is None:
399
            # if an iteration is specified, we want to raise an error here rather than
400
            # continuing silently, since we are trying to load a specific checkpoint
401
            if iteration is not None:
402
                available_checkpoints = sorted(
403
                    [
404
                        int(i.name.replace("global_step", ""))
405
                        for i in Path(neox_args.load).glob("global_step*")
406
                    ]
407
                )
408
                raise ValueError(
409
                    f"Unable to load checkpoint for iteration {iteration}. \nAvailable iterations: {pformat(available_checkpoints)}"
410
                )
411
            if mpu.get_data_parallel_rank() == 0:
412
                print("Unable to load checkpoint.")
413

414
            return 0  # iteration 0, if not checkpoint loaded
415
    else:
416
        raise ValueError("Must be using deepspeed to use neox")
417

418
    # Set iteration.
419
    if neox_args.finetune:
420
        iteration = 0
421
    else:
422
        if "iteration" in state_dict:
423
            iteration = state_dict["iteration"]
424
        else:
425
            iteration = state_dict.get(
426
                "total_iters"
427
            )  # total_iters backward compatible with older checkpoints
428
        if iteration is None:
429
            raise ValueError(
430
                f"Unable to load iteration from checkpoint {checkpoint_name} with keys {state_dict.keys()}, exiting"
431
            )
432

433
    # Check arguments.
434
    if "args" in state_dict:
435
        checkpoint_args = state_dict["args"]
436
        check_checkpoint_args(neox_args=neox_args, checkpoint_args=checkpoint_args)
437
        print_rank_0(
438
            " > validated currently set args with arguments in the checkpoint ..."
439
        )
440
    else:
441
        print_rank_0(" > could not find arguments in the checkpoint for validation...")
442

443
    # Check loaded checkpoint with forward pass
444
    if neox_args.checkpoint_validation_with_forward_pass:
445
        if "checkpoint_validation_logits" in state_dict:
446
            check_forward_pass(
447
                neox_args=neox_args,
448
                model=model,
449
                checkpoint_logits=state_dict["checkpoint_validation_logits"],
450
                inference=inference,
451
            )
452
            print_rank_0(" > validated loaded checkpoint with forward pass ...")
453
        else:
454
            if mpu.get_data_parallel_rank() == 0:
455
                print(
456
                    " > WARNING: checkpoint_validation_with_forward_pass is configured but no checkpoint validation data available in checkpoint {}".format(
457
                        checkpoint_name
458
                    )
459
                )
460

461
    # rng states.
462
    if not neox_args.finetune and not neox_args.no_load_rng:
463
        try:
464
            random.setstate(state_dict["random_rng_state"])
465
            np.random.set_state(state_dict["np_rng_state"])
466
            torch.set_rng_state(state_dict["torch_rng_state"])
467
            torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
468
            mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
469
        except KeyError:
470
            print_rank_0(
471
                "Unable to load optimizer from checkpoint {}. "
472
                "Specify --no-load-rng or --finetune to prevent "
473
                "attempting to load the optimizer state, "
474
                "exiting ...".format(checkpoint_name)
475
            )
476
            sys.exit()
477

478
    torch.distributed.barrier()
479
    if mpu.get_data_parallel_rank() == 0:
480
        print("  successfully loaded {}".format(checkpoint_name))
481

482
    return iteration
483

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.